Extract postProcessForceWithShiftForces()
authorBerk Hess <hess@kth.se>
Thu, 14 May 2020 12:51:51 +0000 (12:51 +0000)
committerPaul Bauer <paul.bauer.q@gmail.com>
Thu, 14 May 2020 12:51:51 +0000 (12:51 +0000)
Also added a boolean to check whether we spread vsite forces exactly
once.

src/gromacs/mdlib/sim_util.cpp
src/gromacs/mdtypes/forceoutput.h

index e4d3660eef29874c5b80302ed95e10aaa3697165..31b0955b5b421dc503cefe7fba855ea0723d453f 100644 (file)
@@ -123,6 +123,7 @@ using gmx::ArrayRef;
 using gmx::AtomLocality;
 using gmx::DomainLifetimeWorkload;
 using gmx::ForceOutputs;
+using gmx::ForceWithShiftForces;
 using gmx::InteractionLocality;
 using gmx::RVec;
 using gmx::SimulationWorkload;
@@ -134,8 +135,9 @@ using gmx::StepWorkload;
 // PME-first ordering would suffice).
 static const bool c_disableAlternatingWait = (getenv("GMX_DISABLE_ALTERNATING_GPU_WAIT") != nullptr);
 
-static void sum_forces(rvec f[], gmx::ArrayRef<const gmx::RVec> forceToAdd)
+static void sum_forces(ArrayRef<RVec> f, ArrayRef<const RVec> forceToAdd)
 {
+    GMX_ASSERT(f.size() >= forceToAdd.size(), "Accumulation buffer should be sufficiently large");
     const int end = forceToAdd.size();
 
     int gmx_unused nt = gmx_omp_nthreads_get(emntDefault);
@@ -250,7 +252,7 @@ static void print_large_forces(FILE*                fp,
                                int64_t              step,
                                real                 forceTolerance,
                                ArrayRef<const RVec> x,
-                               const rvec*          f)
+                               ArrayRef<const RVec> f)
 {
     real       force2Tolerance = gmx::square(forceTolerance);
     gmx::index numNonFinite    = 0;
@@ -278,20 +280,60 @@ static void print_large_forces(FILE*                fp,
     }
 }
 
-static void post_process_forces(const t_commrec*          cr,
-                                int64_t                   step,
-                                t_nrnb*                   nrnb,
-                                gmx_wallcycle_t           wcycle,
-                                const matrix              box,
-                                ArrayRef<const RVec>      x,
-                                ForceOutputs*             forceOutputs,
-                                tensor                    vir_force,
-                                const t_mdatoms*          mdatoms,
-                                const t_forcerec*         fr,
-                                gmx::VirtualSitesHandler* vsite,
-                                const StepWorkload&       stepWork)
+//! When necessary, spreads forces on vsites and computes the virial for \p forceOutputs->forceWithShiftForces()
+static void postProcessForceWithShiftForces(t_nrnb*                   nrnb,
+                                            gmx_wallcycle_t           wcycle,
+                                            const matrix              box,
+                                            ArrayRef<const RVec>      x,
+                                            ForceOutputs*             forceOutputs,
+                                            tensor                    vir_force,
+                                            const t_mdatoms&          mdatoms,
+                                            const t_forcerec&         fr,
+                                            gmx::VirtualSitesHandler* vsite,
+                                            const StepWorkload&       stepWork)
 {
-    rvec* f = as_rvec_array(forceOutputs->forceWithShiftForces().force().data());
+    ForceWithShiftForces& forceWithShiftForces = forceOutputs->forceWithShiftForces();
+
+    /* If we have NoVirSum forces, but we do not calculate the virial,
+     * we later sum the forceWithShiftForces buffer together with
+     * the noVirSum buffer and spread the combined vsite forces at once.
+     */
+    if (vsite && (!forceOutputs->haveForceWithVirial() || stepWork.computeVirial))
+    {
+        using VirialHandling = gmx::VirtualSitesHandler::VirialHandling;
+
+        auto                 f      = forceWithShiftForces.force();
+        auto                 fshift = forceWithShiftForces.shiftForces();
+        const VirialHandling virialHandling =
+                (stepWork.computeVirial ? VirialHandling::Pbc : VirialHandling::None);
+        vsite->spreadForces(x, f, virialHandling, fshift, nullptr, nrnb, box, wcycle);
+        forceWithShiftForces.haveSpreadVsiteForces() = true;
+    }
+
+    if (stepWork.computeVirial)
+    {
+        /* Calculation of the virial must be done after vsites! */
+        calc_virial(0, mdatoms.homenr, as_rvec_array(x.data()), forceWithShiftForces, vir_force,
+                    box, nrnb, &fr, fr.pbcType);
+    }
+}
+
+//! Spread, compute virial for and sum forces, when necessary
+static void postProcessForces(const t_commrec*          cr,
+                              int64_t                   step,
+                              t_nrnb*                   nrnb,
+                              gmx_wallcycle_t           wcycle,
+                              const matrix              box,
+                              ArrayRef<const RVec>      x,
+                              ForceOutputs*             forceOutputs,
+                              tensor                    vir_force,
+                              const t_mdatoms*          mdatoms,
+                              const t_forcerec*         fr,
+                              gmx::VirtualSitesHandler* vsite,
+                              const StepWorkload&       stepWork)
+{
+    // Extract the final output force buffer, which is also the buffer for forces with shift forces
+    ArrayRef<RVec> f = forceOutputs->forceWithShiftForces().force();
 
     if (forceOutputs->haveForceWithVirial())
     {
@@ -303,6 +345,13 @@ static void post_process_forces(const t_commrec*          cr,
              * This is parallellized. MPI communication is performed
              * if the constructing atoms aren't local.
              */
+            GMX_ASSERT(!stepWork.computeVirial || f.data() != forceWithVirial.force_.data(),
+                       "We need separate force buffers for shift and virial forces when "
+                       "computing the virial");
+            GMX_ASSERT(!stepWork.computeVirial
+                               || forceOutputs->forceWithShiftForces().haveSpreadVsiteForces(),
+                       "We should spread the force with shift forces separately when computing "
+                       "the virial");
             const gmx::VirtualSitesHandler::VirialHandling virialHandling =
                     (stepWork.computeVirial ? gmx::VirtualSitesHandler::VirialHandling::NonLinear
                                             : gmx::VirtualSitesHandler::VirialHandling::None);
@@ -328,6 +377,11 @@ static void post_process_forces(const t_commrec*          cr,
             }
         }
     }
+    else
+    {
+        GMX_ASSERT(vsite == nullptr || forceOutputs->forceWithShiftForces().haveSpreadVsiteForces(),
+                   "We should have spread the vsite forces (earlier)");
+    }
 
     if (fr->print_force >= 0)
     {
@@ -1787,25 +1841,8 @@ void do_force(FILE*                               fplog,
 
     if (stepWork.computeForces)
     {
-        /* If we have NoVirSum forces, but we do not calculate the virial,
-         * we sum fr->f_novirsum=forceOut.f later.
-         */
-        if (vsite && !(fr->forceHelperBuffers->haveDirectVirialContributions() && !stepWork.computeVirial))
-        {
-            auto f      = forceOut.forceWithShiftForces().force();
-            auto fshift = forceOut.forceWithShiftForces().shiftForces();
-            const gmx::VirtualSitesHandler::VirialHandling virialHandling =
-                    (stepWork.computeVirial ? gmx::VirtualSitesHandler::VirialHandling::Pbc
-                                            : gmx::VirtualSitesHandler::VirialHandling::None);
-            vsite->spreadForces(x.unpaddedArrayRef(), f, virialHandling, fshift, nullptr, nrnb, box, wcycle);
-        }
-
-        if (stepWork.computeVirial)
-        {
-            /* Calculation of the virial must be done after vsites! */
-            calc_virial(0, mdatoms->homenr, as_rvec_array(x.unpaddedArrayRef().data()),
-                        forceOut.forceWithShiftForces(), vir_force, box, nrnb, fr, inputrec->pbcType);
-        }
+        postProcessForceWithShiftForces(nrnb, wcycle, box, x.unpaddedArrayRef(), &forceOut,
+                                        vir_force, *mdatoms, *fr, vsite, stepWork);
     }
 
     // TODO refactor this and unify with above GPU PME-PP / GPU update path call to the same function
@@ -1821,8 +1858,8 @@ void do_force(FILE*                               fplog,
 
     if (stepWork.computeForces)
     {
-        post_process_forces(cr, step, nrnb, wcycle, box, x.unpaddedArrayRef(), &forceOut, vir_force,
-                            mdatoms, fr, vsite, stepWork);
+        postProcessForces(cr, step, nrnb, wcycle, box, x.unpaddedArrayRef(), &forceOut, vir_force,
+                          mdatoms, fr, vsite, stepWork);
     }
 
     if (stepWork.computeEnergy)
index 6152a68f2d7af75f38d3e4c645110fdc6319c246..c31bcf05ba5f6536fc8cfa63aacbe824ea835f08 100644 (file)
@@ -102,6 +102,9 @@ public:
     //! Returns a const shift forces buffer
     gmx::ArrayRef<const gmx::RVec> shiftForces() const { return shiftForces_; }
 
+    //! Returns a reference to the boolean which tells whether we have spread forces on vsites
+    bool& haveSpreadVsiteForces() { return haveSpreadVsiteForces_; }
+
 private:
     //! The force buffer
     gmx::ArrayRefWithPadding<gmx::RVec> force_;
@@ -109,6 +112,8 @@ private:
     bool computeVirial_;
     //! A buffer for storing the shift forces, size SHIFTS
     gmx::ArrayRef<gmx::RVec> shiftForces_;
+    //! Tells whether we have spread the vsite forces
+    bool haveSpreadVsiteForces_ = false;
 };
 
 /*! \libinternal \brief Container for force and virial for algorithms that provide their own virial tensor contribution