Use ObservablesReducer for check of DD bonded interaction count.
[alexxy/gromacs.git] / src / gromacs / domdec / localtopologychecker.cpp
index c4af1d1d320e7666d899da05cfd43075840ecc6d..8db5e261be92d565b3bc36a21b1c55b253ea4268 100644 (file)
@@ -54,6 +54,8 @@
 #include "gromacs/domdec/reversetopology.h"
 #include "gromacs/gmxlib/network.h"
 #include "gromacs/mdtypes/commrec.h"
+#include "gromacs/mdtypes/observablesreducer.h"
+#include "gromacs/mdtypes/state.h"
 #include "gromacs/topology/idef.h"
 #include "gromacs/topology/ifunc.h"
 #include "gromacs/topology/mtop_util.h"
@@ -257,7 +259,7 @@ static void printMissingInteractionsAtoms(const MDLogger&               mdlog,
                                                        const int numBondedInteractionsOverAllDomains,
                                                        const int expectedNumGlobalBondedInteractions,
                                                        const gmx_mtop_t&     top_global,
-                                                       const gmx_localtop_t* top_local,
+                                                       const gmx_localtop_t& top_local,
                                                        ArrayRef<const RVec>  x,
                                                        const matrix          box)
 {
@@ -274,7 +276,7 @@ static void printMissingInteractionsAtoms(const MDLogger&               mdlog,
     for (int ftype = 0; ftype < F_NRE; ftype++)
     {
         const int nral = NRAL(ftype);
-        cl[ftype]      = top_local->idef.il[ftype].size() / (1 + nral);
+        cl[ftype]      = top_local.idef.il[ftype].size() / (1 + nral);
     }
 
     gmx_sumi(F_NRE, cl, cr);
@@ -318,7 +320,7 @@ static void printMissingInteractionsAtoms(const MDLogger&               mdlog,
         }
     }
 
-    printMissingInteractionsAtoms(mdlog, cr, top_global, top_local->idef);
+    printMissingInteractionsAtoms(mdlog, cr, top_global, top_local.idef);
     write_dd_pdb("dd_dump_err", 0, "dump", top_global, cr, -1, as_rvec_array(x.data()), box);
 
     std::string errorMessage;
@@ -357,8 +359,12 @@ class LocalTopologyChecker::Impl
 {
 public:
     //! Constructor
-    Impl(const MDLogger& mdlog, const t_commrec* cr, const gmx_mtop_t& mtop, bool useUpdateGroups);
-
+    Impl(const MDLogger&       mdlog,
+         const t_commrec*      cr,
+         const gmx_mtop_t&     mtop,
+         const gmx_localtop_t& localTopology,
+         const t_state&        localState,
+         bool                  useUpdateGroups);
     //! Objects used when reporting that interactions are missing
     //! {
     //! Logger
@@ -367,25 +373,25 @@ public:
     const t_commrec* cr_;
     //! Global system topology
     const gmx_mtop_t& mtop_;
+    //! Local topology
+    const gmx_localtop_t& localTopology_;
+    //! Local state
+    const t_state& localState_;
     //! }
 
-    /*! \brief Number of bonded interactions found in the local
-     * topology for this domain. */
-    int numBondedInteractionsToReduce_ = 0;
-    /*! \brief Whether to check at the next global communication
-     * stage the total number of bonded interactions found.
-     *
-     * Cleared after that number is found. */
-    bool shouldCheckNumberOfBondedInteractions_ = false;
-    /*! \brief The total number of bonded interactions found in
-     * the local topology across all domains.
+    /*! \brief View used for computing the global number of bonded interactions.
      *
-     * Only has a value after reduction across all ranks, which is
-     * removed when it is again time to check after a new
-     * partition. */
-    std::optional<int> numBondedInteractionsOverAllDomains_;
-    //! The number of bonded interactions computed from the full system topology
-    int expectedNumGlobalBondedInteractions_ = 0;
+     * Can be written any time, but that is only useful when followed
+     * by a call of the callbackToRequireReduction. Useful to read
+     * only from the callback that the ObservablesReducer will later
+     * make after reduction. */
+    gmx::ArrayRef<double> reductionBuffer_;
+    /*! \brief Callback used after repartitioning to require reduction
+     * of numBondedInteractionsToReduce so that the total number of
+     * bonded interactions can be checked. */
+    gmx::ObservablesReducerBuilder::CallbackToRequireReduction callbackToRequireReduction_;
+    /*! \brief The expected number of global bonded interactions from the system topology */
+    int expectedNumGlobalBondedInteractions_;
 };
 
 
@@ -409,23 +415,57 @@ static int computeExpectedNumGlobalBondedInteractions(const gmx_mtop_t& mtop, co
     return expectedNumGlobalBondedInteractions;
 }
 
-LocalTopologyChecker::Impl::Impl(const MDLogger&   mdlog,
-                                 const t_commrec*  cr,
-                                 const gmx_mtop_t& mtop,
-                                 const bool        useUpdateGroups) :
+LocalTopologyChecker::Impl::Impl(const MDLogger&       mdlog,
+                                 const t_commrec*      cr,
+                                 const gmx_mtop_t&     mtop,
+                                 const gmx_localtop_t& localTopology,
+                                 const t_state&        localState,
+                                 bool                  useUpdateGroups) :
     mdlog_(mdlog),
     cr_(cr),
     mtop_(mtop),
+    localTopology_(localTopology),
+    localState_(localState),
     expectedNumGlobalBondedInteractions_(computeExpectedNumGlobalBondedInteractions(mtop, useUpdateGroups))
 {
 }
 
-LocalTopologyChecker::LocalTopologyChecker(const MDLogger&   mdlog,
-                                           const t_commrec*  cr,
-                                           const gmx_mtop_t& mtop,
-                                           const bool        useUpdateGroups) :
-    impl_(std::make_unique<Impl>(mdlog, cr, mtop, useUpdateGroups))
+LocalTopologyChecker::LocalTopologyChecker(const MDLogger&            mdlog,
+                                           const t_commrec*           cr,
+                                           const gmx_mtop_t&          mtop,
+                                           const gmx_localtop_t&      localTopology,
+                                           const t_state&             localState,
+                                           const bool                 useUpdateGroups,
+                                           ObservablesReducerBuilder* observablesReducerBuilder) :
+    impl_(std::make_unique<Impl>(mdlog, cr, mtop, localTopology, localState, useUpdateGroups))
 {
+    Impl*                                          impl = impl_.get();
+    ObservablesReducerBuilder::CallbackFromBuilder callbackFromBuilder =
+            [impl](ObservablesReducerBuilder::CallbackToRequireReduction c, gmx::ArrayRef<double> v) {
+                impl->callbackToRequireReduction_ = std::move(c);
+                impl->reductionBuffer_            = v;
+            };
+
+    // Make the callback that runs afer reduction.
+    ObservablesReducerBuilder::CallbackAfterReduction callbackAfterReduction = [impl](gmx::Step /*step*/) {
+        // Get the total after reduction
+        int numTotalBondedInteractionsFound = impl->reductionBuffer_[0];
+        if (numTotalBondedInteractionsFound != impl->expectedNumGlobalBondedInteractions_)
+        {
+            // Give error and exit
+            dd_print_missing_interactions(impl->mdlog_,
+                                          impl->cr_,
+                                          numTotalBondedInteractionsFound,
+                                          impl->expectedNumGlobalBondedInteractions_,
+                                          impl->mtop_,
+                                          impl->localTopology_,
+                                          impl->localState_.x,
+                                          impl->localState_.box); // Does not return
+        }
+    };
+
+    observablesReducerBuilder->addSubscriber(
+            1, std::move(callbackFromBuilder), std::move(callbackAfterReduction));
 }
 
 LocalTopologyChecker::~LocalTopologyChecker() = default;
@@ -440,59 +480,25 @@ LocalTopologyChecker& LocalTopologyChecker::operator=(LocalTopologyChecker&& oth
 
 void LocalTopologyChecker::scheduleCheckOfLocalTopology(const int numBondedInteractionsToReduce)
 {
-    impl_->numBondedInteractionsToReduce_ = numBondedInteractionsToReduce;
-    // Note that it's possible for this to still be true from the last
-    // time it was set, e.g. if repartitioning was triggered before
-    // global communication that would have acted on the true
-    // value. This could happen for example when replica exchange took
-    // place soon after a partition.
-    impl_->shouldCheckNumberOfBondedInteractions_ = true;
-    // Clear the old global value, which is now invalid
-    impl_->numBondedInteractionsOverAllDomains_.reset();
-}
-
-bool LocalTopologyChecker::shouldCheckNumberOfBondedInteractions() const
-{
-    return impl_->shouldCheckNumberOfBondedInteractions_;
-}
-
-int LocalTopologyChecker::numBondedInteractions() const
-{
-    return impl_->numBondedInteractionsToReduce_;
-}
-
-void LocalTopologyChecker::setNumberOfBondedInteractionsOverAllDomains(const int newValue)
-{
-    GMX_RELEASE_ASSERT(!impl_->numBondedInteractionsOverAllDomains_.has_value(),
-                       "Cannot set number of bonded interactions because it is already set");
-    impl_->numBondedInteractionsOverAllDomains_.emplace(newValue);
-}
-
-void LocalTopologyChecker::checkNumberOfBondedInteractions(const gmx_localtop_t* top_local,
-                                                           ArrayRef<const RVec>  x,
-                                                           const matrix          box)
-{
-    if (impl_->shouldCheckNumberOfBondedInteractions_)
-    {
-        GMX_RELEASE_ASSERT(impl_->numBondedInteractionsOverAllDomains_.has_value(),
-                           "The check for the total number of bonded interactions requires the "
-                           "value to have been reduced across all domains");
-        if (impl_->numBondedInteractionsOverAllDomains_.value() != impl_->expectedNumGlobalBondedInteractions_)
-        {
-            dd_print_missing_interactions(impl_->mdlog_,
-                                          impl_->cr_,
-                                          impl_->numBondedInteractionsOverAllDomains_.value(),
-                                          impl_->expectedNumGlobalBondedInteractions_,
-                                          impl_->mtop_,
-                                          top_local,
-                                          x,
-                                          box); // Does not return
-        }
-        // Now that the value is set and the check complete, future
-        // global communication should not compute the value until
-        // after the next partitioning.
-        impl_->shouldCheckNumberOfBondedInteractions_ = false;
-    }
+    // Fill the reduction buffer with the value from this domain to reduce
+    impl_->reductionBuffer_[0] = double(numBondedInteractionsToReduce);
+
+    // Pass the post-reduction callback to the ObservablesReducer via
+    // the callback it gave us for the purpose.
+    //
+    // Note that it's possible that the callbackAfterReduction is already
+    // outstanding, e.g. if repartitioning was triggered before
+    // observables were reduced. This could happen for example when
+    // replica exchange took place soon after a partition. If so, the
+    // callback will be called again. So long as there is no race
+    // between the calls to this function and the calls to
+    // ObservablesReducer for reduction, this will work correctly. It
+    // could be made safer e.g. with checks against duplicate
+    // callbacks, but there is no problem to solve.
+    //
+    // There is no need to check the return value from this callback,
+    // as it is not an error to request reduction at a future step.
+    impl_->callbackToRequireReduction_(ReductionRequirement::Eventually);
 }
 
 } // namespace gmx