Decouple coordinates buffer management from buffer ops in NBNXM
[alexxy/gromacs.git] / src / gromacs / mdlib / sim_util.cpp
index 724160a4722568a7a3f266fb9bb316ef77e3fc88..853f97b021063e196bd11bdf89422f78ee511f37 100644 (file)
@@ -1120,8 +1120,22 @@ void do_force(FILE                                     *fplog,
     }
     else
     {
-        nbv->setCoordinates(Nbnxm::AtomLocality::Local, false,
-                            x.unpaddedArrayRef(), useGpuXBufOps, pme_gpu_get_device_x(fr->pmedata));
+        if (useGpuXBufOps == BufferOpsUseGpu::True)
+        {
+            // The condition here was (pme != nullptr && pme_gpu_get_device_x(fr->pmedata) != nullptr)
+            if (!useGpuPme)
+            {
+                nbv->copyCoordinatesToGpu(Nbnxm::AtomLocality::Local, false,
+                                          x.unpaddedArrayRef());
+            }
+            nbv->convertCoordinatesGpu(Nbnxm::AtomLocality::Local, false,
+                                       useGpuPme ? pme_gpu_get_device_x(fr->pmedata) : nbv->getDeviceCoordinates());
+        }
+        else
+        {
+            nbv->convertCoordinates(Nbnxm::AtomLocality::Local, false,
+                                    x.unpaddedArrayRef());
+        }
     }
 
     if (bUseGPU)
@@ -1187,8 +1201,22 @@ void do_force(FILE                                     *fplog,
         {
             dd_move_x(cr->dd, box, x.unpaddedArrayRef(), wcycle);
 
-            nbv->setCoordinates(Nbnxm::AtomLocality::NonLocal, false,
-                                x.unpaddedArrayRef(), useGpuXBufOps, pme_gpu_get_device_x(fr->pmedata));
+            if (useGpuXBufOps == BufferOpsUseGpu::True)
+            {
+                // The condition here was (pme != nullptr && pme_gpu_get_device_x(fr->pmedata) != nullptr)
+                if (!useGpuPme)
+                {
+                    nbv->copyCoordinatesToGpu(Nbnxm::AtomLocality::NonLocal, false,
+                                              x.unpaddedArrayRef());
+                }
+                nbv->convertCoordinatesGpu(Nbnxm::AtomLocality::NonLocal, false,
+                                           useGpuPme ? pme_gpu_get_device_x(fr->pmedata) : nbv->getDeviceCoordinates());
+            }
+            else
+            {
+                nbv->convertCoordinates(Nbnxm::AtomLocality::NonLocal, false,
+                                        x.unpaddedArrayRef());
+            }
 
         }