#include "collision_body_mesh.hpp"
#include "collision_body.hpp"
#include "distance.hpp"
#include "intersection.hpp"
#include "my_assert.hpp"

bool Collision_body_mesh::m_interpolate_mesh = false;

inline Scalar max2(Scalar a, Scalar b)
{
  return (a > b ? a : b);
}
inline Scalar max3(Scalar a, Scalar b, Scalar c) 
{
  return (max2(a, max2(b, c)));
}
inline Scalar max4(Scalar a, Scalar b, Scalar c, Scalar d) 
{
  return (max2(a, max3(b, c, d)));
}

inline Scalar min2(Scalar a, Scalar b)
{
  return (a < b ? a : b);
}
inline Scalar min3(Scalar a, Scalar b, Scalar c) 
{
  return (min2(a, min2(b, c)));
}
inline Scalar min4(Scalar a, Scalar b, Scalar c, Scalar d) 
{
  return (min2(a, min3(b, c, d)));
}

//==============================================================
// Collision_body_mesh
//==============================================================
Collision_body_mesh::Collision_body_mesh()
{
}

//==============================================================
// construct_mesh
//==============================================================
void Collision_body_mesh::construct_mesh(const Collision_body_mesh & mesh)
{
  *this = mesh;
}

//==============================================================
// construct_mesh
//==============================================================
void Collision_body_mesh::construct_mesh(int nx, int ny, int nz, 
                                         Scalar extra_frac,
                                         const std::vector<Collision_triangle> & triangles,
                                         const Collision_mesh_object * body)
{
  // first set up the mesh bounds
  m_nx = nx;
  m_ny = ny;
  m_nz = nz;

  assert1(m_nx > 2);
  assert1(m_nx > 2);
  assert1(m_nx > 2);

  int i;
  const Scalar big_val = 99999999.0f;
  m_min_x = m_min_y = m_min_z = big_val;
  m_max_x = m_max_y = m_max_z = -big_val;

  for (i = 0 ; i < (int) triangles.size() ; ++i)
  {
    m_min_x = min4(m_min_x, triangles[i].v0[0], triangles[i].v1[0], triangles[i].v2[0]);
    m_min_y = min4(m_min_y, triangles[i].v0[1], triangles[i].v1[1], triangles[i].v2[1]);
    m_min_z = min4(m_min_z, triangles[i].v0[2], triangles[i].v1[2], triangles[i].v2[2]);
    m_max_x = max4(m_max_x, triangles[i].v0[0], triangles[i].v1[0], triangles[i].v2[0]);
    m_max_y = max4(m_max_y, triangles[i].v0[1], triangles[i].v1[1], triangles[i].v2[1]);
    m_max_z = max4(m_max_z, triangles[i].v0[2], triangles[i].v1[2], triangles[i].v2[2]);
  }

  Scalar mid_x = 0.5f * (m_min_x + m_max_x);
  Scalar mid_y = 0.5f * (m_min_y + m_max_y);
  Scalar mid_z = 0.5f * (m_min_z + m_max_z);

  m_min_x = mid_x + extra_frac * (m_min_x - mid_x);
  m_min_y = mid_y + extra_frac * (m_min_y - mid_y);
  m_min_z = mid_z + extra_frac * (m_min_z - mid_z);
  m_max_x = mid_x + extra_frac * (m_max_x - mid_x);
  m_max_y = mid_y + extra_frac * (m_max_y - mid_y);
  m_max_z = mid_z + extra_frac * (m_max_z - mid_z);

  m_dx = (m_max_x - m_min_x) / (m_nx - 1);
  m_dy = (m_max_y - m_min_y) / (m_ny - 1);
  m_dz = (m_max_z - m_min_z) / (m_nz - 1);

  Position corners[8];
  corners[0] = Position(m_max_x, m_max_y, m_max_z);
  corners[1] = Position(m_max_x, m_max_y, m_min_z);
  corners[2] = Position(m_max_x, m_min_y, m_max_z);
  corners[3] = Position(m_max_x, m_min_y, m_min_z);
  corners[4] = Position(m_min_x, m_max_y, m_max_z);
  corners[5] = Position(m_min_x, m_max_y, m_min_z);
  corners[6] = Position(m_min_x, m_min_y, m_max_z);
  corners[7] = Position(m_min_x, m_min_y, m_min_z);
  m_bounding_radius = corners[0].mag();
  for (i = 1 ; i < 8 ; ++i)
  {
    if (corners[i].mag() > m_bounding_radius)
      m_bounding_radius = corners[i].mag();
  }

  if (body)
    populate_mesh(*body);
  else
    populate_mesh(triangles);
}

//==============================================================
// count_triangle_hits
//==============================================================
int count_triangle_hits(const Segment & seg, const std::vector<Collision_triangle> & triangles)
{
  int i;
  const int num = triangles.size();
  int hits = 0;
  Position pos;
  Scalar S;
  bool seg_in_dir;
  for (i = 0 ; i < num ; ++i)
  {
    Triangle tri(triangles[i].v0,
                 triangles[i].v1 - triangles[i].v0,
                 triangles[i].v2 - triangles[i].v0);
    if (intersect_segment_triangle(seg, tri,
                                   pos, S, seg_in_dir))
    {
      ++hits;
    }
  }
  return hits;
}

//==============================================================
// find_closest_triangle
//==============================================================
void find_closest_triangle(const Position & pos,
                           const std::vector<Collision_triangle> & triangles, 
                           int &triangle_index, 
                           Position & point_on_tri)
{
  int i;
  const int num = triangles.size();
  Position this_point;
  Scalar closest_sqr = 1.0E10;
  bool got_one = false;
  bool point_on_triangle_edge;

  for (i = 0 ; i < num ; ++i)
  {
    Scalar dist_sqr = distance_sqr_point_triangle(pos, 
                                                  triangles[i].v0, triangles[i].v1, triangles[i].v2, 
                                                  this_point,
                                                  point_on_triangle_edge);
    if (dist_sqr < closest_sqr)
    {
      got_one = true;
      triangle_index = i;
      point_on_tri = this_point;
      closest_sqr = dist_sqr;
    }
  }
  assert1(got_one);
}

//==============================================================
// populate_mesh
//==============================================================
void Collision_body_mesh::populate_mesh(const Collision_mesh_object & body)
{
  m_data.resize(m_nx, m_ny, m_nz);
  int i, j, k;

  for (i = 0 ; i < m_nx ; ++i)
  {
    for (j = 0 ; j < m_ny ; ++j)
    {
      for (k = 0 ; k < m_nz ; ++k)
      {
        Position pos(m_min_x + i * m_dx, m_min_y + j * m_dy, m_min_z + k * m_dz);
        bool inside = false;
        Vector3 vector_to_surface;
        if (body.get_mesh_info(pos, inside, vector_to_surface))
        {
          Datum & datum = m_data(i, j, k);
          Scalar dist = vector_to_surface.mag();
          if (dist > 0.00001f)
          {
            if (inside)
            {
              datum.dir = vector_to_surface / dist;
              datum.dist = dist;
            }
            else
            {
              datum.dir = -vector_to_surface / dist;
              datum.dist = -dist;
            }
          }
          else
          {
            datum.dir = pos;
            datum.dir.normalise();
            datum.dist = 0.00001f;
          }
        }
        else
        {
          assert1(!"Need to implement Collision_body::get_mesh_info");
        }
      }
    }
  }

}

//==============================================================
// populate_mesh
//==============================================================
void Collision_body_mesh::populate_mesh(const std::vector<Collision_triangle> & triangles)
{
  if (triangles.size() == 0)
  {
    TRACE_FILE_IF(ONCE_1)
      TRACE("no triangles to construct mesh - hope you're not intending to use it\n");
    return;
  }

  m_data.resize(m_nx, m_ny, m_nz);
  int i, j, k;

  Vector3 basic_dir(m_max_x - m_min_x, m_max_y - m_min_y, m_max_z - m_min_z);
  basic_dir *= 5.0f;

  static const int num_outside_dirs = 10;
  Vector3 outside_dirs[num_outside_dirs];

  for (i = 0 ; i < num_outside_dirs ; ++i)
  {
    outside_dirs[i] = Vector3(ranged_random(-1.0f, 1.10f), ranged_random(-1.0f, 1.10f), ranged_random(-1.0f, 1.10f));
    outside_dirs[i].normalise();
    outside_dirs[i] *= 2.0f * get_bounding_radius();
  }

  for (i = 0 ; i < m_nx ; ++i)
  {
    for (j = 0 ; j < m_ny ; ++j)
    {
      for (k = 0 ; k < m_nz ; ++k)
      {
        Datum & datum = m_data(i, j, k);
        Position pos(m_min_x + i * m_dx, m_min_y + j * m_dy, m_min_z + k * m_dz);
        Position point_on_triangle;
        int triangle_index;
        find_closest_triangle(pos, triangles, triangle_index, point_on_triangle);
        // set up so correct for point inside
        datum.dir = point_on_triangle - pos;
        datum.dist = datum.dir.mag();
        if (datum.dist < 0.0001f)
        {
          // indicate we've inside...
          datum.dist = 0.0001f;
          // and take the normal from the closest face
          datum.dir = cross(triangles[triangle_index].v1 - triangles[triangle_index].v0,
                            triangles[triangle_index].v2 - triangles[triangle_index].v0);
          datum.dir.normalise();
        }
        else
        {
          datum.dir /= datum.dist;

          // tweak depending on if we're inside or outside.
          int l;
          int num_in = 0;
          int num_out = 0;
          for (l = 0 ; l < num_outside_dirs ; ++l)
          {
            Segment seg(pos, outside_dirs[l]);
            int hits = count_triangle_hits(seg, triangles);
            if (hits % 2)
              ++num_in;
            else
              ++num_out;
          }
          if (num_out > num_in)
          {
            // outside
            datum.dist = -datum.dist;
            datum.dir = -datum.dir;
          }
          if ( (num_in > 0) && (num_out > 0) && (fabs((Scalar) num_out - num_in) / num_outside_dirs < 0.3f) )
            TRACE_FILE_IF(ONCE_3)
              TRACE("eeeeek: (%d, %d, %d) in = %d, out = %d\n", i, j, k, num_in, num_out);
        }
      }
    }
  }
}

// returns the value interpolated from the four corner points
// x,y,z must be 0.0-1.0
template<class T>
T interp3d(const T & val000,
           const T & val001,
           const T & val010,
           const T & val011,
           const T & val100,
           const T & val101,
           const T & val110,
           const T & val111,
           Scalar x, Scalar y, Scalar z)
{
  T vxy0 = (1.0f - y) * (val000 * (1.0f - x) + val100 * x) + 
                    y * (val010 * (1.0f - x) + val110 * x);
  T vxy1 = (1.0f - y) * (val001 * (1.0f - x) + val101 * x) + 
                    y * (val011 * (1.0f - x) + val111 * x);

  return (1.0f - z) * vxy0 + z * vxy1;
}

//==============================================================
// get_point_info
//==============================================================
bool Collision_body_mesh::get_point_info(const Position & pos,
                                         Vector3 & dir,
                                         Scalar & dist,
                                         bool accurate_outside_dist)
{
  // In the interests of speed over accuracy, just return the data at
  // the nearest point. This avoids having to average and normalise.

  if (!accurate_outside_dist)
  {
    if ( (pos[0] < m_min_x) || (pos[0] > m_max_x) ||
         (pos[1] < m_min_y) || (pos[1] > m_max_y) ||
         (pos[2] < m_min_z) || (pos[2] > m_max_z) )
    {
      return false;
    }
  }

  // the floating-point indices of the point.
  Scalar fi = (pos[0] - m_min_x) / m_dx;
  Scalar fj = (pos[1] - m_min_y) / m_dy;
  Scalar fk = (pos[2] - m_min_z) / m_dz;

  if (m_interpolate_mesh)
  {
    // find the indices of the box surrounding the point
    int ii0 = (int) floor(fi);
    int ij0 = (int) floor(fj);
    int ik0 = (int) floor(fk);
    Scalar x = fi - ii0;
    Scalar y = fj - ij0;
    Scalar z = fk - ik0;
    int ii1 = ii0 + 1;
    int ij1 = ij0 + 1;
    int ik1 = ik0 + 1;
    if (ii0 < 0) ii0 = 0; else if (ii0 >= m_nx) ii0 = m_nx - 1;
    if (ij0 < 0) ij0 = 0; else if (ij0 >= m_ny) ij0 = m_ny - 1;
    if (ik0 < 0) ik0 = 0; else if (ik0 >= m_nz) ik0 = m_nz - 1;
    if (ii1 < 0) ii1 = 0; else if (ii1 >= m_nx) ii1 = m_nx - 1;
    if (ij1 < 0) ij1 = 0; else if (ij1 >= m_ny) ij1 = m_ny - 1;
    if (ik1 < 0) ik1 = 0; else if (ik1 >= m_nz) ik1 = m_nz - 1;

    const Datum & datum000 = m_data(ii0, ij0, ik0);
    const Datum & datum001 = m_data(ii0, ij0, ik1);
    const Datum & datum010 = m_data(ii0, ij1, ik0);
    const Datum & datum011 = m_data(ii0, ij1, ik1);
    const Datum & datum100 = m_data(ii1, ij0, ik0);
    const Datum & datum101 = m_data(ii1, ij0, ik1);
    const Datum & datum110 = m_data(ii1, ij1, ik0);
    const Datum & datum111 = m_data(ii1, ij1, ik1);

    dir = interp3d<Vector3>(datum000.dir, 
                            datum001.dir,
                            datum010.dir,
                            datum011.dir,
                            datum100.dir,
                            datum101.dir,
                            datum110.dir,
                            datum111.dir,
                            x, y, z);
    dir.normalise();
    dist = interp3d<Scalar>(datum000.dist, 
                            datum001.dist,
                            datum010.dist,
                            datum011.dist,
                            datum100.dist,
                            datum101.dist,
                            datum110.dist,
                            datum111.dist,
                            x, y, z);
    if (dist > 0)
    {
      return true;
    }
    else
    {
      dist = -dist;
      /*
      if ((accurate_outside_dist) ||
          (ii == 0) || (ii == (m_nx - 1)) ||
          (ij == 0) || (ij == (m_ny - 1)) ||
          (ik == 0) || (ik == (m_nz - 1)))
      {
        // query point was probably outside - add on the extra distance.
        dist += (pos - Position(m_min_x + ii * m_dx, m_min_y + ij * m_dy, m_min_z + ik * m_dz)).mag();
      }
      */
      return false;
    }

  }
  else // interpolate or use closest
  {
    int ii = (int) floor(fi + 0.5f);
    int ij = (int) floor(fj + 0.5f);
    int ik = (int) floor(fk + 0.5f);

    if (ii < 0) ii = 0; else if (ii >= m_nx) ii = m_nx - 1;
    if (ij < 0) ij = 0; else if (ij >= m_ny) ij = m_ny - 1;
    if (ik < 0) ik = 0; else if (ik >= m_nz) ik = m_nz - 1;

    const Datum & datum = m_data(ii, ij, ik);
    dir = datum.dir;
    dist = datum.dist;
    // need to account for the difference between the actual point and
    // the datum point.
    Scalar extra_dist = dot(
      dir,
      pos - Position(m_min_x + ii * m_dx, 
                     m_min_y + ij * m_dy,
                     m_min_z + ik * m_dz) );
    dist -= extra_dist;
  
    if (dist > 0)
    {
      return true;
    }
    else
    {
      dist = -dist;
      if ((accurate_outside_dist) ||
          (ii == 0) || (ii == (m_nx - 1)) ||
          (ij == 0) || (ij == (m_ny - 1)) ||
          (ik == 0) || (ik == (m_nz - 1)))
      {
        // query point was probably outside - add on the extra distance.
        dist += (pos - Position(m_min_x + ii * m_dx, m_min_y + ij * m_dy, m_min_z + ik * m_dz)).mag();
      }
      return false;
    }
  }
}

//==============================================================
// display_object
//==============================================================
void Collision_body_mesh::display_object()
{
}
