/*!\file: solutionsequence_schurcg.cpp
 * \brief: numerical core of 
 */ 

#include "./solutionsequences.h"
#include "../toolkits/toolkits.h"
#include "../classes/classes.h"
#include "../shared/shared.h"
#include "../modules/modules.h"
#include "../analyses/analyses.h"


#ifdef _HAVE_PETSC_
void SchurCGSolver(Vector<IssmDouble>** puf,Mat Kff,Mat Iff,Vec pf, Vec uf0,Vec df,Parameters* parameters){/*{{{*/

	Mat                  A, B, BT;				/* Saddle point block matrices */
	Mat						IP;						/* Preconditioner matrix */
	IS                   isv=NULL;				/* Index set free velocity nodes */
	IS                   isp=NULL;				/* Index set free pressure nodes */
	int                  nu, np;					/* No of. free nodes in velocity / pressure space */
   Vec                  p,uold,unew;			/* Solution vectors for pressure / vel. */ 
	Vec						tmpu, tmpp, rhsu,rhsp; /* temp. vectors, arbitrary RHS in vel. / pressure space */
	Vec						gold,gnew,wold,wnew,chi,thetaold,thetanew,eta; /* CG intermediaries */
	Vec						f1,f2;					/* RHS of the global system */
	double					rho,gamma,tmpScalar; /* Step sizes, arbitrary double */
	KSP						kspu,kspp;				/* KSP contexts for vel. / pressure systems*/
	KSPConvergedReason	reason;					/* Convergence reason for troubleshooting */
	int						its;						/* No. of iterations for troubleshooting */
	double					initRnorm, rnorm, TOL; /* residual norms, STOP tolerance */
	PC							pcu,pcp;					/* Preconditioner contexts pertaining the KSP contexts*/
	PetscViewer				viewer;					/* Viewer for troubleshooting */
	IssmPDouble				t1,t2;					/* Time measurement for bottleneck analysis */

	/*STOP tolerance for the rel. residual*/
	TOL = 0.4;

	/*Initialize output*/
	Vector<IssmDouble>* out_uf=new Vector<IssmDouble>(uf0);
	
	/* Get velocity and pressure index sets for extraction */
	#if _PETSC_MAJOR_==3
		/*Make indices out of doftypes: */
		if(!df)_error_("need doftypes for FS solver!\n");
	   DofTypesToIndexSet(&isv,&isp,df,FSSolverEnum);
	#else
	   _error_("Petsc 3.X required");
	#endif


	/* Extract block matrices from the saddle point matrix */
	/* [ A   B ] = Kff
    * [ B^T 0 ] 
	 *         */
	MatGetSubMatrix(Kff,isv,isv,MAT_INITIAL_MATRIX,&A);
	MatGetSubMatrix(Kff,isv,isp,MAT_INITIAL_MATRIX,&B);
	MatGetSubMatrix(Kff,isp,isv,MAT_INITIAL_MATRIX,&BT);
	
	/* Extract preconditioner matrix on the pressure space*/
	MatGetSubMatrix(Iff,isp,isp,MAT_INITIAL_MATRIX,&IP);
	
	/* Get number of velocity / pressure nodes */
	MatGetSize(B,&nu,&np);

	/* Extract initial guesses for uold and pold */
	VecCreate(IssmComm::GetComm(),&p);VecSetSizes(p,PETSC_DECIDE,np);VecSetFromOptions(p);
	VecAssemblyBegin(p);VecAssemblyEnd(p);
	VecCreate(IssmComm::GetComm(),&uold);VecSetSizes(uold,PETSC_DECIDE,nu);VecSetFromOptions(uold);
	VecAssemblyBegin(uold);VecAssemblyEnd(uold);

	VecGetSubVector(out_uf->pvector->vector,isv,&uold);
	VecGetSubVector(out_uf->pvector->vector,isp,&p);


	/* Set up intermediaries */
	VecDuplicate(uold,&f1);VecSet(f1,0.0);
	VecDuplicate(p,&f2);VecSet(f2,0.0);
	VecDuplicate(uold,&tmpu);VecSet(tmpu,0.0);
	VecDuplicate(p,&tmpp);VecSet(tmpp,0.0);
	VecDuplicate(p,&rhsp);VecSet(rhsp,0.0);
	VecDuplicate(uold,&rhsu);VecSet(rhsu,0.0);
	VecDuplicate(p,&gold);VecSet(gold,0.0);
	VecDuplicate(p,&wnew);VecSet(wnew,0.0);
	VecDuplicate(uold,&chi);VecSet(chi,0.0);
	VecDuplicate(p,&thetanew);VecSet(thetanew,0.0);
	VecDuplicate(p,&thetaold);VecSet(thetaold,0.0);
	VecDuplicate(p,&eta);VecSet(eta,0.0);
	
	/* Get global RHS (for each block sub-problem respectively)*/
	VecGetSubVector(pf,isv,&f1);
	VecGetSubVector(pf,isp,&f2);

   /* ------------------------------------------------------------ */

	/* Generate initial value for the velocity from the pressure */
	/* a(u0,v) = f1(v)-b(p0,v)  i.e.  Au0 = F1-Bp0 */
	/* u0 = u_DIR on \Gamma_DIR */
	
	/* Create KSP context */
	KSPCreate(IssmComm::GetComm(),&kspu);
	KSPSetOperators(kspu,A,A);
	KSPSetType(kspu,KSPCG);
	KSPSetInitialGuessNonzero(kspu,PETSC_TRUE);
	//KSPSetTolerances(kspu,1e-12,PETSC_DEFAULT,PETSC_DEFAULT,PETSC_DEFAULT);
	//KSPMonitorSet(kspu,KSPMonitorDefault,NULL,NULL);
	KSPGetPC(kspu,&pcu);
	PCSetType(pcu,PCSOR);
	KSPSetUp(kspu);

	
	/* Create RHS */
	/* RHS = F1-B * pold */
	VecScale(p,-1.);MatMultAdd(B,p,f1,rhsu);VecScale(p,-1.);

	/* Go solve Au0 = F1-Bp0*/
	KSPSolve(kspu,rhsu,uold);
	

	/* Set up u_new */
	VecDuplicate(uold,&unew);VecCopy(uold,unew);
	VecAssemblyBegin(unew);VecAssemblyEnd(unew);



	/* ------------------------------------------------------------- */

	/*Get initial residual*/
	/*(1/mu(x) * g0, q) = b(q,u0) - (f2,q)  i.e.  IP * g0 = BT * u0 - F2*/
	
	/* Create KSP context */
	KSPCreate(IssmComm::GetComm(),&kspp);
	KSPSetOperators(kspp,IP,IP);
	
	/* Create RHS */
	/* RHS = BT * uold - F2 */
	VecScale(f2,-1.);MatMultAdd(BT,uold,f2,rhsp);VecScale(f2,-1.);

	/* Set KSP & PC options */
	KSPSetType(kspp,KSPCG);
	KSPSetInitialGuessNonzero(kspp,PETSC_TRUE);
	KSPGetPC(kspp,&pcp);
	PCSetType(pcp,PCJACOBI);
	/* Note: Systems in the pressure space are cheap, so we can afford a better tolerance */
	KSPSetTolerances(kspp,1e-10,PETSC_DEFAULT,PETSC_DEFAULT,PETSC_DEFAULT);
	KSPSetUp(kspp);
	
	/* Go solve */
	KSPSolve(kspp,rhsp,gold);
	
	/*Initial residual*/
	VecNorm(gold,NORM_INFINITY,&initRnorm);
	
	/* Further setup */
	VecDuplicate(gold,&gnew);VecCopy(gold,gnew);
	VecAssemblyBegin(gnew);VecAssemblyEnd(gnew);


	/* ------------------------------------------------------------ */

	/*Set initial search direction*/
	/*w0 = g0*/
	VecDuplicate(gold,&wold);VecCopy(gold,wold);
	VecAssemblyBegin(wold);VecAssemblyEnd(wold);

	/*Realizing the step size part 1: thetam */
	/*IP * theta = BT * uold - F2*/
	VecScale(f2,-1.);MatMultAdd(BT,uold,f2,rhsp);VecScale(f2,-1.);
	KSPSolve(kspp,rhsp,thetaold);


	/* Count number of iterations */
	int count = 0;

	/* CG iteration*/
	for(;;){

		/*Realizing the step size part 2: chim */
		/*a(chim,v) = -b(wm,v)  i.e.  A * chim = -B * wm */
		/*chim_DIR = 0*/
		VecScale(wold,-1.);MatMult(B,wold,rhsu);VecScale(wold,-1.);
		KSPSolve(kspu,rhsu,chi);

		/*Realizing the step size part 3: etam */
		MatMult(BT,chi,rhsp);
		KSPSolve(kspp,rhsp,eta);
	
		/* ---------------------------------------------------------- */


		/*Set step size*/
		/*rhom = [(wm)^T * IP^-1 * (BT * um - F2)]/[(wm)^T * IP^-1 * BT * chim]*/
		VecDot(wold,thetaold,&rho);
		VecDot(wold,eta,&tmpScalar);
		rho = rho/tmpScalar;


		/* ---------------------------------------------------------- */


		/*Pressure update*/
		/*p(m+1) = pm - rhom * wm*/
		VecAXPY(p,-1.*rho,wold);


		/*Velocity update*/
		/*u(m+1) = um - rhom * chim*/
		VecWAXPY(unew,-1.*rho,chi,uold);


		/* ---------------------------------------------------------- */

		/*Theta update*/
		/*IP * theta = BT * uold - F2*/
		VecScale(f2,-1.);MatMultAdd(BT,unew,f2,rhsp);VecScale(f2,-1.);
		KSPSolve(kspp,rhsp,thetanew);


		/* ---------------------------------------------------------- */

		/*Residual update*/
		/*g(m+1) = gm - rhom * BT * chim*/
		VecWAXPY(gnew,-1.*rho,eta,gold);

		/* ---------------------------------------------------------- */


		/*BREAK if norm(g(m+0),2) < TOL or pressure space has been full searched*/
		VecNorm(gnew,NORM_INFINITY,&rnorm);
		if(rnorm < TOL*initRnorm) 
		 break;
		else if(rnorm > 100*initRnorm)
		 _error_("Solver diverged. This shouldn't happen\n");
		else
		 PetscPrintf(PETSC_COMM_WORLD,"rel. residual at step %d: %g, at TOL = %g\n",count,rnorm/initRnorm,TOL);
		
		if(count > np-1) break;
	

		/* ---------------------------------------------------------- */


		/*Directional update*/
		/*gamma = [g(m+1)^T * theta(m+1)]/[g(m)^T * thetam]*/
		VecDot(gnew,thetanew,&gamma);
		VecDot(gold,thetaold,&tmpScalar);
		gamma = gamma/tmpScalar;

		/*w(m+1) = g(m+1) + gamma * w(m)*/
		VecWAXPY(wnew,gamma,wold,gnew);

		/* Assign new to old iterates */
		VecCopy(wnew,wold);VecCopy(gnew,gold);VecCopy(unew,uold);VecCopy(thetanew,thetaold);
		
		count++;
	}


	/* Restore pressure and velocity sol. vectors to its global form */
	VecRestoreSubVector(out_uf->pvector->vector,isv,&unew);
	VecRestoreSubVector(out_uf->pvector->vector,isp,&p);

	/*return output pointer*/
	*puf=out_uf;


	/* Cleanup */
	KSPDestroy(&kspu);KSPDestroy(&kspp);

	MatDestroy(&A);MatDestroy(&B);MatDestroy(&BT);MatDestroy(&IP);
	
	VecDestroy(&p);VecDestroy(&uold);VecDestroy(&unew);VecDestroy(&rhsu);VecDestroy(&rhsp);
	VecDestroy(&gold);VecDestroy(&gnew);VecDestroy(&wold);VecDestroy(&wnew);VecDestroy(&chi);
	VecDestroy(&tmpp);VecDestroy(&tmpu);VecDestroy(&f1);VecDestroy(&f2);VecDestroy(&eta);
	VecDestroy(&thetanew);VecDestroy(&thetaold);

}/*}}}*/
void solutionsequence_schurcg(FemModel* femmodel){/*{{{*/

	/*intermediary: */
	Matrix<IssmDouble>* Kff = NULL;
	Matrix<IssmDouble>* Kfs = NULL;
	Vector<IssmDouble>* ug  = NULL;
	Vector<IssmDouble>* uf  = NULL;
	Vector<IssmDouble>* old_uf = NULL;
	Vector<IssmDouble>* pf  = NULL;
	Vector<IssmDouble>* df  = NULL;
	Vector<IssmDouble>* ys  = NULL;
	Matrix<IssmDouble>* Iff = NULL;


	/*parameters:*/
	int max_nonlinear_iterations;
	int configuration_type;
	IssmDouble eps_res,eps_rel,eps_abs;

	/*Recover parameters: */
	femmodel->parameters->FindParam(&max_nonlinear_iterations,StressbalanceMaxiterEnum);
	femmodel->parameters->FindParam(&eps_res,StressbalanceRestolEnum);
	femmodel->parameters->FindParam(&eps_rel,StressbalanceReltolEnum);
	femmodel->parameters->FindParam(&eps_abs,StressbalanceAbstolEnum);
	femmodel->parameters->FindParam(&configuration_type,ConfigurationTypeEnum);
	femmodel->UpdateConstraintsx();
	int size;
	int  count=0;
	bool converged=false;

	/*Start non-linear iteration using input velocity: */
	GetSolutionFromInputsx(&ug,femmodel);
	Reducevectorgtofx(&uf, ug, femmodel->nodes,femmodel->parameters);

	/*Update once again the solution to make sure that vx and vxold are similar*/
	InputUpdateFromConstantx(femmodel,converged,ConvergedEnum);
	InputUpdateFromSolutionx(femmodel,ug);

	for(;;){

		/*save pointer to old velocity*/
		delete old_uf; old_uf=uf;
		delete ug;

		/*Get stiffness matrix and Load vector*/
		SystemMatricesx(&Kff,&Kfs,&pf,&df,NULL,femmodel);
		CreateNodalConstraintsx(&ys,femmodel->nodes,configuration_type);
		Reduceloadx(pf, Kfs, ys); delete Kfs;

		/*Create mass matrix*/
		int fsize; Kff->GetSize(&fsize,&fsize);
		Iff=new Matrix<IssmDouble>(fsize,fsize,300,4);
		StressbalanceAnalysis* analysis = new StressbalanceAnalysis();
		/*Get complete stiffness matrix without penalties*/
		for(int i=0;i<femmodel->elements->Size();i++){
			Element* element=xDynamicCast<Element*>(femmodel->elements->GetObjectByOffset(i));
			ElementMatrix* Ie = analysis->CreateSchurPrecondMatrix(element);
			if(Ie) Ie->AddToGlobal(Iff,NULL);
			delete Ie;
		}
		Iff->Assemble();
		delete analysis;

		/*Solve*/
		femmodel->profiler->Start(SOLVER);
		_assert_(Kff->type==PetscMatType); 
		
		SchurCGSolver(&uf,
					Kff->pmatrix->matrix,
					Iff->pmatrix->matrix,
					pf->pvector->vector,
					old_uf->pvector->vector,
					df->pvector->vector,
					femmodel->parameters);
		femmodel->profiler->Stop(SOLVER);
		delete Iff;
		
		/*Merge solution from f set to g set*/
		Mergesolutionfromftogx(&ug, uf,ys,femmodel->nodes,femmodel->parameters);delete ys;

		/*Check for convergence and update inputs accordingly*/
		convergence(&converged,Kff,pf,uf,old_uf,eps_res,eps_rel,eps_abs); delete Kff; delete pf; delete df;
		count++;

		if(count>=max_nonlinear_iterations){
			_printf0_("   maximum number of nonlinear iterations (" << max_nonlinear_iterations << ") exceeded\n"); 
			converged=true;
		}
		InputUpdateFromConstantx(femmodel,converged,ConvergedEnum);
		InputUpdateFromSolutionx(femmodel,ug);

		/*Increase count: */
		if(converged==true){
			femmodel->results->AddResult(new GenericExternalResult<IssmDouble>(femmodel->results->Size()+1,StressbalanceConvergenceNumStepsEnum,count));
			break;
		}
	}

	if(VerboseConvergence()) _printf0_("\n   total number of iterations: " << count << "\n");

	/*clean-up*/
	delete uf;
	delete ug;
	delete old_uf;

}/*}}}*/
#else
void solutionsequence_schurcg(FemModel* femmodel){_error_("PETSc needs to be installed");}
#endif
