Pass the GPU streams to StatePropagatorDataGpu constructor
[alexxy/gromacs.git] / src / gromacs / ewald / tests / pmegathertest.cpp
index d2afc1e324f2022165d1ec2e3c66ce86491edc23..077970b4a942295ebbf882e91bc311e21254c30c 100644 (file)
@@ -407,9 +407,10 @@ class PmeGatherTest : public ::testing::TestWithParam<GatherInputParameters>
                                           (inputForceTreatment == PmeForceOutputHandling::ReduceWithInput) ? "with reduction" : "without reduction"
                                           ));
 
-                PmeSafePointer         pmeSafe  = pmeInitWrapper(&inputRec, codePath, context->getDeviceInfo(), context->getPmeGpuProgram(), box);
-                StatePropagatorDataGpu stateGpu = makeStatePropagatorDataGpu(*pmeSafe.get());
-                pmeInitAtoms(pmeSafe.get(), &stateGpu, codePath, inputAtomData.coordinates, inputAtomData.charges);
+                PmeSafePointer pmeSafe = pmeInitWrapper(&inputRec, codePath, context->getDeviceInfo(), context->getPmeGpuProgram(), box);
+                std::unique_ptr<StatePropagatorDataGpu> stateGpu = (codePath == CodePath::GPU) ? makeStatePropagatorDataGpu(*pmeSafe.get()) : nullptr;
+
+                pmeInitAtoms(pmeSafe.get(), stateGpu.get(), codePath, inputAtomData.coordinates, inputAtomData.charges);
 
                 /* Setting some more inputs */
                 pmeSetRealGrid(pmeSafe.get(), codePath, nonZeroGridValues);