/*!\file: MatPartition.cpp
 * \brief partition matrix according to node sets
 */ 

#ifdef HAVE_CONFIG_H
	#include <config.h>
#else
#error "Cannot compile with HAVE_CONFIG_H symbol! run configure first!"
#endif

/*Petsc includes: */
#include "petscmat.h"
#include "petscvec.h"
#include "petscksp.h"

#include "../../../shared/shared.h"

int MatPartition(Mat* poutmatrix,Mat matrixA,double* row_partition_vector,int row_partition_vector_size ,
		double* col_partition_vector,int col_partition_vector_size){

	int i;
	
	/*Petsc matrix*/
	int d_nz;
	int o_nz;
	int* node_rows=NULL;
	int* node_cols=NULL;
	int count;
	IS col_index=NULL;
	IS row_index=NULL;
	int lower_row,upper_row,range;
	int MA,NA; //matrixA dimensions
	const char* type=NULL;
	int csize;

	/*output*/
	Mat outmatrix=NULL;
	
	/*get input matrix size: */
	MatGetSize(matrixA,&MA,&NA);

    /*If one of the partitioning row or col vectors has a 0 dimension, we return a NILL matrix. Return it with the same type as matrixA.*/
	
    if ((row_partition_vector_size==0) || (col_partition_vector==0)){
		MatGetType(matrixA,&type);
		if (strcmp(type,"mpiaij")==0){
			d_nz=0;
			o_nz=0;
			MatCreateMPIAIJ(MPI_COMM_WORLD,PETSC_DETERMINE,PETSC_DETERMINE, 0,0,d_nz,PETSC_NULL,o_nz,PETSC_NULL,&outmatrix);
		}
		else if (strcmp(type,"mpidense")==0){
			MatCreateMPIDense(MPI_COMM_WORLD,PETSC_DETERMINE,PETSC_DETERMINE, 0,0,PETSC_NULL,&outmatrix);
		}
		/*Assemble*/
		MatAssemblyBegin(outmatrix,MAT_FINAL_ASSEMBLY);
		MatAssemblyEnd(outmatrix,MAT_FINAL_ASSEMBLY);
	}
	else{
		/*Both vectors are non nill, use MatGetSubMatrix to condense out*/
		/*Figure out which rows each node is going to get from matrix A.*/
		MatGetOwnershipRange(matrixA,&lower_row,&upper_row);
		upper_row--;
		range=upper_row-lower_row+1;

		count=0;
		if (range){
			node_rows=(int*)xmalloc(range*sizeof(int)); //this is the maximum number of rows one node can extract.
		
			for (i=0;i<row_partition_vector_size;i++){
				if ( ((int)(*(row_partition_vector+i))>=(lower_row+1)) && ((int)(*(row_partition_vector+i))<=(upper_row+1)) ){
					*(node_rows+count)=(int)*(row_partition_vector+i)-1;
					count++;
				}
			}
		}
		else{
			count=0;
		}
		
		/*Now each node has a node_rows vectors holding which rows they should extract from matrixA. Create an Index Set from node_rows.*/
		ISCreateGeneral(MPI_COMM_WORLD,count,node_rows,&row_index);
		
		/*Same deal for columns*/
		node_cols=(int*)xmalloc(col_partition_vector_size*sizeof(int));
		for (i=0;i<col_partition_vector_size;i++){
			*(node_cols+i)=(int)*(col_partition_vector+i)-1;
		}
		ISCreateGeneral(MPI_COMM_WORLD,col_partition_vector_size,node_cols,&col_index);

		/*Call MatGetSubMatrix*/
		csize=PetscDetermineLocalSize(col_partition_vector_size);
		if(col_partition_vector_size==row_partition_vector_size){
			#if _PETSC_VERSION_ == 3 
			MatGetSubMatrix(matrixA,row_index,col_index,MAT_INITIAL_MATRIX,&outmatrix);
			#else
			MatGetSubMatrix(matrixA,row_index,col_index,count,MAT_INITIAL_MATRIX,&outmatrix);
			#endif
		}
		else{
			#if _PETSC_VERSION_ == 3 
			MatGetSubMatrix(matrixA,row_index,col_index,MAT_INITIAL_MATRIX,&outmatrix);
			#else
			MatGetSubMatrix(matrixA,row_index,col_index,csize,MAT_INITIAL_MATRIX,&outmatrix);
			#endif
		}

	}

	/*Free ressources:*/
	xfree((void**)&node_rows);
	xfree((void**)&node_cols);
	ISFree(&col_index);
	ISFree(&row_index);

	/*Assign output pointers:*/
	*poutmatrix=outmatrix;
}
