Unify gpu_init_atomdata(...) function
[alexxy/gromacs.git] / src / gromacs / nbnxm / nbnxm_gpu_data_mgmt.cpp
index 50519ced6d63d38f7cabfc4cce7c13be23834a6c..b86b785b94238e764c042adad5dfbf4e459dc30a 100644 (file)
@@ -69,6 +69,7 @@
 #include "gromacs/nbnxm/gpu_data_mgmt.h"
 #include "gromacs/pbcutil/ishift.h"
 #include "gromacs/timing/gpu_timing.h"
+#include "gromacs/pbcutil/ishift.h"
 #include "gromacs/utility/cstringutil.h"
 #include "gromacs/utility/exceptions.h"
 #include "gromacs/utility/fatalerror.h"
@@ -429,6 +430,22 @@ void gpu_init_atomdata(NbnxmGpu* nb, const nbnxn_atomdata_t* nbat)
     issueClFlushInStream(localStream);
 }
 
+void gpu_clear_outputs(NbnxmGpu* nb, bool computeVirial)
+{
+    NBAtomData*         adat        = nb->atdat;
+    const DeviceStream& localStream = *nb->deviceStreams[InteractionLocality::Local];
+    // Clear forces
+    clearDeviceBufferAsync(&adat->f, 0, nb->atdat->numAtoms, localStream);
+    // Clear shift force array and energies if the outputs were used in the current step
+    if (computeVirial)
+    {
+        clearDeviceBufferAsync(&adat->fShift, 0, SHIFTS, localStream);
+        clearDeviceBufferAsync(&adat->eLJ, 0, 1, localStream);
+        clearDeviceBufferAsync(&adat->eElec, 0, 1, localStream);
+    }
+    issueClFlushInStream(localStream);
+}
+
 //! This function is documented in the header file
 gmx_wallclock_gpu_nbnxn_t* gpu_get_timings(NbnxmGpu* nb)
 {