#include <matlabint_misc.h>
#include <getfem_derivatives.h>
#include <getfem_norm.h>
#include <getfem_export.h>

using namespace matlabint;

static void
error_for_non_lagrange_elements(getfem::mesh_fem &mf, bool warning_only = false)
{
  dal::bit_vector nn = mf.convex_index();

  size_type cv, cnt=0, total=0;
  for (cv << nn; cv != getfem::ST_NIL; cv << nn) {
    if (!mf.fem_of_element(cv)->is_lagrange()) cnt++; 
    total++;
  }
  if (cnt) {
    if (!warning_only) {
      DAL_THROW(matlabint_error, "Error: " << cnt << " elements on " << total << " are NOT lagrange elements -- Unable to compute a derivative");
    } else {
      mexPrintf("Warning: %d elements on %d are NOT lagrange elements\n", cnt, total);
    }
  }
}

/* utility function for eval_on_P1_tri_mesh */
static void 
eval_sub_nodes(unsigned N, bgeot::pconvex_ref cv, std::vector<getfem::base_node>& spt)
{
  assert(N>0);
  spt.resize(((N+1)*(N+2))/2);
  
  if (cv->nb_points() != 3) {
    DAL_THROW(matlabint_error, "Error, non triangular element with " << 
	      cv->nb_points() << " vertices");
  } else {   
    /*
      .      pt[0]
      .        /\
      .       /  \
      . pt[2]/____\pt[1]
      
      refinment:
      
      .         0     <- layer 0
      .        111    <- layer 1 etc..
      .       22222
      
    */
    
    spt[0] = cv->points()[0];
    size_type pcnt = 1;

    /* find the three nodes of the each sub-triangle */
    for (size_type layer = 1; layer <= N; layer++) {
      getfem::base_node A,B;
      getfem::scalar_type c;
      c = ((getfem::scalar_type)layer)/N;

      A = cv->points()[0] + (cv->points()[2] - cv->points()[0]) * c;
      B = cv->points()[0] + (cv->points()[1] - cv->points()[0]) * c;

      for (size_type inode = 0; inode <= layer; inode++, pcnt++) {
	spt[pcnt] = A + (B-A) * ((inode)/(getfem::scalar_type)layer);
	/*	mexPrintf("layer %d, inode %d: [%f,%f]\n",
		layer, inode, spt[pcnt][0], spt[pcnt][1]);*/
      }
    }
    if (!(pcnt == spt.size())) THROW_INTERNAL_ERROR;
  }
}

static void
eval_on_P1_tri_mesh(getfem::mesh_fem *mf, matlabint::mlab_vect& U, 
		    matlabint::mexargs_in& in, matlabint::mexargs_out& out)
{
  int Nrefine = in.pop().to_integer(1, 1000);
  dal::bit_vector cvlst;
  if (in.remaining()) {
    cvlst = in.pop().to_bit_vector();
  } else {
    cvlst = mf->convex_index();
  }
  unsigned nb_cv = cvlst.card();
  if (U.getm() != 1) {
    DAL_THROW(matlabint_bad_arg, "Only row vector authorized");
  }
  unsigned mesh_dim = mf->linked_mesh().dim();
  if (mesh_dim < 2 || mesh_dim > 3) {
    DAL_THROW(matlabint_error, "This function do not handle " << 
	      mesh_dim << "D meshes (only 2D or 3D)");
  }
  mlab_vect w = out.pop().create_vector(3 * mesh_dim + 3 , nb_cv*(Nrefine)*(Nrefine));
  std::vector<getfem::base_node> pt(Nrefine * Nrefine);
  getfem::size_type cv;
  size_type cv_cnt = 0;
  size_type refined_tri_cnt = 0;
  for (cv << cvlst; cv != getfem::ST_NIL; cv << cvlst, cv_cnt++) {
    getfem::pfem cv_fem(mf->fem_of_element(cv));
    bgeot::pconvex_ref cv_ref(cv_fem->ref_convex());
    bgeot::pgeometric_trans pgt = mf->linked_mesh().trans_of_convex(cv);
    eval_sub_nodes(Nrefine, cv_ref, pt);
    std::vector<getfem::scalar_type> pt_val(pt.size());

    interpolate_on_convex_ref(mf, cv, pt, U, 1, pt_val);

    /* apply the geometric transformation to the points, in order to
       find their real location on the mesh */
    for (std::vector<getfem::base_node>::iterator it = pt.begin();
	 it < pt.end(); it++) {
      getfem::base_node P(it->size()); P.fill(0.0);
      for (getfem::size_type j = 0; j < pgt->nb_points(); ++j) {
	P.addmul(pgt->poly_vector()[j].eval(it->begin()),
		 mf->linked_mesh().points_of_convex(cv)[j]);
      }
      //  mexPrintf("%d %d [%f %f]->[%f %f]\n", cv, it->size(), (*it)[0], 
      //    (*it)[1], P[0], P[1]);
      *it = P;
    }

    /* find the three nodes of the each sub-triangle */
    for (int layer = 0; layer < Nrefine; layer++) {
      for (int itri = 0; itri < layer*2+1; itri++, refined_tri_cnt++) {
	getfem::size_type n[3];

	if ((itri & 1) == 0) {
	  /*
	    .           0
	    .          /\
	    .       2 /__\ 1
	  */
	  n[0] = (layer*(layer+1))/2 + itri/2;
	  n[1] = n[0] + layer+1; 
	  n[2] = n[1]+1;
	} else {
	  /*
	    .       1 ____ 2
	    .         \  / 
	    .          \/
	    .           0
	  */
	  n[1] = (layer*(layer+1))/2 + itri/2;
	  n[2] = n[1]+1;
	  n[0] = n[2]+layer+1;
	}
	
	if (!(n[0] < pt.size()) || !(n[1] < pt.size()) || !(n[2] < pt.size())) {
	  mexPrintf("n=%d,%d,%d pt.size = %d\n", n[0], n[1], n[2], pt.size());
	  THROW_INTERNAL_ERROR;
	}
	/*
	mexPrintf("sous-triangle %d [%d] : [%f,%f] - [%f,%f] - [%f,%f] - [%f,%f,%f]\n",
		  refined_tri_cnt, cv_cnt, pt[a][0], pt[a][1], pt[b][0], pt[b][1], pt[c][0], pt[c][1],
		  pt_val[a], pt_val[b], pt_val[c]);
	*/
	
	for (size_type ipt = 0; ipt < 3; ipt++) {
	  for (size_type idim = 0; idim < mesh_dim; idim++) {
	    w(ipt * mesh_dim + idim, refined_tri_cnt) = pt[n[ipt]][idim];
	  }
	  w(3*mesh_dim + ipt, refined_tri_cnt) = pt_val[n[ipt]];
	}
      }
    }    
  }
  if (cv_cnt != nb_cv) THROW_INTERNAL_ERROR;
  if (refined_tri_cnt != w.getn()) THROW_INTERNAL_ERROR;
}


void
mesh_edges_deformation(getfem::mesh_fem *mf, mlab_vect &U, unsigned N, 
		       mexargs_in &in, mexargs_out &out)
{
  if (U.getm() != mf->linked_mesh().dim()) {
    DAL_THROW(matlabint_bad_arg, "Error, the supplied field is a flow of dimension " << 
	      U.getm() << " while the mesh is of dimension " << mf->linked_mesh().dim());
  }
  getfem::edge_list el;
  getfem::getfem_mesh &m = mf->linked_mesh();

  build_edge_list(m, el, in);
  
  mlab_vect w   = out.pop().create_vector(m.dim(), N, el.size());

  getfem::edge_list::const_iterator it = el.begin();
  while (it != el.end()) {
    getfem::edge_list::const_iterator nit = it;

    /* count consecutives edges of one convex */
    size_type ecnt = 0;
    while ((*nit).cv == (*it).cv) {
      ecnt++; nit++;
    }
    unsigned cv = (*it).cv;
    getfem::getfem_mesh::ref_convex cv_ref = m.convex(cv);
    bgeot::pgeometric_trans pgt = m.trans_of_convex(cv);

    std::vector<getfem::base_node> pt;

    /* for every edge of the convex, push the points of its refined edge
       on the vector 'pt' */
    for (getfem::edge_list::const_iterator eit = it; eit != nit; eit++) {
      /* build the list of points on the edge, on the reference element */
      /* get local point numbers in the convex */
      bgeot::size_type iA = m.local_ind_of_convex_point(cv, (*eit).i);
      bgeot::size_type iB = m.local_ind_of_convex_point(cv, (*eit).j);
      
      getfem::base_node A = cv_ref.points()[iA];
      getfem::base_node B = cv_ref.points()[iB];
      for (size_type i = 0; i < N; i++) {
	pt.push_back(A +  (B-A)*(i/(double)(N-1)));
      }
    }
    if (pt.size() != ecnt * N) THROW_INTERNAL_ERROR;

    /* now, evaluate the field U on every point of pt et once */
    std::vector<getfem::scalar_type> pt_val;
    interpolate_on_convex_ref(mf, cv, pt, U, U.getm(), pt_val);

    if (pt_val.size() != ecnt * N * U.getm()) THROW_INTERNAL_ERROR;

    /* evaluate the point location on the real mesh, adds it 
       the 'deformation' field pt_val interpolated from U,
       and write the result in the destination vector */
    for (ecnt = 0; it != nit; it++, ecnt++) {
      for (size_type i = 0; i < N; i++) {
	getfem::base_node def_pt = pgt->transform(pt[i], m.points_of_convex(cv));
	for (size_type k = 0; k < U.getm(); k++) {
	  def_pt[k] += pt_val[(ecnt*N + i) * U.getm() + k];
	}
	std::copy(def_pt.begin(), def_pt.end(), &w(0,i, ecnt));
      }
    }
  }
}


/*MLABCOM
  FUNCTION [x] = gf_compute(meshfem MF, vec U, operation [, args])

  Various computations involving the solution U of the finite element problem.

  * N = gf_compute(MF, U, 'L2 norm' [,CVLST])
  Computes the L2 norm of U. If CVLST is indicated, the norm will be
  computed only on the listed convexes.

  * N = gf_compute(MF, U, 'H1 semi norm' [,CVLST])
  Computes the L2 norm of grad(U).

  * N = gf_compute(MF, U, 'H1 norm' [,CVLST])
  Computes the H1 norm of U.

  * DU = gf_compute(MF, U, 'gradient', mesh_fem MFGRAD)
  Computes the gradient of the field U defined on meshfem MF. The
  gradient is interpolated on the meshfem MFGRAD, and returned in DU.
  For example, if U is defined on a P2 mesh_fem, DU should be
  evaluated on a P1-discontinuous mesh_fem. MF and MFGRAD should share
  the same mesh.

  * U2 = gf_compute(MF, U, 'interpolate on', MF2)
  Interpolates a field defined on mesh_fem MF on another (lagrangian)
  mesh_fem MF2. If MF and MF2 share the same mesh object, the 
  interpolation will be much faster.


  * [U2[,MF2,[,X[,Y[,Z]]]]] = gf_compute(MF,U,'interpolate on Q1 grid', 
                               {'regular h', hxyz | 'regular N',Nxyz |
           			   X[,Y[,Z]]}

  Creates a cartesian Q1 mesh fem and interpolates U on it. The
  returned field U2 is organized in a matrix such that in can be drawn
  via the MATLAB command 'pcolor'.

  * E = gf_compute(MF, U, 'mesh edges deformation', N [,vec or 
                   mat CVLIST])

  Evaluates the deformation of the mesh caused by the field U (for a
  2D mesh, U must be a [2 x nb_dof] matrix). N is the refinment level
  (N>=2) of the edges.  CVLIST can be used to restrict the computation
  to the edges of the listed convexes ( if CVLIST is a row vector ),
  or to restrict the computations to certain faces of certain convexes
  when CVLIST is a two-rows matrix, the first row containing convex
  numbers and the second face numbers.

  * UP = gf_compute(MF, U, 'eval on refined P1 tri mesh', int Nrefine,
                    [vec CVLIST])
  Utility function designed for 2D triangular meshes : returns a list
  of triangles coordinates with interpolated U values. This can be
  used for the acturate visualisation of data defined on a
  discontinous high order element. On output, the six first rows of UP
  contains the triangle coordinates, and the others rows contain the
  interpolated values of U (one for each triangle vertex) CVLIST may
  indicate the list of convex number that should be consider, if not
  used then all the mesh convexes will be used. U should be a row
  vector.

  TODO : rewrite this function to handle curved triangles &
  quadrangles -- check if 'interpolate on' would work on discont
  mesh_fems ..

  $Id: gf_compute.C,v 1.6 2002/09/05 14:48:34 pommier Exp $
MLABCOM*/

void gf_compute(matlabint::mexargs_in& in, matlabint::mexargs_out& out)
{
  if (in.narg() < 3) {
    DAL_THROW(matlabint_bad_arg, "Wrong number of input arguments");
  }

  mesh_fem_int *mf       = in.pop().to_mesh_fem();
  mlab_vect U            = in.pop().to_scalar_vector(-1, mf->nb_dof());
  std::string cmd        = in.pop().to_string();

  unsigned dim = U.getm();
  
  if (check_cmd(cmd, "L2 norm", in, out, 0, 1, 0, 1)) {
    dal::bit_vector bv = in.remaining() ? 
      in.pop().to_bit_vector(&mf->convex_index()) : mf->convex_index();
    out.pop().from_scalar(getfem::L2_norm(*mf, U, dim, bv));
  } else if (check_cmd(cmd, "H1 semi norm", in, out, 0, 1, 0, 1)) {
    dal::bit_vector bv = in.remaining() ? 
      in.pop().to_bit_vector(&mf->convex_index()) : mf->convex_index();
    out.pop().from_scalar(getfem::H1_semi_norm(*mf, U, dim, bv));
  } else if (check_cmd(cmd, "H1 norm", in, out, 0, 1, 0, 1)) {
    dal::bit_vector bv = in.remaining() ? 
      in.pop().to_bit_vector(&mf->convex_index()) : mf->convex_index();
    out.pop().from_scalar(getfem::H1_norm(*mf, U, dim, bv));
  } else if (check_cmd(cmd, "gradient", in, out, 1, 1, 0, 1)) {
    mesh_fem_int *mf_grad = in.pop().to_mesh_fem();
    error_for_non_lagrange_elements(*mf_grad);
    mlab_vect DU = out.pop().create_vector(dim*mf->linked_mesh().dim(), mf_grad->nb_dof());
    /* compute_gradient also check that the meshes are the same */
    getfem::compute_gradient(*mf, *mf_grad, U, DU, dim);
  } else if (check_cmd(cmd, "eval on refined P1 tri mesh", in, out, 1, 2, 0, 1)) {
    eval_on_P1_tri_mesh(mf, U, in, out);
  } else if (check_cmd(cmd, "interpolate on", in, out, 1, 1, 0, 1)) {
    mesh_fem_int *mf_dest = in.pop().to_mesh_fem();
    error_for_non_lagrange_elements(*mf_dest, true);
    mlab_vect U2 = out.pop().create_vector(dim, mf_dest->nb_dof());
    getfem::interpolation_solution(*mf, *mf_dest,
				   U, U2, dim);
  } else if (check_cmd(cmd, "interpolate on Q1 grid", in, out, 1, 100, 0, 100)) {
    in.restore(0); in.restore(1); /* push back the mf_u and u */
    call_matlab_function("gf_compute_Q1grid_interp", in,out);
  } else if (check_cmd(cmd, "mesh edges deformation", in, out, 1, 2, 0, 1)) {
    unsigned N = in.pop().to_integer(2,10000);
    mesh_edges_deformation(mf, U, N, in, out);
  } else  bad_cmd(cmd);
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
  catch_errors(nlhs, plhs, nrhs, prhs, gf_compute);
}
