Move responsibility for checking bondeds in reverse topology
[alexxy/gromacs.git] / src / gromacs / domdec / domdec_topology.cpp
index c4a5ac2cd82016b721712755274359af3b669881..2380a813e83311e05df29b45d4952b78065aef1d 100644 (file)
@@ -52,6 +52,7 @@
 
 #include <algorithm>
 #include <memory>
+#include <optional>
 #include <string>
 
 #include "gromacs/domdec/domdec.h"
@@ -123,9 +124,9 @@ struct thread_work_t
 
     InteractionDefinitions         idef;               /**< Partial local topology */
     std::unique_ptr<gmx::VsitePbc> vsitePbc = nullptr; /**< vsite PBC structure */
-    int                            nbonded  = 0;       /**< The number of bondeds in this struct */
-    ListOfLists<int>               excl;               /**< List of exclusions */
-    int                            excl_count = 0;     /**< The total exclusion count for \p excl */
+    int              numBondedInteractions  = 0; /**< The number of bonded interactions observed */
+    ListOfLists<int> excl;                       /**< List of exclusions */
+    int              excl_count = 0;             /**< The total exclusion count for \p excl */
 };
 
 /*! \brief Options for setting up gmx_reverse_top_t */
@@ -174,8 +175,34 @@ struct gmx_reverse_top_t::Impl
     //! \brief Intermolecular reverse ilist
     reverse_ilist_t ril_intermol;
 
-    //! The interaction count for the interactions that have to be present
-    int numInteractionsToCheck;
+    /*! \brief Data to help check reverse topology construction
+     *
+     * Partitioning could incorrectly miss a bonded interaction.
+     * However, checking for that requires a global communication
+     * stage, which does not otherwise happen during partitioning. So,
+     * for performance, we do that alongside the first global energy
+     * reduction after a new DD is made. These variables handle
+     * whether the check happens, its input for this domain, output
+     * across all domains, and the expected value it should match. */
+    /*! \{ */
+    /*! \brief Number of bonded interactions found in the reverse
+     * topology for this domain. */
+    int numBondedInteractions = 0;
+    /*! \brief Whether to check at the next global communication
+     * stage the total number of bonded interactions found.
+     *
+     * Cleared after that number is found. */
+    bool shouldCheckNumberOfBondedInteractions = false;
+    /*! \brief The total number of bonded interactions found in
+     * the reverse topology across all domains.
+     *
+     * Only has a value after reduction across all ranks, which is
+     * removed when it is again time to check after a new
+     * partition. */
+    std::optional<int> numBondedInteractionsOverAllDomains;
+    //! The number of bonded interactions computed from the full topology
+    int expectedNumGlobalBondedInteractions = 0;
+    /*! \} */
 
     /* Work data structures for multi-threading */
     //! \brief Thread work array for local topology generation
@@ -386,11 +413,11 @@ static void printMissingInteractionsAtoms(const gmx::MDLogger&          mdlog,
     }
 }
 
-void dd_print_missing_interactions(const gmx::MDLogger&           mdlog,
-                                   t_commrec*                     cr,
-                                   int                            local_count,
-                                   const gmx_mtop_t&              top_global,
-                                   const gmx_localtop_t*          top_local,
+void dd_print_missing_interactions(const gmx::MDLogger&  mdlog,
+                                   t_commrec*            cr,
+                                   int                   numBondedInteractionsOverAllDomains,
+                                   const gmx_mtop_t&     top_global,
+                                   const gmx_localtop_t* top_local,
                                    gmx::ArrayRef<const gmx::RVec> x,
                                    const matrix                   box)
 {
@@ -402,7 +429,8 @@ void dd_print_missing_interactions(const gmx::MDLogger&           mdlog,
                     "Not all bonded interactions have been properly assigned to the domain "
                     "decomposition cells");
 
-    const int ndiff_tot = local_count - dd->nbonded_global;
+    const int ndiff_tot = numBondedInteractionsOverAllDomains
+                          - dd->reverse_top->impl_->expectedNumGlobalBondedInteractions;
 
     for (int ftype = 0; ftype < F_NRE; ftype++)
     {
@@ -415,8 +443,8 @@ void dd_print_missing_interactions(const gmx::MDLogger&           mdlog,
     if (DDMASTER(dd))
     {
         GMX_LOG(mdlog.warning).appendText("A list of missing interactions:");
-        int rest_global = dd->nbonded_global;
-        int rest_local  = local_count;
+        int rest_global = dd->reverse_top->impl_->expectedNumGlobalBondedInteractions;
+        int rest        = numBondedInteractionsOverAllDomains;
         for (int ftype = 0; ftype < F_NRE; ftype++)
         {
             /* In the reverse and local top all constraints are merged
@@ -440,11 +468,11 @@ void dd_print_missing_interactions(const gmx::MDLogger&           mdlog,
                                                  -ndiff);
                 }
                 rest_global -= n;
-                rest_local -= cl[ftype];
+                rest -= cl[ftype];
             }
         }
 
-        int ndiff = rest_local - rest_global;
+        int ndiff = rest - rest_global;
         if (ndiff != 0)
         {
             GMX_LOG(mdlog.warning).appendTextFormatted("%20s of %6d missing %6d", "exclusions", rest_global, -ndiff);
@@ -470,7 +498,7 @@ void dd_print_missing_interactions(const gmx::MDLogger&           mdlog,
                 "two-body cut-off distance (%g nm), see option -rdd, for pairs and tabulated bonds "
                 "also see option -ddcheck",
                 -ndiff_tot,
-                cr->dd->nbonded_global,
+                dd->reverse_top->impl_->expectedNumGlobalBondedInteractions,
                 dd_cutoff_multibody(dd),
                 dd_cutoff_twobody(dd));
     }
@@ -693,10 +721,10 @@ gmx_reverse_top_t::Impl::Impl(const gmx_mtop_t&        mtop,
         fprintf(debug, "The total size of the atom to interaction index is %d integers\n", ril_mt_tot_size);
     }
 
-    numInteractionsToCheck = 0;
+    expectedNumGlobalBondedInteractions = 0;
     for (const gmx_molblock_t& molblock : mtop.molblock)
     {
-        numInteractionsToCheck += molblock.nmol * nint_mt[molblock.type];
+        expectedNumGlobalBondedInteractions += molblock.nmol * nint_mt[molblock.type];
     }
 
     /* Make an intermolecular reverse top, if necessary */
@@ -711,7 +739,7 @@ gmx_reverse_top_t::Impl::Impl(const gmx_mtop_t&        mtop,
         GMX_RELEASE_ASSERT(mtop.intermolecular_ilist,
                            "We should have an ilist when intermolecular interactions are on");
 
-        numInteractionsToCheck += make_reverse_ilist(
+        expectedNumGlobalBondedInteractions += make_reverse_ilist(
                 *mtop.intermolecular_ilist, &atoms_global, options, AtomLinkRule::FirstAtom, &ril_intermol);
     }
 
@@ -769,8 +797,6 @@ void dd_make_reverse_top(FILE*                           fplog,
     dd->reverse_top = std::make_unique<gmx_reverse_top_t>(
             mtop, inputrec.efep != FreeEnergyPerturbationType::No, rtOptions);
 
-    dd->nbonded_global = dd->reverse_top->impl_->numInteractionsToCheck;
-
     dd->haveExclusions = false;
     for (const gmx_molblock_t& molb : mtop.molblock)
     {
@@ -1068,35 +1094,40 @@ static void combine_idef(InteractionDefinitions* dest, gmx::ArrayRef<const threa
     }
 }
 
-/*! \brief Check and when available assign bonded interactions for local atom i
+/*! \brief Determine whether the local domain has responsibility for
+ * any of the bonded interactions for local atom i
+ *
+ * \returns The total number of bonded interactions for this atom for
+ * which this domain is responsible.
  */
-static inline void check_assign_interactions_atom(int                       i,
-                                                  int                       i_gl,
-                                                  int                       mol,
-                                                  int                       i_mol,
-                                                  int                       numAtomsInMolecule,
-                                                  gmx::ArrayRef<const int>  index,
-                                                  gmx::ArrayRef<const int>  rtil,
-                                                  gmx_bool                  bInterMolInteractions,
-                                                  int                       ind_start,
-                                                  int                       ind_end,
-                                                  const gmx_ga2la_t&        ga2la,
-                                                  const gmx_domdec_zones_t* zones,
-                                                  const gmx_molblock_t*     molb,
-                                                  gmx_bool                  bRCheckMB,
-                                                  const ivec                rcheck,
-                                                  gmx_bool                  bRCheck2B,
-                                                  real                      rc2,
-                                                  t_pbc*                    pbc_null,
-                                                  rvec*                     cg_cm,
-                                                  const t_iparams*          ip_in,
-                                                  InteractionDefinitions*   idef,
-                                                  int                       iz,
-                                                  const DDBondedChecking    ddBondedChecking,
-                                                  int*                      nbonded_local)
+static inline int assign_interactions_atom(int                       i,
+                                           int                       i_gl,
+                                           int                       mol,
+                                           int                       i_mol,
+                                           int                       numAtomsInMolecule,
+                                           gmx::ArrayRef<const int>  index,
+                                           gmx::ArrayRef<const int>  rtil,
+                                           gmx_bool                  bInterMolInteractions,
+                                           int                       ind_start,
+                                           int                       ind_end,
+                                           const gmx_ga2la_t&        ga2la,
+                                           const gmx_domdec_zones_t* zones,
+                                           const gmx_molblock_t*     molb,
+                                           gmx_bool                  bRCheckMB,
+                                           const ivec                rcheck,
+                                           gmx_bool                  bRCheck2B,
+                                           real                      rc2,
+                                           t_pbc*                    pbc_null,
+                                           rvec*                     cg_cm,
+                                           const t_iparams*          ip_in,
+                                           InteractionDefinitions*   idef,
+                                           int                       iz,
+                                           const DDBondedChecking    ddBondedChecking)
 {
     gmx::ArrayRef<const DDPairInteractionRanges> iZones = zones->iZones;
 
+    int numBondedInteractions = 0;
+
     int j = ind_start;
     while (j < ind_end)
     {
@@ -1185,7 +1216,7 @@ static inline void check_assign_interactions_atom(int                       i,
             else
             {
                 /* Assign this multi-body bonded interaction to
-                 * the local node if we have all the atoms involved
+                 * the local domain if we have all the atoms involved
                  * (local or communicated) and the minimum zone shift
                  * in each dimension is zero, for dimensions
                  * with 2 DD cells an extra check may be necessary.
@@ -1253,12 +1284,14 @@ static inline void check_assign_interactions_atom(int                       i,
                 if (ddBondedChecking == DDBondedChecking::All
                     || !(interaction_function[ftype].flags & IF_LIMZERO))
                 {
-                    (*nbonded_local)++;
+                    numBondedInteractions++;
                 }
             }
         }
         j += 1 + nral_rt(ftype);
     }
+
+    return numBondedInteractions;
 }
 
 /*! \brief This function looks up and assigns bonded interactions for zone iz.
@@ -1289,7 +1322,7 @@ static int make_bondeds_zone(gmx_reverse_top_t*                 rt,
 
     const auto ddBondedChecking = rt->impl_->options.ddBondedChecking;
 
-    int nbonded_local = 0;
+    int numBondedInteractions = 0;
 
     for (int i : atomRange)
     {
@@ -1300,30 +1333,29 @@ static int make_bondeds_zone(gmx_reverse_top_t*                 rt,
         gmx::ArrayRef<const int>     index = rt->impl_->ril_mt[mt].index;
         gmx::ArrayRef<const t_iatom> rtil  = rt->impl_->ril_mt[mt].il;
 
-        check_assign_interactions_atom(i,
-                                       i_gl,
-                                       mol,
-                                       i_mol,
-                                       rt->impl_->ril_mt[mt].numAtomsInMolecule,
-                                       index,
-                                       rtil,
-                                       FALSE,
-                                       index[i_mol],
-                                       index[i_mol + 1],
-                                       ga2la,
-                                       zones,
-                                       &molb[mb],
-                                       bRCheckMB,
-                                       rcheck,
-                                       bRCheck2B,
-                                       rc2,
-                                       pbc_null,
-                                       cg_cm,
-                                       ip_in,
-                                       idef,
-                                       izone,
-                                       ddBondedChecking,
-                                       &nbonded_local);
+        numBondedInteractions += assign_interactions_atom(i,
+                                                          i_gl,
+                                                          mol,
+                                                          i_mol,
+                                                          rt->impl_->ril_mt[mt].numAtomsInMolecule,
+                                                          index,
+                                                          rtil,
+                                                          FALSE,
+                                                          index[i_mol],
+                                                          index[i_mol + 1],
+                                                          ga2la,
+                                                          zones,
+                                                          &molb[mb],
+                                                          bRCheckMB,
+                                                          rcheck,
+                                                          bRCheck2B,
+                                                          rc2,
+                                                          pbc_null,
+                                                          cg_cm,
+                                                          ip_in,
+                                                          idef,
+                                                          izone,
+                                                          ddBondedChecking);
 
 
         if (rt->impl_->bIntermolecularInteractions)
@@ -1332,34 +1364,33 @@ static int make_bondeds_zone(gmx_reverse_top_t*                 rt,
             index = rt->impl_->ril_intermol.index;
             rtil  = rt->impl_->ril_intermol.il;
 
-            check_assign_interactions_atom(i,
-                                           i_gl,
-                                           mol,
-                                           i_mol,
-                                           rt->impl_->ril_mt[mt].numAtomsInMolecule,
-                                           index,
-                                           rtil,
-                                           TRUE,
-                                           index[i_gl],
-                                           index[i_gl + 1],
-                                           ga2la,
-                                           zones,
-                                           &molb[mb],
-                                           bRCheckMB,
-                                           rcheck,
-                                           bRCheck2B,
-                                           rc2,
-                                           pbc_null,
-                                           cg_cm,
-                                           ip_in,
-                                           idef,
-                                           izone,
-                                           ddBondedChecking,
-                                           &nbonded_local);
+            numBondedInteractions += assign_interactions_atom(i,
+                                                              i_gl,
+                                                              mol,
+                                                              i_mol,
+                                                              rt->impl_->ril_mt[mt].numAtomsInMolecule,
+                                                              index,
+                                                              rtil,
+                                                              TRUE,
+                                                              index[i_gl],
+                                                              index[i_gl + 1],
+                                                              ga2la,
+                                                              zones,
+                                                              &molb[mb],
+                                                              bRCheckMB,
+                                                              rcheck,
+                                                              bRCheck2B,
+                                                              rc2,
+                                                              pbc_null,
+                                                              cg_cm,
+                                                              ip_in,
+                                                              idef,
+                                                              izone,
+                                                              ddBondedChecking);
         }
     }
 
-    return nbonded_local;
+    return numBondedInteractions;
 }
 
 /*! \brief Set the exclusion data for i-zone \p iz */
@@ -1438,19 +1469,19 @@ static void make_exclusions_zone(ArrayRef<const int>               globalAtomInd
 }
 
 /*! \brief Generate and store all required local bonded interactions in \p idef and local exclusions in \p lexcls */
-static int make_local_bondeds_excls(gmx_domdec_t*           dd,
-                                    gmx_domdec_zones_t*     zones,
-                                    const gmx_mtop_t&       mtop,
-                                    const int*              cginfo,
-                                    gmx_bool                bRCheckMB,
-                                    ivec                    rcheck,
-                                    gmx_bool                bRCheck2B,
-                                    real                    rc,
-                                    t_pbc*                  pbc_null,
-                                    rvec*                   cg_cm,
-                                    InteractionDefinitions* idef,
-                                    ListOfLists<int>*       lexcls,
-                                    int*                    excl_count)
+static void make_local_bondeds_excls(gmx_domdec_t*           dd,
+                                     gmx_domdec_zones_t*     zones,
+                                     const gmx_mtop_t&       mtop,
+                                     const int*              cginfo,
+                                     gmx_bool                bRCheckMB,
+                                     ivec                    rcheck,
+                                     gmx_bool                bRCheck2B,
+                                     real                    rc,
+                                     t_pbc*                  pbc_null,
+                                     rvec*                   cg_cm,
+                                     InteractionDefinitions* idef,
+                                     ListOfLists<int>*       lexcls,
+                                     int*                    excl_count)
 {
     int nzone_bondeds = 0;
 
@@ -1475,7 +1506,7 @@ static int make_local_bondeds_excls(gmx_domdec_t*           dd,
 
     /* Clear the counts */
     idef->clear();
-    int nbonded_local = 0;
+    dd->reverse_top->impl_->numBondedInteractions = 0;
 
     lexcls->clear();
     *excl_count = 0;
@@ -1506,21 +1537,22 @@ static int make_local_bondeds_excls(gmx_domdec_t*           dd,
                     idef_t->clear();
                 }
 
-                rt->impl_->th_work[thread].nbonded = make_bondeds_zone(rt,
-                                                                       dd->globalAtomIndices,
-                                                                       *dd->ga2la,
-                                                                       zones,
-                                                                       mtop.molblock,
-                                                                       bRCheckMB,
-                                                                       rcheck,
-                                                                       bRCheck2B,
-                                                                       rc2,
-                                                                       pbc_null,
-                                                                       cg_cm,
-                                                                       idef->iparams.data(),
-                                                                       idef_t,
-                                                                       izone,
-                                                                       gmx::Range<int>(cg0t, cg1t));
+                rt->impl_->th_work[thread].numBondedInteractions =
+                        make_bondeds_zone(rt,
+                                          dd->globalAtomIndices,
+                                          *dd->ga2la,
+                                          zones,
+                                          mtop.molblock,
+                                          bRCheckMB,
+                                          rcheck,
+                                          bRCheck2B,
+                                          rc2,
+                                          pbc_null,
+                                          cg_cm,
+                                          idef->iparams.data(),
+                                          idef_t,
+                                          izone,
+                                          gmx::Range<int>(cg0t, cg1t));
 
                 if (izone < numIZonesForExclusions)
                 {
@@ -1561,7 +1593,7 @@ static int make_local_bondeds_excls(gmx_domdec_t*           dd,
 
         for (const thread_work_t& th_work : rt->impl_->th_work)
         {
-            nbonded_local += th_work.nbonded;
+            dd->reverse_top->impl_->numBondedInteractions += th_work.numBondedInteractions;
         }
 
         if (izone < numIZonesForExclusions)
@@ -1577,12 +1609,70 @@ static int make_local_bondeds_excls(gmx_domdec_t*           dd,
         }
     }
 
+    // Note that it's possible for this to still be true from the last
+    // time it was set, e.g. if repartitioning was triggered before
+    // global communication that would have acted on the true
+    // value. This could happen for example when replica exchange took
+    // place soon after a partition.
+    dd->reverse_top->impl_->shouldCheckNumberOfBondedInteractions = true;
+    // Clear the old global value, which is now invalid
+    dd->reverse_top->impl_->numBondedInteractionsOverAllDomains.reset();
+
     if (debug)
     {
         fprintf(debug, "We have %d exclusions, check count %d\n", lexcls->numElements(), *excl_count);
     }
+}
+
+bool shouldCheckNumberOfBondedInteractions(const gmx_domdec_t& dd)
+{
+    return dd.reverse_top->impl_->shouldCheckNumberOfBondedInteractions;
+}
 
-    return nbonded_local;
+int numBondedInteractions(const gmx_domdec_t& dd)
+{
+    return dd.reverse_top->impl_->numBondedInteractions;
+}
+
+void setNumberOfBondedInteractionsOverAllDomains(const gmx_domdec_t& dd, int newValue)
+{
+    GMX_RELEASE_ASSERT(!dd.reverse_top->impl_->numBondedInteractionsOverAllDomains.has_value(),
+                       "Cannot set number of bonded interactions because it is already set");
+    dd.reverse_top->impl_->numBondedInteractionsOverAllDomains.emplace(newValue);
+}
+
+void checkNumberOfBondedInteractions(const gmx::MDLogger&           mdlog,
+                                     t_commrec*                     cr,
+                                     const gmx_mtop_t&              top_global,
+                                     const gmx_localtop_t*          top_local,
+                                     gmx::ArrayRef<const gmx::RVec> x,
+                                     const matrix                   box)
+{
+    GMX_RELEASE_ASSERT(
+            DOMAINDECOMP(cr),
+            "No need to check number of bonded interactions when not using domain decomposition");
+    if (cr->dd->reverse_top->impl_->shouldCheckNumberOfBondedInteractions)
+    {
+        GMX_RELEASE_ASSERT(cr->dd->reverse_top->impl_->numBondedInteractionsOverAllDomains.has_value(),
+                           "The check for the total number of bonded interactions requires the "
+                           "value to have been reduced across all domains");
+        if (cr->dd->reverse_top->impl_->numBondedInteractionsOverAllDomains.value()
+            != cr->dd->reverse_top->impl_->expectedNumGlobalBondedInteractions)
+        {
+            dd_print_missing_interactions(
+                    mdlog,
+                    cr,
+                    cr->dd->reverse_top->impl_->numBondedInteractionsOverAllDomains.value(),
+                    top_global,
+                    top_local,
+                    x,
+                    box); // Does not return
+        }
+        // Now that the value is set and the check complete, future
+        // global communication should not compute the value until
+        // after the next partitioning.
+        cr->dd->reverse_top->impl_->shouldCheckNumberOfBondedInteractions = false;
+    }
 }
 
 void dd_make_local_top(gmx_domdec_t*       dd,
@@ -1663,19 +1753,19 @@ void dd_make_local_top(gmx_domdec_t*       dd,
         }
     }
 
-    dd->nbonded_local = make_local_bondeds_excls(dd,
-                                                 zones,
-                                                 mtop,
-                                                 fr->cginfo.data(),
-                                                 bRCheckMB,
-                                                 rcheck,
-                                                 bRCheck2B,
-                                                 rc,
-                                                 pbc_null,
-                                                 cgcm_or_x,
-                                                 &ltop->idef,
-                                                 &ltop->excls,
-                                                 &nexcl);
+    make_local_bondeds_excls(dd,
+                             zones,
+                             mtop,
+                             fr->cginfo.data(),
+                             bRCheckMB,
+                             rcheck,
+                             bRCheck2B,
+                             rc,
+                             pbc_null,
+                             cgcm_or_x,
+                             &ltop->idef,
+                             &ltop->excls,
+                             &nexcl);
 
     /* The ilist is not sorted yet,
      * we can only do this when we have the charge arrays.