Changeset 25536


Ignore:
Timestamp:
09/04/20 20:15:03 (5 years ago)
Author:
Mathieu Morlighem
Message:

CHG: working on AD transient checkpointing

File:
1 edited

Legend:

Unmodified
Added
Removed
  • issm/trunk-jpl/src/c/cores/transient_core.cpp

    r25512 r25536  
    288288        if(isslr) sealevelrise_core_geometry(femmodel);
    289289
     290        std::vector<IssmDouble> time_all;
     291   int                     Ysize = 0;
     292        CountDoublesFunctor   *hdl_countdoubles = NULL;
     293        RegisterInputFunctor  *hdl_regin        = NULL;
     294        RegisterOutputFunctor *hdl_regout       = NULL;
     295        SetAdjointFunctor     *hdl_setadjoint   = NULL;
     296
    290297        while(time < finaltime - (yts*DBL_EPSILON)){ //make sure we run up to finaltime.
    291298
     
    298305                femmodel->parameters->SetParam(step,StepEnum);
    299306                femmodel->parameters->SetParam(false,SaveResultsEnum);
     307                time_all.push_back(time);
    300308
    301309                /*Store Model State at the beginning of the step*/
    302310                if(VerboseSolution()) _printf0_("   checkpointing model \n");
    303311                femmodel->CheckPointAD(step);
     312
     313                if(VerboseSolution()) _printf0_("   counting number of active variables\n");
     314                hdl_countdoubles = new CountDoublesFunctor();
     315                femmodel->Marshall(hdl_countdoubles);
     316      if(hdl_countdoubles->DoubleCount()>Ysize) Ysize= hdl_countdoubles->DoubleCount();
     317                delete hdl_countdoubles;
    304318
    305319                /*Run transient step!*/
     
    314328        }
    315329
     330        if(VerboseSolution()) _printf0_("   done with initial complete transient\n");
     331
    316332        /*__________________________________________________________________________________*/
    317333
     
    320336        GetVectorFromControlInputsx(&X,&Xsize,femmodel->elements,femmodel->nodes,femmodel->vertices,femmodel->loads,femmodel->materials,femmodel->parameters,"value");
    321337
    322         /*Get Y (model state) size*/
    323         CountDoublesFunctor* marshallhandle1 = new CountDoublesFunctor();
    324         femmodel->Marshall(marshallhandle1);
    325         int Ysize = marshallhandle1->DoubleCount();
    326         delete marshallhandle1;
    327 
    328         /*Initialize Xb, Yb and Yin*/
    329         double *Xb  = xNewZeroInit<double>(Xsize);
    330         double *Yb  = xNewZeroInit<double>(Ysize);
    331         int    *Yin = xNewZeroInit<int>(Ysize);
     338   /*Initialize model state adjoint (Yb)*/
     339   double *Yb  = xNewZeroInit<double>(Ysize);
     340   int    *Yin = xNewZeroInit<int>(Ysize);
     341
    332342
    333343        /*Start tracing*/
     
    336346
    337347        /*Reverse dependent (f)*/
    338         RegisterInputFunctor* marshallhandle2 = new RegisterInputFunctor(Yin);
    339         femmodel->Marshall(marshallhandle2);
    340         delete marshallhandle2;
     348   hdl_regin = new RegisterInputFunctor(Yin,Ysize);
     349   femmodel->Marshall(hdl_regin);
     350   delete hdl_regin;
    341351        for(int i=0; i < Xsize; i++) tape_codi.registerInput(X[i]);
     352
    342353        SetControlInputsFromVectorx(femmodel,X);
    343 
    344         IssmDouble J = 0.;
     354   IssmDouble J = 0.;
    345355        for(int i=0;i<dependent_objects->Size();i++){
    346356                DependentObject* dep=(DependentObject*)dependent_objects->GetObjectByOffset(i);
    347357                J += dep->GetValue();
    348358        }
    349 
    350         if(IssmComm::GetRank()==0) {
    351                 tape_codi.registerOutput(J);
    352         }
     359        if(IssmComm::GetRank()==0) tape_codi.registerOutput(J);
     360
    353361        tape_codi.setPassive();
    354 
    355362        J.gradient() = 1.0;
    356363        tape_codi.evaluate();
    357364
     365   /*Initialize Xb and Yb*/
     366   double *Xb  = xNewZeroInit<double>(Xsize);
    358367        for(int i=0;i<Xsize;i++) Xb[i] += X[i].gradient();
    359         SetAdjointFunctor* marshallhandle3 = new SetAdjointFunctor(Yb);
    360         femmodel->Marshall(marshallhandle3);
    361         delete marshallhandle3;
     368
     369   hdl_setadjoint = new SetAdjointFunctor(Yb,Ysize);
     370   femmodel->Marshall(hdl_setadjoint);
     371   delete hdl_setadjoint;
    362372
    363373        /*reverse loop for transient step (G)*/
     
    370380
    371381                /*We need to store the CoDiPack identifier here, since y is overwritten.*/
    372                 RegisterInputFunctor* marshallhandle4 = new RegisterInputFunctor(Yin);
    373                 femmodel->Marshall(marshallhandle4);
    374                 delete marshallhandle4;
     382                hdl_regin = new RegisterInputFunctor(Yin,Ysize);
     383                femmodel->Marshall(hdl_regin);
     384                delete hdl_regin;
    375385
    376386                /*Tell codipack that X is the independent*/
     
    382392
    383393                /*Register output*/
    384                 RegisterOutputFunctor* marshallhandle5 = new RegisterOutputFunctor();
    385                 femmodel->Marshall(marshallhandle5);
    386                 delete marshallhandle5;
     394                hdl_regout = new RegisterOutputFunctor();
     395                femmodel->Marshall(hdl_regout);
     396                delete hdl_regout;
    387397
    388398                /*stop tracing*/
     
    391401                /*Reverse transient step (G)*/
    392402                /* Using y_b here to seed the next reverse iteration there y_b is always overwritten*/
    393                 SetAdjointFunctor* marshallhandle6 = new SetAdjointFunctor(Yb);
    394                 femmodel->Marshall(marshallhandle6);
    395                 delete marshallhandle6;
     403                hdl_setadjoint = new SetAdjointFunctor(Yb,Ysize);
     404                femmodel->Marshall(hdl_setadjoint);
     405                delete hdl_setadjoint;
    396406
    397407                tape_codi.evaluate();
Note: See TracChangeset for help on using the changeset viewer.