/*\file Marshalling.h
 *\brief: macros to help automate the marshalling, demarshalling, and marshalling size routines. 
 */

#ifndef _MARSHALLING_H_
#define _MARSHALLING_H_

#include <string.h>
#include "../../Exceptions/exceptions.h"
#include "../../MemOps/MemOps.h"
#include "../../Numerics/recast.h"

/*Define Marshall operation Enums first*/
enum MarshallOpEnum{
	MARSHALLING_WRITE,
	MARSHALLING_LOAD,
	MARSHALLING_SIZE,
#ifdef _HAVE_CODIPACK_
	AD_COUNTDOUBLES,
	AD_REGISTERINPUT,
	AD_REGISTEROUTPUT,
	AD_SETADJOINT,
#endif
};

/*Define virtual Marshall Handle*/
class MarshallHandle{ /*{{{*/
	public:
		MarshallOpEnum operation_enum;
		MarshallHandle(MarshallOpEnum op_in) : operation_enum(op_in){}
		~MarshallHandle(){}
		virtual void Echo(void)=0;
		int OperationNumber(){return operation_enum;}
		template<typename T> void call(T  & value);
		template<typename T> void call(T* & value,int size);
}; /*}}}*/
/* !! Make sure to pass all fields by reference !! */
class WriteCheckpointFunctor: public MarshallHandle{ /*{{{*/

	private:
		char** pmarshalled_data;

	public:
		WriteCheckpointFunctor(char** pmarshalled_data_in) : MarshallHandle(MARSHALLING_WRITE),pmarshalled_data(pmarshalled_data_in){}
		template<typename T> void call(T & value){
			memcpy(*pmarshalled_data,&value,sizeof(T));
			*pmarshalled_data+=sizeof(T);
		}
		void Echo(void){
			printf("WriteCheckpointFunctor Echo:\n");
			printf("   pmarshalled_data: %p\n",pmarshalled_data);
		}
		template<typename T> void call(T* & value,int size){
			bool pointer_null = true;
			if(value) pointer_null = false;
			call(pointer_null);
			if(value){
				memcpy(*pmarshalled_data,value,size*sizeof(T));
				*pmarshalled_data+=size*sizeof(T);
			}
		}
};/*}}}*/
class LoadCheckpointFunctor:  public MarshallHandle{ /*{{{*/

	private:
		char** pmarshalled_data;

	public:
		LoadCheckpointFunctor(char** pmarshalled_data_in) : MarshallHandle(MARSHALLING_LOAD),pmarshalled_data(pmarshalled_data_in){}
		void Echo(void){
			printf("LoadCheckpointFunctor Echo:\n");
			printf("   pmarshalled_data: %p\n",pmarshalled_data);
		}
		template<typename T> void call(T & value){
			memcpy(&value,*pmarshalled_data,sizeof(T));
			*pmarshalled_data+=sizeof(T);
		}
		template<typename T> void call(T* & value,int size){
			bool pointer_null;
			call(pointer_null);
			if(!pointer_null){
				value=xNew<T>(size);
				memcpy(value,*pmarshalled_data,size*sizeof(T));
				*pmarshalled_data+=size*sizeof(T);
			}
			else{
				value = NULL;
			}
		}
};/*}}}*/
class SizeCheckpointFunctor:  public MarshallHandle{ /*{{{*/

	private:
		int marshalled_data_size;

	public:
		SizeCheckpointFunctor(void) : MarshallHandle(MARSHALLING_SIZE),marshalled_data_size(0){}
		int MarshalledSize(void){return this->marshalled_data_size;};
		void Echo(void){
			printf("SizeCheckpointFunctor Echo:\n");
			printf("   marshalled_data_size: %i\n",marshalled_data_size);
		}
		template<typename T> void call(T & value){
			marshalled_data_size+=sizeof(T);
		}
		template<typename T> void call(T* & value,int size){
			bool pointer_null = true;
			if(value) pointer_null = false;
			this->call(pointer_null);
			if(!pointer_null){
				marshalled_data_size+=size*sizeof(T);
			}
		}
};/*}}}*/
#ifdef _HAVE_CODIPACK_
class CountDoublesFunctor:    public MarshallHandle{ /*{{{*/

	private:
		int double_count;

	public:
		CountDoublesFunctor(void) : MarshallHandle(AD_COUNTDOUBLES),double_count(0){}
		int DoubleCount(void){return this->double_count;};
		void Echo(void){
			printf("CountDoublesFunctor Echo:\n");
			printf("   double_count: %i\n",double_count);
		}
		template<typename T> void call(T & value){
			/*General case: do nothing*/
		}
		template<typename T> void call(T* & value,int size){
			/*General case: do nothing*/
		}
		void call(IssmDouble value){
			this->double_count++;
		}
		void call(IssmDouble* value,int size){
			if(value) this->double_count+= size;
		}
}; /*}}}*/
class RegisterInputFunctor:   public MarshallHandle{ /*{{{*/

	private:
		int   double_count;
		int*  identifiers;
		IssmDouble::TapeType* tape_codi;

	public:
		RegisterInputFunctor(int* identifiers_in) : MarshallHandle(AD_REGISTERINPUT){
			this->double_count = 0;
			this->identifiers  = identifiers_in;
			this->tape_codi    = &(IssmDouble::getGlobalTape());
		}
		void Echo(void){
			printf("RegisterInputFunctor Echo:\n");
			printf("   double_count: %i\n",double_count);
		}
		template<typename T> void call(T & value){
			/*General case: do nothing*/
		}
		template<typename T> void call(T* & value,int size){
			/*General case: do nothing*/
		}
		void call(IssmDouble value){
			this->tape_codi->registerInput(value);
			this->identifiers[this->double_count] = value.getGradientData();
			this->double_count++;
		}
		void call(IssmDouble* value,int size){
			if(value){
				for(int i=0;i<size;i++){
					this->tape_codi->registerInput(value[i]);
					this->identifiers[this->double_count] = value[i].getGradientData();
					this->double_count++;
				}
			}
		}
}; /*}}}*/
class RegisterOutputFunctor:  public MarshallHandle{ /*{{{*/

	private:
		int   double_count;
		IssmDouble::TapeType* tape_codi;

	public:
		RegisterOutputFunctor(void) : MarshallHandle(AD_REGISTEROUTPUT){
			this->double_count = 0;
			this->tape_codi    = &(IssmDouble::getGlobalTape());
		}
		void Echo(void){
			printf("RegisterOutputFunctor Echo:\n");
			printf("   double_count: %i\n",double_count);
		}
		template<typename T> void call(T & value){
			/*General case: do nothing*/
		}
		template<typename T> void call(T* & value,int size){
			/*General case: do nothing*/
		}
		void call(IssmDouble value){
			this->tape_codi->registerOutput(value);
			this->double_count++;
		}
		void call(IssmDouble* value,int size){
			if(value){
				for(int i=0;i<size;i++){
					this->tape_codi->registerOutput(value[i]);
					this->double_count++;
				}
			}
		}
}; /*}}}*/
class SetAdjointFunctor:      public MarshallHandle{ /*{{{*/

	private:
		int                   double_count;
		IssmDouble::TapeType* tape_codi;
		double*               adjoint;

	public:
		SetAdjointFunctor(double* adjoint_in) : MarshallHandle(AD_SETADJOINT){
			this->double_count = 0;
			this->tape_codi    = &(IssmDouble::getGlobalTape());
			this->adjoint      = adjoint_in;
		}
		void Echo(void){
			printf("SetAdjointFunctor Echo:\n");
			printf("   double_count: %i\n",double_count);
		}
		template<typename T> void call(T & value){
			/*General case: do nothing*/
		}
		template<typename T> void call(T* & value,int size){
			/*General case: do nothing*/
		}
		void call(IssmDouble value){
			value.gradient() = this->adjoint[this->double_count];
			this->double_count++;
		}
		void call(IssmDouble* value,int size){
			if(value){
				for(int i=0;i<size;i++){
					value[i].gradient() = this->adjoint[this->double_count];
					this->double_count++;
				}
			}
		}
}; /*}}}*/
#endif

template<typename T> void MarshallHandle::call(T & value){
	switch(OperationNumber()){
		case MARSHALLING_WRITE:{WriteCheckpointFunctor* temp = xDynamicCast<WriteCheckpointFunctor*>(this); temp->call(value); break;}
		case MARSHALLING_LOAD: {LoadCheckpointFunctor*  temp = xDynamicCast<LoadCheckpointFunctor*>(this);  temp->call(value); break;}
		case MARSHALLING_SIZE: {SizeCheckpointFunctor*  temp = xDynamicCast<SizeCheckpointFunctor*>(this);  temp->call(value); break;}
#ifdef _HAVE_CODIPACK_
		case AD_COUNTDOUBLES:  {CountDoublesFunctor*   temp = xDynamicCast<CountDoublesFunctor*>(this);    temp->call(value); break;}
		case AD_REGISTERINPUT: {RegisterInputFunctor*  temp = xDynamicCast<RegisterInputFunctor*>(this);   temp->call(value); break;}
		case AD_REGISTEROUTPUT:{RegisterOutputFunctor* temp = xDynamicCast<RegisterOutputFunctor*>(this);  temp->call(value); break;}
		case AD_SETADJOINT:    {SetAdjointFunctor*     temp = xDynamicCast<SetAdjointFunctor*>(this);      temp->call(value); break;}
#endif
		default: _error_("Operation "<<OperationNumber()<<" not supported yet");
	}
}
template<typename T> void MarshallHandle::call(T* & value,int size){
	switch(OperationNumber()){
		case MARSHALLING_WRITE:{WriteCheckpointFunctor* temp = xDynamicCast<WriteCheckpointFunctor*>(this); temp->call(value,size); break;}
		case MARSHALLING_LOAD: {LoadCheckpointFunctor*  temp = xDynamicCast<LoadCheckpointFunctor*>(this);  temp->call(value,size); break;}
		case MARSHALLING_SIZE: {SizeCheckpointFunctor*  temp = xDynamicCast<SizeCheckpointFunctor*>(this);  temp->call(value,size); break;}
#ifdef _HAVE_CODIPACK_
		case AD_COUNTDOUBLES:  {CountDoublesFunctor*   temp = xDynamicCast<CountDoublesFunctor*>(this);    temp->call(value,size); break;}
		case AD_REGISTERINPUT: {RegisterInputFunctor*  temp = xDynamicCast<RegisterInputFunctor*>(this);   temp->call(value,size); break;}
		case AD_REGISTEROUTPUT:{RegisterOutputFunctor* temp = xDynamicCast<RegisterOutputFunctor*>(this);  temp->call(value,size); break;}
		case AD_SETADJOINT:    {SetAdjointFunctor*     temp = xDynamicCast<SetAdjointFunctor*>(this);      temp->call(value,size); break;}
#endif
		default: _error_("Operation "<<OperationNumber() <<" not supported yet");
	}
}

#endif	
