Prepare legacy checkpoint for modular simulator checkpointing
authorPascal Merz <pascal.merz@me.com>
Sun, 6 Sep 2020 18:15:55 +0000 (18:15 +0000)
committerPaul Bauer <paul.bauer.q@gmail.com>
Sun, 6 Sep 2020 18:15:55 +0000 (18:15 +0000)
* Extend legacy checkpointing functionality to accept a CheckpointDataHolder
  for reading and writing
* Bump checkpoint version to reflect above change
* Turn off some checkpoint sanity checks when using modular simulator
* Pass CheckpointDataHolder object into checkpoint reading in runner, and
  move this object in SimulatorBuilder and then ModularSimulator for element
  setup

Refs #3517
Refs #3422
Refs #3419

14 files changed:
src/gromacs/fileio/checkpoint.cpp
src/gromacs/fileio/checkpoint.h
src/gromacs/mdlib/mdoutf.cpp
src/gromacs/mdlib/mdoutf.h
src/gromacs/mdlib/trajectory_writing.cpp
src/gromacs/mdrun/minimize.cpp
src/gromacs/mdrun/runner.cpp
src/gromacs/mdrun/simulatorbuilder.cpp
src/gromacs/mdrun/simulatorbuilder.h
src/gromacs/modularsimulator/checkpointhelper.cpp
src/gromacs/modularsimulator/modularsimulator.cpp
src/gromacs/modularsimulator/modularsimulator.h
src/gromacs/modularsimulator/statepropagatordata.cpp
src/gromacs/modularsimulator/statepropagatordata.h

index a46510d817d28006b60b14a6097281c61d89446a..5a9e548c812969605dea7d2d8a3072e4f6b884c3 100644 (file)
@@ -60,6 +60,7 @@
 #include "gromacs/math/vectypes.h"
 #include "gromacs/mdtypes/awh_correlation_history.h"
 #include "gromacs/mdtypes/awh_history.h"
+#include "gromacs/mdtypes/checkpointdata.h"
 #include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/df_history.h"
 #include "gromacs/mdtypes/edsamhistory.h"
@@ -157,9 +158,10 @@ enum cptv
     cptv_Unknown = 17,                  /**< Version before numbering scheme */
     cptv_RemoveBuildMachineInformation, /**< remove functionality that makes mdrun builds non-reproducible */
     cptv_ComPrevStepAsPullGroupReference, /**< Allow using COM of previous step as pull group PBC reference */
-    cptv_PullAverage, /**< Added possibility to output average pull force and position */
-    cptv_MdModules,   /**< Added checkpointing for MdModules */
-    cptv_Count        /**< the total number of cptv versions */
+    cptv_PullAverage,      /**< Added possibility to output average pull force and position */
+    cptv_MdModules,        /**< Added checkpointing for MdModules */
+    cptv_ModularSimulator, /**< Added checkpointing for modular simulator */
+    cptv_Count             /**< the total number of cptv versions */
 };
 
 /*! \brief Version number of the file format written to checkpoint
@@ -1228,6 +1230,16 @@ static void do_cpt_header(XDR* xd, gmx_bool bRead, FILE* list, CheckpointHeaderC
     {
         contents->flagsPullHistory = 0;
     }
+
+    if (contents->file_version >= cptv_ModularSimulator)
+    {
+        do_cpt_bool_err(xd, "Is modular simulator checkpoint",
+                        &contents->isModularSimulatorCheckpoint, list);
+    }
+    else
+    {
+        contents->isModularSimulatorCheckpoint = false;
+    }
 }
 
 static int do_cpt_footer(XDR* xd, int file_version)
@@ -2225,7 +2237,8 @@ void write_checkpoint_data(t_fileio*                         fp,
                            t_state*                          state,
                            ObservablesHistory*               observablesHistory,
                            const gmx::MdModulesNotifier&     mdModulesNotifier,
-                           std::vector<gmx_file_position_t>* outputfiles)
+                           std::vector<gmx_file_position_t>* outputfiles,
+                           gmx::WriteCheckpointDataHolder*   modularSimulatorCheckpointData)
 {
     headerContents.flags_eks = 0;
     if (state->ekinstate.bUpToDate)
@@ -2234,6 +2247,7 @@ void write_checkpoint_data(t_fileio*                         fp,
                                     | (1 << eeksEKINO) | (1 << eeksEKINSCALEF) | (1 << eeksEKINSCALEH)
                                     | (1 << eeksVSCALE) | (1 << eeksDEKINDL) | (1 << eeksMVCOS));
     }
+    headerContents.isModularSimulatorCheckpoint = !modularSimulatorCheckpointData->empty();
 
     energyhistory_t* enerhist = observablesHistory->energyHistory.get();
     headerContents.flags_enh  = 0;
@@ -2328,6 +2342,12 @@ void write_checkpoint_data(t_fileio*                         fp,
         gmx::serializeKeyValueTree(tree, &serializer);
     }
 
+    // Checkpointing modular simulator
+    {
+        gmx::FileIOXdrSerializer serializer(fp);
+        modularSimulatorCheckpointData->serialize(&serializer);
+    }
+
     do_cpt_footer(gmx_fio_getxdr(fp), headerContents.file_version);
 }
 
@@ -2467,17 +2487,19 @@ static void check_match(FILE*                           fplog,
     }
 }
 
-static void read_checkpoint(const char*                   fn,
-                            t_fileio*                     logfio,
-                            const t_commrec*              cr,
-                            const ivec                    dd_nc,
-                            int                           eIntegrator,
-                            int*                          init_fep_state,
-                            CheckpointHeaderContents*     headerContents,
-                            t_state*                      state,
-                            ObservablesHistory*           observablesHistory,
-                            gmx_bool                      reproducibilityRequested,
-                            const gmx::MdModulesNotifier& mdModulesNotifier)
+static void read_checkpoint(const char*                    fn,
+                            t_fileio*                      logfio,
+                            const t_commrec*               cr,
+                            const ivec                     dd_nc,
+                            int                            eIntegrator,
+                            int*                           init_fep_state,
+                            CheckpointHeaderContents*      headerContents,
+                            t_state*                       state,
+                            ObservablesHistory*            observablesHistory,
+                            gmx_bool                       reproducibilityRequested,
+                            const gmx::MdModulesNotifier&  mdModulesNotifier,
+                            gmx::ReadCheckpointDataHolder* modularSimulatorCheckpointData,
+                            bool                           useModularSimulator)
 {
     t_fileio* fp;
     char      buf[STEPSTRSIZE];
@@ -2548,7 +2570,8 @@ static void read_checkpoint(const char*                   fn,
                   fn);
     }
 
-    if (headerContents->flags_state != state->flags)
+    // For modular simulator, no state object is populated, so we cannot do this check here!
+    if (headerContents->flags_state != state->flags && !useModularSimulator)
     {
         gmx_fatal(FARGS,
                   "Cannot change a simulation algorithm during a checkpoint restart. Perhaps you "
@@ -2658,6 +2681,11 @@ static void read_checkpoint(const char*                   fn,
         cp_error();
     }
     do_cpt_mdmodules(headerContents->file_version, fp, mdModulesNotifier);
+    if (headerContents->file_version >= cptv_ModularSimulator)
+    {
+        gmx::FileIOXdrSerializer serializer(fp);
+        modularSimulatorCheckpointData->deserialize(&serializer);
+    }
     ret = do_cpt_footer(gmx_fio_getxdr(fp), headerContents->file_version);
     if (ret)
     {
@@ -2670,22 +2698,25 @@ static void read_checkpoint(const char*                   fn,
 }
 
 
-void load_checkpoint(const char*                   fn,
-                     t_fileio*                     logfio,
-                     const t_commrec*              cr,
-                     const ivec                    dd_nc,
-                     t_inputrec*                   ir,
-                     t_state*                      state,
-                     ObservablesHistory*           observablesHistory,
-                     gmx_bool                      reproducibilityRequested,
-                     const gmx::MdModulesNotifier& mdModulesNotifier)
+void load_checkpoint(const char*                    fn,
+                     t_fileio*                      logfio,
+                     const t_commrec*               cr,
+                     const ivec                     dd_nc,
+                     t_inputrec*                    ir,
+                     t_state*                       state,
+                     ObservablesHistory*            observablesHistory,
+                     gmx_bool                       reproducibilityRequested,
+                     const gmx::MdModulesNotifier&  mdModulesNotifier,
+                     gmx::ReadCheckpointDataHolder* modularSimulatorCheckpointData,
+                     bool                           useModularSimulator)
 {
     CheckpointHeaderContents headerContents;
     if (SIMMASTER(cr))
     {
         /* Read the state from the checkpoint file */
-        read_checkpoint(fn, logfio, cr, dd_nc, ir->eI, &(ir->fepvals->init_fep_state), &headerContents,
-                        state, observablesHistory, reproducibilityRequested, mdModulesNotifier);
+        read_checkpoint(fn, logfio, cr, dd_nc, ir->eI, &(ir->fepvals->init_fep_state),
+                        &headerContents, state, observablesHistory, reproducibilityRequested,
+                        mdModulesNotifier, modularSimulatorCheckpointData, useModularSimulator);
     }
     if (PAR(cr))
     {
@@ -2816,6 +2847,14 @@ static CheckpointHeaderContents read_checkpoint_data(t_fileio*
     }
     gmx::MdModulesNotifier mdModuleNotifier;
     do_cpt_mdmodules(headerContents.file_version, fp, mdModuleNotifier);
+    if (headerContents.file_version >= cptv_ModularSimulator)
+    {
+        // In the scope of the current function, we can just throw away the content
+        // of the modular checkpoint, but we need to read it to move the file pointer
+        gmx::FileIOXdrSerializer      serializer(fp);
+        gmx::ReadCheckpointDataHolder modularSimulatorCheckpointData;
+        modularSimulatorCheckpointData.deserialize(&serializer);
+    }
     ret = do_cpt_footer(gmx_fio_getxdr(fp), headerContents.file_version);
     if (ret)
     {
index 863f9e82bd546f4103485328a70ed275b41e1d12..c3dbc3c1077ee06c3eb055250e74c70823d6e0b3 100644 (file)
@@ -63,6 +63,8 @@ namespace gmx
 
 struct MdModulesNotifier;
 class KeyValueTreeObject;
+class ReadCheckpointDataHolder;
+class WriteCheckpointDataHolder;
 
 /*! \brief Read to a key-value-tree value used for checkpointing.
  *
@@ -236,6 +238,8 @@ struct CheckpointHeaderContents
     int nED;
     //! Enum for coordinate swapping.
     int eSwapCoords;
+    //! Whether the checkpoint was written by modular simulator.
+    bool isModularSimulatorCheckpoint = false;
 };
 
 /*! \brief Low-level checkpoint writing function */
@@ -246,7 +250,8 @@ void write_checkpoint_data(t_fileio*                         fp,
                            t_state*                          state,
                            ObservablesHistory*               observablesHistory,
                            const gmx::MdModulesNotifier&     notifier,
-                           std::vector<gmx_file_position_t>* outputfiles);
+                           std::vector<gmx_file_position_t>* outputfiles,
+                           gmx::WriteCheckpointDataHolder*   modularSimulatorCheckpointData);
 
 /* Loads a checkpoint from fn for run continuation.
  * Generates a fatal error on system size mismatch.
@@ -255,15 +260,17 @@ void write_checkpoint_data(t_fileio*                         fp,
  * but not the state itself.
  * With reproducibilityRequested warns about version, build, #ranks differences.
  */
-void load_checkpoint(const char*                   fn,
-                     t_fileio*                     logfio,
-                     const t_commrec*              cr,
-                     const ivec                    dd_nc,
-                     t_inputrec*                   ir,
-                     t_state*                      state,
-                     ObservablesHistory*           observablesHistory,
-                     gmx_bool                      reproducibilityRequested,
-                     const gmx::MdModulesNotifier& mdModulesNotifier);
+void load_checkpoint(const char*                    fn,
+                     t_fileio*                      logfio,
+                     const t_commrec*               cr,
+                     const ivec                     dd_nc,
+                     t_inputrec*                    ir,
+                     t_state*                       state,
+                     ObservablesHistory*            observablesHistory,
+                     gmx_bool                       reproducibilityRequested,
+                     const gmx::MdModulesNotifier&  mdModulesNotifier,
+                     gmx::ReadCheckpointDataHolder* modularSimulatorCheckpointData,
+                     bool                           useModularSimulator);
 
 /* Read everything that can be stored in t_trxframe from a checkpoint file */
 void read_checkpoint_trxframe(struct t_fileio* fp, t_trxframe* fr);
index c337ae1960f725c38abc1d33ec46ed64a0ab3cf6..8c5ddb6482a05002cbf4f1e9e2a2ec2079a813c5 100644 (file)
@@ -283,23 +283,24 @@ static void mpiBarrierBeforeRename(const bool applyMpiBarrierBeforeRename, MPI_C
  * Appends the _step<step>.cpt with bNumberAndKeep, otherwise moves
  * the previous checkpoint filename with suffix _prev.cpt.
  */
-static void write_checkpoint(const char*                   fn,
-                             gmx_bool                      bNumberAndKeep,
-                             FILE*                         fplog,
-                             const t_commrec*              cr,
-                             ivec                          domdecCells,
-                             int                           nppnodes,
-                             int                           eIntegrator,
-                             int                           simulation_part,
-                             gmx_bool                      bExpanded,
-                             int                           elamstats,
-                             int64_t                       step,
-                             double                        t,
-                             t_state*                      state,
-                             ObservablesHistory*           observablesHistory,
-                             const gmx::MdModulesNotifier& mdModulesNotifier,
-                             bool                          applyMpiBarrierBeforeRename,
-                             MPI_Comm                      mpiBarrierCommunicator)
+static void write_checkpoint(const char*                     fn,
+                             gmx_bool                        bNumberAndKeep,
+                             FILE*                           fplog,
+                             const t_commrec*                cr,
+                             ivec                            domdecCells,
+                             int                             nppnodes,
+                             int                             eIntegrator,
+                             int                             simulation_part,
+                             gmx_bool                        bExpanded,
+                             int                             elamstats,
+                             int64_t                         step,
+                             double                          t,
+                             t_state*                        state,
+                             ObservablesHistory*             observablesHistory,
+                             const gmx::MdModulesNotifier&   mdModulesNotifier,
+                             gmx::WriteCheckpointDataHolder* modularSimulatorCheckpointData,
+                             bool                            applyMpiBarrierBeforeRename,
+                             MPI_Comm                        mpiBarrierCommunicator)
 {
     t_fileio* fp;
     char*     fntemp; /* the temporary checkpoint file name */
@@ -383,7 +384,8 @@ static void write_checkpoint(const char*                   fn,
                                                 0,
                                                 0,
                                                 nED,
-                                                eSwapCoords };
+                                                eSwapCoords,
+                                                false };
     std::strcpy(headerContents.version, gmx_version());
     std::strcpy(headerContents.fprog, gmx::getProgramContext().fullBinaryPath());
     std::strcpy(headerContents.ftime, timebuf.c_str());
@@ -393,7 +395,7 @@ static void write_checkpoint(const char*                   fn,
     }
 
     write_checkpoint_data(fp, headerContents, bExpanded, elamstats, state, observablesHistory,
-                          mdModulesNotifier, &outputfiles);
+                          mdModulesNotifier, &outputfiles, modularSimulatorCheckpointData);
 
     /* we really, REALLY, want to make sure to physically write the checkpoint,
        and all the files it depends on, out to disk. Because we've
@@ -476,17 +478,18 @@ static void write_checkpoint(const char*                   fn,
 #endif /* end GMX_FAHCORE block */
 }
 
-void mdoutf_write_to_trajectory_files(FILE*                          fplog,
-                                      const t_commrec*               cr,
-                                      gmx_mdoutf_t                   of,
-                                      int                            mdof_flags,
-                                      int                            natoms,
-                                      int64_t                        step,
-                                      double                         t,
-                                      t_state*                       state_local,
-                                      t_state*                       state_global,
-                                      ObservablesHistory*            observablesHistory,
-                                      gmx::ArrayRef<const gmx::RVec> f_local)
+void mdoutf_write_to_trajectory_files(FILE*                           fplog,
+                                      const t_commrec*                cr,
+                                      gmx_mdoutf_t                    of,
+                                      int                             mdof_flags,
+                                      int                             natoms,
+                                      int64_t                         step,
+                                      double                          t,
+                                      t_state*                        state_local,
+                                      t_state*                        state_global,
+                                      ObservablesHistory*             observablesHistory,
+                                      gmx::ArrayRef<const gmx::RVec>  f_local,
+                                      gmx::WriteCheckpointDataHolder* modularSimulatorCheckpointData)
 {
     const rvec* f_global;
 
@@ -544,7 +547,7 @@ void mdoutf_write_to_trajectory_files(FILE*                          fplog,
                              DOMAINDECOMP(cr) ? cr->dd->nnodes : cr->nnodes, of->eIntegrator,
                              of->simulation_part, of->bExpanded, of->elamstats, step, t,
                              state_global, observablesHistory, *(of->mdModulesNotifier),
-                             of->simulationsShareState, of->mastersComm);
+                             modularSimulatorCheckpointData, of->simulationsShareState, of->mastersComm);
         }
 
         if (mdof_flags & (MDOF_X | MDOF_V | MDOF_F))
index fe534fab60e1ab2149a1b7f131a93b9cfed09d4a..546042ee387bc825e7263b007de01e28d46d7205 100644 (file)
@@ -59,6 +59,7 @@ enum class StartingBehavior;
 class IMDOutputProvider;
 struct MdModulesNotifier;
 struct MdrunOptions;
+class WriteCheckpointDataHolder;
 } // namespace gmx
 
 typedef struct gmx_mdoutf* gmx_mdoutf_t;
@@ -109,17 +110,18 @@ void done_mdoutf(gmx_mdoutf_t of);
  * the master node only when necessary. Without domain decomposition
  * only data from state_local is used and state_global is ignored.
  *
- * \param[in] fplog              File handler to log file.
- * \param[in] cr                 Communication record.
- * \param[in] of                 File handler to trajectory file.
- * \param[in] mdof_flags         Flags indicating what data is written.
- * \param[in] natoms             The total number of atoms in the system.
- * \param[in] step               The current time step.
- * \param[in] t                  The current time.
- * \param[in] state_local        Pointer to the local state object.
- * \param[in] state_global       Pointer to the global state object.
- * \param[in] observablesHistory Pointer to the ObservableHistory object.
- * \param[in] f_local            The local forces.
+ * \param[in] fplog                           File handler to log file.
+ * \param[in] cr                              Communication record.
+ * \param[in] of                              File handler to trajectory file.
+ * \param[in] mdof_flags                      Flags indicating what data is written.
+ * \param[in] natoms                          The total number of atoms in the system.
+ * \param[in] step                            The current time step.
+ * \param[in] t                               The current time.
+ * \param[in] state_local                     Pointer to the local state object.
+ * \param[in] state_global                    Pointer to the global state object.
+ * \param[in] observablesHistory              Pointer to the ObservableHistory object.
+ * \param[in] f_local                         The local forces.
+ * \param[in] modularSimulatorCheckpointData  CheckpointData object used by modular simulator.
  */
 void mdoutf_write_to_trajectory_files(FILE*                          fplog,
                                       const t_commrec*               cr,
@@ -131,7 +133,8 @@ void mdoutf_write_to_trajectory_files(FILE*                          fplog,
                                       t_state*                       state_local,
                                       t_state*                       state_global,
                                       ObservablesHistory*            observablesHistory,
-                                      gmx::ArrayRef<const gmx::RVec> f_local);
+                                      gmx::ArrayRef<const gmx::RVec> f_local,
+                                      gmx::WriteCheckpointDataHolder* modularSimulatorCheckpointData);
 
 /*! \brief Get the output interval of box size of uncompressed TNG output.
  * Returns 0 if no uncompressed TNG file is open.
index 6d250ceeb2b30a8eb2c03a90596a92603e81301c..d9ef7191ede87741b866305ab8c0225f337cfe29 100644 (file)
@@ -44,6 +44,7 @@
 #include "gromacs/mdlib/mdoutf.h"
 #include "gromacs/mdlib/stat.h"
 #include "gromacs/mdlib/update.h"
+#include "gromacs/mdtypes/checkpointdata.h"
 #include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/forcerec.h"
 #include "gromacs/mdtypes/inputrec.h"
@@ -160,11 +161,15 @@ void do_md_trajectory_writing(FILE*                          fplog,
                 energyOutput.fillEnergyHistory(observablesHistory->energyHistory.get());
             }
         }
+        // The current function is only called by legacy code, while
+        // mdoutf_write_to_trajectory_files is also called from modular simulator. Create a dummy
+        // modular simulator checkpointing object for compatibility.
+        gmx::WriteCheckpointDataHolder checkpointDataHolder;
         // Note that part of the following code is duplicated in StatePropagatorData::trajectoryWriterTeardown.
         // This duplication is needed while both legacy and modular code paths are in use.
         // TODO: Remove duplication asap, make sure to keep in sync in the meantime.
-        mdoutf_write_to_trajectory_files(fplog, cr, outf, mdof_flags, top_global->natoms, step, t,
-                                         state, state_global, observablesHistory, f);
+        mdoutf_write_to_trajectory_files(fplog, cr, outf, mdof_flags, top_global->natoms, step, t, state,
+                                         state_global, observablesHistory, f, &checkpointDataHolder);
         if (bLastStep && step_rel == ir->nsteps && bDoConfOut && MASTER(cr) && !bRerunMD)
         {
             if (fr->bMolPBC && state == state_global)
index 24013faeffbe3a5d6fb100dcf17e6b0069217f4e..95815d24b7af85284d152793e38ec01272a6286e 100644 (file)
@@ -90,6 +90,7 @@
 #include "gromacs/mdlib/vsite.h"
 #include "gromacs/mdrunutility/handlerestart.h"
 #include "gromacs/mdrunutility/printtime.h"
+#include "gromacs/mdtypes/checkpointdata.h"
 #include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/forcebuffers.h"
 #include "gromacs/mdtypes/forcerec.h"
@@ -532,9 +533,10 @@ static void write_em_traj(FILE*               fplog,
         mdof_flags |= MDOF_IMD;
     }
 
+    gmx::WriteCheckpointDataHolder checkpointDataHolder;
     mdoutf_write_to_trajectory_files(fplog, cr, outf, mdof_flags, top_global->natoms, step,
                                      static_cast<double>(step), &state->s, state_global,
-                                     observablesHistory, state->f.view().force());
+                                     observablesHistory, state->f.view().force(), &checkpointDataHolder);
 
     if (confout != nullptr)
     {
@@ -1857,9 +1859,10 @@ void LegacySimulator::do_lbfgs()
             mdof_flags |= MDOF_IMD;
         }
 
+        gmx::WriteCheckpointDataHolder checkpointDataHolder;
         mdoutf_write_to_trajectory_files(fplog, cr, outf, mdof_flags, top_global->natoms, step,
-                                         static_cast<real>(step), &ems.s, state_global,
-                                         observablesHistory, ems.f.view().force());
+                                         static_cast<real>(step), &ems.s, state_global, observablesHistory,
+                                         ems.f.view().force(), &checkpointDataHolder);
 
         /* Do the linesearching in the direction dx[point][0..(n-1)] */
 
index 8bc7dc71f6f8bd9dbeacb8ed10669b040e5664fc..ef536514d14d21693f4f2d451adcf5975d435956 100644 (file)
 #include "gromacs/mdrunutility/multisim.h"
 #include "gromacs/mdrunutility/printtime.h"
 #include "gromacs/mdrunutility/threadaffinity.h"
+#include "gromacs/mdtypes/checkpointdata.h"
 #include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/enerdata.h"
 #include "gromacs/mdtypes/fcdata.h"
@@ -1070,6 +1071,7 @@ int Mdrunner::mdrunner()
 
     ObservablesHistory observablesHistory = {};
 
+    auto modularSimulatorCheckpointData = std::make_unique<ReadCheckpointDataHolder>();
     if (startingBehavior != StartingBehavior::NewSimulation)
     {
         /* Check if checkpoint file exists before doing continuation.
@@ -1086,7 +1088,8 @@ int Mdrunner::mdrunner()
 
         load_checkpoint(opt2fn_master("-cpi", filenames.size(), filenames.data(), cr),
                         logFileHandle, cr, domdecOptions.numCells, inputrec.get(), globalState.get(),
-                        &observablesHistory, mdrunOptions.reproducible, mdModules_->notifier());
+                        &observablesHistory, mdrunOptions.reproducible, mdModules_->notifier(),
+                        modularSimulatorCheckpointData.get(), useModularSimulator);
 
         if (startingBehavior == StartingBehavior::RestartWithAppending && logFileHandle)
         {
@@ -1679,6 +1682,7 @@ int Mdrunner::mdrunner()
         simulatorBuilder.add(IonSwapping(swap));
         simulatorBuilder.add(TopologyData(&mtop, mdAtoms.get()));
         simulatorBuilder.add(BoxDeformationHandle(deform.get()));
+        simulatorBuilder.add(std::move(modularSimulatorCheckpointData));
 
         // build and run simulator object based on user-input
         auto simulator = simulatorBuilder.build(useModularSimulator);
index 0474223a293e23fe05a5f9f44834337a2735a622..f1ed995d53f15bcc67a7af90d9f5243b4e95f99e 100644 (file)
@@ -46,6 +46,7 @@
 #include <memory>
 
 #include "gromacs/mdlib/vsite.h"
+#include "gromacs/mdtypes/checkpointdata.h"
 #include "gromacs/mdtypes/mdrunoptions.h"
 #include "gromacs/mdtypes/state.h"
 #include "gromacs/modularsimulator/modularsimulator.h"
@@ -126,20 +127,24 @@ std::unique_ptr<ISimulator> SimulatorBuilder::build(bool useModularSimulator)
     if (useModularSimulator)
     {
         // NOLINTNEXTLINE(modernize-make-unique): make_unique does not work with private constructor
-        return std::unique_ptr<ModularSimulator>(new ModularSimulator(std::make_unique<LegacySimulatorData>(
-                simulatorEnv_->fplog_, simulatorEnv_->commRec_, simulatorEnv_->multisimCommRec_,
-                simulatorEnv_->logger_, legacyInput_->numFile, legacyInput_->filenames,
-                simulatorEnv_->outputEnv_, simulatorConfig_->mdrunOptions_,
-                simulatorConfig_->startingBehavior_, constraintsParam_->vsite,
-                constraintsParam_->constr, constraintsParam_->enforcedRotation, boxDeformation_->deform,
-                simulatorModules_->outputProvider, simulatorModules_->mdModulesNotifier,
-                legacyInput_->inputrec, interactiveMD_->imdSession, centerOfMassPulling_->pull_work,
-                ionSwapping_->ionSwap, topologyData_->top_global, simulatorStateData_->globalState_p,
-                simulatorStateData_->observablesHistory_p, topologyData_->mdAtoms, profiling_->nrnb,
-                profiling_->wallCycle, legacyInput_->forceRec, simulatorStateData_->enerdata_p,
-                simulatorStateData_->ekindata_p, simulatorConfig_->runScheduleWork_,
-                *replicaExchangeParameters_, membedHolder_->membed(), profiling_->walltimeAccounting,
-                std::move(stopHandlerBuilder_), simulatorConfig_->mdrunOptions_.rerun)));
+        return std::unique_ptr<ModularSimulator>(new ModularSimulator(
+                std::make_unique<LegacySimulatorData>(
+                        simulatorEnv_->fplog_, simulatorEnv_->commRec_, simulatorEnv_->multisimCommRec_,
+                        simulatorEnv_->logger_, legacyInput_->numFile, legacyInput_->filenames,
+                        simulatorEnv_->outputEnv_, simulatorConfig_->mdrunOptions_,
+                        simulatorConfig_->startingBehavior_, constraintsParam_->vsite,
+                        constraintsParam_->constr, constraintsParam_->enforcedRotation,
+                        boxDeformation_->deform, simulatorModules_->outputProvider,
+                        simulatorModules_->mdModulesNotifier, legacyInput_->inputrec,
+                        interactiveMD_->imdSession, centerOfMassPulling_->pull_work, ionSwapping_->ionSwap,
+                        topologyData_->top_global, simulatorStateData_->globalState_p,
+                        simulatorStateData_->observablesHistory_p, topologyData_->mdAtoms,
+                        profiling_->nrnb, profiling_->wallCycle, legacyInput_->forceRec,
+                        simulatorStateData_->enerdata_p, simulatorStateData_->ekindata_p,
+                        simulatorConfig_->runScheduleWork_, *replicaExchangeParameters_,
+                        membedHolder_->membed(), profiling_->walltimeAccounting,
+                        std::move(stopHandlerBuilder_), simulatorConfig_->mdrunOptions_.rerun),
+                std::move(modularSimulatorCheckpointData_)));
     }
     // NOLINTNEXTLINE(modernize-make-unique): make_unique does not work with private constructor
     return std::unique_ptr<LegacySimulator>(new LegacySimulator(
@@ -168,5 +173,10 @@ void SimulatorBuilder::add(ReplicaExchangeParameters&& replicaExchangeParameters
     replicaExchangeParameters_ = std::make_unique<ReplicaExchangeParameters>(replicaExchangeParameters);
 }
 
+void SimulatorBuilder::add(std::unique_ptr<ReadCheckpointDataHolder> modularSimulatorCheckpointData)
+{
+    modularSimulatorCheckpointData_ = std::move(modularSimulatorCheckpointData);
+}
+
 
 } // namespace gmx
index 76309136c0fb1299e943e418995bd5fe0764d3f4..74b56cec00247b5e378ade22f5828a3e4d20f5c8 100644 (file)
@@ -77,6 +77,7 @@ class MDAtoms;
 class MDLogger;
 struct MdModulesNotifier;
 struct MdrunOptions;
+class ReadCheckpointDataHolder;
 enum class StartingBehavior;
 class StopHandlerBuilder;
 class VirtualSitesHandler;
@@ -334,6 +335,9 @@ public:
         boxDeformation_ = std::make_unique<BoxDeformationHandle>(boxDeformation);
     }
 
+    //! Pass the read checkpoint data for modular simulator
+    void add(std::unique_ptr<ReadCheckpointDataHolder> modularSimulatorCheckpointData);
+
     /*! \brief Build a Simulator object based on input data
      *
      * Return a pointer to a simulation object. The use of a parameter
@@ -364,6 +368,8 @@ private:
     std::unique_ptr<IonSwapping>               ionSwapping_;
     std::unique_ptr<TopologyData>              topologyData_;
     std::unique_ptr<BoxDeformationHandle>      boxDeformation_;
+    //! Contains checkpointing data for the modular simulator
+    std::unique_ptr<ReadCheckpointDataHolder> modularSimulatorCheckpointData_;
 };
 
 } // namespace gmx
index 64e50cc8682f9ab6505810124440f8eb914b12aa..3246aebe7aad21cc9eb7f401a07ce67b0482d433 100644 (file)
@@ -45,6 +45,7 @@
 
 #include "gromacs/domdec/domdec.h"
 #include "gromacs/mdlib/mdoutf.h"
+#include "gromacs/mdtypes/checkpointdata.h"
 #include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/state.h"
 
@@ -120,14 +121,16 @@ void CheckpointHelper::scheduleTask(Step step, Time time, const RegisterRunFunct
 void CheckpointHelper::writeCheckpoint(Step step, Time time)
 {
     localStateInstance_->flags = 0;
+
+    WriteCheckpointDataHolder checkpointDataHolder;
     for (const auto& client : clients_)
     {
         client->writeCheckpoint(localStateInstance_, state_global_);
     }
 
     mdoutf_write_to_trajectory_files(fplog_, cr_, trajectoryElement_->outf_, MDOF_CPT,
-                                     globalNumAtoms_, step, time, localStateInstance_,
-                                     state_global_, observablesHistory_, ArrayRef<RVec>());
+                                     globalNumAtoms_, step, time, localStateInstance_, state_global_,
+                                     observablesHistory_, ArrayRef<RVec>(), &checkpointDataHolder);
 }
 
 std::optional<SignallerCallback> CheckpointHelper::registerLastStepCallback()
index 5604a8dec2b6ff0ec8278d752ae6381771e067b4..3b5205f7a2121f25f2b88d9491d293ae83dd5647 100644 (file)
@@ -350,8 +350,10 @@ bool ModularSimulator::isInputCompatible(bool                             exitOn
     return isInputCompatible;
 }
 
-ModularSimulator::ModularSimulator(std::unique_ptr<LegacySimulatorData> legacySimulatorData) :
-    legacySimulatorData_(std::move(legacySimulatorData))
+ModularSimulator::ModularSimulator(std::unique_ptr<LegacySimulatorData>      legacySimulatorData,
+                                   std::unique_ptr<ReadCheckpointDataHolder> checkpointDataHolder) :
+    legacySimulatorData_(std::move(legacySimulatorData)),
+    checkpointDataHolder_(std::move(checkpointDataHolder))
 {
     checkInputForDisabledFunctionality();
 }
index 6a0dcdcb8b4e8ee742ad8b50166f7c3bf4da02b4..a46b52ebdef235bc457101076b00528463f82bc5 100644 (file)
@@ -57,6 +57,7 @@ struct t_fcdata;
 namespace gmx
 {
 class ModularSimulatorAlgorithmBuilder;
+class ReadCheckpointDataHolder;
 
 /*! \libinternal
  * \ingroup module_modularsimulator
@@ -90,7 +91,8 @@ public:
 
 private:
     //! Constructor
-    explicit ModularSimulator(std::unique_ptr<LegacySimulatorData> legacySimulatorData);
+    ModularSimulator(std::unique_ptr<LegacySimulatorData>      legacySimulatorData,
+                     std::unique_ptr<ReadCheckpointDataHolder> checkpointDataHolder);
 
     //! Populate algorithm builder with elements
     void addIntegrationElements(ModularSimulatorAlgorithmBuilder* builder);
@@ -98,8 +100,10 @@ private:
     //! Check for disabled functionality (during construction time)
     void checkInputForDisabledFunctionality();
 
-    //! Pointer to legacy simulator data
+    //! Pointer to legacy simulator data (TODO: Can we avoid using unique_ptr? #3628)
     std::unique_ptr<LegacySimulatorData> legacySimulatorData_;
+    //! Input checkpoint data
+    std::unique_ptr<ReadCheckpointDataHolder> checkpointDataHolder_;
 };
 
 /*!
index 67088ca078988c703262057846370fae512e1e85..93fba49cb36a08c93754f4e00ac305705f0ad205 100644 (file)
@@ -53,6 +53,7 @@
 #include "gromacs/mdlib/mdoutf.h"
 #include "gromacs/mdlib/stat.h"
 #include "gromacs/mdlib/update.h"
+#include "gromacs/mdtypes/checkpointdata.h"
 #include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/forcebuffers.h"
 #include "gromacs/mdtypes/forcerec.h"
@@ -424,10 +425,10 @@ void StatePropagatorData::Element::write(gmx_mdoutf_t outf, Step currentStep, Ti
     // TODO: This is only used for CPT - needs to be filled when we turn CPT back on
     ObservablesHistory* observablesHistory = nullptr;
 
-    mdoutf_write_to_trajectory_files(fplog_, cr_, outf, static_cast<int>(mdof_flags),
-                                     statePropagatorData_->totalNumAtoms_, currentStep, currentTime,
-                                     localStateBackup_.get(), statePropagatorData_->globalState_,
-                                     observablesHistory, statePropagatorData_->f_.view().force());
+    mdoutf_write_to_trajectory_files(
+            fplog_, cr_, outf, static_cast<int>(mdof_flags), statePropagatorData_->totalNumAtoms_,
+            currentStep, currentTime, localStateBackup_.get(), statePropagatorData_->globalState_,
+            observablesHistory, statePropagatorData_->f_.view().force(), &dummyCheckpointDataHolder_);
 
     if (currentStep != lastStep_ || !isRegularSimulationEnd_)
     {
index 38713f2629d118117f35b8b0a2fe58e0225cf2d1..2bafcfad8ad6ee4a5a373d3c83b8e7f3fa211acc 100644 (file)
@@ -47,6 +47,7 @@
 #include "gromacs/gpu_utils/hostallocator.h"
 #include "gromacs/math/paddedvector.h"
 #include "gromacs/math/vectypes.h"
+#include "gromacs/mdtypes/checkpointdata.h"
 #include "gromacs/mdtypes/forcebuffers.h"
 
 #include "modularsimulatorinterfaces.h"
@@ -341,6 +342,8 @@ private:
     void trajectoryWriterSetup(gmx_mdoutf gmx_unused* outf) override {}
     //! Trajectory writer teardown - write final coordinates
     void trajectoryWriterTeardown(gmx_mdoutf* outf) override;
+    //! A dummy CheckpointData - remove when we stop using the legacy trajectory writing function
+    WriteCheckpointDataHolder dummyCheckpointDataHolder_;
 
     //! Whether planned total number of steps was reached (used for final output only)
     bool isRegularSimulationEnd_;