Expand signatures for nblib listed forces calculator
[alexxy/gromacs.git] / api / nblib / listed_forces / calculator.cpp
index 5ded958e743d7edd5bef5af185ee3c53fca0d0dc..fde3ee81f8f9933717cb2da6268bb97b0e9ff596 100644 (file)
@@ -42,6 +42,8 @@
  * \author Sebastian Keller <keller@cscs.ch>
  * \author Artem Zhmurov <zhmurov@gmail.com>
  */
+#include <algorithm>
+
 #include "nblib/box.h"
 #include "nblib/exception.h"
 #include "nblib/pbc.hpp"
@@ -59,14 +61,16 @@ ListedForceCalculator::ListedForceCalculator(const ListedInteractionData& intera
                                              int                          nthr,
                                              const Box&                   box) :
     numThreads(nthr),
-    masterForceBuffer_(bufferSize, Vec3{ 0, 0, 0 }),
+    threadedForceBuffers_(numThreads),
+    threadedShiftForceBuffers_(numThreads),
     pbcHolder_(std::make_unique<PbcHolder>(PbcType::Xyz, box))
 {
     // split up the work
     threadedInteractions_ = splitListedWork(interactions, bufferSize, numThreads);
 
     // set up the reduction buffers
-    int threadRange = bufferSize / numThreads;
+    int threadRange = (bufferSize + numThreads - 1) / numThreads;
+#pragma omp parallel for num_threads(numThreads) schedule(static)
     for (int i = 0; i < numThreads; ++i)
     {
         int rangeStart = i * threadRange;
@@ -77,28 +81,78 @@ ListedForceCalculator::ListedForceCalculator(const ListedInteractionData& intera
             rangeEnd = bufferSize;
         }
 
-        threadedForceBuffers_.push_back(std::make_unique<ForceBuffer<Vec3>>(
-                masterForceBuffer_.data(), rangeStart, rangeEnd));
+        threadedForceBuffers_[i]      = ForceBufferProxy<Vec3>(rangeStart, rangeEnd);
+        threadedShiftForceBuffers_[i] = std::vector<Vec3>(gmx::c_numShiftVectors);
     }
 }
 
-void ListedForceCalculator::computeForcesAndEnergies(gmx::ArrayRef<const Vec3> x, bool usePbc)
+template<class ShiftForce>
+void ListedForceCalculator::computeForcesAndEnergies(gmx::ArrayRef<const Vec3> x,
+                                                     gmx::ArrayRef<Vec3>       forces,
+                                                     [[maybe_unused]] gmx::ArrayRef<ShiftForce> shiftForces,
+                                                     bool usePbc)
 {
+    if (x.size() != forces.size())
+    {
+        throw InputException("Coordinates array and force buffer size mismatch");
+    }
+
     energyBuffer_.fill(0);
     std::vector<std::array<real, std::tuple_size<ListedInteractionData>::value>> energiesPerThread(numThreads);
 
+    constexpr bool haveShiftForces = !std::is_same_v<ShiftForce, std::nullptr_t>;
+    if constexpr (haveShiftForces)
+    {
+        if (shiftForces.size() != gmx::c_numShiftVectors)
+        {
+            throw InputException("Shift vectors array size mismatch");
+        }
+    }
+
 #pragma omp parallel for num_threads(numThreads) schedule(static)
     for (int thread = 0; thread < numThreads; ++thread)
     {
+        std::conditional_t<haveShiftForces, gmx::ArrayRef<Vec3>, gmx::ArrayRef<std::nullptr_t>> shiftForceBuffer;
+        if constexpr (haveShiftForces)
+        {
+            shiftForceBuffer = gmx::ArrayRef<Vec3>(threadedShiftForceBuffers_[thread]);
+            std::fill(shiftForceBuffer.begin(), shiftForceBuffer.end(), Vec3{ 0, 0, 0 });
+        }
+
+        ForceBufferProxy<Vec3>* threadBuffer = &threadedForceBuffers_[thread];
+
+        // forces in range of this thread are directly written into the output buffer
+        threadBuffer->setMasterBuffer(forces);
+
+        // zero out the outliers in the thread buffer
+        threadBuffer->clearOutliers();
+
         if (usePbc)
         {
             energiesPerThread[thread] = reduceListedForces(
-                    threadedInteractions_[thread], x, threadedForceBuffers_[thread].get(), *pbcHolder_);
+                    threadedInteractions_[thread], x, threadBuffer, shiftForceBuffer, *pbcHolder_);
         }
         else
         {
             energiesPerThread[thread] = reduceListedForces(
-                    threadedInteractions_[thread], x, threadedForceBuffers_[thread].get(), NoPbc{});
+                    threadedInteractions_[thread], x, threadBuffer, shiftForceBuffer, NoPbc{});
+        }
+    }
+
+    // reduce shift forces
+    // This is a potential candidate for OMP parallelization, but attention should be paid to the
+    // relative costs of thread synchronization overhead vs reduction cost in contexts where the
+    // number of threads could be large vs where number of threads could be small
+    if constexpr (haveShiftForces)
+    {
+        for (int i = 0; i < gmx::c_numShiftVectors; ++i)
+        {
+            Vec3 threadSum{ 0, 0, 0 };
+            for (int thread = 0; thread < numThreads; ++thread)
+            {
+                threadSum += (threadedShiftForceBuffers_[thread])[i];
+            }
+            shiftForces[i] += threadSum;
         }
     }
 
@@ -110,16 +164,15 @@ void ListedForceCalculator::computeForcesAndEnergies(gmx::ArrayRef<const Vec3> x
             energyBuffer_[type] += energiesPerThread[thread][type];
         }
     }
-
     // reduce forces
 #pragma omp parallel for num_threads(numThreads) schedule(static)
     for (int thread = 0; thread < numThreads; ++thread)
     {
-        auto& thisBuffer = *threadedForceBuffers_[thread];
+        auto& thisBuffer = threadedForceBuffers_[thread];
         // access outliers from other threads
         for (int otherThread = 0; otherThread < numThreads; ++otherThread)
         {
-            auto& otherBuffer = *threadedForceBuffers_[otherThread];
+            auto& otherBuffer = threadedForceBuffers_[otherThread];
             for (const auto& outlier : otherBuffer)
             {
                 int index = outlier.first;
@@ -127,43 +180,36 @@ void ListedForceCalculator::computeForcesAndEnergies(gmx::ArrayRef<const Vec3> x
                 if (thisBuffer.inRange(index))
                 {
                     auto force = outlier.second;
-                    masterForceBuffer_[index] += force;
+                    forces[index] += force;
                 }
             }
         }
     }
 }
 
-void ListedForceCalculator::compute(gmx::ArrayRef<const Vec3> coordinates, gmx::ArrayRef<Vec3> forces, bool usePbc)
+void ListedForceCalculator::compute(gmx::ArrayRef<const Vec3> coordinates,
+                                    gmx::ArrayRef<Vec3>       forces,
+                                    gmx::ArrayRef<Vec3>       shiftForces,
+                                    gmx::ArrayRef<real>       energies,
+                                    bool                      usePbc)
 {
-    if (coordinates.size() != forces.size())
-    {
-        throw InputException("Coordinates array and force buffer size mismatch");
-    }
-
-    // check if the force buffers have the same size
-    if (masterForceBuffer_.size() != forces.size())
-    {
-        throw InputException("Input force buffer size mismatch with listed forces buffer");
-    }
-
-    // compute forces and fill in local buffers
-    computeForcesAndEnergies(coordinates, usePbc);
-
-    // add forces to output force buffers
-    for (int pIndex = 0; pIndex < int(forces.size()); pIndex++)
+    computeForcesAndEnergies(coordinates, forces, shiftForces, usePbc);
+    if (!energies.empty())
     {
-        forces[pIndex] += masterForceBuffer_[pIndex];
+        std::copy(energyBuffer_.begin(), energyBuffer_.end(), energies.begin());
     }
 }
+
 void ListedForceCalculator::compute(gmx::ArrayRef<const Vec3> coordinates,
                                     gmx::ArrayRef<Vec3>       forces,
-                                    EnergyType&               energies,
+                                    gmx::ArrayRef<real>       energies,
                                     bool                      usePbc)
 {
-    compute(coordinates, forces, usePbc);
-
-    energies = energyBuffer_;
+    computeForcesAndEnergies(coordinates, forces, gmx::ArrayRef<std::nullptr_t>{}, usePbc);
+    if (!energies.empty())
+    {
+        std::copy(energyBuffer_.begin(), energyBuffer_.end(), energies.begin());
+    }
 }
 
 } // namespace nblib