Add MTS support for pull and AWH
[alexxy/gromacs.git] / src / gromacs / mdlib / sim_util.cpp
index ef27166978b2d7ebeb9f5a781e211fa7d18ebaa2..17a9e44bfa2b41640631b40e7930fdef6823a57e 100644 (file)
@@ -580,7 +580,8 @@ static bool haveSpecialForces(const t_inputrec&          inputrec,
  * \param[in]     mdatoms          Per atom properties
  * \param[in]     lambda           Array of free-energy lambda values
  * \param[in]     stepWork         Step schedule flags
- * \param[in,out] forceWithVirial  Force and virial buffers
+ * \param[in,out] forceWithVirialMtsLevel0  Force and virial for MTS level0 forces
+ * \param[in,out] forceWithVirialMtsLevel1  Force and virial for MTS level1 forces, can be nullptr
  * \param[in,out] enerd            Energy buffer
  * \param[in,out] ed               Essential dynamics pointer
  * \param[in]     didNeighborSearch Tells if we did neighbor searching this step, used for ED sampling
@@ -604,7 +605,8 @@ static void computeSpecialForces(FILE*                          fplog,
                                  const t_mdatoms*               mdatoms,
                                  gmx::ArrayRef<const real>      lambda,
                                  const StepWorkload&            stepWork,
-                                 gmx::ForceWithVirial*          forceWithVirial,
+                                 gmx::ForceWithVirial*          forceWithVirialMtsLevel0,
+                                 gmx::ForceWithVirial*          forceWithVirialMtsLevel1,
                                  gmx_enerdata_t*                enerd,
                                  gmx_edsam*                     ed,
                                  bool                           didNeighborSearch)
@@ -615,7 +617,7 @@ static void computeSpecialForces(FILE*                          fplog,
     if (stepWork.computeForces)
     {
         gmx::ForceProviderInput  forceProviderInput(x, *mdatoms, t, box, *cr);
-        gmx::ForceProviderOutput forceProviderOutput(forceWithVirial, enerd);
+        gmx::ForceProviderOutput forceProviderOutput(forceWithVirialMtsLevel0, enerd);
 
         /* Collect forces from modules */
         forceProviders->calculateForces(forceProviderInput, &forceProviderOutput);
@@ -623,26 +625,36 @@ static void computeSpecialForces(FILE*                          fplog,
 
     if (inputrec->bPull && pull_have_potential(pull_work))
     {
-        pull_potential_wrapper(cr, inputrec, box, x, forceWithVirial, mdatoms, enerd, pull_work,
-                               lambda.data(), t, wcycle);
+        const int mtsLevel = forceGroupMtsLevel(inputrec->mtsLevels, gmx::MtsForceGroups::Pull);
+        if (mtsLevel == 0 || stepWork.computeSlowForces)
+        {
+            auto& forceWithVirial = (mtsLevel == 0) ? forceWithVirialMtsLevel0 : forceWithVirialMtsLevel1;
+            pull_potential_wrapper(cr, inputrec, box, x, forceWithVirial, mdatoms, enerd, pull_work,
+                                   lambda.data(), t, wcycle);
+        }
     }
     if (awh)
     {
-        const bool          needForeignEnergyDifferences = awh->needForeignEnergyDifferences(step);
-        std::vector<double> foreignLambdaDeltaH, foreignLambdaDhDl;
-        if (needForeignEnergyDifferences)
+        const int mtsLevel = forceGroupMtsLevel(inputrec->mtsLevels, gmx::MtsForceGroups::Pull);
+        if (mtsLevel == 0 || stepWork.computeSlowForces)
         {
-            enerd->foreignLambdaTerms.finalizePotentialContributions(enerd->dvdl_lin, lambda,
-                                                                     *inputrec->fepvals);
-            std::tie(foreignLambdaDeltaH, foreignLambdaDhDl) = enerd->foreignLambdaTerms.getTerms(cr);
-        }
+            const bool needForeignEnergyDifferences = awh->needForeignEnergyDifferences(step);
+            std::vector<double> foreignLambdaDeltaH, foreignLambdaDhDl;
+            if (needForeignEnergyDifferences)
+            {
+                enerd->foreignLambdaTerms.finalizePotentialContributions(enerd->dvdl_lin, lambda,
+                                                                         *inputrec->fepvals);
+                std::tie(foreignLambdaDeltaH, foreignLambdaDhDl) = enerd->foreignLambdaTerms.getTerms(cr);
+            }
 
-        enerd->term[F_COM_PULL] += awh->applyBiasForcesAndUpdateBias(
-                inputrec->pbcType, mdatoms->massT, foreignLambdaDeltaH, foreignLambdaDhDl, box,
-                forceWithVirial, t, step, wcycle, fplog);
+            auto& forceWithVirial = (mtsLevel == 0) ? forceWithVirialMtsLevel0 : forceWithVirialMtsLevel1;
+            enerd->term[F_COM_PULL] += awh->applyBiasForcesAndUpdateBias(
+                    inputrec->pbcType, mdatoms->massT, foreignLambdaDeltaH, foreignLambdaDhDl, box,
+                    forceWithVirial, t, step, wcycle, fplog);
+        }
     }
 
-    rvec* f = as_rvec_array(forceWithVirial->force_.data());
+    rvec* f = as_rvec_array(forceWithVirialMtsLevel0->force_.data());
 
     /* Add the forces from enforced rotation potentials (if any) */
     if (inputrec->bRot)
@@ -1781,8 +1793,10 @@ void do_force(FILE*                               fplog,
     }
 
     computeSpecialForces(fplog, cr, inputrec, awh, enforcedRotation, imdSession, pull_work, step, t,
-                         wcycle, fr->forceProviders, box, x.unpaddedArrayRef(), mdatoms, lambda, stepWork,
-                         &forceOutMtsLevel0.forceWithVirial(), enerd, ed, stepWork.doNeighborSearch);
+                         wcycle, fr->forceProviders, box, x.unpaddedArrayRef(), mdatoms, lambda,
+                         stepWork, &forceOutMtsLevel0.forceWithVirial(),
+                         forceOutMtsLevel1 ? &forceOutMtsLevel1->forceWithVirial() : nullptr, enerd,
+                         ed, stepWork.doNeighborSearch);
 
     GMX_ASSERT(!(nonbondedAtMtsLevel1 && stepWork.useGpuFBufferOps),
                "The schedule below does not allow for nonbonded MTS with GPU buffer ops");