Move computeSlowForces into stepWork
[alexxy/gromacs.git] / src / gromacs / mdlib / forcerec.cpp
index 6144b3a2faaa7edeedf7cb0bdef726972987532d..0a9f0dc89f6d760a31fcef32f4cdacd9a7982daf 100644 (file)
@@ -82,9 +82,8 @@
 #include "gromacs/mdtypes/inputrec.h"
 #include "gromacs/mdtypes/interaction_const.h"
 #include "gromacs/mdtypes/md_enums.h"
-#include "gromacs/nbnxm/gpu_data_mgmt.h"
+#include "gromacs/mdtypes/multipletimestepping.h"
 #include "gromacs/nbnxm/nbnxm.h"
-#include "gromacs/nbnxm/nbnxm_geometry.h"
 #include "gromacs/pbcutil/ishift.h"
 #include "gromacs/pbcutil/pbc.h"
 #include "gromacs/tables/forcetable.h"
@@ -615,7 +614,10 @@ void forcerec_set_ranges(t_forcerec* fr, int natoms_force, int natoms_force_cons
     fr->natoms_force        = natoms_force;
     fr->natoms_force_constr = natoms_force_constr;
 
-    fr->forceHelperBuffers->resize(natoms_f_novirsum);
+    for (auto& forceHelperBuffers : fr->forceHelperBuffers)
+    {
+        forceHelperBuffers.resize(natoms_f_novirsum);
+    }
 }
 
 static real cutoff_inf(real cutoff)
@@ -1156,11 +1158,26 @@ void init_forcerec(FILE*                            fp,
     /* 1-4 interaction electrostatics */
     fr->fudgeQQ = mtop->ffparams.fudgeQQ;
 
-    const bool haveDirectVirialContributions =
-            (EEL_FULL(ic->eeltype) || EVDW_PME(ic->vdwtype) || fr->forceProviders->hasForceProvider()
-             || gmx_mtop_ftype_count(mtop, F_POSRES) > 0 || gmx_mtop_ftype_count(mtop, F_FBPOSRES) > 0
-             || ir->nwall > 0 || ir->bPull || ir->bRot || ir->bIMD);
-    fr->forceHelperBuffers = std::make_unique<ForceHelperBuffers>(haveDirectVirialContributions);
+    // Multiple time stepping
+    fr->useMts = ir->useMts;
+
+    if (fr->useMts)
+    {
+        gmx::assertMtsRequirements(*ir);
+    }
+
+    const bool haveDirectVirialContributionsFast =
+            fr->forceProviders->hasForceProvider() || gmx_mtop_ftype_count(mtop, F_POSRES) > 0
+            || gmx_mtop_ftype_count(mtop, F_FBPOSRES) > 0 || ir->nwall > 0 || ir->bPull || ir->bRot
+            || ir->bIMD;
+    const bool haveDirectVirialContributionsSlow = EEL_FULL(ic->eeltype) || EVDW_PME(ic->vdwtype);
+    for (int i = 0; i < (fr->useMts ? 2 : 1); i++)
+    {
+        bool haveDirectVirialContributions =
+                (((!fr->useMts || i == 0) && haveDirectVirialContributionsFast)
+                 || ((!fr->useMts || i == 1) && haveDirectVirialContributionsSlow));
+        fr->forceHelperBuffers.emplace_back(haveDirectVirialContributions);
+    }
 
     if (fr->shift_vec == nullptr)
     {
@@ -1264,9 +1281,43 @@ void init_forcerec(FILE*                            fp,
     }
 
     /* Initialize the thread working data for bonded interactions */
-    fr->listedForces.emplace_back(
-            mtop->ffparams, mtop->groups.groups[SimulationAtomGroupType::EnergyOutput].size(),
-            gmx_omp_nthreads_get(emntBonded), ListedForces::interactionSelectionAll(), fp);
+    if (fr->useMts)
+    {
+        // Add one ListedForces object for each MTS level
+        bool isFirstLevel = true;
+        for (const auto& mtsLevel : ir->mtsLevels)
+        {
+            ListedForces::InteractionSelection interactionSelection;
+            const auto&                        forceGroups = mtsLevel.forceGroups;
+            if (forceGroups[static_cast<int>(gmx::MtsForceGroups::Pair)])
+            {
+                interactionSelection.set(static_cast<int>(ListedForces::InteractionGroup::Pairs));
+            }
+            if (forceGroups[static_cast<int>(gmx::MtsForceGroups::Dihedral)])
+            {
+                interactionSelection.set(static_cast<int>(ListedForces::InteractionGroup::Dihedrals));
+            }
+            if (forceGroups[static_cast<int>(gmx::MtsForceGroups::Angle)])
+            {
+                interactionSelection.set(static_cast<int>(ListedForces::InteractionGroup::Angles));
+            }
+            if (isFirstLevel)
+            {
+                interactionSelection.set(static_cast<int>(ListedForces::InteractionGroup::Rest));
+                isFirstLevel = false;
+            }
+            fr->listedForces.emplace_back(
+                    mtop->ffparams, mtop->groups.groups[SimulationAtomGroupType::EnergyOutput].size(),
+                    gmx_omp_nthreads_get(emntBonded), interactionSelection, fp);
+        }
+    }
+    else
+    {
+        // Add one ListedForces object with all listed interactions
+        fr->listedForces.emplace_back(
+                mtop->ffparams, mtop->groups.groups[SimulationAtomGroupType::EnergyOutput].size(),
+                gmx_omp_nthreads_get(emntBonded), ListedForces::interactionSelectionAll(), fp);
+    }
 
     // QM/MM initialization if requested
     if (ir->bQMMM)