Add checkpointing for MdModules
[alexxy/gromacs.git] / src / gromacs / fileio / checkpoint.cpp
index 83d661c219a80164adbbe0a71c9a8b46979dbbb1..6250e76ba0f4fb518e21f2f1b5e7a1afc920de5a 100644 (file)
 #include "gromacs/utility/futil.h"
 #include "gromacs/utility/gmxassert.h"
 #include "gromacs/utility/int64_to_int.h"
+#include "gromacs/utility/keyvaluetree.h"
+#include "gromacs/utility/keyvaluetreebuilder.h"
+#include "gromacs/utility/keyvaluetreeserializer.h"
+#include "gromacs/utility/mdmodulenotification.h"
 #include "gromacs/utility/programcontext.h"
 #include "gromacs/utility/smalloc.h"
 #include "gromacs/utility/sysinfo.h"
@@ -105,6 +109,7 @@ enum cptv {
     cptv_RemoveBuildMachineInformation,                      /**< remove functionality that makes mdrun builds non-reproducible */
     cptv_ComPrevStepAsPullGroupReference,                    /**< Allow using COM of previous step as pull group PBC reference */
     cptv_PullAverage,                                        /**< Added possibility to output average pull force and position */
+    cptv_MdModules,                                          /**< Added checkpointing for MdModules */
     cptv_Count                                               /**< the total number of cptv versions */
 };
 
@@ -262,6 +267,39 @@ enum class StatePart
     pullHistory         //!< Pull history statistics (sums since last written output)
 };
 
+namespace gmx
+{
+
+struct MdModulesCheckpointReadingDataOnMaster
+{
+    //! The data of the MdModules that is stored in the checkpoint file
+    const KeyValueTreeObject &checkpointedData;
+    //! The version of the read ceckpoint file
+    int                       checkpointFileVersion_;
+};
+
+/*! \libinternal
+ * \brief Provides the MdModules with the communication record to broadcast.
+ */
+struct MdModulesCheckpointReadingBroadcast
+{
+    //! The communication record
+    const t_commrec &cr;
+    //! The version of the read file version
+    int              checkpointFileVersion_;
+};
+
+/*! \libinternal \brief Writing the MdModules data to a checkpoint file.
+ */
+struct MdModulesWriteCheckpointData
+{
+    //! Builder for the Key-Value-Tree to store the MdModule checkpoint data
+    KeyValueTreeObjectBuilder builder;
+    //! The version of the read file version
+    int                       checkpointFileVersion_;
+};
+} // namespace gmx
+
 //! \brief Return the name of a checkpoint entry based on part and part entry
 static const char *entryName(StatePart part, int ecpt)
 {
@@ -1865,6 +1903,18 @@ static int do_cpt_awh(XDR *xd, gmx_bool bRead,
     return ret;
 }
 
+static void do_cpt_mdmodules(int fileVersion, t_fileio *checkpointFileHandle, const gmx::MdModulesNotifier &mdModulesNotifier)
+{
+    if (fileVersion >= cptv_MdModules)
+    {
+        gmx::FileIOXdrSerializer                    serializer(checkpointFileHandle);
+        gmx::KeyValueTreeObject                     mdModuleCheckpointParameterTree = gmx::deserializeKeyValueTree(&serializer);
+        gmx::MdModulesCheckpointReadingDataOnMaster mdModuleCheckpointReadingDataOnMaster
+            = { mdModuleCheckpointParameterTree, fileVersion };
+        mdModulesNotifier.notifier_.notify(mdModuleCheckpointReadingDataOnMaster);
+    }
+}
+
 static int do_cpt_files(XDR *xd, gmx_bool bRead,
                         std::vector<gmx_file_position_t> *outputfiles,
                         FILE *list, int file_version)
@@ -1956,7 +2006,8 @@ void write_checkpoint(const char *fn, gmx_bool bNumberAndKeep,
                       int eIntegrator, int simulation_part,
                       gmx_bool bExpanded, int elamstats,
                       int64_t step, double t,
-                      t_state *state, ObservablesHistory *observablesHistory)
+                      t_state *state, ObservablesHistory *observablesHistory,
+                      const gmx::MdModulesNotifier &mdModulesNotifier)
 {
     t_fileio            *fp;
     char                *fntemp; /* the temporary checkpoint file name */
@@ -2127,6 +2178,16 @@ void write_checkpoint(const char *fn, gmx_bool bNumberAndKeep,
         gmx_file("Cannot read/write checkpoint; corrupt file, or maybe you are out of disk space?");
     }
 
+    // Checkpointing MdModules
+    {
+        gmx::KeyValueTreeBuilder          builder;
+        gmx::MdModulesWriteCheckpointData mdModulesWriteCheckpoint = {builder.rootObject(), headerContents.file_version};
+        mdModulesNotifier.notifier_.notify(mdModulesWriteCheckpoint);
+        auto                              tree = builder.build();
+        gmx::FileIOXdrSerializer          serializer(fp);
+        gmx::serializeKeyValueTree(tree, &serializer);
+    }
+
     do_cpt_footer(gmx_fio_getxdr(fp), headerContents.file_version);
 
     /* we really, REALLY, want to make sure to physically write the checkpoint,
@@ -2347,7 +2408,8 @@ static void read_checkpoint(const char *fn, t_fileio *logfio,
                             CheckpointHeaderContents *headerContents,
                             t_state *state,
                             ObservablesHistory *observablesHistory,
-                            gmx_bool reproducibilityRequested)
+                            gmx_bool reproducibilityRequested,
+                            const gmx::MdModulesNotifier &mdModulesNotifier)
 {
     t_fileio            *fp;
     char                 buf[STEPSTRSIZE];
@@ -2505,7 +2567,7 @@ static void read_checkpoint(const char *fn, t_fileio *logfio,
     {
         cp_error();
     }
-
+    do_cpt_mdmodules(headerContents->file_version, fp, mdModulesNotifier);
     ret = do_cpt_footer(gmx_fio_getxdr(fp), headerContents->file_version);
     if (ret)
     {
@@ -2522,7 +2584,8 @@ void load_checkpoint(const char *fn, t_fileio *logfio,
                      const t_commrec *cr, const ivec dd_nc,
                      t_inputrec *ir, t_state *state,
                      ObservablesHistory *observablesHistory,
-                     gmx_bool reproducibilityRequested)
+                     gmx_bool reproducibilityRequested,
+                     const gmx::MdModulesNotifier &mdModulesNotifier)
 {
     CheckpointHeaderContents headerContents;
     if (SIMMASTER(cr))
@@ -2533,11 +2596,13 @@ void load_checkpoint(const char *fn, t_fileio *logfio,
                         ir->eI, &(ir->fepvals->init_fep_state),
                         &headerContents,
                         state, observablesHistory,
-                        reproducibilityRequested);
+                        reproducibilityRequested, mdModulesNotifier);
     }
     if (PAR(cr))
     {
         gmx_bcast(sizeof(headerContents.step), &headerContents.step, cr);
+        gmx::MdModulesCheckpointReadingBroadcast broadcastCheckPointData = {*cr, headerContents.file_version};
+        mdModulesNotifier.notifier_.notify(broadcastCheckPointData);
     }
     ir->bContinuation    = TRUE;
     // TODO Should the following condition be <=? Currently if you
@@ -2661,7 +2726,8 @@ read_checkpoint_data(t_fileio                         *fp,
     {
         cp_error();
     }
-
+    gmx::MdModulesNotifier mdModuleNotifier;
+    do_cpt_mdmodules(headerContents.file_version, fp, mdModuleNotifier);
     ret = do_cpt_footer(gmx_fio_getxdr(fp), headerContents.file_version);
     if (ret)
     {