/*!\file: solutionsequence_linear.cpp
 * \brief: numerical core of linear solutions
 */ 

#include "../toolkits/toolkits.h"
#include "../classes/classes.h"
#include "../shared/shared.h"
#include "../modules/modules.h"

#ifdef _HAVE_PETSC_
void CreateDMatrix(Mat* pD,Mat K){/*{{{*/
	/*Create D matrix such that:
	 *
	 * d_ij = max( -k_ij,0,-k_ji) off diagonal
	 *
	 * d_ii = - sum_{i!=j} d_ij for the diagonal
	 *
	 */

	/*Intermediaries*/
	int        ncols,ncols2,rstart,rend;
	double     d,diagD;
	Mat        D        = NULL;
	Mat        K_transp = NULL;
	int*       cols  = NULL;
	int*       cols2 = NULL;
	double*    vals  = NULL;
	double*    vals2 = NULL;

	/*First, we need to transpose K so that we access both k_ij and k_ji*/
	MatTranspose(K,MAT_INITIAL_MATRIX,&K_transp);

	/*Initialize output (D has the same non zero pattern as K)*/
	MatDuplicate(K,MAT_SHARE_NONZERO_PATTERN,&D);

	/*Go through the rows of K an K' and build D*/
	MatGetOwnershipRange(K,&rstart,&rend);
	for(int row=rstart; row<rend; row++){
		diagD = 0.;
		MatGetRow(K       ,row,&ncols, (const int**)&cols, (const double**)&vals);
		MatGetRow(K_transp,row,&ncols2,(const int**)&cols2,(const double**)&vals2);
		_assert_(ncols==ncols2);
		for(int j=0; j<ncols; j++) {
			_assert_(cols[j]==cols2[j]);
			d = max(max(-vals[j],-vals2[j]),0.);
			MatSetValue(D,row,cols[j],(const double)d,INSERT_VALUES);
			if(cols[j]!=row) diagD -= d;
		}
		MatSetValue(D,row,row,(const double)diagD,INSERT_VALUES);
		MatRestoreRow(K       ,row,&ncols, (const int**)&cols, (const double**)&vals);
		MatRestoreRow(K_transp,row,&ncols2,(const int**)&cols2,(const double**)&vals2);
	}
	MatAssemblyBegin(D,MAT_FINAL_ASSEMBLY);
	MatAssemblyEnd(  D,MAT_FINAL_ASSEMBLY);

	/*Clean up and assign output pointer*/
	MatFree(&K_transp);
	*pD = D;
}/*}}}*/
void CreateLHS(Mat* pLHS,IssmDouble* pdmax,Mat K,Mat D,Vec Ml,IssmDouble theta,IssmDouble deltat,FemModel* femmodel,int configuration_type){/*{{{*/
	/*Create Left Hand side of Lower order solution
	 *
	 * LHS = [ML − theta*detlat *(K+D)^n+1]
	 *
	 */

	/*Intermediaries*/
	int        dof,ncols,ncols2,rstart,rend;
	double     d,mi,dmax = 0.;
	Mat        LHS   = NULL;
	int*       cols  = NULL;
	int*       cols2 = NULL;
	double*    vals  = NULL;
	double*    vals2 = NULL;

	MatDuplicate(K,MAT_SHARE_NONZERO_PATTERN,&LHS);
	MatGetOwnershipRange(K,&rstart,&rend);
	for(int row=rstart; row<rend; row++){
		MatGetRow(K,row,&ncols, (const int**)&cols, (const double**)&vals);
		MatGetRow(D,row,&ncols2,(const int**)&cols2,(const double**)&vals2);
		_assert_(ncols==ncols2);
		for(int j=0; j<ncols; j++) {
			_assert_(cols[j]==cols2[j]);
			d = -theta*deltat*(vals[j] + vals2[j]);
			if(cols[j]==row){
				VecGetValues(Ml,1,(const int*)&cols[j],&mi);
				d += mi;
			}
			if(fabs(d)>dmax) dmax = fabs(d);
			MatSetValue(LHS,row,cols[j],(const double)d,INSERT_VALUES);
		}
		MatRestoreRow(K,row,&ncols, (const int**)&cols, (const double**)&vals);
		MatRestoreRow(D,row,&ncols2,(const int**)&cols2,(const double**)&vals2);
	}

	/*Penalize Dirichlet boundary*/
	dmax = dmax * 1.e+3;
	for(int i=0;i<femmodel->constraints->Size();i++){
		Constraint* constraint=(Constraint*)femmodel->constraints->GetObjectByOffset(i);
		if(constraint->InAnalysis(configuration_type)){
			constraint->PenaltyDofAndValue(&dof,&d,femmodel->nodes,femmodel->parameters);
			if(dof!=-1){
				MatSetValue(LHS,dof,dof,(const double)dmax,INSERT_VALUES);
			}
		}
	}
	MatAssemblyBegin(LHS,MAT_FINAL_ASSEMBLY);
	MatAssemblyEnd(  LHS,MAT_FINAL_ASSEMBLY);

	/*Clean up and assign output pointer*/
	*pdmax = dmax;
	*pLHS  = LHS;
}/*}}}*/
void CreateRHS(Vec* pRHS,Mat K,Mat D,Vec Ml,Vec u,IssmDouble theta,IssmDouble deltat,IssmDouble dmax,FemModel* femmodel,int configuration_type){/*{{{*/
	/*Create Left Hand side of Lower order solution
	 *
	 * RHS = [ML + (1 − theta) deltaT L^n] u^n
	 *
	 * where L = K + D
	 *
	 */

	/*Intermediaries*/
	Vec         Ku  = NULL;
	Vec         Du  = NULL;
	Vec         RHS = NULL;
	int         dof;
	IssmDouble  d;

	/*Initialize vectors*/
	VecDuplicate(u,&Ku);
	VecDuplicate(u,&Du);
	VecDuplicate(u,&RHS);

	/*Create RHS = M*u + (1-theta)*deltat*K*u + (1-theta)*deltat*D*u*/
	MatMult(K,u,Ku);
	MatMult(D,u,Du);
	VecPointwiseMult(RHS,Ml,u);
	VecAXPBYPCZ(RHS,(1-theta)*deltat,(1-theta)*deltat,1,Ku,Du);
	VecFree(&Ku);
	VecFree(&Du);

	/*Penalize Dirichlet boundary*/
	for(int i=0;i<femmodel->constraints->Size();i++){
		Constraint* constraint=(Constraint*)femmodel->constraints->GetObjectByOffset(i);
		if(constraint->InAnalysis(configuration_type)){
			constraint->PenaltyDofAndValue(&dof,&d,femmodel->nodes,femmodel->parameters);
			d = d*dmax;
			if(dof!=-1){
				VecSetValues(RHS,1,&dof,(const double*)&d,INSERT_VALUES);
			}
		}
	}
	VecAssemblyBegin(RHS);
	VecAssemblyEnd(  RHS);

	/*Assign output pointer*/
	*pRHS = RHS;

}/*}}}*/
#endif
void solutionsequence_fct(FemModel* femmodel){

	/*intermediary: */
	Vector<IssmDouble>*  Ml = NULL;
	Matrix<IssmDouble>*  K  = NULL;
	Matrix<IssmDouble>*  Mc = NULL;
	Vector<IssmDouble>*  ug = NULL;
	Vector<IssmDouble>*  uf = NULL;

	IssmDouble theta,deltat,dmax;
	int        dof,ncols,ncols2,rstart,rend;
	int        configuration_type;
	double     d;
	int*       cols  = NULL;
	int*       cols2 = NULL;
	double*    vals  = NULL;
	double*    vals2 = NULL;

	/*Create analysis*/
	MasstransportAnalysis* analysis = new MasstransportAnalysis();

	/*Recover parameters: */
	femmodel->parameters->FindParam(&deltat,TimesteppingTimeStepEnum);
	femmodel->parameters->FindParam(&configuration_type,ConfigurationTypeEnum);
	femmodel->UpdateConstraintsx();
	theta = 0.5;

	/*Create lumped mass matrix*/
	analysis->LumpedMassMatrix(&Ml,femmodel);
	analysis->MassMatrix(&Mc,femmodel);
	analysis->FctKMatrix(&K,NULL,femmodel);

	/*Convert matrices to PETSc matrices*/
	Mat D_petsc  = NULL;
	Mat LHS      = NULL;
	Vec RHS      = NULL;
	Vec u        = NULL;
	Mat K_petsc  = K->pmatrix->matrix;
	Vec Ml_petsc = Ml->pvector->vector;
	Mat Mc_petsc = Mc->pmatrix->matrix;

	/*Create D Matrix*/
	#ifdef _HAVE_PETSC_
	CreateDMatrix(&D_petsc,K_petsc);

	/*Create LHS: [ML − theta*detlat *(K+D)^n+1]*/
	CreateLHS(&LHS,&dmax,K_petsc,D_petsc,Ml_petsc,theta,deltat,femmodel,configuration_type);

	/*Create RHS: [ML + (1 − theta) deltaT L^n] u^n */
	GetSolutionFromInputsx(&ug,femmodel);
	Reducevectorgtofx(&uf, ug, femmodel->nodes,femmodel->parameters);
	delete ug;
	CreateRHS(&RHS,K_petsc,D_petsc,Ml_petsc,uf->pvector->vector,theta,deltat,dmax,femmodel,configuration_type);
	delete uf;

	/*Go solve!*/
	SolverxPetsc(&u,LHS,RHS,NULL,NULL, femmodel->parameters); 
	MatFree(&LHS);
	VecFree(&RHS);

	/*Richardson to calculate udot*/
	Vec udot = NULL;
	VecDuplicate(u,&udot);
	VecZeroEntries(udot);
	Vec temp1 = NULL; VecDuplicate(u,&temp1);
	Vec temp2 = NULL; VecDuplicate(u,&temp2);
	for(int i=0;i<5;i++){
		/*udot_new = udot_old + Ml^-1 (K^(n+1) u -- Mc udot_old)*/
		MatMult(Mc_petsc,udot,temp1);
		MatMult(K_petsc, u,   temp2);
		VecAXPBY(temp2,-1.,1.,temp1); // temp2 = (K^(n+1) u -- Mc udot_old)
		VecPointwiseDivide(temp1,temp2,Ml_petsc); //temp1 = Ml^-1 temp2
		VecAXPBY(udot,1.,1.,temp1);
	}
	VecFree(&temp1);
	VecFree(&temp2);

	/*Serialize u and udot*/
	IssmDouble* udot_serial = NULL;
	IssmDouble* u_serial    = NULL;
	IssmDouble* ml_serial    = NULL;
	VecToMPISerial(&udot_serial,udot    ,IssmComm::GetComm());
	VecToMPISerial(&u_serial   ,u       ,IssmComm::GetComm());
	VecToMPISerial(&ml_serial  ,Ml_petsc,IssmComm::GetComm());

	/*Anti diffusive fluxes*/
	Vec Ri_plus  = NULL;
	Vec Ri_minus = NULL;
	double uiLmax = 3.;
	double uiLmin = 2.;
	VecDuplicate(u,&Ri_plus);
	VecDuplicate(u,&Ri_minus);
	MatGetOwnershipRange(K_petsc,&rstart,&rend);
	for(int row=rstart; row<rend; row++){
		double Pi_plus  = 0.;
		double Pi_minus = 0.;
		MatGetRow(Mc_petsc,row,&ncols, (const int**)&cols, (const double**)&vals);
		MatGetRow(D_petsc ,row,&ncols2,(const int**)&cols2,(const double**)&vals2);
		_assert_(ncols==ncols2);
		for(int j=0; j<ncols; j++) {
			_assert_(cols[j]==cols2[j]);
			d = vals[j]*(udot_serial[row] - udot_serial[cols[j]]) + vals2[j]*(u_serial[row] - u_serial[cols[j]]);
			if(row!=cols[j]){
				if(d>0.){
					Pi_plus  += d;
				}
				else{
					Pi_minus += d;
				}
			}
		}

		/*Compute Qis and Ris*/
		double Qi_plus  = ml_serial[row]/deltat*(uiLmax - u_serial[row]);
		double Qi_minus = ml_serial[row]/deltat*(uiLmin - u_serial[row]);
		d = 1.;
		if(Pi_plus!=0.) d = min(1.,Qi_plus/Pi_plus);
		VecSetValue(Ri_plus,row,(const double)d,INSERT_VALUES);
		d = 1.;
		if(Pi_minus!=0.) d = min(1.,Qi_minus/Pi_minus);
		VecSetValue(Ri_minus,row,(const double)d,INSERT_VALUES);

		MatRestoreRow(Mc_petsc, row,&ncols, (const int**)&cols, (const double**)&vals);
		MatRestoreRow(D_petsc,row,&ncols2,(const int**)&cols2,(const double**)&vals2);
	}
	VecAssemblyBegin(Ri_plus);
	VecAssemblyEnd(  Ri_plus);
	VecAssemblyBegin(Ri_minus);
	VecAssemblyEnd(  Ri_minus);

	/*Serialize Ris*/
	IssmDouble* Ri_plus_serial  = NULL;
	IssmDouble* Ri_minus_serial = NULL;
	VecToMPISerial(&Ri_plus_serial,Ri_plus,IssmComm::GetComm());
	VecToMPISerial(&Ri_minus_serial,Ri_minus,IssmComm::GetComm());
	VecFree(&Ri_plus);
	VecFree(&Ri_minus);

	/*Build fbar*/
	Vec Fbar = NULL;
	VecDuplicate(u,&Fbar);
	for(int row=rstart; row<rend; row++){
		MatGetRow(Mc_petsc,row,&ncols, (const int**)&cols, (const double**)&vals);
		MatGetRow(D_petsc ,row,&ncols2,(const int**)&cols2,(const double**)&vals2);
		_assert_(ncols==ncols2);
		d = 0.;
		for(int j=0; j<ncols; j++) {
			_assert_(cols[j]==cols2[j]);
			if(row==cols[j]) continue;
			double f_ij = vals[j]*(udot_serial[row] - udot_serial[cols[j]]) + vals2[j]*(u_serial[row] - u_serial[cols[j]]);
			if(f_ij>0){
				d += min(Ri_plus_serial[row],Ri_minus_serial[cols[j]]) * f_ij;
			}
			else{
				d += min(Ri_minus_serial[row],Ri_plus_serial[cols[j]]) * f_ij;
			}
		}
		VecSetValue(Fbar,row,(const double)d,INSERT_VALUES);
		MatRestoreRow(Mc_petsc, row,&ncols, (const int**)&cols, (const double**)&vals);
		MatRestoreRow(D_petsc,row,&ncols2,(const int**)&cols2,(const double**)&vals2);
	}
	VecAssemblyBegin(Fbar);
	VecAssemblyEnd(  Fbar);

	MatFree(&D_petsc);
	delete Mc;
	xDelete<IssmDouble>(udot_serial);
	xDelete<IssmDouble>(u_serial);
	xDelete<IssmDouble>(ml_serial);
	xDelete<IssmDouble>(Ri_plus_serial);
	xDelete<IssmDouble>(Ri_minus_serial);

	/*Compute solution u^n+1 = u_L + deltat Ml^-1 fbar*/
	VecDuplicate(u,&temp1);
	VecPointwiseDivide(temp1,Fbar,Ml_petsc); //temp1 = Ml^-1 temp2
	VecAXPBY(udot,1.,1.,temp1);
	VecAXPY(u,deltat,temp1);
	VecFree(&Fbar);
	VecFree(&udot);
	VecFree(&temp1);

	uf =new Vector<IssmDouble>(u);
	VecFree(&u);

	InputUpdateFromSolutionx(femmodel,uf); 
	delete uf;

	#else
	_error_("PETSc needs to be installed");
	#endif

	delete Ml;
	delete K;
	delete analysis;

}
