import numpy as np
import copy
from pairoptions import *
import MatlabFuncs as m
from adjacency import *
from Chaco import *
#from Scotch import *
from MeshPartition import *
from project3d import *
from mesh2d import *

def partitioner(md,*varargin):
	help ='''
PARTITIONER - partition mesh 

   List of options to partitioner: 

   package: 'chaco', 'metis'
   npart: number of partitions.
   weighting: 'on' or 'off': default off
   section:  1 by defaults(1=bisection, 2=quadrisection, 3=octasection)
   recomputeadjacency:  'on' by default (set to 'off' to compute existing one)
   Output: md.qmu.partition recover the partition vector
   
   Usage:
      md=partitioner(md,'package','chaco','npart',100,'weighting','on')
	'''

	#get options: 
	options=pairoptions(*varargin)

	#set defaults
	options.addfielddefault('package','chaco')
	options.addfielddefault('npart',10)
	options.addfielddefault('weighting','on')
	options.addfielddefault('section',1)
	options.addfielddefault('recomputeadjacency','on')

	#get package: 
	package=options.getfieldvalue('package')
	npart=options.getfieldvalue('npart')
	recomputeadjacency=options.getfieldvalue('recomputeadjacency')

	if(md.mesh.dimension()==3):
		#partitioning essentially happens in 2D. So partition in 2D, then 
		#extrude the partition vector vertically. 
		md3d = copy.deepcopy(md)
		md.mesh.elements=md.mesh.elements2d
		md.mesh.x=md.mesh.x2d
		md.mesh.y=md.mesh.y2d
		md.mesh.numberofvertices=md.mesh.numberofvertices2d
		md.mesh.numberofelements=md.mesh.numberofelements2d
		md.qmu.vertex_weight=[]
		md.mesh.vertexconnectivity=[]
		recomputeadjacency='on'

	#adjacency matrix if needed:
	if m.strcmpi(recomputeadjacency,'on'):
		md=adjacency(md)
	else:
		print('skipping adjacency matrix computation as requested in the options')

	if m.strcmpi(package,'chaco'):
		#raise RuntimeError('Chaco is not currently supported for this function')

		#  default method (from chaco.m)
		method=np.array([1,1,0,0,1,1,50,0,.001,7654321])
		method[0]=3    #  global method (3=inertial (geometric))
		method[2]=0    #  vertex weights (0=off, 1=on)

		#specify bisection
		method[5]=options.getfieldvalue('section')#  ndims (1=bisection, 2=quadrisection, 3=octasection)

		#are we using weights? 
		if m.strcmpi(options.getfieldvalue('weighting'),'on'):
			weights=np.floor(md.qmu.vertex_weight/min(md.qmu.vertex_weight))
			method[2]=1
		else:
			weights=[]
	
		method = method.reshape(-1,1)	# transpose to 1x10 instead of 10

		#  partition into nparts
		if isinstance(md.mesh,mesh2d):
			part=np.array(Chaco(md.qmu.adjacency,weights,np.array([]),md.mesh.x, md.mesh.y,np.zeros((md.mesh.numberofvertices,)),method,npart,np.array([]))).T+1 #index partitions from 1 up. like metis.
		else:
			part=np.array(Chaco(md.qmu.adjacency,weights,np.array([]),md.mesh.x, md.mesh.y,md.mesh.z,method,npart,np.array([]))).T+1 #index partitions from 1 up. like metis.
	
	elif m.strcmpi(package,'scotch'):
		raise RuntimeError('Scotch is not currently supported for this function')

		#are we using weights? 
		#if m.strcmpi(options.getfieldvalue('weighting'),'on'):
			#weights=np.floor(md.qmu.vertex_weight/min(md.qmu.vertex_weight))
		#else:
			#weights=[]
	
		#maptab=Scotch(md.qmu.adjacency,[],weights,[],'cmplt',[npart])

		#part=maptab[:,1]+1#index partitions from 1 up. like metis.

	elif m.strcmpi(package,'linear'):

		if (npart == md.mesh.numberofelements) or (md.qmu.numberofpartitions == md.mesh.numberofelements):
			part=np.arange(1,1+md.mesh.numberofelements,1)
			print('Linear partitioner requesting partitions on elements')
		else:
			part=np.arange(1,1+md.mesh.numberofvertices,1)

	elif m.strcmpi(package,'metis'):
		raise RuntimeError('Metis/MeshPartition is not currently supported for this function')
		#[element_partitioning,part]=MeshPartition(md,md.qmu.numberofpartitions)

	else:
		print(help)
		raise RuntimeError('partitioner error message: could not find '+str(package)+' partitioner')

	#extrude if we are in 3D:
	if md.mesh.dimension()==3:
		md3d.qmu.vertex_weight=md.qmu.vertex_weight
		md3d.qmu.adjacency=md.qmu.adjacency
		md=md3d
		part=project3d(md,'vector',np.squeeze(part),'type','node')

	md.qmu.partition=part

	return md
