Decouple GPU force buffer management from buffer ops in NBNXM
[alexxy/gromacs.git] / src / gromacs / nbnxm / cuda / nbnxm_buffer_ops_kernels.cuh
index bc0d2c2c5e81e1dda9f0ec8bbb66b4bc4ea9296f..3c8f7b1cb3877568852c527bd99d386ded9f9700 100644 (file)
@@ -137,56 +137,59 @@ __global__ void nbnxn_gpu_x_to_nbat_x_kernel(int                         numColu
 
 }
 
-/*! \brief CUDA kernel to add part of the force array(s) from nbnxn_atomdata_t to f
+/*! \brief CUDA kernel to sum up the force components
  *
- * \param[in]     fnb              Force in nbat format
- * \param[in]     fPmeDeviceBuffer PME force
- * \param[in,out] f                Force buffer to be reduced into
- * \param[in]     cell             Cell index mapping
- * \param[in]     atomStart        Start atom index
- * \param[in]     nAtoms           Number of Atoms
+ * \tparam        accumulateForce  If the initial forces in \p d_fTotal should be saved.
+ * \tparam        addPmeForce      Whether the PME force should be added to the total.
+ *
+ * \param[in]     d_fNB            Non-bonded forces in nbat format.
+ * \param[in]     d_fPme           PME forces.
+ * \param[in,out] d_fTotal         Force buffer to be reduced into.
+ * \param[in]     cell             Cell index mapping.
+ * \param[in]     atomStart        Start atom index.
+ * \param[in]     numAtoms         Number of atoms.
  */
-template <bool accumulateForce, bool addPmeF>
+template <bool accumulateForce, bool addPmeForce>
 __global__ void
-nbnxn_gpu_add_nbat_f_to_f_kernel(const float3 *__restrict__ fnb,
-                                 const float3 *__restrict__ fPmeDeviceBuffer,
-                                 float3                   * f,
-                                 const int *__restrict__    cell,
-                                 const int                  atomStart,
-                                 const int                  nAtoms);
-template <bool accumulateForce, bool addPmeF>
+nbnxn_gpu_add_nbat_f_to_f_kernel(const float3 *__restrict__  d_fNB,
+                                 const float3 *__restrict__  d_fPme,
+                                 float3                     *d_fTotal,
+                                 const int *__restrict__     d_cell,
+                                 const int                   atomStart,
+                                 const int                   numAtoms);
+template <bool accumulateForce, bool addPmeForce>
 __global__ void
-nbnxn_gpu_add_nbat_f_to_f_kernel(const float3 *__restrict__ fnb,
-                                 const float3 *__restrict__ fPmeDeviceBuffer,
-                                 float3                   * f,
-                                 const int *__restrict__    cell,
-                                 const int                  atomStart,
-                                 const int                  nAtoms)
+nbnxn_gpu_add_nbat_f_to_f_kernel(const float3 *__restrict__  d_fNB,
+                                 const float3 *__restrict__  d_fPme,
+                                 float3                     *d_fTotal,
+                                 const int *__restrict__     d_cell,
+                                 const int                   atomStart,
+                                 const int                   numAtoms)
 {
 
     /* map particle-level parallelism to 1D CUDA thread and block index */
     int threadIndex = blockIdx.x*blockDim.x+threadIdx.x;
 
     /* perform addition for each particle*/
-    if (threadIndex < nAtoms)
+    if (threadIndex < numAtoms)
     {
 
-        int     i        = cell[atomStart+threadIndex];
-        float3 *fDest    = (float3 *)&f[atomStart+threadIndex];
+        int     i        = d_cell[atomStart+threadIndex];
+        float3 *fDest    = (float3 *)&d_fTotal[atomStart+threadIndex];
         float3  temp;
 
         if (accumulateForce)
         {
             temp  = *fDest;
-            temp += fnb[i];
+            temp += d_fNB[i];
         }
         else
         {
-            temp = fnb[i];
+            temp = d_fNB[i];
         }
-        if (addPmeF)
+        if (addPmeForce)
         {
-            temp += fPmeDeviceBuffer[atomStart+threadIndex];
+            temp += d_fPme[atomStart+threadIndex];
         }
         *fDest = temp;