Pass the new gmx::ForceFlags to the nbnxm module
[alexxy/gromacs.git] / src / gromacs / nbnxm / opencl / nbnxm_ocl.cpp
index 99a7829e29df983e782571f6fab14c6f9e339ba9..365b404ad60aeb135838d9b7f77a5bc80d9ac2a0 100644 (file)
@@ -72,7 +72,7 @@
 #include "gromacs/gpu_utils/gputraits_ocl.h"
 #include "gromacs/gpu_utils/oclutils.h"
 #include "gromacs/hardware/hw_info.h"
-#include "gromacs/mdlib/force_flags.h"
+#include "gromacs/mdlib/ppforceworkload.h"
 #include "gromacs/nbnxm/atomdata.h"
 #include "gromacs/nbnxm/gpu_common.h"
 #include "gromacs/nbnxm/gpu_common_utils.h"
@@ -468,7 +468,7 @@ void gpu_copy_xq_to_gpu(gmx_nbnxn_ocl_t        *nb,
    are finished and synchronize with this event in the non-local stream.
  */
 void gpu_launch_kernel(gmx_nbnxn_ocl_t                  *nb,
-                       const int                         flags,
+                       const gmx::ForceFlags            &forceFlags,
                        const Nbnxm::InteractionLocality  iloc)
 {
     cl_atomdata_t       *adat    = nb->atdat;
@@ -477,8 +477,6 @@ void gpu_launch_kernel(gmx_nbnxn_ocl_t                  *nb,
     cl_timers_t         *t       = nb->timers;
     cl_command_queue     stream  = nb->stream[iloc];
 
-    bool                 bCalcEner   = (flags & GMX_FORCE_ENERGY) != 0;
-    int                  bCalcFshift = flags & GMX_FORCE_VIRIAL;
     bool                 bDoTime     = (nb->bDoTime) != 0;
 
     cl_nbparam_params_t  nbparams_params;
@@ -548,17 +546,20 @@ void gpu_launch_kernel(gmx_nbnxn_ocl_t                  *nb,
     const auto     kernel       = select_nbnxn_kernel(nb,
                                                       nbp->eeltype,
                                                       nbp->vdwtype,
-                                                      bCalcEner,
+                                                      forceFlags.computeEnergy,
                                                       (plist->haveFreshList && !nb->timers->interaction[iloc].didPrune));
 
 
+    // The OpenCL kernel takes int as second to last argument because bool is
+    // not supported as a kernel argument type (sizeof(bool) is implementation defined).
+    const int computeFshift = forceFlags.computeVirial;
     if (useLjCombRule(nb->nbparam->vdwtype))
     {
         const auto kernelArgs = prepareGpuKernelArguments(kernel, config,
                                                           &nbparams_params, &adat->xq, &adat->f, &adat->e_lj, &adat->e_el, &adat->fshift,
                                                           &adat->lj_comb,
                                                           &adat->shift_vec, &nbp->nbfp_climg2d, &nbp->nbfp_comb_climg2d, &nbp->coulomb_tab_climg2d,
-                                                          &plist->sci, &plist->cj4, &plist->excl, &bCalcFshift);
+                                                          &plist->sci, &plist->cj4, &plist->excl, &computeFshift);
 
         launchGpuKernel(kernel, config, timingEvent, kernelName, kernelArgs);
     }
@@ -569,7 +570,7 @@ void gpu_launch_kernel(gmx_nbnxn_ocl_t                  *nb,
                                                           &nbparams_params, &adat->xq, &adat->f, &adat->e_lj, &adat->e_el, &adat->fshift,
                                                           &adat->atom_types,
                                                           &adat->shift_vec, &nbp->nbfp_climg2d, &nbp->nbfp_comb_climg2d, &nbp->coulomb_tab_climg2d,
-                                                          &plist->sci, &plist->cj4, &plist->excl, &bCalcFshift);
+                                                          &plist->sci, &plist->cj4, &plist->excl, &computeFshift);
         launchGpuKernel(kernel, config, timingEvent, kernelName, kernelArgs);
     }
 
@@ -733,9 +734,9 @@ void gpu_launch_kernel_pruneonly(gmx_nbnxn_gpu_t           *nb,
  */
 void gpu_launch_cpyback(gmx_nbnxn_ocl_t                          *nb,
                         struct nbnxn_atomdata_t                  *nbatom,
-                        const int                                 flags,
+                        const gmx::ForceFlags                    &forceFlags,
                         const AtomLocality                        aloc,
-                        const bool                    gmx_unused  copyBackNbForce)
+                        const bool                     gmx_unused copyBackNbForce)
 {
     GMX_ASSERT(nb, "Need a valid nbnxn_gpu object");
 
@@ -750,10 +751,6 @@ void gpu_launch_cpyback(gmx_nbnxn_ocl_t                          *nb,
     bool                      bDoTime = nb->bDoTime == CL_TRUE;
     cl_command_queue          stream  = nb->stream[iloc];
 
-    bool                      bCalcEner   = (flags & GMX_FORCE_ENERGY) != 0;
-    int                       bCalcFshift = flags & GMX_FORCE_VIRIAL;
-
-
     /* don't launch non-local copy-back if there was no non-local work to do */
     if ((iloc == InteractionLocality::NonLocal) && !haveGpuShortRangeWork(*nb, iloc))
     {
@@ -806,15 +803,15 @@ void gpu_launch_cpyback(gmx_nbnxn_ocl_t                          *nb,
     /* only transfer energies in the local stream */
     if (iloc == InteractionLocality::Local)
     {
-        /* DtoH fshift */
-        if (bCalcFshift)
+        /* DtoH fshift when virial is needed */
+        if (forceFlags.computeVirial)
         {
             ocl_copy_D2H_async(nb->nbst.fshift, adat->fshift, 0,
                                SHIFTS * adat->fshift_elem_size, stream, bDoTime ? t->xf[aloc].nb_d2h.fetchNextEvent() : nullptr);
         }
 
         /* DtoH energies */
-        if (bCalcEner)
+        if (forceFlags.computeEnergy)
         {
             ocl_copy_D2H_async(nb->nbst.e_lj, adat->e_lj, 0,
                                sizeof(float), stream, bDoTime ? t->xf[aloc].nb_d2h.fetchNextEvent() : nullptr);