# imports
import netCDF4
from netCDF4 import Dataset
import numpy as np
import numpy.ma as ma
import time
import os
from model import *
from results import *
from m1qn3inversion import m1qn3inversion
from taoinversion import taoinversion
#import OrderedStruct


'''
Given a md, this set of functions will perform the following:
    1. Enter each nested class of the md.
    2. View each attribute of each nested class.
    3. Compare state of attribute in the model to an empty model class.
    4. If states are identical, pass.
    5. Otherwise, create nested groups named after class structure.
    6. Create variable named after class attribute and assign value to it.
'''


def write_netCDF(md, filename: str, verbose = False):
    if verbose:
        print('Python C2NetCDF4 v1.1.14')
    else: pass
    '''
    md = model() class instance to be saved
    filename = path and name to save file under
    verbose = T/F muted or show log statements. Naturally muted
    '''
    
    # Create a NetCDF file to write to
    make_NetCDF(filename, verbose)
    
    # Create an instance of an empty md class to compare md_var against
    global empty_model
    empty_model = model()

    # Walk through the md class and compare subclass states to empty_model
    walk_through_model(md, verbose)

    # in order to handle some subclasses in the results class, we have to utilize this band-aid
    # there will likely be more band-aids added unless a class name library is created with all class names that might be added to a md
    try:
        # if results has meaningful data, save the name of the subclass and class instance
        NetCDF.groups['results']
        results_subclasses_bandaid(md, verbose)
        # otherwise, ignore
    except KeyError:
        pass
        
    NetCDF.close()
    if verbose:
        print('Model successfully saved as NetCDF4')
    else: pass
    

def results_subclasses_bandaid(md, verbose = False):
    # since the results class may have nested classes within it, we need to record the name of the 
    # nested class instance variable as it appears in the md that we're trying to save
    quality_control = []

    # we save lists of instances to the netcdf
    solutions = []
    solutionsteps = []
    resultsdakotas = []
    
    for class_instance_name in md.results.__dict__.keys():
        if verbose:
            print(class_instance_name)
        # for each class instance in results, see which class its from and record that info in the netcdf to recreate structure later
        # check to see if there is a solutionstep class instance
        if isinstance(md.results.__dict__[class_instance_name],solutionstep):
            quality_control.append(1)
            solutionsteps.append(class_instance_name)

        # check to see if there is a solution class instance
        if isinstance(md.results.__dict__[class_instance_name],solution):
            quality_control.append(1)
            solutions.append(class_instance_name)

        # check to see if there is a resultsdakota class instance
        if isinstance(md.results.__dict__[class_instance_name],resultsdakota):
            quality_control.append(1)
            resultsdakotas.append(class_instance_name)

    if solutionsteps != []:
        write_string_to_netcdf(variable_name=str('solutionstep'), address_of_child=solutionsteps, group=NetCDF.groups['results'], list=True, verbose=verbose)

    if solutions != []:
        write_string_to_netcdf(variable_name=str('solution'), address_of_child=solutions, group=NetCDF.groups['results'], list=True, verbose=verbose)

    if resultsdakotas != []:
        write_string_to_netcdf(variable_name=str('resultsdakota'), address_of_child=resultsdakotas, group=NetCDF.groups['results'], list=True, verbose=verbose)

    
    if len(quality_control) != len(md.results.__dict__.keys()):
        print('Error: The class instance within your md.results class is not currently supported by this application')
        print(type(md.results.__dict__[class_instance_name]))
    else:
        if verbose:
            print('The results class was successfully stored on disk')
        else: pass


def make_NetCDF(filename: str, verbose = False):
    # If file already exists delete / rename it
    if os.path.exists(filename):
        print('File {} allready exist'.format(filename))
    
        # If so, inqure for a new name or to do delete the existing file
        newname = input('Give a new name or "delete" to replace: ')

        if newname == 'delete':
            os.remove(filename)
        else:
            print(('New file name is {}'.format(newname)))
            filename = newname
    else:
        # Otherwise create the file and define it globally so other functions can call it
        global NetCDF
        NetCDF = Dataset(filename, 'w', format='NETCDF4')
        NetCDF.history = 'Created ' + time.ctime(time.time())
        NetCDF.createDimension('Unlim', None)  # unlimited dimension
        NetCDF.createDimension('float', 1)     # single integer dimension
        NetCDF.createDimension('int', 1)       # single float dimension
    
    if verbose:
        print('Successfully created ' + filename)


def walk_through_model(md, verbose = False):
    # Iterate over first layer of md_var attributes and assume this first layer is only classes
    for group in md.__dict__.keys():
        address = md.__dict__[group]
        empty_address = empty_model.__dict__[group]
        # we need to record the layers of the md so we can save them to the netcdf file
        layers = [group]

        # Recursively walk through subclasses
        walk_through_subclasses(address, empty_address, layers, verbose)       


def walk_through_subclasses(address, empty_address, layers: list, verbose = False):
    # See if we have an object with keys or a not
    try:
        address.__dict__.keys()
        is_object = True
    except: is_object = False # this is not an object with keys

    if is_object:
        # enter the subclass, see if it has nested classes and/or attributes
        # then compare attributes between mds and write to netCDF if they differ
        # if subclass found, walk through it and repeat
        for child in address.__dict__.keys():
            # record the current location
            current_layer = layers.copy()
            current_layer.append(child)
            
            # navigate to child in each md
            address_of_child = address.__dict__[child]
            
            # if the current object is a results.<solution> object and has the steps attr it needs special treatment
            if isinstance(address_of_child, solution) and len(address_of_child.steps) != 0:
                create_group(address_of_child, current_layer, is_struct = True, verbose = verbose)

            # if the variable is an array, assume it has relevant data (this is because the next line cannot evaluate "==" with an array)
            elif isinstance(address_of_child, np.ndarray):
                create_group(address_of_child, current_layer, is_struct = False, verbose = verbose)
            
            # see if the child exists in the empty md. If not, record it in the netcdf
            else:
                try: 
                    address_of_child_in_empty_class = empty_address.__dict__[child]
                    # if that line worked, we can see how the mds' attributes at this layer compare:
    
                    # if the attributes are identical we don't need to save anything
                    if address_of_child == address_of_child_in_empty_class:
                        walk_through_subclasses(address_of_child, address_of_child_in_empty_class, current_layer, verbose)
    
                    # If it has been modified, record it in the NetCDF file
                    else:
                        create_group(address_of_child, current_layer, is_struct = False, verbose = verbose)
                        walk_through_subclasses(address_of_child, address_of_child_in_empty_class, current_layer, verbose)
    
                except KeyError: # record in netcdf and continue to walk thru md
                    walk_through_subclasses(address_of_child, empty_address, current_layer, verbose)
                    create_group(address_of_child, current_layer, is_struct = False, verbose = verbose)
    else: pass


def create_group(address_of_child, layers, is_struct = False, verbose = False):

    # Handle the first layer of the group(s)
    group_name = layers[0]
    try:
        group = NetCDF.createGroup(str(group_name))
    except:
        group = NetCDF.groups[str(group_name)]

    # need to check if inversion or m1qn3inversion class
    if group_name == 'inversion':
        check_inversion_class(address_of_child, verbose)
    else: pass

    # if the data is nested, create nested groups to match class structure
    if len(layers) > 2:
        for name in layers[1:-1]:
            try:
                group = group.createGroup(str(name))
            except:
                group = NetCDF.groups[str(name)]
    else: pass

    # Lastly, handle the variable(s)
    if is_struct:
        parent_struct_name = layers[-1]
        copy_nested_results_struct(parent_struct_name, address_of_child, group, verbose)
    
    else:
        variable_name = layers[-1]
        create_var(variable_name, address_of_child, group, verbose)
            

def singleton(func):
    """
    A decorator to ensure a function is only executed once.
    """
    def wrapper(*args, **kwargs):
        if not wrapper.has_run:
            wrapper.result = func(*args, **kwargs)
            wrapper.has_run = True
        return wrapper.result
    wrapper.has_run = False
    wrapper.result = None
    return wrapper
    

@singleton
def check_inversion_class(address_of_child, verbose = False):
    # need to make sure that we have the right inversion class: inversion, m1qn3inversion, taoinversion
    if isinstance(address_of_child, m1qn3inversion):
        write_string_to_netcdf(variable_name=str('inversion_class_name'), address_of_child=str('m1qn3inversion'), group=NetCDF.groups['inversion'], verbose = verbose)
        if verbose:
            print('Successfully saved inversion class instance ' + 'm1qn3inversion')
    elif isinstance(address_of_child, taoinversion):
        write_string_to_netcdf(variable_name=str('inversion_class_name'), address_of_child=str('taoinversion'), group=NetCDF.groups['inversion'], verbose = verbose)
        if verbose:
            print('Successfully saved inversion class instance ' + 'taoinversion')
    else:
        write_string_to_netcdf(variable_name=str('inversion_class_name'), address_of_child=str('inversion'), group=NetCDF.groups['inversion'], verbose = verbose)
        if verbose:
            print('Successfully saved inversion class instance ' + 'inversion')


def copy_nested_results_struct(parent_struct_name, address_of_struct, group, verbose = False):
    '''
        This function takes a solution class instance and saves the solutionstep instances from <solution>.steps to the netcdf. 

        To do this, we get the number of dimensions (substructs) of the parent struct.
        Next, we iterate through each substruct and record the data. 
        For each substruct, we create a subgroup of the main struct.
        For each variable, we create dimensions that are assigned to each subgroup uniquely.
    '''
    if verbose:
        print("Beginning transfer of nested MATLAB struct to the NetCDF")
    
    # make a new subgroup to contain all the others:
    group = group.createGroup(str(parent_struct_name))

    # make sure other systems can flag the nested struct type
    write_string_to_netcdf('this_is_a_nested', 'struct', group, list=False, verbose = verbose)

    # other systems know the name of the parent struct because it's covered by the results/qmu functions above
    no_of_dims = len(address_of_struct)
    for substruct in range(0, no_of_dims):
        # we start by making subgroups with nice names like "1x4"
        name_of_subgroup = '1x' + str(substruct)
        subgroup = group.createGroup(str(name_of_subgroup))

        # do some housekeeping to keep track of the current layer
        current_substruct = address_of_struct[substruct]
        substruct_fields = current_substruct.__dict__.keys()

        # now we need to iterate over each variable of the nested struct and save it to this new subgroup
        for variable in substruct_fields:
            address_of_child = current_substruct.__dict__[variable]
            create_var(variable, address_of_child, subgroup, verbose = verbose)
    
    if verbose:
        print(f'Successfully transferred struct {parent_struct_name} to the NetCDF\n')
    
        
def create_var(variable_name, address_of_child, group, verbose = False):
    # There are lots of different variable types that we need to handle from the md class
    
    # This first conditional statement will catch numpy arrays of any dimension and save them
    if isinstance(address_of_child, np.ndarray):
        write_numpy_array_to_netcdf(variable_name, address_of_child, group, verbose=verbose)
    
    # check if it's an int
    elif isinstance(address_of_child, int) or isinstance(address_of_child, np.integer):
        variable = group.createVariable(variable_name, int, ('int',))
        variable[:] = address_of_child
    
    # or a float
    elif isinstance(address_of_child, float) or isinstance(address_of_child, np.floating):
        variable = group.createVariable(variable_name, float, ('float',))
        variable[:] = address_of_child

    # or a string
    elif isinstance(address_of_child, str):
        write_string_to_netcdf(variable_name, address_of_child, group, verbose=verbose)

    #or a bool
    elif isinstance(address_of_child, bool) or isinstance(address_of_child, np.bool_):
        # netcdf4 can't handle bool types like True/False so we convert all to int 1/0 and add an attribute named units with value 'bool'
        variable = group.createVariable(variable_name, int, ('int',))
        variable[:] = int(address_of_child)
        variable.units = "bool"
        
    # or an empty list
    elif isinstance(address_of_child, list) and len(address_of_child)==0:
        variable = group.createVariable(variable_name, int, ('int',))

    # or a list of strings -- this needs work as it can only handle a list of 1 string
    elif isinstance(address_of_child,list) and isinstance(address_of_child[0],str):
        for string in address_of_child:
            write_string_to_netcdf(variable_name, string, group, list=True, verbose=verbose)

    # or a regular list
    elif isinstance(address_of_child, list):
        variable = group.createVariable(variable_name, type(address_of_child[0]), ('Unlim',))
        variable[:] = address_of_child

    # anything else... (will likely need to add more cases; ie helpers.OrderedStruct)
    else:
        try:
            variable = group.createVariable(variable_name, type(address_of_child), ('Unlim',))
            variable[:] = address_of_child
            print(f'Unrecognized variable was saved {variable_name}')
        except TypeError: pass # this would mean that we have an object, so we just let this continue to feed thru the recursive function above
        except Exception as e:
            print(f'There was error with {variable_name} in {group}')
            print("The error message is:")
            print(e)
            print('Datatype given: ' + str(type(address_of_child)))
    
    if verbose:
        print(f'Successfully transferred data from {variable_name} to the NetCDF')
    

def write_string_to_netcdf(variable_name, address_of_child, group, list=False, verbose = False):
    # netcdf and strings dont get along.. we have to do it 'custom':
    # if we hand it an address we need to do it this way:
    if list == True:
        """
        Save a list of strings to a NetCDF file.
    
        Convert a list of strings to a numpy.char_array with utf-8 encoded elements
        and size rows x cols with each row the same # of cols and save to NetCDF
        as char array.
        """
        try:
            strings = address_of_child
            # get dims of array to save
            rows = len(strings)
            cols = len(max(strings, key = len))
    
            # Define dimensions for the strings
            rows_name = 'rows' + str(rows)
            cols_name = 'cols' + str(cols)
            try:
                group.createDimension(rows_name, rows)
            except: pass

            try:
                group.createDimension(cols_name, cols)
            except: pass
                
            # Create a variable to store the strings
            string_var = group.createVariable(str(variable_name), 'S1', (rows_name, cols_name))
    
            # break the list into a list of lists of words with the same length as the longest word:
            # make words same sizes by adding spaces 
            modded_strings = [word + ' ' * (len(max(strings, key=len)) - len(word)) for word in strings]
            # encoded words into list of encoded lists
            new_list = [[s.encode('utf-8') for s in word] for word in modded_strings]
    
            # make numpy char array with dims rows x cols
            arr = np.chararray((rows, cols))
    
            # fill array with list of encoded lists
            for i in range(len(new_list)):
                arr[i] = new_list[i]
    
            # save array to netcdf file
            string_var[:] = arr

            if verbose:
                print(f'Saved {len(modded_strings)} strings to {variable_name}')
    
        except Exception as e:
            print(f'Error: {e}')
        
    else:
        the_string_to_save = address_of_child
        length_of_the_string = len(the_string_to_save)
        numpy_datatype = 'S' + str(length_of_the_string)
        str_out = netCDF4.stringtochar(np.array([the_string_to_save], dtype=numpy_datatype))        
    
        # we'll need to make a new dimension for the string if it doesn't already exist
        name_of_dimension = 'char' + str(length_of_the_string)
        try: 
            group.createDimension(name_of_dimension, length_of_the_string)
        except: pass
        # this is another band-aid to the results sub classes...
        try:
            # now we can make a variable in this dimension:
            string = group.createVariable(variable_name, 'S1', (name_of_dimension))
            #finally we can write the variable:
            string[:] = str_out
        #except RuntimeError: pass
        except Exception as e:
            print(f'There was an error saving a string from {variable_name}')
            print(e)


def write_numpy_array_to_netcdf(variable_name, address_of_child, group, verbose = False):
    # to make a nested array in netCDF, we have to get the dimensions of the array,
    # create corresponding dimensions in the netCDF file, then we can make a variable
    # in the netCDF with dimensions identical to those in the original array
    
    # start by getting the data type at the lowest level in the array:
    typeis = address_of_child.dtype

    # catch boolean arrays here
    if typeis == bool:
        # sometimes an array has just 1 element in it, we account for those cases here:
        if len(address_of_child) == 1:
            variable = group.createVariable(variable_name, int, ('int',))
            variable[:] = int(address_of_child)
            variable.units = "bool"
        else:
            # make the dimensions
            dimensions = []
            for dimension in np.shape(address_of_child):
                dimensions.append(str('dim' + str(dimension)))
                # if the dimension already exists we can't have a duplicate
                try:
                    group.createDimension(str('dim' + str(dimension)), dimension)
                except: pass # this would mean that the dimension already exists
    
            # create the variable:
            variable = group.createVariable(variable_name, int, tuple(dimensions))
            # write the variable:
            variable[:] = address_of_child.astype(int)
            variable.units = "bool"

    # handle all other datatypes here
    else:
        # sometimes an array has just 1 element in it, we account for those cases here:
        if len(address_of_child) == 1:
            if typeis is np.dtype('float64'):
                variable = group.createVariable(variable_name, typeis, ('float',))
                variable[:] = address_of_child[0]
            elif typeis is np.dtype('int64'):
                variable = group.createVariable(variable_name, typeis, ('int',))
                variable[:] = address_of_child[0]
            else:
                print(f'Encountered single datatype from {variable_name} that was not float64 or int64, saving under unlimited dimension, may cause errors.')
                variable = group.createVariable(variable_name, typeis, ('Unlim',))
                variable[:] = address_of_child[0]
    
        # This catches all arrays/lists:
        else:
            # make the dimensions
            dimensions = []
            for dimension in np.shape(address_of_child):
                dimensions.append(str('dim' + str(dimension)))
                # if the dimension already exists we can't have a duplicate
                try:
                    group.createDimension(str('dim' + str(dimension)), dimension)
                except: pass # this would mean that the dimension already exists
    
            # create the variable:
            variable = group.createVariable(variable_name, typeis, tuple(dimensions))
    
            # write the variable:
            variable[:] = address_of_child

            