import numpy as np
from checkfield import checkfield
from fielddisplay import fielddisplay
from WriteData import WriteData


class stochasticforcing(object):
    """STOCHASTICFORCING class definition

    Usage:
        stochasticforcing = stochasticforcing()
    """

    def __init__(self, *args):  # {{{
        self.isstochasticforcing = 0
        self.fields = np.nan
        self.defaultdimension = 0
        self.default_id = np.nan
        self.covariance = np.nan
        self.stochastictimestep = 0
        self.randomflag = 1

        if len(args) == 0:
            self.setdefaultparameters()
        else:
            raise RuntimeError('constructor not supported for stochasticforcing')

    def __repr__(self):  # {{{
        s = '   stochasticforcing parameters:\n'
        s += '{}\n'.format(fielddisplay(self, 'isstochasticforcing', 'is stochasticity activated?'))
        s += '{}\n'.format(fielddisplay(self, 'fields', 'fields with stochasticity applied, ex: [\'SMBautoregression\'], or [\'SMBforcing\',\'DefaultCalving\']'))
        s += '{}\n'.format(fielddisplay(self, 'defaultdimension', 'dimensionality of the noise terms (does not apply to fields with their specific dimension)'))
        s += '{}\n'.format(fielddisplay(self, 'default_id', 'id of each element for partitioning of the noise terms (does not apply to fields with their specific partition)'))
        s += '{}\n'.format(fielddisplay(self, 'covariance', 'covariance matrix for within- and between-fields covariance (units must be squared field units)'))
        s += '{}\n'.format(fielddisplay(self, 'stochastictimestep', 'timestep at which new stochastic noise terms are generated (default: md.timestepping.time_step)'))
        s += '{}\n'.format(fielddisplay(self, 'randomflag', 'whether to apply real randomness (true) or pseudo-randomness with fixed seed (false)'))
        s += 'Available fields:\n'
        s += '   BasalforcingsSpatialDeepwaterMeltingRate\n'
        s += '   DefaultCalving\n'
        s += '   FloatingMeltRate\n'
        s += '   FrictionWaterPressure\n'
        s += '   FrontalForcingsRignotAutoregression (thermal forcing)\n'
        s += '   SMBautoregression\n'
        s += '   SMBforcing\n'
        return s
    #}}}

    def setdefaultparameters(self):  # {{{
        # Type of stabilization used
        self.isstochasticforcing = 0  # stochasticforcing is turned off by default
        self.fields = []  # Need to initialize to list to avoid "RuntimeError: object of type 'float' has no len()" on import of class
        self.randomflag = 1  # true randomness is implemented by default
        return self
    #}}}

    def checkconsistency(self, md, solution, analyses):  # {{{
        # Early return
        if not self.isstochasticforcing:
            return md

        num_fields = len(self.fields)
        if(self.stochastictimestep==0):
            md.stochasticforcing.stochastictimestep = md.timestepping.time_step #by default: stochastictimestep set to ISSM time step
            print('      stochasticforcing.stocahstictimestep not specified: set to md.timestepping.time_step')

        # Check that covariance matrix is positive definite (this is done internally by linalg)
        try:
            np.linalg.cholesky(self.covariance)
        except:
            raise TypeError('md.stochasticforcing.covariance is not positive definite')

        # Check that all fields agree with the corresponding md class and if any field needs the default params
        checkdefaults = False  # Need to check defaults only if one of the fields does not have its own dimensionality
        structstoch = self.structstochforcing()
        for field in self.fields:
            # Checking agreement of classes
            if 'SMBautoregression' in field:
                mdname = structstoch[field]
                if (type(md.smb).__name__ != mdname):
                    raise TypeError('md.smb does not agree with stochasticforcing field {}'.format(field))
            if 'SMBforcing' in field:
                mdname = structstoch[field]
                if (type(md.smb).__name__ != mdname):
                    raise TypeError('md.smb does not agree with stochasticforcing field {}'.format(field))
            if 'FrontalForcings' in field:
                mdname = structstoch[field]
                if (type(md.frontalforcings).__name__ != mdname):
                    raise TypeError('md.frontalforcings does not agree with stochasticforcing field {}'.format(field))
            if 'Calving' in field:
                mdname = structstoch[field]
                if (type(md.calving).__name__ != mdname):
                    raise TypeError('md.calving does not agree with stochasticforcing field {}'.format(field))
            if 'BasalforcingsFloatingice' in field:
                mdname = structstoch[field]
                if (type(md.basalforcings).__name__ != mdname):
                    raise TypeError('md.basalforcings does not agree with stochasticforcing field {}'.format(field))
            if 'BasalforcingsSpatialDeepwaterMeltingRate' in field:
                mdname = structstoch[field]
                if (type(md.basalforcings).__name__ != mdname):
                    raise TypeError('md.basalforcings does not agree with stochasticforcing field {}'.format(field))
            if 'BasalforcingsDeepwaterMeltingRateAutoregression' in field:
                mdname = structstoch[field]
                if (type(md.basalforcings).__name__ != mdname):
                    raise TypeError('md.basalforcings does not agree with stochasticforcing field {}'.format(field))
            if 'WaterPressure' in field:
                mdname = structstoch[field]
                if (type(md.friction).__name__ != mdname):
                    raise TypeError('stochasticforcing field {} is only implemented for default friction'.format(field))
                if md.friction.coupling not in[0, 1, 2]:
                    raise TypeError('stochasticforcing field {} is only implemented for cases md.friction.coupling 0 or 1 or 2'.format(field))
                if (np.any(md.friction.q == 0)):
                    raise TypeError('stochasticforcing field {} requires non-zero q exponent'.format(field))

            # Checking for specific dimensions
            if field not in['SMBautoregression', 'FrontalForcingsRignotAutoregression','BasalforcingsDeepwaterMeltingRateAutoregression']:
                checkdefaults = True  # field with non-specific dimensionality

        # Retrieve sum of all the field dimensionalities
        dimensions = self.defaultdimension*np.ones((num_fields))
        indSMBar   = -1  # About to check for index of SMBautoregression
        indTFar    = -1  # About to check for index of FrontalForcingsRignotAutoregression
        indBDWar   = -1  # About to check for index of BasalforcingsDeepwaterMeltingRateAutoregression
        if ('SMBautoregression' in self.fields):
            indSMBar = self.fields.index('SMBautoregression')  # Index of SMBar, now check for consistency with other timesteps
            dimensions[indSMBar] = md.smb.num_basins
            if(md.smb.ar_timestep<self.stochastictimestep):
                raise TypeError('SMBautoregression cannot have a timestep shorter than stochastictimestep')
        if ('FrontalForcingsRignotAutoregression' in self.fields):
            indTFar = self.fields.index('FrontalForcingsRignotAutoregression')  # Index of TFar, now check for consistency with other timesteps
            dimensions[indTFar] = md.frontalforcings.num_basins
            if(md.frontalforcings.ar_timestep<self.stochastictimestep):
                raise TypeError('FrontalForcingsRignotAutoregression cannot have a timestep shorter than stochastictimestep')
        if ('BasalforcingsDeepwaterMeltingRateAutoregression' in self.fields):
            indBDWar = self.fields.index('BasalforcingsDeepwaterMeltingRateAutoregression')  # Index of BDWar, now check for consistency with other timesteps
            dimensions[indTFar] = md.basalforcings.num_basins
            if(md.basalforcings.ar_timestep<self.stochastictimestep):
                raise TypeError('BasalforcingsDeepwaterMeltingRateAutoregression cannot have a timestep shorter than stochastictimestep')
        size_tot = np.sum(dimensions)

        if (indSMBar != -1 and indTFar != -1):  # Both autoregressive models are used: check autoregressive time step consistency
            covsum = self.covariance[np.sum(dimensions[0:indSMBar]).astype(int):np.sum(dimensions[0:indSMBar + 1]).astype(int), np.sum(dimensions[0:indTFar]).astype(int):np.sum(dimensions[0:indTFar + 1]).astype(int)]
            if((md.smb.ar_timestep != md.frontalforcings.ar_timestep) and np.any(covsum != 0)):
                raise IOError('SMBautoregression and FrontalForcingsRignotAutoregression have different ar_timestep and non-zero covariance')
        if (indSMBar != -1 and indBDWar != -1):  # Both autoregressive models are used: check autoregressive time step consistency
            covsum = self.covariance[np.sum(dimensions[0:indSMBar]).astype(int):np.sum(dimensions[0:indSMBar + 1]).astype(int), np.sum(dimensions[0:indBDWar]).astype(int):np.sum(dimensions[0:indBDWar + 1]).astype(int)]
            if((md.smb.ar_timestep != md.basalforcings.ar_timestep) and np.any(covsum != 0)):
                raise IOError('SMBautoregression and BasalforcingsDeepwaterMeltingRateAutoregression have different ar_timestep and non-zero covariance')
        if (indTFar != -1 and indBDWar != -1):  # Both autoregressive models are used: check autoregressive time step consistency
            covsum = self.covariance[np.sum(dimensions[0:indTFar]).astype(int):np.sum(dimensions[0:indTFar + 1]).astype(int), np.sum(dimensions[0:indBDWar]).astype(int):np.sum(dimensions[0:indBDWar + 1]).astype(int)]
            if((md.frontalforcings.ar_timestep != md.basalforcings.ar_timestep) and np.any(covsum != 0)):
                raise IOError('FrontalForcingsRignotAutoregression and BasalforcingsDeepwaterMeltingRateAutoregression have different ar_timestep and non-zero covariance')

        md = checkfield(md, 'fieldname', 'stochasticforcing.isstochasticforcing', 'values', [0, 1])
        md = checkfield(md, 'fieldname', 'stochasticforcing.fields', 'numel', num_fields, 'cell', 1, 'values', self.supportedstochforcings())
        md = checkfield(md, 'fieldname', 'stochasticforcing.covariance', 'NaN', 1, 'Inf', 1, 'size', [size_tot, size_tot])  # global covariance matrix
        md = checkfield(md, 'fieldname', 'stochasticforcing.stochastictimestep', 'NaN', 1,'Inf', 1, '>=', md.timestepping.time_step)
        md = checkfield(md, 'fieldname', 'stochasticforcing.randomflag', 'numel', [1], 'values', [0, 1])
        if (checkdefaults):
            md = checkfield(md, 'fieldname', 'stochasticforcing.defaultdimension', 'numel', 1, 'NaN', 1, 'Inf', 1, '>', 0)
            md = checkfield(md, 'fieldname', 'stochasticforcing.default_id', 'Inf', 1, 'NaN', 1, '>=', 0, '<=', self.defaultdimension, 'size', [md.mesh.numberofelements])
        return md
    # }}}

    def extrude(self, md):  # {{{
        # Nothing for now
        return self
    # }}}

    def marshall(self, prefix, md, fid):  # {{{
        yts = md.constants.yts

        WriteData(fid, prefix, 'object', self, 'fieldname', 'isstochasticforcing', 'format', 'Boolean')
        if not self.isstochasticforcing:
            return md

        else:
            num_fields = len(self.fields)
            if(self.stochastictimestep==0):
                md.stochasticforcing.stochastictimestep = md.timestepping.time_step #by default: stochastictimestep set to ISSM time step
            # Retrieve dimensionality of each field
            dimensions = self.defaultdimension * np.ones((num_fields))
            for ind, field in enumerate(self.fields):
                # Checking for specific dimensions
                if (field == 'SMBautoregression'):
                    dimensions[ind] = md.smb.num_basins
                if (field == 'FrontalForcingsRignotAutoregression'):
                    dimensions[ind] = md.frontalforcings.num_basins
                if (field == 'BasalforcingsDeepwaterMeltingRateAutoregression'):
                    dimensions[ind] = md.basalforcings.num_basins

            # Scaling covariance matrix (scale column-by-column and row-by-row)
            scaledfields = ['BasalforcingsDeepwaterMeltingRateAutoregression','BasalforcingsSpatialDeepwaterMeltingRate','DefaultCalving', 'FloatingMeltRate', 'SMBautoregression', 'SMBforcing']  # list of fields that need scaling * 1/yts
            tempcovariance = np.copy(self.covariance)
            for i in range(num_fields):
                if self.fields[i] in scaledfields:
                    inds = range(int(np.sum(dimensions[0:i])), int(np.sum(dimensions[0:i + 1])))
                    for row in inds:  # scale rows corresponding to scaled field
                        tempcovariance[row, :] = 1 / yts * tempcovariance[row, :]
                    for col in inds:  # scale columns corresponding to scaled field
                        tempcovariance[:, col] = 1 / yts * tempcovariance[:, col]
            # Set dummy default_id vector if defaults not used
            if np.any(np.isnan(self.default_id)):
                self.default_id = np.zeros(md.mesh.numberofelements)
            # Reshape dimensions as column array for marshalling
            dimensions = dimensions.reshape(1, len(dimensions))

            WriteData(fid, prefix, 'data', num_fields, 'name', 'md.stochasticforcing.num_fields', 'format', 'Integer')
            WriteData(fid, prefix, 'object', self, 'fieldname', 'fields', 'format', 'StringArray')
            WriteData(fid, prefix, 'data', dimensions, 'name', 'md.stochasticforcing.dimensions', 'format', 'IntMat', 'mattype', 2)
            WriteData(fid, prefix, 'object', self, 'fieldname', 'default_id', 'data', self.default_id - 1, 'format', 'IntMat', 'mattype', 2)  #12Nov2021 make sure this is zero-indexed!
            WriteData(fid, prefix, 'object', self, 'fieldname', 'defaultdimension', 'format', 'Integer')
            WriteData(fid, prefix, 'data', tempcovariance, 'name', 'md.stochasticforcing.covariance', 'format', 'DoubleMat')
            WriteData(fid, prefix, 'object', self, 'fieldname', 'stochastictimestep', 'format', 'Double', 'scale', yts)
            WriteData(fid, prefix, 'object', self, 'fieldname', 'randomflag', 'format', 'Boolean')
    # }}}

    def supportedstochforcings(self):  # {{{
        """Defines list of fields supported by the class md.stochasticforcing
        """
        list1 = self.structstochforcing()
        list1 = list1.keys()
        return list(list1)
    #}}}

    def structstochforcing(self):  # {{{
        """Defines dictionary with list of fields
           supported and corresponding md names
        """
        structure = {'BasalforcingsDeepwaterMeltingRateAutoregression': 'autoregressionlinearbasalforcings',
                     'BasalforcingsSpatialDeepwaterMeltingRate': 'spatiallinearbasalforcings',
                     'DefaultCalving': 'calving',
                     'FloatingMeltRate': 'basalforcings',
                     'FrictionWaterPressure': 'friction',
                     'FrontalForcingsRignotAutoregression': 'frontalforcingsrignotautoregression',
                     'SMBautoregression': 'SMBautoregression',
                     'SMBforcing': 'SMBforcing'}
        return structure
    # }}}
