import numpy
from model import *
from pairoptions import *
from MatlabFuncs import *
from PythonFuncs import *
from FlagElements import *

def setflowequation(md,*args):
	"""
	SETELEMENTSTYPE - associate a solution type to each element

	   This routine works like plotmodel: it works with an even number of inputs
	   'SIA','SSA','HO','L1L2','FS' and 'fill' are the possible options
	   that must be followed by the corresponding exp file or flags list
	   It can either be a domain file (argus type, .exp extension), or an array of element flags. 
	   If user wants every element outside the domain to be 
	   setflowequationd, add '~' to the name of the domain file (ex: '~HO.exp');
	   an empty string '' will be considered as an empty domain
	   a string 'all' will be considered as the entire domain
	   You can specify the type of coupling, 'penalties' or 'tiling', to use with the input 'coupling'

	   Usage:
	      md=setflowequation(md,varargin)

	   Example:
	      md=setflowequation(md,'HO','HO.exp','SSA',md.mask.elementonfloatingice,'fill','SIA');
	      md=setflowequation(md,'HO','HO.exp',fill','SIA','coupling','tiling');
	"""

	#some checks on list of arguments
	if not isinstance(md,model) or not len(args):
		raise TypeError("setflowequation error message")

	#process options
	options=pairoptions(*args)
#	options=deleteduplicates(options,1);

	#Find_out what kind of coupling to use
	coupling_method=options.getfieldvalue('coupling','tiling')
	if not strcmpi(coupling_method,'tiling') and not strcmpi(coupling_method,'penalties'):
		raise TypeError("coupling type can only be: tiling or penalties")

	#recover elements distribution
	SIAflag   = FlagElements(md,options.getfieldvalue('SIA',''))
	SSAflag = FlagElements(md,options.getfieldvalue('SSA',''))
	HOflag   = FlagElements(md,options.getfieldvalue('HO',''))
	L1L2flag     = FlagElements(md,options.getfieldvalue('L1L2',''))
	FSflag   = FlagElements(md,options.getfieldvalue('FS',''))
	filltype     = options.getfieldvalue('fill','none')

	#Flag the elements that have not been flagged as filltype
	if   strcmpi(filltype,'SIA'):
		SIAflag[numpy.nonzero(numpy.logical_not(logical_or_n(SSAflag,HOflag)))]=True
	elif strcmpi(filltype,'SSA'):
		SSAflag[numpy.nonzero(numpy.logical_not(logical_or_n(SIAflag,HOflag,FSflag)))]=True
	elif strcmpi(filltype,'HO'):
		HOflag[numpy.nonzero(numpy.logical_not(logical_or_n(SIAflag,SSAflag,FSflag)))]=True

	#check that each element has at least one flag
	if not any(SIAflag+SSAflag+L1L2flag+HOflag+FSflag):
		raise TypeError("elements type not assigned, must be specified")

	#check that each element has only one flag
	if any(SIAflag+SSAflag+L1L2flag+HOflag+FSflag>1):
		print "setflowequation warning message: some elements have several types, higher order type is used for them"
		SIAflag[numpy.nonzero(numpy.logical_and(SIAflag,SSAflag))]=False
		SIAflag[numpy.nonzero(numpy.logical_and(SIAflag,HOflag))]=False
		SSAflag[numpy.nonzero(numpy.logical_and(SSAflag,HOflag))]=False

	#Check that no HO or FS for 2d mesh
	if md.mesh.dimension==2:
		if numpy.any(logical_or_n(L1L2flag,FSflag,HOflag)):
			raise TypeError("FS and HO elements not allowed in 2d mesh, extrude it first")

	#FS can only be used alone for now:
	if any(FSflag) and any(SIAflag):
		raise TypeError("FS cannot be used with any other model for now, put FS everywhere")

	#Initialize node fields
	nodeonSIA=numpy.zeros(md.mesh.numberofvertices,bool)
	nodeonSIA[md.mesh.elements[numpy.nonzero(SIAflag),:]-1]=True
	nodeonSSA=numpy.zeros(md.mesh.numberofvertices,bool)
	nodeonSSA[md.mesh.elements[numpy.nonzero(SSAflag),:]-1]=True
	nodeonL1L2=numpy.zeros(md.mesh.numberofvertices,bool)
	nodeonL1L2[md.mesh.elements[numpy.nonzero(L1L2flag),:]-1]=True
	nodeonHO=numpy.zeros(md.mesh.numberofvertices,bool)
	nodeonHO[md.mesh.elements[numpy.nonzero(HOflag),:]-1]=True
	nodeonFS=numpy.zeros(md.mesh.numberofvertices,bool)
	noneflag=numpy.zeros(md.mesh.numberofelements,bool)

	#First modify FSflag to get rid of elements contrained everywhere (spc + border with HO or SSA)
	if any(FSflag):
#		fullspcnodes=double((~isnan(md.diagnostic.spcvx)+~isnan(md.diagnostic.spcvy)+~isnan(md.diagnostic.spcvz))==3 | (nodeonHO & nodeonFS));         %find all the nodes on the boundary of the domain without icefront
		fullspcnodes=numpy.logical_or(numpy.logical_not(numpy.isnan(md.diagnostic.spcvx)).astype(int)+ \
		                              numpy.logical_not(numpy.isnan(md.diagnostic.spcvy)).astype(int)+ \
		                              numpy.logical_not(numpy.isnan(md.diagnostic.spcvz)).astype(int)==3, \
		                              numpy.logical_and(nodeonHO,nodeonFS).reshape(-1,1)).astype(int)    #find all the nodes on the boundary of the domain without icefront
#		fullspcelems=double(sum(fullspcnodes(md.mesh.elements),2)==6);         %find all the nodes on the boundary of the domain without icefront
		fullspcelems=(numpy.sum(fullspcnodes[md.mesh.elements-1],axis=1)==6).astype(int)    #find all the nodes on the boundary of the domain without icefront
		FSflag[numpy.nonzero(fullspcelems.reshape(-1))]=False
		nodeonFS[md.mesh.elements[numpy.nonzero(FSflag),:]-1]=True

	#Then complete with NoneApproximation or the other model used if there is no FS
	if any(FSflag): 
		if   any(HOflag):    #fill with HO
			HOflag[numpy.logical_not(FSflag)]=True
			nodeonHO[md.mesh.elements[numpy.nonzero(HOflag),:]-1]=True
		elif any(SSAflag):    #fill with SSA
			SSAflag[numpy.logical_not(FSflag)]=True
			nodeonSSA[md.mesh.elements[numpy.nonzero(SSAflag),:]-1]=True
		else:    #fill with none 
			noneflag[numpy.nonzero(numpy.logical_not(FSflag))]=True

	#Now take care of the coupling between SSA and HO
	md.diagnostic.vertex_pairing=numpy.array([])
	nodeonSSAHO=numpy.zeros(md.mesh.numberofvertices,bool)
	nodeonHOFS=numpy.zeros(md.mesh.numberofvertices,bool)
	nodeonSSAFS=numpy.zeros(md.mesh.numberofvertices,bool)
	SSAHOflag=numpy.zeros(md.mesh.numberofelements,bool)
	SSAFSflag=numpy.zeros(md.mesh.numberofelements,bool)
	HOFSflag=numpy.zeros(md.mesh.numberofelements,bool)
	if   strcmpi(coupling_method,'penalties'):
		#Create the border nodes between HO and SSA and extrude them
		numnodes2d=md.mesh.numberofvertices2d
		numlayers=md.mesh.numberoflayers
		bordernodes2d=numpy.nonzero(numpy.logical_and(nodeonHO[0:numnodes2d],nodeonSSA[0:numnodes2d]))[0]+1    #Nodes connected to two different types of elements

		#initialize and fill in penalties structure
		if numpy.all(numpy.logical_not(numpy.isnan(bordernodes2d))):
			penalties=numpy.zeros((0,2))
			for	i in xrange(1,numlayers):
				penalties=numpy.vstack((penalties,numpy.hstack((bordernodes2d.reshape(-1,1),bordernodes2d.reshape(-1,1)+md.mesh.numberofvertices2d*(i)))))
			md.diagnostic.vertex_pairing=penalties

	elif strcmpi(coupling_method,'tiling'):
		if   any(SSAflag) and any(HOflag):    #coupling SSA HO
			#Find node at the border
			nodeonSSAHO[numpy.nonzero(numpy.logical_and(nodeonSSA,nodeonHO))]=True
			#SSA elements in contact with this layer become SSAHO elements
			matrixelements=ismember(md.mesh.elements-1,numpy.nonzero(nodeonSSAHO)[0])
			commonelements=numpy.sum(matrixelements,axis=1)!=0
			commonelements[numpy.nonzero(HOflag)]=False    #only one layer: the elements previously in SSA
			SSAflag[numpy.nonzero(commonelements)]=False    #these elements are now SSAHOelements
			SSAHOflag[numpy.nonzero(commonelements)]=True
			nodeonSSA[:]=False
			nodeonSSA[md.mesh.elements[numpy.nonzero(SSAflag),:]-1]=True

			#rule out elements that don't touch the 2 boundaries
			pos=numpy.nonzero(SSAHOflag)[0]
			elist=numpy.zeros(numpy.size(pos),dtype=int)
			elist = elist + numpy.sum(nodeonSSA[md.mesh.elements[pos,:]-1],axis=1).astype(bool)
			elist = elist - numpy.sum(nodeonHO[md.mesh.elements[pos,:]-1]  ,axis=1).astype(bool)
			pos1=numpy.nonzero(elist==1)[0]
			SSAflag[pos[pos1]]=True
			SSAHOflag[pos[pos1]]=False
			pos2=numpy.nonzero(elist==-1)[0]
			HOflag[pos[pos2]]=True
			SSAHOflag[pos[pos2]]=False

			#Recompute nodes associated to these elements
			nodeonSSA[:]=False
			nodeonSSA[md.mesh.elements[numpy.nonzero(SSAflag),:]-1]=True
			nodeonHO[:]=False
			nodeonHO[md.mesh.elements[numpy.nonzero(HOflag),:]-1]=True
			nodeonSSAHO[:]=False
			nodeonSSAHO[md.mesh.elements[numpy.nonzero(SSAHOflag),:]-1]=True

		elif any(HOflag) and any(FSflag):    #coupling HO FS
			#Find node at the border
			nodeonHOFS[numpy.nonzero(numpy.logical_and(nodeonHO,nodeonFS))]=True
			#FS elements in contact with this layer become HOFS elements
			matrixelements=ismember(md.mesh.elements-1,numpy.nonzero(nodeonHOFS)[0])
			commonelements=numpy.sum(matrixelements,axis=1)!=0
			commonelements[numpy.nonzero(HOflag)]=False    #only one layer: the elements previously in SSA
			FSflag[numpy.nonzero(commonelements)]=False    #these elements are now SSAHOelements
			HOFSflag[numpy.nonzero(commonelements)]=True
			nodeonFS=numpy.zeros(md.mesh.numberofvertices,bool)
			nodeonFS[md.mesh.elements[numpy.nonzero(FSflag),:]-1]=True

			#rule out elements that don't touch the 2 boundaries
			pos=numpy.nonzero(HOFSflag)[0]
			elist=numpy.zeros(numpy.size(pos),dtype=int)
			elist = elist + numpy.sum(nodeonFS[md.mesh.elements[pos,:]-1],axis=1).astype(bool)
			elist = elist - numpy.sum(nodeonHO[md.mesh.elements[pos,:]-1],axis=1).astype(bool)
			pos1=numpy.nonzero(elist==1)[0]
			FSflag[pos[pos1]]=True
			HOFSflag[pos[pos1]]=False
			pos2=numpy.nonzero(elist==-1)[0]
			HOflag[pos[pos2]]=True
			HOFSflag[pos[pos2]]=False

			#Recompute nodes associated to these elements
			nodeonFS[:]=False
			nodeonFS[md.mesh.elements[numpy.nonzero(FSflag),:]-1]=True
			nodeonHO[:]=False
			nodeonHO[md.mesh.elements[numpy.nonzero(HOflag),:]-1]=True
			nodeonHOFS[:]=False
			nodeonHOFS[md.mesh.elements[numpy.nonzero(HOFSflag),:]-1]=True

		elif any(FSflag) and any(SSAflag):
			#Find node at the border
			nodeonSSAFS[numpy.nonzero(numpy.logical_and(nodeonSSA,nodeonFS))]=True
			#FS elements in contact with this layer become SSAFS elements
			matrixelements=ismember(md.mesh.elements-1,numpy.nonzero(nodeonSSAFS)[0])
			commonelements=numpy.sum(matrixelements,axis=1)!=0
			commonelements[numpy.nonzero(SSAflag)]=False    #only one layer: the elements previously in SSA
			FSflag[numpy.nonzero(commonelements)]=False    #these elements are now SSASSAelements
			SSAFSflag[numpy.nonzero(commonelements)]=True
			nodeonFS=numpy.zeros(md.mesh.numberofvertices,bool)
			nodeonFS[md.mesh.elements[numpy.nonzero(FSflag),:]-1]=True

			#rule out elements that don't touch the 2 boundaries
			pos=numpy.nonzero(SSAFSflag)[0]
			elist=numpy.zeros(numpy.size(pos),dtype=int)
			elist = elist + numpy.sum(nodeonSSA[md.mesh.elements[pos,:]-1],axis=1).astype(bool)
			elist = elist - numpy.sum(nodeonFS[md.mesh.elements[pos,:]-1]  ,axis=1).astype(bool)
			pos1=numpy.nonzero(elist==1)[0]
			SSAflag[pos[pos1]]=True
			SSAFSflag[pos[pos1]]=False
			pos2=numpy.nonzero(elist==-1)[0]
			FSflag[pos[pos2]]=True
			SSAFSflag[pos[pos2]]=False

			#Recompute nodes associated to these elements
			nodeonSSA[:]=False
			nodeonSSA[md.mesh.elements[numpy.nonzero(SSAflag),:]-1]=True
			nodeonFS[:]=False
			nodeonFS[md.mesh.elements[numpy.nonzero(FSflag),:]-1]=True
			nodeonSSAFS[:]=False
			nodeonSSAFS[md.mesh.elements[numpy.nonzero(SSAFSflag),:]-1]=True

		elif any(FSflag) and any(SIAflag):
			raise TypeError("type of coupling not supported yet")

	#Create SSAHOApproximation where needed
	md.flowequation.element_equation=numpy.zeros(md.mesh.numberofelements,int)
	md.flowequation.element_equation[numpy.nonzero(noneflag)]=0
	md.flowequation.element_equation[numpy.nonzero(SIAflag)]=1
	md.flowequation.element_equation[numpy.nonzero(SSAflag)]=2
	md.flowequation.element_equation[numpy.nonzero(L1L2flag)]=8
	md.flowequation.element_equation[numpy.nonzero(HOflag)]=3
	md.flowequation.element_equation[numpy.nonzero(FSflag)]=4
	md.flowequation.element_equation[numpy.nonzero(SSAHOflag)]=5
	md.flowequation.element_equation[numpy.nonzero(SSAFSflag)]=6
	md.flowequation.element_equation[numpy.nonzero(HOFSflag)]=7

	#border
	md.flowequation.borderHO=nodeonHO
	md.flowequation.borderSSA=nodeonSSA
	md.flowequation.borderFS=nodeonFS

	#Create vertices_type
	md.flowequation.vertex_equation=numpy.zeros(md.mesh.numberofvertices,int)
	pos=numpy.nonzero(nodeonSSA)
	md.flowequation.vertex_equation[pos]=2
	pos=numpy.nonzero(nodeonL1L2)
	md.flowequation.vertex_equation[pos]=8
	pos=numpy.nonzero(nodeonHO)
	md.flowequation.vertex_equation[pos]=3
	pos=numpy.nonzero(nodeonSIA)
	md.flowequation.vertex_equation[pos]=1
	pos=numpy.nonzero(nodeonSSAHO)
	md.flowequation.vertex_equation[pos]=5
	pos=numpy.nonzero(nodeonFS)
	md.flowequation.vertex_equation[pos]=4
	if any(FSflag):
		pos=numpy.nonzero(numpy.logical_not(nodeonFS))
		if not (any(HOflag) or any(SSAflag)):
			md.flowequation.vertex_equation[pos]=0
	pos=numpy.nonzero(nodeonHOFS)
	md.flowequation.vertex_equation[pos]=7
	pos=numpy.nonzero(nodeonSSAFS)
	md.flowequation.vertex_equation[pos]=6

	#figure out solution types
	md.flowequation.isSIA=any(md.flowequation.element_equation==1)
	md.flowequation.isSSA=any(md.flowequation.element_equation==2)
	md.flowequation.isL1L2=any(md.flowequation.element_equation==8)
	md.flowequation.isHO=any(md.flowequation.element_equation==3)
	md.flowequation.isFS=any(md.flowequation.element_equation==4)

	return md

	#Check that tiling can work:
	if any(md.flowequation.borderSSA) and any(md.flowequation.borderHO) and any(md.flowequation.borderHO + md.flowequation.borderSSA !=1):
		raise TypeError("error coupling domain too irregular")
	if any(md.flowequation.borderSSA) and any(md.flowequation.borderFS) and any(md.flowequation.borderFS + md.flowequation.borderSSA !=1):
		raise TypeError("error coupling domain too irregular")
	if any(md.flowequation.borderFS) and any(md.flowequation.borderHO) and any(md.flowequation.borderHO + md.flowequation.borderFS !=1):
		raise TypeError("error coupling domain too irregular")

	return md

