Make AWH bias sharing more flexible
[alexxy/gromacs.git] / src / gromacs / applied_forces / awh / biasstate.cpp
index 55442b10fdc0578f712b40a4524ba8b219803cec..3a43f29db1ff8993ba3d026d39ff3d46d8d70248 100644 (file)
@@ -74,6 +74,7 @@
 #include "gromacs/utility/stringutil.h"
 
 #include "biasgrid.h"
+#include "biassharing.h"
 #include "pointstate.h"
 
 namespace gmx
@@ -95,70 +96,24 @@ void BiasState::getPmf(gmx::ArrayRef<float> pmf) const
 namespace
 {
 
-/*! \brief
- * Sum an array over all simulations on the master rank of each simulation.
- *
- * \param[in,out] arrayRef      The data to sum.
- * \param[in]     multiSimComm  Struct for multi-simulation communication.
- */
-void sumOverSimulations(gmx::ArrayRef<int> arrayRef, const gmx_multisim_t* multiSimComm)
-{
-    gmx_sumi_sim(arrayRef.size(), arrayRef.data(), multiSimComm);
-}
-
-/*! \brief
- * Sum an array over all simulations on the master rank of each simulation.
- *
- * \param[in,out] arrayRef      The data to sum.
- * \param[in]     multiSimComm  Struct for multi-simulation communication.
- */
-void sumOverSimulations(gmx::ArrayRef<double> arrayRef, const gmx_multisim_t* multiSimComm)
-{
-    gmx_sumd_sim(arrayRef.size(), arrayRef.data(), multiSimComm);
-}
-
-/*! \brief
- * Sum an array over all simulations on all ranks of each simulation.
- *
- * This assumes the data is identical on all ranks within each simulation.
- *
- * \param[in,out] arrayRef      The data to sum.
- * \param[in]     commRecord    Struct for intra-simulation communication.
- * \param[in]     multiSimComm  Struct for multi-simulation communication.
- */
-template<typename T>
-void sumOverSimulations(gmx::ArrayRef<T> arrayRef, const t_commrec* commRecord, const gmx_multisim_t* multiSimComm)
-{
-    if (MASTER(commRecord))
-    {
-        sumOverSimulations(arrayRef, multiSimComm);
-    }
-    if (commRecord->nnodes > 1)
-    {
-        gmx_bcast(arrayRef.size() * sizeof(T), arrayRef.data(), commRecord->mpi_comm_mygroup);
-    }
-}
-
 /*! \brief
  * Sum PMF over multiple simulations, when requested.
  *
  * \param[in,out] pointState         The state of the points in the bias.
  * \param[in]     numSharedUpdate    The number of biases sharing the histogram.
- * \param[in]     commRecord         Struct for intra-simulation communication.
- * \param[in]     multiSimComm       Struct for multi-simulation communication.
+ * \param[in]     biasSharing        Object for sharing bias data over multiple simulations
+ * \param[in]     biasIndex          Index of this bias in the total list of biases in this simulation
  */
-void sumPmf(gmx::ArrayRef<PointState> pointState,
-            int                       numSharedUpdate,
-            const t_commrec*          commRecord,
-            const gmx_multisim_t*     multiSimComm)
+void sumPmf(gmx::ArrayRef<PointState> pointState, int numSharedUpdate, const BiasSharing* biasSharing, const int biasIndex)
 {
     if (numSharedUpdate == 1)
     {
         return;
     }
-    GMX_ASSERT(multiSimComm != nullptr && numSharedUpdate % multiSimComm->numSimulations_ == 0,
+    GMX_ASSERT(biasSharing != nullptr
+                       && numSharedUpdate % biasSharing->numSharingSimulations(biasIndex) == 0,
                "numSharedUpdate should be a multiple of multiSimComm->numSimulations_");
-    GMX_ASSERT(numSharedUpdate == multiSimComm->numSimulations_,
+    GMX_ASSERT(numSharedUpdate == biasSharing->numSharingSimulations(biasIndex),
                "Sharing within a simulation is not implemented (yet)");
 
     std::vector<double> buffer(pointState.size());
@@ -169,7 +124,7 @@ void sumPmf(gmx::ArrayRef<PointState> pointState,
         buffer[i] = pointState[i].inTargetRegion() ? std::exp(pointState[i].logPmfSum()) : 0;
     }
 
-    sumOverSimulations(gmx::ArrayRef<double>(buffer), commRecord, multiSimComm);
+    biasSharing->sum(gmx::ArrayRef<double>(buffer), biasIndex);
 
     /* Take log again to get (non-normalized) PMF */
     double normFac = 1.0 / numSharedUpdate;
@@ -697,13 +652,13 @@ namespace
  *
  * \param[in,out] updateList    Update list for this simulation (assumed >= npoints long).
  * \param[in]     numPoints     Total number of points.
- * \param[in]     commRecord    Struct for intra-simulation communication.
- * \param[in]     multiSimComm  Struct for multi-simulation communication.
+ * \param[in]     biasSharing   Object for sharing bias data over multiple simulations
+ * \param[in]     biasIndex     Index of this bias in the total list of biases in this simulation
  */
-void mergeSharedUpdateLists(std::vector<int>*     updateList,
-                            int                   numPoints,
-                            const t_commrec*      commRecord,
-                            const gmx_multisim_t* multiSimComm)
+void mergeSharedUpdateLists(std::vector<int>*  updateList,
+                            int                numPoints,
+                            const BiasSharing& biasSharing,
+                            const int          biasIndex)
 {
     std::vector<int> numUpdatesOfPoint;
 
@@ -716,7 +671,7 @@ void mergeSharedUpdateLists(std::vector<int>*     updateList,
     }
 
     /* Sum over the sims to get all the flagged points */
-    sumOverSimulations(arrayRefFromArray(numUpdatesOfPoint.data(), numPoints), commRecord, multiSimComm);
+    biasSharing.sum(arrayRefFromArray(numUpdatesOfPoint.data(), numPoints), biasIndex);
 
     /* Collect the indices of the flagged points in place. The resulting array will be the merged update list.*/
     updateList->clear();
@@ -796,15 +751,15 @@ namespace
  * \param[in,out] pointState         The state of the points in the bias.
  * \param[in,out] weightSumCovering  The weights for checking covering.
  * \param[in]     numSharedUpdate    The number of biases sharing the histrogram.
- * \param[in]     commRecord         Struct for intra-simulation communication.
- * \param[in]     multiSimComm       Struct for multi-simulation communication.
- * \param[in]     localUpdateList    List of points with data.
+ * \param[in]     biasSharing        Object for sharing bias data over multiple simulations
+ * \param[in]     biasIndex          Index of this bias in the total list of biases in this
+ * simulation \param[in]     localUpdateList    List of points with data.
  */
 void sumHistograms(gmx::ArrayRef<PointState> pointState,
                    gmx::ArrayRef<double>     weightSumCovering,
                    int                       numSharedUpdate,
-                   const t_commrec*          commRecord,
-                   const gmx_multisim_t*     multiSimComm,
+                   const BiasSharing*        biasSharing,
+                   const int                 biasIndex,
                    const std::vector<int>&   localUpdateList)
 {
     /* The covering checking histograms are added before summing over simulations, so that the
@@ -817,7 +772,7 @@ void sumHistograms(gmx::ArrayRef<PointState> pointState,
     /* Sum histograms over multiple simulations if needed. */
     if (numSharedUpdate > 1)
     {
-        GMX_ASSERT(numSharedUpdate == multiSimComm->numSimulations_,
+        GMX_ASSERT(numSharedUpdate == biasSharing->numSharingSimulations(biasIndex),
                    "Sharing within a simulation is not implemented (yet)");
 
         /* Collect the weights and counts in linear arrays to be able to use gmx_sumd_sim. */
@@ -835,8 +790,8 @@ void sumHistograms(gmx::ArrayRef<PointState> pointState,
             coordVisits[localIndex] = ps.numVisitsIteration();
         }
 
-        sumOverSimulations(gmx::ArrayRef<double>(weightSum), commRecord, multiSimComm);
-        sumOverSimulations(gmx::ArrayRef<double>(coordVisits), commRecord, multiSimComm);
+        biasSharing->sum(gmx::ArrayRef<double>(weightSum), biasIndex);
+        biasSharing->sum(gmx::ArrayRef<double>(coordVisits), biasIndex);
 
         /* Transfer back the result */
         for (size_t localIndex = 0; localIndex < localUpdateList.size(); localIndex++)
@@ -964,9 +919,7 @@ void labelCoveredPoints(const std::vector<bool>& visited,
 
 bool BiasState::isSamplingRegionCovered(const BiasParams&         params,
                                         ArrayRef<const DimParams> dimParams,
-                                        const BiasGrid&           grid,
-                                        const t_commrec*          commRecord,
-                                        const gmx_multisim_t*     multiSimComm) const
+                                        const BiasGrid&           grid) const
 {
     /* Allocate and initialize arrays: one for checking visits along each dimension,
        one for keeping track of which points to check and one for the covered points.
@@ -1072,10 +1025,8 @@ bool BiasState::isSamplingRegionCovered(const BiasParams&         params,
         /* For multiple dimensions this may not be the best way to do it. */
         for (int d = 0; d < grid.numDimensions(); d++)
         {
-            sumOverSimulations(
-                    gmx::arrayRefFromArray(checkDim[d].covered.data(), grid.axis(d).numPoints()),
-                    commRecord,
-                    multiSimComm);
+            biasSharing_->sum(gmx::arrayRefFromArray(checkDim[d].covered.data(), grid.axis(d).numPoints()),
+                              params.biasIndex);
         }
     }
 
@@ -1110,8 +1061,6 @@ static void normalizeFreeEnergyAndPmfSum(std::vector<PointState>* pointState)
 void BiasState::updateFreeEnergyAndAddSamplesToHistogram(ArrayRef<const DimParams> dimParams,
                                                          const BiasGrid&           grid,
                                                          const BiasParams&         params,
-                                                         const t_commrec*          commRecord,
-                                                         const gmx_multisim_t*     multiSimComm,
                                                          double                    t,
                                                          int64_t                   step,
                                                          FILE*                     fplog,
@@ -1130,16 +1079,16 @@ void BiasState::updateFreeEnergyAndAddSamplesToHistogram(ArrayRef<const DimParam
     makeLocalUpdateList(grid, points_, originUpdatelist_, endUpdatelist_, updateList);
     if (params.numSharedUpdate > 1)
     {
-        mergeSharedUpdateLists(updateList, points_.size(), commRecord, multiSimComm);
+        mergeSharedUpdateLists(updateList, points_.size(), *biasSharing_, params.biasIndex);
     }
 
     /* Reset the range for the next update */
     resetLocalUpdateRange(grid);
 
     /* Add samples to histograms for all local points and sync simulations if needed */
-    sumHistograms(points_, weightSumCovering_, params.numSharedUpdate, commRecord, multiSimComm, *updateList);
+    sumHistograms(points_, weightSumCovering_, params.numSharedUpdate, biasSharing_, params.biasIndex, *updateList);
 
-    sumPmf(points_, params.numSharedUpdate, commRecord, multiSimComm);
+    sumPmf(points_, params.numSharedUpdate, biasSharing_, params.biasIndex);
 
     /* Renormalize the free energy if values are too large. */
     bool needToNormalizeFreeEnergy = false;
@@ -1166,8 +1115,7 @@ void BiasState::updateFreeEnergyAndAddSamplesToHistogram(ArrayRef<const DimParam
     if (inInitialStage())
     {
         detectedCovering =
-                (params.isCheckCoveringStep(step)
-                 && isSamplingRegionCovered(params, dimParams, grid, commRecord, multiSimComm));
+                (params.isCheckCoveringStep(step) && isSamplingRegionCovered(params, dimParams, grid));
     }
 
     /* The weighthistogram size after this update. */
@@ -1873,11 +1821,13 @@ void BiasState::initGridPointState(const AwhBiasParams&      awhBiasParams,
 BiasState::BiasState(const AwhBiasParams&      awhBiasParams,
                      double                    histogramSizeInitial,
                      ArrayRef<const DimParams> dimParams,
-                     const BiasGrid&           grid) :
+                     const BiasGrid&           grid,
+                     const BiasSharing*        biasSharing) :
     coordState_(awhBiasParams, dimParams, grid),
     points_(grid.numPoints()),
     weightSumCovering_(grid.numPoints()),
-    histogramSize_(awhBiasParams, histogramSizeInitial)
+    histogramSize_(awhBiasParams, histogramSizeInitial),
+    biasSharing_(biasSharing)
 {
     /* The minimum and maximum multidimensional point indices that are affected by the next update */
     for (size_t d = 0; d < dimParams.size(); d++)