/*!\file StochasticForcingx
 * \brief: compute noise terms for the StochasticForcing fields
 */

#include "./StochasticForcingx.h"
#include "../../shared/shared.h"
#include "../../toolkits/toolkits.h"
#include "../../shared/Random/random.h"

void StochasticForcingx(FemModel* femmodel){/*{{{*/

   /*Retrieve parameters*/
   bool randomflag;
   int M,N,numfields,my_rank;
   int* fields            = NULL;
   int* dimensions        = NULL;
   IssmDouble* covariance = NULL;
   femmodel->parameters->FindParam(&randomflag,StochasticForcingRandomflagEnum);
   femmodel->parameters->FindParam(&numfields,StochasticForcingNumFieldsEnum);
   femmodel->parameters->FindParam(&fields,&N,StochasticForcingFieldsEnum);    _assert_(N==numfields);
   femmodel->parameters->FindParam(&dimensions,&N,StochasticForcingDimensionsEnum);    _assert_(N==numfields);
   int dimtot=0;
   for(int i=0;i<numfields;i++) dimtot = dimtot+dimensions[i];
   femmodel->parameters->FindParam(&covariance,&M,&N,StochasticForcingCovarianceEnum); _assert_(M==dimtot); _assert_(N==dimtot);

   /*Compute noise terms*/
   IssmDouble* noiseterms = xNew<IssmDouble>(dimtot);
   my_rank=IssmComm::GetRank();
   if(my_rank==0){
      int fixedseed;
      IssmDouble time,dt,starttime;
      femmodel->parameters->FindParam(&time,TimeEnum);
      femmodel->parameters->FindParam(&dt,TimesteppingTimeStepEnum);
      femmodel->parameters->FindParam(&starttime,TimesteppingStartTimeEnum);
		/*Determine whether random seed is fixed to time step (randomflag==false) or random seed truly random (randomflag==true)*/
      if(randomflag) fixedseed=-1;
      else fixedseed = reCast<int,IssmDouble>((time-starttime)/dt);
		/*multivariateNormal needs to be passed a NULL pointer to avoid memory leak issues*/
      IssmDouble* temparray = NULL;
      multivariateNormal(&temparray,dimtot,0.0,covariance,fixedseed);
      for(int i=0;i<dimtot;i++) noiseterms[i]=temparray[i];
      xDelete<IssmDouble>(temparray);
   }
   ISSM_MPI_Bcast(noiseterms,dimtot,ISSM_MPI_DOUBLE,0,IssmComm::GetComm());
   
	int i=0;
   for(int j=0;j<numfields;j++){
      int dimenum_type,noiseenum_type;
      IssmDouble* noisefield = xNew<IssmDouble>(dimensions[j]);
      for(int k=0;k<dimensions[j];k++){
         noisefield[k]=noiseterms[i+k];
      }
     
		int dimensionid;

		/*Deal with the autoregressive models*/
		if(fields[j]==SMBautoregressionEnum || fields[j]==FrontalForcingsRignotAutoregressionEnum){
			switch(fields[j]){
				case SMBautoregressionEnum:
					dimenum_type   = SmbBasinsIdEnum;
					noiseenum_type = SmbAutoregressionNoiseEnum;
					break;
				case FrontalForcingsRignotAutoregressionEnum:
					dimenum_type   = FrontalForcingsBasinIdEnum;
					noiseenum_type = ThermalforcingAutoregressionNoiseEnum;
					break;
			}
			for(Object* &object:femmodel->elements->objects){
            Element* element = xDynamicCast<Element*>(object);
            int numvertices  = element->GetNumberOfVertices();
            IssmDouble* noise_element = xNew<IssmDouble>(numvertices);
            element->GetInputValue(&dimensionid,dimenum_type);
            for(int i=0;i<numvertices;i++) noise_element[i] = noisefield[dimensionid];
            element->AddInput(noiseenum_type,noise_element,P0Enum);
            xDelete<IssmDouble>(noise_element);
			}
		}
		else{
			switch(fields[j]){
				case SMBautoregressionEnum:
				case FrontalForcingsRignotAutoregressionEnum:
					/*Already done above*/
					break;
				case DefaultCalvingEnum:
					/*Delete CalvingCalvingrateEnum at previous time step (required if it is transient)*/
					femmodel->inputs->DeleteInput(CalvingCalvingrateEnum);
					for(Object* &object:femmodel->elements->objects){
						Element* element = xDynamicCast<Element*>(object);
						int numvertices  = element->GetNumberOfVertices();
						IssmDouble baselinecalvingrate;
						IssmDouble calvingrate_tot[numvertices];
						Input* baselinecalvingrate_input  = NULL;
						baselinecalvingrate_input = element->GetInput(BaselineCalvingCalvingrateEnum); _assert_(baselinecalvingrate_input);
						element->GetInputValue(&dimensionid,StochasticForcingDefaultIdEnum);
						Gauss* gauss = element->NewGauss();
						for(int i=0;i<numvertices;i++){
							gauss->GaussVertex(i);
							baselinecalvingrate_input->GetInputValue(&baselinecalvingrate,gauss);
							calvingrate_tot[i] = max(0.0,baselinecalvingrate+noisefield[dimensionid]);
						}
						element->AddInput(CalvingCalvingrateEnum,&calvingrate_tot[0],P1DGEnum);
						delete gauss;
					}
					break;
				default:
					_error_("Field "<<EnumToStringx(fields[j])<<" does not support stochasticity yet.");
			}
		}
		i=i+dimensions[j];
      xDelete<IssmDouble>(noisefield);
   }

	/*Cleanup*/
   xDelete<int>(fields);
   xDelete<int>(dimensions);
   xDelete<IssmDouble>(covariance);
   xDelete<IssmDouble>(noiseterms);
}/*}}}*/
