F buffer operations in CUDA
[alexxy/gromacs.git] / src / gromacs / nbnxm / cuda / nbnxm_buffer_ops_kernels.cuh
index 172d2221488d13cb3013a54c5095484dd93be00a..09d6e3724381a060f7883edbbac0bd62272e68f4 100644 (file)
@@ -36,8 +36,7 @@
 /*! \internal \file
  *
  * \brief
- * CUDA kernel for GPU version of copy_rvec_to_nbat_real.
- * Converts coordinate data from rvec to nb format.
+ * CUDA kernels for GPU versions of copy_rvec_to_nbat_real and add_nbat_f_to_f.
  *
  *  \author Alan Gray <alang@nvidia.com>
  *  \author Jon Vincent <jvincent@nvidia.com>
@@ -137,3 +136,51 @@ __global__ void nbnxn_gpu_x_to_nbat_x_kernel(int                         numColu
     }
 
 }
+
+/*! \brief CUDA kernel to add part of the force array(s) from nbnxn_atomdata_t to f
+ *
+ * \param[in]     fnb     Force in nbat format
+ * \param[in,out] f       Force buffer to be reduced into
+ * \param[in]     cell    Cell index mapping
+ * \param[in]     a0      start atom index
+ * \param[in]     a1      end atom index
+ * \param[in]     stride  stride between atoms in memory
+ */
+template <bool accumulateForce>
+__global__ void
+nbnxn_gpu_add_nbat_f_to_f_kernel(const float3 *__restrict__ fnb,
+                                 rvec                     * f,
+                                 const int *__restrict__    cell,
+                                 const int                  atomStart,
+                                 const int                  nAtoms);
+template <bool accumulateForce>
+__global__ void
+nbnxn_gpu_add_nbat_f_to_f_kernel(const float3 *__restrict__ fnb,
+                                 rvec                     * f,
+                                 const int *__restrict__    cell,
+                                 const int                  atomStart,
+                                 const int                  nAtoms)
+{
+
+    /* map particle-level parallelism to 1D CUDA thread and block index */
+    int threadIndex = blockIdx.x*blockDim.x+threadIdx.x;
+
+    /* perform addition for each particle*/
+    if (threadIndex < nAtoms)
+    {
+
+        int     i        = cell[atomStart+threadIndex];
+        float3 *f_dest   = (float3 *)&f[atomStart+threadIndex][XX];
+
+        if (accumulateForce)
+        {
+            *f_dest += fnb[i];
+        }
+        else
+        {
+            *f_dest = fnb[i];
+        }
+
+    }
+    return;
+}