/*!\file: transient_3d_core.cpp
 * \brief: core of the transient_3d solution
 */

#ifdef HAVE_CONFIG_H
#include <config.h>
#else
#error "Cannot compile with HAVE_CONFIG_H symbol! run configure first!"
#endif

#include <float.h>
#include "./cores.h"
#include "../toolkits/toolkits.h"
#include "../classes/classes.h"
#include "../shared/shared.h"
#include "../modules/modules.h"
#include "../solutionsequences/solutionsequences.h"

/*Prototypes*/
void transient_step(FemModel* femmodel);

void transient_core(FemModel* femmodel){/*{{{*/

	/*parameters: */
	IssmDouble finaltime,dt,yts;
	bool       isoceancoupling,iscontrol,isautodiff,isslr;
	int        timestepping;
	int        output_frequency,recording_frequency;
	int        amr_frequency,amr_restart;
	char     **requested_outputs = NULL;

	/*intermediary: */
	int        step;
	IssmDouble time;

	/*first, figure out if there was a check point, if so, do a reset of the FemModel* femmodel structure. */
	femmodel->parameters->FindParam(&recording_frequency,SettingsRecordingFrequencyEnum);
	if(recording_frequency) femmodel->Restart();

	/*then recover parameters common to all solutions*/
	femmodel->parameters->FindParam(&step,StepEnum);
	femmodel->parameters->FindParam(&time,TimeEnum);
	femmodel->parameters->FindParam(&finaltime,TimesteppingFinalTimeEnum);
	femmodel->parameters->FindParam(&yts,ConstantsYtsEnum);
	femmodel->parameters->FindParam(&output_frequency,SettingsOutputFrequencyEnum);
	femmodel->parameters->FindParam(&timestepping,TimesteppingTypeEnum);
	femmodel->parameters->FindParam(&isslr,TransientIsslrEnum);
	femmodel->parameters->FindParam(&isoceancoupling,TransientIsoceancouplingEnum);
	femmodel->parameters->FindParam(&amr_frequency,TransientAmrFrequencyEnum);
	femmodel->parameters->FindParam(&iscontrol,InversionIscontrolEnum);
	femmodel->parameters->FindParam(&isautodiff,AutodiffIsautodiffEnum);

	#if defined(_HAVE_BAMG_) && !defined(_HAVE_AD_)
	if(amr_frequency){
		femmodel->parameters->FindParam(&amr_restart,AmrRestartEnum);
		if(amr_restart) femmodel->ReMesh();
	}
	#endif

	#if defined(_HAVE_OCEAN_ )
	if(isoceancoupling) OceanExchangeDatax(femmodel,true);
	#endif

	DataSet* dependent_objects=NULL;
	if(iscontrol && isautodiff){
		femmodel->parameters->FindParam(&dependent_objects,AutodiffDependentObjectsEnum);
	}

	if(isslr) sealevelrise_core_geometry(femmodel);

	while(time < finaltime - (yts*DBL_EPSILON)){ //make sure we run up to finaltime.

		/*Time Increment*/
		switch(timestepping){
			case AdaptiveTimesteppingEnum:
				femmodel->TimeAdaptx(&dt);
				if(time+dt>finaltime) dt=finaltime-time;
				femmodel->parameters->SetParam(dt,TimesteppingTimeStepEnum);
				break;
			case FixedTimesteppingEnum:
				femmodel->parameters->FindParam(&dt,TimesteppingTimeStepEnum);
				break;
			default:
				_error_("Time stepping \""<<EnumToStringx(timestepping)<<"\" not supported yet");
		}
		step+=1;
		time+=dt;
		femmodel->parameters->SetParam(time,TimeEnum);
		femmodel->parameters->SetParam(step,StepEnum);

		if(VerboseSolution()){
			_printf0_("iteration " << step << "/" << ceil((finaltime-time)/dt)+step << \
						"  time [yr]: " << setprecision(4) << time/yts << " (time step: " << dt/yts << ")\n");
		}
		bool save_results=false;
		if(step%output_frequency==0 || (time >= finaltime - (yts*DBL_EPSILON)) || step==1) save_results=true;
		femmodel->parameters->SetParam(save_results,SaveResultsEnum);

		/*Run transient step!*/
		transient_step(femmodel);

		/*unload results*/
		if(save_results){
			if(VerboseSolution()) _printf0_("   saving temporary results\n");
			OutputResultsx(femmodel);
		}

		if(recording_frequency && (step%recording_frequency==0)){
			if(VerboseSolution()) _printf0_("   checkpointing model \n");
			femmodel->CheckPoint();
		}

		/*Adaptive mesh refinement*/
		if(amr_frequency){

			#if !defined(_HAVE_AD_)
			if(save_results) femmodel->WriteMeshInResults();
			if(step%amr_frequency==0 && time<finaltime){
				if(VerboseSolution()) _printf0_("   refining mesh\n");
				femmodel->ReMesh();//Do not refine the last step
			}

			#else
			_error_("AMR not suppored with AD");
			#endif
		}

		if(iscontrol && isautodiff){
			/*Go through our dependent variables, and compute the response:*/
			for(Object* & object:dependent_objects->objects){
				DependentObject* dep=(DependentObject*)object;
				IssmDouble  output_value;
				dep->Responsex(&output_value,femmodel);
				dep->AddValue(output_value);
			}
		}
	}

	if(!iscontrol || !isautodiff) femmodel->RequestedDependentsx();
	if(iscontrol && isautodiff) femmodel->parameters->SetParam(dependent_objects,AutodiffDependentObjectsEnum);

}/*}}}*/
void transient_step(FemModel* femmodel){/*{{{*/

	/*parameters: */
	bool isstressbalance,ismasstransport,issmb,isthermal,isgroundingline,isgia,isesa;
	bool isslr,ismovingfront,isdamageevolution,ishydrology,isoceancoupling,save_results;
	int  step,sb_coupling_frequency;
	int  domaintype,numoutputs;

	/*then recover parameters common to all solutions*/
	femmodel->parameters->FindParam(&domaintype,DomainTypeEnum);
	femmodel->parameters->FindParam(&save_results,SaveResultsEnum);
	femmodel->parameters->FindParam(&step,StepEnum);
	femmodel->parameters->FindParam(&sb_coupling_frequency,SettingsSbCouplingFrequencyEnum);
	femmodel->parameters->FindParam(&isstressbalance,TransientIsstressbalanceEnum);
	femmodel->parameters->FindParam(&ismasstransport,TransientIsmasstransportEnum);
	femmodel->parameters->FindParam(&issmb,TransientIssmbEnum);
	femmodel->parameters->FindParam(&isthermal,TransientIsthermalEnum);
	femmodel->parameters->FindParam(&isgia,TransientIsgiaEnum);
	femmodel->parameters->FindParam(&isesa,TransientIsesaEnum);
	femmodel->parameters->FindParam(&isslr,TransientIsslrEnum);
	femmodel->parameters->FindParam(&isgroundingline,TransientIsgroundinglineEnum);
	femmodel->parameters->FindParam(&ismovingfront,TransientIsmovingfrontEnum);
	femmodel->parameters->FindParam(&isoceancoupling,TransientIsoceancouplingEnum);
	femmodel->parameters->FindParam(&isdamageevolution,TransientIsdamageevolutionEnum);
	femmodel->parameters->FindParam(&ishydrology,TransientIshydrologyEnum);
	femmodel->parameters->FindParam(&numoutputs,TransientNumRequestedOutputsEnum);

	#if defined(_HAVE_OCEAN_ )
	if(isoceancoupling) OceanExchangeDatax(femmodel,false);
	#endif

	if(isthermal && domaintype==Domain3DEnum){
		if(issmb){
			bool isenthalpy;
			int  smb_model;
			femmodel->parameters->FindParam(&isenthalpy,ThermalIsenthalpyEnum);
			femmodel->parameters->FindParam(&smb_model,SmbEnum);
			if(isenthalpy){
				if(smb_model==SMBpddEnum || smb_model==SMBd18opddEnum || smb_model==SMBpddSicopolisEnum){
					femmodel->SetCurrentConfiguration(EnthalpyAnalysisEnum);
					ResetBoundaryConditions(femmodel,EnthalpyAnalysisEnum);
				}
			}
			else{
				if(smb_model==SMBpddEnum || smb_model==SMBd18opddEnum || smb_model==SMBpddSicopolisEnum){
					femmodel->SetCurrentConfiguration(ThermalAnalysisEnum);
					ResetBoundaryConditions(femmodel,ThermalAnalysisEnum);
				}
			}
		}
		thermal_core(femmodel);
	}

	/* Using Hydrology dc  coupled we need to compute smb in the hydrology inner time loop*/
	if(issmb) {
		if(VerboseSolution()) _printf0_("   computing smb\n");
		smb_core(femmodel);
	}

	if(ishydrology){
		if(VerboseSolution()) _printf0_("   computing hydrology\n");
		int hydrology_model;
		hydrology_core(femmodel);
		femmodel->parameters->FindParam(&hydrology_model,HydrologyModelEnum);
		if(hydrology_model!=HydrologydcEnum && issmb) smb_core(femmodel);
	}

	if(isstressbalance && (step%sb_coupling_frequency==0 || step==1) ) {
		if(VerboseSolution()) _printf0_("   computing stress balance\n");
		stressbalance_core(femmodel);
	}

	if(isdamageevolution) {
		if(VerboseSolution()) _printf0_("   computing damage\n");
		damage_core(femmodel);
	}

	if(ismovingfront)	{
		if(VerboseSolution()) _printf0_("   computing moving front\n");
		movingfront_core(femmodel);
	}

	/* from here on, prepare geometry for next time step*/

	if(ismasstransport){
		if(VerboseSolution()) _printf0_("   computing mass transport\n");
		bmb_core(femmodel);
		masstransport_core(femmodel);
		femmodel->UpdateVertexPositionsx();
	}

	if(isgroundingline){
		if(VerboseSolution()) _printf0_("   computing new grounding line position\n");
		groundingline_core(femmodel);
	}

	if(isgia){
		if(VerboseSolution()) _printf0_("   computing glacial isostatic adjustment\n");
		#ifdef _HAVE_GIA_
		gia_core(femmodel);
		#else
		_error_("ISSM was not compiled with gia capabilities. Exiting");
		#endif
	}

	/*esa: */
	if(isesa) esa_core(femmodel);

	/*Sea level rise: */
	if(isslr){
		if(VerboseSolution()) _printf0_("   computing sea level rise\n");
		sealevelchange_core(femmodel);
	}

	/*Any requested output that needs to be saved?*/
	if(numoutputs){
		char **requested_outputs = NULL;
		femmodel->parameters->FindParam(&requested_outputs,&numoutputs,TransientRequestedOutputsEnum);

		if(VerboseSolution()) _printf0_("   computing transient requested outputs\n");
		femmodel->RequestedOutputsx(&femmodel->results,requested_outputs,numoutputs,save_results);

		/*Free ressources:*/
		for(int i=0;i<numoutputs;i++){xDelete<char>(requested_outputs[i]);} xDelete<char*>(requested_outputs);
	}
}/*}}}*/

#ifdef _HAVE_CODIPACK_
void transient_ad(FemModel* femmodel){/*{{{*/

	/*parameters: */
	IssmDouble output_value;
	IssmDouble finaltime,dt,yts,time;
	bool       isoceancoupling,isslr;
	int        step,timestepping;
	DataSet*   dependent_objects=NULL;

	/*then recover parameters common to all solutions*/
	femmodel->parameters->FindParam(&step,StepEnum);
	femmodel->parameters->FindParam(&time,TimeEnum);
	femmodel->parameters->FindParam(&finaltime,TimesteppingFinalTimeEnum);
	femmodel->parameters->FindParam(&yts,ConstantsYtsEnum);
	femmodel->parameters->FindParam(&timestepping,TimesteppingTypeEnum);
	femmodel->parameters->FindParam(&isslr,TransientIsslrEnum);
	femmodel->parameters->FindParam(&dependent_objects,AutodiffDependentObjectsEnum);
	if(isslr) sealevelrise_core_geometry(femmodel);

	std::vector<IssmDouble> time_all;
	int                     Ysize = 0;
	CountDoublesFunctor   *hdl_countdoubles = NULL;
	RegisterInputFunctor  *hdl_regin        = NULL;
	RegisterOutputFunctor *hdl_regout       = NULL;
	SetAdjointFunctor     *hdl_setadjoint   = NULL;

	while(time < finaltime - (yts*DBL_EPSILON)){ //make sure we run up to finaltime.

		/*Time Increment*/
		if(timestepping!=FixedTimesteppingEnum) _error_("not supported yet, but easy to handle...");
		femmodel->parameters->FindParam(&dt,TimesteppingTimeStepEnum);
		step+=1;
		time+=dt;
		femmodel->parameters->SetParam(time,TimeEnum);
		femmodel->parameters->SetParam(step,StepEnum);
		femmodel->parameters->SetParam(false,SaveResultsEnum);
		time_all.push_back(time);

		/*Store Model State at the beginning of the step*/
		if(VerboseSolution()) _printf0_("   checkpointing model \n");
		femmodel->CheckPointAD(step);

		if(VerboseSolution()) _printf0_("   counting number of active variables\n");
		hdl_countdoubles = new CountDoublesFunctor();
		femmodel->Marshall(hdl_countdoubles);
		if(hdl_countdoubles->DoubleCount()>Ysize) Ysize= hdl_countdoubles->DoubleCount();
		delete hdl_countdoubles;

		/*Run transient step!*/
		transient_step(femmodel);

		/*Go through our dependent variables, and compute the response:*/
		for(Object* & object:dependent_objects->objects){
			DependentObject* dep=(DependentObject*)object;
			dep->Responsex(&output_value,femmodel);
			dep->AddValue(output_value);
		}
	}

	if(VerboseSolution()) _printf0_("   done with initial complete transient\n");

	/*__________________________________________________________________________________*/

	/*Get X (control)*/
	IssmDouble *X = NULL; int Xsize;
	GetVectorFromControlInputsx(&X,&Xsize,femmodel->elements,femmodel->nodes,femmodel->vertices,femmodel->loads,femmodel->materials,femmodel->parameters,"value");

	/*Initialize model state adjoint (Yb)*/
	double *Yb  = xNewZeroInit<double>(Ysize);
	int    *Yin = xNewZeroInit<int>(Ysize);

	/*Get final Ysize*/
	hdl_countdoubles = new CountDoublesFunctor();
	femmodel->Marshall(hdl_countdoubles);
	int Ysize_i= hdl_countdoubles->DoubleCount();
	delete hdl_countdoubles;

	/*Start tracing*/
	auto& tape_codi = IssmDouble::getGlobalTape();
	tape_codi.setActive();

	/*Reverse dependent (f)*/
	hdl_regin = new RegisterInputFunctor(Yin,Ysize);
	femmodel->Marshall(hdl_regin);
	delete hdl_regin;
	for(int i=0; i < Xsize; i++) tape_codi.registerInput(X[i]);

	SetControlInputsFromVectorx(femmodel,X);

	IssmDouble J = 0.;
	if(0){
		for(Object* & object:dependent_objects->objects){
			DependentObject* dep=(DependentObject*)object;
			J += dep->GetValue();
		}
	}
	else{
		femmodel->IceVolumex(&J,false);
	}
	if(IssmComm::GetRank()==0) tape_codi.registerOutput(J);

	tape_codi.setPassive();
	J.gradient() = 1.0;
	tape_codi.evaluate();

	/*Initialize Xb and Yb*/
	double *Xb  = xNewZeroInit<double>(Xsize);
	for(int i=0;i<Xsize  ;i++) Xb[i] += X[i].gradient();
	for(int i=0;i<Ysize_i;i++) Yb[i]  = tape_codi.gradient(Yin[i]);

	/*reverse loop for transient step (G)*/
	for(int reverse_step = step;reverse_step>0; reverse_step--){

		/*Restore model from this step*/
		tape_codi.reset();
		femmodel->RestartAD(reverse_step);
		tape_codi.setActive();

		/*Get new Ysize*/
		hdl_countdoubles = new CountDoublesFunctor();
		femmodel->Marshall(hdl_countdoubles);
		int Ysize_i= hdl_countdoubles->DoubleCount();
		delete hdl_countdoubles;

		/*We need to store the CoDiPack identifier here, since y is overwritten.*/
		hdl_regin = new RegisterInputFunctor(Yin,Ysize);
		femmodel->Marshall(hdl_regin);
		delete hdl_regin;

		/*Tell codipack that X is the independent*/
		for(int i=0; i<Xsize; i++) tape_codi.registerInput(X[i]);
		SetControlInputsFromVectorx(femmodel,X);

		/*Get New state*/
		transient_step(femmodel);

		/*Register output*/
		hdl_regout = new RegisterOutputFunctor();
		femmodel->Marshall(hdl_regout);
		delete hdl_regout;

		/*stop tracing*/
		tape_codi.setPassive();

		/*Reverse transient step (G)*/
		/* Using y_b here to seed the next reverse iteration there y_b is always overwritten*/
		hdl_setadjoint = new SetAdjointFunctor(Yb,Ysize);
		femmodel->Marshall(hdl_setadjoint);
		delete hdl_setadjoint;

		tape_codi.evaluate();

		/* here we access the gradient data via the stored identifiers.*/
		//for(int i=0; i<Ysize_i; i++) if(tape_codi.gradient(Yin[i])!=0.) printf(" %g (%i)",tape_codi.gradient(Yin[i]),Yin[i]);
		for(int i=0; i<Ysize_i; i++) Yb[i]  = tape_codi.gradient(Yin[i]);
		for(int i=0; i<Xsize;   i++) Xb[i] += X[i].gradient();
	}

	/*Now set final gradient*/
	IssmDouble* aG=xNew<IssmDouble>(Xsize);
	for(int i=0;i<Xsize;i++){
		aG[i] = reCast<IssmDouble>(Xb[i]);
	}
	ControlInputSetGradientx(femmodel->elements,femmodel->nodes,femmodel->vertices,femmodel->loads,femmodel->materials,femmodel->parameters,aG);
	xDelete<IssmDouble>(aG);

	xDelete<IssmDouble>(X);
	xDelete<double>(Xb);
	xDelete<double>(Yb);
	xDelete<int>(Yin);
}/*}}}*/
#endif
