Use enum class for nbnxm locality
[alexxy/gromacs.git] / src / gromacs / nbnxm / gpu_common.h
index cd64c8ca1a8b31761c3e89e61dc3e2afe68bbdc8..1624c56c8c5bb91aba3248330e2e78857ab82148 100644 (file)
@@ -67,6 +67,9 @@
 #include "gpu_common_utils.h"
 #include "nbnxm_gpu.h"
 
+namespace Nbnxm
+{
+
 /*! \brief Check that atom locality values are valid for the GPU module.
  *
  *  In the GPU module atom locality "all" is not supported, the local and
  *
  *  \param[in] atomLocality atom locality specifier
  */
-static inline void validateGpuAtomLocality(int atomLocality)
+static inline void
+validateGpuAtomLocality(const AtomLocality atomLocality)
 {
     std::string str = gmx::formatString("Invalid atom locality passed (%d); valid here is only "
-                                        "local (%d) or nonlocal (%d)", atomLocality, eatLocal, eatNonlocal);
+                                        "local (%d) or nonlocal (%d)",
+                                        static_cast<int>(atomLocality),
+                                        static_cast<int>(AtomLocality::Local),
+                                        static_cast<int>(AtomLocality::NonLocal));
 
-    GMX_ASSERT(LOCAL_OR_NONLOCAL_A(atomLocality), str.c_str());
+    GMX_ASSERT(atomLocality == AtomLocality::Local || atomLocality == AtomLocality::NonLocal, str.c_str());
 }
 
 /*! \brief Convert atom locality to interaction locality.
@@ -90,18 +97,19 @@ static inline void validateGpuAtomLocality(int atomLocality)
  *  \param[in] atomLocality Atom locality specifier
  *  \returns                Interaction locality corresponding to the atom locality passed.
  */
-static inline int gpuAtomToInteractionLocality(int atomLocality)
+static inline InteractionLocality
+gpuAtomToInteractionLocality(const AtomLocality atomLocality)
 {
     validateGpuAtomLocality(atomLocality);
 
     /* determine interaction locality from atom locality */
-    if (LOCAL_A(atomLocality))
+    if (atomLocality == AtomLocality::Local)
     {
-        return eintLocal;
+        return InteractionLocality::Local;
     }
-    else if (NONLOCAL_A(atomLocality))
+    else if (atomLocality == AtomLocality::NonLocal)
     {
-        return eintNonlocal;
+        return InteractionLocality::NonLocal;
     }
     else
     {
@@ -117,16 +125,17 @@ static inline int gpuAtomToInteractionLocality(int atomLocality)
  * \param[out] atomRangeLen Atom range length in the atom data array.
  */
 template <typename AtomDataT>
-static inline void getGpuAtomRange(const AtomDataT *atomData,
-                                   int              atomLocality,
-                                   int             *atomRangeBegin,
-                                   int             *atomRangeLen)
+static inline void
+getGpuAtomRange(const AtomDataT    *atomData,
+                const AtomLocality  atomLocality,
+                int                *atomRangeBegin,
+                int                *atomRangeLen)
 {
     assert(atomData);
     validateGpuAtomLocality(atomLocality);
 
     /* calculate the atom data index range based on locality */
-    if (LOCAL_A(atomLocality))
+    if (atomLocality == AtomLocality::Local)
     {
         *atomRangeBegin  = 0;
         *atomRangeLen    = atomData->natoms_local;
@@ -155,24 +164,27 @@ static inline void getGpuAtomRange(const AtomDataT *atomData,
 template <typename GpuTimers>
 static void countPruneKernelTime(GpuTimers                 *timers,
                                  gmx_wallclock_gpu_nbnxn_t *timings,
-                                 const int                  iloc)
+                                 const InteractionLocality  iloc)
 {
+    gpu_timers_t::Interaction &iTimers = timers->interaction[iloc];
+
     // We might have not done any pruning (e.g. if we skipped with empty domains).
-    if (!timers->didPrune[iloc] && !timers->didRollingPrune[iloc])
+    if (!iTimers.didPrune &&
+        !iTimers.didRollingPrune)
     {
         return;
     }
 
-    if (timers->didPrune[iloc])
+    if (iTimers.didPrune)
     {
         timings->pruneTime.c++;
-        timings->pruneTime.t += timers->prune_k[iloc].getLastRangeTime();
+        timings->pruneTime.t += iTimers.prune_k.getLastRangeTime();
     }
 
-    if (timers->didRollingPrune[iloc])
+    if (iTimers.didRollingPrune)
     {
         timings->dynamicPruneTime.c++;
-        timings->dynamicPruneTime.t += timers->rollingPrune_k[iloc].getLastRangeTime();
+        timings->dynamicPruneTime.t += iTimers.rollingPrune_k.getLastRangeTime();
     }
 }
 
@@ -195,16 +207,17 @@ static void countPruneKernelTime(GpuTimers                 *timers,
  * \param[out] fshift         Pointer to the array of shift forces to accumulate into
  */
 template <typename StagingData>
-static inline void nbnxn_gpu_reduce_staged_outputs(const StagingData &nbst,
-                                                   int                iLocality,
-                                                   bool               reduceEnergies,
-                                                   bool               reduceFshift,
-                                                   real              *e_lj,
-                                                   real              *e_el,
-                                                   rvec              *fshift)
+static inline void
+gpu_reduce_staged_outputs(const StagingData         &nbst,
+                          const InteractionLocality  iLocality,
+                          const bool                 reduceEnergies,
+                          const bool                 reduceFshift,
+                          real                      *e_lj,
+                          real                      *e_el,
+                          rvec                      *fshift)
 {
     /* add up energies and shift forces (only once at local F wait) */
-    if (LOCAL_I(iLocality))
+    if (iLocality == InteractionLocality::Local)
     {
         if (reduceEnergies)
         {
@@ -244,12 +257,13 @@ static inline void nbnxn_gpu_reduce_staged_outputs(const StagingData &nbst,
  *
  */
 template <typename GpuTimers, typename GpuPairlist>
-static inline void nbnxn_gpu_accumulate_timings(gmx_wallclock_gpu_nbnxn_t *timings,
-                                                GpuTimers                 *timers,
-                                                const GpuPairlist         *plist,
-                                                int                        atomLocality,
-                                                bool                       didEnergyKernels,
-                                                bool                       doTiming)
+static inline void
+gpu_accumulate_timings(gmx_wallclock_gpu_nbnxn_t *timings,
+                       GpuTimers                 *timers,
+                       const GpuPairlist         *plist,
+                       AtomLocality               atomLocality,
+                       bool                       didEnergyKernels,
+                       bool                       doTiming)
 {
     /* timing data accumulation */
     if (!doTiming)
@@ -258,10 +272,10 @@ static inline void nbnxn_gpu_accumulate_timings(gmx_wallclock_gpu_nbnxn_t *timin
     }
 
     /* determine interaction locality from atom locality */
-    int iLocality = gpuAtomToInteractionLocality(atomLocality);
+    const InteractionLocality iLocality = gpuAtomToInteractionLocality(atomLocality);
 
     /* only increase counter once (at local F wait) */
-    if (LOCAL_I(iLocality))
+    if (iLocality == InteractionLocality::Local)
     {
         timings->nb_c++;
         timings->ktime[plist->haveFreshList ? 1 : 0][didEnergyKernels ? 1 : 0].c += 1;
@@ -269,11 +283,11 @@ static inline void nbnxn_gpu_accumulate_timings(gmx_wallclock_gpu_nbnxn_t *timin
 
     /* kernel timings */
     timings->ktime[plist->haveFreshList ? 1 : 0][didEnergyKernels ? 1 : 0].t +=
-        timers->nb_k[iLocality].getLastRangeTime();
+        timers->interaction[iLocality].nb_k.getLastRangeTime();
 
     /* X/q H2D and F D2H timings */
-    timings->nb_h2d_t += timers->nb_h2d[iLocality].getLastRangeTime();
-    timings->nb_d2h_t += timers->nb_d2h[iLocality].getLastRangeTime();
+    timings->nb_h2d_t += timers->xf[atomLocality].nb_h2d.getLastRangeTime();
+    timings->nb_d2h_t += timers->xf[atomLocality].nb_d2h.getLastRangeTime();
 
     /* Count the pruning kernel times for both cases:1st pass (at search step)
        and rolling pruning (if called at the previous step).
@@ -284,39 +298,41 @@ static inline void nbnxn_gpu_accumulate_timings(gmx_wallclock_gpu_nbnxn_t *timin
     countPruneKernelTime(timers, timings, iLocality);
 
     /* only count atdat and pair-list H2D at pair-search step */
-    if (timers->didPairlistH2D[iLocality])
+    if (timers->interaction[iLocality].didPairlistH2D)
     {
         /* atdat transfer timing (add only once, at local F wait) */
-        if (LOCAL_A(atomLocality))
+        if (atomLocality == AtomLocality::Local)
         {
             timings->pl_h2d_c++;
             timings->pl_h2d_t += timers->atdat.getLastRangeTime();
         }
 
-        timings->pl_h2d_t += timers->pl_h2d[iLocality].getLastRangeTime();
+        timings->pl_h2d_t += timers->interaction[iLocality].pl_h2d.getLastRangeTime();
 
         /* Clear the timing flag for the next step */
-        timers->didPairlistH2D[iLocality] = false;
+        timers->interaction[iLocality].didPairlistH2D = false;
     }
 }
 
 //TODO: move into shared source file with gmx_compile_cpp_as_cuda
 //NOLINTNEXTLINE(misc-definitions-in-headers)
-bool nbnxn_gpu_try_finish_task(gmx_nbnxn_gpu_t  *nb,
-                               int               flags,
-                               int               aloc,
-                               bool              haveOtherWork,
-                               real             *e_lj,
-                               real             *e_el,
-                               rvec             *fshift,
-                               GpuTaskCompletion completionKind)
+bool gpu_try_finish_task(gmx_nbnxn_gpu_t    *nb,
+                         const int           flags,
+                         const AtomLocality  aloc,
+                         const bool          haveOtherWork,
+                         real               *e_lj,
+                         real               *e_el,
+                         rvec               *fshift,
+                         GpuTaskCompletion   completionKind)
 {
+    GMX_ASSERT(nb, "Need a valid nbnxn_gpu object");
+
     /* determine interaction locality from atom locality */
-    int iLocality = gpuAtomToInteractionLocality(aloc);
+    const InteractionLocality iLocality = gpuAtomToInteractionLocality(aloc);
 
     //  We skip when during the non-local phase there was actually no work to do.
     //  This is consistent with nbnxn_gpu_launch_kernel.
-    if (haveOtherWork || !canSkipWork(nb, iLocality))
+    if (haveOtherWork || !canSkipWork(*nb, iLocality))
     {
         // Query the state of the GPU stream and return early if we're not done
         if (completionKind == GpuTaskCompletion::Check)
@@ -336,14 +352,14 @@ bool nbnxn_gpu_try_finish_task(gmx_nbnxn_gpu_t  *nb,
         bool calcEner   = (flags & GMX_FORCE_ENERGY) != 0;
         bool calcFshift = (flags & GMX_FORCE_VIRIAL) != 0;
 
-        nbnxn_gpu_accumulate_timings(nb->timings, nb->timers, nb->plist[iLocality], aloc, calcEner,
-                                     nb->bDoTime != 0);
+        gpu_accumulate_timings(nb->timings, nb->timers, nb->plist[iLocality], aloc, calcEner,
+                               nb->bDoTime != 0);
 
-        nbnxn_gpu_reduce_staged_outputs(nb->nbst, iLocality, calcEner, calcFshift, e_lj, e_el, fshift);
+        gpu_reduce_staged_outputs(nb->nbst, iLocality, calcEner, calcFshift, e_lj, e_el, fshift);
     }
 
     /* Always reset both pruning flags (doesn't hurt doing it even when timing is off). */
-    nb->timers->didPrune[iLocality] = nb->timers->didRollingPrune[iLocality] = false;
+    nb->timers->interaction[iLocality].didPrune = nb->timers->interaction[iLocality].didRollingPrune = false;
 
     /* Turn off initial list pruning (doesn't hurt if this is not pair-search step). */
     nb->plist[iLocality]->haveFreshList = false;
@@ -368,16 +384,18 @@ bool nbnxn_gpu_try_finish_task(gmx_nbnxn_gpu_t  *nb,
  * \param[out] fshift Pointer to the shift force buffer to accumulate into
  */
 //NOLINTNEXTLINE(misc-definitions-in-headers) TODO: move into source file
-void nbnxn_gpu_wait_finish_task(gmx_nbnxn_gpu_t *nb,
-                                int              flags,
-                                int              aloc,
-                                bool             haveOtherWork,
-                                real            *e_lj,
-                                real            *e_el,
-                                rvec            *fshift)
+void gpu_wait_finish_task(gmx_nbnxn_gpu_t *nb,
+                          int              flags,
+                          AtomLocality     aloc,
+                          bool             haveOtherWork,
+                          real            *e_lj,
+                          real            *e_el,
+                          rvec            *fshift)
 {
-    nbnxn_gpu_try_finish_task(nb, flags, aloc, haveOtherWork, e_lj, e_el, fshift,
-                              GpuTaskCompletion::Wait);
+    gpu_try_finish_task(nb, flags, aloc, haveOtherWork, e_lj, e_el, fshift,
+                        GpuTaskCompletion::Wait);
 }
 
+} // namespace Nbnxm
+
 #endif