import numpy
from model import *
from pairoptions import *
from MatlabFuncs import *
from FlagElements import *

def setflowequation(md,*args):
	"""
	SETELEMENTSTYPE - associate a solution type to each element

	   This routine works like plotmodel: it works with an even number of inputs
	   'hutter','macayeal','pattyn','l1l2','stokes' and 'fill' are the possible options
	   that must be followed by the corresponding exp file or flags list
	   It can either be a domain file (argus type, .exp extension), or an array of element flags. 
	   If user wants every element outside the domain to be 
	   setflowequationd, add '~' to the name of the domain file (ex: '~Pattyn.exp');
	   an empty string '' will be considered as an empty domain
	   a string 'all' will be considered as the entire domain
	   You can specify the type of coupling, 'penalties' or 'tiling', to use with the input 'coupling'

	   Usage:
	      md=setflowequation(md,varargin)

	   Example:
	      md=setflowequation(md,'pattyn','Pattyn.exp','macayeal',md.mask.elementonfloatingice,'fill','hutter');
	      md=setflowequation(md,'pattyn','Pattyn.exp',fill','hutter','coupling','tiling');
	"""

	#some checks on list of arguments
	if not isinstance(md,model) or not len(args):
		raise TypeError("setflowequation error message")

	#process options
	options=pairoptions(*args)
#	options=deleteduplicates(options,1);

	#Find_out what kind of coupling to use
	coupling_method=options.getfieldvalue('coupling','tiling')
	if not strcmpi(coupling_method,'tiling') and not strcmpi(coupling_method,'penalties'):
		raise TypeError("coupling type can only be: tiling or penalties")

	#recover elements distribution
	hutterflag   = FlagElements(md,options.getfieldvalue('hutter',''))
	macayealflag = FlagElements(md,options.getfieldvalue('macayeal',''))
	pattynflag   = FlagElements(md,options.getfieldvalue('pattyn',''))
	l1l2flag     = FlagElements(md,options.getfieldvalue('l1l2',''))
	stokesflag   = FlagElements(md,options.getfieldvalue('stokes',''))
	filltype     = options.getfieldvalue('fill','none')

	#Flag the elements that have not been flagged as filltype
	if   strcmpi(filltype,'hutter'):
		hutterflag[numpy.nonzero(numpy.logical_not(numpy.logical_or(macayealflag,pattynflag)))]=1
	elif strcmpi(filltype,'macayeal'):
		macayealflag[numpy.nonzero(numpy.logical_not(numpy.logical_or(hutterflag,numpy.logical_or(pattynflag,stokesflag))))]=1
	elif strcmpi(filltype,'pattyn'):
		pattynflag[numpy.nonzero(numpy.logical_not(numpy.logical_or(hutterflag,numpy.logical_or(macayealflag,stokesflag))))]=1

	#check that each element has at least one flag
	if not any(hutterflag+macayealflag+l1l2flag+pattynflag+stokesflag):
		raise TypeError("elements type not assigned, must be specified")

	#check that each element has only one flag
	if any(hutterflag+macayealflag+l1l2flag+pattynflag+stokesflag>1):
		print "setflowequation warning message: some elements have several types, higher order type is used for them"
		hutterflag[numpy.nonzero(numpy.logical_and(hutterflag,macayealflag))]=0
		hutterflag[numpy.nonzero(numpy.logical_and(hutterflag,pattynflag))]=0
		macayealflag[numpy.nonzero(numpy.logical_and(macayealflag,pattynflag))]=0

	#Check that no pattyn or stokes for 2d mesh
	if md.mesh.dimension==2:
		if numpy.any(numpy.logical_or(l1l2flag,stokesflag,pattynflag)):
			raise TypeError("stokes and pattyn elements not allowed in 2d mesh, extrude it first")

	#Stokes can only be used alone for now:
	if any(stokesflag) and any(hutterflag):
		raise TypeError("stokes cannot be used with any other model for now, put stokes everywhere")

	#Initialize node fields
	nodeonhutter=numpy.zeros(md.mesh.numberofvertices)
	nodeonhutter[md.mesh.elements[numpy.nonzero(hutterflag),:].astype(int)-1]=1
	nodeonmacayeal=numpy.zeros(md.mesh.numberofvertices)
	nodeonmacayeal[md.mesh.elements[numpy.nonzero(macayealflag),:].astype(int)-1]=1
	nodeonl1l2=numpy.zeros(md.mesh.numberofvertices)
	nodeonl1l2[md.mesh.elements[numpy.nonzero(l1l2flag),:].astype(int)-1]=1
	nodeonpattyn=numpy.zeros(md.mesh.numberofvertices)
	nodeonpattyn[md.mesh.elements[numpy.nonzero(pattynflag),:].astype(int)-1]=1
	nodeonstokes=numpy.zeros(md.mesh.numberofvertices)
	noneflag=numpy.zeros(md.mesh.numberofelements)

	#First modify stokesflag to get rid of elements contrained everywhere (spc + border with pattyn or macayeal)
	if any(stokesflag):
#		fullspcnodes=double((~isnan(md.diagnostic.spcvx)+~isnan(md.diagnostic.spcvy)+~isnan(md.diagnostic.spcvz))==3 | (nodeonpattyn & nodeonstokes));         %find all the nodes on the boundary of the domain without icefront
		fullspcnodes=numpy.logical_or(numpy.logical_not(numpy.isnan(md.diagnostic.spcvx))+ \
		                              numpy.logical_not(numpy.isnan(md.diagnostic.spcvy))+ \
		                              numpy.logical_not(numpy.isnan(md.diagnostic.spcvz))==3, \
		                              numpy.logical_and(nodeonpattyn,nodeonstokes)).astype(int)    #find all the nodes on the boundary of the domain without icefront
#		fullspcelems=double(sum(fullspcnodes(md.mesh.elements),2)==6);         %find all the nodes on the boundary of the domain without icefront
		fullspcelems=(numpy.sum(fullspcnodes[md.mesh.elements.astype(int)-1],axis=1)==6).astype(int)    #find all the nodes on the boundary of the domain without icefront
		stokesflag[numpy.nonzero(fullspcelems)]=0
		nodeonstokes[md.mesh.elements[numpy.nonzero(stokesflag),:].astype(int)-1]=1

	#Then complete with NoneApproximation or the other model used if there is no stokes
	if any(stokesflag): 
		if   any(pattynflag):    #fill with pattyn
			pattynflag[numpy.logical_not(stokesflag)]=1
			nodeonpattyn[md.mesh.elements[numpy.nonzero(pattynflag),:].astype(int)-1]=1
		elif any(macayealflag):    #fill with macayeal
			macayealflag[numpy.logical_not(stokesflag)]=1
			nodeonmacayeal[md.mesh.elements[numpy.nonzero(macayealflag),:].astype(int)-1]=1
		else:    #fill with none 
			noneflag[numpy.nonzero(numpy.logical_not(stokesflag))]=1

	#Now take care of the coupling between MacAyeal and Pattyn
	md.diagnostic.vertex_pairing=numpy.array([])
	nodeonmacayealpattyn=numpy.zeros(md.mesh.numberofvertices)
	nodeonpattynstokes=numpy.zeros(md.mesh.numberofvertices)
	nodeonmacayealstokes=numpy.zeros(md.mesh.numberofvertices)
	macayealpattynflag=numpy.zeros(md.mesh.numberofelements)
	macayealstokesflag=numpy.zeros(md.mesh.numberofelements)
	pattynstokesflag=numpy.zeros(md.mesh.numberofelements)
	if   strcmpi(coupling_method,'penalties'):
		#Create the border nodes between Pattyn and MacAyeal and extrude them
		numnodes2d=md.mesh.numberofvertices2d
		numlayers=md.mesh.numberoflayers
		bordernodes2d=numpy.nonzero(numpy.logical_and(nodeonpattyn[1:numnodes2d],nodeonmacayeal[1:numnodes2d]))    #Nodes connected to two different types of elements

		#initialize and fill in penalties structure
		if numpy.all(numpy.logical_not(numpy.isnan(bordernodes2d))):
			penalties=numpy.zeros((0,2))
			for	i in xrange(1,numlayers):
				penalties=numpy.concatenate((penalties,numpy.concatenate((bordernodes2d,bordernodes2d+md.mesh.numberofvertices2d*(i)),axis=1)),axis=0)
			md.diagnostic.vertex_pairing=penalties

	elif strcmpi(coupling_method,'tiling'):
		if   any(macayealflag) and any(pattynflag):    #coupling macayeal pattyn
			#Find node at the border
			nodeonmacayealpattyn[numpy.nonzero(numpy.logical_and(nodeonmacayeal,nodeonpattyn))]=1
			#Macayeal elements in contact with this layer become MacAyealPattyn elements
			matrixelements=ismember(md.mesh.elements,numpy.nonzero(nodeonmacayealpattyn))
			commonelements=numpy.sum(matrixelements,axis=1)!=0
			commonelements[numpy.nonzero(pattynflag)]=0    #only one layer: the elements previously in macayeal
			macayealflag[numpy.nonzero(commonelements)]=0    #these elements are now macayealpattynelements
			macayealpattynflag[numpy.nonzero(commonelements)]=1
			nodeonmacayeal[:]=0
			nodeonmacayeal[md.mesh.elements[numpy.nonzero(macayealflag),:].astype(int)-1]=1

			#rule out elements that don't touch the 2 boundaries
			pos=numpy.nonzero(macayealpattynflag)
			elist=numpy.zeros(len(pos))
			elist = elist + numpy.any(numpy.sum(nodeonmacayeal[md.mesh.elements[pos,:].astype(int)-1],axis=1),axis=1)
			elist = elist - numpy.any(numpy.sum(nodeonpattyn[md.mesh.elements[pos,:].astype(int)-1]  ,axis=1),axis=1)
			pos1=numpy.nonzero(elist==1)
			macayealflag[pos[pos1]]=1
			macayealpattynflag[pos[pos1]]=0
			pos2=numpy.nonzero(elist==-1)
			pattynflag[pos[pos2]]=1
			macayealpattynflag[pos[pos2]]=0

			#Recompute nodes associated to these elements
			nodeonmacayeal[:]=0
			nodeonmacayeal[md.mesh.elements[numpy.nonzero(macayealflag),:].astype(int)-1]=1
			nodeonpattyn[:]=0
			nodeonpattyn[md.mesh.elements[numpy.nonzero(pattynflag),:].astype(int)-1]=1
			nodeonmacayealpattyn[:]=0
			nodeonmacayealpattyn[md.mesh.elements[numpy.nonzero(macayealpattynflag),:].astype(int)-1]=1

		elif any(pattynflag) and any(stokesflag):    #coupling pattyn stokes
			#Find node at the border
			nodeonpattynstokes[numpy.nonzero(numpy.logical_and(nodeonpattyn,nodeonstokes))]=1
			#Stokes elements in contact with this layer become PattynStokes elements
			matrixelements=ismember(md.mesh.elements,numpy.nonzero(nodeonpattynstokes))
			commonelements=numpy.sum(matrixelements,axis=1)!=0
			commonelements[numpy.nonzero(pattynflag)]=0    #only one layer: the elements previously in macayeal
			stokesflag[numpy.nonzero(commonelements)]=0    #these elements are now macayealpattynelements
			pattynstokesflag[numpy.nonzero(commonelements)]=1
			nodeonstokes=numpy.zeros(md.mesh.numberofvertices)
			nodeonstokes[md.mesh.elements[numpy.nonzero(stokesflag),:].astype(int)-1]=1

			#rule out elements that don't touch the 2 boundaries
			pos=numpy.nonzero(pattynstokesflag)
			elist=numpy.zeros(len(pos))
			elist = elist + numpy.any(numpy.sum(nodeonstokes[md.mesh.elements[pos,:].astype(int)-1],axis=1),axis=1)
			elist = elist - numpy.any(numpy.sum(nodeonpattyn[md.mesh.elements[pos,:].astype(int)-1],axis=1),axis=1)
			pos1=numpy.nonzero(elist==1)
			stokesflag[pos[pos1]]=1
			pattynstokesflag[pos[pos1]]=0
			pos2=numpy.nonzero(elist==-1)
			pattynflag[pos[pos2]]=1
			pattynstokesflag[pos[pos2]]=0

			#Recompute nodes associated to these elements
			nodeonstokes[:]=0
			nodeonstokes[md.mesh.elements[numpy.nonzero(stokesflag),:].astype(int)-1]=1
			nodeonpattyn[:]=0
			nodeonpattyn[md.mesh.elements[numpy.nonzero(pattynflag),:].astype(int)-1]=1
			nodeonpattynstokes[:]=0
			nodeonpattynstokes[md.mesh.elements[numpy.nonzero(pattynstokesflag),:].astype(int)-1]=1

		elif any(stokesflag) and any(macayealflag):
			#Find node at the border
			nodeonmacayealstokes[numpy.nonzero(numpy.logical_and(nodeonmacayeal,nodeonstokes))]=1
			#Stokes elements in contact with this layer become MacAyealStokes elements
			matrixelements=ismember(md.mesh.elements,numpy.nonzero(nodeonmacayealstokes))
			commonelements=numpy.sum(matrixelements,axis=1)!=0
			commonelements[numpy.nonzero(macayealflag)]=0    #only one layer: the elements previously in macayeal
			stokesflag[numpy.nonzero(commonelements)]=0    #these elements are now macayealmacayealelements
			macayealstokesflag[numpy.nonzero(commonelements)]=1
			nodeonstokes=numpy.zeros(md.mesh.numberofvertices)
			nodeonstokes[md.mesh.elements[numpy.nonzero(stokesflag),:].astype(int)-1]=1

			#rule out elements that don't touch the 2 boundaries
			pos=numpy.nonzero(macayealstokesflag)
			elist=numpy.zeros(len(pos))
			elist = elist + numpy.any(numpy.sum(nodeonmacayeal[md.mesh.elements[pos,:].astype(int)-1],axis=1),axis=1)
			elist = elist - numpy.any(numpy.sum(nodeonstokes[md.mesh.elements[pos,:].astype(int)-1]  ,axis=1),axis=1)
			pos1=numpy.nonzero(elist==1)
			macayealflag[pos[pos1]]=1
			macayealstokesflag[pos[pos1]]=0
			pos2=numpy.nonzero(elist==-1)
			stokesflag[pos[pos2]]=1
			macayealstokesflag[pos[pos2]]=0

			#Recompute nodes associated to these elements
			nodeonmacayeal[:]=0
			nodeonmacayeal[md.mesh.elements[numpy.nonzero(macayealflag),:].astype(int)-1]=1
			nodeonstokes[:]=0
			nodeonstokes[md.mesh.elements[numpy.nonzero(stokesflag),:].astype(int)-1]=1
			nodeonmacayealstokes[:]=0
			nodeonmacayealstokes[md.mesh.elements[numpy.nonzero(macayealstokesflag),:].astype(int)-1]=1

		elif any(stokesflag) and any(hutterflag):
			raise TypeError("type of coupling not supported yet")

	#Create MacaAyealPattynApproximation where needed
	md.flowequation.element_equation=numpy.zeros(md.mesh.numberofelements)
	md.flowequation.element_equation[numpy.nonzero(noneflag)]=0
	md.flowequation.element_equation[numpy.nonzero(hutterflag)]=1
	md.flowequation.element_equation[numpy.nonzero(macayealflag)]=2
	md.flowequation.element_equation[numpy.nonzero(l1l2flag)]=8
	md.flowequation.element_equation[numpy.nonzero(pattynflag)]=3
	md.flowequation.element_equation[numpy.nonzero(stokesflag)]=4
	md.flowequation.element_equation[numpy.nonzero(macayealpattynflag)]=5
	md.flowequation.element_equation[numpy.nonzero(macayealstokesflag)]=6
	md.flowequation.element_equation[numpy.nonzero(pattynstokesflag)]=7

	#border
	md.flowequation.borderpattyn=nodeonpattyn
	md.flowequation.bordermacayeal=nodeonmacayeal
	md.flowequation.borderstokes=nodeonstokes

	#Create vertices_type
	md.flowequation.vertex_equation=numpy.zeros(md.mesh.numberofvertices)
	pos=numpy.nonzero(nodeonmacayeal)
	md.flowequation.vertex_equation[pos]=2
	pos=numpy.nonzero(nodeonl1l2)
	md.flowequation.vertex_equation[pos]=8
	pos=numpy.nonzero(nodeonpattyn)
	md.flowequation.vertex_equation[pos]=3
	pos=numpy.nonzero(nodeonhutter)
	md.flowequation.vertex_equation[pos]=1
	pos=numpy.nonzero(nodeonmacayealpattyn)
	md.flowequation.vertex_equation[pos]=5
	pos=numpy.nonzero(nodeonstokes)
	md.flowequation.vertex_equation[pos]=4
	if any(stokesflag):
		pos=numpy.nonzero(numpy.logical_not(nodeonstokes))
		if not (any(pattynflag) or any(macayealflag)):
			md.flowequation.vertex_equation[pos]=0
	pos=numpy.nonzero(nodeonpattynstokes)
	md.flowequation.vertex_equation[pos]=7
	pos=numpy.nonzero(nodeonmacayealstokes)
	md.flowequation.vertex_equation[pos]=6

	#figure out solution types
	md.flowequation.ishutter=float(any(md.flowequation.element_equation==1))
	md.flowequation.ismacayealpattyn=float(numpy.any(numpy.logical_or(md.flowequation.element_equation==2,md.flowequation.element_equation==3)))
	md.flowequation.isl1l2=float(any(md.flowequation.element_equation==8))
	md.flowequation.isstokes=float(any(md.flowequation.element_equation==4))

	return md

	#Check that tiling can work:
	if any(md.flowequation.bordermacayeal) and any(md.flowequation.borderpattyn) and any(md.flowequation.borderpattyn + md.flowequation.bordermacayeal !=1):
		raise TypeError("error coupling domain too irregular")
	if any(md.flowequation.bordermacayeal) and any(md.flowequation.borderstokes) and any(md.flowequation.borderstokes + md.flowequation.bordermacayeal !=1):
		raise TypeError("error coupling domain too irregular")
	if any(md.flowequation.borderstokes) and any(md.flowequation.borderpattyn) and any(md.flowequation.borderpattyn + md.flowequation.borderstokes !=1):
		raise TypeError("error coupling domain too irregular")

	return md

