Fix SYCL PME Solve kernel
[alexxy/gromacs.git] / src / gromacs / ewald / pme_solve_sycl.cpp
index 633cf31e8e690b5cfacadcdb8726dcb5699a5156..46883060e9b1367d8f34e48a145b530d250b1309 100644 (file)
@@ -65,12 +65,15 @@ template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int subGroupSiz
 auto makeSolveKernel(cl::sycl::handler&                            cgh,
                      DeviceAccessor<float, mode::read>             a_splineModuli,
                      DeviceAccessor<SolveKernelParams, mode::read> a_solveKernelParams,
-                     DeviceAccessor<float, mode::read_write>       a_virialAndEnergy,
-                     DeviceAccessor<float, mode::read_write>       a_fourierGrid)
+                     OptionalAccessor<float, mode::read_write, computeEnergyAndVirial> a_virialAndEnergy,
+                     DeviceAccessor<float, mode::read_write> a_fourierGrid)
 {
     cgh.require(a_splineModuli);
     cgh.require(a_solveKernelParams);
-    cgh.require(a_virialAndEnergy);
+    if constexpr (computeEnergyAndVirial)
+    {
+        cgh.require(a_virialAndEnergy);
+    }
     cgh.require(a_fourierGrid);
 
     /* Reduce 7 outputs per warp in the shared memory */
@@ -281,7 +284,7 @@ auto makeSolveKernel(cl::sycl::handler&                            cgh,
         }
 
         /* Optional energy/virial reduction */
-        if (computeEnergyAndVirial)
+        if constexpr (computeEnergyAndVirial)
         {
             /* A tricky shuffle reduction inspired by reduce_force_j_warp_shfl.
              * The idea is to reduce 7 energy/virial components into a single variable (aligned by