Decouple coordinates buffer management from buffer ops in NBNXM
[alexxy/gromacs.git] / src / gromacs / nbnxm / atomdata.cpp
index e78a65945e0df3ea74579344149693a7402eed21..6424b1558e4d958aa566b88db8872d970965cba1 100644 (file)
@@ -998,116 +998,148 @@ void nbnxn_atomdata_copy_shiftvec(gmx_bool          bDynamicBox,
     }
 }
 
-/* Copies (and reorders) the coordinates to nbnxn_atomdata_t */
-template <bool useGpu>
-void nbnxn_atomdata_copy_x_to_nbat_x(const Nbnxm::GridSet     &gridSet,
-                                     const Nbnxm::AtomLocality locality,
-                                     gmx_bool                  FillLocal,
-                                     const rvec               *x,
-                                     nbnxn_atomdata_t         *nbat,
-                                     gmx_nbnxn_gpu_t          *gpu_nbv,
-                                     void                     *xPmeDevicePtr)
+// This is slightly different from nbnxn_get_atom_range(...) at the end of the file
+// TODO: Combine if possible
+static void getAtomRanges(const Nbnxm::GridSet      &gridSet,
+                          const Nbnxm::AtomLocality  locality,
+                          int                       *gridBegin,
+                          int                       *gridEnd)
 {
-    int gridBegin = 0;
-    int gridEnd   = 0;
-
     switch (locality)
     {
         case Nbnxm::AtomLocality::All:
-            gridBegin = 0;
-            gridEnd   = gridSet.grids().size();
+            *gridBegin = 0;
+            *gridEnd   = gridSet.grids().size();
             break;
         case Nbnxm::AtomLocality::Local:
-            gridBegin = 0;
-            gridEnd   = 1;
+            *gridBegin = 0;
+            *gridEnd   = 1;
             break;
         case Nbnxm::AtomLocality::NonLocal:
-            gridBegin = 1;
-            gridEnd   = gridSet.grids().size();
+            *gridBegin = 1;
+            *gridEnd   = gridSet.grids().size();
             break;
         case Nbnxm::AtomLocality::Count:
             GMX_ASSERT(false, "Count is invalid locality specifier");
             break;
     }
+}
 
-    if (FillLocal)
+/* Copies (and reorders) the coordinates to nbnxn_atomdata_t */
+void nbnxn_atomdata_copy_x_to_nbat_x(const Nbnxm::GridSet     &gridSet,
+                                     const Nbnxm::AtomLocality locality,
+                                     bool                      fillLocal,
+                                     const rvec               *coordinates,
+                                     nbnxn_atomdata_t         *nbat)
+{
+
+    int gridBegin = 0;
+    int gridEnd   = 0;
+    getAtomRanges(gridSet, locality, &gridBegin, &gridEnd);
+
+    if (fillLocal)
     {
         nbat->natoms_local = gridSet.grids()[0].atomIndexEnd();
     }
 
-    if (useGpu)
-    {
-        for (int g = gridBegin; g < gridEnd; g++)
-        {
-            nbnxn_gpu_x_to_nbat_x(gridSet.grids()[g],
-                                  FillLocal && g == 0,
-                                  gpu_nbv,
-                                  xPmeDevicePtr,
-                                  locality,
-                                  x, g, gridSet.numColumnsMax());
-        }
-    }
-    else
-    {
-        const int nth = gmx_omp_nthreads_get(emntPairsearch);
+    const int nth = gmx_omp_nthreads_get(emntPairsearch);
 #pragma omp parallel for num_threads(nth) schedule(static)
-        for (int th = 0; th < nth; th++)
+    for (int th = 0; th < nth; th++)
+    {
+        try
         {
-            try
+            for (int g = gridBegin; g < gridEnd; g++)
             {
-                for (int g = gridBegin; g < gridEnd; g++)
-                {
-                    const Nbnxm::Grid  &grid       = gridSet.grids()[g];
-                    const int           numCellsXY = grid.numColumns();
+                const Nbnxm::Grid  &grid       = gridSet.grids()[g];
+                const int           numCellsXY = grid.numColumns();
 
-                    const int           cxy0 = (numCellsXY* th      + nth - 1)/nth;
-                    const int           cxy1 = (numCellsXY*(th + 1) + nth - 1)/nth;
+                const int           cxy0 = (numCellsXY* th      + nth - 1)/nth;
+                const int           cxy1 = (numCellsXY*(th + 1) + nth - 1)/nth;
 
-                    for (int cxy = cxy0; cxy < cxy1; cxy++)
-                    {
-                        const int na  = grid.numAtomsInColumn(cxy);
-                        const int ash = grid.firstAtomInColumn(cxy);
+                for (int cxy = cxy0; cxy < cxy1; cxy++)
+                {
+                    const int na  = grid.numAtomsInColumn(cxy);
+                    const int ash = grid.firstAtomInColumn(cxy);
 
-                        int       na_fill;
-                        if (g == 0 && FillLocal)
-                        {
-                            na_fill = grid.paddedNumAtomsInColumn(cxy);
-                        }
-                        else
-                        {
-                            /* We fill only the real particle locations.
-                             * We assume the filling entries at the end have been
-                             * properly set before during pair-list generation.
-                             */
-                            na_fill = na;
-                        }
-                        copy_rvec_to_nbat_real(gridSet.atomIndices().data() + ash,
-                                               na, na_fill, x,
-                                               nbat->XFormat, nbat->x().data(), ash);
+                    int       na_fill;
+                    if (g == 0 && fillLocal)
+                    {
+                        na_fill = grid.paddedNumAtomsInColumn(cxy);
+                    }
+                    else
+                    {
+                        /* We fill only the real particle locations.
+                         * We assume the filling entries at the end have been
+                         * properly set before during pair-list generation.
+                         */
+                        na_fill = na;
                     }
+                    copy_rvec_to_nbat_real(gridSet.atomIndices().data() + ash,
+                                           na, na_fill, coordinates,
+                                           nbat->XFormat, nbat->x().data(), ash);
                 }
             }
-            GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
         }
+        GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
     }
 }
 
-template
-void nbnxn_atomdata_copy_x_to_nbat_x<true>(const Nbnxm::GridSet &,
-                                           const Nbnxm::AtomLocality,
-                                           gmx_bool,
-                                           const rvec*,
-                                           nbnxn_atomdata_t *,
-                                           gmx_nbnxn_gpu_t*,
-                                           void *);
-template
-void nbnxn_atomdata_copy_x_to_nbat_x<false>(const Nbnxm::GridSet &,
-                                            const Nbnxm::AtomLocality,
-                                            gmx_bool,
-                                            const rvec*,
-                                            nbnxn_atomdata_t *,
-                                            gmx_nbnxn_gpu_t*,
-                                            void *);
+void nbnxn_atomdata_copy_x_to_gpu(const Nbnxm::GridSet     &gridSet,
+                                  const Nbnxm::AtomLocality locality,
+                                  bool                      fillLocal,
+                                  nbnxn_atomdata_t         *nbat,
+                                  gmx_nbnxn_gpu_t          *gpu_nbv,
+                                  const rvec               *coordinatesHost)
+{
+    int gridBegin = 0;
+    int gridEnd   = 0;
+    getAtomRanges(gridSet, locality, &gridBegin, &gridEnd);
+
+    if (fillLocal)
+    {
+        nbat->natoms_local = gridSet.grids()[0].atomIndexEnd();
+    }
+
+    for (int g = gridBegin; g < gridEnd; g++)
+    {
+        nbnxn_gpu_copy_x_to_gpu(gridSet.grids()[g],
+                                fillLocal && g == 0,
+                                gpu_nbv,
+                                locality,
+                                coordinatesHost,
+                                g,
+                                gridSet.numColumnsMax());
+    }
+}
+
+DeviceBuffer<float> nbnxn_atomdata_get_x_gpu(gmx_nbnxn_gpu_t *gpu_nbv)
+{
+    return Nbnxm::nbnxn_gpu_get_x_gpu(gpu_nbv);
+}
+
+/* Copies (and reorders) the coordinates to nbnxn_atomdata_t on the GPU*/
+void nbnxn_atomdata_x_to_nbat_x_gpu(const Nbnxm::GridSet     &gridSet,
+                                    const Nbnxm::AtomLocality locality,
+                                    bool                      fillLocal,
+                                    gmx_nbnxn_gpu_t          *gpu_nbv,
+                                    DeviceBuffer<float>       coordinatesDevice)
+{
+
+    int gridBegin = 0;
+    int gridEnd   = 0;
+    getAtomRanges(gridSet, locality, &gridBegin, &gridEnd);
+
+    for (int g = gridBegin; g < gridEnd; g++)
+    {
+        nbnxn_gpu_x_to_nbat_x(gridSet.grids()[g],
+                              fillLocal && g == 0,
+                              gpu_nbv,
+                              coordinatesDevice,
+                              locality,
+                              g,
+                              gridSet.numColumnsMax());
+    }
+}
 
 static void
 nbnxn_atomdata_clear_reals(gmx::ArrayRef<real> dest,