Use getAtomRanges(...) function in NBNXM more
[alexxy/gromacs.git] / src / gromacs / nbnxm / opencl / nbnxm_ocl.cpp
index e00874a30b6a78761a3c4b12c83ef9185b52243e..50e7b9d8d4f167e1a36eef7134a2047e8b4b7bd5 100644 (file)
@@ -527,9 +527,6 @@ void gpu_copy_xq_to_gpu(NbnxmGpu* nb, const nbnxn_atomdata_t* nbatom, const Atom
 
     const InteractionLocality iloc = gpuAtomToInteractionLocality(atomLocality);
 
-    /* local/nonlocal offset and length used for xq and f */
-    int adat_begin, adat_len;
-
     NBAtomData*         adat         = nb->atdat;
     gpu_plist*          plist        = nb->plist[iloc];
     cl_timers_t*        t            = nb->timers;
@@ -558,17 +555,8 @@ void gpu_copy_xq_to_gpu(NbnxmGpu* nb, const nbnxn_atomdata_t* nbatom, const Atom
         return;
     }
 
-    /* calculate the atom data index range based on locality */
-    if (atomLocality == AtomLocality::Local)
-    {
-        adat_begin = 0;
-        adat_len   = adat->numAtomsLocal;
-    }
-    else
-    {
-        adat_begin = adat->numAtomsLocal;
-        adat_len   = adat->numAtoms - adat->numAtomsLocal;
-    }
+    /* local/nonlocal offset and length used for xq and f */
+    auto atomsRange = getGpuAtomRange(adat, atomLocality);
 
     /* beginning of timed HtoD section */
     if (bDoTime)
@@ -580,9 +568,9 @@ void gpu_copy_xq_to_gpu(NbnxmGpu* nb, const nbnxn_atomdata_t* nbatom, const Atom
     static_assert(sizeof(float) == sizeof(*nbatom->x().data()),
                   "The size of the xyzq buffer element should be equal to the size of float4.");
     copyToDeviceBuffer(&adat->xq,
-                       reinterpret_cast<const Float4*>(nbatom->x().data()) + adat_begin,
-                       adat_begin,
-                       adat_len,
+                       reinterpret_cast<const Float4*>(nbatom->x().data()) + atomsRange.begin(),
+                       atomsRange.begin(),
+                       atomsRange.size(),
                        deviceStream,
                        GpuApiCallBehavior::Async,
                        bDoTime ? t->xf[atomLocality].nb_h2d.fetchNextEvent() : nullptr);
@@ -931,15 +919,14 @@ void gpu_launch_kernel_pruneonly(NbnxmGpu* nb, const InteractionLocality iloc, c
 void gpu_launch_cpyback(NbnxmGpu*                nb,
                         struct nbnxn_atomdata_t* nbatom,
                         const gmx::StepWorkload& stepWork,
-                        const AtomLocality       aloc)
+                        const AtomLocality       atomLocality)
 {
     GMX_ASSERT(nb, "Need a valid nbnxn_gpu object");
 
     cl_int gmx_unused cl_error;
-    int               adat_begin, adat_len; /* local/nonlocal offset and length used for xq and f */
 
     /* determine interaction locality from atom locality */
-    const InteractionLocality iloc = gpuAtomToInteractionLocality(aloc);
+    const InteractionLocality iloc = gpuAtomToInteractionLocality(atomLocality);
     GMX_ASSERT(iloc == InteractionLocality::Local
                        || (iloc == InteractionLocality::NonLocal && nb->bNonLocalStreamDoneMarked == false),
                "Non-local stream is indicating that the copy back event is enqueued at the "
@@ -965,12 +952,13 @@ void gpu_launch_cpyback(NbnxmGpu*                nb,
         return;
     }
 
-    getGpuAtomRange(adat, aloc, &adat_begin, &adat_len);
+    /* local/nonlocal offset and length used for xq and f */
+    auto atomsRange = getGpuAtomRange(adat, atomLocality);
 
     /* beginning of timed D2H section */
     if (bDoTime)
     {
-        t->xf[aloc].nb_d2h.openTimingRegion(deviceStream);
+        t->xf[atomLocality].nb_d2h.openTimingRegion(deviceStream);
     }
 
     /* With DD the local D2H transfer can only start after the non-local
@@ -984,13 +972,13 @@ void gpu_launch_cpyback(NbnxmGpu*                nb,
     /* DtoH f */
     GMX_ASSERT(sizeof(*nbatom->out[0].f.data()) == sizeof(float),
                "The host force buffer should be in single precision to match device data size.");
-    copyFromDeviceBuffer(reinterpret_cast<Float3*>(nbatom->out[0].f.data()) + adat_begin,
+    copyFromDeviceBuffer(reinterpret_cast<Float3*>(nbatom->out[0].f.data()) + atomsRange.begin(),
                          &adat->f,
-                         adat_begin,
-                         adat_len,
+                         atomsRange.begin(),
+                         atomsRange.size(),
                          deviceStream,
                          GpuApiCallBehavior::Async,
-                         bDoTime ? t->xf[aloc].nb_d2h.fetchNextEvent() : nullptr);
+                         bDoTime ? t->xf[atomLocality].nb_d2h.fetchNextEvent() : nullptr);
 
     /* kick off work */
     cl_error = clFlush(deviceStream.stream());
@@ -1021,7 +1009,7 @@ void gpu_launch_cpyback(NbnxmGpu*                nb,
                                  SHIFTS,
                                  deviceStream,
                                  GpuApiCallBehavior::Async,
-                                 bDoTime ? t->xf[aloc].nb_d2h.fetchNextEvent() : nullptr);
+                                 bDoTime ? t->xf[atomLocality].nb_d2h.fetchNextEvent() : nullptr);
         }
 
         /* DtoH energies */
@@ -1035,7 +1023,7 @@ void gpu_launch_cpyback(NbnxmGpu*                nb,
                                  1,
                                  deviceStream,
                                  GpuApiCallBehavior::Async,
-                                 bDoTime ? t->xf[aloc].nb_d2h.fetchNextEvent() : nullptr);
+                                 bDoTime ? t->xf[atomLocality].nb_d2h.fetchNextEvent() : nullptr);
             static_assert(sizeof(*nb->nbst.eElec) == sizeof(float),
                           "Sizes of host- and device-side electrostatic energy terms should be the "
                           "same.");
@@ -1045,13 +1033,13 @@ void gpu_launch_cpyback(NbnxmGpu*                nb,
                                  1,
                                  deviceStream,
                                  GpuApiCallBehavior::Async,
-                                 bDoTime ? t->xf[aloc].nb_d2h.fetchNextEvent() : nullptr);
+                                 bDoTime ? t->xf[atomLocality].nb_d2h.fetchNextEvent() : nullptr);
         }
     }
 
     if (bDoTime)
     {
-        t->xf[aloc].nb_d2h.closeTimingRegion(deviceStream);
+        t->xf[atomLocality].nb_d2h.closeTimingRegion(deviceStream);
     }
 }