from collections import OrderedDict
# Hack to keep python 2 compatibility
try:
    from dbm import whichdb # Python 3
except ImportError:
    from whichdb import whichdb # Python 2
from re import findall, split
import shelve
from netCDF4 import Dataset
import numpy as np
from model import *


def loadvars(*args, **kwargs):
    """LOADVARS - function to load variables from a file

    This function loads one or more variables from a file. The names of the
    variables must be supplied. If more than one variable is specified, it may
    be done with a list of names or a dictionary of name as keys. The output
    type will correspond to the input type. All the variables in the file may
    be loaded by specifying only the file name.

    Usage:
        a = loadvars('shelve.dat', 'a')
        [a, b] = loadvars('shelve.dat', ['a', 'b'])
        nvdict = loadvars('shelve.dat', {'a':None, 'b':None})
        nvdict = loadvars('shelve.dat')
    """

    filename = ''
    nvdict = {}
    debug = False  #print messages if true

    if len(args) >= 1 and isinstance(args[0], str):
        filename = args[0]
        if not filename:
            filename = '/tmp/shelve.dat'
    else:
        raise TypeError("Missing file name.")

    if len(args) >= 2 and isinstance(args[1], str):  # (filename, name)
        for name in args[1:]:
            nvdict[name] = None
    elif len(args) == 2 and isinstance(args[1], list):  # (filename, [names])
        for name in args[1]:
            nvdict[name] = None
    elif len(args) == 2 and isinstance(args[1], dict):  # (filename, {names:values})
        nvdict = args[1]
    elif len(args) == 1:  #  (filename)
        pass
    else:
        raise TypeError("Unrecognized input arguments.")

    onlylast = False

    for key, value in kwargs.items():
        if key == 'onlylast':
            onlylast = value

    if whichdb(filename):   #We used python pickle for the save
        print("Loading variables from file {}.".format(filename))
        my_shelf = shelve.open(filename, 'r')  # 'r' for read - only
        if nvdict:
            for name in list(nvdict.keys()):
                try:
                    nvdict[name] = my_shelf[name]
                    print(("Variable '%s' loaded." % name))
                except KeyError:
                    value = None
                    print("Variable '{}' not found.".format(name))

        else:
            for name in list(my_shelf.keys()):
                nvdict[name] = my_shelf[name]
                print(("Variable '%s' loaded." % name))
        my_shelf.close()
    else:  #We used netcdf for the save
        try:
            NCFile = Dataset(filename, mode='r')
            NCFile.close()
        except RuntimeError:
            raise IOError("File '{}' not found.".format(filename))

        classtype, classtree = netCDFread(filename)
        nvdict['md'] = model()
        NCFile = Dataset(filename, mode='r')
        for mod in dict.keys(classtype):
            #==== First we create the model structure  {{{
            if debug:
                print(' ==== Now treating classtype {}'.format(mod))
            if mod not in classtree.keys():
                print("WARNING: {} classe is not in the model anymore and will be omited.".format(mod))
            elif np.size(classtree[mod]) > 1:
                # this points to a subclass (results.TransientSolution for example)
                curclass = NCFile.groups[classtree[mod][0]].groups[classtree[mod][1]]
                if debug:
                    print("    ==> {} is of class {}".format(mod, classtype[mod]))
                if classtype[mod][0] == 'results.solutionstep':  #Treating results {{{
                    keylist = [key for key in curclass.groups]
                    #that is the current treatment
                    #here we have a more NC approach with time being a dimension
                    listtype = split(r'\.', classtype[mod][0])[0]
                    if len(NCFile.dimensions['Time']) == 1:
                        nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]] = getattr(classtype[mod][1], listtype)()
                        Tree = nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]]
                    else:
                        #Time dimension is in all the variables so we take that as stepnumber for the results
                        if onlylast:   #we load only the last result to save on time and memory
                            nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]] = [getattr(classtype[mod][1], listtype)()]
                            Tree = nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]]
                        else:
                            setattr(nvdict['md'].__dict__[classtree[mod][0]], classtree[mod][1], getattr(classtype[mod][1], 'solution')([]))
                            for i in range(max(1, len(NCFile.dimensions['Time']))):
                                nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]].steps.append(getattr(classtype[mod][1], 'solutionstep')())
                            Tree = nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]][:]
                # }}}
                elif "results" in mod and classtype[mod][0] == 'list':  #this is the old style of results where every step has a group{{{
                    keylist = [key for key in curclass.groups]
                    #one group per step so use that in place of time
                    stepnum = len(NCFile.groups[classtree[mod][0]].groups[classtree[mod][1]].groups)
                    #we need to redefine classtype from list to result
                    listtype = 'results'
                    classtype[mod].append(__import__(listtype))
                    if stepnum == 1:
                        nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]] = getattr(classtype[mod][1], listtype)()
                        Tree = nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]]
                    else:
                        if onlylast:   #we load only the last result to save on time and memory
                            nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]] = [getattr(classtype[mod][1], listtype)()]
                            Tree = nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]]
                        else:
                            #nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]] = [getattr(classtype[mod][1], listtype)() for i in range(max(1, len(NCFile.dimensions['Time'])))]
                            #Tree = nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]][:]
                            setattr(nvdict['md'].__dict__[classtree[mod][0]], classtree[mod][1], getattr(classtype[mod][1], 'solution')([]))
                            for i in range(max(1, stepnum)):
                                nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]].steps.append(getattr(classtype[mod][1], 'solutionstep')())
                            Tree = nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]][:]
                    #}}}
                elif classtype[mod][0] == 'massfluxatgate':  #this is for output definitions {{{
                    defname = split('Output|[0-9]+', classtree[mod][1])[1] + 's'
                    defindex = int(findall('[0-9]+', classtree[mod][1])[0])
                    nvdict['md'].__dict__[classtree[mod][0]].__dict__[defname].append(getattr(classtype[mod][1], classtype[mod][0])())
                    Tree = nvdict['md'].__dict__[classtree[mod][0]].__dict__[defname][defindex - 1]
                #}}}
                else:
                    if debug:
                        print("    Using the default for md.{}.{}, is that right??".format(classtree[mod][0], classtree[mod][1]))
                    try:
                        modulename = split(r'\.', classtype[mod][0])[0]
                        if debug:
                            print("    trying to import {} from {}".format(classtype[mod][0], modulename))
                        nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]] = getattr(classtype[mod][1], modulename)()
                    except AttributeError:
                        print("WARNING: md.{}.{} is not initialized, hopefully that was done in the main group:".format(classtree[mod][0], classtree[mod][1]))
                    Tree = nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]]
            else:
                curclass = NCFile.groups[classtree[mod][0]]
                modulename = split(r'\.', classtype[mod][0])[0]
                nvdict['md'].__dict__[mod] = getattr(classtype[mod][1], modulename)()
                Tree = nvdict['md'].__dict__[classtree[mod][0]]
            if debug:
                print("    for {} Tree is a {} with len {}".format(mod, Tree.__class__.__name__, len(curclass.groups)))
            # }}}
            #==== Then we populate it {{{
            #for i in range(0, max(1, len(curclass.groups))):
            if len(curclass.groups) > 0:  #that is presumably only for old style NC where each result step had its own group
                if onlylast:
                    groupclass = [curclass.groups[keylist[len(curclass.groups) - 1]]]
                else:
                    groupclass = [curclass.groups[key] for key in keylist]
            else:
                groupclass = [curclass]
            #==== We deal with Variables {{{
            for groupindex, listclass in enumerate(groupclass):
                for var in listclass.variables:
                    if var not in ['errlog', 'outlog']:
                        varval = listclass.variables[str(var)]
                        vardim = varval.ndim
                        if debug:
                            print("    ==> treating var {} of dimension {}".format(var, vardim))
                        #There is a special treatment for results to account for its specific structure
                        #that is the new export version where time is a named dimension
                        NewFormat = 'Time' in NCFile.dimensions
                        if type(Tree) == list:  # and NewFormat:
                            if onlylast:
                                if NewFormat:
                                    if vardim == 1:
                                        Tree[0].__dict__[str(var)] = varval[-1].data
                                    elif vardim == 2:
                                        Tree[0].__dict__[str(var)] = varval[-1, :].data
                                    elif vardim == 3:
                                        Tree[0].__dict__[str(var)] = varval[-1, :, :].data
                                    else:
                                        print('table dimension greater than 3 not implemented yet')
                                else:  #old format had step sorted in difeerent group so last group is last time
                                    Tree[0].__dict__[str(var)] = varval[:].data
                            else:
                                if NewFormat:
                                    incomplete = 'Time' not in varval.dimensions and NewFormat
                                    if incomplete:
                                        chosendim = varval.dimensions[0]
                                        timelist = np.arange(0, len(NCFile.dimensions[chosendim]))
                                        print('WARNING, {} is not present on every times, we chose {}({}) as the dimension to write it with'.format(var, chosendim, len(NCFile.dimensions[chosendim])))
                                    else:
                                        timelist = np.arange(0, len(NCFile.dimensions['Time']))
                                    for t in timelist:
                                        if debug:
                                            print("filing step {} for {}".format(t, var))
                                        if vardim == 0:
                                            Tree[t].__dict__[str(var)] = varval[:].data
                                        elif vardim == 1:
                                            Tree[t].__dict__[str(var)] = varval[t].data
                                        elif vardim == 2:
                                            Tree[t].__dict__[str(var)] = varval[t, :].data
                                        elif vardim == 3:
                                            Tree[t].__dict__[str(var)] = varval[t, :, :].data
                                        else:
                                            print('table dimension greater than 3 not implemented yet')
                                else:
                                    if debug:
                                        print("filing step {} for {}".format(groupindex, var))
                                    Tree[groupindex].__dict__[str(var)] = varval[:].data
                        else:
                            if vardim == 0:  #that is a scalar
                                if str(varval[0]) == '':  #no value
                                    Tree.__dict__[str(var)] = []
                                elif varval[0] == 'True':  #treatin bool
                                    Tree.__dict__[str(var)] = True
                                elif varval[0] == 'False':  #treatin bool
                                    Tree.__dict__[str(var)] = False
                                else:
                                    Tree.__dict__[str(var)] = varval[0].item()

                            elif vardim == 1:  #that is a vector
                                if varval.dtype == str:
                                    if varval.shape[0] == 1:
                                        Tree.__dict__[str(var)] = [str(varval[0]), ]
                                    elif 'True' in varval[:] or 'False' in varval[:]:
                                        Tree.__dict__[str(var)] = np.asarray([V == 'True' for V in varval[:]], dtype=bool)
                                    else:
                                        Tree.__dict__[str(var)] = [str(vallue) for vallue in varval[:]]
                                else:
                                    try:
                                        #some thing specifically require a list
                                        mdtype = type(Tree.__dict__[str(var)])
                                    except KeyError:
                                        mdtype = float
                                    if mdtype == list:
                                        Tree.__dict__[str(var)] = [mdval for mdval in varval[:]]
                                    else:
                                        Tree.__dict__[str(var)] = varval[:].data

                            elif vardim == 2:
                                #dealling with dict
                                if varval.dtype == str:  #that is for dictionaries
                                    if any(varval[:, 0] == 'toolkit'):  #toolkit definition have to be first
                                        Tree.__dict__[str(var)] = OrderedDict([('toolkit', str(varval[np.where(varval[:, 0] == 'toolkit')[0][0], 1]))])
                                        strings1 = [str(arg[0]) for arg in varval if arg[0] != 'toolkits']
                                        strings2 = [str(arg[1]) for arg in varval if arg[0] != 'toolkits']
                                        Tree.__dict__[str(var)].update(list(zip(strings1, strings2)))
                                    else:
                                        strings1 = [str(arg[0]) for arg in varval]
                                        strings2 = [str(arg[1]) for arg in varval]
                                        Tree.__dict__[str(var)] = OrderedDict(list(zip(strings1, strings2)))
                                else:
                                    if type(Tree) == list:
                                        t = indexlist[i]
                                        if listtype == 'dict':
                                            Tree[t][str(var)] = varval[:, :].data
                                        else:
                                            Tree[t].__dict__[str(var)] = varval[:, :].data
                                    else:
                                        Tree.__dict__[str(var)] = varval[:, :].data
                            elif vardim == 3:
                                Tree.__dict__[str(var)] = varval[:, :, :].data
                            else:
                                print('table dimension greater than 3 not implemented yet')
                # }}}
                #==== And with atribute {{{
                for attr in listclass.ncattrs():
                    if debug:
                        print("      ==> treating attribute {}".format(attr))
                    if attr != 'classtype':  #classtype is for treatment, don't get it back
                        attribute = str(attr).swapcase()  #there is a reason for swapcase, no sure what it isanymore
                        if attr == 'VARNAME':
                            attribute = 'name'
                        if type(Tree) == list:
                            if debug:
                                print("        printing with index 0")
                            if listtype == 'dict':
                                Tree[0][attribute] = str(listclass.getncattr(attr))
                            else:
                                Tree[0].__dict__[attribute] = str(listclass.getncattr(attr))
                        else:
                            Tree.__dict__[attribute] = str(listclass.getncattr(attr))
                            if listclass.getncattr(attr) == 'True':
                                Tree.__dict__[attribute] = True
                            elif listclass.getncattr(attr) == 'False':
                                Tree.__dict__[attribute] = False
                # }}}
            # }}}
        NCFile.close()
    if len(args) >= 2 and isinstance(args[1], str):  # (value)
        value = [nvdict[name] for name in args[1:]]
        return value

    elif len(args) == 2 and isinstance(args[1], list):  # ([values])
        value = [nvdict[name] for name in args[1]]
        return value

    elif (len(args) == 2 and isinstance(args[1], dict)) or (len(args) == 1):  # ({names:values})
        return nvdict


def netCDFread(filename):
    print(('Opening {} for reading '.format(filename)))
    NCData = Dataset(filename, 'r')
    class_dict = {}
    class_tree = {}

    for group in NCData.groups:
        if len(NCData.groups[group].groups) > 0:
            for subgroup in NCData.groups[group].groups:
                classe = str(group) + '.' + str(subgroup)
                grpclass = str(getattr(NCData.groups[group].groups[subgroup], 'classtype'))
                class_dict[classe] = [grpclass, ]
                if class_dict[classe][0] not in ['dict', 'list', 'cell']:
                    try:
                        modulename = split(r'\.', class_dict[classe][0])[0]
                        class_dict[classe].append(__import__(modulename))
                    except ModuleNotFoundError:
                        #submodule probably has a different name
                        modulename = str(getattr(NCData.groups[group].groups[subgroup], 'classtype'))
                        print("WARNING importing {} rather than {}".format(modulename, class_dict[classe][0]))
                        class_dict[classe].append(__import__(modulename))
                class_tree[classe] = [group, subgroup]
        else:
            classe = str(group)
            try:
                class_dict[classe] = [str(getattr(NCData.groups[group], 'classtype')), ]
                if class_dict[classe][0] not in ['dict', 'list', 'cell']:
                    modulename = split(r'\.', class_dict[classe][0])[0]
                    if modulename == "giaivins":
                        print("WARNING: module {} does not exist anymore and is skipped".format(modulename))
                    else:
                        class_dict[classe].append(__import__(modulename))
                        class_tree[classe] = [group, ]
            except AttributeError:
                print(('group {} is empty'.format(group)))
    NCData.close()
    return class_dict, class_tree
