Two sets of coefficients for Coulomb FEP PME on GPU
[alexxy/gromacs.git] / src / gromacs / ewald / pme_gather.cu
index 2edec547ef7d17e39365cf5d0ce07f99202b1173..52814c4ce44f372c70c1be0366c88e609779966f 100644 (file)
@@ -213,6 +213,119 @@ __device__ __forceinline__ void reduce_atom_forces(float3* __restrict__ sm_force
     }
 }
 
+/*! \brief Calculate the sum of the force partial components (in X, Y and Z)
+ *
+ * \tparam[in] order              The PME order (must be 4).
+ * \tparam[in] atomsPerWarp       The number of atoms per GPU warp.
+ * \tparam[in] wrapX              Tells if the grid is wrapped in the X dimension.
+ * \tparam[in] wrapY              Tells if the grid is wrapped in the Y dimension.
+ * \param[out] fx                 The force partial component in the X dimension.
+ * \param[out] fy                 The force partial component in the Y dimension.
+ * \param[out] fz                 The force partial component in the Z dimension.
+ * \param[in] ithyMin             The thread minimum index in the Y dimension.
+ * \param[in] ithyMax             The thread maximum index in the Y dimension.
+ * \param[in] ixBase              The grid line index base value in the X dimension.
+ * \param[in] iz                  The grid line index in the Z dimension.
+ * \param[in] nx                  The grid real size in the X dimension.
+ * \param[in] ny                  The grid real size in the Y dimension.
+ * \param[in] pny                 The padded grid real size in the Y dimension.
+ * \param[in] pnz                 The padded grid real size in the Z dimension.
+ * \param[in] atomIndexLocal      The atom index for this thread.
+ * \param[in] splineIndexBase     The base value of the spline parameter index.
+ * \param[in] tdz                 The theta and dtheta in the Z dimension.
+ * \param[in] sm_gridlineIndices  Shared memory array of grid line indices.
+ * \param[in] sm_theta            Shared memory array of atom theta values.
+ * \param[in] sm_dtheta           Shared memory array of atom dtheta values.
+ * \param[in] gm_grid             Global memory array of the grid to use.
+ */
+template<int order, int atomsPerWarp, bool wrapX, bool wrapY>
+__device__ __forceinline__ void sumForceComponents(float* __restrict__ fx,
+                                                   float* __restrict__ fy,
+                                                   float* __restrict__ fz,
+                                                   const int    ithyMin,
+                                                   const int    ithyMax,
+                                                   const int    ixBase,
+                                                   const int    iz,
+                                                   const int    nx,
+                                                   const int    ny,
+                                                   const int    pny,
+                                                   const int    pnz,
+                                                   const int    atomIndexLocal,
+                                                   const int    splineIndexBase,
+                                                   const float2 tdz,
+                                                   const int* __restrict__ sm_gridlineIndices,
+                                                   const float* __restrict__ sm_theta,
+                                                   const float* __restrict__ sm_dtheta,
+                                                   const float* __restrict__ gm_grid)
+{
+#pragma unroll
+    for (int ithy = ithyMin; ithy < ithyMax; ithy++)
+    {
+        const int splineIndexY = getSplineParamIndex<order, atomsPerWarp>(splineIndexBase, YY, ithy);
+        const float2 tdy       = make_float2(sm_theta[splineIndexY], sm_dtheta[splineIndexY]);
+
+        int iy = sm_gridlineIndices[atomIndexLocal * DIM + YY] + ithy;
+        if (wrapY & (iy >= ny))
+        {
+            iy -= ny;
+        }
+        const int constOffset = iy * pnz + iz;
+
+#pragma unroll
+        for (int ithx = 0; (ithx < order); ithx++)
+        {
+            int ix = ixBase + ithx;
+            if (wrapX & (ix >= nx))
+            {
+                ix -= nx;
+            }
+            const int gridIndexGlobal = ix * pny * pnz + constOffset;
+            assert(gridIndexGlobal >= 0);
+            const float gridValue = gm_grid[gridIndexGlobal];
+            assert(isfinite(gridValue));
+            const int splineIndexX = getSplineParamIndex<order, atomsPerWarp>(splineIndexBase, XX, ithx);
+            const float2 tdx       = make_float2(sm_theta[splineIndexX], sm_dtheta[splineIndexX]);
+            const float  fxy1      = tdz.x * gridValue;
+            const float  fz1       = tdz.y * gridValue;
+            *fx += tdx.y * tdy.x * fxy1;
+            *fy += tdx.x * tdy.y * fxy1;
+            *fz += tdx.x * tdy.x * fz1;
+        }
+    }
+}
+
+/*! \brief Calculate the grid forces and store them in shared memory.
+ *
+ * \param[in,out] sm_forces       Shared memory array with the output forces.
+ * \param[in] forceIndexLocal     The local (per thread) index in the sm_forces array.
+ * \param[in] forceIndexGlobal    The index of the thread in the gm_coefficients array.
+ * \param[in] recipBox            The reciprocal box.
+ * \param[in] scale               The scale to use when calculating the forces. For gm_coefficientsB
+ * (when using multiple coefficients on a single grid) the scale will be (1.0 - scale).
+ * \param[in] gm_coefficients     Global memory array of the coefficients to use for an unperturbed
+ * or FEP in state A if a single grid is used (\p multiCoefficientsSingleGrid == true).If two
+ * separate grids are used this should be the coefficients of the grid in question.
+ * \param[in] gm_coefficientsB    Global memory array of the coefficients to use for FEP in state B.
+ * Should be nullptr if two separate grids are used.
+ */
+__device__ __forceinline__ void calculateAndStoreGridForces(float3* __restrict__ sm_forces,
+                                                            const int   forceIndexLocal,
+                                                            const int   forceIndexGlobal,
+                                                            const float recipBox[DIM][DIM],
+                                                            const float scale,
+                                                            const float* __restrict__ gm_coefficients)
+{
+    const float3 atomForces     = sm_forces[forceIndexLocal];
+    float        negCoefficient = -scale * gm_coefficients[forceIndexGlobal];
+    float3       result;
+    result.x = negCoefficient * recipBox[XX][XX] * atomForces.x;
+    result.y = negCoefficient * (recipBox[XX][YY] * atomForces.x + recipBox[YY][YY] * atomForces.y);
+    result.z = negCoefficient
+               * (recipBox[XX][ZZ] * atomForces.x + recipBox[YY][ZZ] * atomForces.y
+                  + recipBox[ZZ][ZZ] * atomForces.z);
+    sm_forces[forceIndexLocal] = result;
+}
+
 /*! \brief
  * A CUDA kernel which gathers the atom forces from the grid.
  * The grid is assumed to be wrapped in dimension Z.
@@ -220,19 +333,24 @@ __device__ __forceinline__ void reduce_atom_forces(float3* __restrict__ sm_force
  * \tparam[in] order                The PME order (must be 4 currently).
  * \tparam[in] wrapX                Tells if the grid is wrapped in the X dimension.
  * \tparam[in] wrapY                Tells if the grid is wrapped in the Y dimension.
+ * \tparam[in] numGrids             The number of grids to use in the kernel. Can be 1 or 2.
  * \tparam[in] readGlobal           Tells if we should read spline values from global memory
  * \tparam[in] threadsPerAtom       How many threads work on each atom
  *
  * \param[in]  kernelParams         All the PME GPU data.
  */
-template<int order, bool wrapX, bool wrapY, bool readGlobal, ThreadsPerAtom threadsPerAtom>
+template<int order, bool wrapX, bool wrapY, int numGrids, bool readGlobal, ThreadsPerAtom threadsPerAtom>
 __launch_bounds__(c_gatherMaxThreadsPerBlock, c_gatherMinBlocksPerMP) __global__
         void pme_gather_kernel(const PmeGpuCudaKernelParams kernelParams)
 {
+    assert(numGrids == 1 || numGrids == 2);
+
     /* Global memory pointers */
-    const float* __restrict__ gm_coefficients = kernelParams.atoms.d_coefficients;
-    const float* __restrict__ gm_grid         = kernelParams.grid.d_realGrid;
-    float* __restrict__ gm_forces             = kernelParams.atoms.d_forces;
+    const float* __restrict__ gm_coefficientsA = kernelParams.atoms.d_coefficients[0];
+    const float* __restrict__ gm_coefficientsB = kernelParams.atoms.d_coefficients[1];
+    const float* __restrict__ gm_gridA         = kernelParams.grid.d_realGrid[0];
+    const float* __restrict__ gm_gridB         = kernelParams.grid.d_realGrid[1];
+    float* __restrict__ gm_forces              = kernelParams.atoms.d_forces;
 
     /* Global memory pointers for readGlobal */
     const float* __restrict__ gm_theta         = kernelParams.atoms.d_theta;
@@ -328,7 +446,7 @@ __launch_bounds__(c_gatherMaxThreadsPerBlock, c_gatherMinBlocksPerMP) __global__
             // Coordinates
             __shared__ float3 sm_coordinates[atomsPerBlock];
             /* Staging coefficients/charges */
-            pme_gpu_stage_atom_data<float, atomsPerBlock, 1>(sm_coefficients, gm_coefficients);
+            pme_gpu_stage_atom_data<float, atomsPerBlock, 1>(sm_coefficients, gm_coefficientsA);
 
             /* Staging coordinates */
             pme_gpu_stage_atom_data<float3, atomsPerBlock, 1>(sm_coordinates, gm_coordinates);
@@ -339,7 +457,7 @@ __launch_bounds__(c_gatherMaxThreadsPerBlock, c_gatherMinBlocksPerMP) __global__
         else
         {
             atomX      = gm_coordinates[atomIndexGlobal];
-            atomCharge = gm_coefficients[atomIndexGlobal];
+            atomCharge = gm_coefficientsA[atomIndexGlobal];
         }
         calculate_splines<order, atomsPerBlock, atomsPerWarp, true, false>(
                 kernelParams, atomIndexOffset, atomX, atomCharge, sm_theta, sm_dtheta, sm_gridlineIndices);
@@ -349,70 +467,37 @@ __launch_bounds__(c_gatherMaxThreadsPerBlock, c_gatherMinBlocksPerMP) __global__
     float fy = 0.0f;
     float fz = 0.0f;
 
-    const int chargeCheck = pme_gpu_check_atom_charge(gm_coefficients[atomIndexGlobal]);
-
-    if (chargeCheck)
-    {
-        const int nx  = kernelParams.grid.realGridSize[XX];
-        const int ny  = kernelParams.grid.realGridSize[YY];
-        const int nz  = kernelParams.grid.realGridSize[ZZ];
-        const int pny = kernelParams.grid.realGridSizePadded[YY];
-        const int pnz = kernelParams.grid.realGridSizePadded[ZZ];
-
-        const int atomWarpIndex = atomIndexLocal % atomsPerWarp;
-        const int warpIndex     = atomIndexLocal / atomsPerWarp;
+    const int chargeCheck = pme_gpu_check_atom_charge(gm_coefficientsA[atomIndexGlobal]);
 
-        const int splineIndexBase = getSplineParamIndexBase<order, atomsPerWarp>(warpIndex, atomWarpIndex);
-        const int splineIndexZ = getSplineParamIndex<order, atomsPerWarp>(splineIndexBase, ZZ, ithz);
-        const float2 tdz       = make_float2(sm_theta[splineIndexZ], sm_dtheta[splineIndexZ]);
+    const int nx  = kernelParams.grid.realGridSize[XX];
+    const int ny  = kernelParams.grid.realGridSize[YY];
+    const int nz  = kernelParams.grid.realGridSize[ZZ];
+    const int pny = kernelParams.grid.realGridSizePadded[YY];
+    const int pnz = kernelParams.grid.realGridSizePadded[ZZ];
 
-        int       iz     = sm_gridlineIndices[atomIndexLocal * DIM + ZZ] + ithz;
-        const int ixBase = sm_gridlineIndices[atomIndexLocal * DIM + XX];
+    const int atomWarpIndex = atomIndexLocal % atomsPerWarp;
+    const int warpIndex     = atomIndexLocal / atomsPerWarp;
 
-        if (iz >= nz)
-        {
-            iz -= nz;
-        }
-        int constOffset, iy;
+    const int splineIndexBase = getSplineParamIndexBase<order, atomsPerWarp>(warpIndex, atomWarpIndex);
+    const int    splineIndexZ = getSplineParamIndex<order, atomsPerWarp>(splineIndexBase, ZZ, ithz);
+    const float2 tdz          = make_float2(sm_theta[splineIndexZ], sm_dtheta[splineIndexZ]);
 
-        const int ithyMin = (threadsPerAtom == ThreadsPerAtom::Order) ? 0 : threadIdx.y;
-        const int ithyMax = (threadsPerAtom == ThreadsPerAtom::Order) ? order : threadIdx.y + 1;
-        for (int ithy = ithyMin; ithy < ithyMax; ithy++)
-        {
-            const int splineIndexY = getSplineParamIndex<order, atomsPerWarp>(splineIndexBase, YY, ithy);
-            const float2 tdy       = make_float2(sm_theta[splineIndexY], sm_dtheta[splineIndexY]);
+    int       iz     = sm_gridlineIndices[atomIndexLocal * DIM + ZZ] + ithz;
+    const int ixBase = sm_gridlineIndices[atomIndexLocal * DIM + XX];
 
-            iy = sm_gridlineIndices[atomIndexLocal * DIM + YY] + ithy;
-            if (wrapY & (iy >= ny))
-            {
-                iy -= ny;
-            }
-            constOffset = iy * pnz + iz;
-
-#pragma unroll
-            for (int ithx = 0; (ithx < order); ithx++)
-            {
-                int ix = ixBase + ithx;
-                if (wrapX & (ix >= nx))
-                {
-                    ix -= nx;
-                }
-                const int gridIndexGlobal = ix * pny * pnz + constOffset;
-                assert(gridIndexGlobal >= 0);
-                const float gridValue = gm_grid[gridIndexGlobal];
-                assert(isfinite(gridValue));
-                const int splineIndexX =
-                        getSplineParamIndex<order, atomsPerWarp>(splineIndexBase, XX, ithx);
-                const float2 tdx  = make_float2(sm_theta[splineIndexX], sm_dtheta[splineIndexX]);
-                const float  fxy1 = tdz.x * gridValue;
-                const float  fz1  = tdz.y * gridValue;
-                fx += tdx.y * tdy.x * fxy1;
-                fy += tdx.x * tdy.y * fxy1;
-                fz += tdx.x * tdy.x * fz1;
-            }
-        }
+    if (iz >= nz)
+    {
+        iz -= nz;
     }
 
+    const int ithyMin = (threadsPerAtom == ThreadsPerAtom::Order) ? 0 : threadIdx.y;
+    const int ithyMax = (threadsPerAtom == ThreadsPerAtom::Order) ? order : threadIdx.y + 1;
+    if (chargeCheck)
+    {
+        sumForceComponents<order, atomsPerWarp, wrapX, wrapY>(
+                &fx, &fy, &fz, ithyMin, ithyMax, ixBase, iz, nx, ny, pny, pnz, atomIndexLocal,
+                splineIndexBase, tdz, sm_gridlineIndices, sm_theta, sm_dtheta, gm_gridA);
+    }
     // Reduction of partial force contributions
     __shared__ float3 sm_forces[atomsPerBlock];
     reduce_atom_forces<order, atomDataSize, blockSize>(sm_forces, atomIndexLocal, splineIndex, lineIndex,
@@ -420,22 +505,13 @@ __launch_bounds__(c_gatherMaxThreadsPerBlock, c_gatherMinBlocksPerMP) __global__
     __syncthreads();
 
     /* Calculating the final forces with no component branching, atomsPerBlock threads */
-    const int forceIndexLocal  = threadLocalId;
-    const int forceIndexGlobal = atomIndexOffset + forceIndexLocal;
+    const int   forceIndexLocal  = threadLocalId;
+    const int   forceIndexGlobal = atomIndexOffset + forceIndexLocal;
+    const float scale            = kernelParams.current.scale;
     if (forceIndexLocal < atomsPerBlock)
     {
-        const float3 atomForces     = sm_forces[forceIndexLocal];
-        const float  negCoefficient = -gm_coefficients[forceIndexGlobal];
-        float3       result;
-        result.x = negCoefficient * kernelParams.current.recipBox[XX][XX] * atomForces.x;
-        result.y = negCoefficient
-                   * (kernelParams.current.recipBox[XX][YY] * atomForces.x
-                      + kernelParams.current.recipBox[YY][YY] * atomForces.y);
-        result.z = negCoefficient
-                   * (kernelParams.current.recipBox[XX][ZZ] * atomForces.x
-                      + kernelParams.current.recipBox[YY][ZZ] * atomForces.y
-                      + kernelParams.current.recipBox[ZZ][ZZ] * atomForces.z);
-        sm_forces[forceIndexLocal] = result;
+        calculateAndStoreGridForces(sm_forces, forceIndexLocal, forceIndexGlobal,
+                                    kernelParams.current.recipBox, scale, gm_coefficientsA);
     }
 
     __syncwarp();
@@ -450,18 +526,65 @@ __launch_bounds__(c_gatherMaxThreadsPerBlock, c_gatherMinBlocksPerMP) __global__
 #pragma unroll
         for (int i = 0; i < numIter; i++)
         {
-            int         outputIndexLocal     = i * iterThreads + threadLocalId;
-            int         outputIndexGlobal    = blockIndex * blockForcesSize + outputIndexLocal;
-            const float outputForceComponent = ((float*)sm_forces)[outputIndexLocal];
-            gm_forces[outputIndexGlobal]     = outputForceComponent;
+            int   outputIndexLocal       = i * iterThreads + threadLocalId;
+            int   outputIndexGlobal      = blockIndex * blockForcesSize + outputIndexLocal;
+            float outputForceComponent   = ((float*)sm_forces)[outputIndexLocal];
+            gm_forces[outputIndexGlobal] = outputForceComponent;
+        }
+    }
+
+    if (numGrids == 2)
+    {
+        /* We must sync here since the same shared memory is used as above. */
+        __syncthreads();
+        fx                    = 0.0f;
+        fy                    = 0.0f;
+        fz                    = 0.0f;
+        const int chargeCheck = pme_gpu_check_atom_charge(gm_coefficientsB[atomIndexGlobal]);
+        if (chargeCheck)
+        {
+            sumForceComponents<order, atomsPerWarp, wrapX, wrapY>(
+                    &fx, &fy, &fz, ithyMin, ithyMax, ixBase, iz, nx, ny, pny, pnz, atomIndexLocal,
+                    splineIndexBase, tdz, sm_gridlineIndices, sm_theta, sm_dtheta, gm_gridB);
+        }
+        // Reduction of partial force contributions
+        reduce_atom_forces<order, atomDataSize, blockSize>(sm_forces, atomIndexLocal, splineIndex,
+                                                           lineIndex, kernelParams.grid.realGridSizeFP,
+                                                           fx, fy, fz);
+        __syncthreads();
+
+        /* Calculating the final forces with no component branching, atomsPerBlock threads */
+        if (forceIndexLocal < atomsPerBlock)
+        {
+            calculateAndStoreGridForces(sm_forces, forceIndexLocal, forceIndexGlobal,
+                                        kernelParams.current.recipBox, 1.0F - scale, gm_coefficientsB);
+        }
+
+        __syncwarp();
+
+        /* Writing or adding the final forces component-wise, single warp */
+        if (threadLocalId < iterThreads)
+        {
+#pragma unroll
+            for (int i = 0; i < numIter; i++)
+            {
+                int   outputIndexLocal     = i * iterThreads + threadLocalId;
+                int   outputIndexGlobal    = blockIndex * blockForcesSize + outputIndexLocal;
+                float outputForceComponent = ((float*)sm_forces)[outputIndexLocal];
+                gm_forces[outputIndexGlobal] += outputForceComponent;
+            }
         }
     }
 }
 
 //! Kernel instantiations
 // clang-format off
-template __global__ void pme_gather_kernel<4, true, true, true,  ThreadsPerAtom::Order>       (const PmeGpuCudaKernelParams);
-template __global__ void pme_gather_kernel<4, true, true, true,  ThreadsPerAtom::OrderSquared>(const PmeGpuCudaKernelParams);
-template __global__ void pme_gather_kernel<4, true, true, false, ThreadsPerAtom::Order>       (const PmeGpuCudaKernelParams);
-template __global__ void pme_gather_kernel<4, true, true, false, ThreadsPerAtom::OrderSquared>(const PmeGpuCudaKernelParams);
-// clang-format on
\ No newline at end of file
+template __global__ void pme_gather_kernel<4, true, true, 1, true, ThreadsPerAtom::Order>        (const PmeGpuCudaKernelParams);
+template __global__ void pme_gather_kernel<4, true, true, 1, true, ThreadsPerAtom::OrderSquared> (const PmeGpuCudaKernelParams);
+template __global__ void pme_gather_kernel<4, true, true, 1, false, ThreadsPerAtom::Order>       (const PmeGpuCudaKernelParams);
+template __global__ void pme_gather_kernel<4, true, true, 1, false, ThreadsPerAtom::OrderSquared>(const PmeGpuCudaKernelParams);
+template __global__ void pme_gather_kernel<4, true, true, 2, true, ThreadsPerAtom::Order>        (const PmeGpuCudaKernelParams);
+template __global__ void pme_gather_kernel<4, true, true, 2, true, ThreadsPerAtom::OrderSquared> (const PmeGpuCudaKernelParams);
+template __global__ void pme_gather_kernel<4, true, true, 2, false, ThreadsPerAtom::Order>       (const PmeGpuCudaKernelParams);
+template __global__ void pme_gather_kernel<4, true, true, 2, false, ThreadsPerAtom::OrderSquared>(const PmeGpuCudaKernelParams);
+// clang-format on