import numpy as np

from checkfield import *
from fielddisplay import fielddisplay
from project3d import *
from WriteData import *

class autoregressionlinearbasalforcings(object):
    """autoregressionlinearbasalforcings class definition

    Usage:
        autoregressionlinearbasalforcings = autoregressionlinearbasalforcings()
    """

    def __init__(self, *args):  # {{{
        self.num_basins = 0
        self.beta0 = np.nan
        self.beta1 = np.nan
        self.ar_order = 0
        self.ar_initialtime = 0
        self.ar_timestep = 0
        self.phi = np.nan
        self.basin_id = np.nan
        self.groundedice_melting_rate = np.nan
        self.deepwater_elevation = np.nan
        self.upperwater_melting_rate = np.nan
        self.upperwater_elevation = np.nan
        self.geothermalflux = np.nan

        nargs = len(args)
        if nargs == 0:
            self.setdefaultparameters()
        else:
            raise Exception('constructor not supported')
    # }}}

    def __repr__(self):  # {{{
        s = '   surface forcings parameters:\n'
        s += '   autoregressive model is applied for deepwater_melting_rate\n'
        s += '{}\n'.format(fielddisplay(self, 'num_basins', 'number of different basins [unitless]'))
        s += '{}\n'.format(fielddisplay(self, 'basin_id', 'basin number assigned to each element [unitless]'))
        s += '{}\n'.format(fielddisplay(self, 'beta0', 'basin-specific intercept values [m ice eq./yr] (if beta_1==0 mean=beta_0/(1-sum(phi)))'))
        s += '{}\n'.format(fielddisplay(self, 'beta1', 'basin-specific trend values [m ice eq. yr^(-2)]'))
        s += '{}\n'.format(fielddisplay(self, 'ar_order', 'order of the autoregressive model [unitless]'))
        s += '{}\n'.format(fielddisplay(self, 'ar_initialtime', 'initial time assumed in the autoregressive model parameterization [yr]'))
        s += '{}\n'.format(fielddisplay(self, 'ar_timestep', 'time resolution of the autoregressive model [yr]'))
        s += '{}\n'.format(fielddisplay(self, 'phi', 'basin-specific vectors of lag coefficients [unitless]'))
        s += '{}\n'.format(fielddisplay(self, 'deepwater_elevation', 'basin-specific elevation of ocean deepwater [m]'))
        s += '{}\n'.format(fielddisplay(self, 'upperwater_melting_rate', 'basin-specic basal melting rate (positive if melting applied for floating ice whith base >= upperwater_elevation) [m/yr]'))
        s += '{}\n'.format(fielddisplay(self, 'upperwater_elevation', 'basin-specific elevation of ocean upperwater [m]'))
        s += '{}\n'.format(fielddisplay(self, 'groundedice_melting_rate','node-specific basal melting rate (positive if melting) [m/yr]'))
        s += '{}\n'.format(fielddisplay(self, 'geothermalflux','node-specific geothermal heat flux [W/m^2]'))
        return s
    # }}}

    def setdefaultparameters(self): #{{{
        #Nothing for now
        return self
    # }}}

    def extrude(self, md):  # {{{
        return self # Nothing for now
    # }}}

    def initialize(self, md):  # {{{
        if np.all(np.isnan(self.groundedice_melting_rate)):
            self.groundedice_melting_rate = np.zeros((md.mesh.numberofvertices))
            print("      no basalforcings.groundedice_melting_rate specified: values set as zero")
        return self
    # }}}

    def checkconsistency(self, md, solution, analyses):  # {{{
        if 'MasstransportAnalysis' in analyses:
            md = checkfield(md, 'fieldname', 'basalforcings.num_basins', 'numel', 1, 'NaN', 1, 'Inf', 1, '>', 0)
            md = checkfield(md,'fieldname','basalforcings.groundedice_melting_rate','NaN',1,'Inf',1,'timeseries',1)
            md = checkfield(md, 'fieldname', 'basalforcings.deepwater_elevation', 'NaN', 1, 'Inf', 1, 'size', [1, md.basalforcings.num_basins], 'numel', md.basalforcings.num_basins) 
            md = checkfield(md, 'fieldname', 'basalforcings.upperwater_elevation', 'NaN', 1, 'Inf', 1,'<=',0, 'size', [1, md.basalforcings.num_basins], 'numel', md.basalforcings.num_basins) 
            md = checkfield(md, 'fieldname', 'basalforcings.upperwater_melting_rate', 'NaN', 1, 'Inf', 1,'>=',0, 'size', [1, md.basalforcings.num_basins], 'numel', md.basalforcings.num_basins) 
            md = checkfield(md, 'fieldname', 'basalforcings.basin_id', 'Inf', 1, '>=', 0, '<=', md.basalforcings.num_basins, 'size', [md.mesh.numberofelements])
            md = checkfield(md, 'fieldname', 'basalforcings.beta0', 'NaN', 1, 'Inf', 1, 'size', [1, md.basalforcings.num_basins], 'numel', md.basalforcings.num_basins) # Scheme fails if passed as column vector
            md = checkfield(md, 'fieldname', 'basalforcings.beta1', 'NaN', 1, 'Inf', 1, 'size', [1, md.basalforcings.num_basins], 'numel', md.basalforcings.num_basins) # Scheme fails if passed as column vector; NOTE: As opposed to MATLAB implementation, pass list
            md = checkfield(md, 'fieldname', 'basalforcings.ar_order', 'numel', 1, 'NaN', 1, 'Inf', 1, '>=', 0)
            md = checkfield(md, 'fieldname', 'basalforcings.ar_initialtime', 'numel', 1, 'NaN', 1, 'Inf', 1)
            md = checkfield(md, 'fieldname', 'basalforcings.ar_timestep', 'numel', 1, 'NaN', 1, 'Inf', 1, '>=', md.timestepping.time_step) # Autoregression time step cannot be finer than ISSM timestep
            md = checkfield(md, 'fieldname', 'basalforcings.phi', 'NaN', 1, 'Inf', 1, 'size', [md.basalforcings.num_basins, md.basalforcings.ar_order])
        if 'BalancethicknessAnalysis' in analyses:
            raise Exception('not implemented yet!')
        if 'ThermalAnalysis' in analyses and not solution == 'TransientSolution' and not md.transient.isthermal:
            raise Exception('not implemented yet!')

        return md
    # }}}

    def marshall(self, prefix, md, fid):  # {{{
        yts = md.constants.yts

        WriteData(fid, prefix, 'name', 'md.basalforcings.model', 'data', 9, 'format', 'Integer')
        WriteData(fid, prefix, 'object', self, 'fieldname', 'groundedice_melting_rate', 'name', 'md.basalforcings.groundedice_melting_rate', 'format', 'DoubleMat', 'mattype', 1, 'scale', 1. / yts, 'timeserieslength', md.mesh.numberofvertices + 1, 'yts', yts)
        WriteData(fid, prefix, 'object', self, 'fieldname', 'geothermalflux', 'name', 'md.basalforcings.geothermalflux', 'format', 'DoubleMat', 'mattype', 1, 'timeserieslength', md.mesh.numberofvertices + 1, 'yts', yts)
        WriteData(fid, prefix, 'object', self, 'fieldname', 'num_basins', 'format', 'Integer')
        WriteData(fid, prefix, 'object', self, 'fieldname', 'ar_order', 'format', 'Integer')
        WriteData(fid, prefix, 'object', self, 'fieldname', 'ar_initialtime', 'format', 'Double', 'scale', yts)
        WriteData(fid, prefix, 'object', self, 'fieldname', 'ar_timestep', 'format', 'Double', 'scale', yts)
        WriteData(fid, prefix, 'object', self, 'fieldname', 'basin_id', 'data', self.basin_id, 'name', 'md.basalforcings.basin_id', 'format', 'IntMat', 'mattype', 2) # 0-indexed
        WriteData(fid, prefix, 'object', self, 'fieldname', 'beta0', 'format', 'DoubleMat', 'name', 'md.basalforcings.beta0', 'scale', 1 / yts, 'yts', yts)
        WriteData(fid, prefix, 'object', self, 'fieldname', 'beta1', 'format', 'DoubleMat', 'name', 'md.basalforcings.beta1', 'scale', 1 / (yts ** 2), 'yts', yts)
        WriteData(fid, prefix, 'object', self, 'fieldname', 'phi', 'format', 'DoubleMat', 'name', 'md.basalforcings.phi', 'yts', yts)
        WriteData(fid, prefix, 'object', self, 'fieldname', 'deepwater_elevation', 'format', 'DoubleMat', 'name', 'md.basalforcings.deepwater_elevation')
        WriteData(fid, prefix, 'object', self, 'fieldname', 'upperwater_melting_rate', 'format', 'DoubleMat', 'name', 'md.basalforcings.upperwater_melting_rate', 'scale', 1 / yts, 'yts', yts)
        WriteData(fid, prefix, 'object', self, 'fieldname', 'upperwater_elevation', 'format', 'DoubleMat', 'name', 'md.basalforcings.upperwater_elevation')

    # }}}
