Pad RVec force buffer in ThreadForceBuffer
authorBerk Hess <hess@kth.se>
Thu, 30 Sep 2021 13:40:53 +0000 (13:40 +0000)
committerSzilárd Páll <pall.szilard@gmail.com>
Thu, 30 Sep 2021 13:40:53 +0000 (13:40 +0000)
src/gromacs/gmxlib/nonbonded/nb_free_energy.cpp
src/gromacs/gmxlib/nonbonded/nb_free_energy.h
src/gromacs/gmxlib/nonbonded/tests/nb_free_energy.cpp
src/gromacs/listed_forces/listed_forces.cpp
src/gromacs/mdtypes/threaded_force_buffer.cpp
src/gromacs/mdtypes/threaded_force_buffer.h
src/gromacs/nbnxm/freeenergydispatch.cpp

index ad316bd5da5020a31631387cbb416f019b398993..dbb17010c2ea80573c290e2328d9c0802eb39cef 100644 (file)
@@ -248,11 +248,11 @@ static void nb_free_energy_kernel(const t_nblist&
                                   int                                              flags,
                                   gmx::ArrayRef<const real>                        lambda,
                                   t_nrnb* gmx_restrict                             nrnb,
-                                  gmx::RVec*          threadForceBuffer,
-                                  rvec*               threadForceShiftBuffer,
-                                  gmx::ArrayRef<real> threadVc,
-                                  gmx::ArrayRef<real> threadVv,
-                                  gmx::ArrayRef<real> threadDvdl)
+                                  gmx::ArrayRefWithPadding<gmx::RVec> threadForceBuffer,
+                                  rvec*                               threadForceShiftBuffer,
+                                  gmx::ArrayRef<real>                 threadVc,
+                                  gmx::ArrayRef<real>                 threadVv,
+                                  gmx::ArrayRef<real>                 threadDvdl)
 {
 #define STATE_A 0
 #define STATE_B 1
@@ -391,8 +391,9 @@ static void nb_free_energy_kernel(const t_nblist&
         dlFacVdw[i]  = DLF[i] * lam_power / sc_r_power * (lam_power == 2 ? (1 - LFV[i]) : 1);
     }
 
-    // TODO: We should get rid of using pointers to real
-    const real* gmx_restrict x = coords.paddedConstArrayRef().data()[0];
+    // We need pointers to real for SIMD access
+    const real* gmx_restrict x            = coords.paddedConstArrayRef().data()[0];
+    real* gmx_restrict       forceRealPtr = threadForceBuffer.paddedArrayRef().data()[0];
 
     const real rlistSquared = gmx::square(rlist);
 
@@ -914,8 +915,7 @@ static void nb_free_energy_kernel(const t_nblist&
                 fIY               = fIY + tY;
                 fIZ               = fIZ + tZ;
 
-                gmx::transposeScatterDecrU<3>(
-                        reinterpret_cast<real*>(threadForceBuffer), preloadJnr, tX, tY, tZ);
+                gmx::transposeScatterDecrU<3>(forceRealPtr, preloadJnr, tX, tY, tZ);
             }
         } // end for (int k = nj0; k < nj1; k += DataTypes::simdRealWidth)
 
@@ -923,8 +923,7 @@ static void nb_free_energy_kernel(const t_nblist&
         {
             if (doForces)
             {
-                gmx::transposeScatterIncrU<3>(
-                        reinterpret_cast<real*>(threadForceBuffer), preloadIi, fIX, fIY, fIZ);
+                gmx::transposeScatterIncrU<3>(forceRealPtr, preloadIi, fIX, fIY, fIZ);
             }
             if (doShiftForces)
             {
@@ -985,7 +984,7 @@ typedef void (*KernelFunction)(const t_nblist&
                                int                                              flags,
                                gmx::ArrayRef<const real>                        lambda,
                                t_nrnb* gmx_restrict                             nrnb,
-                               gmx::RVec*                                       threadForceBuffer,
+                               gmx::ArrayRefWithPadding<gmx::RVec>              threadForceBuffer,
                                rvec*               threadForceShiftBuffer,
                                gmx::ArrayRef<real> threadVc,
                                gmx::ArrayRef<real> threadVv,
@@ -1119,7 +1118,7 @@ void gmx_nb_free_energy_kernel(const t_nblist&
                                int                                              flags,
                                gmx::ArrayRef<const real>                        lambda,
                                t_nrnb*                                          nrnb,
-                               gmx::RVec*                                       threadForceBuffer,
+                               gmx::ArrayRefWithPadding<gmx::RVec>              threadForceBuffer,
                                rvec*               threadForceShiftBuffer,
                                gmx::ArrayRef<real> threadVc,
                                gmx::ArrayRef<real> threadVv,
@@ -1129,6 +1128,11 @@ void gmx_nb_free_energy_kernel(const t_nblist&
                "Unsupported eeltype with free energy");
     GMX_ASSERT(ic.softCoreParameters, "We need soft-core parameters");
 
+    // Not all SIMD implementations need padding, but we provide padding anyhow so we can assert
+    GMX_ASSERT(!GMX_SIMD_HAVE_REAL || threadForceBuffer.empty()
+                       || threadForceBuffer.size() > threadForceBuffer.unpaddedArrayRef().ssize(),
+               "We need actual padding with at least one element for SIMD scatter operations");
+
     const auto& scParams                   = *ic.softCoreParameters;
     const bool  vdwInteractionTypeIsEwald  = (EVDW_PME(ic.vdwtype));
     const bool  elecInteractionTypeIsEwald = (EEL_PME_EWALD(ic.eeltype));
index e942afbf1bd2c03983e33b569cceed0d329f21c9..fb97c67bfa027a3a88dab4c25b333ba895db77e2 100644 (file)
@@ -69,7 +69,7 @@ void gmx_nb_free_energy_kernel(const t_nblist&
                                int                                              flags,
                                gmx::ArrayRef<const real>                        lambda,
                                t_nrnb* gmx_restrict                             nrnb,
-                               gmx::RVec*                                       threadForceBuffer,
+                               gmx::ArrayRefWithPadding<gmx::RVec>              threadForceBuffer,
                                rvec*               threadForceShiftBuffer,
                                gmx::ArrayRef<real> threadVc,
                                gmx::ArrayRef<real> threadVv,
index d25b98c1ddb703ee66890dfa03ee992eb56aa633..a4414ee71622a65ced93a0161e1dcfd41c8529e8 100644 (file)
@@ -469,7 +469,7 @@ protected:
                                   doNBFlags,
                                   lambdas,
                                   &nrnb,
-                                  output.f.arrayRefWithPadding().paddedArrayRef().data(),
+                                  output.f.arrayRefWithPadding(),
                                   as_rvec_array(output.fShift.data()),
                                   output.energy.energyGroupPairTerms[NonBondedEnergyTerms::CoulombSR],
                                   output.energy.energyGroupPairTerms[NonBondedEnergyTerms::LJSR],
index 25900e838d0f102ebb28769000b28915d9d9ecef..7a90343212ab631541aa26780235bfc66cec0c61 100644 (file)
@@ -398,7 +398,7 @@ static void calcBondedForces(const InteractionDefinitions& idef,
 
             threadBuffer.clearForcesAndEnergies();
 
-            rvec4* ft = threadBuffer.forceBuffer();
+            rvec4* ft = threadBuffer.forceBuffer().data();
 
             /* Thread 0 writes directly to the main output buffers.
              * We might want to reconsider this.
index 0aca081f2c4945690b8101e60b07b452d21a49ae..515d1934b70ffc779b3eba7ff947092c0df6268f 100644 (file)
@@ -109,7 +109,16 @@ void ThreadForceBuffer<ForceBufferElementType>::resizeBufferAndClearMask(const i
     const int numBlocks = (numAtoms + s_reductionBlockSize - 1) >> s_numReductionBlockBits;
 
     reductionMask_.resize(numBlocks);
-    forceBuffer_.resize(numBlocks * s_reductionBlockSize * sizeof(ForceBufferElementType) / sizeof(real));
+
+    constexpr size_t c_numComponentsInElement = sizeof(ForceBufferElementType) / sizeof(real);
+    int              newNumElements           = numBlocks * s_reductionBlockSize;
+    if (c_numComponentsInElement != 4 && newNumElements == numAtoms)
+    {
+        // Pad with one element to allow 4-wide SIMD loads and stores.
+        // Note that actually only one real is needed, but we need a whole element for the ArrayRef.
+        newNumElements += 1;
+    }
+    forceBuffer_.resize(newNumElements * c_numComponentsInElement);
 
     for (gmx_bitmask_t& mask : reductionMask_)
     {
@@ -175,7 +184,8 @@ void reduceThreadForceBuffers(ArrayRef<gmx::RVec> force,
             {
                 if (bitmask_is_set(masks[blockIndex], ft))
                 {
-                    fp[numContributingBuffers++] = threadForceBuffers[ft]->forceBuffer();
+                    fp[numContributingBuffers++] =
+                            threadForceBuffers[ft]->forceBufferWithPadding().paddedArrayRef().data();
                 }
             }
             if (numContributingBuffers > 0)
index 6eeae6008f501bc4363de959ca48d30fe931a891..d6b948c88a78c68192a3de82e611e54e1ccaa08c 100644 (file)
@@ -61,6 +61,7 @@
 
 #include <memory>
 
+#include "gromacs/math/arrayrefwithpadding.h"
 #include "gromacs/math/vectypes.h"
 #include "gromacs/mdtypes/enerdata.h"
 #include "gromacs/mdtypes/simulation_workload.h"
@@ -118,10 +119,25 @@ public:
     //! Clears all force and energy buffers
     void clearForcesAndEnergies();
 
-    //! Returns a plain pointer to the force buffer
-    ForceBufferElementType* forceBuffer()
+    //! Returns an array reference to the force buffer which is aligned for SIMD access
+    ArrayRef<ForceBufferElementType> forceBuffer()
     {
-        return reinterpret_cast<ForceBufferElementType*>(forceBuffer_.data());
+        return ArrayRef<ForceBufferElementType>(
+                reinterpret_cast<ForceBufferElementType*>(forceBuffer_.data()),
+                reinterpret_cast<ForceBufferElementType*>(forceBuffer_.data()) + numAtoms_);
+    }
+
+    /*! \brief Returns an array reference with padding to the force buffer which is aligned for SIMD access
+     *
+     * For RVec there is padding of one real for 4-wide SIMD access.
+     * For both RVec and rvec4 there is padding up to the block size for use in ThreadedForceBuffer.
+     */
+    ArrayRefWithPadding<ForceBufferElementType> forceBufferWithPadding()
+    {
+        return ArrayRefWithPadding<ForceBufferElementType>(
+                reinterpret_cast<ForceBufferElementType*>(forceBuffer_.data()),
+                reinterpret_cast<ForceBufferElementType*>(forceBuffer_.data()) + numAtoms_,
+                reinterpret_cast<ForceBufferElementType*>(forceBuffer_.data() + forceBuffer_.size()));
     }
 
     //! Returns a view of the shift force buffer
@@ -140,9 +156,9 @@ public:
     ArrayRef<const gmx_bitmask_t> reductionMask() const { return reductionMask_; }
 
 private:
-    //! Force array buffer
+    //! Force array buffer, aligned to enable aligned SIMD access
     std::vector<real, AlignedAllocator<real>> forceBuffer_;
-    //! Mask for marking which parts of f are filled, working array for constructing mask in bonded_threading_t
+    //! Mask for marking which parts of f are filled, working array for constructing mask in setupReduction()
     std::vector<gmx_bitmask_t> reductionMask_;
     //! Index to touched blocks
     std::vector<int> usedBlockIndices_;
index 826bd9ff61baba8b8244a0754979a70414dfb4c7..f5c0667bcf1fed48d44d4f23ccc2d3c8ce815f7f 100644 (file)
@@ -191,7 +191,7 @@ void dispatchFreeEnergyKernel(gmx::ArrayRef<const std::unique_ptr<t_nblist>>   n
                 threadForceBuffer.clearForcesAndEnergies();
             }
 
-            gmx::RVec* threadForces      = threadForceBuffer.forceBuffer();
+            auto  threadForces           = threadForceBuffer.forceBufferWithPadding();
             rvec* threadForceShiftBuffer = as_rvec_array(threadForceBuffer.shiftForces().data());
             gmx::ArrayRef<real> threadVc =
                     threadForceBuffer.groupPairEnergies().energyGroupPairTerms[NonBondedEnergyTerms::CoulombSR];
@@ -281,7 +281,7 @@ void dispatchFreeEnergyKernel(gmx::ArrayRef<const std::unique_ptr<t_nblist>>   n
                                               kernelFlags,
                                               lam_i,
                                               nrnb,
-                                              nullptr,
+                                              gmx::ArrayRefWithPadding<gmx::RVec>(),
                                               nullptr,
                                               threadVc,
                                               threadVv,