Pass the GPU streams to StatePropagatorDataGpu constructor
[alexxy/gromacs.git] / src / gromacs / ewald / tests / pmesplinespreadtest.cpp
index 136c08b69733a400c73d56c4831f89fb0e357d1f..4a30d4d1c52a9c6c21c6701b91514192f5cd2b56 100644 (file)
@@ -122,7 +122,6 @@ class PmeSplineAndSpreadTest : public ::testing::TestWithParam<SplineAndSpreadIn
 
             for (const auto &context : getPmeTestEnv()->getHardwareContexts())
             {
-                std::shared_ptr<StatePropagatorDataGpu> stateGpu;
                 CodePath   codePath       = context->getCodePath();
                 const bool supportedInput = pmeSupportsInputForMode(*getPmeTestEnv()->hwinfo(), &inputRec, codePath);
                 if (!supportedInput)
@@ -146,9 +145,10 @@ class PmeSplineAndSpreadTest : public ::testing::TestWithParam<SplineAndSpreadIn
 
                     /* Running the test */
 
-                    PmeSafePointer         pmeSafe  = pmeInitWrapper(&inputRec, codePath, context->getDeviceInfo(), context->getPmeGpuProgram(), box);
-                    StatePropagatorDataGpu stateGpu = makeStatePropagatorDataGpu(*pmeSafe.get());
-                    pmeInitAtoms(pmeSafe.get(), &stateGpu, codePath, coordinates, charges);
+                    PmeSafePointer pmeSafe = pmeInitWrapper(&inputRec, codePath, context->getDeviceInfo(), context->getPmeGpuProgram(), box);
+                    std::unique_ptr<StatePropagatorDataGpu> stateGpu = (codePath == CodePath::GPU) ? makeStatePropagatorDataGpu(*pmeSafe.get()) : nullptr;
+
+                    pmeInitAtoms(pmeSafe.get(), stateGpu.get(), codePath, coordinates, charges);
 
                     const bool     computeSplines = (option.first == PmeSplineAndSpreadOptions::SplineOnly) || (option.first == PmeSplineAndSpreadOptions::SplineAndSpreadUnified);
                     const bool     spreadCharges  = (option.first == PmeSplineAndSpreadOptions::SpreadOnly) || (option.first == PmeSplineAndSpreadOptions::SplineAndSpreadUnified);