Pad RVec force buffer in ThreadForceBuffer
[alexxy/gromacs.git] / src / gromacs / gmxlib / nonbonded / nb_free_energy.cpp
index ad316bd5da5020a31631387cbb416f019b398993..dbb17010c2ea80573c290e2328d9c0802eb39cef 100644 (file)
@@ -248,11 +248,11 @@ static void nb_free_energy_kernel(const t_nblist&
                                   int                                              flags,
                                   gmx::ArrayRef<const real>                        lambda,
                                   t_nrnb* gmx_restrict                             nrnb,
-                                  gmx::RVec*          threadForceBuffer,
-                                  rvec*               threadForceShiftBuffer,
-                                  gmx::ArrayRef<real> threadVc,
-                                  gmx::ArrayRef<real> threadVv,
-                                  gmx::ArrayRef<real> threadDvdl)
+                                  gmx::ArrayRefWithPadding<gmx::RVec> threadForceBuffer,
+                                  rvec*                               threadForceShiftBuffer,
+                                  gmx::ArrayRef<real>                 threadVc,
+                                  gmx::ArrayRef<real>                 threadVv,
+                                  gmx::ArrayRef<real>                 threadDvdl)
 {
 #define STATE_A 0
 #define STATE_B 1
@@ -391,8 +391,9 @@ static void nb_free_energy_kernel(const t_nblist&
         dlFacVdw[i]  = DLF[i] * lam_power / sc_r_power * (lam_power == 2 ? (1 - LFV[i]) : 1);
     }
 
-    // TODO: We should get rid of using pointers to real
-    const real* gmx_restrict x = coords.paddedConstArrayRef().data()[0];
+    // We need pointers to real for SIMD access
+    const real* gmx_restrict x            = coords.paddedConstArrayRef().data()[0];
+    real* gmx_restrict       forceRealPtr = threadForceBuffer.paddedArrayRef().data()[0];
 
     const real rlistSquared = gmx::square(rlist);
 
@@ -914,8 +915,7 @@ static void nb_free_energy_kernel(const t_nblist&
                 fIY               = fIY + tY;
                 fIZ               = fIZ + tZ;
 
-                gmx::transposeScatterDecrU<3>(
-                        reinterpret_cast<real*>(threadForceBuffer), preloadJnr, tX, tY, tZ);
+                gmx::transposeScatterDecrU<3>(forceRealPtr, preloadJnr, tX, tY, tZ);
             }
         } // end for (int k = nj0; k < nj1; k += DataTypes::simdRealWidth)
 
@@ -923,8 +923,7 @@ static void nb_free_energy_kernel(const t_nblist&
         {
             if (doForces)
             {
-                gmx::transposeScatterIncrU<3>(
-                        reinterpret_cast<real*>(threadForceBuffer), preloadIi, fIX, fIY, fIZ);
+                gmx::transposeScatterIncrU<3>(forceRealPtr, preloadIi, fIX, fIY, fIZ);
             }
             if (doShiftForces)
             {
@@ -985,7 +984,7 @@ typedef void (*KernelFunction)(const t_nblist&
                                int                                              flags,
                                gmx::ArrayRef<const real>                        lambda,
                                t_nrnb* gmx_restrict                             nrnb,
-                               gmx::RVec*                                       threadForceBuffer,
+                               gmx::ArrayRefWithPadding<gmx::RVec>              threadForceBuffer,
                                rvec*               threadForceShiftBuffer,
                                gmx::ArrayRef<real> threadVc,
                                gmx::ArrayRef<real> threadVv,
@@ -1119,7 +1118,7 @@ void gmx_nb_free_energy_kernel(const t_nblist&
                                int                                              flags,
                                gmx::ArrayRef<const real>                        lambda,
                                t_nrnb*                                          nrnb,
-                               gmx::RVec*                                       threadForceBuffer,
+                               gmx::ArrayRefWithPadding<gmx::RVec>              threadForceBuffer,
                                rvec*               threadForceShiftBuffer,
                                gmx::ArrayRef<real> threadVc,
                                gmx::ArrayRef<real> threadVv,
@@ -1129,6 +1128,11 @@ void gmx_nb_free_energy_kernel(const t_nblist&
                "Unsupported eeltype with free energy");
     GMX_ASSERT(ic.softCoreParameters, "We need soft-core parameters");
 
+    // Not all SIMD implementations need padding, but we provide padding anyhow so we can assert
+    GMX_ASSERT(!GMX_SIMD_HAVE_REAL || threadForceBuffer.empty()
+                       || threadForceBuffer.size() > threadForceBuffer.unpaddedArrayRef().ssize(),
+               "We need actual padding with at least one element for SIMD scatter operations");
+
     const auto& scParams                   = *ic.softCoreParameters;
     const bool  vdwInteractionTypeIsEwald  = (EVDW_PME(ic.vdwtype));
     const bool  elecInteractionTypeIsEwald = (EEL_PME_EWALD(ic.eeltype));