import numpy
from model import *
from pairoptions import *
from MatlabFuncs import *
from FlagElements import *

def setflowequation(md,*args):
	"""
	SETELEMENTSTYPE - associate a solution type to each element

	   This routine works like plotmodel: it works with an even number of inputs
	   'hutter','macayeal','pattyn','l1l2','stokes' and 'fill' are the possible options
	   that must be followed by the corresponding exp file or flags list
	   It can either be a domain file (argus type, .exp extension), or an array of element flags. 
	   If user wants every element outside the domain to be 
	   setflowequationd, add '~' to the name of the domain file (ex: '~Pattyn.exp');
	   an empty string '' will be considered as an empty domain
	   a string 'all' will be considered as the entire domain
	   You can specify the type of coupling, 'penalties' or 'tiling', to use with the input 'coupling'

	   Usage:
	      md=setflowequation(md,varargin)

	   Example:
	      md=setflowequation(md,'pattyn','Pattyn.exp','macayeal',md.mask.elementonfloatingice,'fill','hutter');
	      md=setflowequation(md,'pattyn','Pattyn.exp',fill','hutter','coupling','tiling');
	"""

	#some checks on list of arguments
	if not isinstance(md,model) or not len(args):
		raise TypeError("setflowequation error message")

	#process options
	options=pairoptions(*args)
#	options=deleteduplicates(options,1);

	#Find_out what kind of coupling to use
	coupling_method=options.getfieldvalue('coupling','tiling')
	if not strcmpi(coupling_method,'tiling') and not strcmpi(coupling_method,'penalties'):
		raise TypeError("coupling type can only be: tiling or penalties")

	#recover elements distribution
	hutterflag   = FlagElements(md,options.getfieldvalue('hutter',''))
	macayealflag = FlagElements(md,options.getfieldvalue('macayeal',''))
	pattynflag   = FlagElements(md,options.getfieldvalue('pattyn',''))
	l1l2flag     = FlagElements(md,options.getfieldvalue('l1l2',''))
	stokesflag   = FlagElements(md,options.getfieldvalue('stokes',''))
	filltype     = options.getfieldvalue('fill','none')

	#Flag the elements that have not been flagged as filltype
	if   strcmpi(filltype,'hutter'):
		hutterflag[numpy.nonzero(numpy.logical_not(numpy.logical_or(macayealflag,pattynflag)))]=True
	elif strcmpi(filltype,'macayeal'):
		macayealflag[numpy.nonzero(numpy.logical_not(numpy.logical_or(hutterflag,numpy.logical_or(pattynflag,stokesflag))))]=True
	elif strcmpi(filltype,'pattyn'):
		pattynflag[numpy.nonzero(numpy.logical_not(numpy.logical_or(hutterflag,numpy.logical_or(macayealflag,stokesflag))))]=True

	#check that each element has at least one flag
	if not any(hutterflag+macayealflag+l1l2flag+pattynflag+stokesflag):
		raise TypeError("elements type not assigned, must be specified")

	#check that each element has only one flag
	if any(hutterflag+macayealflag+l1l2flag+pattynflag+stokesflag>1):
		print "setflowequation warning message: some elements have several types, higher order type is used for them"
		hutterflag[numpy.nonzero(numpy.logical_and(hutterflag,macayealflag))]=False
		hutterflag[numpy.nonzero(numpy.logical_and(hutterflag,pattynflag))]=False
		macayealflag[numpy.nonzero(numpy.logical_and(macayealflag,pattynflag))]=False

	#Check that no pattyn or stokes for 2d mesh
	if md.mesh.dimension==2:
		if numpy.any(numpy.logical_or(l1l2flag,stokesflag,pattynflag)):
			raise TypeError("stokes and pattyn elements not allowed in 2d mesh, extrude it first")

	#Stokes can only be used alone for now:
	if any(stokesflag) and any(hutterflag):
		raise TypeError("stokes cannot be used with any other model for now, put stokes everywhere")

	#Initialize node fields
	nodeonhutter=numpy.zeros(md.mesh.numberofvertices,bool)
	nodeonhutter[md.mesh.elements[numpy.nonzero(hutterflag),:]-1]=True
	nodeonmacayeal=numpy.zeros(md.mesh.numberofvertices,bool)
	nodeonmacayeal[md.mesh.elements[numpy.nonzero(macayealflag),:]-1]=True
	nodeonl1l2=numpy.zeros(md.mesh.numberofvertices,bool)
	nodeonl1l2[md.mesh.elements[numpy.nonzero(l1l2flag),:]-1]=True
	nodeonpattyn=numpy.zeros(md.mesh.numberofvertices,bool)
	nodeonpattyn[md.mesh.elements[numpy.nonzero(pattynflag),:]-1]=True
	nodeonstokes=numpy.zeros(md.mesh.numberofvertices,bool)
	noneflag=numpy.zeros(md.mesh.numberofelements,bool)

	#First modify stokesflag to get rid of elements contrained everywhere (spc + border with pattyn or macayeal)
	if any(stokesflag):
#		fullspcnodes=double((~isnan(md.diagnostic.spcvx)+~isnan(md.diagnostic.spcvy)+~isnan(md.diagnostic.spcvz))==3 | (nodeonpattyn & nodeonstokes));         %find all the nodes on the boundary of the domain without icefront
		fullspcnodes=numpy.logical_or(numpy.logical_not(numpy.isnan(md.diagnostic.spcvx)).astype(int)+ \
		                              numpy.logical_not(numpy.isnan(md.diagnostic.spcvy)).astype(int)+ \
		                              numpy.logical_not(numpy.isnan(md.diagnostic.spcvz)).astype(int)==3, \
		                              numpy.logical_and(nodeonpattyn,nodeonstokes).reshape(-1,1)).astype(int)    #find all the nodes on the boundary of the domain without icefront
#		fullspcelems=double(sum(fullspcnodes(md.mesh.elements),2)==6);         %find all the nodes on the boundary of the domain without icefront
		fullspcelems=(numpy.sum(fullspcnodes[md.mesh.elements-1],axis=1)==6).astype(int)    #find all the nodes on the boundary of the domain without icefront
		stokesflag[numpy.nonzero(fullspcelems.reshape(-1))]=False
		nodeonstokes[md.mesh.elements[numpy.nonzero(stokesflag),:]-1]=True

	#Then complete with NoneApproximation or the other model used if there is no stokes
	if any(stokesflag): 
		if   any(pattynflag):    #fill with pattyn
			pattynflag[numpy.logical_not(stokesflag)]=True
			nodeonpattyn[md.mesh.elements[numpy.nonzero(pattynflag),:]-1]=True
		elif any(macayealflag):    #fill with macayeal
			macayealflag[numpy.logical_not(stokesflag)]=True
			nodeonmacayeal[md.mesh.elements[numpy.nonzero(macayealflag),:]-1]=True
		else:    #fill with none 
			noneflag[numpy.nonzero(numpy.logical_not(stokesflag))]=True

	#Now take care of the coupling between MacAyeal and Pattyn
	md.diagnostic.vertex_pairing=numpy.array([])
	nodeonmacayealpattyn=numpy.zeros(md.mesh.numberofvertices,bool)
	nodeonpattynstokes=numpy.zeros(md.mesh.numberofvertices,bool)
	nodeonmacayealstokes=numpy.zeros(md.mesh.numberofvertices,bool)
	macayealpattynflag=numpy.zeros(md.mesh.numberofelements,bool)
	macayealstokesflag=numpy.zeros(md.mesh.numberofelements,bool)
	pattynstokesflag=numpy.zeros(md.mesh.numberofelements,bool)
	if   strcmpi(coupling_method,'penalties'):
		#Create the border nodes between Pattyn and MacAyeal and extrude them
		numnodes2d=md.mesh.numberofvertices2d
		numlayers=md.mesh.numberoflayers
		bordernodes2d=numpy.nonzero(numpy.logical_and(nodeonpattyn[0:numnodes2d],nodeonmacayeal[0:numnodes2d]))[0]+1    #Nodes connected to two different types of elements

		#initialize and fill in penalties structure
		if numpy.all(numpy.logical_not(numpy.isnan(bordernodes2d))):
			penalties=numpy.zeros((0,2))
			for	i in xrange(1,numlayers):
				penalties=numpy.vstack((penalties,numpy.hstack((bordernodes2d.reshape(-1,1),bordernodes2d.reshape(-1,1)+md.mesh.numberofvertices2d*(i)))))
			md.diagnostic.vertex_pairing=penalties

	elif strcmpi(coupling_method,'tiling'):
		if   any(macayealflag) and any(pattynflag):    #coupling macayeal pattyn
			#Find node at the border
			nodeonmacayealpattyn[numpy.nonzero(numpy.logical_and(nodeonmacayeal,nodeonpattyn))]=True
			#Macayeal elements in contact with this layer become MacAyealPattyn elements
			matrixelements=ismember(md.mesh.elements-1,numpy.nonzero(nodeonmacayealpattyn)[0])
			commonelements=numpy.sum(matrixelements,axis=1)!=0
			commonelements[numpy.nonzero(pattynflag)]=False    #only one layer: the elements previously in macayeal
			macayealflag[numpy.nonzero(commonelements)]=False    #these elements are now macayealpattynelements
			macayealpattynflag[numpy.nonzero(commonelements)]=True
			nodeonmacayeal[:]=False
			nodeonmacayeal[md.mesh.elements[numpy.nonzero(macayealflag),:]-1]=True

			#rule out elements that don't touch the 2 boundaries
			pos=numpy.nonzero(macayealpattynflag)[0]
			elist=numpy.zeros(numpy.size(pos),dtype=int)
			elist = elist + numpy.sum(nodeonmacayeal[md.mesh.elements[pos,:]-1],axis=1).astype(bool)
			elist = elist - numpy.sum(nodeonpattyn[md.mesh.elements[pos,:]-1]  ,axis=1).astype(bool)
			pos1=numpy.nonzero(elist==1)[0]
			macayealflag[pos[pos1]]=True
			macayealpattynflag[pos[pos1]]=False
			pos2=numpy.nonzero(elist==-1)[0]
			pattynflag[pos[pos2]]=True
			macayealpattynflag[pos[pos2]]=False

			#Recompute nodes associated to these elements
			nodeonmacayeal[:]=False
			nodeonmacayeal[md.mesh.elements[numpy.nonzero(macayealflag),:]-1]=True
			nodeonpattyn[:]=False
			nodeonpattyn[md.mesh.elements[numpy.nonzero(pattynflag),:]-1]=True
			nodeonmacayealpattyn[:]=False
			nodeonmacayealpattyn[md.mesh.elements[numpy.nonzero(macayealpattynflag),:]-1]=True

		elif any(pattynflag) and any(stokesflag):    #coupling pattyn stokes
			#Find node at the border
			nodeonpattynstokes[numpy.nonzero(numpy.logical_and(nodeonpattyn,nodeonstokes))]=True
			#Stokes elements in contact with this layer become PattynStokes elements
			matrixelements=ismember(md.mesh.elements-1,numpy.nonzero(nodeonpattynstokes)[0])
			commonelements=numpy.sum(matrixelements,axis=1)!=0
			commonelements[numpy.nonzero(pattynflag)]=False    #only one layer: the elements previously in macayeal
			stokesflag[numpy.nonzero(commonelements)]=False    #these elements are now macayealpattynelements
			pattynstokesflag[numpy.nonzero(commonelements)]=True
			nodeonstokes=numpy.zeros(md.mesh.numberofvertices,bool)
			nodeonstokes[md.mesh.elements[numpy.nonzero(stokesflag),:]-1]=True

			#rule out elements that don't touch the 2 boundaries
			pos=numpy.nonzero(pattynstokesflag)[0]
			elist=numpy.zeros(numpy.size(pos),dtype=int)
			elist = elist + numpy.sum(nodeonstokes[md.mesh.elements[pos,:]-1],axis=1).astype(bool)
			elist = elist - numpy.sum(nodeonpattyn[md.mesh.elements[pos,:]-1],axis=1).astype(bool)
			pos1=numpy.nonzero(elist==1)[0]
			stokesflag[pos[pos1]]=True
			pattynstokesflag[pos[pos1]]=False
			pos2=numpy.nonzero(elist==-1)[0]
			pattynflag[pos[pos2]]=True
			pattynstokesflag[pos[pos2]]=False

			#Recompute nodes associated to these elements
			nodeonstokes[:]=False
			nodeonstokes[md.mesh.elements[numpy.nonzero(stokesflag),:]-1]=True
			nodeonpattyn[:]=False
			nodeonpattyn[md.mesh.elements[numpy.nonzero(pattynflag),:]-1]=True
			nodeonpattynstokes[:]=False
			nodeonpattynstokes[md.mesh.elements[numpy.nonzero(pattynstokesflag),:]-1]=True

		elif any(stokesflag) and any(macayealflag):
			#Find node at the border
			nodeonmacayealstokes[numpy.nonzero(numpy.logical_and(nodeonmacayeal,nodeonstokes))]=True
			#Stokes elements in contact with this layer become MacAyealStokes elements
			matrixelements=ismember(md.mesh.elements-1,numpy.nonzero(nodeonmacayealstokes)[0])
			commonelements=numpy.sum(matrixelements,axis=1)!=0
			commonelements[numpy.nonzero(macayealflag)]=False    #only one layer: the elements previously in macayeal
			stokesflag[numpy.nonzero(commonelements)]=False    #these elements are now macayealmacayealelements
			macayealstokesflag[numpy.nonzero(commonelements)]=True
			nodeonstokes=numpy.zeros(md.mesh.numberofvertices,bool)
			nodeonstokes[md.mesh.elements[numpy.nonzero(stokesflag),:]-1]=True

			#rule out elements that don't touch the 2 boundaries
			pos=numpy.nonzero(macayealstokesflag)[0]
			elist=numpy.zeros(numpy.size(pos),dtype=int)
			elist = elist + numpy.sum(nodeonmacayeal[md.mesh.elements[pos,:]-1],axis=1).astype(bool)
			elist = elist - numpy.sum(nodeonstokes[md.mesh.elements[pos,:]-1]  ,axis=1).astype(bool)
			pos1=numpy.nonzero(elist==1)[0]
			macayealflag[pos[pos1]]=True
			macayealstokesflag[pos[pos1]]=False
			pos2=numpy.nonzero(elist==-1)[0]
			stokesflag[pos[pos2]]=True
			macayealstokesflag[pos[pos2]]=False

			#Recompute nodes associated to these elements
			nodeonmacayeal[:]=False
			nodeonmacayeal[md.mesh.elements[numpy.nonzero(macayealflag),:]-1]=True
			nodeonstokes[:]=False
			nodeonstokes[md.mesh.elements[numpy.nonzero(stokesflag),:]-1]=True
			nodeonmacayealstokes[:]=False
			nodeonmacayealstokes[md.mesh.elements[numpy.nonzero(macayealstokesflag),:]-1]=True

		elif any(stokesflag) and any(hutterflag):
			raise TypeError("type of coupling not supported yet")

	#Create MacaAyealPattynApproximation where needed
	md.flowequation.element_equation=numpy.zeros(md.mesh.numberofelements,int)
	md.flowequation.element_equation[numpy.nonzero(noneflag)]=0
	md.flowequation.element_equation[numpy.nonzero(hutterflag)]=1
	md.flowequation.element_equation[numpy.nonzero(macayealflag)]=2
	md.flowequation.element_equation[numpy.nonzero(l1l2flag)]=8
	md.flowequation.element_equation[numpy.nonzero(pattynflag)]=3
	md.flowequation.element_equation[numpy.nonzero(stokesflag)]=4
	md.flowequation.element_equation[numpy.nonzero(macayealpattynflag)]=5
	md.flowequation.element_equation[numpy.nonzero(macayealstokesflag)]=6
	md.flowequation.element_equation[numpy.nonzero(pattynstokesflag)]=7

	#border
	md.flowequation.borderpattyn=nodeonpattyn
	md.flowequation.bordermacayeal=nodeonmacayeal
	md.flowequation.borderstokes=nodeonstokes

	#Create vertices_type
	md.flowequation.vertex_equation=numpy.zeros(md.mesh.numberofvertices,int)
	pos=numpy.nonzero(nodeonmacayeal)
	md.flowequation.vertex_equation[pos]=2
	pos=numpy.nonzero(nodeonl1l2)
	md.flowequation.vertex_equation[pos]=8
	pos=numpy.nonzero(nodeonpattyn)
	md.flowequation.vertex_equation[pos]=3
	pos=numpy.nonzero(nodeonhutter)
	md.flowequation.vertex_equation[pos]=1
	pos=numpy.nonzero(nodeonmacayealpattyn)
	md.flowequation.vertex_equation[pos]=5
	pos=numpy.nonzero(nodeonstokes)
	md.flowequation.vertex_equation[pos]=4
	if any(stokesflag):
		pos=numpy.nonzero(numpy.logical_not(nodeonstokes))
		if not (any(pattynflag) or any(macayealflag)):
			md.flowequation.vertex_equation[pos]=0
	pos=numpy.nonzero(nodeonpattynstokes)
	md.flowequation.vertex_equation[pos]=7
	pos=numpy.nonzero(nodeonmacayealstokes)
	md.flowequation.vertex_equation[pos]=6

	#figure out solution types
	md.flowequation.ishutter=any(md.flowequation.element_equation==1)
	md.flowequation.ismacayealpattyn=bool(numpy.any(numpy.logical_or(md.flowequation.element_equation==2,md.flowequation.element_equation==3)))
	md.flowequation.isl1l2=any(md.flowequation.element_equation==8)
	md.flowequation.isstokes=any(md.flowequation.element_equation==4)

	return md

	#Check that tiling can work:
	if any(md.flowequation.bordermacayeal) and any(md.flowequation.borderpattyn) and any(md.flowequation.borderpattyn + md.flowequation.bordermacayeal !=1):
		raise TypeError("error coupling domain too irregular")
	if any(md.flowequation.bordermacayeal) and any(md.flowequation.borderstokes) and any(md.flowequation.borderstokes + md.flowequation.bordermacayeal !=1):
		raise TypeError("error coupling domain too irregular")
	if any(md.flowequation.borderstokes) and any(md.flowequation.borderpattyn) and any(md.flowequation.borderpattyn + md.flowequation.borderstokes !=1):
		raise TypeError("error coupling domain too irregular")

	return md

