Expand signatures for nblib listed forces calculator
authorJoe Jordan <ejjordan12@gmail.com>
Thu, 21 Oct 2021 14:19:32 +0000 (14:19 +0000)
committerMark Abraham <mark.j.abraham@gmail.com>
Thu, 21 Oct 2021 14:19:32 +0000 (14:19 +0000)
api/nblib/listed_forces/CMakeLists.txt
api/nblib/listed_forces/calculator.cpp
api/nblib/listed_forces/calculator.h
api/nblib/listed_forces/helpers.hpp
api/nblib/listed_forces/tests/calculator.cpp
api/nblib/listed_forces/tests/helpers.cpp
api/nblib/samples/methane-water-integration.cpp

index 55507984d59ab16ec145e918a2dd05a7c91c8290..67f6b35ffd728ea8749a82f16b1eb14fb346b42e 100644 (file)
@@ -1,7 +1,7 @@
 #
 # This file is part of the GROMACS molecular simulation package.
 #
-# Copyright (c) 2020, by the GROMACS development team, led by
+# Copyright (c) 2020,2021, by the GROMACS development team, led by
 # Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
 # and including many others, as listed in the AUTHORS file in the
 # top-level source directory and at http://www.gromacs.org.
@@ -49,7 +49,7 @@ if(GMX_INSTALL_NBLIB_API)
             bondtypes.h
             calculator.h
             definitions.h
-            DESTINATION include/nblib)
+            DESTINATION include/nblib/listed_forces)
 endif()
 
 if(BUILD_TESTING)
index 5ded958e743d7edd5bef5af185ee3c53fca0d0dc..fde3ee81f8f9933717cb2da6268bb97b0e9ff596 100644 (file)
@@ -42,6 +42,8 @@
  * \author Sebastian Keller <keller@cscs.ch>
  * \author Artem Zhmurov <zhmurov@gmail.com>
  */
+#include <algorithm>
+
 #include "nblib/box.h"
 #include "nblib/exception.h"
 #include "nblib/pbc.hpp"
@@ -59,14 +61,16 @@ ListedForceCalculator::ListedForceCalculator(const ListedInteractionData& intera
                                              int                          nthr,
                                              const Box&                   box) :
     numThreads(nthr),
-    masterForceBuffer_(bufferSize, Vec3{ 0, 0, 0 }),
+    threadedForceBuffers_(numThreads),
+    threadedShiftForceBuffers_(numThreads),
     pbcHolder_(std::make_unique<PbcHolder>(PbcType::Xyz, box))
 {
     // split up the work
     threadedInteractions_ = splitListedWork(interactions, bufferSize, numThreads);
 
     // set up the reduction buffers
-    int threadRange = bufferSize / numThreads;
+    int threadRange = (bufferSize + numThreads - 1) / numThreads;
+#pragma omp parallel for num_threads(numThreads) schedule(static)
     for (int i = 0; i < numThreads; ++i)
     {
         int rangeStart = i * threadRange;
@@ -77,28 +81,78 @@ ListedForceCalculator::ListedForceCalculator(const ListedInteractionData& intera
             rangeEnd = bufferSize;
         }
 
-        threadedForceBuffers_.push_back(std::make_unique<ForceBuffer<Vec3>>(
-                masterForceBuffer_.data(), rangeStart, rangeEnd));
+        threadedForceBuffers_[i]      = ForceBufferProxy<Vec3>(rangeStart, rangeEnd);
+        threadedShiftForceBuffers_[i] = std::vector<Vec3>(gmx::c_numShiftVectors);
     }
 }
 
-void ListedForceCalculator::computeForcesAndEnergies(gmx::ArrayRef<const Vec3> x, bool usePbc)
+template<class ShiftForce>
+void ListedForceCalculator::computeForcesAndEnergies(gmx::ArrayRef<const Vec3> x,
+                                                     gmx::ArrayRef<Vec3>       forces,
+                                                     [[maybe_unused]] gmx::ArrayRef<ShiftForce> shiftForces,
+                                                     bool usePbc)
 {
+    if (x.size() != forces.size())
+    {
+        throw InputException("Coordinates array and force buffer size mismatch");
+    }
+
     energyBuffer_.fill(0);
     std::vector<std::array<real, std::tuple_size<ListedInteractionData>::value>> energiesPerThread(numThreads);
 
+    constexpr bool haveShiftForces = !std::is_same_v<ShiftForce, std::nullptr_t>;
+    if constexpr (haveShiftForces)
+    {
+        if (shiftForces.size() != gmx::c_numShiftVectors)
+        {
+            throw InputException("Shift vectors array size mismatch");
+        }
+    }
+
 #pragma omp parallel for num_threads(numThreads) schedule(static)
     for (int thread = 0; thread < numThreads; ++thread)
     {
+        std::conditional_t<haveShiftForces, gmx::ArrayRef<Vec3>, gmx::ArrayRef<std::nullptr_t>> shiftForceBuffer;
+        if constexpr (haveShiftForces)
+        {
+            shiftForceBuffer = gmx::ArrayRef<Vec3>(threadedShiftForceBuffers_[thread]);
+            std::fill(shiftForceBuffer.begin(), shiftForceBuffer.end(), Vec3{ 0, 0, 0 });
+        }
+
+        ForceBufferProxy<Vec3>* threadBuffer = &threadedForceBuffers_[thread];
+
+        // forces in range of this thread are directly written into the output buffer
+        threadBuffer->setMasterBuffer(forces);
+
+        // zero out the outliers in the thread buffer
+        threadBuffer->clearOutliers();
+
         if (usePbc)
         {
             energiesPerThread[thread] = reduceListedForces(
-                    threadedInteractions_[thread], x, threadedForceBuffers_[thread].get(), *pbcHolder_);
+                    threadedInteractions_[thread], x, threadBuffer, shiftForceBuffer, *pbcHolder_);
         }
         else
         {
             energiesPerThread[thread] = reduceListedForces(
-                    threadedInteractions_[thread], x, threadedForceBuffers_[thread].get(), NoPbc{});
+                    threadedInteractions_[thread], x, threadBuffer, shiftForceBuffer, NoPbc{});
+        }
+    }
+
+    // reduce shift forces
+    // This is a potential candidate for OMP parallelization, but attention should be paid to the
+    // relative costs of thread synchronization overhead vs reduction cost in contexts where the
+    // number of threads could be large vs where number of threads could be small
+    if constexpr (haveShiftForces)
+    {
+        for (int i = 0; i < gmx::c_numShiftVectors; ++i)
+        {
+            Vec3 threadSum{ 0, 0, 0 };
+            for (int thread = 0; thread < numThreads; ++thread)
+            {
+                threadSum += (threadedShiftForceBuffers_[thread])[i];
+            }
+            shiftForces[i] += threadSum;
         }
     }
 
@@ -110,16 +164,15 @@ void ListedForceCalculator::computeForcesAndEnergies(gmx::ArrayRef<const Vec3> x
             energyBuffer_[type] += energiesPerThread[thread][type];
         }
     }
-
     // reduce forces
 #pragma omp parallel for num_threads(numThreads) schedule(static)
     for (int thread = 0; thread < numThreads; ++thread)
     {
-        auto& thisBuffer = *threadedForceBuffers_[thread];
+        auto& thisBuffer = threadedForceBuffers_[thread];
         // access outliers from other threads
         for (int otherThread = 0; otherThread < numThreads; ++otherThread)
         {
-            auto& otherBuffer = *threadedForceBuffers_[otherThread];
+            auto& otherBuffer = threadedForceBuffers_[otherThread];
             for (const auto& outlier : otherBuffer)
             {
                 int index = outlier.first;
@@ -127,43 +180,36 @@ void ListedForceCalculator::computeForcesAndEnergies(gmx::ArrayRef<const Vec3> x
                 if (thisBuffer.inRange(index))
                 {
                     auto force = outlier.second;
-                    masterForceBuffer_[index] += force;
+                    forces[index] += force;
                 }
             }
         }
     }
 }
 
-void ListedForceCalculator::compute(gmx::ArrayRef<const Vec3> coordinates, gmx::ArrayRef<Vec3> forces, bool usePbc)
+void ListedForceCalculator::compute(gmx::ArrayRef<const Vec3> coordinates,
+                                    gmx::ArrayRef<Vec3>       forces,
+                                    gmx::ArrayRef<Vec3>       shiftForces,
+                                    gmx::ArrayRef<real>       energies,
+                                    bool                      usePbc)
 {
-    if (coordinates.size() != forces.size())
-    {
-        throw InputException("Coordinates array and force buffer size mismatch");
-    }
-
-    // check if the force buffers have the same size
-    if (masterForceBuffer_.size() != forces.size())
-    {
-        throw InputException("Input force buffer size mismatch with listed forces buffer");
-    }
-
-    // compute forces and fill in local buffers
-    computeForcesAndEnergies(coordinates, usePbc);
-
-    // add forces to output force buffers
-    for (int pIndex = 0; pIndex < int(forces.size()); pIndex++)
+    computeForcesAndEnergies(coordinates, forces, shiftForces, usePbc);
+    if (!energies.empty())
     {
-        forces[pIndex] += masterForceBuffer_[pIndex];
+        std::copy(energyBuffer_.begin(), energyBuffer_.end(), energies.begin());
     }
 }
+
 void ListedForceCalculator::compute(gmx::ArrayRef<const Vec3> coordinates,
                                     gmx::ArrayRef<Vec3>       forces,
-                                    EnergyType&               energies,
+                                    gmx::ArrayRef<real>       energies,
                                     bool                      usePbc)
 {
-    compute(coordinates, forces, usePbc);
-
-    energies = energyBuffer_;
+    computeForcesAndEnergies(coordinates, forces, gmx::ArrayRef<std::nullptr_t>{}, usePbc);
+    if (!energies.empty())
+    {
+        std::copy(energyBuffer_.begin(), energyBuffer_.end(), energies.begin());
+    }
 }
 
 } // namespace nblib
index 7017862aa690dcb98a09608e82dc3b9d53df2101..70097bf35a324c4cb0ae8b4c270e11acf1275f0b 100644 (file)
@@ -65,7 +65,7 @@ namespace nblib
 class Box;
 class PbcHolder;
 template<class T>
-class ForceBuffer;
+class ForceBufferProxy;
 
 /*! \internal \brief Object to calculate forces and energies of listed forces
  *
@@ -90,15 +90,22 @@ public:
      * This function also stores the forces and energies from listed interactions in the internal
      * buffer of the ListedForceCalculator object
      *
-     * \param[in] coordinates to be used for the force calculation
-     * \param[out] forces buffer to store the output forces
+     * \param[in]  coordinates     input coordinates for the force calculation
+     * \param[inout] forces        output for adding the forces
+     * \param[inout] shiftForces   output for adding shift forces
+     * \param[out] energies        output for potential energies
+     * \param[in]  usePbc          whether or not to consider periodic boundary conditions
      */
-    void compute(gmx::ArrayRef<const Vec3> coordinates, gmx::ArrayRef<Vec3> forces, bool usePbc = false);
+    void compute(gmx::ArrayRef<const Vec3> coordinates,
+                 gmx::ArrayRef<Vec3>       forces,
+                 gmx::ArrayRef<Vec3>       shiftForces,
+                 gmx::ArrayRef<real>       energies,
+                 bool                      usePbc = false);
 
-    //! \brief Alternative overload with the energies in an output buffer
+    //! \brief Alternative overload without shift forces
     void compute(gmx::ArrayRef<const Vec3> coordinates,
                  gmx::ArrayRef<Vec3>       forces,
-                 EnergyType&               energies,
+                 gmx::ArrayRef<real>       energies,
                  bool                      usePbc = false);
 
     //! \brief default, but moved to separate compilation unit
@@ -107,9 +114,6 @@ public:
 private:
     int numThreads;
 
-    //! the main buffer to hold the final listed forces
-    std::vector<gmx::RVec> masterForceBuffer_;
-
     //! holds the array of energies computed
     EnergyType energyBuffer_;
 
@@ -117,13 +121,20 @@ private:
     std::vector<ListedInteractionData> threadedInteractions_;
 
     //! reduction force buffers
-    std::vector<std::unique_ptr<ForceBuffer<gmx::RVec>>> threadedForceBuffers_;
+    std::vector<ForceBufferProxy<Vec3>> threadedForceBuffers_;
+
+    //! reduction shift force buffers
+    std::vector<std::vector<Vec3>> threadedShiftForceBuffers_;
 
     //! PBC objects
     std::unique_ptr<PbcHolder> pbcHolder_;
 
     //! compute listed forces and energies, overwrites the internal buffers
-    void computeForcesAndEnergies(gmx::ArrayRef<const Vec3> x, bool usePbc = false);
+    template<class ShiftForce>
+    void computeForcesAndEnergies(gmx::ArrayRef<const Vec3>                  x,
+                                  gmx::ArrayRef<Vec3>                        forces,
+                                  [[maybe_unused]] gmx::ArrayRef<ShiftForce> shiftForces,
+                                  bool                                       usePbc = false);
 };
 
 } // namespace nblib
index b5e89ba4aa9126b843dfa2f5503209c0d3c13823..b514089abaccf820168129928e9463c452b00a12 100644 (file)
@@ -49,6 +49,8 @@
 
 #include <unordered_map>
 
+#include "gromacs/utility/arrayref.h"
+
 #include "nblib/pbc.hpp"
 #include "definitions.h"
 #include "nblib/util/util.hpp"
@@ -75,29 +77,29 @@ inline void gmxRVecZeroWorkaround<gmx::RVec>(gmx::RVec& value)
 }
 } // namespace detail
 
-/*! \internal \brief object to store forces for multithreaded listed forces computation
+/*! \internal \brief proxy object to access forces in an underlying buffer
+ *
+ * Depending on the index, either the underlying master buffer, or local
+ * storage for outliers is accessed. This object does not own the master buffer.
  *
  */
 template<class T>
-class ForceBuffer
+class ForceBufferProxy
 {
     using HashMap = std::unordered_map<int, T>;
 
 public:
-    ForceBuffer() : rangeStart(0), rangeEnd(0) { }
+    ForceBufferProxy() : rangeStart_(0), rangeEnd_(0) { }
 
-    ForceBuffer(T* mbuf, int rs, int re) :
-        masterForceBuffer(mbuf),
-        rangeStart(rs),
-        rangeEnd(re)
+    ForceBufferProxy(int rangeStart, int rangeEnd) : rangeStart_(rangeStart), rangeEnd_(rangeEnd)
     {
     }
 
-    void clear() { outliers.clear(); }
+    void clearOutliers() { outliers.clear(); }
 
     inline NBLIB_ALWAYS_INLINE T& operator[](int i)
     {
-        if (i >= rangeStart && i < rangeEnd)
+        if (i >= rangeStart_ && i < rangeEnd_)
         {
             return masterForceBuffer[i];
         }
@@ -117,12 +119,14 @@ public:
     typename HashMap::const_iterator begin() { return outliers.begin(); }
     typename HashMap::const_iterator end() { return outliers.end(); }
 
-    [[nodiscard]] bool inRange(int index) const { return (index >= rangeStart && index < rangeEnd); }
+    [[nodiscard]] bool inRange(int index) const { return (index >= rangeStart_ && index < rangeEnd_); }
+
+    void setMasterBuffer(gmx::ArrayRef<T> buffer) { masterForceBuffer = buffer; }
 
 private:
-    T*  masterForceBuffer;
-    int rangeStart;
-    int rangeEnd;
+    gmx::ArrayRef<T> masterForceBuffer;
+    int rangeStart_;
+    int rangeEnd_;
 
     HashMap outliers;
 };
index 447ab2385d46d8acab094066bb112fc0c0e39e0c..1c7bff1b4e4b07ae9510665b0b983d35f7da3de0 100644 (file)
@@ -206,7 +206,6 @@ protected:
         interactions = data.interactions;
         box          = data.box;
         refForces    = data.forces;
-        // pbc.reset(new PbcHolder(*box));
 
         refEnergies = reduceListedForces(interactions, x, &refForces, NoPbc{});
     }
@@ -224,7 +223,6 @@ protected:
     std::vector<gmx::RVec> x;
     ListedInteractionData  interactions;
     std::shared_ptr<Box>   box;
-    // std::shared_ptr<PbcHolder> pbc;
 
 private:
     std::vector<gmx::RVec>            refForces;
index cb91699c8feb7f9c063e0efe7733064d92aa6b09..e705c16978226d76f8e357b10ce109fb9ee31287 100644 (file)
@@ -106,12 +106,13 @@ TEST(NBlibTest, ListedForceBuffer)
     T              vzero{ 0, 0, 0 };
     std::vector<T> masterBuffer(ncoords, vzero);
 
-    // the ForceBuffer is going to access indices [10-15) through the masterBuffer
+    // the ForceBufferProxy is going to access indices [10-15) through the masterBuffer
     // and the outliers internally
     int rangeStart = 10;
     int rangeEnd   = 15;
 
-    ForceBuffer<T> forceBuffer(masterBuffer.data(), rangeStart, rangeEnd);
+    ForceBufferProxy<T> forceBuffer(rangeStart, rangeEnd);
+    forceBuffer.setMasterBuffer(masterBuffer);
 
     // in range
     T internal1{ 1, 2, 3 };
index 06bc92b6680341e5050e4fe11b7e3da2cfd3b7e5..4a565735bf538f17e6c20707381be287eaf504f2 100644 (file)
@@ -203,7 +203,8 @@ int main()
         forceCalculator->compute(
                 simulationState.coordinates(), simulationState.box(), simulationState.forces());
 
-        listedForceCalculator.compute(simulationState.coordinates(), simulationState.forces());
+        listedForceCalculator.compute(
+                simulationState.coordinates(), simulationState.forces(), gmx::ArrayRef<real>{});
 
         // Integrate with a time step of 1 fs, positions, velocities and forces
         integrator.integrate(