/*!\file: controlad_core.cpp
 * \brief: core of the ad control solution 
 */ 

#include <config.h>
#include "./cores.h"
#include "../toolkits/toolkits.h"
#include "../classes/classes.h"
#include "../shared/shared.h"
#include "../modules/modules.h"
#include "../solutionsequences/solutionsequences.h"

#if defined (_HAVE_M1QN3_)  & defined (_HAVE_ADOLC_)
/*m1qn3 prototypes {{{*/
extern "C" void *ctonbe_; // DIS mode : Conversion
extern "C" void *ctcabe_; // DIS mode : Conversion
extern "C" void *euclid_; // Scalar product
typedef void (*SimulFunc) (long* indic,long* n, double* x, double* pf,double* g,long [],float [],void* dzs);
extern "C" void m1qn3_ (void f(long* indic,long* n, double* x, double* pf,double* g,long [],float [],void* dzs),
			void **, void **, void **,
			long *, double [], double *, double [], double*, double *,
			double *, char [], long *, long *, long *, long *, long *, long *, long [], double [], long *,
			long *, long *, long [], float [],void* );

/*Cost function prototype*/
void simulad(long* indic,long* n,double* X,double* pf,double* G,long izs[1],float rzs[1],void* dzs);
/*}}}*/
void controlad_core(FemModel* femmodel){ /*{{{*/

	/*Intermediaries*/
	int          i;
	long         omode;
	IssmPDouble  f,dxmin,gttol;
	IssmDouble   dxmind,gttold; 
	int          maxsteps,maxiter;
	int          intn,numberofvertices,num_controls,solution_type;
	IssmDouble  *scaling_factors = NULL;
	IssmPDouble  *X  = NULL;
	IssmDouble   *Xd  = NULL;
	IssmDouble   *Gd  = NULL;
	IssmPDouble  *G  = NULL;
	bool onsid=true;

	/*Recover some parameters*/
	femmodel->parameters->FindParam(&solution_type,SolutionTypeEnum);
	femmodel->parameters->FindParam(&num_controls,InversionNumControlParametersEnum);
	femmodel->parameters->FindParam(&maxsteps,InversionMaxstepsEnum);
	femmodel->parameters->FindParam(&maxiter,InversionMaxiterEnum);
	femmodel->parameters->FindParam(&dxmind,InversionDxminEnum); dxmin=reCast<IssmPDouble>(dxmind);
	femmodel->parameters->FindParam(&gttold,InversionGttolEnum); gttol=reCast<IssmPDouble>(gttold);
	femmodel->parameters->FindParam(&scaling_factors,NULL,InversionControlScalingFactorsEnum);
	femmodel->parameters->SetParam(false,SaveResultsEnum);
	numberofvertices=femmodel->vertices->NumberOfVertices();

	/*Initialize M1QN3 parameters*/
	if(VerboseControl())_printf0_("   Initialize M1QN3 parameters\n");
	SimulFunc costfuncion  = &simulad;  /*Cost function address*/
	void**    prosca       = &euclid_;  /*Dot product function (euclid is the default)*/
	char      normtype[]   = "dfn";     /*Norm type: dfn = scalar product defined by prosca*/
	long      izs[5];                   /*Arrays used by m1qn3 subroutines*/
	long      iz[5];                    /*Integer m1qn3 working array of size 5*/
	float     rzs[1];                   /*Arrays used by m1qn3 subroutines*/
	long      impres       = 0;         /*verbosity level*/
	long      imode[3]     = {0};       /*scaling and starting mode, 0 by default*/
	long      indic        = 4;         /*compute f and g*/
	long      reverse      = 0;         /*reverse or direct mode*/
	long      io           = 6;         /*Channel number for the output*/

	/*Optimization criterions*/
	long niter = long(maxsteps); /*Maximum number of iterations*/
	long nsim  = long(maxiter);/*Maximum number of function calls*/

	/*Get initial guess*/
	Vector<IssmDouble> *Xad = NULL;
	GetVectorFromControlInputsx(&Xad,femmodel->elements,femmodel->nodes,femmodel->vertices,femmodel->loads,femmodel->materials,femmodel->parameters,"value",onsid);
	Xd = Xad->ToMPISerial();
	Xad->GetSize(&intn);
	X=xNew<IssmPDouble>(intn);
	for(i=0;i<intn;i++) X[i]=reCast<IssmPDouble>(Xd[i]);
	delete Xad;
	_assert_(intn==numberofvertices*num_controls);
	
	/*Get problem dimension and initialize gradient and initial guess*/
	long n = long(intn);
	G = xNew<IssmPDouble>(n);
	Gd = xNew<IssmDouble>(n);

	/*Scale control for M1QN3*/
	for(int i=0;i<numberofvertices;i++){
		for(int c=0;c<num_controls;c++){
			int index = num_controls*i+c;
			X[index] = X[index]/reCast<IssmPDouble>(scaling_factors[c]);
		}
	}

	/*Allocate m1qn3 working arrays (see doc)*/
	long      m   = 100;
	long      ndz = 4*n+m*(2*n+1);
	double*   dz  = xNew<double>(ndz);

	if(VerboseControl())_printf0_("   Computing initial solution\n");
	_printf0_("\n");
	_printf0_("Cost function f(x)   | Gradient norm |g(x)| |  List of contributions\n");
	_printf0_("____________________________________________________________________\n");

	//first run before firing up the control optimization
	simulad(&indic,&n,X,&f,G,izs,rzs,(void*)femmodel);
	double f1=f;

	m1qn3_(costfuncion,prosca,&ctonbe_,&ctcabe_,
				&n,X,&f,G,&dxmin,&f1,
				&gttol,normtype,&impres,&io,imode,&omode,&niter,&nsim,iz,dz,&ndz,
				&reverse,&indic,izs,rzs,(void*)femmodel);

	switch(int(omode)){
		case 0:  _printf0_("   Stop requested (indic = 0)\n"); break;
		case 1:  _printf0_("   Convergence reached (gradient satisfies stopping criterion)\n"); break;
		case 2:  _printf0_("   Bad initialization\n"); break;
		case 3:  _printf0_("   Line search failure\n"); break;
		case 4:  _printf0_("   Maximum number of iterations exceeded\n");break;
		case 5:  _printf0_("   Maximum number of function calls exceeded\n"); break;
		case 6:  _printf0_("   stopped on dxmin during line search\n"); break;
		case 7:  _printf0_("   <g,d> > 0  or  <y,s> <0\n"); break;
		default: _printf0_("   Unknown end condition\n");
	}
	
	/*Constrain solution vector*/
	IssmDouble  *XL = NULL;
	IssmDouble  *XU = NULL;
	GetVectorFromControlInputsx(&XL,femmodel->elements,femmodel->nodes,femmodel->vertices,femmodel->loads,femmodel->materials,femmodel->parameters,"lowerbound",onsid);
	GetVectorFromControlInputsx(&XU,femmodel->elements,femmodel->nodes,femmodel->vertices,femmodel->loads,femmodel->materials,femmodel->parameters,"upperbound",onsid);
	for(int i=0;i<numberofvertices;i++){
		for(int c=0;c<num_controls;c++){
			int index = num_controls*i+c;
			X[index] = X[index]*reCast<IssmPDouble>(scaling_factors[c]);
			if(X[index]>XU[index]) X[index]=reCast<IssmPDouble>(XU[index]);
			if(X[index]<XL[index]) X[index]=reCast<IssmPDouble>(XL[index]);
		}
	}

	/*Save results:*/
	femmodel->results->AddObject(new GenericExternalResult<IssmPDouble*>(femmodel->results->Size()+1,AutodiffJacobianEnum,G,n,1,1,0.0));
	femmodel->results->AddObject(new GenericExternalResult<IssmPDouble*>(femmodel->results->Size()+1,AutodiffXpEnum,X,intn,1,1,0.0));

	/*Clean-up and return*/
	xDelete<double>(G);
	xDelete<double>(X);
	xDelete<double>(dz);
}/*}}}*/
void simulad(long* indic,long* n,double* X,double* pf,double* G,long izs[1],float rzs[1],void* dzs){ /*{{{*/

	/*Intermediaries:*/
	char* rootpath=NULL;
	char* inputfilename=NULL;
	char* outputfilename=NULL;
	char* toolkitsfilename=NULL;
	char* lockfilename=NULL;
	IssmPDouble* G2=NULL;
	bool onsid=true;

	/*Recover Femmodel*/
	int         solution_type;
	FemModel   *femmodel  = (FemModel*)dzs;
	FemModel   *femmodelad  = NULL;
	IssmDouble    pfd;
	int            i;

	/*Recover number of cost functions responses*/
	int num_responses,num_controls,numberofvertices;
	IssmDouble* scaling_factors = NULL;
	femmodel->parameters->FindParam(&num_responses,InversionNumCostFunctionsEnum);
	femmodel->parameters->FindParam(&num_controls,InversionNumControlParametersEnum);
	femmodel->parameters->FindParam(&scaling_factors,NULL,InversionControlScalingFactorsEnum);
	numberofvertices=femmodel->vertices->NumberOfVertices();

	/*Constrain input vector*/
	IssmDouble  *XL = NULL;
	IssmDouble  *XU = NULL;
	GetVectorFromControlInputsx(&XL,femmodel->elements,femmodel->nodes,femmodel->vertices,femmodel->loads,femmodel->materials,femmodel->parameters,"lowerbound",onsid);
	GetVectorFromControlInputsx(&XU,femmodel->elements,femmodel->nodes,femmodel->vertices,femmodel->loads,femmodel->materials,femmodel->parameters,"upperbound",onsid);
	for(int i=0;i<numberofvertices;i++){
		for(int c=0;c<num_controls;c++){
			int index = num_controls*i+c;
			X[index] = X[index]*reCast<IssmPDouble>(scaling_factors[c]);
			if(X[index]>reCast<IssmPDouble>(XU[index])) X[index]=reCast<IssmPDouble>(XU[index]);
			if(X[index]<reCast<IssmPDouble>(XL[index])) X[index]=reCast<IssmPDouble>(XL[index]);
		}
	}

	/*Now things get complicated. The femmodel we recovered did not initialize an AD trace, so we can't compute gradients with it. We are going to recreate 
	 *a new femmodel, identical in all aspects to the first one, with trace on though, which will allow us to run the forward mode and get the gradient 
	 in one run of the solution core. So first recover the filenames required for the FemModel constructor, then call a new ad tailored constructor:*/
	femmodel->parameters->FindParam(&rootpath,RootPathEnum);
	femmodel->parameters->FindParam(&inputfilename,InputFileNameEnum);
	femmodel->parameters->FindParam(&outputfilename,OutputFileNameEnum);
	femmodel->parameters->FindParam(&toolkitsfilename,ToolkitsFileNameEnum);
	femmodel->parameters->FindParam(&lockfilename,LockFileNameEnum);

	femmodelad=new FemModel(rootpath, inputfilename, outputfilename, toolkitsfilename, lockfilename, femmodel->comm, femmodel->solution_type,X);
	femmodel=femmodelad; //We can do this, because femmodel is being called from outside, not by reference, so we won't erase it
	
	/*Recover some parameters*/
	femmodel->parameters->FindParam(&solution_type,SolutionTypeEnum);

	/*Compute solution:*/
	void (*solutioncore)(FemModel*)=NULL;
	CorePointerFromSolutionEnum(&solutioncore,femmodel->parameters,solution_type);
	solutioncore(femmodel);

	/*Compute objective function*/
	IssmDouble* Jlist = NULL;
	femmodel->CostFunctionx(&pfd,&Jlist,NULL); *pf=reCast<IssmPDouble>(pfd);
	_printf0_("f(x) = "<<setw(12)<<setprecision(7)<<*pf<<"  |  ");
	
	/*Compute gradient using AD. Gradient is in the results after the ad_core is called*/
	ad_core(femmodel); 

	if(IssmComm::GetRank()==0){
		GenericExternalResult<IssmPDouble*>* gradient=(GenericExternalResult<IssmPDouble*>*)femmodel->results->FindResult(AutodiffJacobianEnum); _assert_(gradient);
		G2=gradient->GetValues();
	}
	else G2=xNew<IssmPDouble>(*n);

	/*MPI broadcast results:*/
	ISSM_MPI_Bcast(G2,*n,ISSM_MPI_PDOUBLE,0,IssmComm::GetComm());
	
	/*Send gradient to m1qn3 core: */
	for(long i=0;i<*n;i++) G[i] = G2[i];
	
	/*Constrain X and G*/
	IssmDouble  Gnorm = 0.;
	for(int i=0;i<numberofvertices;i++){
		for(int c=0;c<num_controls;c++){
			int index = num_controls*i+c;
			if(X[index]>=XU[index]) G[index]=0.;
			if(X[index]<=XL[index]) G[index]=0.;
			G[index] = G[index]*reCast<IssmPDouble>(scaling_factors[c]);
			X[index] = X[index]/reCast<IssmPDouble>(scaling_factors[c]);
			Gnorm += G[index]*G[index];
		}
	}
	Gnorm = sqrt(Gnorm);

	/*Print info*/
	_printf0_("       "<<setw(12)<<setprecision(7)<<Gnorm<<" |");
	for(int i=0;i<num_responses;i++) _printf0_(" "<<setw(12)<<setprecision(7)<<Jlist[i]);
	_printf0_("\n");

	/*Clean-up and return*/
	xDelete<IssmDouble>(Jlist);
	xDelete<IssmDouble>(XU);
	xDelete<IssmDouble>(XL);
	xDelete<IssmPDouble>(G2);
	//if(femmodelad)delete femmodelad;
} /*}}}*/
#else
void controlad_core(FemModel* femmodel){ /*{{{*/
	_error_("AD and/or M1QN3 not installed");
}/*}}}*/
#endif //_HAVE_M1QN3_
