/*!\file MpiDenseMumpsSolve.cpp
 * \brief: solve dense matrix system with MUMPS
 */

/*Header files: {{{*/
#ifdef HAVE_CONFIG_H
	#include <config.h>
#else
#error "Cannot compile with HAVE_CONFIG_H symbol! run configure first!"
#endif

#include "../../shared/Numerics/types.h"
#include "../../shared/MemOps/MemOps.h"
#include "../../shared/Exceptions/exceptions.h"
#include "../../shared/io/Comm/IssmComm.h"
#include "../mpi/issmmpi.h"

/*Mumps header files: */
#include "dmumps_c.h"
/*}}}*/

void MumpsInit(DMUMPS_STRUC_C &theMumpsStruc) { 
	theMumpsStruc.par          = 1;  
	theMumpsStruc.sym          = 0;
	theMumpsStruc.comm_fortran = MPI_Comm_c2f(IssmComm::GetComm());
	theMumpsStruc.job          = -1;
	dmumps_c(&theMumpsStruc);
}

// must be preceded by a call to MumpsInit
void MumpsSettings(DMUMPS_STRUC_C &theMumpsStruc) { 
	/*Control statements:{{{ */
	theMumpsStruc.icntl[1-1] = 6; //error verbose
	theMumpsStruc.icntl[2-1] = 1; //std verbose
	theMumpsStruc.icntl[4-1] = 4; //verbose everything
	theMumpsStruc.icntl[5-1] = 0;
	theMumpsStruc.icntl[18-1] = 3;

	theMumpsStruc.icntl[20-1] = 0;
	theMumpsStruc.icntl[21-1] = 0;
	theMumpsStruc.icntl[30-1] = 0;
	/*}}}*/
}

// must be preceded by a call to MumpsInit
void MumpsAnalyze(DMUMPS_STRUC_C &theMumpsStruc) { 
	theMumpsStruc.job          = 1;
	dmumps_c(&theMumpsStruc);
}

// must be preceded by a call to MumpsAnalyze
void MumpsFactorize(DMUMPS_STRUC_C &theMumpsStruc) { 
	theMumpsStruc.job          = 2;
	dmumps_c(&theMumpsStruc);
}

// must be preceded by a call to MumpsFactorize
void MumpsBacksubstitute(DMUMPS_STRUC_C &theMumpsStruc) { 
	theMumpsStruc.job          = 3;
	dmumps_c(&theMumpsStruc);
}

// must be preceded at least  by a call to MumpsInit
void MumpsFinalize(DMUMPS_STRUC_C &theMumpsStruc) { 
	theMumpsStruc.job          = -2;
	dmumps_c(&theMumpsStruc);
}

void MumpsSolve(int n,
		int nnz,
		int local_nnz,
		int* irn_loc,
		int* jcn_loc,
		IssmPDouble *a_loc,
		IssmPDouble *rhs) { 
	/*Initialize mumps: {{{*/
	DMUMPS_STRUC_C theMumpsStruc;
	MumpsInit(theMumpsStruc);
	MumpsSettings(theMumpsStruc);
	/*}}}*/
	// now setup the rest of theMumpsStruc 
	theMumpsStruc.n=n;
	theMumpsStruc.nz=nnz;
	theMumpsStruc.nz_loc=local_nnz;
	theMumpsStruc.irn_loc=irn_loc;
	theMumpsStruc.jcn_loc=jcn_loc;
	theMumpsStruc.a_loc=a_loc;
	theMumpsStruc.rhs=rhs;
	theMumpsStruc.nrhs=1;
	theMumpsStruc.lrhs=1;
	/*Solve system: {{{*/
	MumpsAnalyze(theMumpsStruc);
	MumpsFactorize(theMumpsStruc);
	MumpsBacksubstitute(theMumpsStruc);
	/*}}}*/
	MumpsFinalize(theMumpsStruc);
}

#ifdef _HAVE_ADOLC_
// prototype for active variant
void MumpsSolve(int n,
		int nnz,
		int local_nnz,
		int* irn_loc,
		int* jcn_loc,
		IssmDouble *a_loc,
		IssmDouble *rhs);
#endif 

void MpiDenseMumpsSolve( /*output: */ IssmDouble* uf, int uf_M, int uf_m, /*matrix input: */ IssmDouble* Kff, int Kff_M, int Kff_N, int Kff_m, /*right hand side vector: */ IssmDouble* pf, int pf_M, int pf_m){ /*{{{*/

	/*Variables: {{{*/

	ISSM_MPI_Comm   comm;
	int        my_rank;
	int        num_procs;
	int        i;
	int        j;
	int         nnz       ,local_nnz;
	int        *irn_loc = NULL;
	int        *jcn_loc = NULL;
	IssmDouble *a_loc   = NULL;
	int         count;
	int         lower_row;
	int         upper_row;
	IssmDouble* rhs=NULL;
	int*        recvcounts=NULL;
	int*        displs=NULL;
	/*}}}*/
	/*Communicator info:{{{ */
	my_rank=IssmComm::GetRank();
	num_procs=IssmComm::GetSize();
	comm=IssmComm::GetComm();
	/*}}}*/
	/*First, some checks:{{{ */
	if (Kff_M!=Kff_N)_error_("stiffness matrix Kff should be square");
	if (uf_M!=Kff_M | uf_M!=pf_M)_error_("solution vector should be the same size as stiffness matrix Kff and load vector pf");
	if (uf_m!=Kff_m | uf_m!=pf_m)_error_("solution vector should be locally the same size as stiffness matrix Kff and load vector pf");
	/*}}}*/
	/*Initialize matrix:{{{ */

	/*figure out number of non-zero entries: */
	local_nnz=0;
	for(i=0;i<Kff_m;i++){
		for(j=0;j<Kff_N;j++){
			if (Kff[i*Kff_N+j]!=0)local_nnz++;
		}
	}

	ISSM_MPI_Reduce(&local_nnz,&nnz,1,ISSM_MPI_INT,ISSM_MPI_SUM,0,comm);
	ISSM_MPI_Bcast(&nnz,1,ISSM_MPI_INT,0,comm);

	/*Allocate: */
	if(local_nnz){
		irn_loc=xNew<int>(local_nnz);
		jcn_loc=xNew<int>(local_nnz);
		a_loc=xNew<IssmDouble>(local_nnz);
	}

	/*Populate the triplets: */
	GetOwnershipBoundariesFromRange(&lower_row,&upper_row,Kff_m,comm);
	count=0;
	for(i=0;i<Kff_m;i++){
		for(j=0;j<Kff_N;j++){
			if (Kff[i*Kff_N+j]!=0){
				irn_loc[count]=lower_row+i+1; //fortran indexing
				jcn_loc[count]=j+1; //fortran indexing
				a_loc[count]=Kff[i*Kff_N+j];
				count++;
			}
		}
	}
	/*Deal with right hand side. We need to ISSM_MPI_Gather it onto cpu 0: */
	rhs=xNew<IssmDouble>(pf_M);

	recvcounts=xNew<int>(num_procs);
	displs=xNew<int>(num_procs);

	/*recvcounts:*/
	ISSM_MPI_Allgather(&pf_m,1,ISSM_MPI_INT,recvcounts,1,ISSM_MPI_INT,comm);

	/*displs: */
	ISSM_MPI_Allgather(&lower_row,1,ISSM_MPI_INT,displs,1,ISSM_MPI_INT,comm);

	/*Gather:*/
	ISSM_MPI_Gatherv(pf, pf_m, ISSM_MPI_DOUBLE, rhs, recvcounts, displs, ISSM_MPI_DOUBLE,0,comm);

	MumpsSolve(Kff_M,
		   nnz,
		   local_nnz,
		   irn_loc,
		   jcn_loc,
		   a_loc,
		   rhs);

if (my_rank==0) for (int i=0;i<Kff_M;++i) std::cout << i << " : " << rhs[i] << std::endl;
	/*Now scatter from cpu 0 to all other cpus: {{{*/
	ISSM_MPI_Scatterv( rhs, recvcounts, displs, ISSM_MPI_DOUBLE, uf, uf_m, ISSM_MPI_DOUBLE, 0, comm); 

	/*}}}*/
	/*Cleanup: {{{*/
	xDelete<int>(irn_loc);
	xDelete<int>(jcn_loc);
	xDelete<IssmDouble>(a_loc);
	xDelete<IssmDouble>(rhs);
	xDelete<int>(recvcounts);
	xDelete<int>(displs);
	/*}}}*/
} /*}}}*/

#ifdef _HAVE_ADOLC_

int mumpsSolveEDF(int iArrLength, int* iArr, int nPlusNz /* we can ignore it*/, double* dp_x, int m, double* dp_y) {
  // unpack parameters
  int n=iArr[0];
  int nz=iArr[1];
  int *irn=new int[nz];
  int *jcn=new int[nz];
  double *A=new double[nz];
  for (int i=0;i<nz;++i) { 
    irn[i]=iArr[2+i];
    jcn[i]=iArr[2+nz+i];
    A[i]=dp_x[i];
  }
  double *rhs_sol=new double[n];
  for (int i=0;i<n;++i) { 
    rhs_sol[i]=dp_x[nz+i];
  }
  mumpsSolve(n,nz,irn,jcn,A,rhs_sol);
  for (int i=0;i<m;++i) { 
    dp_y[i]=rhs_sol[i];
  }
  return 0;
}

void MumpsSolve(int n,
		int nnz,
		int local_nnz,
		int* irn_loc,
		int* jcn_loc,
		IssmDouble *a_loc,
		IssmDouble *rhs) { 
  int packedDimsSparseArrLength=1+1+1+local_nnz+local_nnz;
  int *packedDimsSparseArr=xNew<int>(packedDimsSparseArrLength);
  packedDimsSparseArr[0]=n;
  packedDimsSparseArr[1]=nnz;
  packedDimsSparseArr[2]=local_nnz;
  for (int i=0;i<local_nnz;++i) {
    packedDimsSparseArr[3+i]=irn_loc[i];
    packedDimsSparseArr[3+local_nnz+i]=jcn_loc[i];
  }
  ensureContiguousLocations(local_nnz+n);
  adouble *pack_A_rhs=xNew<IssmDouble>(local_nnz+n);
  for (int i=0;i<local_nnz;++i) { 
    pack_A_rhs[i]=a_loc[i];
  }
  for (int i=0;i<n;++i) { 
    pack_A_rhs[local_nnz+i]=rhs[i];
  }
  double *passivePack_A_rhs=xNew<IssmPDouble>(local_nnz+n);
  double *passiveSol=xNew<IssmPDouble>(n);
  ensureContiguousLocations(n);
  adouble *sol=xNew<IssmDouble>(n);
  call_ext_fct(ourEDF_p,
	       packedDimsSparseArrLength, packedDimsSparseArr,
	       local_nnz+n, passivePack_A_rhs, pack_A_rhs, 
	       n, passiveSol,sol);
  for (int i=0;i<n;++i) { 
    rhs[i]=sol[i];
  }
  xDelete(sol);
  xDelete(passiveSol);
  xDelete(passivePack_A_rhs);
  xDelete(pack_A_rhs);
  xDelete(packedDimsSparseArr);
}

int fos_forward_mumpsSolveEDF(int iArrLength, int* iArr, int nPlusNz /* we can ignore it*/, 
			      double *dp_x, double *dp_X, int m, double *dp_y, double *dp_Y) {
  // unpack parameters
  int n=iArr[0];
  int nz=iArr[1];
  int *irn=new int[nz];
  int *jcn=new int[nz];
  double *A=new double[nz];
  for (int i=0;i<nz;++i) { 
    irn[i]=iArr[2+i];
    jcn[i]=iArr[2+nz+i];
    A[i]=dp_x[i];
  }
  double *rhs_sol=new double[n];
  for (int i=0;i<n;++i) { 
    rhs_sol[i]=dp_x[nz+i];
  }
  DMUMPS_STRUC_C id;
  id.par = 1; // one processor=sequential code
  id.sym = 0; // asymmetric
  id.job = JOB_INIT;
  dmumps_c(&id);

  id.icntl[1-1] = 6; //error verbose
  id.icntl[2-1] = 0; //std verbose
  id.icntl[3-1] = 0; // 
  id.icntl[4-1] = 0; // 
  id.icntl[5-1] = 0; // matrix is assembled
  id.icntl[18-1] = 0; // centralized
  id.icntl[20-1] = 0; // rhs is dense and centralized
  id.icntl[21-1] = 0; // solution is centralized
  id.n=n;
  id.nz=nz;
  id.irn=irn;
  id.jcn=jcn;
  id.a=A;
  id.job = JOB_ANALYSIS;
  dmumps_c(&id);
  id.job = JOB_FACTORIZATION; 
  dmumps_c (&id);
  // solve the orifginal system
  id.rhs=rhs_sol;
  id.nrhs=1;
  id.lrhs=1;
  id.job = JOB_BACKSUBST; 
  dmumps_c (&id);
  for (int i=0;i<m;++i) { 
    dp_y[i]=rhs_sol[i];
  }
  // solve for the derivative
  for (int i=0;i<n;++i) { 
    rhs_sol[i]=dp_X[nz+i]; 
  }
  for (int i=0;i<nz;++i) { 
    rhs_sol[irn[i]-1]-=dp_X[i]*dp_y[jcn[i]-1];
  }
  dmumps_c (&id);
  for (int i=0;i<m;++i) { 
    dp_Y[i]=rhs_sol[i];
  }
  id.job = JOB_END; 
  dmumps_c (&id);
  return 3;
}

int fos_reverse_mumpsSolveEDF(int iArrLength, int* iArr, 
			      int m, double *dp_U, 
			      int nPlusNz, double *dp_Z, 
			      double *dp_x, double *dp_y) {
  // unpack parameters
  int n=iArr[0];
  int nz=iArr[1];
  int *irn=new int[nz];
  int *jcn=new int[nz];
  double *A=new double[nz];
  for (int i=0;i<nz;++i) { 
    irn[i]=iArr[2+i];
    jcn[i]=iArr[2+nz+i];
    A[i]=dp_x[i];
  }
  DMUMPS_STRUC_C id;
  id.par = 1; // one processor=sequential code
  id.sym = 0; // asymmetric
  id.job = JOB_INIT;
  dmumps_c(&id);

  id.icntl[1-1] = 6; //error verbose
  id.icntl[2-1] = 0; //std verbose
  id.icntl[3-1] = 0; // 
  id.icntl[4-1] = 0; // 
  id.icntl[5-1] = 0; // matrix is assembled
  id.icntl[9-1] = 0; //solve for the transpose
  id.icntl[18-1] = 0; // centralized
  id.icntl[20-1] = 0; // rhs is dense and centralized
  id.icntl[21-1] = 0; // solution is centralized
  id.n=n;
  id.nz=nz;
  id.irn=irn;
  id.jcn=jcn;
  id.a=A;
  id.job = JOB_ANALYSIS;
  dmumps_c(&id);
  id.job = JOB_FACTORIZATION; 
  dmumps_c (&id);
  double *rhs_sol=new double[n];
  for (int i=0;i<n;++i) { 
    rhs_sol[i]=dp_U[i];
  }
  id.rhs=rhs_sol;
  id.nrhs=1;
  id.lrhs=1;
  id.job = JOB_BACKSUBST; 
  dmumps_c (&id);
  // update the adhoint of the rhs: 
  for (int i=0;i<m;++i) { 
    dp_Z[nz+i]+=rhs_sol[i];
  }
  // update the adjoint of the matrix: 
  for (int i=0;i<nz;++i) { 
    dp_Z[i]+=-dp_U[irn[i]-1]*dp_y[jcn[i]-1];
  }
  return 3;
}

void MpiDenseMumpsSolve( /*output: */ IssmDouble* uf, int uf_M, int uf_m, /*matrix input: */ IssmDouble* Kff, int Kff_M, int Kff_N, int Kff_m, /*right hand side vector: */ IssmDouble* pf, int pf_M, int pf_m){ /*{{{*/
	_error_("not supported yet!");
} /*}}}*/

#endif
