# -*- coding: utf-8 -*-
import numpy as np
from checkfield import checkfield
from fielddisplay import fielddisplay
from MatlabFuncs import *
from WriteData import WriteData


class frontalforcingsrignotarma(object):
    """FRONTALFORCINGSRIGNOTARMA class definition

    Usage:
        frontalforcingsrignotarma = frontalforcingsrignotarma()
    """

    def __init__(self, *args):  # {{{
        self.num_basins = 0
        self.num_params = 0
        self.num_breaks = 0
        self.polynomialparams = np.nan
        self.datebreaks       = np.nan
        self.ar_order = 0
        self.ma_order = 0
        self.arma_timestep = 0
        self.arlag_coefs = np.nan
        self.malag_coefs = np.nan
        self.monthly_effects = np.nan
        self.basin_id = np.nan
        self.subglacial_discharge = np.nan

        if len(args) == 0:
            self.setdefaultparameters()
        else:
            error('constructor not supported')

    def __repr__(self):  # {{{
        s = '   Frontalforcings parameters:\n'
        s += '{}\n'.format(fielddisplay(self, 'num_basins', 'number of different basins [unitless]'))
        s += '{}\n'.format(fielddisplay(self, 'basin_id', 'basin number assigned to each element [unitless]'))
        s += '{}\n'.format(fielddisplay(self, 'subglacial_discharge', 'sum of subglacial discharge for each basin [m/d]'))
        s += '{}\n'.format(fielddisplay(self, 'num_breaks', 'number of different breakpoints in the piecewise-polynomial (separating num_breaks+1 periods)'))
        s += '{}\n'.format(fielddisplay(self, 'num_params', 'number of different parameters in the piecewise-polynomial (1:intercept only, 2:with linear trend, 3:with quadratic trend, etc.)'))
        s += '{}\n'.format(fielddisplay(self, 'polynomialparams', 'coefficients for the polynomial (const,trend,quadratic,etc.),dim1 for basins,dim2 for periods,dim3 for orders, ex: polyparams=cat(num_params,intercepts,trendlinearcoefs,trendquadraticcoefs)'))
        s += '{}\n'.format(fielddisplay(self, 'datebreaks', 'dates at which the breakpoints in the piecewise polynomial occur (1 row per basin) [yr]'))
        s += '{}\n'.format(fielddisplay(self, 'ar_order', 'order of the autoregressive model [unitless]'))
        s += '{}\n'.format(fielddisplay(self, 'ma_order', 'order of the moving-average model [unitless]'))
        s += '{}\n'.format(fielddisplay(self, 'arma_timestep', 'time resolution of the ARMA model [yr]'))
        s += '{}\n'.format(fielddisplay(self, 'arlag_coefs', 'basin-specific vectors of AR lag coefficients [unitless]'))
        s += '{}\n'.format(fielddisplay(self, 'malag_coefs', 'basin-specific vectors of MA lag coefficients [unitless]'))
        s += '{}\n'.format(fielddisplay(self, 'monthly_effects', 'basin-specific monthly values of TF added at corresponding month (default: all 0) [°C]'))
        return s
    #}}}

    def setdefaultparameters(self):  # {{{
        self.basin_id = np.nan
        self.num_basins = 0
        self.subglacial_discharge = np.nan
        self.ar_order = 0.0  # Autoregression model of order 0
        self.ma_order = 0.0  # Moving-average model of order 0
        return self
    #}}}

    def checkconsistency(self, md, solution, analyses):  # {{{
        # Early return
        if not (solution == 'TransientSolution') or not md.transient.ismovingfront:
            return md

        nbas = md.frontalforcings.num_basins;
        nprm = md.frontalforcings.num_params;
        nbrk = md.frontalforcings.num_breaks;
        md = checkfield(md, 'fieldname', 'frontalforcings.num_basins', 'numel', 1, 'NaN', 1, 'Inf', 1, '>', 0)
        md = checkfield(md, 'fieldname', 'frontalforcings.num_params', 'numel', 1, 'NaN', 1, 'Inf', 1, '>', 0)
        md = checkfield(md, 'fieldname', 'frontalforcings.num_breaks', 'numel', 1, 'NaN', 1, 'Inf', 1, '>=', 0)
        md = checkfield(md, 'fieldname', 'frontalforcings.basin_id', 'Inf', 1, '>=', 0, '<=', md.frontalforcings.num_basins, 'size', [md.mesh.numberofelements])
        md = checkfield(md, 'fieldname', 'frontalforcings.subglacial_discharge', '>=', 0, 'NaN', 1, 'Inf', 1, 'timeseries', 1)
        if len(np.shape(self.polynomialparams)) == 1:
            self.polynomialparams = np.array([[self.polynomialparams]])
        if(nbas>1 and nbrk>=1 and nprm>1):
            md = checkfield(md,'fieldname','frontalforcings.polynomialparams','NaN',1,'Inf',1,'size',[nbas,nbrk+1,nprm],'numel',nbas*(nbrk+1)*nprm)
        elif(nbas==1):
            md = checkfield(md,'fieldname','frontalforcings.polynomialparams','NaN',1,'Inf',1,'size',[nprm,nbrk+1],'numel',nbas*(nbrk+1)*nprm)
        elif(nbrk==0):
            md = checkfield(md,'fieldname','frontalforcings.polynomialparams','NaN',1,'Inf',1,'size',[nbas,nprm],'numel',nbas*(nbrk+1)*nprm)
        elif(nprm==1):
            md = checkfield(md,'fieldname','frontalforcings.polynomialparams','NaN',1,'Inf',1,'size',[nbas,nbrk],'numel',nbas*(nbrk+1)*nprm)
        md = checkfield(md, 'fieldname', 'frontalforcings.ar_order', 'numel', 1, 'NaN', 1, 'Inf', 1, '>=', 0)
        md = checkfield(md, 'fieldname', 'frontalforcings.ma_order', 'numel', 1, 'NaN', 1, 'Inf', 1, '>=', 0)
        md = checkfield(md, 'fieldname', 'frontalforcings.arma_timestep', 'numel', 1, 'NaN', 1, 'Inf', 1, '>=', md.timestepping.time_step) # ARMA time step cannot be finer than ISSM timestep
        md = checkfield(md, 'fieldname', 'frontalforcings.arlag_coefs', 'NaN', 1, 'Inf', 1, 'size', [md.frontalforcings.num_basins, md.frontalforcings.ar_order])
        md = checkfield(md, 'fieldname', 'frontalforcings.malag_coefs', 'NaN', 1, 'Inf', 1, 'size', [md.frontalforcings.num_basins, md.frontalforcings.ma_order])
        if(nbrk>0):
            md = checkfield(md, 'fieldname', 'frontalforcings.datebreaks', 'NaN', 1, 'Inf', 1, 'size', [nbas,nbrk])
        elif(np.size(md.frontalforcings.datebreaks)==0 or np.all(np.isnan(md.frontalforcings.datebreaks))):
            pass
        else:
            raise RuntimeError('md.frontalforcings.num_breaks is 0 but md.frontalforcings.datebreaks is not empty')
        if(np.any(np.isnan(md.frontalforcings.monthly_effects)==False)):
            md = checkfield(md, 'fieldname', 'frontalforcings.monthly_effects', 'NaN', 1, 'Inf', 1, 'size', [md.frontalforcings.num_basins, 12])
            if(md.timestepping.time_step>=1):
                raise RuntimeError('md.frontalforcings.monthly_effects are provided but md.timestepping.time_step>=1')
        return md
    # }}}

    def extrude(self, md):  # {{{
        # Nothing for now
        return self
    # }}}

    def marshall(self, prefix, md, fid):  # {{{
        yts = md.constants.yts
        nbas = md.frontalforcings.num_basins;
        nprm = md.frontalforcings.num_params;
        nper = md.frontalforcings.num_breaks+1;
        # Scale the parameters #
        polyparamsScaled   = np.copy(md.frontalforcings.polynomialparams)
        polyparams2dScaled = np.zeros((nbas,nper*nprm))
        if(nprm>1):
            # Case 3D #
            if(nbas>1 and nper>1):
                for ii in range(nprm):
                    polyparamsScaled[:,:,ii] = polyparamsScaled[:,:,ii]*(1/yts)**ii
                # Fit in 2D array #
                for ii in range(nprm):
                    polyparams2dScaled[:,ii*nper:(ii+1)*nper] = 1*polyparamsScaled[:,:,ii]
            # Case 2D and higher-order params at increasing row index #
            elif(nbas==1):
                for ii in range(nprm):
                    polyparamsScaled[ii,:] = polyparamsScaled[ii,:]*(1/yts)**ii
                # Fit in row array #
                for ii in range(nprm):
                    polyparams2dScaled[0,ii*nper:(ii+1)*nper] = 1*polyparamsScaled[ii,:]
            # Case 2D and higher-order params at incrasing column index #
            elif(nper==1):
                for ii in range(nprm):
                    polyparamsScaled[:,ii] = polyparamsScaled[:,ii]*(1/yts)**ii
                # 2D array is already in correct format #
                polyparams2dScaled = np.copy(polyparamsScaled)
        else:
            # 2D array is already in correct format and no need for scaling #
            polyparams2dScaled = np.copy(polyparamsScaled)
        if(np.any(np.isnan(md.frontalforcings.monthly_effects))): #monthly effects not provided, set to 0
            meffects = np.zeros((md.frontalforcings.num_basins,12))
        else:
            meffects = 1*md.frontalforcings.monthly_effects
        if(nper==1):
            dbreaks = np.zeros((nbas,1))
        else:
            dbreaks = np.copy(md.frontalforcings.datebreaks)

        WriteData(fid, prefix, 'name', 'md.frontalforcings.parameterization', 'data', 3, 'format', 'Integer')
        WriteData(fid, prefix, 'object', self, 'class', 'frontalforcings', 'fieldname', 'num_basins', 'format', 'Integer')
        WriteData(fid, prefix, 'object', self, 'class', 'frontalforcings', 'fieldname', 'num_breaks', 'format', 'Integer')
        WriteData(fid, prefix, 'object', self, 'class', 'frontalforcings', 'fieldname', 'num_params', 'format', 'Integer')
        WriteData(fid, prefix, 'object', self, 'class', 'frontalforcings', 'fieldname', 'subglacial_discharge', 'format', 'DoubleMat', 'mattype', 1, 'timeserieslength', md.mesh.numberofvertices + 1, 'yts', yts)
        WriteData(fid, prefix, 'object', self, 'class', 'frontalforcings', 'fieldname', 'ar_order', 'format', 'Integer')
        WriteData(fid, prefix, 'object', self, 'class', 'frontalforcings', 'fieldname', 'ma_order', 'format', 'Integer')
        WriteData(fid, prefix, 'object', self, 'class', 'frontalforcings', 'fieldname', 'arma_timestep', 'format', 'Double', 'scale', yts)
        WriteData(fid, prefix, 'object', self, 'class', 'frontalforcings', 'fieldname', 'basin_id', 'data', self.basin_id - 1, 'name', 'md.frontalforcings.basin_id', 'format', 'IntMat', 'mattype', 2)  # 0-indexed
        WriteData(fid, prefix, 'data', polyparams2dScaled, 'name', 'md.frontalforcings.polynomialparams', 'format', 'DoubleMat')
        WriteData(fid, prefix, 'object', self, 'class', 'frontalforcings', 'fieldname', 'arlag_coefs', 'format', 'DoubleMat', 'name', 'md.frontalforcings.arlag_coefs', 'yts', yts)
        WriteData(fid, prefix, 'object', self, 'class', 'frontalforcings', 'fieldname', 'malag_coefs', 'format', 'DoubleMat', 'name', 'md.frontalforcings.malag_coefs', 'yts', yts)
        WriteData(fid, prefix, 'data', dbreaks, 'name', 'md.frontalforcings.datebreaks', 'format', 'DoubleMat','scale',yts)
        WriteData(fid, prefix, 'data', meffects, 'name', 'md.frontalforcings.monthly_effects', 'format', 'DoubleMat')
    # }}}
