#include "./FreeSurfaceBaseAnalysis.h"
#include "../toolkits/toolkits.h"
#include "../classes/classes.h"
#include "../shared/shared.h"
#include "../modules/modules.h"

/*Model processing*/
void FreeSurfaceBaseAnalysis::CreateConstraints(Constraints* constraints,IoModel* iomodel){/*{{{*/
}/*}}}*/
void FreeSurfaceBaseAnalysis::CreateLoads(Loads* loads, IoModel* iomodel){/*{{{*/

	/*Intermediaries*/
	int penpair_ids[2];
	int count=0;
	int numvertex_pairing;

	/*Create Penpair for vertex_pairing: */
	IssmDouble *vertex_pairing=NULL;
	IssmDouble *nodeonbase=NULL;
	iomodel->FetchData(&vertex_pairing,&numvertex_pairing,NULL,"md.masstransport.vertex_pairing");
	if(iomodel->domaintype!=Domain2DhorizontalEnum) iomodel->FetchData(&nodeonbase,NULL,NULL,"md.mesh.vertexonbase");
	for(int i=0;i<numvertex_pairing;i++){

		if(iomodel->my_vertices[reCast<int>(vertex_pairing[2*i+0])-1]){

			/*In debugging mode, check that the second node is in the same cpu*/
			_assert_(iomodel->my_vertices[reCast<int>(vertex_pairing[2*i+1])-1]);

			/*Skip if one of the two is not on the bed*/
			if(iomodel->domaintype!=Domain2DhorizontalEnum){
				if(!(reCast<bool>(nodeonbase[reCast<int>(vertex_pairing[2*i+0])-1])) || !(reCast<bool>(nodeonbase[reCast<int>(vertex_pairing[2*i+1])-1]))) continue;
			}

			/*Get node ids*/
			penpair_ids[0]=reCast<int>(vertex_pairing[2*i+0]);
			penpair_ids[1]=reCast<int>(vertex_pairing[2*i+1]);

			/*Create Load*/
			loads->AddObject(new Penpair(count+1, &penpair_ids[0]));
			count++;
		}
	}

	/*free ressources: */
	iomodel->DeleteData(vertex_pairing,"md.masstransport.vertex_pairing");
	iomodel->DeleteData(nodeonbase,"md.mesh.vertexonbase");
}/*}}}*/
void FreeSurfaceBaseAnalysis::CreateNodes(Nodes* nodes,IoModel* iomodel,bool isamr){/*{{{*/

	if(iomodel->domaintype!=Domain2DhorizontalEnum) iomodel->FetchData(2,"md.mesh.vertexonbase","md.mesh.vertexonsurface");
	::CreateNodes(nodes,iomodel,FreeSurfaceBaseAnalysisEnum,P1Enum);
	iomodel->DeleteData(2,"md.mesh.vertexonbase","md.mesh.vertexonsurface");
}/*}}}*/
int  FreeSurfaceBaseAnalysis::DofsPerNode(int** doflist,int domaintype,int approximation){/*{{{*/
	return 1;
}/*}}}*/
void FreeSurfaceBaseAnalysis::UpdateElements(Elements* elements,Inputs2* inputs2,IoModel* iomodel,int analysis_counter,int analysis_type){/*{{{*/

	/*Now, is the model 3d? otherwise, do nothing: */
	if (iomodel->domaintype==Domain2DhorizontalEnum)return;

	/*Finite element type*/
	int finiteelement = P1Enum;

	/*Update elements: */
	int counter=0;
	for(int i=0;i<iomodel->numberofelements;i++){
		if(iomodel->my_elements[i]){
			Element* element=(Element*)elements->GetObjectByOffset(counter);
			element->Update(inputs2,i,iomodel,analysis_counter,analysis_type,finiteelement);
			counter++;
		}
	}

	iomodel->FetchDataToInput(inputs2,elements,"md.geometry.surface",SurfaceEnum);
	iomodel->FetchDataToInput(inputs2,elements,"md.solidearth.sealevel",SealevelEnum,0);
	iomodel->FetchDataToInput(inputs2,elements,"md.mask.ice_levelset",MaskIceLevelsetEnum);
	iomodel->FetchDataToInput(inputs2,elements,"md.basalforcings.groundedice_melting_rate",BasalforcingsGroundediceMeltingRateEnum,0.);
	iomodel->FetchDataToInput(inputs2,elements,"md.initialization.vx",VxEnum);
	iomodel->FetchDataToInput(inputs2,elements,"md.initialization.vy",VyEnum);
	if(iomodel->domaindim==3){
		iomodel->FetchDataToInput(inputs2,elements,"md.initialization.vz",VzEnum);
	}
	if(iomodel->domaintype!=Domain2DhorizontalEnum){
		iomodel->FetchDataToInput(inputs2,elements,"md.mesh.vertexonbase",MeshVertexonbaseEnum);
		iomodel->FetchDataToInput(inputs2,elements,"md.mesh.vertexonsurface",MeshVertexonsurfaceEnum);
	}

	/*Get what we need for ocean-induced basal melting*/
	int basalforcing_model;
	iomodel->FindConstant(&basalforcing_model,"md.basalforcings.model");
	switch(basalforcing_model){
		case FloatingMeltRateEnum:
			iomodel->FetchDataToInput(inputs2,elements,"md.basalforcings.floatingice_melting_rate",BasalforcingsFloatingiceMeltingRateEnum);
			break;
		case LinearFloatingMeltRateEnum:
			break;
		case MismipFloatingMeltRateEnum:
			break;
		case MantlePlumeGeothermalFluxEnum:
			break;
		case SpatialLinearFloatingMeltRateEnum:
			iomodel->FetchDataToInput(inputs2,elements,"md.basalforcings.deepwater_melting_rate",BasalforcingsDeepwaterMeltingRateEnum);
			iomodel->FetchDataToInput(inputs2,elements,"md.basalforcings.deepwater_elevation",BasalforcingsDeepwaterElevationEnum);
			iomodel->FetchDataToInput(inputs2,elements,"md.basalforcings.upperwater_elevation",BasalforcingsUpperwaterElevationEnum);
			break;
		case BasalforcingsPicoEnum:
			iomodel->FetchDataToInput(inputs2,elements,"md.basalforcings.basin_id",BasalforcingsPicoBasinIdEnum);
			iomodel->FetchDataToInput(inputs2,elements,"md.basalforcings.overturning_coeff",BasalforcingsPicoOverturningCoeffEnum);
			break;
		case BasalforcingsIsmip6Enum:
			iomodel->FetchDataToInput(inputs2,elements,"md.basalforcings.basin_id",BasalforcingsIsmip6BasinIdEnum);
			break;
		case BeckmannGoosseFloatingMeltRateEnum:
			iomodel->FetchDataToInput(inputs2,elements,"md.basalforcings.ocean_salinity",BasalforcingsOceanSalinityEnum);
			iomodel->FetchDataToInput(inputs2,elements,"md.basalforcings.ocean_temp",BasalforcingsOceanTempEnum);
			break;
		default:
			_error_("Basal forcing model "<<EnumToStringx(basalforcing_model)<<" not supported yet");
	}
}/*}}}*/
void FreeSurfaceBaseAnalysis::UpdateParameters(Parameters* parameters,IoModel* iomodel,int solution_enum,int analysis_enum){/*{{{*/
}/*}}}*/

/*Finite Element Analysis*/
void           FreeSurfaceBaseAnalysis::Core(FemModel* femmodel){/*{{{*/
	_error_("not implemented");
}/*}}}*/
ElementVector* FreeSurfaceBaseAnalysis::CreateDVector(Element* element){/*{{{*/
	/*Default, return NULL*/
	return NULL;
}/*}}}*/
ElementMatrix* FreeSurfaceBaseAnalysis::CreateJacobianMatrix(Element* element){/*{{{*/
_error_("Not implemented");
}/*}}}*/
ElementMatrix* FreeSurfaceBaseAnalysis::CreateKMatrix(Element* element){/*{{{*/

	/*Intermediaries*/
	int         domaintype,dim,stabilization;
	Element*    basalelement = NULL;
	IssmDouble *xyz_list  = NULL;
	IssmDouble  Jdet,D_scalar,dt,h;
	IssmDouble  vel,vx,vy;

	/*Get basal element*/
	element->FindParam(&domaintype,DomainTypeEnum);
	switch(domaintype){
		case Domain2DhorizontalEnum:
			basalelement = element;
			dim = 2;
			break;
		case Domain2DverticalEnum:
			if(!element->IsOnBase()) return NULL;
			basalelement = element->SpawnBasalElement();
			dim = 1;
			break;
		case Domain3DEnum:
			if(!element->IsOnBase()) return NULL;
			basalelement = element->SpawnBasalElement();
			dim = 2;
			break;
		default: _error_("mesh "<<EnumToStringx(domaintype)<<" not supported yet");
	}

	/*Fetch number of nodes and dof for this finite element*/
	int numnodes = basalelement->GetNumberOfNodes();

	/*Initialize Element vector*/
	ElementMatrix* Ke     = basalelement->NewElementMatrix(NoneApproximationEnum);
	IssmDouble*    basis  = xNew<IssmDouble>(numnodes);
	IssmDouble*    B      = xNew<IssmDouble>(dim*numnodes);
	IssmDouble*    Bprime = xNew<IssmDouble>(dim*numnodes);
	IssmDouble*    D      = xNew<IssmDouble>(dim*dim);

	/*Retrieve all inputs and parameters*/
	basalelement->GetVerticesCoordinates(&xyz_list);
	basalelement->FindParam(&dt,TimesteppingTimeStepEnum);
	basalelement->FindParam(&stabilization,MasstransportStabilizationEnum);
	Input2* vx_input=basalelement->GetInput2(VxEnum); _assert_(vx_input);
	Input2* vy_input=NULL;
	if(dim>1){vy_input = basalelement->GetInput2(VyEnum); _assert_(vy_input);}
	h = basalelement->CharacteristicLength();

	/* Start  looping on the number of gaussian points: */
	Gauss* gauss=basalelement->NewGauss(2);
	for(int ig=gauss->begin();ig<gauss->end();ig++){
		gauss->GaussPoint(ig);

		basalelement->JacobianDeterminant(&Jdet,xyz_list,gauss);
		basalelement->NodalFunctions(basis,gauss);

		vx_input->GetInputValue(&vx,gauss);
		if(dim==2) vy_input->GetInputValue(&vy,gauss);

		D_scalar=gauss->weight*Jdet;
		TripleMultiply(basis,1,numnodes,1,
					&D_scalar,1,1,0,
					basis,1,numnodes,0,
					&Ke->values[0],1);

		GetB(B,basalelement,dim,xyz_list,gauss);
		GetBprime(Bprime,basalelement,dim,xyz_list,gauss);

		D_scalar=dt*gauss->weight*Jdet;
		for(int i=0;i<dim*dim;i++) D[i]=0.;
		D[0] = D_scalar*vx;
		if(dim==2) D[1*dim+1] = D_scalar*vy;

		TripleMultiply(B,dim,numnodes,1,
					D,dim,dim,0,
					Bprime,dim,numnodes,0,
					&Ke->values[0],1);

		if(stabilization==2){
			/*Streamline upwinding*/
			if(dim==1){
			 vel=fabs(vx)+1.e-8;
			 D[0] = h/(2.*vel)*vx*vx;
			}
			else{
			 vel=sqrt(vx*vx+vy*vy)+1.e-8;
			 D[0*dim+0]=h/(2*vel)*vx*vx;
			 D[1*dim+0]=h/(2*vel)*vy*vx;
			 D[0*dim+1]=h/(2*vel)*vx*vy;
			 D[1*dim+1]=h/(2*vel)*vy*vy;
			}
		}
		else if(stabilization==1){
			/*SSA*/
			if(dim==1){
				vx_input->GetInputAverage(&vx);
				D[0]=h/2.*fabs(vx);
			}
			else{
				vx_input->GetInputAverage(&vx);
				vy_input->GetInputAverage(&vy);
				D[0*dim+0]=h/2.0*fabs(vx);
				D[1*dim+1]=h/2.0*fabs(vy);
			}
		}
		if(stabilization==1 || stabilization==2){
			for(int i=0;i<dim*dim;i++) D[i]=D_scalar*D[i];
			TripleMultiply(Bprime,dim,numnodes,1,
						D,dim,dim,0,
						Bprime,dim,numnodes,0,
						&Ke->values[0],1);
		}
	}

	/*Clean up and return*/
	xDelete<IssmDouble>(xyz_list);
	xDelete<IssmDouble>(basis);
	xDelete<IssmDouble>(B);
	xDelete<IssmDouble>(Bprime);
	xDelete<IssmDouble>(D);
	delete gauss;
	if(domaintype!=Domain2DhorizontalEnum){basalelement->DeleteMaterials(); delete basalelement;};
	return Ke;
}/*}}}*/
ElementVector* FreeSurfaceBaseAnalysis::CreatePVector(Element* element){/*{{{*/
	/*Intermediaries*/
	int         domaintype,dim;
	IssmDouble  Jdet,dt;
	IssmDouble  gmb,fmb,mb,bed,phi,vz;
	Element*    basalelement = NULL;
	IssmDouble *xyz_list  = NULL;

	/*Get basal element*/
	element->FindParam(&domaintype,DomainTypeEnum);
	switch(domaintype){
		case Domain2DhorizontalEnum:
			basalelement = element;
			dim = 2;
			break;
		case Domain2DverticalEnum:
			if(!element->IsOnBase()) return NULL;
			basalelement = element->SpawnBasalElement();
			dim = 1;
			break;
		case Domain3DEnum:
			if(!element->IsOnBase()) return NULL;
			basalelement = element->SpawnBasalElement();
			dim = 2;
			break;
		default: _error_("mesh "<<EnumToStringx(domaintype)<<" not supported yet");
	}

	/*Fetch number of nodes and dof for this finite element*/
	int numnodes = basalelement->GetNumberOfNodes();

	/*Initialize Element vector and other vectors*/
	ElementVector* pe    = basalelement->NewElementVector();
	IssmDouble*    basis = xNew<IssmDouble>(numnodes);

	/*Retrieve all inputs and parameters*/
	basalelement->GetVerticesCoordinates(&xyz_list);
	basalelement->FindParam(&dt,TimesteppingTimeStepEnum);
	Input2* groundedice_input   = basalelement->GetInput2(MaskOceanLevelsetEnum);              _assert_(groundedice_input);
	Input2* gmb_input           = basalelement->GetInput2(BasalforcingsGroundediceMeltingRateEnum);  _assert_(gmb_input);
	Input2* fmb_input           = basalelement->GetInput2(BasalforcingsFloatingiceMeltingRateEnum);  _assert_(fmb_input);
	Input2* base_input          = basalelement->GetInput2(BaseEnum);                                 _assert_(base_input);
	Input2* vz_input      = NULL;
	switch(dim){
		case 1: vz_input = basalelement->GetInput2(VyEnum); _assert_(vz_input); break;
		case 2: vz_input = basalelement->GetInput2(VzEnum); _assert_(vz_input); break;
		default: _error_("not implemented");
	}

	/* Start  looping on the number of gaussian points: */
	Gauss* gauss=basalelement->NewGauss(2);
	for(int ig=gauss->begin();ig<gauss->end();ig++){
		gauss->GaussPoint(ig);

		basalelement->JacobianDeterminant(&Jdet,xyz_list,gauss);
		basalelement->NodalFunctions(basis,gauss);

		vz_input->GetInputValue(&vz,gauss);
		gmb_input->GetInputValue(&gmb,gauss);
		fmb_input->GetInputValue(&fmb,gauss);
		base_input->GetInputValue(&bed,gauss);
		groundedice_input->GetInputValue(&phi,gauss);
		if(phi>0) mb=gmb;
		else mb=fmb;

		for(int i=0;i<numnodes;i++) pe->values[i]+=Jdet*gauss->weight*(bed+dt*(mb) + dt*vz)*basis[i];
	}

	/*Clean up and return*/
	xDelete<IssmDouble>(xyz_list);
	xDelete<IssmDouble>(basis);
	delete gauss;
	if(domaintype!=Domain2DhorizontalEnum){basalelement->DeleteMaterials(); delete basalelement;};
	return pe;

}/*}}}*/
void           FreeSurfaceBaseAnalysis::GetB(IssmDouble* B,Element* element,int dim,IssmDouble* xyz_list,Gauss* gauss){/*{{{*/
	/*Compute B  matrix. B=[B1 B2 B3] where Bi is of size 3*2. 
	 * For node i, Bi can be expressed in the actual coordinate system
	 * by: 
	 *       Bi=[ N ]
	 *          [ N ]
	 * where N is the finiteelement function for node i.
	 *
	 * We assume B_prog has been allocated already, of size: 2x(1*numnodes)
	 */

	/*Fetch number of nodes for this finite element*/
	int numnodes = element->GetNumberOfNodes();

	/*Get nodal functions*/
	IssmDouble* basis=xNew<IssmDouble>(numnodes);
	element->NodalFunctions(basis,gauss);

	/*Build B: */
	for(int i=0;i<numnodes;i++){
		for(int j=0;j<dim;j++){
			B[numnodes*j+i] = basis[i];
		}
	}

	/*Clean-up*/
	xDelete<IssmDouble>(basis);
}/*}}}*/
void           FreeSurfaceBaseAnalysis::GetBprime(IssmDouble* Bprime,Element* element,int dim,IssmDouble* xyz_list,Gauss* gauss){/*{{{*/
	/*Compute B'  matrix. B'=[B1' B2' B3'] where Bi' is of size 3*2. 
	 * For node i, Bi' can be expressed in the actual coordinate system
	 * by: 
	 *       Bi_prime=[ dN/dx ]
	 *                [ dN/dy ]
	 * where N is the finiteelement function for node i.
	 *
	 * We assume B' has been allocated already, of size: 3x(2*numnodes)
	 */

	/*Fetch number of nodes for this finite element*/
	int numnodes = element->GetNumberOfNodes();

	/*Get nodal functions derivatives*/
	IssmDouble* dbasis=xNew<IssmDouble>(dim*numnodes);
	element->NodalFunctionsDerivatives(dbasis,xyz_list,gauss);

	/*Build B': */
	for(int i=0;i<numnodes;i++){
		for(int j=0;j<dim;j++){
			Bprime[numnodes*j+i] = dbasis[j*numnodes+i];
		}
	}

	/*Clean-up*/
	xDelete<IssmDouble>(dbasis);

}/*}}}*/
void           FreeSurfaceBaseAnalysis::GetSolutionFromInputs(Vector<IssmDouble>* solution,Element* element){/*{{{*/
	   _error_("not implemented yet");
}/*}}}*/
void           FreeSurfaceBaseAnalysis::GradientJ(Vector<IssmDouble>* gradient,Element* element,int control_type,int control_index){/*{{{*/
	_error_("Not implemented yet");
}/*}}}*/
void           FreeSurfaceBaseAnalysis::InputUpdateFromSolution(IssmDouble* solution,Element* element){/*{{{*/
	element->InputUpdateFromSolutionOneDof(solution,BaseEnum);
}/*}}}*/
void           FreeSurfaceBaseAnalysis::UpdateConstraints(FemModel* femmodel){/*{{{*/

	/*Intermediary*/
	IssmDouble phi,isonbase,base;

	for(int i=0;i<femmodel->elements->Size();i++){

		Element* element=xDynamicCast<Element*>(femmodel->elements->GetObjectByOffset(i));
		if(!element->IsOnBase()) continue;

		int             numnodes = element->GetNumberOfNodes();
		Input2* groundedice_input = element->GetInput2(MaskOceanLevelsetEnum);  _assert_(groundedice_input);
		Input2* onbase_input       = element->GetInput2(MeshVertexonbaseEnum);          _assert_(onbase_input);
		Input2* base_input        = element->GetInput2(BaseEnum);                     _assert_(base_input);

		Gauss* gauss=element->NewGauss();
		for(int iv=0;iv<numnodes;iv++){
			gauss->GaussNode(element->GetElementType(),iv);
			onbase_input->GetInputValue(&isonbase,gauss);
			if(isonbase==1.){
				groundedice_input->GetInputValue(&phi,gauss);
				if(phi>=0.){
					base_input->GetInputValue(&base,gauss);
					element->nodes[iv]->ApplyConstraint(0,base);
				}
				else{
					element->nodes[iv]->DofInFSet(0);
				}
			}
		}
		delete gauss;
	}
}/*}}}*/
