Remove hardcoded warp_size == 32 assumption from PME GPU
[alexxy/gromacs.git] / src / gromacs / ewald / pme-gpu-utils.h
index bd8b36fe541211224365fd432681d98bbfb63553..c2424a37219b977fa83f273db59d05bb72b83635 100644 (file)
  * Removing warp dependency would also be nice (and would probably coincide with removing PME_SPREADGATHER_ATOMS_PER_WARP).
  *
  * \tparam order               PME order
+ * \tparam atomsPerWarp        Number of atoms processed by a warp
  * \param[in] warpIndex        Warp index wrt the block.
- * \param[in] atomWarpIndex    Atom index wrt the warp (from 0 to PME_SPREADGATHER_ATOMS_PER_WARP - 1).
+ * \param[in] atomWarpIndex    Atom index wrt the warp (from 0 to atomsPerWarp - 1).
  *
  * \returns Index into theta or dtheta array using GPU layout.
  */
-template <int order>
+template <int order, int atomsPerWarp>
 int INLINE_EVERYWHERE getSplineParamIndexBase(int warpIndex, int atomWarpIndex)
 {
-    assert((atomWarpIndex >= 0) && (atomWarpIndex < PME_SPREADGATHER_ATOMS_PER_WARP));
+    assert((atomWarpIndex >= 0) && (atomWarpIndex < atomsPerWarp));
     const int dimIndex    = 0;
     const int splineIndex = 0;
     // The zeroes are here to preserve the full index formula for reference
-    return (((splineIndex + order * warpIndex) * DIM + dimIndex) * PME_SPREADGATHER_ATOMS_PER_WARP + atomWarpIndex);
+    return (((splineIndex + order * warpIndex) * DIM + dimIndex) * atomsPerWarp + atomWarpIndex);
 }
 
 /*! \internal \brief
@@ -86,18 +87,19 @@ int INLINE_EVERYWHERE getSplineParamIndexBase(int warpIndex, int atomWarpIndex)
  * This function consumes result of getSplineParamIndexBase() and adjusts it for \p dimIndex and \p splineIndex.
  *
  * \tparam order               PME order
+ * \tparam atomsPerWarp        Number of atoms processed by a warp
  * \param[in] paramIndexBase   Must be result of getSplineParamIndexBase().
  * \param[in] dimIndex         Dimension index (from 0 to 2)
  * \param[in] splineIndex      Spline contribution index (from 0 to \p order - 1)
  *
  * \returns Index into theta or dtheta array using GPU layout.
  */
-template <int order>
+template <int order, int atomsPerWarp>
 int INLINE_EVERYWHERE getSplineParamIndex(int paramIndexBase, int dimIndex, int splineIndex)
 {
     assert((dimIndex >= XX) && (dimIndex < DIM));
     assert((splineIndex >= 0) && (splineIndex < order));
-    return (paramIndexBase + (splineIndex * DIM + dimIndex) * PME_SPREADGATHER_ATOMS_PER_WARP);
+    return (paramIndexBase + (splineIndex * DIM + dimIndex) * atomsPerWarp);
 }
 
 #endif