Use ObservablesReducer for LINCS RMSD computation
[alexxy/gromacs.git] / src / gromacs / mdlib / lincs.cpp
index bc476f1d7d2f16be0635db1e6218ce6d1129b156..3c2f2d2547a9a859a0866596dafbe4dc6ce8cc86 100644 (file)
@@ -53,6 +53,7 @@
 #include <cstdlib>
 
 #include <algorithm>
+#include <optional>
 #include <vector>
 
 #include "gromacs/domdec/domdec.h"
@@ -68,6 +69,7 @@
 #include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/inputrec.h"
 #include "gromacs/mdtypes/md_enums.h"
+#include "gromacs/mdtypes/observablesreducer.h"
 #include "gromacs/pbcutil/pbc.h"
 #include "gromacs/pbcutil/pbc_simd.h"
 #include "gromacs/simd/simd.h"
@@ -207,8 +209,24 @@ public:
     /*! @} */
     //! The Lagrange multipliers times -1.
     std::vector<real, AlignedAllocator<real>> mlambda;
-    //! Storage for the constraint RMS relative deviation output.
-    std::array<real, 2> rmsdData = { { 0 } };
+    /*! \brief Callback used after constraining to require reduction
+     * of values later used to compute the constraint RMS relative
+     * deviation, so the latter can be output. */
+    std::optional<ObservablesReducerBuilder::CallbackToRequireReduction> callbackToRequireReduction;
+    /*! \brief View used for reducing the components of the global
+     * relative RMS constraint deviation.
+     *
+     * Can be written any time, but that is only useful when followed
+     * by a call of the callbackToRequireReduction. Useful to read
+     * only from the callback that the ObservablesReducer will later
+     * make after reduction. */
+    ArrayRef<double> rmsdReductionBuffer;
+    /*! \brief The value of the constraint RMS deviation after it has
+     * been computed.
+     *
+     * When DD is active, filled by the ObservablesReducer, otherwise
+     * filled directly here. */
+    std::optional<double> constraintRmsDeviation;
 };
 
 /*! \brief Define simd_width for memory allocation used for SIMD code */
@@ -218,16 +236,11 @@ static const int simd_width = GMX_SIMD_REAL_WIDTH;
 static const int simd_width = 1;
 #endif
 
-ArrayRef<real> lincs_rmsdData(Lincs* lincsd)
-{
-    return lincsd->rmsdData;
-}
-
 real lincs_rmsd(const Lincs* lincsd)
 {
-    if (lincsd->rmsdData[0] > 0)
+    if (lincsd->constraintRmsDeviation.has_value())
     {
-        return std::sqrt(lincsd->rmsdData[1] / lincsd->rmsdData[0]);
+        return real(lincsd->constraintRmsDeviation.value());
     }
     else
     {
@@ -1439,7 +1452,8 @@ Lincs* init_lincs(FILE*                            fplog,
                   ArrayRef<const ListOfLists<int>> atomToConstraintsPerMolType,
                   bool                             bPLINCS,
                   int                              nIter,
-                  int                              nProjOrder)
+                  int                              nProjOrder,
+                  ObservablesReducerBuilder*       observablesReducerBuilder)
 {
     // TODO this should become a unique_ptr
     Lincs* li;
@@ -1548,6 +1562,29 @@ Lincs* init_lincs(FILE*                            fplog,
         }
     }
 
+    if (observablesReducerBuilder)
+    {
+        ObservablesReducerBuilder::CallbackFromBuilder callbackFromBuilder =
+                [li](ObservablesReducerBuilder::CallbackToRequireReduction c, gmx::ArrayRef<double> v) {
+                    li->callbackToRequireReduction = std::move(c);
+                    li->rmsdReductionBuffer        = v;
+                };
+
+        // Make the callback that runs afer reduction.
+        ObservablesReducerBuilder::CallbackAfterReduction callbackAfterReduction = [li](gmx::Step /*step*/) {
+            if (li->rmsdReductionBuffer[0] > 0)
+            {
+                li->constraintRmsDeviation =
+                        std::sqrt(li->rmsdReductionBuffer[1] / li->rmsdReductionBuffer[0]);
+            }
+        };
+
+        const int reductionValuesRequired = 2;
+        observablesReducerBuilder->addSubscriber(reductionValuesRequired,
+                                                 std::move(callbackFromBuilder),
+                                                 std::move(callbackAfterReduction));
+    }
+
     return li;
 }
 
@@ -2136,9 +2173,6 @@ void set_lincs(const InteractionDefinitions& idef,
     }
 
     set_lincs_matrix(li, invmass, lambda);
-
-    li->rmsdData[0] = 0.0;
-    li->rmsdData[1] = 0.0;
 }
 
 //! Issues a warning when LINCS constraints cannot be satisfied.
@@ -2300,11 +2334,6 @@ bool constrain_lincs(bool                            computeRmsd,
 
     if (lincsd->nc == 0 && cr->dd == nullptr)
     {
-        if (computeRmsd)
-        {
-            lincsd->rmsdData = { { 0 } };
-        }
-
         return bOK;
     }
 
@@ -2410,15 +2439,27 @@ bool constrain_lincs(bool                            computeRmsd,
 
             if (computeRmsd)
             {
-                // This is reduced across domains in compute_globals and
-                // reported to the log file.
-                lincsd->rmsdData[0] = deviations.numConstraints;
-                lincsd->rmsdData[1] = deviations.sumSquaredDeviation;
-            }
-            else
-            {
-                // This is never read
-                lincsd->rmsdData = { { 0 } };
+                if (lincsd->callbackToRequireReduction.has_value())
+                {
+                    // This is reduced across domains in compute_globals and
+                    // reported to the log file.
+                    lincsd->rmsdReductionBuffer[0] = deviations.numConstraints;
+                    lincsd->rmsdReductionBuffer[1] = deviations.sumSquaredDeviation;
+
+                    // Call the ObservablesReducer via the callback it
+                    // gave us for the purpose.
+                    ObservablesReducerStatus status =
+                            lincsd->callbackToRequireReduction.value()(ReductionRequirement::Soon);
+                    GMX_RELEASE_ASSERT(status == ObservablesReducerStatus::ReadyToReduce,
+                                       "The LINCS RMSD is computed after observables have been "
+                                       "reduced, please reorder them.");
+                }
+                else
+                {
+                    // Compute the deviation directly
+                    lincsd->constraintRmsDeviation =
+                            std::sqrt(deviations.sumSquaredDeviation / deviations.numConstraints);
+                }
             }
             if (printDebugOutput)
             {