/*!\file:  MeshPartition.cpp
 * \brief: partition mesh according to number of areas, using Metis library.

	usage:
	[element_partitioning,node_partitioning]=MeshPartition(model,num_areas)
	
	%Info needed from model are the following: 
	%mesh info: 
	numberofelements,numberofgrids,elements,elements_width
	%Non-extruded 2d mesh info
	nel2d,nods2d,elements2d,
	%Extruded 2d mesh info
	nel2d_ext,nods2d_ext,elements2d_ext,
	%Diverse
	numlayers,meshtype)

	output:
	vector of partitioning area numbers, for every element.
	vector of partitioning area numbers, for every node.
*/
	
#include "./MeshPartition.h"


void mexFunction( int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {


	/*Indexing: */
	int i,j;

	/* required input: */
	char* meshtype=NULL;

	int numberofelements, numberofgrids, elements_width;
	double* elements=NULL;

	int nel2d, nods2d;
	double* elements2d=NULL;

	int nel2d_ext, nods2d_ext;
	double* elements2d_ext=NULL;
	int numlayers;
	int num_areas=1;

	/* output: */
	int*    int_element_partitioning=NULL;
	int*    int_node_partitioning   =NULL;
	double* element_partitioning    =NULL;
	double* node_partitioning       =NULL;


	/*Boot module: */
	MODULEBOOT();

	/*checks on arguments on the matlab side: */
	CheckNumMatlabArguments(nlhs,NLHS,nrhs,NRHS,__FUNCT__,&MeshPartitionUsage);

	/*Fetch data: */
	FetchData((void**)&meshtype,NULL,NULL,mxGetField(MODEL,0,"type"),"String",NULL);
	FetchData((void**)&numberofelements,NULL,NULL,mxGetField(MODEL,0,"numberofelements"),"Integer",NULL);
	FetchData((void**)&numberofgrids,NULL,NULL,mxGetField(MODEL,0,"numberofgrids"),"Integer",NULL);
	FetchData((void**)&elements,NULL,NULL,mxGetField(MODEL,0,"elements"),"Matrix","Mat");

	
	int MeshPartition(int** pepart, int numberofelements,int numberofgrids,double* elements,
		        int numberofelements2d,int numberofgrids2d,double* elements2d,int numlayers,int elements_width, char* meshtype,int num_procs){

	if (strcmp(meshtype,"3d")==0){
	
		FetchData((void**)&numberofelements2d,NULL,NULL,mxGetField(MODEL,0,"numberofelements2d"),"Integer",NULL);
		FetchData((void**)&numberofgrids2d,NULL,NULL,mxGetField(MODEL,0,"numberofgrids2d"),"Integer",NULL);

	}

	/*Number of extrusion layers: */
	pfield=mxGetField(prhs[0],0,"numlayers");
	numlayers= (int)(*(double*)mxGetPr(pfield));

	#ifdef _DEBUG_
	printf("meshtype: %s\n",meshtype);
	printf("numberofelements %i numberofgrids %i elements_width %i \n",numberofelements,numberofgrids,elements_width);
	for (i=0;i<numberofelements;i++){
		for(j=0;j<elements_width;j++){
			printf(" %i ",(int)*(elements+elements_width*i+j));
		}
		printf("\n");
	}
	if (strcmp(meshtype,"3d")==0){
		printf("numberofelements 2d %i numberofgrids 2d %i\n",nel2d,nods2d);
		for (i=0;i<nel2d;i++){
			for(j=0;j<3;j++){
				printf(" %i ",(int)*(elements2d+3*i+j));
			}
			printf("\n");
		}
		
		printf("numberofelements 2d_ext %i numberofgrids 2d_ext %i\n",nel2d_ext,nods2d_ext);
		for (i=0;i<nel2d_ext;i++){
			for(j=0;j<3;j++){
				printf(" %i ",(int)*(elements2d_ext+3*i+j));
			}
			printf("\n");
		}
	}
	printf("Number of extrusion layers: %i\n",numlayers);
	#endif



	/*Fetch number of processors: */
	num_areas= (int)(*(double*)mxGetPr(prhs[1]));
	#ifdef _DEBUG_
		printf("num_areas: %i\n",num_areas);
	#endif

	/*Run partitioning algorithm based on a "clever" use of the Metis partitioner: */
	MeshPartitionx(&int_element_partitioning,&int_node_partitioning,numberofelements,numberofgrids,elements,
		nel2d,nods2d,elements2d,nel2d_ext,nods2d_ext,elements2d_ext, numlayers,elements_width,meshtype,num_areas);

	/* output: */
	element_partitioning=mxMalloc(numberofelements*sizeof(double));
	for (i=0;i<numberofelements;i++){
		element_partitioning[i]=(double)int_element_partitioning[i]+1; //Metis indexing from 0, matlab from 1.
	}
	plhs[0]=mxCreateDoubleMatrix(numberofelements,1,mxREAL);
	mxSetPr(plhs[0],element_partitioning);

	node_partitioning=mxMalloc(numberofgrids*sizeof(double));
	for (i=0;i<numberofgrids;i++){
		node_partitioning[i]=(double)int_node_partitioning[i]+1; //Metis indexing from 0, matlab from 1.
	}
	plhs[1]=mxCreateDoubleMatrix(numberofgrids,1,mxREAL);
	mxSetPr(plhs[1],node_partitioning);

	/*end module: */
	MODULEEND();
}

void MeshPartitionUsage(void){
	printf("   usage:\n");
	printf("   [element_partitioning,node_partitioning]=MeshPartition(model,num_areas)");
	printf("   where:\n");
	printf("      model is a @model class object instance,num_areas is the number of processors on which partitioning will occur.\n");
	printf("      element_partitioning is a vector of partitioning area numbers, for every element.\n");
	printf("      node_partitioning is a vector of partitioning area numbers, for every node.\n");
	printf("\n");
}
