Change MPI setup to communicate TPR as buffer
[alexxy/gromacs.git] / src / gromacs / fileio / tpxio.cpp
index 302af1dc5b7f2214e6d020eb6b425b5fcc0de248..33279bd528215a7daaa3f90bb5521523afa40c47 100644 (file)
@@ -2651,7 +2651,6 @@ static void do_tpxheader(gmx::FileIOXdrSerializer *serializer,
                          t_fileio                 *fio,
                          bool                      TopOnlyOK)
 {
-    gmx_bool  bDouble;
     int       precision;
     int       idum = 0;
     real      rdum = 0;
@@ -2670,24 +2669,28 @@ static void do_tpxheader(gmx::FileIOXdrSerializer *serializer,
                       "             Make a new one with grompp or use a gro or pdb file, if possible",
                       filename);
         }
+        // We need to know the precision used to write the TPR file, to match it
+        // to the precision of the currently running binary. If the precisions match
+        // there is no problem, but mismatching precision needs to be accounted for
+        // by reading into temporary variables of the correct precision instead
+        // of the desired target datastructures.
         serializer->doInt(&precision);
-        bDouble = (precision == sizeof(double));
-        if ((precision != sizeof(float)) && !bDouble)
+        tpx->isDouble = (precision == sizeof(double));
+        if ((precision != sizeof(float)) && !tpx->isDouble)
         {
             gmx_fatal(FARGS, "Unknown precision in file %s: real is %d bytes "
                       "instead of %zu or %zu",
                       filename, precision, sizeof(float), sizeof(double));
         }
-        gmx_fio_setprecision(fio, bDouble);
+        gmx_fio_setprecision(fio, tpx->isDouble);
         fprintf(stderr, "Reading file %s, %s (%s precision)\n",
-                filename, buf.c_str(), bDouble ? "double" : "single");
+                filename, buf.c_str(), tpx->isDouble ? "double" : "single");
     }
     else
     {
         buf = gmx::formatString("VERSION %s", gmx_version());
         serializer->doString(&buf);
-        bDouble = (precision == sizeof(double));
-        gmx_fio_setprecision(fio, bDouble);
+        gmx_fio_setprecision(fio, tpx->isDouble);
         serializer->doInt(&precision);
         fileTag        = gmx::formatString("%s", tpx_tag);
     }
@@ -2783,6 +2786,21 @@ static void do_tpxheader(gmx::FileIOXdrSerializer *serializer,
 
 #define do_test(serializer, b, p) if ((serializer)->reading() && ((p) != nullptr) && !(b)) gmx_fatal(FARGS, "No %s in input file",#p)
 
+/*! \brief
+ * Process the first part of the TPR into the state datastructure.
+ *
+ * Due to the structure of the legacy code, it is necessary
+ * to split up the state reading into two parts, with the
+ * box and legacy temperature coupling processed before the
+ * topology datastructures.
+ *
+ * See the documentation for do_tpx_body for the correct order of
+ * the operations for reading a tpr file.
+ *
+ * \param[in] serializer Abstract serializer used to read/write data.
+ * \param[in] tpx The file header data.
+ * \param[in, out] state Global state data.
+ */
 static void do_tpx_state_first(gmx::ISerializer *serializer,
                                TpxFileHeader    *tpx,
                                t_state          *state)
@@ -2827,6 +2845,16 @@ static void do_tpx_state_first(gmx::ISerializer *serializer,
     }
 }
 
+/*! \brief
+ * Process global topology data.
+ *
+ * See the documentation for do_tpx_body for the correct order of
+ * the operations for reading a tpr file.
+ *
+ * \param[in] serializer Abstract serializer  used to read/write data.
+ * \param[in] tpx The file header data.
+ * \param[in,out] mtop Global topology.
+ */
 static void do_tpx_mtop(gmx::ISerializer *serializer,
                         TpxFileHeader    *tpx,
                         gmx_mtop_t       *mtop)
@@ -2847,7 +2875,20 @@ static void do_tpx_mtop(gmx::ISerializer *serializer,
         }
     }
 }
-
+/*! \brief
+ * Process coordinate vectors for state data.
+ *
+ * Main part of state gets processed here.
+ *
+ * See the documentation for do_tpx_body for the correct order of
+ * the operations for reading a tpr file.
+ *
+ * \param[in] serializer Abstract serializer used to read/write data.
+ * \param[in] tpx The file header data.
+ * \param[in,out] state Global state data.
+ * \param[in,out] x Individual coordinates for processing, deprecated.
+ * \param[in,out] v Individual velocities for processing, deprecated.
+ */
 static void do_tpx_state_second(gmx::ISerializer *serializer,
                                 TpxFileHeader    *tpx,
                                 t_state          *state,
@@ -2922,7 +2963,16 @@ static void do_tpx_state_second(gmx::ISerializer *serializer,
         serializer->doRvecArray(as_rvec_array(dummyForces.data()), tpx->natoms);
     }
 }
-
+/*! \brief
+ * Process simulation parameters.
+ *
+ * See the documentation for do_tpx_body for the correct order of
+ * the operations for reading a tpr file.
+ *
+ * \param[in] serializer Abstract serializer used to read/write data.
+ * \param[in] tpx The file header data.
+ * \param[in,out] ir Datastructure with simulation parameters.
+ */
 static int do_tpx_ir(gmx::ISerializer *serializer,
                      TpxFileHeader    *tpx,
                      t_inputrec       *ir)
@@ -2976,7 +3026,10 @@ static int do_tpx_ir(gmx::ISerializer *serializer,
 /*! \brief
  * Correct and finalize read information.
  *
- * Moved here from previous code because this is done after reading files.
+ * If \p state is nullptr, skip the parts dependent on it.
+ *
+ * See the documentation for do_tpx_body for the correct order of
+ * the operations for reading a tpr file.
  *
  * \param[in] tpx The file header used to check version numbers.
  * \param[out] ir Input rec that needs correction.
@@ -2988,13 +3041,13 @@ static void do_tpx_finalize(TpxFileHeader *tpx,
                             t_state       *state,
                             gmx_mtop_t    *mtop)
 {
-    if (tpx->fileVersion < 51)
+    if (tpx->fileVersion < 51 && state)
     {
         set_box_rel(ir, state);
     }
     if (tpx->bIr && ir)
     {
-        if (state->ngtc == 0)
+        if (state && state->ngtc == 0)
         {
             /* Reading old version without tcoupl state data: set it */
             init_gtc_state(state, ir->opts.ngtc, 0, ir->opts.nhchainlength);
@@ -3016,6 +3069,26 @@ static void do_tpx_finalize(TpxFileHeader *tpx,
     }
 }
 
+/*! \brief
+ * Process TPR data for file reading/writing.
+ *
+ * The TPR file gets processed in in four stages due to the organization
+ * of the data within it.
+ *
+ * First, state data for the box is processed in do_tpx_state_first.
+ * This is followed by processing the topology in do_tpx_mtop.
+ * Coordinate and velocity vectors are handled next in do_tpx_state_second.
+ * The last file information processed is the collection of simulation parameters in do_tpx_ir.
+ * When reading, a final processing step is undertaken at the end.
+ *
+ * \param[in] serializer Abstract serializer used to read/write data.
+ * \param[in] tpx The file header data.
+ * \param[in,out] ir Datastructures with simulation parameters.
+ * \param[in,out] state Global state data.
+ * \param[in,out] x Individual coordinates for processing, deprecated.
+ * \param[in,out] v Individual velocities for processing, deprecated.
+ * \param[in,out] mtop Global topology.
+ */
 static int do_tpx_body(gmx::ISerializer *serializer,
                        TpxFileHeader    *tpx,
                        t_inputrec       *ir,
@@ -3024,9 +3097,15 @@ static int do_tpx_body(gmx::ISerializer *serializer,
                        rvec             *v,
                        gmx_mtop_t       *mtop)
 {
-    do_tpx_state_first(serializer, tpx, state);
+    if (state)
+    {
+        do_tpx_state_first(serializer, tpx, state);
+    }
     do_tpx_mtop(serializer, tpx, mtop);
-    do_tpx_state_second(serializer, tpx, state, x, v);
+    if (state)
+    {
+        do_tpx_state_second(serializer, tpx, state, x, v);
+    }
     int ePBC = do_tpx_ir(serializer, tpx, ir);
     if (serializer->reading())
     {
@@ -3035,6 +3114,22 @@ static int do_tpx_body(gmx::ISerializer *serializer,
     return ePBC;
 }
 
+/*! \brief
+ * Overload for do_tpx_body that defaults to state vectors being nullptr.
+ *
+ * \param[in] serializer Abstract serializer used to read/write data.
+ * \param[in] tpx The file header data.
+ * \param[in,out] ir Datastructures with simulation parameters.
+ * \param[in,out] mtop Global topology.
+ */
+static int do_tpx_body(gmx::ISerializer *serializer,
+                       TpxFileHeader    *tpx,
+                       t_inputrec       *ir,
+                       gmx_mtop_t       *mtop)
+{
+    return do_tpx_body(serializer, tpx, ir, nullptr, nullptr, nullptr, mtop);
+}
+
 static t_fileio *open_tpx(const char *fn, const char *mode)
 {
     return gmx_fio_open(fn, mode);
@@ -3073,13 +3168,100 @@ static TpxFileHeader populateTpxHeader(const t_state    &state,
     header.bBox           = true;
     header.fileVersion    = tpx_version;
     header.fileGeneration = tpx_generation;
+    header.isDouble       = (sizeof(real) == sizeof(double));
 
     return header;
 }
 
-static void doTpxBodyBuffer(gmx::ISerializer *topologySerializer, gmx::ArrayRef<char> buffer)
+/*! \brief
+ * Process the body of a TPR file as char buffer.
+ *
+ * Reads/writes the information in \p buffer from/to the \p serializer
+ * provided to the function. Does not interact with the actual
+ * TPR datastructures but with an in memory representation of the
+ * data, so that this data can be efficiently read or written from/to
+ * an original source.
+ *
+ * \param[in] serializer The abstract serializer used for reading or writing
+ *                       the information in \p buffer.
+ * \param[in,out] buffer Information from TPR file as char buffer.
+ */
+static void doTpxBodyBuffer(gmx::ISerializer *serializer, gmx::ArrayRef<char> buffer)
+{
+    serializer->doCharArray(buffer.data(), buffer.size());
+}
+
+/*! \brief
+ * Populates simulation datastructures.
+ *
+ * Here the information from the serialization interface \p serializer
+ * is used to first populate the datastructures containing the simulation
+ * information. Depending on the version found in the header \p tpx,
+ * this is done using the new reading of the data as one block from disk,
+ * followed by complete deserialization of the information read from there.
+ * Otherwise, the datastructures are populated as before one by one from disk.
+ * The second version is the default for the legacy tools that read the
+ * coordinates and velocities separate from the state.
+ *
+ * After reading in the data, a separate buffer is populated from them
+ * containing only \p ir and \p mtop that can be communicated directly
+ * to nodes needing the information to set up a simulation.
+ *
+ * \param[in] tpx The file header.
+ * \param[in] serializer The Serialization interface used to read the TPR.
+ * \param[out] ir Input rec to populate.
+ * \param[out] state State vectors to populate.
+ * \param[out] x Coordinates to populate if needed.
+ * \param[out] v Velocities to populate if needed.
+ * \param[out] mtop Global topology to populate.
+ *
+ * \returns Partial de-serialized TPR used for communication to nodes.
+ */
+static PartialDeserializedTprFile readTpxBody(TpxFileHeader *tpx,
+                                              gmx::ISerializer *serializer,
+                                              t_inputrec *ir,
+                                              t_state *state,
+                                              rvec *x, rvec *v,
+                                              gmx_mtop_t *mtop)
 {
-    topologySerializer->doCharArray(buffer.data(), buffer.size());
+    PartialDeserializedTprFile partialDeserializedTpr;
+    if (tpx->fileVersion >= tpxv_AddSizeField && tpx->fileGeneration >= 27)
+    {
+        partialDeserializedTpr.body.resize(tpx->sizeOfTprBody);
+        partialDeserializedTpr.header = *tpx;
+        doTpxBodyBuffer(serializer, partialDeserializedTpr.body);
+
+        partialDeserializedTpr.ePBC =
+            completeTprDeserialization(&partialDeserializedTpr,
+                                       ir,
+                                       state,
+                                       x, v,
+                                       mtop);
+    }
+    else
+    {
+        partialDeserializedTpr.ePBC =
+            do_tpx_body(serializer,
+                        tpx,
+                        ir,
+                        state,
+                        x,
+                        v,
+                        mtop);
+    }
+    // Update header to system info for communication to nodes.
+    // As we only need to communicate the inputrec and mtop to other nodes,
+    // we prepare a new char buffer with the information we have already read
+    // in on master.
+    partialDeserializedTpr.header = populateTpxHeader(*state, ir, mtop);
+    gmx::InMemorySerializer tprBodySerializer;
+    do_tpx_body(&tprBodySerializer,
+                &partialDeserializedTpr.header,
+                ir,
+                mtop);
+    partialDeserializedTpr.body = tprBodySerializer.finishAndGetBuffer();
+
+    return partialDeserializedTpr;
 }
 
 /************************************************************
@@ -3139,73 +3321,54 @@ void write_tpx_state(const char *fn,
     close_tpx(fio);
 }
 
-/*! \brief
- * Wraps reading of header before and after introduction of size field.
- *
- * \param[in] tpx The file header.
- * \param[in] serializer The Serialization interface used to read the TPR.
- * \param[in] isDouble Whether the file is double or single precision.
- * \param[out] ir Input rec to populate.
- * \param[out] state State vectors to populate.
- * \param[out] x Coordinates to populate if needed.
- * \param[out] v Velocities to populate if needed.
- * \param[out] mtop Global topology to populate.
- *
- * \returns Flag for pbc.
- */
-static int do_tpx_body_dispatcher(TpxFileHeader *tpx,
-                                  gmx::ISerializer *serializer,
-                                  bool isDouble,
-                                  t_inputrec *ir,
-                                  t_state *state,
-                                  rvec *x, rvec *v,
-                                  gmx_mtop_t *mtop)
+int completeTprDeserialization(PartialDeserializedTprFile *partialDeserializedTpr,
+                               t_inputrec                 *ir,
+                               t_state                    *state,
+                               rvec                       *x,
+                               rvec                       *v,
+                               gmx_mtop_t                 *mtop)
 {
-    int ePBC;
-    if (tpx->fileVersion >= tpxv_AddSizeField && tpx->fileGeneration >= 27)
-    {
-        std::vector<char>         tprBody(tpx->sizeOfTprBody);
-        doTpxBodyBuffer(serializer, tprBody);
-        gmx::InMemoryDeserializer tprBodyDeserializer(tprBody, isDouble);
-
-        ePBC = do_tpx_body(&tprBodyDeserializer,
-                           tpx,
-                           ir,
-                           state,
-                           x,
-                           v,
-                           mtop);
-    }
-    else
-    {
-        ePBC = do_tpx_body(serializer,
-                           tpx,
-                           ir,
-                           state,
-                           x,
-                           v,
-                           mtop);
-    }
-    return ePBC;
+    gmx::InMemoryDeserializer tprBodyDeserializer(partialDeserializedTpr->body,
+                                                  partialDeserializedTpr->header.isDouble);
+    return do_tpx_body(&tprBodyDeserializer,
+                       &partialDeserializedTpr->header,
+                       ir,
+                       state,
+                       x,
+                       v,
+                       mtop);
 }
 
-
-void read_tpx_state(const char *fn,
-                    t_inputrec *ir, t_state *state, gmx_mtop_t *mtop)
+int completeTprDeserialization(PartialDeserializedTprFile *partialDeserializedTpr,
+                               t_inputrec                 *ir,
+                               gmx_mtop_t                 *mtop)
 {
-    t_fileio                *fio;
+    return completeTprDeserialization(partialDeserializedTpr, ir, nullptr, nullptr, nullptr, mtop);
+}
 
-    TpxFileHeader            tpx;
+PartialDeserializedTprFile read_tpx_state(const char *fn,
+                                          t_inputrec *ir,
+                                          t_state    *state,
+                                          gmx_mtop_t *mtop)
+{
+    t_fileio                   *fio;
     fio = open_tpx(fn, "r");
-    gmx::FileIOXdrSerializer serializer(fio);
+    gmx::FileIOXdrSerializer    serializer(fio);
+    PartialDeserializedTprFile  partialDeserializedTpr;
     do_tpxheader(&serializer,
-                 &tpx,
+                 &partialDeserializedTpr.header,
                  fn,
                  fio,
                  ir == nullptr);
-    do_tpx_body_dispatcher(&tpx, &serializer, gmx_fio_is_double(fio),
-                           ir, state, nullptr, nullptr, mtop);
+    partialDeserializedTpr = readTpxBody(&partialDeserializedTpr.header,
+                                         &serializer,
+                                         ir,
+                                         state,
+                                         nullptr,
+                                         nullptr,
+                                         mtop);
     close_tpx(fio);
+    return partialDeserializedTpr;
 }
 
 int read_tpx(const char *fn,
@@ -3214,7 +3377,6 @@ int read_tpx(const char *fn,
 {
     t_fileio                *fio;
     t_state                  state;
-    int                      ePBC;
 
     TpxFileHeader            tpx;
     fio     = open_tpx(fn, "r");
@@ -3224,8 +3386,9 @@ int read_tpx(const char *fn,
                  fn,
                  fio,
                  ir == nullptr);
-    ePBC = do_tpx_body_dispatcher(&tpx, &serializer, gmx_fio_is_double(fio),
-                                  ir, &state, x, v, mtop);
+    PartialDeserializedTprFile partialDeserializedTpr
+        = readTpxBody(&tpx, &serializer,
+                      ir, &state, x, v, mtop);
     close_tpx(fio);
     if (mtop != nullptr && natoms != nullptr)
     {
@@ -3235,7 +3398,7 @@ int read_tpx(const char *fn,
     {
         copy_mat(state.box, box);
     }
-    return ePBC;
+    return partialDeserializedTpr.ePBC;
 }
 
 int read_tpx_top(const char *fn,