Prepare ThreadedForceBuffer for FE kernel use
authorBerk Hess <hess@kth.se>
Tue, 21 Sep 2021 07:52:15 +0000 (09:52 +0200)
committerBerk Hess <hess@kth.se>
Tue, 21 Sep 2021 07:52:15 +0000 (09:52 +0200)
Made the energy terms optional.
Skip reductions and allow nullptr/empty input for buffers that are
not used according to the stepwork flags.
Instantiate RVec versions of the templated types.

src/gromacs/listed_forces/manage_threading.cpp
src/gromacs/mdtypes/threaded_force_buffer.cpp
src/gromacs/mdtypes/threaded_force_buffer.h

index 8b1b68eb2b62d4eb6f27add05d40032524bd4278..5a3e88929e231047794efe118243ee12a8f08370 100644 (file)
@@ -413,7 +413,7 @@ void setup_bonded_threading(bonded_threading_t*           bt,
 
 bonded_threading_t::bonded_threading_t(const int numThreads, const int numEnergyGroups, FILE* fplog) :
     nthreads(numThreads),
-    threadedForceBuffer(numThreads, numEnergyGroups),
+    threadedForceBuffer(numThreads, true, numEnergyGroups),
     haveBondeds(false),
     workDivision(nthreads),
     foreignLambdaWorkDivision(1)
index c3dc2d92b0de97aa1f06b0761156a755bd2d3252..0aca081f2c4945690b8101e60b07b452d21a49ae 100644 (file)
@@ -60,10 +60,15 @@ namespace gmx
 static constexpr int s_maxNumThreadsForReduction = 256;
 
 template<typename ForceBufferElementType>
-ThreadForceBuffer<ForceBufferElementType>::ThreadForceBuffer(const int threadIndex,
-                                                             const int numEnergyGroups) :
+ThreadForceBuffer<ForceBufferElementType>::ThreadForceBuffer(const int  threadIndex,
+                                                             const bool useEnergyTerms,
+                                                             const int  numEnergyGroups) :
     threadIndex_(threadIndex), shiftForces_(c_numShiftVectors), groupPairEnergies_(numEnergyGroups)
 {
+    if (useEnergyTerms)
+    {
+        energyTerms_.resize(F_NRE);
+    }
 }
 
 template<typename ForceBufferElementType>
@@ -207,7 +212,10 @@ void reduceThreadForceBuffers(ArrayRef<gmx::RVec> force,
 } // namespace
 
 template<typename ForceBufferElementType>
-ThreadedForceBuffer<ForceBufferElementType>::ThreadedForceBuffer(const int numThreads, const int numEnergyGroups)
+ThreadedForceBuffer<ForceBufferElementType>::ThreadedForceBuffer(const int  numThreads,
+                                                                 const bool useEnergyTerms,
+                                                                 const int  numEnergyGroups) :
+    useEnergyTerms_(useEnergyTerms)
 {
     threadForceBuffers_.resize(numThreads);
 #pragma omp parallel for num_threads(numThreads) schedule(static)
@@ -218,8 +226,8 @@ ThreadedForceBuffer<ForceBufferElementType>::ThreadedForceBuffer(const int numTh
             /* Note that thread 0 uses the global fshift and energy arrays,
              * but to keep the code simple, we initialize all data here.
              */
-            threadForceBuffers_[t] =
-                    std::make_unique<ThreadForceBuffer<ForceBufferElementType>>(t, numEnergyGroups);
+            threadForceBuffers_[t] = std::make_unique<ThreadForceBuffer<ForceBufferElementType>>(
+                    t, useEnergyTerms_, numEnergyGroups);
         }
         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
     }
@@ -307,15 +315,15 @@ void ThreadedForceBuffer<ForceBufferElementType>::reduce(gmx::ForceWithShiftForc
                                                          const gmx::StepWorkload& stepWork,
                                                          const int reductionBeginIndex)
 {
-    if (!usedBlockIndices_.empty())
+    if (stepWork.computeForces && !usedBlockIndices_.empty())
     {
-        /* Reduce the bonded force buffer */
+        /* Reduce the force buffer */
+        GMX_ASSERT(forceWithShiftForces, "Need a valid force buffer for reduction");
+
         reduceThreadForceBuffers<ForceBufferElementType>(
                 forceWithShiftForces->force(), threadForceBuffers_, reductionMask_, usedBlockIndices_);
     }
 
-    rvec* gmx_restrict fshift = as_rvec_array(forceWithShiftForces->shiftForces().data());
-
     const int numBuffers = numThreadBuffers();
 
     /* When necessary, reduce energy and virial using one thread only */
@@ -327,6 +335,10 @@ void ThreadedForceBuffer<ForceBufferElementType>::reduce(gmx::ForceWithShiftForc
 
         if (stepWork.computeVirial)
         {
+            GMX_ASSERT(forceWithShiftForces, "Need a valid force buffer for reduction");
+
+            rvec* gmx_restrict fshift = as_rvec_array(forceWithShiftForces->shiftForces().data());
+
             for (int i = 0; i < gmx::c_numShiftVectors; i++)
             {
                 for (int t = reductionBeginIndex; t < numBuffers; t++)
@@ -335,8 +347,10 @@ void ThreadedForceBuffer<ForceBufferElementType>::reduce(gmx::ForceWithShiftForc
                 }
             }
         }
-        if (stepWork.computeEnergy)
+        if (stepWork.computeEnergy && useEnergyTerms_)
         {
+            GMX_ASSERT(ener, "Need a valid energy buffer for reduction");
+
             for (int i = 0; i < F_NRE; i++)
             {
                 for (int t = reductionBeginIndex; t < numBuffers; t++)
@@ -344,6 +358,12 @@ void ThreadedForceBuffer<ForceBufferElementType>::reduce(gmx::ForceWithShiftForc
                     ener[i] += f_t[t]->energyTerms()[i];
                 }
             }
+        }
+
+        if (stepWork.computeEnergy)
+        {
+            GMX_ASSERT(grpp, "Need a valid group pair energy buffer for reduction");
+
             for (int i = 0; i < static_cast<int>(NonBondedEnergyTerms::Count); i++)
             {
                 for (int j = 0; j < f_t[0]->groupPairEnergies().nener; j++)
@@ -358,6 +378,8 @@ void ThreadedForceBuffer<ForceBufferElementType>::reduce(gmx::ForceWithShiftForc
         }
         if (stepWork.computeDhdl)
         {
+            GMX_ASSERT(!dvdl.empty(), "Need a valid dV/dl buffer for reduction");
+
             for (auto i : keysOf(f_t[0]->dvdl()))
             {
 
@@ -370,6 +392,9 @@ void ThreadedForceBuffer<ForceBufferElementType>::reduce(gmx::ForceWithShiftForc
     }
 }
 
+template class ThreadForceBuffer<RVec>;
+template class ThreadedForceBuffer<RVec>;
+
 template class ThreadForceBuffer<rvec4>;
 template class ThreadedForceBuffer<rvec4>;
 
index f468c1dafef19723413ecf896384340fdc94796e..6eeae6008f501bc4363de959ca48d30fe931a891 100644 (file)
@@ -64,7 +64,6 @@
 #include "gromacs/math/vectypes.h"
 #include "gromacs/mdtypes/enerdata.h"
 #include "gromacs/mdtypes/simulation_workload.h"
-#include "gromacs/topology/idef.h"
 #include "gromacs/topology/ifunc.h"
 #include "gromacs/utility/alignedallocator.h"
 #include "gromacs/utility/arrayref.h"
@@ -95,8 +94,12 @@ public:
     //! Force buffer block size in atoms
     static constexpr int s_reductionBlockSize = (1 << s_numReductionBlockBits);
 
-    //! Constructor
-    ThreadForceBuffer(int threadIndex, int numEnergyGroups);
+    /*! \brief Constructor
+     * \param[in] threadIndex  The index of the thread that will fill the buffers in this object
+     * \param[in] useEnergyTerms   Whether the list of energy terms will be used
+     * \param[in] numEnergyGroups  The number of non-bonded energy groups
+     */
+    ThreadForceBuffer(int threadIndex, bool useEnergyTerms, int numEnergyGroups);
 
     //! Resizes the buffer to \p numAtoms and clears the mask
     void resizeBufferAndClearMask(int numAtoms);
@@ -150,8 +153,8 @@ private:
 
     //! Shift force array, size c_numShiftVectors
     std::vector<RVec> shiftForces_;
-    //! Energy array
-    std::array<real, F_NRE> energyTerms_;
+    //! Energy array, can be empty
+    std::vector<real> energyTerms_;
     //! Group pair energy data for pairs
     gmx_grppairener_t groupPairEnergies_;
     //! Free-energy dV/dl output
@@ -170,8 +173,12 @@ template<typename ForceBufferElementType>
 class ThreadedForceBuffer
 {
 public:
-    //! Constructor
-    ThreadedForceBuffer(int numThreads, int numEnergyGroups);
+    /*! \brief Constructor
+     * \param[in] numThreads       The number of threads that will use the buffers and reduce
+     * \param[in] useEnergyTerms   Whether the list of energy terms will be used
+     * \param[in] numEnergyGroups  The number of non-bonded energy groups
+     */
+    ThreadedForceBuffer(int numThreads, bool useEnergyTerms, int numEnergyGroups);
 
     //! Returns the number of thread buffers
     int numThreadBuffers() const { return threadForceBuffers_.size(); }
@@ -189,6 +196,9 @@ public:
      *
      * The reduction of all output starts at the output from thread \p reductionBeginIndex,
      * except for the normal force buffer, which always starts at 0.
+     *
+     * Buffers that will not be used as indicated by the flags in \p stepWork
+     * are allowed to be nullptr or empty.
      */
     void reduce(gmx::ForceWithShiftForces* forceWithShiftForces,
                 real*                      ener,
@@ -198,6 +208,8 @@ public:
                 int                        reductionBeginIndex);
 
 private:
+    //! Whether the energy buffer is used
+    bool useEnergyTerms_;
     //! Force/energy data per thread, size nthreads, stored in unique_ptr to allow thread local allocation
     std::vector<std::unique_ptr<ThreadForceBuffer<ForceBufferElementType>>> threadForceBuffers_;
     //! Indices of blocks that are used, i.e. have force contributions.
@@ -211,7 +223,11 @@ private:
     GMX_DISALLOW_COPY_MOVE_AND_ASSIGN(ThreadedForceBuffer);
 };
 
-// Instantiate for rvec4. Can also be instantiated for rvec.
+// Instantiate for RVec
+extern template class ThreadForceBuffer<RVec>;
+extern template class ThreadedForceBuffer<RVec>;
+
+// Instantiate for rvec4
 extern template class ThreadForceBuffer<rvec4>;
 extern template class ThreadedForceBuffer<rvec4>;