import numpy as np

from checkfield import *
from fielddisplay import fielddisplay
from project3d import *
from WriteData import *
from GetAreas import *

class SMBautoregression(object):
    """SMBAUTOREGRESSION class definition

    Usage:
        SMBautoregression = SMBautoregression()
    """

    def __init__(self, *args):  # {{{
        self.num_basins = 0
        self.beta0 = np.nan
        self.beta1 = np.nan
        self.ar_order = 0
        self.ar_initialtime = 0
        self.ar_timestep = 0
        self.phi = np.nan
        self.basin_id = np.nan
        self.lapserate_pos = np.nan
        self.lapserate_neg = np.nan
        self.refelevation = np.nan
        self.steps_per_step = 1
        self.averaging = 0
        self.requested_outputs = []

        nargs = len(args)
        if nargs == 0:
            self.setdefaultparameters()
        else:
            raise Exception('constructor not supported')
    # }}}

    def __repr__(self):  # {{{
        s = '   surface forcings parameters:\n'
        s += '{}\n'.format(fielddisplay(self, 'num_basins', 'number of different basins [unitless]'))
        s += '{}\n'.format(fielddisplay(self, 'basin_id', 'basin number assigned to each element [unitless]'))
        s += '{}\n'.format(fielddisplay(self, 'beta0', 'basin-specific intercept values [m ice eq./yr] (if beta_1==0 mean=beta_0/(1-sum(phi)))'))
        s += '{}\n'.format(fielddisplay(self, 'beta1', 'basin-specific trend values [m ice eq. yr^(-2)]'))
        s += '{}\n'.format(fielddisplay(self, 'ar_order', 'order of the autoregressive model [unitless]'))
        s += '{}\n'.format(fielddisplay(self, 'ar_initialtime', 'initial time assumed in the autoregressive model parameterization [yr]'))
        s += '{}\n'.format(fielddisplay(self, 'ar_timestep', 'time resolution of the autoregressive model [yr]'))
        s += '{}\n'.format(fielddisplay(self, 'phi', 'basin-specific vectors of lag coefficients [unitless]'))
        s += '{}\n'.format(fielddisplay(self, 'lapserate_pos', 'basin-specific SMB lapse rates applied in range of SMB>=0 [m ice eq yr^-1 m^-1] (default: no lapse rate)'))
        s += '{}\n'.format(fielddisplay(self, 'lapserate_neg', 'basin-specific SMB lapse rates applied in range of SMB<0 [m ice eq yr^-1 m^-1] (default: no lapse rate)'))
        s += '{}\n'.format(fielddisplay(self, 'refelevation', 'basin-specific reference elevations on which SMB lapse rates are applied (default: basin mean elevation) [m]'))
        s += '{}\n'.format(fielddisplay(self, 'steps_per_step', 'number of smb steps per time step'))
        s += '{}\n'.format(fielddisplay(self, 'averaging', 'averaging methods from short to long steps'))
        s += '\t\t{}\n'.format('0: Arithmetic (default)')
        s += '\t\t{}\n'.format('1: Geometric')
        s += '\t\t{}\n'.format('2: Harmonic')
        s += '{}\n'.format(fielddisplay(self, 'requested_outputs', 'additional outputs requested'))
        return s
    # }}}

    def setdefaultparameters(self): #{{{
        self.ar_order = 0.0 # Autoregression model of order 0
    # }}}

    def extrude(self, md):  # {{{
        return self # Nothing for now
    # }}}

    def defaultoutputs(self, md):  # {{{
        return []
    # }}}

    def initialize(self, md):  # {{{
        if np.all(np.isnan(self.beta1)):
            self.beta1 = np.zeros((1, self.num_basins)) # No trend in SMB
            print('      smb.beta1 (trend) not specified: value set to 0')
        if self.ar_order == 0:
            self.ar_order = 1 # Dummy 1 value for autoregression
            self.phi = np.zeros((self.num_basins, self.ar_order)) # Autorgression coefficients all set to 0
            print('      smb.ar_order (order of autoregressive model) not specified: order of autoregressive model set to 0')
        if self.ar_initialtime == 0:
            self.ar_initialtime = md.timestepping.start_time # Autoregression model has no prescribed initial time
            print('      smb.ar_initialtime (initial time in the autoregressive model parameterization) not specified: set to md.timestepping.start_time')
        if self.ar_timestep == 0:
            self.ar_timestep = md.timestepping.time_step # Autoregression model has no prescribed time step
            print('      smb.ar_timestep (timestep of autoregressive model) not specified: set to md.timestepping.time_step')
        if np.all(np.isnan(self.phi)):
            self.phi = np.zeros((self.num_basins, self.ar_order)) # Autoregression model of order 0
            print('      smb.phi (lag coefficients) not specified: order of autoregressive model set to 0')
        return self
    # }}}

    def checkconsistency(self, md, solution, analyses):  # {{{
        if 'MasstransportAnalysis' in analyses:
            md = checkfield(md, 'fieldname', 'smb.num_basins', 'numel', 1, 'NaN', 1, 'Inf', 1, '>', 0)
            md = checkfield(md, 'fieldname', 'smb.basin_id', 'Inf', 1, '>=', 0, '<=', md.smb.num_basins, 'size', [md.mesh.numberofelements])
            md = checkfield(md, 'fieldname', 'smb.beta0', 'NaN', 1, 'Inf', 1, 'size', [1, md.smb.num_basins], 'numel', md.smb.num_basins) # Scheme fails if passed as column vector
            md = checkfield(md, 'fieldname', 'smb.beta1', 'NaN', 1, 'Inf', 1, 'size', [1, md.smb.num_basins], 'numel', md.smb.num_basins) # Scheme fails if passed as column vector; NOTE: As opposed to MATLAB implementation, pass list
            md = checkfield(md, 'fieldname', 'smb.ar_order', 'numel', 1, 'NaN', 1, 'Inf', 1, '>=', 0)
            md = checkfield(md, 'fieldname', 'smb.ar_initialtime', 'numel', 1, 'NaN', 1, 'Inf', 1)
            md = checkfield(md, 'fieldname', 'smb.ar_timestep', 'numel', 1, 'NaN', 1, 'Inf', 1, '>=', md.timestepping.time_step) # Autoregression time step cannot be finer than ISSM timestep
            md = checkfield(md, 'fieldname', 'smb.phi', 'NaN', 1, 'Inf', 1, 'size', [md.smb.num_basins, md.smb.ar_order])
            if(np.any(np.isnan(md.smb.lapserate_pos)==False) or np.size(md.smb.lapserate_pos)>1):
                md = checkfield(md, 'fieldname', 'smb.lapserate_pos', 'NaN', 1, 'Inf', 1, 'size', [1, md.smb.num_basins], 'numel', md.smb.num_basins)
            if(np.any(np.isnan(md.smb.lapserate_neg)==False) or np.size(md.smb.lapserate_neg)>1):
                md = checkfield(md, 'fieldname', 'smb.lapserate_neg', 'NaN', 1, 'Inf', 1, 'size', [1, md.smb.num_basins], 'numel', md.smb.num_basins)
            if(np.any(np.isnan(md.smb.refelevation)==False) or np.size(md.smb.refelevation)>1):
                md = checkfield(md, 'fieldname', 'smb.refelevation', 'NaN', 1, 'Inf', 1, '>=', 0, 'size', [1, md.smb.num_basins], 'numel', md.smb.num_basins)

        md = checkfield(md, 'fieldname', 'smb.steps_per_step', '>=', 1, 'numel', [1])
        md = checkfield(md, 'fieldname', 'smb.averaging', 'numel', [1], 'values', [0, 1, 2])
        md = checkfield(md, 'fieldname', 'smb.requested_outputs', 'stringrow', 1)
        return md
    # }}}

    def marshall(self, prefix, md, fid):  # {{{
        yts = md.constants.yts

        templapserate_pos = np.copy(md.smb.lapserate_pos)
        templapserate_neg = np.copy(md.smb.lapserate_neg)
        temprefelevation  = np.copy(md.smb.lapserate_neg)
        if(np.any(np.isnan(md.smb.lapserate_pos))):
            templapserate_pos = np.zeros((md.smb.num_basins))
            print('      smb.lapserate_pos not specified: set to 0')
        if(np.any(np.isnan(md.smb.lapserate_neg))):
            templapserate_neg = np.zeros((md.smb.num_basins))
            print('      smb.lapserate_neg not specified: set to 0')
        if(np.any(np.isnan(md.smb.refelevation))):
            temprefelevation = np.zeros((md.smb.num_basins))
            areas = GetAreas(md.mesh.elements, md.mesh.x, md.mesh.y)
            for ii in range(int(md.smb.num_basins)):
                indices = np.where(md.smb.basin_id==ii)[0]
                elemsh  = np.zeros((len(indices)))
                for jj in range(len(indices)):
                    elemsh[jj] = np.mean(md.geometry.surface[md.mesh.elements[indices[jj],:]])
                temprefelevation[ii] = np.sum(areas[indices]*elemsh)/np.sum(areas[indices])
            if(np.any(templapserate_pos!=0) or np.any(templapserate_neg!=0)):
                print('      smb.refelevation not specified: Reference elevations set to mean surface elevation of basins')

        WriteData(fid, prefix, 'name', 'md.smb.model', 'data', 55, 'format', 'Integer')
        WriteData(fid, prefix, 'object', self, 'class', 'smb', 'fieldname', 'num_basins', 'format', 'Integer')
        WriteData(fid, prefix, 'object', self, 'class', 'smb', 'fieldname', 'ar_order', 'format', 'Integer')
        WriteData(fid, prefix, 'object', self, 'class', 'smb', 'fieldname', 'ar_initialtime', 'format', 'Double', 'scale', yts)
        WriteData(fid, prefix, 'object', self, 'class', 'smb', 'fieldname', 'ar_timestep', 'format', 'Double', 'scale', yts)
        WriteData(fid, prefix, 'object', self, 'class', 'smb', 'fieldname', 'basin_id', 'data', self.basin_id, 'name', 'md.smb.basin_id', 'format', 'IntMat', 'mattype', 2) # 0-indexed
        WriteData(fid, prefix, 'object', self, 'class', 'smb', 'fieldname', 'beta0', 'format', 'DoubleMat', 'name', 'md.smb.beta0', 'scale', 1 / yts, 'yts', yts)
        WriteData(fid, prefix, 'object', self, 'class', 'smb', 'fieldname', 'beta1', 'format', 'DoubleMat', 'name', 'md.smb.beta1', 'scale', 1 / (yts ** 2), 'yts', yts)
        WriteData(fid, prefix, 'object', self, 'class', 'smb', 'fieldname', 'phi', 'format', 'DoubleMat', 'name', 'md.smb.phi', 'yts', yts)
        WriteData(fid, prefix, 'data', templapserate_pos, 'name', 'md.smb.lapserate_pos', 'format', 'DoubleMat','scale',1/yts,'yts',yts)
        WriteData(fid, prefix, 'data', templapserate_neg, 'name', 'md.smb.lapserate_neg', 'format', 'DoubleMat','scale',1/yts,'yts',yts)
        WriteData(fid, prefix, 'data', temprefelevation, 'name', 'md.smb.refelevation', 'format', 'DoubleMat')
        WriteData(fid, prefix, 'object', self, 'fieldname', 'steps_per_step', 'format', 'Integer')
        WriteData(fid, prefix, 'object', self, 'fieldname', 'averaging', 'format', 'Integer')

        # Process requested outputs
        outputs = self.requested_outputs
        indices = [i for i, x in enumerate(outputs) if x == 'default']
        if len(indices) > 0:
            outputscopy = outputs[0:max(0, indices[0] - 1)] + self.defaultoutputs(md) + outputs[indices[0] + 1:]
            outputs = outputscopy
        WriteData(fid, prefix, 'data', outputs, 'name', 'md.smb.requested_outputs', 'format', 'StringArray')

    # }}}
