import numpy as np
from checkfield import checkfield
from fielddisplay import fielddisplay
from WriteData import WriteData


class stochasticforcing(object):
    """STOCHASTICFORCING class definition

    Usage:
        stochasticforcing = stochasticforcing()
    """

    def __init__(self, *args):  # {{{
        self.isstochasticforcing = 0
        self.fields = np.nan
        self.defaultdimension = 0
        self.default_id = np.nan
        self.covariance = np.nan
        self.randomflag = 1

        if len(args) == 0:
            self.setdefaultparameters()
        else:
            raise RuntimeError('constructor not supported for stochasticforcing')

    def __repr__(self):  # {{{
        s = '   stochasticforcing parameters:\n'
        s += '{}\n'.format(fielddisplay(self, 'isstochasticforcing', 'is stochasticity activated?'))
        s += '{}\n'.format(fielddisplay(self, 'fields', 'fields with stochasticity applied, ex: [\'SMBautoregression\'], or [\'FrontalForcingsRignotAutoregression\']'))
        s += '{}\n'.format(fielddisplay(self, 'defaultdimension', 'dimensionality of the noise terms (does not apply to fields with their specific dimension)'))
        s += '{}\n'.format(fielddisplay(self, 'default_id', 'id of each element for partitioning of the noise terms (does not apply to fields with their specific partition)'))
        s += '{}\n'.format(fielddisplay(self, 'covariance', 'covariance matrix for within- and between-fields covariance (units must be squared field units)'))
        s += '{}\n'.format(fielddisplay(self, 'randomflag', 'whether to apply real randomness (true) or pseudo-randomness with fixed seed (false)'))
        s += 'Available fields:\n'
        s += '   DefaultCalving\n'
        s += '   FloatingMeltRate\n'
        s += '   SMBautoregression\n'
        s += '   FrontalForcingsRignotAutoregression (thermal forcing)\n'
        return s
    #}}}

    def setdefaultparameters(self):  # {{{
        # Type of stabilization used
        self.isstochasticforcing = 0 # stochasticforcing is turned off by default
        self.fields = [] # Need to initialize to list to avoid "RuntimeError: object of type 'float' has no len()" on import of class
        self.randomflag = 1 # true randomness is implemented by default
        return self
    #}}}

    def checkconsistency(self, md, solution, analyses):  # {{{
        # Early return
        if not self.isstochasticforcing:
            return md

        num_fields = len(self.fields)

        # Check that covariance matrix is positive definite (this is done internally by linalg)
        try:
            np.linalg.cholesky(self.covariance)
        except:
            raise TypeError('md.stochasticforcing.covariance is not positive definite')

        # Check that all fields agree with the corresponding md class and if any field needs the default params
        checkdefaults = False # Need to check defaults only if one of the fields does not have its own dimensionality
        structstoch = self.structstochforcing()
        for field in self.fields:
            # Checking agreement of classes
            if 'SMB' in field:
                mdname = structstoch[field]
                if (type(md.smb).__name__ != mdname):
                    raise TypeError('md.smb does not agree with stochasticforcing field {}'.format(mdname))
            if 'FrontalForcings' in field:
                mdname = structstoch[field]
                if (type(md.frontalforcings).__name__ != mdname):
                    raise TypeError('md.frontalforcings does not agree with stochasticforcing field {}'.format(mdname))
            if 'Calving' in field:
                mdname = structstoch[field]
                if (type(md.calving).__name__ != mdname):
                    raise TypeError('md.calving does not agree with stochasticforcing field {}'.format(mdname))
            if 'BasalforcingsFloatingice' in field:
                mdname = structstoch[field]
                if (type(md.basalforcings).__name__ != mdname):
                    raise TypeError('md.basalforcings does not agree with stochasticforcing field {}'.format(mdname))
            # Checking for specific dimensions
            if not (field == 'SMBautoregression' or field == 'FrontalForcingsRignotAutoregression'):
                checkdefaults = True # field with non-specific dimensionality

        # Retrieve sum of all the field dimensionalities
        size_tot = self.defaultdimension * num_fields
        indSMBar = -1 # About to check for index of SMBautoregression
        indTFar = -1 # About to check for index of FrontalForcingsRignotAutoregression
        if ('SMBautoregression' in self.fields):
            size_tot = size_tot - self.defaultdimension + md.smb.num_basins
            indSMBar = self.fields.index('SMBautoregression') # Index of SMBar, now check for consistency with TFar timestep (08Nov2021)
        if ('FrontalForcingsRignotAutoregression' in self.fields):
            size_tot = size_tot - self.defaultdimension + md.frontalforcings.num_basins
            indTFar = self.fields.index('FrontalForcingsRignotAutoregression') # Index of TFar, now check for consistency with SMBar timestep (08Nov2021)
        if (indSMBar != -1 and indTFar != -1): # Both autoregressive models are used: check autoregressive time step consistency
            covsum = self.covariance[np.sum(self.defaultdimensions[0:indSMBar]).astype(int):np.sum(self.defaultdimensions[0:indSMBar + 1]).astype(int), np.sum(self.defaultdimensions[0:indTFar]).astype(int):np.sum(self.defaultdimensions[0:indTFar + 1]).astype(int)]
            if((md.smb.ar_timestep != md.frontalforcings.ar_timestep) and np.any(covsum != 0)):
                raise IOError('SMBautoregression and FrontalForcingsRignotAutoregression have different ar_timestep and non-zero covariance')

        md = checkfield(md, 'fieldname', 'stochasticforcing.isstochasticforcing', 'values', [0, 1])
        md = checkfield(md, 'fieldname', 'stochasticforcing.fields', 'numel', num_fields, 'cell', 1, 'values', self.supportedstochforcings())
        md = checkfield(md, 'fieldname', 'stochasticforcing.covariance', 'NaN', 1, 'Inf', 1, 'size', [size_tot, size_tot])  # global covariance matrix
        md = checkfield(md, 'fieldname', 'stochasticforcing.randomflag', 'numel', [1], 'values', [0, 1])
        if (checkdefaults):
            md = checkfield(md, 'fieldname', 'stochasticforcing.defaultdimension', 'numel', 1, 'NaN', 1, 'Inf', 1, '>', 0)
            md = checkfield(md, 'fieldname', 'stochasticforcing.default_id','Inf',1,'NaN',1,'>=',0,'<=',self.defaultdimension,'size', [md.mesh.numberofelements])
        return md
    # }}}

    def extrude(self, md):  # {{{
        # Nothing for now
        return self
    # }}}

    def marshall(self, prefix, md, fid):  # {{{
        yts = md.constants.yts
        num_fields = len(self.fields)

        WriteData(fid, prefix, 'object', self, 'fieldname', 'isstochasticforcing', 'format', 'Boolean')
        if not self.isstochasticforcing:
            return md
        else:
            # Retrieve dimensionality of each field
            dimensions = self.defaultdimension * np.ones((num_fields,))
            for ind, field in enumerate(self.fields):
                # Checking for specific dimensions
                if (field == 'SMBautoregression'):
                    dimensions[ind] = md.smb.num_basins
                if (field == 'FrontalForcingsRignotAutoregression'):
                    dimensions[ind] = md.frontalforcings.num_basins

            # Scaling covariance matrix (scale column-by-column and row-by-row)
            scaledfields = ['DefaultCalving', 'SMBautoregression'] # list of fields that need scaling * 1/yts
            tempcovariance = np.copy(self.covariance)
            for i in range(num_fields):
                if self.fields[i] in scaledfields:
                    inds = range(int(np.sum(dimensions[0:i])), int(np.sum(dimensions[0:i + 1])))
                    for row in inds:  # scale rows corresponding to scaled field
                        tempcovariance[row, :] = 1 / yts * tempcovariance[row, :]
                    for col in inds:  # scale columns corresponding to scaled field
                        tempcovariance[:, col] = 1 / yts * tempcovariance[:, col]
            # Set dummy default_id vector if defaults not used
            if np.any(np.isnan(self.default_id)):
                self.default_id = np.zeros(md.mesh.numberofelements)
            # Reshape dimensions as column array for marshalling
            dimensions = dimensions.reshape(1,len(dimensions))

            WriteData(fid, prefix, 'data', num_fields, 'name', 'md.stochasticforcing.num_fields', 'format', 'Integer')
            WriteData(fid, prefix, 'object', self, 'fieldname', 'fields', 'format', 'StringArray')
            WriteData(fid, prefix, 'data', dimensions, 'name', 'md.stochasticforcing.dimensions', 'format', 'IntMat','mattype',2)
            WriteData(fid, prefix, 'object', self, 'fieldname', 'default_id', 'format', 'IntMat', 'mattype', 2)  #12Nov2021 make sure this is zero-indexed!
            WriteData(fid, prefix, 'object', self, 'fieldname', 'defaultdimension', 'format', 'Integer')
            WriteData(fid, prefix, 'data', tempcovariance, 'name', 'md.stochasticforcing.covariance', 'format', 'DoubleMat')
            WriteData(fid, prefix, 'object', self, 'fieldname', 'randomflag', 'format', 'Boolean')
    # }}}

    def supportedstochforcings(self):  # {{{
        """Defines list of fields supported by the class md.stochasticforcing
        """
        list1 = self.structstochforcing()
        list1 = list1.keys()
        return list(list1)
    #}}}

    def structstochforcing(self):  # {{{
        """Defines dictionary with list of fields
           supported and corresponding md names
        """
        structure = {'DefaultCalving': 'calving',
                     'FloatingMeltRate': 'basalforcings',
                     'FrontalForcingsRignotAutoregression': 'frontalforcingsrignotautoregression',
                     'SMBautoregression': 'SMBautoregression'}
        return structure
    # }}}
