Enable splitting of listed interaction calculation
[alexxy/gromacs.git] / src / gromacs / listed_forces / listed_forces.cpp
index c35a06a895fdf25fb27a2df3a8b90a25f60e6294..9548ba44338c6ab72491e053ec5c203f43f78c18 100644 (file)
 #include "manage_threading.h"
 #include "utilities.h"
 
-ListedForces::ListedForces(const int numEnergyGroups, const int numThreads, FILE* fplog) :
+ListedForces::ListedForces(const gmx_ffparams_t&      ffparams,
+                           const int                  numEnergyGroups,
+                           const int                  numThreads,
+                           const InteractionSelection interactionSelection,
+                           FILE*                      fplog) :
+    idefSelection_(ffparams),
     threading_(std::make_unique<bonded_threading_t>(numThreads, numEnergyGroups, fplog)),
-    fcdata_(std::make_unique<t_fcdata>())
+    interactionSelection_(interactionSelection)
 {
 }
 
+ListedForces::ListedForces(ListedForces&& o) noexcept = default;
+
 ListedForces::~ListedForces() = default;
 
-void ListedForces::setup(const InteractionDefinitions& idef, const int numAtomsForce, const bool useGpu)
+//! Copies the selection interactions from \p idefSrc to \p idef
+static void selectInteractions(InteractionDefinitions*                  idef,
+                               const InteractionDefinitions&            idefSrc,
+                               const ListedForces::InteractionSelection interactionSelection)
 {
-    idef_ = &idef;
+    const bool selectPairs =
+            interactionSelection.test(static_cast<int>(ListedForces::InteractionGroup::Pairs));
+    const bool selectDihedrals =
+            interactionSelection.test(static_cast<int>(ListedForces::InteractionGroup::Dihedrals));
+    const bool selectRest =
+            interactionSelection.test(static_cast<int>(ListedForces::InteractionGroup::Rest));
+
+    for (int ftype = 0; ftype < F_NRE; ftype++)
+    {
+        const t_interaction_function& ifunc = interaction_function[ftype];
+        if (ifunc.flags & IF_BOND)
+        {
+            bool assign = false;
+            if (ifunc.flags & IF_PAIR)
+            {
+                assign = selectPairs;
+            }
+            else if (ifunc.flags & IF_DIHEDRAL)
+            {
+                assign = selectDihedrals;
+            }
+            else
+            {
+                assign = selectRest;
+            }
+            if (assign)
+            {
+                idef->il[ftype] = idefSrc.il[ftype];
+            }
+            else
+            {
+                idef->il[ftype].clear();
+            }
+        }
+    }
+}
+
+void ListedForces::setup(const InteractionDefinitions& domainIdef, const int numAtomsForce, const bool useGpu)
+{
+    if (interactionSelection_.all())
+    {
+        // Avoid the overhead of copying all interaction lists by simply setting the reference to the domain idef
+        idef_ = &domainIdef;
+    }
+    else
+    {
+        idef_ = &idefSelection_;
+
+        selectInteractions(&idefSelection_, domainIdef, interactionSelection_);
+
+        idefSelection_.ilsort = domainIdef.ilsort;
+    }
 
     setup_bonded_threading(threading_.get(), numAtomsForce, useGpu, *idef_);
 
@@ -478,13 +539,12 @@ static void calcBondedForces(const InteractionDefinitions& idef,
     }
 }
 
-bool ListedForces::haveRestraints() const
+bool ListedForces::haveRestraints(const t_fcdata& fcdata) const
 {
-    GMX_ASSERT(fcdata_, "Need valid fcdata");
-    GMX_ASSERT(fcdata_->orires && fcdata_->disres, "NMR restraints objects should be set up");
+    GMX_ASSERT(fcdata.orires && fcdata.disres, "NMR restraints objects should be set up");
 
-    return (!idef_->il[F_POSRES].empty() || !idef_->il[F_FBPOSRES].empty()
-            || fcdata_->orires->nr > 0 || fcdata_->disres->nres > 0);
+    return (!idef_->il[F_POSRES].empty() || !idef_->il[F_FBPOSRES].empty() || fcdata.orires->nr > 0
+            || fcdata.disres->nres > 0);
 }
 
 bool ListedForces::haveCpuBondeds() const
@@ -492,9 +552,9 @@ bool ListedForces::haveCpuBondeds() const
     return threading_->haveBondeds;
 }
 
-bool ListedForces::haveCpuListedForces() const
+bool ListedForces::haveCpuListedForces(const t_fcdata& fcdata) const
 {
-    return haveCpuBondeds() || haveRestraints();
+    return haveCpuBondeds() || haveRestraints(fcdata);
 }
 
 namespace
@@ -625,6 +685,7 @@ void ListedForces::calculate(struct gmx_wallcycle*          wcycle,
                              const gmx_multisim_t*          ms,
                              const rvec                     x[],
                              gmx::ArrayRef<const gmx::RVec> xWholeMolecules,
+                             t_fcdata*                      fcdata,
                              history_t*                     hist,
                              gmx::ForceOutputs*             forceOutputs,
                              const t_forcerec*              fr,
@@ -636,16 +697,15 @@ void ListedForces::calculate(struct gmx_wallcycle*          wcycle,
                              int*                           global_atom_index,
                              const gmx::StepWorkload&       stepWork)
 {
-    if (!stepWork.computeListedForces)
+    if (interactionSelection_.none() || !stepWork.computeListedForces)
     {
         return;
     }
 
-    const InteractionDefinitions& idef   = *idef_;
-    t_fcdata&                     fcdata = *fcdata_;
+    const InteractionDefinitions& idef = *idef_;
 
     t_pbc pbc_full; /* Full PBC is needed for position restraints */
-    if (haveRestraints())
+    if (haveRestraints(*fcdata))
     {
         if (!idef.il[F_POSRES].empty() || !idef.il[F_FBPOSRES].empty())
         {
@@ -671,24 +731,24 @@ void ListedForces::calculate(struct gmx_wallcycle*          wcycle,
         }
 
         /* Do pre force calculation stuff which might require communication */
-        if (fcdata.orires->nr > 0)
+        if (fcdata->orires->nr > 0)
         {
             GMX_ASSERT(!xWholeMolecules.empty(), "Need whole molecules for orienation restraints");
             enerd->term[F_ORIRESDEV] = calc_orires_dev(
                     ms, idef.il[F_ORIRES].size(), idef.il[F_ORIRES].iatoms.data(), idef.iparams.data(),
-                    md, xWholeMolecules, x, fr->bMolPBC ? pbc : nullptr, fcdata.orires, hist);
+                    md, xWholeMolecules, x, fr->bMolPBC ? pbc : nullptr, fcdata->orires, hist);
         }
-        if (fcdata.disres->nres > 0)
+        if (fcdata->disres->nres > 0)
         {
             calc_disres_R_6(cr, ms, idef.il[F_DISRES].size(), idef.il[F_DISRES].iatoms.data(), x,
-                            fr->bMolPBC ? pbc : nullptr, fcdata.disres, hist);
+                            fr->bMolPBC ? pbc : nullptr, fcdata->disres, hist);
         }
 
         wallcycle_sub_stop(wcycle, ewcsRESTRAINTS);
     }
 
     calc_listed(wcycle, idef, threading_.get(), x, forceOutputs, fr, pbc, enerd, nrnb, lambda, md,
-                &fcdata, global_atom_index, stepWork);
+                fcdata, global_atom_index, stepWork);
 
     /* Check if we have to determine energy differences
      * at foreign lambda's.
@@ -718,7 +778,7 @@ void ListedForces::calculate(struct gmx_wallcycle*          wcycle,
                 }
                 calc_listed_lambda(idef, threading_.get(), x, fr, pbc, forceBufferLambda_,
                                    shiftForceBufferLambda_, &(enerd->foreign_grpp), enerd->foreign_term,
-                                   dvdl, nrnb, lam_i, md, &fcdata, global_atom_index);
+                                   dvdl, nrnb, lam_i, md, fcdata, global_atom_index);
                 sum_epot(enerd->foreign_grpp, enerd->foreign_term);
                 const double dvdlSum = std::accumulate(std::begin(dvdl), std::end(dvdl), 0.);
                 std::fill(std::begin(dvdl), std::end(dvdl), 0.0);