#module imports {{{
import numpy
import copy
import sys
from mesh import mesh
from mask import mask
from geometry import geometry
from constants import constants
from surfaceforcings import surfaceforcings
from SMB import SMB
from SMBpdd import SMBpdd
from SMBgradients import SMBgradients
from basalforcings import basalforcings
from matice import matice
from damage import damage
from friction import friction
from flowequation import flowequation
from timestepping import timestepping
from initialization import initialization
from rifts import rifts
from debug import debug
from verbose import verbose
from settings import settings
from toolkits import toolkits
from generic import generic
from balancethickness import balancethickness
from stressbalance import stressbalance
from groundingline import groundingline
from hydrologyshreve import hydrologyshreve
from masstransport import masstransport
from thermal import thermal
from steadystate import steadystate
from transient import transient
from gia import gia
from autodiff import autodiff
from flaim import flaim
from inversion import inversion
from outputdefinition import outputdefinition
from qmu import qmu
from results import results
from radaroverlay import radaroverlay
from miscellaneous import miscellaneous
from private import private
from EnumDefinitions import *
from mumpsoptions import *
from iluasmoptions import *
from project3d import *
from FlagElements import *
from NodeConnectivity import *
from ElementConnectivity import *
from contourenvelope import *
from PythonFuncs import *
#}}}

class model(object):
	#properties
	def __init__(self):#{{{
		self.mesh             = mesh()
		self.mask             = mask()
		self.geometry         = geometry()
		self.constants        = constants()
		self.surfaceforcings  = SMB()
		self.basalforcings    = basalforcings()
		self.materials        = matice()
		self.damage           = damage()
		self.friction         = friction()
		self.flowequation     = flowequation()
		self.timestepping     = timestepping()
		self.initialization   = initialization()
		self.rifts            = rifts()

		self.debug            = debug()
		self.verbose          = verbose()
		self.settings         = settings()
		self.toolkits         = toolkits()
		self.cluster          = generic()

		self.balancethickness = balancethickness()
		self.stressbalance       = stressbalance()
		self.groundingline    = groundingline()
		self.hydrology        = hydrologyshreve()
		self.masstransport       = masstransport()
		self.thermal          = thermal()
		self.steadystate      = steadystate()
		self.transient        = transient()
		self.gia              = gia()

		self.autodiff         = autodiff()
		self.flaim            = flaim()
		self.inversion        = inversion()
		self.qmu              = qmu()

		self.results          = results()
		self.outputdefinition = outputdefinition()
		self.radaroverlay     = radaroverlay()
		self.miscellaneous    = miscellaneous()
		self.private          = private()
		#}}}
	def properties(self):    # {{{
		# ordered list of properties since vars(self) is random
		return ['mesh',\
		        'mask',\
		        'geometry',\
		        'constants',\
		        'surfaceforcings',\
		        'basalforcings',\
		        'materials',\
		        'damage',\
		        'friction',\
		        'flowequation',\
		        'timestepping',\
		        'initialization',\
		        'rifts',\
		        'debug',\
		        'verbose',\
		        'settings',\
		        'toolkits',\
		        'cluster',\
		        'balancethickness',\
		        'stressbalance',\
		        'groundingline',\
		        'hydrology',\
		        'masstransport',\
		        'thermal',\
		        'steadystate',\
		        'transient',\
				  'gia',\
		        'autodiff',\
		        'flaim',\
		        'inversion',\
		        'qmu',\
		        'outputdefinition',\
		        'results',\
		        'radaroverlay',\
		        'miscellaneous',\
		        'private']
	# }}}
	def __repr__(obj): #{{{
		#print "Here %s the number: %d" % ("is", 37)
		string="%19s: %-22s -- %s" % ("mesh","[%s,%s]" % ("1x1",obj.mesh.__class__.__name__),"mesh properties")
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("mask","[%s,%s]" % ("1x1",obj.mask.__class__.__name__),"defines grounded and floating elements"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("geometry","[%s,%s]" % ("1x1",obj.geometry.__class__.__name__),"surface elevation, bedrock topography, ice thickness,..."))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("constants","[%s,%s]" % ("1x1",obj.constants.__class__.__name__),"physical constants"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("surfaceforcings","[%s,%s]" % ("1x1",obj.surfaceforcings.__class__.__name__),"surface forcings"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("basalforcings","[%s,%s]" % ("1x1",obj.basalforcings.__class__.__name__),"bed forcings"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("materials","[%s,%s]" % ("1x1",obj.materials.__class__.__name__),"material properties"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("damage","[%s,%s]" % ("1x1",obj.damage.__class__.__name__),"damage propagation laws"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("friction","[%s,%s]" % ("1x1",obj.friction.__class__.__name__),"basal friction/drag properties"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("flowequation","[%s,%s]" % ("1x1",obj.flowequation.__class__.__name__),"flow equations"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("timestepping","[%s,%s]" % ("1x1",obj.timestepping.__class__.__name__),"time stepping for transient models"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("initialization","[%s,%s]" % ("1x1",obj.initialization.__class__.__name__),"initial guess/state"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("rifts","[%s,%s]" % ("1x1",obj.rifts.__class__.__name__),"rifts properties"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("debug","[%s,%s]" % ("1x1",obj.debug.__class__.__name__),"debugging tools (valgrind, gprof)"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("verbose","[%s,%s]" % ("1x1",obj.verbose.__class__.__name__),"verbosity level in solve"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("settings","[%s,%s]" % ("1x1",obj.settings.__class__.__name__),"settings properties"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("toolkits","[%s,%s]" % ("1x1",obj.toolkits.__class__.__name__),"PETSc options for each solution"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("cluster","[%s,%s]" % ("1x1",obj.cluster.__class__.__name__),"cluster parameters (number of cpus...)"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("balancethickness","[%s,%s]" % ("1x1",obj.balancethickness.__class__.__name__),"parameters for balancethickness solution"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("stressbalance","[%s,%s]" % ("1x1",obj.stressbalance.__class__.__name__),"parameters for stressbalance solution"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("groundingline","[%s,%s]" % ("1x1",obj.groundingline.__class__.__name__),"parameters for groundingline solution"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("hydrology","[%s,%s]" % ("1x1",obj.hydrology.__class__.__name__),"parameters for hydrology solution"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("masstransport","[%s,%s]" % ("1x1",obj.masstransport.__class__.__name__),"parameters for masstransport solution"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("thermal","[%s,%s]" % ("1x1",obj.thermal.__class__.__name__),"parameters for thermal solution"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("steadystate","[%s,%s]" % ("1x1",obj.steadystate.__class__.__name__),"parameters for steadystate solution"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("transient","[%s,%s]" % ("1x1",obj.transient.__class__.__name__),"parameters for transient solution"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("autodiff","[%s,%s]" % ("1x1",obj.autodiff.__class__.__name__),"automatic differentiation parameters"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("flaim","[%s,%s]" % ("1x1",obj.flaim.__class__.__name__),"flaim parameters"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("inversion","[%s,%s]" % ("1x1",obj.inversion.__class__.__name__),"parameters for inverse methods"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("qmu","[%s,%s]" % ("1x1",obj.qmu.__class__.__name__),"dakota properties"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("outputdefinition","[%s,%s]" % ("1x1",obj.outputdefinition.__class__.__name__),"output definition"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("results","[%s,%s]" % ("1x1",obj.results.__class__.__name__),"model results"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("radaroverlay","[%s,%s]" % ("1x1",obj.radaroverlay.__class__.__name__),"radar image for plot overlay"))
		string="%s\n%s" % (string,"%19s: %-22s -- %s" % ("miscellaneous","[%s,%s]" % ("1x1",obj.miscellaneous.__class__.__name__),"miscellaneous fields"))
		return string
	# }}}
	def checkmessage(self,string):    # {{{
		print ("model not consistent: %s" % string)
		self.private.isconsistent=False
		return self
	# }}}
	def extract(md,area):    # {{{
		"""
		extract - extract a model according to an Argus contour or flag list

		   This routine extracts a submodel from a bigger model with respect to a given contour
		   md must be followed by the corresponding exp file or flags list
		   It can either be a domain file (argus type, .exp extension), or an array of element flags. 
		   If user wants every element outside the domain to be 
		   extract2d, add '~' to the name of the domain file (ex: '~HO.exp');
		   an empty string '' will be considered as an empty domain
		   a string 'all' will be considered as the entire domain

		   Usage:
		      md2=extract(md,area);

		   Examples:
		      md2=extract(md,'Domain.exp');

		   See also: EXTRUDE, COLLAPSE
		"""

		#copy model
		md1=copy.deepcopy(md)

		#get elements that are inside area
		flag_elem=FlagElements(md1,area)
		if not numpy.any(flag_elem):
			raise RuntimeError("extracted model is empty")

		#kick out all elements with 3 dirichlets
		spc_elem=numpy.nonzero(numpy.logical_not(flag_elem))[0]
		spc_node=numpy.unique(md1.mesh.elements[spc_elem,:])-1
		flag=numpy.ones(md1.mesh.numberofvertices)
		flag[spc_node]=0
		pos=numpy.nonzero(numpy.logical_not(numpy.sum(flag[md1.mesh.elements-1],axis=1)))[0]
		flag_elem[pos]=0

		#extracted elements and nodes lists
		pos_elem=numpy.nonzero(flag_elem)[0]
		pos_node=numpy.unique(md1.mesh.elements[pos_elem,:])-1

		#keep track of some fields
		numberofvertices1=md1.mesh.numberofvertices
		numberofelements1=md1.mesh.numberofelements
		numberofvertices2=numpy.size(pos_node)
		numberofelements2=numpy.size(pos_elem)
		flag_node=numpy.zeros(numberofvertices1)
		flag_node[pos_node]=1

		#Create Pelem and Pnode (transform old nodes in new nodes and same thing for the elements)
		Pelem=numpy.zeros(numberofelements1,int)
		Pelem[pos_elem]=numpy.arange(1,numberofelements2+1)
		Pnode=numpy.zeros(numberofvertices1,int)
		Pnode[pos_node]=numpy.arange(1,numberofvertices2+1)

		#renumber the elements (some node won't exist anymore)
		elements_1=copy.deepcopy(md1.mesh.elements)
		elements_2=elements_1[pos_elem,:]
		elements_2[:,0]=Pnode[elements_2[:,0]-1]
		elements_2[:,1]=Pnode[elements_2[:,1]-1]
		elements_2[:,2]=Pnode[elements_2[:,2]-1]
		if md1.mesh.dimension==3:
			elements_2[:,3]=Pnode[elements_2[:,3]-1]
			elements_2[:,4]=Pnode[elements_2[:,4]-1]
			elements_2[:,5]=Pnode[elements_2[:,5]-1]

		#OK, now create the new model!

		#take every field from model
		md2=copy.deepcopy(md1)

		#automatically modify fields

		#loop over model fields
		model_fields=vars(md1)
		for fieldi in model_fields:
			#get field
			field=getattr(md1,fieldi)
			fieldsize=numpy.shape(field)
			if hasattr(field,'__dict__') and not ismember(fieldi,['results'])[0]:    #recursive call
				object_fields=vars(field)
				for fieldj in object_fields:
					#get field
					field=getattr(getattr(md1,fieldi),fieldj)
					fieldsize=numpy.shape(field)
					if len(fieldsize):
						#size = number of nodes * n
						if   fieldsize[0]==numberofvertices1:
							setattr(getattr(md2,fieldi),fieldj,field[pos_node,:])
						elif fieldsize[0]==numberofvertices1+1:
							setattr(getattr(md2,fieldi),fieldj,numpy.vstack((field[pos_node,:],field[-1,:])))
						#size = number of elements * n
						elif fieldsize[0]==numberofelements1:
							setattr(getattr(md2,fieldi),fieldj,field[pos_elem,:])
			else:
				if len(fieldsize):
					#size = number of nodes * n
					if   fieldsize[0]==numberofvertices1:
						setattr(md2,fieldi,field[pos_node,:])
					elif fieldsize[0]==numberofvertices1+1:
						setattr(md2,fieldi,numpy.hstack((field[pos_node,:],field[-1,:])))
					#size = number of elements * n
					elif fieldsize[0]==numberofelements1:
						setattr(md2,fieldi,field[pos_elem,:])

		#modify some specific fields

		#Mesh
		md2.mesh.numberofelements=numberofelements2
		md2.mesh.numberofvertices=numberofvertices2
		md2.mesh.elements=elements_2

		#mesh.uppervertex mesh.lowervertex
		if md1.mesh.dimension==3:
			md2.mesh.uppervertex=md1.mesh.uppervertex[pos_node]
			pos=numpy.nonzero(numpy.logical_not(md2.mesh.uppervertex==-1))[0]
			md2.mesh.uppervertex[pos]=Pnode[md2.mesh.uppervertex[pos]-1]

			md2.mesh.lowervertex=md1.mesh.lowervertex[pos_node]
			pos=numpy.nonzero(numpy.logical_not(md2.mesh.lowervertex==-1))[0]
			md2.mesh.lowervertex[pos]=Pnode[md2.mesh.lowervertex[pos]-1]

			md2.mesh.upperelements=md1.mesh.upperelements[pos_elem]
			pos=numpy.nonzero(numpy.logical_not(md2.mesh.upperelements==-1))[0]
			md2.mesh.upperelements[pos]=Pelem[md2.mesh.upperelements[pos]-1]

			md2.mesh.lowerelements=md1.mesh.lowerelements[pos_elem]
			pos=numpy.nonzero(numpy.logical_not(md2.mesh.lowerelements==-1))[0]
			md2.mesh.lowerelements[pos]=Pelem[md2.mesh.lowerelements[pos]-1]

		#Initial 2d mesh 
		if md1.mesh.dimension==3:
			flag_elem_2d=flag_elem[numpy.arange(0,md1.mesh.numberofelements2d)]
			pos_elem_2d=numpy.nonzero(flag_elem_2d)[0]
			flag_node_2d=flag_node[numpy.arange(0,md1.mesh.numberofvertices2d)]
			pos_node_2d=numpy.nonzero(flag_node_2d)[0]

			md2.mesh.numberofelements2d=numpy.size(pos_elem_2d)
			md2.mesh.numberofvertices2d=numpy.size(pos_node_2d)
			md2.mesh.elements2d=md1.mesh.elements2d[pos_elem_2d,:]
			md2.mesh.elements2d[:,0]=Pnode[md2.mesh.elements2d[:,0]-1]
			md2.mesh.elements2d[:,1]=Pnode[md2.mesh.elements2d[:,1]-1]
			md2.mesh.elements2d[:,2]=Pnode[md2.mesh.elements2d[:,2]-1]

			md2.mesh.x2d=md1.mesh.x[pos_node_2d]
			md2.mesh.y2d=md1.mesh.y[pos_node_2d]

		#Edges
		if numpy.ndim(md2.mesh.edges)>1 and numpy.size(md2.mesh.edges,axis=1)>1:    #do not use ~isnan because there are some NaNs...
			#renumber first two columns
			pos=numpy.nonzero(md2.mesh.edges[:,3]!=-1)[0]
			md2.mesh.edges[:  ,0]=Pnode[md2.mesh.edges[:,0]-1]
			md2.mesh.edges[:  ,1]=Pnode[md2.mesh.edges[:,1]-1]
			md2.mesh.edges[:  ,2]=Pelem[md2.mesh.edges[:,2]-1]
			md2.mesh.edges[pos,3]=Pelem[md2.mesh.edges[pos,3]-1]
			#remove edges when the 2 vertices are not in the domain.
			md2.mesh.edges=md2.mesh.edges[numpy.nonzero(numpy.logical_and(md2.mesh.edges[:,0],md2.mesh.edges[:,1]))[0],:]
			#Replace all zeros by -1 in the last two columns
			pos=numpy.nonzero(md2.mesh.edges[:,2]==0)[0]
			md2.mesh.edges[pos,2]=-1
			pos=numpy.nonzero(md2.mesh.edges[:,3]==0)[0]
			md2.mesh.edges[pos,3]=-1
			#Invert -1 on the third column with last column (Also invert first two columns!!)
			pos=numpy.nonzero(md2.mesh.edges[:,2]==-1)[0]
			md2.mesh.edges[pos,2]=md2.mesh.edges[pos,3]
			md2.mesh.edges[pos,3]=-1
			values=md2.mesh.edges[pos,1]
			md2.mesh.edges[pos,1]=md2.mesh.edges[pos,0]
			md2.mesh.edges[pos,0]=values
			#Finally remove edges that do not belong to any element
			pos=numpy.nonzero(numpy.logical_and(md2.mesh.edges[:,1]==-1,md2.mesh.edges[:,2]==-1))[0]
			md2.mesh.edges=numpy.delete(md2.mesh.edges,pos,axis=0)

		#Penalties
		if numpy.any(numpy.logical_not(numpy.isnan(md2.stressbalance.vertex_pairing))):
			for i in xrange(numpy.size(md1.stressbalance.vertex_pairing,axis=0)):
				md2.stressbalance.vertex_pairing[i,:]=Pnode[md1.stressbalance.vertex_pairing[i,:]]
			md2.stressbalance.vertex_pairing=md2.stressbalance.vertex_pairing[numpy.nonzero(md2.stressbalance.vertex_pairing[:,0])[0],:]
		if numpy.any(numpy.logical_not(numpy.isnan(md2.masstransport.vertex_pairing))):
			for i in xrange(numpy.size(md1.masstransport.vertex_pairing,axis=0)):
				md2.masstransport.vertex_pairing[i,:]=Pnode[md1.masstransport.vertex_pairing[i,:]]
			md2.masstransport.vertex_pairing=md2.masstransport.vertex_pairing[numpy.nonzero(md2.masstransport.vertex_pairing[:,0])[0],:]

		#recreate segments
		if md1.mesh.dimension==2:
			[md2.mesh.vertexconnectivity]=NodeConnectivity(md2.mesh.elements,md2.mesh.numberofvertices)
			[md2.mesh.elementconnectivity]=ElementConnectivity(md2.mesh.elements,md2.mesh.vertexconnectivity)
			md2.mesh.segments=contourenvelope(md2)
			md2.mesh.vertexonboundary=numpy.zeros(numberofvertices2,bool)
			md2.mesh.vertexonboundary[md2.mesh.segments[:,0:2]-1]=True
		else:
			#First do the connectivity for the contourenvelope in 2d
			[md2.mesh.vertexconnectivity]=NodeConnectivity(md2.mesh.elements2d,md2.mesh.numberofvertices2d)
			[md2.mesh.elementconnectivity]=ElementConnectivity(md2.mesh.elements2d,md2.mesh.vertexconnectivity)
			md2.mesh.segments=contourenvelope(md2)
			md2.mesh.vertexonboundary=numpy.zeros(numberofvertices2/md2.mesh.numberoflayers,bool)
			md2.mesh.vertexonboundary[md2.mesh.segments[:,0:2]-1]=True
			md2.mesh.vertexonboundary=numpy.tile(md2.mesh.vertexonboundary,md2.mesh.numberoflayers)
			#Then do it for 3d as usual
			[md2.mesh.vertexconnectivity]=NodeConnectivity(md2.mesh.elements,md2.mesh.numberofvertices)
			[md2.mesh.elementconnectivity]=ElementConnectivity(md2.mesh.elements,md2.mesh.vertexconnectivity)

		#Boundary conditions: Dirichlets on new boundary
		#Catch the elements that have not been extracted
		orphans_elem=numpy.nonzero(numpy.logical_not(flag_elem))[0]
		orphans_node=numpy.unique(md1.mesh.elements[orphans_elem,:])-1
		#Figure out which node are on the boundary between md2 and md1
		nodestoflag1=numpy.intersect1d(orphans_node,pos_node)
		nodestoflag2=Pnode[nodestoflag1].astype(int)-1
		if numpy.size(md1.stressbalance.spcvx)>1 and numpy.size(md1.stressbalance.spcvy)>2 and numpy.size(md1.stressbalance.spcvz)>2:
			if numpy.size(md1.inversion.vx_obs)>1 and numpy.size(md1.inversion.vy_obs)>1:
				md2.stressbalance.spcvx[nodestoflag2]=md2.inversion.vx_obs[nodestoflag2] 
				md2.stressbalance.spcvy[nodestoflag2]=md2.inversion.vy_obs[nodestoflag2]
			else:
				md2.stressbalance.spcvx[nodestoflag2]=float('NaN')
				md2.stressbalance.spcvy[nodestoflag2]=float('NaN')
				print "\n!! extract warning: spc values should be checked !!\n\n"
			#put 0 for vz
			md2.stressbalance.spcvz[nodestoflag2]=0
		if numpy.any(numpy.logical_not(numpy.isnan(md1.thermal.spctemperature))):
			md2.thermal.spctemperature[nodestoflag2,0]=1

		#Results fields
		if md1.results:
			md2.results=results()
			for solutionfield,field in md1.results.__dict__.iteritems():
				if   isinstance(field,list):
					setattr(md2.results,solutionfield,[])
					#get time step
					for i,fieldi in enumerate(field):
						if isinstance(fieldi,results) and fieldi:
							getattr(md2.results,solutionfield).append(results())
							fieldr=getattr(md2.results,solutionfield)[i]
							#get subfields
							for solutionsubfield,subfield in fieldi.__dict__.iteritems():
								if   numpy.size(subfield)==numberofvertices1:
									setattr(fieldr,solutionsubfield,subfield[pos_node])
								elif numpy.size(subfield)==numberofelements1:
									setattr(fieldr,solutionsubfield,subfield[pos_elem])
								else:
									setattr(fieldr,solutionsubfield,subfield)
						else:
							getattr(md2.results,solutionfield).append(None)
				elif isinstance(field,results):
					setattr(md2.results,solutionfield,results())
					if isinstance(field,results) and field:
						fieldr=getattr(md2.results,solutionfield)
						#get subfields
						for solutionsubfield,subfield in field.__dict__.iteritems():
							if   numpy.size(subfield)==numberofvertices1:
								setattr(fieldr,solutionsubfield,subfield[pos_node])
							elif numpy.size(subfield)==numberofelements1:
								setattr(fieldr,solutionsubfield,subfield[pos_elem])
							else:
								setattr(fieldr,solutionsubfield,subfield)

		#Keep track of pos_node and pos_elem
		md2.mesh.extractedvertices=pos_node+1
		md2.mesh.extractedelements=pos_elem+1

		return md2
	# }}}
	def extrude(md,*args):    # {{{
		"""
		EXTRUDE - vertically extrude a 2d mesh

		   vertically extrude a 2d mesh and create corresponding 3d mesh.
		   The vertical distribution can:
		    - follow a polynomial law
		    - follow two polynomial laws, one for the lower part and one for the upper part of the mesh
		    - be discribed by a list of coefficients (between 0 and 1)
 

		   Usage:
		      md=extrude(md,numlayers,extrusionexponent);
		      md=extrude(md,numlayers,lowerexponent,upperexponent);
		      md=extrude(md,listofcoefficients);

		   Example:
		      md=extrude(md,8,3);
		      md=extrude(md,8,3,2);
		      md=extrude(md,[0 0.2 0.5 0.7 0.9 0.95 1]);

		   See also: MODELEXTRACT, COLLAPSE
		"""

		#some checks on list of arguments
		if len(args)>3 or len(args)<1:
			raise RuntimeError("extrude error message")

		#Extrude the mesh
		if   len(args)==1:    #list of coefficients
			clist=args[0]
			if any(clist<0) or any(clist>1):
				raise TypeError("extrusioncoefficients must be between 0 and 1")
			clist.extend([0.,1.])
			clist.sort()
			extrusionlist=list(set(clist))
			numlayers=len(extrusionlist)

		elif len(args)==2:    #one polynomial law
			if args[1]<=0:
				raise TypeError("extrusionexponent must be >=0")
			numlayers=args[0]
			extrusionlist=(numpy.arange(0.,float(numlayers-1)+1.,1.)/float(numlayers-1))**args[1]

		elif len(args)==3:    #two polynomial laws
			numlayers=args[0]
			lowerexp=args[1]
			upperexp=args[2]

			if args[1]<=0 or args[2]<=0:
				raise TypeError("lower and upper extrusionexponents must be >=0")

			lowerextrusionlist=(numpy.arange(0.,1.+2./float(numlayers-1),2./float(numlayers-1)))**lowerexp/2.
			upperextrusionlist=(numpy.arange(0.,1.+2./float(numlayers-1),2./float(numlayers-1)))**upperexp/2.
			extrusionlist=numpy.unique(numpy.concatenate((lowerextrusionlist,1.-upperextrusionlist)))

		if numlayers<2:
			raise TypeError("number of layers should be at least 2")
		if md.mesh.dimension==3:
			raise TypeError("Cannot extrude a 3d mesh (extrude cannot be called more than once)")

		#Initialize with the 2d mesh
		x3d=numpy.empty((0))
		y3d=numpy.empty((0))
		z3d=numpy.empty((0))    #the lower node is on the bed
		thickness3d=md.geometry.thickness    #thickness and bed for these nodes
		bed3d=md.geometry.bed

		#Create the new layers
		for i in xrange(numlayers):
			x3d=numpy.concatenate((x3d,md.mesh.x))
			y3d=numpy.concatenate((y3d,md.mesh.y))
			#nodes are distributed between bed and surface accordingly to the given exponent
			z3d=numpy.concatenate((z3d,(bed3d+thickness3d*extrusionlist[i]).reshape(-1)))
		number_nodes3d=numpy.size(x3d)    #number of 3d nodes for the non extruded part of the mesh

		#Extrude elements 
		elements3d=numpy.empty((0,6),int)
		for i in xrange(numlayers-1):
			elements3d=numpy.vstack((elements3d,numpy.hstack((md.mesh.elements+i*md.mesh.numberofvertices,md.mesh.elements+(i+1)*md.mesh.numberofvertices))))    #Create the elements of the 3d mesh for the non extruded part
		number_el3d=numpy.size(elements3d,axis=0)    #number of 3d nodes for the non extruded part of the mesh

		#Keep a trace of lower and upper nodes
		mesh.lowervertex=-1*numpy.ones(number_nodes3d,int)
		mesh.uppervertex=-1*numpy.ones(number_nodes3d,int)
		mesh.lowervertex[md.mesh.numberofvertices:]=numpy.arange(1,(numlayers-1)*md.mesh.numberofvertices+1)
		mesh.uppervertex[:(numlayers-1)*md.mesh.numberofvertices]=numpy.arange(md.mesh.numberofvertices+1,number_nodes3d+1)
		md.mesh.lowervertex=mesh.lowervertex
		md.mesh.uppervertex=mesh.uppervertex

		#same for lower and upper elements
		mesh.lowerelements=-1*numpy.ones(number_el3d,int)
		mesh.upperelements=-1*numpy.ones(number_el3d,int)
		mesh.lowerelements[md.mesh.numberofelements:]=numpy.arange(1,(numlayers-2)*md.mesh.numberofelements+1)
		mesh.upperelements[:(numlayers-2)*md.mesh.numberofelements]=numpy.arange(md.mesh.numberofelements+1,(numlayers-1)*md.mesh.numberofelements+1)
		md.mesh.lowerelements=mesh.lowerelements
		md.mesh.upperelements=mesh.upperelements

		#Save old mesh 
		md.mesh.x2d=md.mesh.x
		md.mesh.y2d=md.mesh.y
		md.mesh.elements2d=md.mesh.elements
		md.mesh.numberofelements2d=md.mesh.numberofelements
		md.mesh.numberofvertices2d=md.mesh.numberofvertices

		#Update mesh type
		md.mesh.dimension=3

		#Build global 3d mesh 
		md.mesh.elements=elements3d
		md.mesh.x=x3d
		md.mesh.y=y3d
		md.mesh.z=z3d
		md.mesh.numberofelements=number_el3d
		md.mesh.numberofvertices=number_nodes3d
		md.mesh.numberoflayers=numlayers

		#Ok, now deal with the other fields from the 2d mesh:

		#lat long
		md.mesh.lat=project3d(md,'vector',md.mesh.lat,'type','node')
		md.mesh.long=project3d(md,'vector',md.mesh.long,'type','node')

		#drag coefficient is limited to nodes that are on the bedrock.
		md.friction.coefficient=project3d(md,'vector',md.friction.coefficient,'type','node','layer',1)

		#p and q (same deal, except for element that are on the bedrock: )
		md.friction.p=project3d(md,'vector',md.friction.p,'type','element')
		md.friction.q=project3d(md,'vector',md.friction.q,'type','element')

		#observations
		md.inversion.vx_obs=project3d(md,'vector',md.inversion.vx_obs,'type','node')
		md.inversion.vy_obs=project3d(md,'vector',md.inversion.vy_obs,'type','node')
		md.inversion.vel_obs=project3d(md,'vector',md.inversion.vel_obs,'type','node')
		md.inversion.thickness_obs=project3d(md,'vector',md.inversion.thickness_obs,'type','node')
		md.surfaceforcings.extrude(md)
		md.balancethickness.thickening_rate=project3d(md,'vector',md.balancethickness.thickening_rate,'type','node')

		#results
		if not numpy.any(numpy.isnan(md.initialization.vx)):
			md.initialization.vx=project3d(md,'vector',md.initialization.vx,'type','node')
		if not numpy.any(numpy.isnan(md.initialization.vy)):
			md.initialization.vy=project3d(md,'vector',md.initialization.vy,'type','node')
		if not numpy.any(numpy.isnan(md.initialization.vz)):
			md.initialization.vz=project3d(md,'vector',md.initialization.vz,'type','node')
		if not numpy.any(numpy.isnan(md.initialization.vel)):
			md.initialization.vel=project3d(md,'vector',md.initialization.vel,'type','node')
		if not numpy.any(numpy.isnan(md.initialization.temperature)):
			md.initialization.temperature=project3d(md,'vector',md.initialization.temperature,'type','node')
		if not numpy.any(numpy.isnan(md.initialization.waterfraction)):
			md.initialization.waterfraction=project3d(md,'vector',md.initialization.waterfraction,'type','node')

		#bedinfo and surface info
		md.mesh.elementonbed=project3d(md,'vector',numpy.ones(md.mesh.numberofelements2d,bool),'type','element','layer',1)
		md.mesh.elementonsurface=project3d(md,'vector',numpy.ones(md.mesh.numberofelements2d,bool),'type','element','layer',md.mesh.numberoflayers-1)
		md.mesh.vertexonbed=project3d(md,'vector',numpy.ones(md.mesh.numberofvertices2d,bool),'type','node','layer',1)
		md.mesh.vertexonsurface=project3d(md,'vector',numpy.ones(md.mesh.numberofvertices2d,bool),'type','node','layer',md.mesh.numberoflayers)

		#elementstype
		if not numpy.any(numpy.isnan(md.flowequation.element_equation)):
			oldelements_type=md.flowequation.element_equation
			md.flowequation.element_equation=numpy.zeros(number_el3d,int)
			md.flowequation.element_equation=project3d(md,'vector',oldelements_type,'type','element')

		#verticestype
		if not numpy.any(numpy.isnan(md.flowequation.vertex_equation)):
			oldvertices_type=md.flowequation.vertex_equation
			md.flowequation.vertex_equation=numpy.zeros(number_nodes3d,int)
			md.flowequation.vertex_equation=project3d(md,'vector',oldvertices_type,'type','node')

		md.flowequation.borderSSA=project3d(md,'vector',md.flowequation.borderSSA,'type','node')
		md.flowequation.borderHO=project3d(md,'vector',md.flowequation.borderHO,'type','node')
		md.flowequation.borderFS=project3d(md,'vector',md.flowequation.borderFS,'type','node')

		#boundary conditions
		md.stressbalance.spcvx=project3d(md,'vector',md.stressbalance.spcvx,'type','node')
		md.stressbalance.spcvy=project3d(md,'vector',md.stressbalance.spcvy,'type','node')
		md.stressbalance.spcvz=project3d(md,'vector',md.stressbalance.spcvz,'type','node')
		md.thermal.spctemperature=project3d(md,'vector',md.thermal.spctemperature,'type','node','layer',md.mesh.numberoflayers,'padding',float('NaN'))
		md.masstransport.spcthickness=project3d(md,'vector',md.masstransport.spcthickness,'type','node')
		md.balancethickness.spcthickness=project3d(md,'vector',md.balancethickness.spcthickness,'type','node')
		md.damage.spcdamage=project3d(md,'vector',md.damage.spcdamage,'type','node')
		md.stressbalance.referential=project3d(md,'vector',md.stressbalance.referential,'type','node')
		md.stressbalance.loadingforce=project3d(md,'vector',md.stressbalance.loadingforce,'type','node')

		#connectivity
		md.mesh.elementconnectivity=numpy.tile(md.mesh.elementconnectivity,(numlayers-1,1))
		md.mesh.elementconnectivity[numpy.nonzero(md.mesh.elementconnectivity==0)]=-sys.maxint-1
		for i in xrange(1,numlayers-1):
			md.mesh.elementconnectivity[i*md.mesh.numberofelements2d:(i+1)*md.mesh.numberofelements2d,:] \
				=md.mesh.elementconnectivity[i*md.mesh.numberofelements2d:(i+1)*md.mesh.numberofelements2d,:]+md.mesh.numberofelements2d
		md.mesh.elementconnectivity[numpy.nonzero(md.mesh.elementconnectivity<0)]=0

		#materials
		md.materials.rheology_B=project3d(md,'vector',md.materials.rheology_B,'type','node')
		md.materials.rheology_n=project3d(md,'vector',md.materials.rheology_n,'type','element')

		#damage
		md.damage.D=project3d(md,'vector',md.damage.D,'type','node')

		#parameters
		md.geometry.surface=project3d(md,'vector',md.geometry.surface,'type','node')
		md.geometry.thickness=project3d(md,'vector',md.geometry.thickness,'type','node')
		md.gia.mantle_viscosity=project3d(md,'vector',md.gia.mantle_viscosity,'type','node')
		md.gia.lithosphere_thickness=project3d(md,'vector',md.gia.lithosphere_thickness,'type','node')
		md.geometry.hydrostatic_ratio=project3d(md,'vector',md.geometry.hydrostatic_ratio,'type','node')
		md.geometry.bed=project3d(md,'vector',md.geometry.bed,'type','node')
		md.geometry.bathymetry=project3d(md,'vector',md.geometry.bathymetry,'type','node')
		md.mesh.vertexonboundary=project3d(md,'vector',md.mesh.vertexonboundary,'type','node')
		md.mask.ice_levelset=project3d(md,'vector',md.mask.ice_levelset,'type','node')
		md.mask.groundedice_levelset=project3d(md,'vector',md.mask.groundedice_levelset,'type','node')
		if not numpy.any(numpy.isnan(md.inversion.cost_functions_coefficients)):
			md.inversion.cost_functions_coefficients=project3d(md,'vector',md.inversion.cost_functions_coefficients,'type','node');end;
		if not numpy.any(numpy.isnan(md.inversion.min_parameters)):
			md.inversion.min_parameters=project3d(md,'vector',md.inversion.min_parameters,'type','node')
		if not numpy.any(numpy.isnan(md.inversion.max_parameters)):
			md.inversion.max_parameters=project3d(md,'vector',md.inversion.max_parameters,'type','node')
		if not numpy.any(numpy.isnan(md.qmu.partition)):
			md.qmu.partition=project3d(md,'vector',numpy.transpose(md.qmu.partition),'type','node')

		#Put lithostatic pressure if there is an existing pressure
		if not numpy.any(numpy.isnan(md.initialization.pressure)):
			md.initialization.pressure=md.constants.g*md.materials.rho_ice*(md.geometry.surface-md.mesh.z.reshape(-1,1))

		#special for thermal modeling:
		md.basalforcings.melting_rate=project3d(md,'vector',md.basalforcings.melting_rate,'type','node','layer',1)
		if not numpy.any(numpy.isnan(md.basalforcings.geothermalflux)):
			md.basalforcings.geothermalflux=project3d(md,'vector',md.basalforcings.geothermalflux,'type','node','layer',1)    #bedrock only gets geothermal flux

		#increase connectivity if less than 25:
		if md.mesh.average_vertex_connectivity<=25:
			md.mesh.average_vertex_connectivity=100

		return md
		# }}}
