Unify gpu_init_atomdata(...) function
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_data_mgmt.cpp
index cc4f9f3a6bfbb62637bb97bcdef05609d97f1e68..2f37a0c011c90f8a9b7e7575b62d06ed1b92bad1 100644 (file)
 namespace Nbnxm
 {
 
-//! This function is documented in the header file
-void gpu_clear_outputs(NbnxmGpu* nb, bool computeVirial)
-{
-    NBAtomData*         adat        = nb->atdat;
-    const DeviceStream& localStream = *nb->deviceStreams[InteractionLocality::Local];
-    // Clear forces
-    clearDeviceBufferAsync(&adat->f, 0, nb->atdat->numAtoms, localStream);
-    // Clear shift force array and energies if the outputs were used in the current step
-    if (computeVirial)
-    {
-        clearDeviceBufferAsync(&adat->fShift, 0, SHIFTS, localStream);
-        clearDeviceBufferAsync(&adat->eLJ, 0, 1, localStream);
-        clearDeviceBufferAsync(&adat->eElec, 0, 1, localStream);
-    }
-}
-
 /*! \brief Initialize \p atomdata first time; it only gets filled at pair-search. */
-static void initAtomdataFirst(NbnxmGpu* nb, int numTypes, const DeviceContext& deviceContext)
+static void initAtomdataFirst(NBAtomData*          atomdata,
+                              int                  numTypes,
+                              const DeviceContext& deviceContext,
+                              const DeviceStream&  localStream)
 {
-    const DeviceStream& localStream = *nb->deviceStreams[InteractionLocality::Local];
-    NBAtomData*         atomdata    = nb->atdat;
-    atomdata->numTypes              = numTypes;
+    atomdata->numTypes = numTypes;
     allocateDeviceBuffer(&atomdata->shiftVec, SHIFTS, deviceContext);
     atomdata->shiftVecUploaded = false;
 
@@ -179,8 +164,8 @@ NbnxmGpu* gpu_init(const gmx::DeviceStreamManager& deviceStreamManager,
     /* local/non-local GPU streams */
     GMX_RELEASE_ASSERT(deviceStreamManager.streamIsValid(gmx::DeviceStreamType::NonBondedLocal),
                        "Local non-bonded stream should be initialized to use GPU for non-bonded.");
-    nb->deviceStreams[InteractionLocality::Local] =
-            &deviceStreamManager.stream(gmx::DeviceStreamType::NonBondedLocal);
+    const DeviceStream& localStream = deviceStreamManager.stream(gmx::DeviceStreamType::NonBondedLocal);
+    nb->deviceStreams[InteractionLocality::Local] = &localStream;
     // In general, it's not strictly necessary to use 2 streams for SYCL, since they are
     // out-of-order. But for the time being, it will be less disruptive to keep them.
     if (nb->bUseTwoStreams)
@@ -200,7 +185,7 @@ NbnxmGpu* gpu_init(const gmx::DeviceStreamManager& deviceStreamManager,
     const DeviceContext&            deviceContext = *nb->deviceContext_;
 
     initNbparam(nb->nbparam, *ic, listParams, nbatParams, deviceContext);
-    initAtomdataFirst(nb, nbatParams.numTypes, deviceContext);
+    initAtomdataFirst(nb->atdat, nbatParams.numTypes, deviceContext, localStream);
 
     return nb;
 }