import numpy as np

from checkfield import checkfield
from fielddisplay import fielddisplay
from MatlabFuncs import *
from WriteData import WriteData

class stochasticforcing(object):
    """STOCHASTICFORCING class definition

    Usage:
        stochasticforcing = stochasticforcing()
    """

    def __init__(self, *args):  # {{{
        self.isstochasticforcing = 0
        self.fields = np.nan
        self.dimensions = np.nan
        self.covariance = np.nan
        self.randomflag = 1

        if len(args) == 0:
            self.setdefaultparameters()
        else:
            error('constructor not supported')

    def __repr__(self):  # {{{
        s = '   stochasticforcing parameters:\n'
        s += '{}\n'.format(fielddisplay(self, 'isstochasticforcing', 'is stochasticity activated?'))
        s += '{}\n'.format(fielddisplay(self, 'fields', 'fields with stochasticity applied, ex: [\'SMBautoregression\'], or [\'FrontalForcingsRignotAutoregression\']'))
        s += '{}\n'.format(fielddisplay(self, 'covariance', 'covariance matrix for within- and between-fields covariance (units must be squared field units)'))
        s += '{}\n'.format(fielddisplay(self, 'randomflag', 'whether to apply real randomness (true) or pseudo-randomness with fixed seed (false)'))
        s += 'Available fields:\n'
        s += '   SMBautoregression\n'
        s += '   FrontalForcingsRignotAutoregression (thermal forcing)\n'
        return s
    #}}}

    def setdefaultparameters(self):  # {{{
        # Type of stabilization used
        self.isstochasticforcing = 0 # stochasticforcing is turned off by default
        self.randomflag          = 1 # true randomness is implemented by default
        return self
    #}}}

    def checkconsistency(self, md, solution, analyses):  # {{{
        # Early return
        if not self.isstochasticforcing:
            return md

        num_fields  = numel(self.fields)
        size_tot    = np.sum(self.dimensions)

        md = checkfield(md, 'fieldname', 'stochasticforcing.isstochasticforcing', 'values', [0, 1])
        md = checkfield(md, 'fieldname', 'stochasticforcing.fields', 'numel', num_fields, 'cell', 1, 'values', supportedstochforcings()) # VV check here 'cell' (19Oct2021)
        md = checkfield(md, 'fieldname', 'stochasticforcing.dimensions', 'NaN', 1, 'Inf', 1, '>', 0, 'size', [num_fields]) # specific dimension for each field; NOTE: As opposed to MATLAB implementation, pass list
        md = checkfield(md, 'fieldname', 'stochasticforcing.covariance', 'NaN', 1, 'Inf', 1, 'size', [size_tot, size_tot]) # global covariance matrix
        md = checkfield(md, 'fieldname', 'stochasticforcing.randomflag', 'numel', [1], 'values', [0, 1])

        # Check that all fields agree with the corresponding md class
        for field in self.fields:
            if (contains(field, 'SMB')):
                if not (type(md.smb) == field):
                    error('md.smb does not agree with stochasticforcing field {}'.format(field))
            if (contains(field, 'frontalforcings')):
                if not (type(md.frontalforcings) == field):
                    error('md.frontalforcings does not agree with stochasticforcing field {}'.format(field))
        return md
    # }}}

    def extrude(self, md):  # {{{
        # Nothing for now
        return self
    # }}}

    def marshall(self, prefix, md, fid):  # {{{
        yts = md.constants.yts
        if (type(self.fields) is list):
            num_fields = len(self.fields)
            # Scaling covariance matrix (scale column-by-column and row-by-row)
            scaledfields = ['SMBautoregression'] # list of fields that need scaling * 1/yts
            for i in range(num_fields):
                if self.fields[i] in scaledfields:
                    inds = range(1 + np.sum(self.dimensions[0:i]), np.sum(self.dimensions[0:i]))
                    for row in inds: # scale rows corresponding to scaled field
                        self.covariance[row, :] = 1 / yts * self.covariance[row, :]
                    for col in inds: # scale columns corresponding to scaled field
                        self.covariance[:, col] = 1 / yts * self.covariance[:, col]

        WriteData(fid, prefix, 'object', self, 'fieldname', 'isstochasticforcing', 'format', 'Boolean')
        if not self.isstochasticforcing:
            return md
        else:
            WriteData(fid, prefix, 'data', num_fields, 'name', 'md.stochasticforcing.num_fields', 'format', 'Integer')
            WriteData(fid, prefix, 'object', self, 'fieldname', 'fields', 'format', 'StringArray')
            WriteData(fid, prefix, 'object', self, 'fieldname','dimensions', 'format', 'IntMat')
            WriteData(fid, prefix, 'object', self, 'fieldname', 'covariance', 'format', 'DoubleMat')
            WriteData(fid, prefix, 'object', self, 'fieldname', 'randomflag', 'format', 'Boolean')
    # }}}

def supportedstochforcings():
    """ Defines list of fields supported  by the class stochasticforcings
    """
    return [
        'SMBautoregression',
        'FrontalForcingsRignotAutoregression'
    ]
