import warnings

import numpy as np

from generic import generic
from issmsettings import issmsettings
from meshintersect3d import meshintersect3d
from miscellaneous import miscellaneous
from model import model
from modelmerge3d import modelmerge3d
from pairoptions import pairoptions
from private import private
from TwoDToThreeD import TwoDToThreeD


class sealevelmodel(object):
    """SEALEVELMODEL class definition

    Usage:
        slm = sealevelmodel(*args)

        where args is a variable list of options

    Example:
        slm = sealevelmodel(
            'icecap', md_greenland,
            'icecap', md_antarctica,
            'earth', md_earth
        )
    """

    def __init__(self, *args): #{{{
        self.icecaps        = [] # list of land/ice models; name should be changed later
        self.earth          = 0 # model for the whole earth
        self.basins         = [] # list of basins, matching icecaps, where shapefile info is held
        self.cluster        = 0
        self.miscellaneous  = 0
        self.settings       = 0
        self.private        = 0
        self.mergedcaps     = 0
        self.transitions    = []
        self.eltransitions  = []
        self.planet         = ''

        # Create a default object
        self.setdefaultparameters()

        if len(args):
            # Use provided options to set fields
            options = pairoptions(*args)

            # Recover all the icecap models
            self.icecaps = options.getfieldvalue('ice_cap', [])

            # Recover the earth models
            self.earth = options.getfieldvalue('earth', 0)

            # Set planet type
            self.planet = options.getfieldvalue('planet', 'earth')
    #}}}

    def __repr__(self): # {{{
        s = '{}\n'.format(fielddisplay(self, 'icecaps', 'ice caps'))
        s += '{}\n'.format(fielddisplay(self, 'earth', 'earth'))
        s += '{}\n'.format(fielddisplay(self, 'settings', 'settings properties'))
        s += '{}\n'.format(fielddisplay(self, 'cluster', 'cluster parameters (number of cpus...'))
        s += '{}\n'.format(fielddisplay(self, 'miscellaneous', 'miscellaneous fields'))
    #}}}

    def setdefaultparameters(self): # {{{
        self.icecaps        = []
        self.earth          = []
        self.cluster        = generic()
        self.miscellaneous  = miscellaneous()
        self.settings       = issmsettings()
        self.private        = private()
        self.transitions    = []
        self.eltransitions  = []
        self.planet         = 'earth'
    #}}}

    @staticmethod
    def checkconsistency(slm, solutiontype): # {{{
        # Is the coupler turned on?
        for i in range(len(slm.icecaps)):
            if slm.icecaps[i].transient.iscoupler == 0:
                warnings.warn('sealevelmodel checkconsistency error: icecap model {} should have the transient coupler option turned on!'.format(slm.icecaps[i].miscellaneous.name))

        if slm.earth.transient.iscoupler == 0:
            warnings.warn('sealevelmodel checkconsistency error: earth model should have the transient coupler option turned on!')

        # Check that the transition vectors have the right size
        for i in range(len(slm.icecaps)):
            if slm.icecaps[i].mesh.numberofvertices != len(slm.earth.slr.transitions[i]):
                raise RuntimeError('sealevelmodel checkconsistency issue with size of transition vector for ice cap: {} name: {}'.format(i, slm.icecaps[i].miscellaneous.name))

        # Check that run frequency is the same everywhere
        for i in range(len(slm.icecaps)):
            if slm.icecaps[i].slr.geodetic_run_frequency != slm.earth.geodetic_run_frequency:
                raise RuntimeError('sealevelmodel checkconsistency error: icecap model {} should have the same run frequency as earth!'.format(slm.icecaps[i].miscellaneous.name))

        # Make sure steric_rate is the same everywhere
        for i in range(len(slm.icecaps)):
            md = slm.icecaps[i]
            if np.nonzero(md.slr.steric_rate - slm.earth.slr.steric_rate[slm.earth.slr.transitions[i]]) != []:
                raise RuntimeError('steric rate on ice cap {} is not the same as for the earth'.format(md.miscellaneous.name))
    #}}}

    def mergeresults(self): # {{{
        champs = fieldnames(self.icecaps[i].results.TransientSolution)
        for i in range(len(self.mergedcaps / 2)):
            md = self.mergedcaps[2 * i]
            trans = self.mergedcaps[2 * i + 1]
            for j in range(len(self.icecaps[0].results.TransientSolution)):
                for k in range(len(champs)):
                    if isinstance(getattr(icecaps[0].results.TransientSolution[j], champs[k]), float):
                        # Vertex or element?
                        if len(getattr(icecaps[0].results.TransientSolution[j], champs[k]) == icecaps[0].mesh.numberofvertices):
                            setattr(md.results.TransientSolution[j], champs[k], np.zeros(md.mesh.numberofvertices))
                            for l in range(len(trans)):
                                resultcap = getattr(icecaps[l].results.TransientSolution[j], champs[k])
                                setattr(getattr(md.results.TransientSolution[j], champs[k]), trans[l], resultcap)
                        else:
                            if champs[k] == 'IceVolume' or champs[k] == 'IceVolumeAboveFlotation':
                                setattr(md.results.TransientSolution, champs[k], 0)
                                for l in range(len(trans)):
                                    resultcap = getattr(icecaps[l].results.TransientSolution[j], champs[k])
                                    setattr(md.results.TransientSolution[j], champs[k], getattr(md.results.TransientSolution[j], champs[k]) + resultcap)
                            elif champs[k] == 'time':
                                setattr(md.results.TransientSolution[j], champs[k], getattr(icecaps[0].results.TransientSolution[j], champs[k]))
                            else:
                                continue
                    else:
                        continue
            self.mergedcaps[2 * i] = md
    #}}}
    
    def listcaps(self): # {{{
        for i in range(len(self.icecaps)):
            print('{}: {}'.format(i, self.icecaps[i].miscellaneous.name))
    #}}}

    def continents(self): # {{{
        list = []
        for i in range(len(self.basins)):
            list.append = self.basins[i].continent
        return np.unique(list)
    #}}}

    def basinsfromcontinent(self, continent): # {{{
        list = []
        for i in range(len(self.icecaps)):
            if self.basins[i].continent == continent:
                list.append = self.basins[i].name
        return np.unique(list)
    #}}}

    def addbasin(self, bas): # {{{
        if bas.__class__.__name__ != 'basin':
            raise RuntimeError('addbasin method only takes a \'basin\' class object as input')
        self.basins.append(bas)
    #}}}

    def intersections(self, *args): #{{{
        options = pairoptions(*args)
        force = options.getfieldvalue('force', 0)

        # Initialize, to avoid issues of having more transitions than meshes
        self.transitions = []
        self.eltransitions = []

        # For elements
        onesmatrix = np.array([[1], [1], [1]])
        xe = self.earth.mesh.x[self.earth.mesh.elements] * onesmatrix / 3
        ye = self.earth.mesh.y[self.earth.mesh.elements] * onesmatrix / 3
        ze = self.earth.mesh.z[self.earth.mesh.elements] * onesmatrix / 3

        for i in range(len(self.icecaps)):
            mdi = self.icecaps[i]
            mdi = TwoDToThreeD(mdi, self.planet)

            # For elements
            zei = mdi.mesh.x[mdi.mesh.elements] * onesmatrix / 3
            yei = mdi.mesh.y[mdi.mesh.elements] * onesmatrix / 3
            zei = mdi.mesh.z[mdi.mesh.elements] * onesmatrix / 3

            print('Computing vertex intersections for basin {}'.format(self.basins[i].name))

            self.transitions.append(meshintersect3d(self.earth.mesh.x, self.earth.mesh.y, self.earth.mesh.z, mdi.mesh.x, mdi.mesh.y, mdi.mesh.z, 'force', force))
            self.eltransitions.append(meshintersect3d(xe, ye, ze, xei, yei, zei, 'force', force))
    #}}}

    def checkintersections(self): #{{{
        flags = np.zeros(self.earth.mesh.numberofvertices, 1)
        for i in range(len(self.basins)):
            flags[self.transitions[i]] = i
        plotmodel(self.earth, 'data', flags, 'coastline', 'on')
    #}}}

    def checkbasinconsistency(self): #{{{
        for i in range(len(self.basins)):
            self.basins[i].checkconsistency()
    #}}}

    def basinindx(self, *args): #{{{
        options = pairoptions(*args)
        continent = options.getfieldvalue('continent', 'all')
        bas = options.getfieldvalue('basin', 'all')

        # Expand continent list #{{{
        if type(continent) == np.ndarray:
            if continent.shape[1] == 1:
                if continent[0] == 'all':
                    # Need to transform this into a list of continents
                    continent = []
                    for i in range(len(self.basins)):
                        continent.append(self.basins[i].continent)
                    continent = np.unique(continent)
                else:
                    pass # Nothing to do: assume we have a list of continents
            else:
                pass # Nothing to do: assume we have a list of continents
        else:
            if continent == 'all':
                # Need to transform this into a list of continents
                continent = []
                for i in range(len(self.basins)):
                    continent.append(self.basins[i].continent)
                continent = np.unique(continent)
            else:
                pass # Nothing to do: assume we have a list of continents
        #}}}

        # Expand basins list using the continent list above and the extra bas discriminator #{{{
        if type(bas) == np.ndarray:
            if bas.shape[1] == 1:
                if bas[0] == 'all':
                    # Need to transform this into a list of basins
                    baslist = []
                    for i in range(len(self.basins)):
                        if self.basins[i].iscontinentany(continent):
                            baslist.append(i)
                    baslist = np.unique(baslist)
                else:
                    bas = bas[0]
                    baslist = []
                    for i in range(len(self.basins)):
                        if self.basins[i].iscontinentany(continent):
                            if self.basins[i].isnameany(bas):
                                baslist.append(i)
            else:
                # We have a list of basin names
                baslist = []
                for i in range(len(bas)):
                    basname = bas[i]
                    for j in range(len(self.basins)):
                        if self.basins[j].iscontinentany(continent):
                            if self.basins[j].isnameany(basname):
                                baslist.append(j)
                    baslist = np.unique(baslist)
        else:
            if bas == 'all':
                baslist = []
                for i in range(len(self.basins)):
                    if self.basins[i].iscontinentany(continent):
                        baslist.append(i)
                baslist = np.unique(baslist)
            else:
                baslist = []
                for i in range(len(self.basins)):
                    if self.basins[i].iscontinentany(continent):
                        if self.basins[i].isnameany(bas):
                            baslist.append(i)
                baslist = np.unique(baslist)

        return baslist
        #}}}
    #}}}

    def addicecap(self, md): #{{{
        if not type(md) ==  model:
            raise Exception("addicecap method only takes a 'model' class object as input")

        self.icecaps.append(md)
    #}}}

    def basinsplot3d(self, *args): #{{{
        for i in range(len(self.basins)):
            self.basins[i].plot3d(*args)
    #}}}

    def caticecaps(self, *args): #{{{
        # Recover options
        options = pairoptions(*args)
        tolerance = options.getfieldvalue('tolerance', .65)
        loneedgesdetect = options.getfieldvalue('loneedgesdetect', 0)

        # Make 3D model
        models = self.icecaps
        for i in range(len(models)):
            models[i] = TwoDToThreeD(models[i], self.planet)

        # Plug all models together
        md = models[0]
        for i in range(1, len(models)):
            md = modelmerge3d(md, models[i], 'tolerance', tolerance)
            md.private.bamg.landmask = np.vstack((md.private.bamg.landmask, models[i].private.bamg.landmask))

        # Look for lone edges if asked for it
        if loneedgesdetect:
            edges = loneedges(md)
            plotmodel(md, 'data', md.mask.land_levelset)
            for i in range(len(edges)):
                ind1 = edges(i, 1)
                ind1 = edges(i, 2)
                plot3([md.mesh.x[ind1], md.mesh.x[ind2]], [md.mesh.y[ind1], md.mesh.y[ind2]], [md.mesh.z[ind1], md.mesh.z[ind2]], 'g*-')

        # Plug into earth
        self.earth = md

        # Create mesh radius
        self.earth.mesh.r = planetradius('earth')
    #}}}

    def viscousiterations(self): #{{{
        for i in range(len(self.icecaps)):
            ic = self.icecaps[i]
            mvi = ic.resutls.TransientSolution[0].StressbalanceConvergenceNumSteps
            for j in range(1, len(ic.results.TransientSolution) - 1):
                mvi = np.max(mvi, ic.results.TransientSolution[j].StressbalanceConvergenceNumSteps)
            print("{}, {}: {}".format(i, self.icecaps[i].miscellaneous.name, mvi))
    #}}}

    def maxtimestep(self): #{{{
        for i in range(len(self.icecaps)):
            ic = self.icecaps[i]
            mvi = len(ic.results.TransientSolution)
            timei = ic.results.TransientSolution[-1].time
            print("{}, {}: {}/{}".format(i, self.icecaps[i].miscellaneous.name, mvi, timei))

        mvi = len(self.earth.results.TransientSolution)
        timei = self.earth.results.TransientSolution[-1].time
        print("Earth: {}/{}", mvi, timei)
    #}}}
    
    def transfer(self, string): #{{{
        # Recover field size in one icecap
        n = np.size(getattr(self.icecaps[i], string), 0)

        if n == self.icecaps[0].mesh.numberofvertices:
            setattr(self.earth, string, np.zeros((self.earth.mesh.numberofvertices, )))
            for i in range(len(self.icecaps)):
                getattr(self.earth, string)[self.transitions[i]] = getattr(self.icecaps[i], string)
        elif n == (self.self.icecaps[0].mesh.numberofvertices + 1):
            # Dealing with transient dataset
            # Check that all timetags are similar between all icecaps #{{{
            for i in range(len(self.icecaps)):
                capfieldi = getattr(self.icecaps[i], string)
                for j in range(1, len(self.icecaps)):
                    capfieldj = getattr(self.icecaps[j], string)
                    if capfieldi[-1, :] != capfieldj[-1, :]:
                        raise Exception("Time stamps for {} field is different between icecaps {} and {}".format(string, i, j))
            capfield1 = getattr(self.icecaps[0], string)
            times = capfield1[-1, :]
            nsteps = len(times)
            #}}}
            # Initialize #{{{
            field = np.zeros((self.earth.mesh.numberofvertices + 1, nsteps))
            field[-1, :] = times # Transfer the times only, not the values
            #}}}
            # Transfer all the time fields #{{{
            for i in range(len(self.icecaps)):
                capfieldi = getattr(self.icecaps[i], string)
                for j in range(nsteps):
                    field[self.transitions[i], j] = capfieldi[0:-2, j] # Transfer only the values, not the time
            setattr(self.earth, string, field) # Do not forget to plug the field variable into its final location
            #}}}
        elif n == (self.icecaps[0].mesh.numberofelements):
            setattr(self.earth, string, np.zeros((self.earth.mesh.numberofvertices, )))
            for i in range(len(self.icecaps)):
                getattr(self.earth, string)[self.eltransitions[i]] = getattr(self.icecaps[i], string)
        else:
            raise Exception('not supported yet')
    #}}}

    def homogenize(self, noearth=0): #{{{
        mintimestep = np.inf

        for i in range(len(self.icecaps)):
            ic = self.icecaps[i]
            mintimestep = np.min(mintimestep, len (ic.results.TransientSolution))

        if not noearth:
            mintimestep = np.min(mintimestep, len(self.earth.results.TransientSolution))

        for i in range(len(self.icecaps)):
            ic = self.icecaps[i]
            ic.resuts.TransientSolution = ic.results.TransientSolution[:mintimestep]
            self.icecaps[i] = ic

        ic = self.earth

        if not noearth:
            ic.results.TransientSolution = ic.resutls.TransientSolution[:mintimestep]

        self.earth = ic

        return self
    #}}}

    def initializemodels(self): #{{{
        for i in range(len(self.basins)):
            md = model()
            md.miscellaneous.name = self.basins[i].name
            self.addicecap(md)
    #}}}
