Extract helper force buffers to a separate class
authorBerk Hess <hess@kth.se>
Tue, 12 May 2020 13:14:11 +0000 (13:14 +0000)
committerPaul Bauer <paul.bauer.q@gmail.com>
Tue, 12 May 2020 13:14:11 +0000 (13:14 +0000)
This change provides useful refactoring itself and it is needed
for multiple time stepping.

src/gromacs/domdec/partition.cpp
src/gromacs/math/vec.h
src/gromacs/mdlib/forcerec.cpp
src/gromacs/mdlib/sim_util.cpp
src/gromacs/mdtypes/forceoutput.h
src/gromacs/mdtypes/forcerec.h

index ca7c48c14d35692b3ddf67502c072529ef99b3d2..2d5d6cb0b82c978d8224574f0aed4fccd7dd3378 100644 (file)
@@ -3117,7 +3117,7 @@ void dd_partition_system(FILE*                        fplog,
 
     state_change_natoms(state_local, state_local->natoms);
 
-    if (fr->haveDirectVirialContributions)
+    if (fr->forceHelperBuffers->haveDirectVirialContributions())
     {
         if (vsite && vsite->numInterUpdategroupVirtualSites())
         {
index 61f51d8b9b3bdfd3fb0cd509bd3a17c5c127ec15..7666ec6d3309b5b8918a43be9b6f247d0b82f469 100644 (file)
@@ -292,9 +292,9 @@ static inline void clear_rvec(rvec a)
     /* The ibm compiler has problems with inlining this
      * when we use a const real variable
      */
-    a[XX] = 0.0;
-    a[YY] = 0.0;
-    a[ZZ] = 0.0;
+    a[XX] = 0.0_real;
+    a[YY] = 0.0_real;
+    a[ZZ] = 0.0_real;
 }
 
 static inline void clear_dvec(dvec a)
index 3cb7e40a93c6e173f20f7a9fd45c712063f55d92..e96755eeb63800df213dd72ef6b2e7a8859c3cb8 100644 (file)
 #include "gromacs/utility/smalloc.h"
 #include "gromacs/utility/strconvert.h"
 
+ForceHelperBuffers::ForceHelperBuffers(bool haveDirectVirialContributions) :
+    haveDirectVirialContributions_(haveDirectVirialContributions)
+{
+    shiftForces_.resize(SHIFTS);
+}
+
+void ForceHelperBuffers::resize(int numAtoms)
+{
+    if (haveDirectVirialContributions_)
+    {
+        forceBufferForDirectVirialContributions_.resize(numAtoms);
+    }
+}
+
 static std::vector<real> mk_nbfp(const gmx_ffparams_t* idef, gmx_bool bBHAM)
 {
     std::vector<real> nbfp;
@@ -602,10 +616,7 @@ void forcerec_set_ranges(t_forcerec* fr, int natoms_force, int natoms_force_cons
     fr->natoms_force        = natoms_force;
     fr->natoms_force_constr = natoms_force_constr;
 
-    if (fr->haveDirectVirialContributions)
-    {
-        fr->forceBufferForDirectVirialContributions.resize(natoms_f_novirsum);
-    }
+    fr->forceHelperBuffers->resize(natoms_f_novirsum);
 }
 
 static real cutoff_inf(real cutoff)
@@ -1185,18 +1196,17 @@ void init_forcerec(FILE*                            fp,
     /* 1-4 interaction electrostatics */
     fr->fudgeQQ = mtop->ffparams.fudgeQQ;
 
-    fr->haveDirectVirialContributions =
+    const bool haveDirectVirialContributions =
             (EEL_FULL(ic->eeltype) || EVDW_PME(ic->vdwtype) || fr->forceProviders->hasForceProvider()
              || gmx_mtop_ftype_count(mtop, F_POSRES) > 0 || gmx_mtop_ftype_count(mtop, F_FBPOSRES) > 0
              || ir->nwall > 0 || ir->bPull || ir->bRot || ir->bIMD);
+    fr->forceHelperBuffers = std::make_unique<ForceHelperBuffers>(haveDirectVirialContributions);
 
     if (fr->shift_vec == nullptr)
     {
         snew(fr->shift_vec, SHIFTS);
     }
 
-    fr->shiftForces.resize(SHIFTS);
-
     if (fr->nbfp.empty())
     {
         fr->ntype = mtop->ffparams.atnr;
index 841337b6f642351acec9a727da1ee900a4e1228f..e4d3660eef29874c5b80302ed95e10aaa3697165 100644 (file)
@@ -293,7 +293,7 @@ static void post_process_forces(const t_commrec*          cr,
 {
     rvec* f = as_rvec_array(forceOutputs->forceWithShiftForces().force().data());
 
-    if (fr->haveDirectVirialContributions)
+    if (forceOutputs->haveForceWithVirial())
     {
         auto& forceWithVirial = forceOutputs->forceWithVirial();
 
@@ -378,25 +378,25 @@ static void do_nb_verlet(t_forcerec*                fr,
     nbv->dispatchNonbondedKernel(ilocality, *ic, stepWork, clearF, *fr, enerd, nrnb);
 }
 
-static inline void clear_rvecs_omp(int n, rvec v[])
+static inline void clearRVecs(ArrayRef<RVec> v, const bool useOpenmpThreading)
 {
-    int nth = gmx_omp_nthreads_get_simple_rvec_task(emntDefault, n);
+    int nth = gmx_omp_nthreads_get_simple_rvec_task(emntDefault, v.ssize());
 
     /* Note that we would like to avoid this conditional by putting it
      * into the omp pragma instead, but then we still take the full
      * omp parallel for overhead (at least with gcc5).
      */
-    if (nth == 1)
+    if (!useOpenmpThreading || nth == 1)
     {
-        for (int i = 0; i < n; i++)
+        for (RVec& elem : v)
         {
-            clear_rvec(v[i]);
+            clear_rvec(elem);
         }
     }
     else
     {
 #pragma omp parallel for num_threads(nth) schedule(static)
-        for (int i = 0; i < n; i++)
+        for (gmx::index i = 0; i < v.ssize(); i++)
         {
             clear_rvec(v[i]);
         }
@@ -706,7 +706,7 @@ static void alternatePmeNbGpuWaitReduce(nonbonded_verlet_t* nbv,
 
 /*! \brief Set up the different force buffers; also does clearing.
  *
- * \param[in] fr        force record pointer
+ * \param[in] forceHelperBuffers  Helper force buffers
  * \param[in] pull_work The pull work object.
  * \param[in] inputrec  input record
  * \param[in] force     force array
@@ -715,7 +715,7 @@ static void alternatePmeNbGpuWaitReduce(nonbonded_verlet_t* nbv,
  *
  * \returns             Cleared force output structure
  */
-static ForceOutputs setupForceOutputs(t_forcerec*                         fr,
+static ForceOutputs setupForceOutputs(ForceHelperBuffers*                 forceHelperBuffers,
                                       pull_t*                             pull_work,
                                       const t_inputrec&                   inputrec,
                                       gmx::ArrayRefWithPadding<gmx::RVec> force,
@@ -725,12 +725,16 @@ static ForceOutputs setupForceOutputs(t_forcerec*                         fr,
     wallcycle_sub_start(wcycle, ewcsCLEAR_FORCE_BUFFER);
 
     /* NOTE: We assume fr->shiftForces is all zeros here */
-    gmx::ForceWithShiftForces forceWithShiftForces(force, stepWork.computeVirial, fr->shiftForces);
+    gmx::ForceWithShiftForces forceWithShiftForces(force, stepWork.computeVirial,
+                                                   forceHelperBuffers->shiftForces());
 
     if (stepWork.computeForces)
     {
         /* Clear the short- and long-range forces */
-        clear_rvecs_omp(fr->natoms_force_constr, as_rvec_array(forceWithShiftForces.force().data()));
+        clearRVecs(forceWithShiftForces.force(), true);
+
+        /* Clear the shift forces */
+        clearRVecs(forceWithShiftForces.shiftForces(), false);
     }
 
     /* If we need to compute the virial, we might need a separate
@@ -739,11 +743,13 @@ static ForceOutputs setupForceOutputs(t_forcerec*                         fr,
      * the same force (f in legacy calls) buffer as other algorithms.
      */
     const bool useSeparateForceWithVirialBuffer =
-            (stepWork.computeForces && (stepWork.computeVirial && fr->haveDirectVirialContributions));
+            (stepWork.computeForces
+             && (stepWork.computeVirial && forceHelperBuffers->haveDirectVirialContributions()));
     /* forceWithVirial uses the local atom range only */
-    gmx::ForceWithVirial forceWithVirial(useSeparateForceWithVirialBuffer ? fr->forceBufferForDirectVirialContributions
-                                                                          : force.unpaddedArrayRef(),
-                                         stepWork.computeVirial);
+    gmx::ForceWithVirial forceWithVirial(
+            useSeparateForceWithVirialBuffer ? forceHelperBuffers->forceBufferForDirectVirialContributions()
+                                             : force.unpaddedArrayRef(),
+            stepWork.computeVirial);
 
     if (useSeparateForceWithVirialBuffer)
     {
@@ -752,7 +758,7 @@ static ForceOutputs setupForceOutputs(t_forcerec*                         fr,
          * spread to non-local atoms, but that part of the buffer is
          * cleared separately in the vsite spreading code.
          */
-        clear_rvecs_omp(forceWithVirial.force_.size(), as_rvec_array(forceWithVirial.force_.data()));
+        clearRVecs(forceWithVirial.force_, true);
     }
 
     if (inputrec.bPull && pull_have_constraint(pull_work))
@@ -762,7 +768,8 @@ static ForceOutputs setupForceOutputs(t_forcerec*                         fr,
 
     wallcycle_sub_stop(wcycle, ewcsCLEAR_FORCE_BUFFER);
 
-    return ForceOutputs(forceWithShiftForces, forceWithVirial);
+    return ForceOutputs(forceWithShiftForces, forceHelperBuffers->haveDirectVirialContributions(),
+                        forceWithVirial);
 }
 
 
@@ -1409,12 +1416,6 @@ void do_force(FILE*                               fplog,
 
     /* Reset energies */
     reset_enerdata(enerd);
-    /* Clear the shift forces */
-    // TODO: This should be linked to the shift force buffer in use, or cleared before use instead
-    for (gmx::RVec& elem : fr->shiftForces)
-    {
-        elem = { 0.0_real, 0.0_real, 0.0_real };
-    }
 
     if (DOMAINDECOMP(cr) && !thisRankHasDuty(cr, DUTY_PME))
     {
@@ -1445,8 +1446,8 @@ void do_force(FILE*                               fplog,
 
     // Set up and clear force outputs.
     // We use std::move to keep the compiler happy, it has no effect.
-    ForceOutputs forceOut =
-            setupForceOutputs(fr, pull_work, *inputrec, std::move(force), stepWork, wcycle);
+    ForceOutputs forceOut = setupForceOutputs(fr->forceHelperBuffers.get(), pull_work, *inputrec,
+                                              std::move(force), stepWork, wcycle);
 
     /* We calculate the non-bonded forces, when done on the CPU, here.
      * We do this before calling do_force_lowlevel, because in that
@@ -1789,7 +1790,7 @@ void do_force(FILE*                               fplog,
         /* If we have NoVirSum forces, but we do not calculate the virial,
          * we sum fr->f_novirsum=forceOut.f later.
          */
-        if (vsite && !(fr->haveDirectVirialContributions && !stepWork.computeVirial))
+        if (vsite && !(fr->forceHelperBuffers->haveDirectVirialContributions() && !stepWork.computeVirial))
         {
             auto f      = forceOut.forceWithShiftForces().force();
             auto fshift = forceOut.forceWithShiftForces().shiftForces();
index f17e61deebabff8e82bf52debfbc49830b77d131..6152a68f2d7af75f38d3e4c645110fdc6319c246 100644 (file)
@@ -1,7 +1,7 @@
 /*
  * This file is part of the GROMACS molecular simulation package.
  *
- * Copyright (c) 2019, by the GROMACS development team, led by
+ * Copyright (c) 2019,2020, by the GROMACS development team, led by
  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
  * and including many others, as listed in the AUTHORS file in the
  * top-level source directory and at http://www.gromacs.org.
@@ -83,6 +83,8 @@ public:
         computeVirial_(computeVirial),
         shiftForces_(computeVirial ? shiftForces : gmx::ArrayRef<gmx::RVec>())
     {
+        GMX_ASSERT(!computeVirial || !shiftForces.empty(),
+                   "We need a valid shift force buffer when computing the virial");
     }
 
     //! Returns an arrayref to the force buffer without padding
@@ -193,8 +195,11 @@ class ForceOutputs
 {
 public:
     //! Constructor
-    ForceOutputs(const ForceWithShiftForces& forceWithShiftForces, const ForceWithVirial& forceWithVirial) :
+    ForceOutputs(const ForceWithShiftForces& forceWithShiftForces,
+                 bool                        haveForceWithVirial,
+                 const ForceWithVirial&      forceWithVirial) :
         forceWithShiftForces_(forceWithShiftForces),
+        haveForceWithVirial_(haveForceWithVirial),
         forceWithVirial_(forceWithVirial)
     {
     }
@@ -202,12 +207,17 @@ public:
     //! Returns a reference to the force with shift forces object
     ForceWithShiftForces& forceWithShiftForces() { return forceWithShiftForces_; }
 
+    //! Return whether there are forces with direct virial contributions
+    bool haveForceWithVirial() const { return haveForceWithVirial_; }
+
     //! Returns a reference to the force with virial object
     ForceWithVirial& forceWithVirial() { return forceWithVirial_; }
 
 private:
     //! Force output buffer used by legacy modules (without SIMD padding)
     ForceWithShiftForces forceWithShiftForces_;
+    //! Whether we have forces with direct virial contributions
+    bool haveForceWithVirial_;
     //! Force with direct virial contribution (if there are any; without SIMD padding)
     ForceWithVirial forceWithVirial_;
 };
index 552580a7842eb5a07a9e5b0d0cf36e749018cb79..81238d4b32aa8b75a19f3e294f9211e00c972d28 100644 (file)
@@ -45,6 +45,7 @@
 #include "gromacs/math/vectypes.h"
 #include "gromacs/mdtypes/md_enums.h"
 #include "gromacs/pbcutil/pbc.h"
+#include "gromacs/utility/arrayref.h"
 #include "gromacs/utility/basedefinitions.h"
 #include "gromacs/utility/real.h"
 
@@ -125,6 +126,49 @@ struct gmx_ewald_tab_t;
 
 struct ewald_corr_thread_t;
 
+/*! \brief Helper force buffers for ForceOutputs
+ *
+ * This class stores intermediate force buffers that are used
+ * internally in the force calculation and which are reduced into
+ * the output force buffer passed to the force calculation.
+ */
+class ForceHelperBuffers
+{
+public:
+    /*! \brief Constructs helper buffers
+     *
+     * When the forces that will be accumulated with help of these buffers
+     * have direct virial contributions, set the parameter to true, so
+     * an extra force buffer is available for these forces to enable
+     * correct virial computation.
+     */
+    ForceHelperBuffers(bool haveDirectVirialContributions);
+
+    //! Returns whether we have a direct virial contribution force buffer
+    bool haveDirectVirialContributions() const { return haveDirectVirialContributions_; }
+
+    //! Returns the buffer for direct virial contributions
+    gmx::ArrayRef<gmx::RVec> forceBufferForDirectVirialContributions()
+    {
+        GMX_ASSERT(haveDirectVirialContributions_, "Buffer can only be requested when present");
+        return forceBufferForDirectVirialContributions_;
+    }
+
+    //! Returns the buffer for shift forces, size SHIFTS
+    gmx::ArrayRef<gmx::RVec> shiftForces() { return shiftForces_; }
+
+    //! Resizes the direct virial contribution buffer, when present
+    void resize(int numAtoms);
+
+private:
+    //! True when we have contributions that are directly added to the virial
+    bool haveDirectVirialContributions_ = false;
+    //! Force buffer for force computation with direct virial contributions
+    std::vector<gmx::RVec> forceBufferForDirectVirialContributions_;
+    //! Shift force array for computing the virial, size SHIFTS
+    std::vector<gmx::RVec> shiftForces_;
+};
+
 struct t_forcerec
 { // NOLINT (clang-analyzer-optin.performance.Padding)
     // Declare an explicit constructor and destructor, so they can be
@@ -216,13 +260,9 @@ struct t_forcerec
     int natoms_force = 0;
     /* The number of atoms participating in force calculation and constraints */
     int natoms_force_constr = 0;
-    /* Forces that should not enter into the coord x force virial summation:
-     * PPPM/PME/Ewald/posres/ForceProviders
-     */
-    /* True when we have contributions that are directly added to the virial */
-    bool haveDirectVirialContributions = false;
-    /* Force buffer for force computation with direct virial contributions */
-    std::vector<gmx::RVec> forceBufferForDirectVirialContributions;
+
+    /* Helper buffer for ForceOutputs */
+    std::unique_ptr<ForceHelperBuffers> forceHelperBuffers;
 
     /* Data for PPPM/PME/Ewald */
     struct gmx_pme_t* pmedata                = nullptr;
@@ -231,9 +271,6 @@ struct t_forcerec
     /* PME/Ewald stuff */
     struct gmx_ewald_tab_t* ewald_table = nullptr;
 
-    /* Shift force array for computing the virial, size SHIFTS */
-    std::vector<gmx::RVec> shiftForces;
-
     /* Non bonded Parameter lists */
     int               ntype = 0; /* Number of atom types */
     gmx_bool          bBHAM = FALSE;