Implement PME solve in SYCL
[alexxy/gromacs.git] / src / gromacs / ewald / pme_gpu_program_impl_sycl.cpp
index 196cff26ae3881a4bbee79fe6f1e5296971bf349..a65a2828b074859802b30735aaafa7da323e21f2 100644 (file)
@@ -50,6 +50,7 @@
 
 #include "pme_gpu_program_impl.h"
 #include "pme_gather_sycl.h"
+#include "pme_solve_sycl.h"
 #include "pme_spread_sycl.h"
 
 #include "pme_gpu_constants.h"
@@ -62,6 +63,9 @@ constexpr int c_pmeOrder = 4;
 constexpr bool c_wrapX = true;
 constexpr bool c_wrapY = true;
 
+constexpr int c_stateA = 0;
+constexpr int c_stateB = 1;
+
 static int subGroupSizeFromVendor(const DeviceInformation& deviceInfo)
 {
     switch (deviceInfo.deviceVendor)
@@ -96,9 +100,20 @@ static int subGroupSizeFromVendor(const DeviceInformation& deviceInfo)
     INSTANTIATE_##x(order, 2, ThreadsPerAtom::Order, subGroupSize);        \
     INSTANTIATE_##x(order, 2, ThreadsPerAtom::OrderSquared, subGroupSize);
 
+#define INSTANTIATE_SOLVE(subGroupSize)                                                     \
+    extern template class PmeSolveKernel<GridOrdering::XYZ, false, c_stateA, subGroupSize>; \
+    extern template class PmeSolveKernel<GridOrdering::XYZ, true, c_stateA, subGroupSize>;  \
+    extern template class PmeSolveKernel<GridOrdering::YZX, false, c_stateA, subGroupSize>; \
+    extern template class PmeSolveKernel<GridOrdering::YZX, true, c_stateA, subGroupSize>;  \
+    extern template class PmeSolveKernel<GridOrdering::XYZ, false, c_stateB, subGroupSize>; \
+    extern template class PmeSolveKernel<GridOrdering::XYZ, true, c_stateB, subGroupSize>;  \
+    extern template class PmeSolveKernel<GridOrdering::YZX, false, c_stateB, subGroupSize>; \
+    extern template class PmeSolveKernel<GridOrdering::YZX, true, c_stateB, subGroupSize>;
+
 #define INSTANTIATE(order, subGroupSize)        \
     INSTANTIATE_X(SPREAD, order, subGroupSize); \
-    INSTANTIATE_X(GATHER, order, subGroupSize);
+    INSTANTIATE_X(GATHER, order, subGroupSize); \
+    INSTANTIATE_SOLVE(subGroupSize);
 
 #if GMX_SYCL_DPCPP
 INSTANTIATE(4, 16);
@@ -107,7 +122,6 @@ INSTANTIATE(4, 32);
 INSTANTIATE(4, 64);
 #endif
 
-
 //! Helper function to set proper kernel functor pointers
 template<int subGroupSize>
 static void setKernelPointers(struct PmeGpuProgramImpl* pmeGpuProgram)
@@ -164,6 +178,22 @@ static void setKernelPointers(struct PmeGpuProgramImpl* pmeGpuProgram)
             new PmeGatherKernel<c_pmeOrder, c_wrapX, c_wrapY, 2, true, ThreadsPerAtom::OrderSquared, subGroupSize>();
     pmeGpuProgram->gatherKernelReadSplinesThPerAtom4Dual =
             new PmeGatherKernel<c_pmeOrder, c_wrapX, c_wrapY, 2, true, ThreadsPerAtom::Order, subGroupSize>();
+    pmeGpuProgram->solveXYZKernelA =
+            new PmeSolveKernel<GridOrdering::XYZ, false, c_stateA, subGroupSize>();
+    pmeGpuProgram->solveXYZEnergyKernelA =
+            new PmeSolveKernel<GridOrdering::XYZ, true, c_stateA, subGroupSize>();
+    pmeGpuProgram->solveYZXKernelA =
+            new PmeSolveKernel<GridOrdering::YZX, false, c_stateA, subGroupSize>();
+    pmeGpuProgram->solveYZXEnergyKernelA =
+            new PmeSolveKernel<GridOrdering::YZX, true, c_stateA, subGroupSize>();
+    pmeGpuProgram->solveXYZKernelB =
+            new PmeSolveKernel<GridOrdering::XYZ, false, c_stateB, subGroupSize>();
+    pmeGpuProgram->solveXYZEnergyKernelB =
+            new PmeSolveKernel<GridOrdering::XYZ, true, c_stateB, subGroupSize>();
+    pmeGpuProgram->solveYZXKernelB =
+            new PmeSolveKernel<GridOrdering::YZX, false, c_stateB, subGroupSize>();
+    pmeGpuProgram->solveYZXEnergyKernelB =
+            new PmeSolveKernel<GridOrdering::YZX, true, c_stateB, subGroupSize>();
 }
 
 PmeGpuProgramImpl::PmeGpuProgramImpl(const DeviceContext& deviceContext) :
@@ -205,4 +235,20 @@ PmeGpuProgramImpl::~PmeGpuProgramImpl()
     delete splineAndSpreadKernelThPerAtom4Dual;
     delete splineAndSpreadKernelWriteSplinesDual;
     delete splineAndSpreadKernelWriteSplinesThPerAtom4Dual;
+    delete gatherKernelSingle;
+    delete gatherKernelThPerAtom4Single;
+    delete gatherKernelReadSplinesSingle;
+    delete gatherKernelReadSplinesThPerAtom4Single;
+    delete gatherKernelDual;
+    delete gatherKernelThPerAtom4Dual;
+    delete gatherKernelReadSplinesDual;
+    delete gatherKernelReadSplinesThPerAtom4Dual;
+    delete solveYZXKernelA;
+    delete solveXYZKernelA;
+    delete solveYZXEnergyKernelA;
+    delete solveXYZEnergyKernelA;
+    delete solveYZXKernelB;
+    delete solveXYZKernelB;
+    delete solveYZXEnergyKernelB;
+    delete solveXYZEnergyKernelB;
 }