Pass the new gmx::ForceFlags to the nbnxm module
[alexxy/gromacs.git] / src / gromacs / nbnxm / cuda / nbnxm_cuda.cu
index 74925428ac79994726cdf9d441186172b152b1d4..0b30c010b0042a9dacf7194e24f158492bdcdd89 100644 (file)
@@ -56,7 +56,7 @@
 #include "gromacs/gpu_utils/cudautils.cuh"
 #include "gromacs/gpu_utils/gpueventsynchronizer.cuh"
 #include "gromacs/gpu_utils/vectype_ops.cuh"
-#include "gromacs/mdlib/force_flags.h"
+#include "gromacs/mdlib/ppforceworkload.h"
 #include "gromacs/nbnxm/atomdata.h"
 #include "gromacs/nbnxm/gpu_common.h"
 #include "gromacs/nbnxm/gpu_common_utils.h"
@@ -403,7 +403,7 @@ void gpu_copy_xq_to_gpu(gmx_nbnxn_cuda_t       *nb,
    with this event in the non-local stream before launching the non-bonded kernel.
  */
 void gpu_launch_kernel(gmx_nbnxn_cuda_t          *nb,
-                       const int                  flags,
+                       const gmx::ForceFlags     &forceFlags,
                        const InteractionLocality  iloc)
 {
     cu_atomdata_t       *adat    = nb->atdat;
@@ -412,8 +412,6 @@ void gpu_launch_kernel(gmx_nbnxn_cuda_t          *nb,
     cu_timers_t         *t       = nb->timers;
     cudaStream_t         stream  = nb->stream[iloc];
 
-    bool                 bCalcEner   = flags & GMX_FORCE_ENERGY;
-    bool                 bCalcFshift = flags & GMX_FORCE_VIRIAL;
     bool                 bDoTime     = nb->bDoTime;
 
     /* Don't launch the non-local kernel if there is no work to do.
@@ -488,10 +486,10 @@ void gpu_launch_kernel(gmx_nbnxn_cuda_t          *nb,
     auto       *timingEvent = bDoTime ? t->interaction[iloc].nb_k.fetchNextEvent() : nullptr;
     const auto  kernel      = select_nbnxn_kernel(nbp->eeltype,
                                                   nbp->vdwtype,
-                                                  bCalcEner,
+                                                  forceFlags.computeEnergy,
                                                   (plist->haveFreshList && !nb->timers->interaction[iloc].didPrune),
                                                   nb->dev_info);
-    const auto kernelArgs  = prepareGpuKernelArguments(kernel, config, adat, nbp, plist, &bCalcFshift);
+    const auto kernelArgs  = prepareGpuKernelArguments(kernel, config, adat, nbp, plist, &forceFlags.computeVirial);
     launchGpuKernel(kernel, config, timingEvent, "k_calc_nb", kernelArgs);
 
     if (bDoTime)
@@ -645,7 +643,7 @@ void gpu_launch_kernel_pruneonly(gmx_nbnxn_cuda_t          *nb,
 
 void gpu_launch_cpyback(gmx_nbnxn_cuda_t       *nb,
                         nbnxn_atomdata_t       *nbatom,
-                        const int               flags,
+                        const gmx::ForceFlags  &forceFlags,
                         const AtomLocality      atomLocality,
                         const bool              copyBackNbForce)
 {
@@ -663,9 +661,6 @@ void gpu_launch_cpyback(gmx_nbnxn_cuda_t       *nb,
     bool             bDoTime = nb->bDoTime;
     cudaStream_t     stream  = nb->stream[iloc];
 
-    bool             bCalcEner   = flags & GMX_FORCE_ENERGY;
-    bool             bCalcFshift = flags & GMX_FORCE_VIRIAL;
-
     /* don't launch non-local copy-back if there was no non-local work to do */
     if ((iloc == InteractionLocality::NonLocal) && !haveGpuShortRangeWork(*nb, iloc))
     {
@@ -708,15 +703,15 @@ void gpu_launch_cpyback(gmx_nbnxn_cuda_t       *nb,
     /* only transfer energies in the local stream */
     if (iloc == InteractionLocality::Local)
     {
-        /* DtoH fshift */
-        if (bCalcFshift)
+        /* DtoH fshift when virial is needed */
+        if (forceFlags.computeVirial)
         {
             cu_copy_D2H_async(nb->nbst.fshift, adat->fshift,
                               SHIFTS * sizeof(*nb->nbst.fshift), stream);
         }
 
         /* DtoH energies */
-        if (bCalcEner)
+        if (forceFlags.computeEnergy)
         {
             cu_copy_D2H_async(nb->nbst.e_lj, adat->e_lj,
                               sizeof(*nb->nbst.e_lj), stream);