Modernize PME GPU timing enums
[alexxy/gromacs.git] / src / gromacs / ewald / pme_gpu_internal.cpp
index d66865569ba1bc44de549472839482cd3cda7288..5f4a4fd28986214931ceb68c33acfe16989e0459 100644 (file)
@@ -1016,23 +1016,23 @@ void pme_gpu_reinit_atoms(PmeGpu* pmeGpu, const int nAtoms, const real* chargesA
  * In CUDA result can be nullptr stub, per GpuRegionTimer implementation.
  *
  * \param[in] pmeGpu         The PME GPU data structure.
- * \param[in] PMEStageId     The PME GPU stage gtPME_ index from the enum in src/gromacs/timing/gpu_timing.h
+ * \param[in] pmeStageId     The PME GPU stage gtPME_ index from the enum in src/gromacs/timing/gpu_timing.h
  */
-static CommandEvent* pme_gpu_fetch_timing_event(const PmeGpu* pmeGpu, size_t PMEStageId)
+static CommandEvent* pme_gpu_fetch_timing_event(const PmeGpu* pmeGpu, PmeStage pmeStageId)
 {
     CommandEvent* timingEvent = nullptr;
     if (pme_gpu_timings_enabled(pmeGpu))
     {
-        GMX_ASSERT(PMEStageId < pmeGpu->archSpecific->timingEvents.size(),
-                   "Wrong PME GPU timing event index");
-        timingEvent = pmeGpu->archSpecific->timingEvents[PMEStageId].fetchNextEvent();
+        GMX_ASSERT(pmeStageId < PmeStage::Count, "Wrong PME GPU timing event index");
+        timingEvent = pmeGpu->archSpecific->timingEvents[pmeStageId].fetchNextEvent();
     }
     return timingEvent;
 }
 
 void pme_gpu_3dfft(const PmeGpu* pmeGpu, gmx_fft_direction dir, const int grid_index)
 {
-    int timerId = (dir == GMX_FFT_REAL_TO_COMPLEX) ? gtPME_FFT_R2C : gtPME_FFT_C2R;
+    PmeStage timerId = (dir == GMX_FFT_REAL_TO_COMPLEX) ? PmeStage::FftTransformR2C
+                                                        : PmeStage::FftTransformC2R;
 
     pme_gpu_start_timing(pmeGpu, timerId);
     pmeGpu->archSpecific->fftSetup[grid_index]->perform3dFft(
@@ -1313,13 +1313,13 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
     config.gridSize[0]  = dimGrid.first;
     config.gridSize[1]  = dimGrid.second;
 
-    int                                timingId;
+    PmeStage                           timingId;
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     if (computeSplines)
     {
         if (spreadCharges)
         {
-            timingId  = gtPME_SPLINEANDSPREAD;
+            timingId  = PmeStage::SplineAndSpread;
             kernelPtr = selectSplineAndSpreadKernelPtr(pmeGpu,
                                                        pmeGpu->settings.threadsPerAtom,
                                                        writeGlobal || (!recalculateSplines),
@@ -1327,7 +1327,7 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
         }
         else
         {
-            timingId  = gtPME_SPLINE;
+            timingId  = PmeStage::Spline;
             kernelPtr = selectSplineKernelPtr(pmeGpu,
                                               pmeGpu->settings.threadsPerAtom,
                                               writeGlobal || (!recalculateSplines),
@@ -1336,7 +1336,7 @@ void pme_gpu_spread(const PmeGpu*         pmeGpu,
     }
     else
     {
-        timingId  = gtPME_SPREAD;
+        timingId  = PmeStage::Spread;
         kernelPtr = selectSpreadKernelPtr(pmeGpu,
                                           pmeGpu->settings.threadsPerAtom,
                                           writeGlobal || (!recalculateSplines),
@@ -1463,7 +1463,7 @@ void pme_gpu_solve(const PmeGpu* pmeGpu,
                          / gridLinesPerBlock;
     config.gridSize[2] = pmeGpu->kernelParams->grid.complexGridSize[majorDim];
 
-    int                                timingId  = gtPME_SOLVE;
+    PmeStage                           timingId  = PmeStage::Solve;
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr = nullptr;
     if (gridOrdering == GridOrdering::YZX)
     {
@@ -1655,7 +1655,7 @@ void pme_gpu_gather(PmeGpu* pmeGpu, real** h_grids, const float lambda)
 
     // TODO test different cache configs
 
-    int                                timingId = gtPME_GATHER;
+    PmeStage                           timingId = PmeStage::Gather;
     PmeGpuProgramImpl::PmeKernelHandle kernelPtr =
             selectGatherKernelPtr(pmeGpu,
                                   pmeGpu->settings.threadsPerAtom,