import warnings

import numpy as np

from generic import generic
from issmsettings import issmsettings
from meshintersect3d import meshintersect3d
from miscellaneous import miscellaneous
from pairoptions import pairoptions
from private import private
from TwoDToThreeD import TwoDToThreeD


class sealevelmodel(object):
    '''
    SEALEVELMODEL class definition

    Usage:
        slm = sealevelmodel(*args)

        where args is a variable list of options

    Example:
        slm = sealevelmodel(
            'icecap', md_greenland,
            'icecap', md_antarctica,
            'earth', md_earth
        )
    '''

    def __init__(self, *args): #{{{
        self.icecaps        = [] # list of land/ice models; name should be changed later
        self.earth          = 0 # model for the whole earth
        self.basins         = [] # list of basins, matching icecaps, where shapefile info is held
        self.cluster        = 0
        self.miscellaneous  = 0
        self.settings       = 0
        self.private        = 0
        self.mergedcaps     = 0
        self.transitions    = []
        self.eltransitions  = []
        self.planet         = ''

        # Create a default object
        self = self.setdefaultparameters()

        if len(args):
            # Use provided options to set fields
            options = pairoptions(*args)

            # Recover all the icecap models
            self.icecaps = options.getfieldvalue('ice_cap', [])

            # Recover the earth models
            self.earth = options.getfieldvalue('earth', 0)

            # Set planet type
            self.planet = options.getfieldvalue('planet', 'earth')
    #}}}

    def __repr__(self): # {{{
        s = '{}\n'.format(fielddisplay(self, 'icecaps', 'ice caps'))
        s += '{}\n'.format(fielddisplay(self, 'earth', 'earth'))
        s += '{}\n'.format(fielddisplay(self, 'settings', 'settings properties'))
        s += '{}\n'.format(fielddisplay(self, 'cluster', 'cluster parameters (number of cpus...'))
        s += '{}\n'.format(fielddisplay(self, 'miscellaneous', 'miscellaneous fields'))
    #}}}

    def setdefaultparameters(self): # {{{
        self.icecaps        = []
        self.earth          = []
        self.cluster        = generic()
        self.miscellaneous  = miscellaneous()
        self.settings       = issmsettings()
        self.private        = private()
        self.transitions    = []
        self.eltransitions  = []
        self.planet         = 'earth'
    #}}}

    @staticmethod
    def checkconsistency(slm, solutiontype): # {{{
        # Is the coupler turned on?
        for i in range(len(slm.icecaps)):
            if slm.icecaps[i].transient.iscoupler == 0:
                warnings.warn('sealevelmodel checkconsistency error: icecap model {} should have the transient coupler option turned on!'.format(slm.icecaps[i].miscellaneous.name))

        if slm.earth.transient.iscoupler == 0:
            warnings.warn('sealevelmodel checkconsistency error: earth model should have the transient coupler option turned on!')

        # Check that the transition vectors have the right size
        for i in range(len(slm.icecaps)):
            if slm.icecaps[i].mesh.numberofvertices != len(slm.earth.slr.transitions[i]):
                raise RuntimeError('sealevelmodel checkconsistency issue with size of transition vector for ice cap: {} name: {}'.format(i, slm.icecaps[i].miscellaneous.name))

        # Check that run frequency is the same everywhere
        for i in range(len(slm.icecaps)):
            if slm.icecaps[i].slr.geodetic_run_frequency != slm.earth.geodetic_run_frequency:
                raise RuntimeError('sealevelmodel checkconsistency error: icecap model {} should have the same run frequency as earth!'.format(slm.icecaps[i].miscellaneous.name))

        # Make sure steric_rate is the same everywhere
        for i in range(len(slm.icecaps)):
            md = slm.icecaps[i]
            if np.nonzero(md.slr.steric_rate - slm.earth.slr.steric_rate[slm.earth.slr.transitions[i]]) != []:
                raise RuntimeError('steric rate on ice cap {} is not the same as for the earth'.format(md.miscellaneous.name))
    #}}}

    def mergeresults(self): # {{{
        champs = fieldnames(self.icecaps[i].results.TransientSolution)
        for i in range(len(self.mergedcaps / 2)):
            md = self.mergedcaps[2 * i]
            trans = self.mergedcaps[2 * i + 1]
            for j in range(len(self.icecaps[0].results.TransientSolution)):
                for k in range(len(champs)):
                    if isinstance(getattr(icecaps[0].results.TransientSolution[j], champs[k]), float):
                        # Vertex or element?
                        if len(getattr(icecaps[0].results.TransientSolution[j], champs[k]) == icecaps[0].mesh.numberofvertices):
                            setattr(md.results.TransientSolution[j], champs[k], np.zeros(md.mesh.numberofvertices))
                            for l in range(len(trans)):
                                resultcap = getattr(icecaps[l].results.TransientSolution[j], champs[k])
                                setattr(getattr(md.results.TransientSolution[j], champs[k]), trans[l], resultcap)
                        else:
                            if champs[k] == 'IceVolume' or champs[k] == 'IceVolumeAboveFlotation':
                                setattr(md.results.TransientSolution, champs[k], 0)
                                for l in range(len(trans)):
                                    resultcap = getattr(icecaps[l].results.TransientSolution[j], champs[k])
                                    setattr(md.results.TransientSolution[j], champs[k], getattr(md.results.TransientSolution[j], champs[k]) + resultcap)
                            elif champs[k] == 'time':
                                setattr(md.results.TransientSolution[j], champs[k], getattr(icecaps[0].results.TransientSolution[j], champs[k]))
                            else:
                                continue
                    else:
                        continue
            self.mergedcaps[2 * i] = md
    #}}}
    
    def listcaps(self): # {{{
        for i in range(len(self.icecaps)):
            print('{}: {}'.format(i, self.icecaps[i].miscellaneous.name))
    #}}}

    def continents(self): # {{{
        list = []
        for i in range(len(self.basins)):
            list.append = self.basins[i].continent
        return np.unique(list)
    #}}}

    def basinsfromcontinent(self, continent): # {{{
        list = []
        for i in range(len(self.icecaps)):
            if self.basins[i].continent == continent:
                list.append = self.basins[i].name
        return np.unique(list)
    #}}}

    def addbasin(self, bas): # {{{
        if bas.__class__.__name__ != 'basin':
            raise RuntimeError('addbasin method only takes a \'basin\' class object as input')
        self.basins.append(bas)
    #}}}

    def intersections(self, *args): #{{{
        options = pairoptions(*args)
        force = options.getfieldvalue('force', 0)

        # Initialize, to avoid issues of having more transitions than meshes
        self.transitions = []
        self.eltransitions = []

        # For elements
        onesmatrix = np.array([[1], [1], [1]])
        xe = self.earth.mesh.x[self.earth.mesh.elements] * onesmatrix / 3
        ye = self.earth.mesh.y[self.earth.mesh.elements] * onesmatrix / 3
        ze = self.earth.mesh.z[self.earth.mesh.elements] * onesmatrix / 3

        for i in range(len(self.icecaps)):
            mdi = self.icecaps[i]
            mdi = TwoDToThreeD(mdi, self.planet)

            # For elements
            zei = mdi.mesh.x[mdi.mesh.elements] * onesmatrix / 3
            yei = mdi.mesh.y[mdi.mesh.elements] * onesmatrix / 3
            zei = mdi.mesh.z[mdi.mesh.elements] * onesmatrix / 3

            print('Computing vertex intersections for basin {}'.format(self.basins[i].name))

            self.transitions.append(meshintersect3d(self.earth.mesh.x, self.earth.mesh.y, self.earth.mesh.z, mdi.mesh.x, mdi.mesh.y, mdi.mesh.z, 'force', force))
            self.eltransitions.append(meshintersect3d(xe, ye, ze, xei, yei, zei, 'force', force))
    #}}}

    def checkintersections(self): #{{{
        flags = np.zeros(self.earth.mesh.numberofvertices, 1)
        for i in range(len(self.basins)):
            flags[self.transitions[i]] = i
        plotmodel(self.earth, 'data', flags, 'coastline', 'on')
    #}}}

    def checkbasinconsistency(self): #{{{
        for i in range(len(self.basins)):
            self.basins[i].checkconsistency()
    #}}}

    def basinindx(self, *args): #{{{
        options = pairoptions(*args)
        continent = options.getfieldvalue('continent', 'all')
        bas = options.getfieldvalue('basin', 'all')

        #expand continent list #{{{
        if type(continent) == np.ndarray:
            if continent.shape[1] == 1:
                if continent[0] == 'all':
                    #need to transform this into a list of continents
                    continent = []
                    for i in range(len(self.basins)):
                        continent.append(self.basins[i].continent)
                    continent = np.unique(continent)
            else:
                pass #nothing to do, we have a list of continents
        else:
            if continent == 'all':
                #need to transform this into a list of continents
                continent = []
                for i in range(len(self.basins)):
                    continent.append(self.basins[i].continent)
                continent = np.unique(continent)
            else:
                continent = [continent]
        #}}}
        #expand basins list using the continent list above and the extra bas discriminator #{{{
        if type(bas) == np.ndarray:
            if bas.shape[1] == 1:
                if bas[0] == 'all':
                    #need to transform this into a list of basins
                    baslist = []
                    for i in range(len(self.basins)):
                        if self.basins[i].iscontinentany(continent):
                            baslist.append(i)
                    baslist = np.unique(baslist)
                else:
                    bas = bas[0]
                    baslist = []
                    for i in range(len(self.basins)):
                        if self.basins[i].iscontinentany(continent):
                            if self.basins[i].isnameany(bas):
                                baslist.append(i)
            else:
                #we have a list of basin names
                baslist = []
                for i in range(len(bas)):
                    basname = bas[i]
                    for j in range(len(self.basins)):
                        if self.basins[j].iscontinentany(continent):
                            if self.basins[j].isnameany(basname):
                                baslist.append(j)
                    baslist = np.unique(baslist)
        else:
            if bas == 'all':
                baslist = []
                for i in range(len(self.basins)):
                    if self.basins[i].iscontinentany(continent):
                        baslist.append(i)
                baslist = np.unique(baslist)
            else:
                baslist = []
                for i in range(len(self.basins)):
                    if self.basins[i].iscontinentany(continent):
                        if self.basins[i].isnameany(bas):
                            baslist.append(i)
                baslist = np.unique(baslist)

        return baslist
        #}}}
    #}}}
