import shelve
import numpy as np
from netCDF4 import Dataset
from re import findall
from collections import OrderedDict
from model import *
#hack to keep python 2 compatibility
try:
    #py3 import
    from dbm import whichdb
except ImportError:
    #py2 import
    from whichdb import whichdb


def loadvars(*args, OL):
    """
    LOADVARS - function to load variables to a file.

    This function loads one or more variables from a file.  The names of the variables
    must be supplied.  If more than one variable is specified, it may be done with
    a list of names or a dictionary of name as keys.  The output type will correspond
    to the input type.  All the variables in the file may be loaded by specifying only
    the file name.

    Usage:
        a = loadvars('shelve.dat', 'a')
        [a, b] = loadvars('shelve.dat', ['a', 'b'])
        nvdict = loadvars('shelve.dat', {'a':None, 'b':None})
        nvdict = loadvars('shelve.dat')

    """

    filename = ''
    nvdict = {}
    debug = False  #print messages if true

    if len(args) >= 1 and isinstance(args[0], str):
        filename = args[0]
        if not filename:
            filename = '/tmp/shelve.dat'

    else:
        raise TypeError("Missing file name.")

    if len(args) >= 2 and isinstance(args[1], str):  # (filename, name)
        for name in args[1:]:
            nvdict[name] = None

    elif len(args) == 2 and isinstance(args[1], list):  # (filename, [names])
        for name in args[1]:
            nvdict[name] = None

    elif len(args) == 2 and isinstance(args[1], dict):  # (filename, {names:values})
        nvdict = args[1]

    elif len(args) == 1:  #  (filename)
        pass

    else:
        raise TypeError("Unrecognized input arguments.")

    if whichdb(filename):   #We used python pickle for the save
        print("Loading variables from file {}.".format(filename))
        my_shelf = shelve.open(filename, 'r')  # 'r' for read - only
        if nvdict:
            for name in list(nvdict.keys()):
                try:
                    nvdict[name] = my_shelf[name]
                    print(("Variable '%s' loaded." % name))
                except KeyError:
                    value = None
                    print("Variable '{}' not found.".format(name))

        else:
            for name in list(my_shelf.keys()):
                nvdict[name] = my_shelf[name]
                print(("Variable '%s' loaded." % name))
        my_shelf.close()

    else:  #We used netcdf for the save
        try:
            NCFile = Dataset(filename, mode='r')
            NCFile.close()
        except RuntimeError:
            raise IOError("File '{}' not found.".format(filename))

        classtype, classtree = netCDFread(filename)
        nvdict['md'] = model()
        NCFile = Dataset(filename, mode='r')
        for mod in dict.keys(classtype):
            #==== First we create the model structure  {{{
            if debug:
                print(' - Now treating classtype {}'.format(mod))
            if np.size(classtree[mod]) > 1:
                # this points to a subclass (results.TransientSolution for example)
                curclass = NCFile.groups[classtree[mod][0]].groups[classtree[mod][1]]
                if debug:
                    print("now checking {} for list : {}".format(mod, classtype[mod][0] == 'list' or (classtype[mod][0] == 'results' and 'Time' in NCFile.dimensions)))
                if classtype[mod][0] == 'list' or (classtype[mod][0] == 'results' and 'Time' in NCFile.dimensions):  #We have a list of variables
                    keylist = [key for key in curclass.groups]
                    # this is related to the old structure of NC files where every steps of results had its own group
                    if len(keylist) > 0:
                        #this is kept for compatibility
                        #and treatment of list of dicts??
                        try:
                            #group are named after their step
                            steplist = [int(key) for key in curclass.groups]
                        except ValueError:
                            #or a number is appended at the end of the name
                            steplist = [int(findall(r'\d + ', key)[0]) for key in keylist]
                        indexlist = [int(index * (len(curclass.groups) - 1) / max(1, max(steplist))) for index in steplist]
                        listtype = curclass.groups[keylist[0]].classtype
                        #discriminate between dict and results
                        if listtype == 'dict':
                            nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]] = [OrderedDict() for i in range(max(1, len(curclass.groups)))]
                        else:
                            if OL:   #we load only the last result to save on time and memory
                                nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]] = [getattr(__import__(listtype), listtype)()]
                                Tree = nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]]
                            else:
                                nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]] = [getattr(__import__(listtype), listtype)() for i in range(max(1, len(curclass.groups)))]
                                Tree = nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]][:]
                    else:
                        #that is the current treatment
                        #here we have a more NC approach with time being a dimension
                        keylist = [key for key in curclass.variables]
                        dimlist = [curclass.variables[key].dimensions for key in keylist]
                        indexlist = np.arange(0, len(NCFile.dimensions['Time']))
                        AllHaveTime = np.all(['Time' in dimtuple for dimtuple in dimlist])
                        listtype = curclass.classtype
                        if AllHaveTime:
                            #Time dimension is in all the variables so we take that as stepnumber for the results
                            if OL:   #we load only the last result to save on time and memory
                                nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]] = [getattr(__import__(listtype), listtype)()]
                                Tree = nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]]
                            else:
                                nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]] = [getattr(__import__(listtype), listtype)() for i in range(max(1, len(NCFile.dimensions['Time'])))]
                                Tree = nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]][:]
                        elif len(NCFile.dimensions['Time']) == 1:
                            nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]] = getattr(__import__(listtype), listtype)()
                            Tree = nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]]
                        else:

                            print("ERROR: Time dimension is not in all results. That has been overlooked for now but your resulat are not saved.")

                else:
                    nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]] = getattr(classtype[mod][1], classtype[mod][0])()
                    Tree = nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]]
            else:
                curclass = NCFile.groups[classtree[mod][0]]
                nvdict['md'].__dict__[mod] = getattr(classtype[mod][1], classtype[mod][0])()
                Tree = nvdict['md'].__dict__[classtree[mod][0]]
            if debug:
                print("for {} Tree is a {}".format(mod, Tree.__class__.__name__))
            # }}}
            #==== Then we populate it {{{
            for i in range(0, max(1, len(curclass.groups))):
                if len(curclass.groups) > 0:
                    listclass = curclass.groups[keylist[i]]
                else:
                    listclass = curclass
                #==== We deal with Variables {{{
                for var in listclass.variables:
                    if debug:
                        print("treating var {}".format(var))
                    if var not in ['errlog', 'outlog']:
                        varval = listclass.variables[str(var)]
                        vardim = varval.ndim
                        #There is a special treatment for results to account for its specific structure
                        #that is the new export version where time is a named dimension
                        NewFormat = 'Time' in NCFile.dimensions
                        if type(Tree) == list and NewFormat:
                            if OL:
                                if vardim == 1:
                                    Tree[0].__dict__[str(var)] = varval[-1].data
                                elif vardim == 2:
                                    Tree[0].__dict__[str(var)] = varval[-1, :].data
                                elif vardim == 3:
                                    Tree[0].__dict__[str(var)] = varval[-1, :, :].data
                                else:
                                    print('table dimension greater than 3 not implemented yet')
                            else:
                                for t in indexlist:
                                    if vardim == 0:
                                        Tree[t].__dict__[str(var)] = varval[:].data
                                    if vardim == 1:
                                        Tree[t].__dict__[str(var)] = varval[t].data
                                    elif vardim == 2:
                                        Tree[t].__dict__[str(var)] = varval[t, :].data
                                    elif vardim == 3:
                                        Tree[t].__dict__[str(var)] = varval[t, :, :].data
                                    else:
                                        print('table dimension greater than 3 not implemented yet')
                        else:
                            if vardim == 0:  #that is a scalar
                                if type(Tree) == list:
                                    t = indexlist[i]
                                    if listtype == 'dict':
                                        Tree[t][str(var)] = varval[0].data
                                    else:
                                        Tree[t].__dict__[str(var)] = varval[0].data
                                else:
                                    if str(varval[0]) == '':  #no value
                                        Tree.__dict__[str(var)] = []
                                    elif varval[0] == 'True':  #treatin bool
                                        Tree.__dict__[str(var)] = True
                                    elif varval[0] == 'False':  #treatin bool
                                        Tree.__dict__[str(var)] = False
                                    else:
                                        Tree.__dict__[str(var)] = varval[0].item()

                            elif vardim == 1:  #that is a vector
                                if varval.dtype == str:
                                    if varval.shape[0] == 1:
                                        Tree.__dict__[str(var)] = [str(varval[0]), ]
                                    elif 'True' in varval[:] or 'False' in varval[:]:
                                        Tree.__dict__[str(var)] = np.asarray([V == 'True' for V in varval[:]], dtype=bool)
                                    else:
                                        Tree.__dict__[str(var)] = [str(vallue) for vallue in varval[:]]
                                else:
                                    if type(Tree) == list:
                                        t = indexlist[i]
                                        if listtype == 'dict':
                                            Tree[t][str(var)] = varval[:].data
                                        else:
                                            Tree[t].__dict__[str(var)] = varval[:].data
                                    else:
                                        try:
                                            #some thing specifically require a list
                                            mdtype = type(Tree.__dict__[str(var)])
                                        except KeyError:
                                            mdtype = float
                                        if mdtype == list:
                                            Tree.__dict__[str(var)] = [mdval for mdval in varval[:]]
                                        else:
                                            Tree.__dict__[str(var)] = varval[:].data

                            elif vardim == 2:
                                #dealling with dict
                                if varval.dtype == str:  #that is for toolkits wich needs to be ordered
                                    if any(varval[:, 0] == 'toolkit'):  #toolkit definition have to be first
                                        Tree.__dict__[str(var)] = OrderedDict([('toolkit', str(varval[np.where(varval[:, 0] == 'toolkit')[0][0], 1]))])
                                        strings1 = [str(arg[0]) for arg in varval if arg[0] != 'toolkits']
                                        strings2 = [str(arg[1]) for arg in varval if arg[0] != 'toolkits']
                                        Tree.__dict__[str(var)].update(list(zip(strings1, strings2)))
                                else:
                                    if type(Tree) == list:
                                        t = indexlist[i]
                                        if listtype == 'dict':
                                            Tree[t][str(var)] = varval[:, :].data
                                        else:
                                            Tree[t].__dict__[str(var)] = varval[:, :].data
                                    else:
                                        Tree.__dict__[str(var)] = varval[:, :].data
                            elif vardim == 3:
                                if type(Tree) == list:
                                    t = int(indexlist[i])
                                    if listtype == 'dict':
                                        Tree[t][str(var)] = varval[:, :, :].data
                                    else:
                                        Tree[t].__dict__[str(var)] = varval[:, :, :]
                                else:
                                    Tree.__dict__[str(var)] = varval[:, :, :].data
                            else:
                                print('table dimension greater than 3 not implemented yet')
                    # }}}
                #==== And with atribute {{{
                for attr in listclass.ncattrs():
                    if attr != 'classtype':  #classtype is for treatment, don't get it back
                        if type(Tree) == list:
                            t = int(indexlist[i])
                            if listtype == 'dict':
                                Tree[t][str(attr).swapcase()] = str(listclass.getncattr(attr))
                            else:
                                Tree[t].__dict__[str(attr).swapcase()] = str(listclass.getncattr(attr))
                        else:
                            Tree.__dict__[str(attr).swapcase()] = str(listclass.getncattr(attr))
                            if listclass.getncattr(attr) == 'True':
                                Tree.__dict__[str(attr).swapcase()] = True
                            elif listclass.getncattr(attr) == 'False':
                                Tree.__dict__[str(attr).swapcase()] = False
                # }}}
            # }}}
        NCFile.close()
    if len(args) >= 2 and isinstance(args[1], str):  # (value)
        value = [nvdict[name] for name in args[1:]]
        return value

    elif len(args) == 2 and isinstance(args[1], list):  # ([values])
        value = [nvdict[name] for name in args[1]]
        return value

    elif (len(args) == 2 and isinstance(args[1], dict)) or (len(args) == 1):  # ({names:values})
        return nvdict


def netCDFread(filename):
    print(('Opening {} for reading '.format(filename)))
    NCData = Dataset(filename, 'r')
    class_dict = {}
    class_tree = {}

    for group in NCData.groups:
        if len(NCData.groups[group].groups) > 0:
            for subgroup in NCData.groups[group].groups:
                classe = str(group) + '.' + str(subgroup)
                class_dict[classe] = [str(getattr(NCData.groups[group].groups[subgroup], 'classtype')), ]
                if class_dict[classe][0] not in ['dict', 'list', 'cell']:
                    class_dict[classe].append(__import__(class_dict[classe][0]))
                class_tree[classe] = [group, subgroup]
        else:
            classe = str(group)
            try:
                class_dict[classe] = [str(getattr(NCData.groups[group], 'classtype')), ]
                if class_dict[classe][0] not in ['dict', 'list', 'cell']:
                    class_dict[classe].append(__import__(class_dict[classe][0]))
                    class_tree[classe] = [group, ]
            except AttributeError:
                print(('group {} is empty'.format(group)))
    NCData.close()
    return class_dict, class_tree
