prepareGpuKernelArguments() and launchGpuKernel() are added
[alexxy/gromacs.git] / src / gromacs / ewald / pme-spread.cu
index 65748941467d26aab305b9a6a5657a0790423091..89df2c77f9f2eebff514d93c9f088caba66e7656 100644 (file)
@@ -489,7 +489,6 @@ void pme_gpu_spread(const PmeGpu    *pmeGpu,
                     bool             spreadCharges)
 {
     GMX_ASSERT(computeSplines || spreadCharges, "PME spline/spread kernel has invalid input (nothing to do)");
-    cudaStream_t  stream          = pmeGpu->archSpecific->pmeStream;
     const auto   *kernelParamsPtr = pmeGpu->kernelParams.get();
     GMX_ASSERT(kernelParamsPtr->atoms.nAtoms > 0, "No atom data in PME GPU spread");
 
@@ -503,50 +502,53 @@ void pme_gpu_spread(const PmeGpu    *pmeGpu,
     //(for spline data mostly, together with varying PME_GPU_PARALLEL_SPLINE define)
     GMX_ASSERT(!c_usePadding || !(PME_ATOM_DATA_ALIGNMENT % atomsPerBlock), "inconsistent atom data padding vs. spreading block size");
 
-    const int blockCount = pmeGpu->nAtomsPadded / atomsPerBlock;
-    auto      dimGrid    = pmeGpuCreateGrid(pmeGpu, blockCount);
-    dim3 dimBlock(order, order, atomsPerBlock);
+    const int          blockCount = pmeGpu->nAtomsPadded / atomsPerBlock;
+    auto               dimGrid    = pmeGpuCreateGrid(pmeGpu, blockCount);
+
+    KernelLaunchConfig config;
+    config.blockSize[0] = config.blockSize[1] = order;
+    config.blockSize[2] = atomsPerBlock;
+    config.gridSize[0]  = dimGrid.x;
+    config.gridSize[1]  = dimGrid.y;
+    config.stream       = pmeGpu->archSpecific->pmeStream;
+
+    if (order != 4)
+    {
+        GMX_THROW(gmx::NotImplementedError("The code for pme_order != 4 was not implemented!"));
+    }
 
     // These should later check for PME decomposition
-    const bool wrapX = true;
-    const bool wrapY = true;
+    constexpr bool wrapX = true;
+    constexpr bool wrapY = true;
     GMX_UNUSED_VALUE(wrapX);
     GMX_UNUSED_VALUE(wrapY);
-    switch (order)
+
+    int  timingId;
+    void (*kernelPtr)(const PmeGpuCudaKernelParams) = nullptr;
+    if (computeSplines)
     {
-        case 4:
+        if (spreadCharges)
         {
-            // TODO: cleaner unroll with some template trick?
-            if (computeSplines)
-            {
-                if (spreadCharges)
-                {
-                    pme_gpu_start_timing(pmeGpu, gtPME_SPLINEANDSPREAD);
-                    pme_spline_and_spread_kernel<4, true, true, wrapX, wrapY> <<< dimGrid, dimBlock, 0, stream>>> (*kernelParamsPtr);
-                    CU_LAUNCH_ERR("pme_spline_and_spread_kernel");
-                    pme_gpu_stop_timing(pmeGpu, gtPME_SPLINEANDSPREAD);
-                }
-                else
-                {
-                    pme_gpu_start_timing(pmeGpu, gtPME_SPLINE);
-                    pme_spline_and_spread_kernel<4, true, false, wrapX, wrapY> <<< dimGrid, dimBlock, 0, stream>>> (*kernelParamsPtr);
-                    CU_LAUNCH_ERR("pme_spline_and_spread_kernel");
-                    pme_gpu_stop_timing(pmeGpu, gtPME_SPLINE);
-                }
-            }
-            else
-            {
-                pme_gpu_start_timing(pmeGpu, gtPME_SPREAD);
-                pme_spline_and_spread_kernel<4, false, true, wrapX, wrapY> <<< dimGrid, dimBlock, 0, stream>>> (*kernelParamsPtr);
-                CU_LAUNCH_ERR("pme_spline_and_spread_kernel");
-                pme_gpu_stop_timing(pmeGpu, gtPME_SPREAD);
-            }
+            timingId  = gtPME_SPLINEANDSPREAD;
+            kernelPtr = pme_spline_and_spread_kernel<4, true, true, wrapX, wrapY>;
+        }
+        else
+        {
+            timingId  = gtPME_SPLINE;
+            kernelPtr = pme_spline_and_spread_kernel<4, true, false, wrapX, wrapY>;
         }
-        break;
-
-        default:
-            GMX_THROW(gmx::NotImplementedError("The code for pme_order != 4 was not tested!"));
     }
+    else
+    {
+        timingId  = gtPME_SPREAD;
+        kernelPtr = pme_spline_and_spread_kernel<4, false, true, wrapX, wrapY>;
+    }
+
+    pme_gpu_start_timing(pmeGpu, timingId);
+    auto      *timingEvent = pme_gpu_fetch_timing_event(pmeGpu, timingId);
+    const auto kernelArgs  = prepareGpuKernelArguments(kernelPtr, config, kernelParamsPtr);
+    launchGpuKernel(kernelPtr, config, timingEvent, "PME spline/spread", kernelArgs);
+    pme_gpu_stop_timing(pmeGpu, timingId);
 
     const bool copyBackGrid = spreadCharges && (pme_gpu_is_testing(pmeGpu) || !pme_gpu_performs_FFT(pmeGpu));
     if (copyBackGrid)