import numpy as npy
from processmesh import processmesh
from processdata import processdata
import matplotlib.pyplot as plt
import matplotlib as mpl
import os
try:
	from osgeo import gdal
except ImportError:
	print 'osgeo/gdal for python not installed, plot_overlay is disabled'

def plot_overlay(md,data,options,ax):
	'''
	Function for plotting a georeferenced image.  This function is called
	from within the plotmodel code.
	'''

	x,y,z,elements,is2d,isplanet=processmesh(md,[],options)

	if data=='none' or data==None:
		imageonly=1
		data=npy.float('nan')*npy.ones((md.mesh.numberofvertices,))
		datatype=1
	else:
		imageonly=0
		data,datatype=processdata(md,data,options)

	if not is2d:
		raise StandardError('overlay plot not supported for 3D meshes, project on a 2D layer first')
	if datatype==3:
		raise StandardError('overlay not yet supported for quiver plots')	

	if not options.exist('geotiff_name'):
		raise StandardError('overlay error: provide geotiff_name with path to geotiff file')
	geotiff=options.getfieldvalue('geotiff_name')

	xlim=options.getfieldvalue('xlim',[min(md.mesh.x),max(md.mesh.x)])
	ylim=options.getfieldvalue('ylim',[min(md.mesh.y),max(md.mesh.y)])

	gtif=gdal.Open(geotiff)
	trans=gtif.GetGeoTransform()
	xmin=trans[0]
	xmax=trans[0]+gtif.RasterXSize*trans[1]
	ymin=trans[3]+gtif.RasterYSize*trans[5]
	ymax=trans[3]
	
	# allow supplied geotiff to have limits smaller than basemap or model limits
	x0=max(min(xlim),xmin)
	x1=min(max(xlim),xmax)
	y0=max(min(ylim),ymin)
	y1=min(max(ylim),ymax)
	inputname='temp.tif'
	os.system('gdal_translate -quiet -projwin ' + str(x0) + ' ' + str(y1) + ' ' + str(x1) + ' ' + str(y0) + ' ' + geotiff + ' ' + inputname)
	
	gtif=gdal.Open(inputname)
	arr=gtif.ReadAsArray()
	#os.system('rm -rf ./temp.tif')
	
	if gtif.RasterCount>=3:  # RGB array
		r=gtif.GetRasterBand(1).ReadAsArray()
		g=gtif.GetRasterBand(2).ReadAsArray()
		b=gtif.GetRasterBand(3).ReadAsArray()
		arr=0.299*r+0.587*g+0.114*b

	# normalize array
	arr=arr/npy.float(npy.max(arr.ravel()))
        arr=1.-arr # somehow the values got flipped

	if options.getfieldvalue('overlayhist',0)==1:
		ax=plt.gca()
		num=2
		while True:
			if not plt.fignum_exists(num):
				break
			else:
				num+=1
		plt.figure(num)
		plt.hist(arr.flatten(),bins=256,range=(0.,1.))
		plt.title('histogram of overlay image, use for setting overlaylims')
		plt.sca(ax) # return to original axes/figure
		
	# get parameters from cropped geotiff
	trans=gtif.GetGeoTransform()
	xmin=trans[0]
	xmax=trans[0]+gtif.RasterXSize*trans[1]
	ymin=trans[3]+gtif.RasterYSize*trans[5]
	ymax=trans[3]
	dx=trans[1]
	dy=trans[5]	
	
	xarr=npy.arange(xmin,xmax,dx)
	yarr=npy.arange(ymin,ymax,-dy) # -dy since origin='upper' (not sure how robust this is)
	xg,yg=npy.meshgrid(xarr,yarr)
	if options.exist('basemap'):
		# TODO get handle to or create basemap instance 
		# create coordinate grid in map projection units (for plotting)
		lats,lons=xy2ll(xg,yg,-1,0,71)
		xgmap,ygmap=m(lons,lats) # map projection units returned by basemap instance
	else:
		xgmap=xg
		ygmap=yg
	
	overlaylims=options.getfieldvalue('overlaylims',[min(arr.ravel()),max(arr.ravel())])

	norm=mpl.colors.Normalize(vmin=overlaylims[0],vmax=overlaylims[1])

	pc=ax.pcolormesh(xgmap, ygmap, npy.flipud(arr), cmap=mpl.cm.Greys, norm=norm)
	#rasterization? 
	if options.getfieldvalue('rasterized',0):
		pc.set_rasterized(True)
