#include "./LsfReinitializationAnalysis.h"
#include "../toolkits/toolkits.h"
#include "../classes/classes.h"
#include "../shared/shared.h"
#include "../modules/modules.h"
#include "../solutionsequences/solutionsequences.h"
#include "../cores/cores.h"

#include "../modules/GetVectorFromInputsx/GetVectorFromInputsx.h"

/*Model processing*/
int  LsfReinitializationAnalysis::DofsPerNode(int** doflist,int domaintype,int approximation){/*{{{*/
	return 1;
}/*}}}*/
void LsfReinitializationAnalysis::UpdateParameters(Parameters* parameters,IoModel* iomodel,int solution_enum,int analysis_enum){/*{{{*/
	/* Do nothing for now */
}/*}}}*/
void LsfReinitializationAnalysis::UpdateElements(Elements* elements,IoModel* iomodel,int analysis_counter,int analysis_type){/*{{{*/
	int    finiteelement;

	/*Finite element type*/
	finiteelement = P1Enum;

	/*Update elements: */
	int counter=0;
	for(int i=0;i<iomodel->numberofelements;i++){
		if(iomodel->my_elements[i]){
			Element* element=(Element*)elements->GetObjectByOffset(counter);
			element->Update(i,iomodel,analysis_counter,analysis_type,finiteelement);
			counter++;
		}
	}

	iomodel->FetchDataToInput(elements,MaskIceLevelsetEnum);
}/*}}}*/
void LsfReinitializationAnalysis::CreateNodes(Nodes* nodes,IoModel* iomodel){/*{{{*/
	int finiteelement=P1Enum;
	if(iomodel->domaintype!=Domain2DhorizontalEnum) iomodel->FetchData(2,MeshVertexonbaseEnum,MeshVertexonsurfaceEnum);
	::CreateNodes(nodes,iomodel,LsfReinitializationAnalysisEnum,finiteelement);
	iomodel->DeleteData(2,MeshVertexonbaseEnum,MeshVertexonsurfaceEnum);
}/*}}}*/
void LsfReinitializationAnalysis::CreateConstraints(Constraints* constraints,IoModel* iomodel){/*{{{*/
	/* Do nothing for now */
}/*}}}*/
void LsfReinitializationAnalysis::CreateLoads(Loads* loads, IoModel* iomodel){/*{{{*/
	/* Do nothing for now */
}/*}}}*/

/*Finite element Analysis*/
void  LsfReinitializationAnalysis::Core(FemModel* femmodel){/*{{{*/

	/*parameters: */
	bool save_results;
	int maxiter = 3;
	int step;
	IssmDouble reltol = 0.05;

	Vector<IssmDouble>* lsfg     = NULL;
	Vector<IssmDouble>* lsfg_old = NULL;

	femmodel->parameters->FindParam(&save_results,SaveResultsEnum);

	/*activate formulation: */
	femmodel->SetCurrentConfiguration(LsfReinitializationAnalysisEnum);

	/* set spcs for reinitialization */
	if(VerboseSolution()) _printf0_("Update spcs for reinitialization:\n");
	SetReinitSPCs(femmodel);

	step = 1;
	for(;;){

		_printf_("smoothing lsf slope\n");
		/* smoothen slope of lsf for computation of normal on ice domain*/
		levelsetfunctionslope_core(femmodel);

		//solve current artificial time step
		if(VerboseSolution()) _printf0_("call computational core for reinitialization in step " << step << ":\n");
		solutionsequence_linear(femmodel);
		GetSolutionFromInputsx(&lsfg,femmodel);

		if(step>1){
			if(VerboseSolution()) _printf0_("   checking reinitialization convergence\n");
			if(ReinitConvergence(lsfg,lsfg_old,reltol)) break;
		}
		if(step>maxiter){
			if(VerboseSolution()) _printf0_("   maximum number reinitialization iterations " << maxiter << " reached\n");
			break;
		}

		/*update results and increase counter*/
		delete lsfg_old;lsfg_old=lsfg;
		step++;
	}

	if(save_results){
		if(VerboseSolution()) _printf0_("   saving results\n");
		int outputs[1] = {MaskIceLevelsetEnum};
		femmodel->RequestedOutputsx(&femmodel->results,&outputs[0],1);
	}

}/*}}}*/
ElementVector* LsfReinitializationAnalysis::CreateDVector(Element* element){/*{{{*/
	/*Default, return NULL*/
	return NULL;
}/*}}}*/
ElementMatrix* LsfReinitializationAnalysis::CreateJacobianMatrix(Element* element){/*{{{*/
	_error_("not implemented yet");
}/*}}}*/
ElementMatrix* LsfReinitializationAnalysis::CreateKMatrix(Element* element){/*{{{*/
	
	/*Intermediaries */
	const int dim = 2;
	int        i,row,col,stabilization;
	IssmDouble Jdet,D_scalar;
	IssmDouble dtau = 1.;
	IssmDouble mu = 1.;
	IssmDouble* xyz_list = NULL;

	/*Fetch number of nodes and dof for this finite element*/
	int numnodes = element->GetNumberOfNodes();

	/*Initialize Element vector and other vectors*/
	ElementMatrix* Ke     = element->NewElementMatrix();
	IssmDouble*    basis    = xNew<IssmDouble>(numnodes);
	IssmDouble*    Bprime = xNew<IssmDouble>(dim*numnodes);
	IssmDouble*    D		= xNew<IssmDouble>(dim*dim);
	IssmDouble*    dlsf	= xNew<IssmDouble>(dim);
	IssmDouble*    normal= xNew<IssmDouble>(dim);

	element->GetVerticesCoordinates(&xyz_list);

	/* Start  looping on the number of gaussian points: */
	Gauss* gauss=element->NewGauss(2);
	for(int ig=gauss->begin();ig<gauss->end();ig++){/*{{{*/
		gauss->GaussPoint(ig);

		element->JacobianDeterminant(&Jdet,xyz_list,gauss);
		D_scalar=gauss->weight*Jdet;

		if(dtau!=0.){
			element->NodalFunctions(basis,gauss);
			TripleMultiply(basis,numnodes,1,0,
						&D_scalar,1,1,0,
						basis,1,numnodes,0,
						&Ke->values[0],1);
			D_scalar*=dtau;
		}

		GetBprime(Bprime,element,xyz_list,gauss);

		for(row=0;row<dim;row++)
			for(col=0;col<dim;col++)
				if(row==col)
					D[row*dim+col]=mu*D_scalar;
				else
					D[row*dim+col]=0.;
		TripleMultiply(Bprime,dim,numnodes,1,
					D,dim,dim,0,
					Bprime,dim,numnodes,0,
					&Ke->values[0],1);

		/* Stabilization */
		stabilization=0;
		if (stabilization==0){/* no stabilization, do nothing*/}
		
	}/*}}}*/

	/*Clean up and return*/
	xDelete<IssmDouble>(xyz_list);
	xDelete<IssmDouble>(basis);
	xDelete<IssmDouble>(Bprime);
	xDelete<IssmDouble>(D);
	delete gauss;
	return Ke;
}/*}}}*/
ElementVector* LsfReinitializationAnalysis::CreatePVector(Element* element){/*{{{*/
	
	/*Intermediaries */
	int i,k;
	int dim = 2;
	IssmDouble dtau = 1.;
	IssmDouble mu = 1.;
	IssmDouble Jdet, D_scalar;
	IssmDouble lsf;
	IssmDouble norm_dlsf;
	IssmDouble dbasis_normal;

	/*Fetch number of nodes */
	int numnodes = element->GetNumberOfNodes();

	IssmDouble* xyz_list = NULL;
	IssmDouble* basis = xNew<IssmDouble>(numnodes);
	IssmDouble* dbasis=xNew<IssmDouble>(dim*numnodes);
	IssmDouble* dlsf = xNew<IssmDouble>(dim);
	IssmDouble* normal = xNew<IssmDouble>(dim);
	element->GetVerticesCoordinates(&xyz_list);

	/*Initialize Element vector*/
	ElementVector* pe = element->NewElementVector();

	/*Retrieve all inputs and parameters*/
	Input* lsf_input = element->GetInput(MaskIceLevelsetEnum); _assert_(lsf_input);
	Input* lsf_slopex_input=element->GetInput(LevelsetfunctionSlopeXEnum); _assert_(lsf_slopex_input);
	Input* lsf_slopey_input=element->GetInput(LevelsetfunctionSlopeYEnum); _assert_(lsf_slopey_input);

	Gauss* gauss=element->NewGauss(2);
	for(int ig=gauss->begin();ig<gauss->end();ig++){
		gauss->GaussPoint(ig);

		element->JacobianDeterminant(&Jdet,xyz_list,gauss);
		element->NodalFunctions(basis,gauss);
		element->NodalFunctionsDerivatives(dbasis,xyz_list,gauss);

		D_scalar=Jdet*gauss->weight;

		if(dtau!=0.){
			/* old function value */
			lsf_input->GetInputValue(&lsf,gauss);
			for(i=0;i<numnodes;i++) pe->values[i]+=D_scalar*lsf*basis[i];
			D_scalar*=dtau;
		}

		lsf_slopex_input->GetInputValue(&dlsf[0],gauss);
		lsf_slopey_input->GetInputValue(&dlsf[1],gauss);

		/*get normal*/
		norm_dlsf=0.;
		for(i=0;i<dim;i++) norm_dlsf+=dlsf[i]*dlsf[i]; 
		norm_dlsf=sqrt(norm_dlsf);
		if(norm_dlsf>0.)
			for(i=0;i<dim;i++)	normal[i]=dlsf[i]/norm_dlsf;
		else
			for(i=0;i<dim;i++)	normal[i]=0.;

		/* multiply normal and dbasis */
		for(i=0;i<numnodes;i++){
			dbasis_normal=0.;
			for(k=0;k<dim;k++) dbasis_normal+=dbasis[k*numnodes+i]*normal[k];
			pe->values[i]+=D_scalar*mu*dbasis_normal; 
		}
	}

	xDelete<IssmDouble>(basis);
	xDelete<IssmDouble>(dbasis);
	xDelete<IssmDouble>(xyz_list);
	xDelete<IssmDouble>(dlsf);
	xDelete<IssmDouble>(normal);
	return pe;
	}/*}}}*/
void LsfReinitializationAnalysis::GetSolutionFromInputs(Vector<IssmDouble>* solution,Element* element){/*{{{*/

	IssmDouble   lsf;
	int          dim;
	int*         doflist = NULL;

	/*Get some parameters*/
	element->FindParam(&dim,DomainDimensionEnum);

	/*Fetch number of nodes and dof for this finite element*/
	int numnodes = element->GetNumberOfNodes();

	/*Fetch dof list and allocate solution vector*/
	element->GetDofList(&doflist,NoneApproximationEnum,GsetEnum);
	IssmDouble* values = xNew<IssmDouble>(numnodes);

	/*Get inputs*/
	Input* lsf_input=element->GetInput(MaskIceLevelsetEnum); _assert_(lsf_input);

	Gauss* gauss=element->NewGauss();
	for(int i=0;i<numnodes;i++){
		gauss->GaussNode(element->FiniteElement(),i);

		lsf_input->GetInputValue(&lsf,gauss);
		values[i]=lsf;
	}

	solution->SetValues(numnodes,doflist,values,INS_VAL);

	/*Free ressources:*/
	delete gauss;
	xDelete<IssmDouble>(values);
	xDelete<int>(doflist);

}/*}}}*/
void LsfReinitializationAnalysis::InputUpdateFromSolution(IssmDouble* solution,Element* element){/*{{{*/

	int domaintype;
	element->FindParam(&domaintype,DomainTypeEnum);
	switch(domaintype){
		case Domain2DhorizontalEnum:
			element->InputUpdateFromSolutionOneDof(solution,MaskIceLevelsetEnum);
			break;
		case Domain3DEnum:
			element->InputUpdateFromSolutionOneDofCollapsed(solution,MaskIceLevelsetEnum);
			break;
		default: _error_("mesh "<<EnumToStringx(domaintype)<<" not supported yet");
	}
}/*}}}*/
void LsfReinitializationAnalysis::UpdateConstraints(FemModel* femmodel){/*{{{*/
	/* Do nothing for now */
}/*}}}*/
void LsfReinitializationAnalysis::GetB(IssmDouble* B,Element* element,IssmDouble* xyz_list,Gauss* gauss){/*{{{*/
	/*Compute B  matrix. B=[B1 B2 B3] where Bi is of size 3*NDOF2. 
	 * For node i, Bi can be expressed in the actual coordinate system
	 * by: 
	 *       Bi=[ N ]
	 *          [ N ]
	 * where N is the finiteelement function for node i.
	 *
	 * We assume B_prog has been allocated already, of size: 2x(NDOF1*numnodes)
	 */

	/*Fetch number of nodes for this finite element*/
	int numnodes = element->GetNumberOfNodes();

	/*Get nodal functions*/
	IssmDouble* basis=xNew<IssmDouble>(numnodes);
	element->NodalFunctions(basis,gauss);

	/*Build B: */
	for(int i=0;i<numnodes;i++){
		B[numnodes*0+i] = basis[i];
		B[numnodes*1+i] = basis[i];
	}

	/*Clean-up*/
	xDelete<IssmDouble>(basis);
}/*}}}*/
void LsfReinitializationAnalysis::GetBprime(IssmDouble* Bprime,Element* element,IssmDouble* xyz_list,Gauss* gauss){/*{{{*/
	/*Compute B'  matrix. B'=[B1' B2' B3'] where Bi' is of size 3*NDOF2. 
	 * For node i, Bi' can be expressed in the actual coordinate system
	 * by: 
	 *       Bi_prime=[ dN/dx ]
	 *                [ dN/dy ]
	 * where N is the finiteelement function for node i.
	 *
	 * We assume B' has been allocated already, of size: 3x(NDOF2*numnodes)
	 */

	/*Fetch number of nodes for this finite element*/
	int numnodes = element->GetNumberOfNodes();

	/*Get nodal functions derivatives*/
	IssmDouble* dbasis=xNew<IssmDouble>(2*numnodes);
	element->NodalFunctionsDerivatives(dbasis,xyz_list,gauss);

	/*Build B': */
	for(int i=0;i<numnodes;i++){
		Bprime[numnodes*0+i] = dbasis[0*numnodes+i];
		Bprime[numnodes*1+i] = dbasis[1*numnodes+i];
	}

	/*Clean-up*/
	xDelete<IssmDouble>(dbasis);

}/*}}}*/

/* Other */
void LsfReinitializationAnalysis::SetReinitSPCs(FemModel* femmodel){/*{{{*/

	int i,k, numnodes;
	Element* element;
	Node* node;

	/* deactivate all spcs */
	for(i=0;i<femmodel->elements->Size();i++){
		element=dynamic_cast<Element*>(femmodel->elements->GetObjectByOffset(i));
		for(k=0;k<element->GetNumberOfNodes();k++){
			node=element->GetNode(k);
			if(node->IsActive()){
				node->DofInFSet(0); 
			}
		}
	}

	SetDistanceOnIntersectedElements(femmodel);

	/* reactivate spcs on elements intersected by zero levelset */
	for(i=0;i<femmodel->elements->Size();i++){
		element=dynamic_cast<Element*>(femmodel->elements->GetObjectByOffset(i));
		if(element->IsZeroLevelset(MaskIceLevelsetEnum)){
			/*iterate over nodes and set spc */
			numnodes=element->GetNumberOfNodes();
			IssmDouble* lsf = xNew<IssmDouble>(numnodes);
			element->GetInputListOnNodes(&lsf[0],MaskIceLevelsetEnum);
			for(k=0;k<numnodes;k++){
				node=element->GetNode(k);
				if(node->IsActive()){
					node->ApplyConstraint(0,lsf[k]);
				}
			}
			xDelete<IssmDouble>(lsf);
		}
	}

}/*}}}*/
void LsfReinitializationAnalysis::SetDistanceOnIntersectedElements(FemModel* femmodel){/*{{{*/

	/* Intermediaries */
	int i,k;
	IssmDouble dmaxp,dmaxm,val;
	Element* element;

	/*Initialize vector with number of vertices*/
	int numvertices=femmodel->vertices->NumberOfVertices();

	Vector<IssmDouble>* vec_dist_zerolevelset = NULL;
	GetVectorFromInputsx(&vec_dist_zerolevelset, femmodel, MaskIceLevelsetEnum, VertexEnum);
	
	/* set distance on elements intersected by zero levelset */
	for(i=0;i<femmodel->elements->Size();i++){
		element=dynamic_cast<Element*>(femmodel->elements->GetObjectByOffset(i));
		if(element->IsZeroLevelset(MaskIceLevelsetEnum)){
			SetDistanceToZeroLevelsetElement(vec_dist_zerolevelset, element);
		}
	}
	vec_dist_zerolevelset->Assemble();

	/* Get maximum distance to interface along vertices */
	dmaxp=0.; dmaxm=0.;
	for(i=0;i<numvertices;i++){
		vec_dist_zerolevelset->GetValue(&val,i); 
		if((val>0.) && (val>dmaxp))
			 dmaxp=val;
		else if((val<0.) && (val<dmaxm))
			 dmaxm=val;
	}
	//wait until all values are computed

	/* set all none intersected vertices to max/min distance */
	for(i=0;i<numvertices;i++){
		vec_dist_zerolevelset->GetValue(&val,i);
		if(val==1.) //FIXME: improve check
			vec_dist_zerolevelset->SetValue(i,3.*dmaxp,INS_VAL);
		else if(val==-1.)
			vec_dist_zerolevelset->SetValue(i,3.*dmaxm,INS_VAL);
	}

	/*Assemble vector and serialize */
	vec_dist_zerolevelset->Assemble();
	IssmDouble* dist_zerolevelset=vec_dist_zerolevelset->ToMPISerial();
	InputUpdateFromVectorx(femmodel,dist_zerolevelset,MaskIceLevelsetEnum,VertexSIdEnum);

	/*Clean up and return*/
	delete vec_dist_zerolevelset;
	delete dist_zerolevelset;

}/*}}}*/
void LsfReinitializationAnalysis::SetDistanceToZeroLevelsetElement(Vector<IssmDouble>* vec_signed_dist, Element* element){/*{{{*/

	if(!element->IsZeroLevelset(MaskIceLevelsetEnum))
		return;

	/* Intermediaries */
	int dim=3;
	int i,d;
	IssmDouble dist,lsf_old;

	int numvertices=element->GetNumberOfVertices();

	IssmDouble* lsf = xNew<IssmDouble>(numvertices);
	IssmDouble* sign_lsf = xNew<IssmDouble>(numvertices);
	IssmDouble* signed_dist = xNew<IssmDouble>(numvertices);
	IssmDouble* s0 = xNew<IssmDouble>(dim);
	IssmDouble* s1 = xNew<IssmDouble>(dim);
	IssmDouble* v = xNew<IssmDouble>(dim);
	IssmDouble* xyz_list = NULL;
	IssmDouble* xyz_list_zero = NULL;

	/* retrieve inputs and parameters */
	element->GetVerticesCoordinates(&xyz_list);
	element->GetInputListOnVertices(lsf,MaskIceLevelsetEnum);

	/* get sign of levelset function */
	for(i=0;i<numvertices;i++)
		sign_lsf[i]=(lsf[i]>=0.?1.:-1.);

	element->ZeroLevelsetCoordinates(&xyz_list_zero, xyz_list, MaskIceLevelsetEnum);
	for(i=0;i<dim;i++){
		s0[i]=xyz_list_zero[0+i];
		s1[i]=xyz_list_zero[3+i];
	}

	/* get signed_distance of vertices to zero levelset straight */
	for(i=0;i<numvertices;i++){
		for(d=0;d<dim;d++)
			v[d]=xyz_list[dim*i+d];
		dist=GetDistanceToStraight(&v[0],&s0[0],&s1[0]);
		signed_dist[i]=sign_lsf[i]*dist;
	}
	
	/* insert signed_distance into vec_signed_dist, if computed distance is smaller */
	for(i=0;i<numvertices;i++){
		vec_signed_dist->GetValue(&lsf_old, element->vertices[i]->Sid());
		/* initial lsf values are +-1. Update those fields or if distance to interface smaller.*/
		if(fabs(lsf_old)==1. || fabs(signed_dist[i])<fabs(lsf_old))
			vec_signed_dist->SetValue(element->vertices[i]->Sid(),signed_dist[i],INS_VAL);
	}

	xDelete<IssmDouble>(lsf);
	xDelete<IssmDouble>(sign_lsf);
	xDelete<IssmDouble>(signed_dist);
	xDelete<IssmDouble>(s0);
	xDelete<IssmDouble>(s1);
	xDelete<IssmDouble>(v);

}/*}}}*/
IssmDouble LsfReinitializationAnalysis::GetDistanceToStraight(IssmDouble* q, IssmDouble* s0, IssmDouble* s1){/*{{{*/
	// returns distance d of point q to straight going through points s0, s1
	// d=|a x b|/|b|
	// with a=q-s0, b=s1-s0
	
	/* Intermediaries */
	const int dim=2;
	int i;
	IssmDouble a[dim], b[dim];
	IssmDouble norm_b;

	for(i=0;i<dim;i++){
		a[i]=q[i]-s0[i];
		b[i]=s1[i]-s0[i];
	}
	
	norm_b=0.;
	for(i=0;i<dim;i++)
		norm_b+=b[i]*b[i];
	norm_b=sqrt(norm_b);
	_assert_(norm_b>0.);

	return fabs(a[0]*b[1]-a[1]*b[0])/norm_b;
}/*}}}*/
bool LsfReinitializationAnalysis::ReinitConvergence(Vector<IssmDouble>* lsfg,Vector<IssmDouble>* lsfg_old,IssmDouble reltol){/*{{{*/

	/*Output*/
	bool converged = true;

	/*Intermediary*/
	Vector<IssmDouble>* dlsfg    = NULL;
	IssmDouble          norm_dlsf,norm_lsf;

	/*compute norm(du)/norm(u)*/
	dlsfg=lsfg_old->Duplicate(); lsfg_old->Copy(dlsfg); dlsfg->AYPX(lsfg,-1.0);
	norm_dlsf=dlsfg->Norm(NORM_TWO); norm_lsf=lsfg_old->Norm(NORM_TWO);
	if (xIsNan<IssmDouble>(norm_dlsf) || xIsNan<IssmDouble>(norm_lsf)) _error_("convergence criterion is NaN!");
	if((norm_dlsf/norm_lsf)<reltol){
		if(VerboseConvergence()) _printf0_("\n"<<setw(50)<<left<<"   Velocity convergence: norm(du)/norm(u)"<<norm_dlsf/norm_lsf*100<<" < "<<reltol*100<<" %\n");
	}
	else{ 
		if(VerboseConvergence()) _printf0_("\n"<<setw(50)<<left<<"   Velocity convergence: norm(du)/norm(u)"<<norm_dlsf/norm_lsf*100<<" > "<<reltol*100<<" %\n");
		converged=false;
	}

	/*Cleanup*/
	delete dlsfg;

	return converged;
}/*}}}*/
