import shelve
import os.path
import numpy as np
from netCDF4 import Dataset
from netCDF4 import chartostring
from os import path
from whichdb import whichdb
from model import *

def loadvars(*args):
	"""
	LOADVARS - function to load variables to a file.

	This function loads one or more variables from a file.  The names of the variables
	must be supplied.  If more than one variable is specified, it may be done with
	a list of names or a dictionary of name as keys.  The output type will correspond
	to the input type.  All the variables in the file may be loaded by specifying only
	the file name.

	Usage:
	   a=loadvars('shelve.dat','a')
	   [a,b]=loadvars('shelve.dat',['a','b'])
	   nvdict=loadvars('shelve.dat',{'a':None,'b':None})
	   nvdict=loadvars('shelve.dat')

	"""

	filename=''
	nvdict={}

	if len(args) >= 1 and isinstance(args[0],(str,unicode)):
		filename=args[0]
		if not filename:
			filename='/tmp/shelve.dat'

	else:
		raise TypeError("Missing file name.")

	if   len(args) >= 2 and isinstance(args[1],(str,unicode)):    # (filename,name)
		for name in args[1:]:
			nvdict[name]=None

	elif len(args) == 2 and isinstance(args[1],list):    # (filename,[names])
		for name in args[1]:
			nvdict[name]=None

	elif len(args) == 2 and isinstance(args[1],dict):    # (filename,{names:values})
		nvdict=args[1]

	elif len(args) == 1:    #  (filename)
		pass

	else:
		raise TypeError("Unrecognized input arguments.")

	if whichdb(filename):
		print "Loading variables from file '%s'." % filename
		
		my_shelf = shelve.open(filename,'r') # 'r' for read-only
		if nvdict:
			for name in nvdict.iterkeys():
				try:
					nvdict[name] = my_shelf[name]
					print "Variable '%s' loaded." % name
				except KeyError:
					value = None
					print "Variable '%s' not found." % name

		else:
			for name in my_shelf.iterkeys():
				nvdict[name] = my_shelf[name]
				print "Variable '%s' loaded." % name

		my_shelf.close()

	else:
		try:
			NCFile=Dataset(filename,mode='r')
			NCFile.close()
			classtype,classtree=netCDFread(filename)
			nvdict['md']=model()
			module=map(__import__,dict.values(classtype))
			for i,mod in enumerate(dict.keys(classtype)):
#				print('treating md.{}'.format(mod))
				if np.size(classtree[mod])>1:
					if classtree[mod][0]=='results':
						#treating results (Dimension4 is time)
						resdim=len(NCFile.dimensions['Dimension4'])
						curclass=NCFile.groups[classtree[mod][0]].groups[classtree[mod][1]]
						nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]] = [getattr(module[i],classtype[mod])()]
						if resdim>1:
							for t in range(1,resdim):
								nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]].append(getattr(module[i],classtype[mod])())
						Tree=nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]][:]
					else:
						curclass=NCFile.groups[classtree[mod][0]].groups[classtree[mod][1]]
						nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]] = getattr(module[i],classtype[mod])()
						Tree=nvdict['md'].__dict__[classtree[mod][0]].__dict__[classtree[mod][1]]
				else:
					curclass=NCFile.groups[classtree[mod][0]]
					nvdict['md'].__dict__[mod] = getattr(module[i],classtype[mod])()
					Tree=nvdict['md'].__dict__[classtree[mod][0]]
				for var in curclass.variables:
					#print('    treating {}'.format(var))
					varval=curclass.variables[str(var)]
					vardim=varval.ndim
					try:
						val_type=str(varval.dtype)
					except AttributeError:
						val_type=type(varval)
					if vardim==0:
						try:
							Tree.__dict__[str(var)]=varval.getValue()
							if varval.getValue()=='True':
								Tree.__dict__[str(var)]=True
							elif varval.getValue()=='False':
								Tree.__dict__[str(var)]=False
						except IndexError:
							Tree.__dict__[str(var)]=[]
					elif vardim==1:
						if varval.dtype==str:
							if varval.shape==1:
								Tree.__dict__[str(var)]=str(varval[0])
							if 'True' in varval[:] or 'False' in varval[:]:
								Tree.__dict__[str(var)]=np.asarray(varval[:],dtype=bool)
						else:
							if classtree[mod][0]=='results' and resdim>1:
								for t in range(0,resdim):
									Tree[t].__dict__[str(var)]=varval[t]
							else:
								Tree.__dict__[str(var)]=varval[:]
					elif vardim==2:
						#dealling with dict
						if varval.dtype==str:
							Tree.__dict__[str(var)]=dict(zip(varval[:,0], varval[:,1]))
						else:
							if classtree[mod][0]=='results' and resdim>1:
								for t in range(0,resdim):
									Tree[t].__dict__[str(var)]=varval[:,t]
							else:
								Tree.__dict__[str(var)]=varval[:,:]
					elif vardim==3:
						if classtree[mod][0]=='results' and resdim>1:
							for t in range(0,resdim):
								Tree[t].__dict__[str(var)]=varval[:,:,t]
						else:
							Tree.__dict__[str(var)]=varval[:,:,:]
					else:
						print 'table dimension greater than 3 not implemented yet'
				for attr in curclass.ncattrs():
					#print('    treating {}'.format(attr))
					if classtree[mod][0]!='results' and attr!='classtype': #no attributes in results
						Tree.__dict__[str(attr)]=str(curclass.getncattr(attr))
						if curclass.getncattr(attr)=='True':
							Tree.__dict__[str(attr)]=True
						elif curclass.getncattr(attr)=='False':
							Tree.__dict__[str(attr)]=False

		except RuntimeError:
			raise IOError("File '%s' not found." % filename)

	if   len(args) >= 2 and isinstance(args[1],(str,unicode)):    # (value)
		value=[nvdict[name] for name in args[1:]]
		return value

	elif len(args) == 2 and isinstance(args[1],list):    # ([values])
		value=[nvdict[name] for name in args[1]]
		return value

	elif (len(args) == 2 and isinstance(args[1],dict)) or (len(args) == 1):    # ({names:values})
		return nvdict


def netCDFread(filename):
	def walktree(data):
		keys = data.groups.keys()
		for key in keys:
			yield [str(key)]
			for children in walktree(data.groups[str(key)]):
				child=[str(key)]
				child.append(str(children[0]))
				yield child
	print ('Opening {} for reading '.format(filename))
	NCData=Dataset(filename, 'r')
	class_dict={}
	class_tree={}
	
	for children in walktree(NCData):
		classe=str(children[0])
		if np.size(children)>1:
			for name in children[1:]:
				classe=classe+'.'+name
			class_dict[classe]=str(getattr(NCData.groups[children[0]].groups[children[1]],'classtype'))
		else:
			class_dict[classe]=str(getattr(NCData.groups[classe],'classtype'))
		class_tree[classe]=children

	return class_dict,class_tree
