/*!\file:  MatMultPatch
 * \brief: relocalize vector when MatMult yields non conforming object sizes errors.
 */ 

#ifdef HAVE_CONFIG_H
	#include "config.h"
#else
#error "Cannot compile with HAVE_CONFIG_H symbol! run configure first!"
#endif

/*Petsc includes: */
#include "petscmat.h"
#include "petscvec.h"
#include "petscksp.h"

#include "../../../shared/shared.h"

/*Function prototypes: */
int MatMultCompatible(Mat A,Vec x);
void VecRelocalize(Vec* outvector,Vec vector,int m);

void MatMultPatch(Mat A,Vec X, Vec AX){ //same prototype as MatMult in Petsc

	int m,n;
	Vec X_rel=NULL;

	_assert_(A); _assert_(X);

	if (MatMultCompatible(A,X)){
		MatMult(A,X,AX); 
	}
	else{
		MatGetLocalSize(A,&m,&n);;
		VecRelocalize(&X_rel,X,n);
		MatMult(A,X_rel,AX); ;
		VecDestroy(X_rel);;
	}
}

int MatMultCompatible(Mat A,Vec x){
	
	/*error management*/
	
	int local_m,local_n;
	int lower_row,upper_row,range;
	int result=1;
	int sumresult;
	extern int num_procs;

	MatGetLocalSize(A,&local_m,&local_n);;
	VecGetLocalSize(x,&range);;
	
	if (local_n!=range)result=0;
	
	/*synchronize result: */
	MPI_Reduce (&result,&sumresult,1,MPI_INT,MPI_SUM,0,MPI_COMM_WORLD );
	MPI_Bcast(&sumresult,1,MPI_INT,0,MPI_COMM_WORLD);                
	if (sumresult!=num_procs){
		result=0;
	}
	else{
		result=1;\
	}
	return result;
}

void VecRelocalize(Vec* poutvector,Vec vector,int m){

	/*vector index and vector values*/
	int* index=NULL;
	double* values=NULL;
	int lower_row,upper_row,range;

	/*output: */
	Vec outvector=NULL;
	
	/*Create outvector with local size m*/
	VecCreate(PETSC_COMM_WORLD,&outvector); ; 
	VecSetSizes(outvector,m,PETSC_DECIDE); ; 
	VecSetFromOptions(outvector); ; 

	/*Go through vector, get values, and plug them into outvector*/
	VecGetOwnershipRange(vector,&lower_row,&upper_row); ; 
	upper_row--;
	range=upper_row-lower_row+1;
	if (range){
		index=(int*)xmalloc(range*sizeof(int));
		values=(double*)xmalloc(range*sizeof(double));
		for (int i=0;i<range;i++){
			*(index+i)=lower_row+i;
		}
		VecGetValues(vector,range,index,values);
		VecSetValues(outvector,range,index,values,INSERT_VALUES);
	}

	/*Assemble outvector*/
	VecAssemblyBegin(outvector);; 
	VecAssemblyEnd(outvector);; 

	/*Free ressources:*/
	xfree((void**)&index);
	xfree((void**)&values);	

	/*Assign output pointers:*/
	*poutvector=outvector;

}
