Refactor md_enums
[alexxy/gromacs.git] / src / gromacs / mdlib / sim_util.cpp
index b6b21dad6b138ae46f8bb98520da2ecb9aa08592..a18add27538ca959c091944fd5dba7ea4b036cce 100644 (file)
@@ -203,9 +203,17 @@ static void pull_potential_wrapper(const t_commrec*               cr,
     wallcycle_start(wcycle, ewcPULLPOT);
     set_pbc(&pbc, ir.pbcType, box);
     dvdl = 0;
-    enerd->term[F_COM_PULL] += pull_potential(
-            pull_work, mdatoms->massT, &pbc, cr, t, lambda[efptRESTRAINT], as_rvec_array(x.data()), force, &dvdl);
-    enerd->dvdl_lin[efptRESTRAINT] += dvdl;
+    enerd->term[F_COM_PULL] +=
+            pull_potential(pull_work,
+                           mdatoms->massT,
+                           &pbc,
+                           cr,
+                           t,
+                           lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Restraint)],
+                           as_rvec_array(x.data()),
+                           force,
+                           &dvdl);
+    enerd->dvdl_lin[FreeEnergyPerturbationCouplingType::Restraint] += dvdl;
     wallcycle_stop(wcycle, ewcPULLPOT);
 }
 
@@ -241,8 +249,8 @@ static void pme_receive_force_ener(t_forcerec*           fr,
                       &cycles_seppme);
     enerd->term[F_COUL_RECIP] += e_q;
     enerd->term[F_LJ_RECIP] += e_lj;
-    enerd->dvdl_lin[efptCOUL] += dvdl_q;
-    enerd->dvdl_lin[efptVDW] += dvdl_lj;
+    enerd->dvdl_lin[FreeEnergyPerturbationCouplingType::Coul] += dvdl_q;
+    enerd->dvdl_lin[FreeEnergyPerturbationCouplingType::Vdw] += dvdl_lj;
 
     if (wcycle)
     {
@@ -896,7 +904,8 @@ static DomainLifetimeWorkload setupDomainLifetimeWorkload(const t_inputrec&
     }
     domainWork.haveGpuBondedWork = ((fr.gpuBonded != nullptr) && fr.gpuBonded->haveInteractions());
     // Note that haveFreeEnergyWork is constant over the whole run
-    domainWork.haveFreeEnergyWork = (fr.efep != efepNO && mdatoms.nPerturbed != 0);
+    domainWork.haveFreeEnergyWork =
+            (fr.efep != FreeEnergyPerturbationType::No && mdatoms.nPerturbed != 0);
     // We assume we have local force work if there are CPU
     // force tasks including PME or nonbondeds.
     domainWork.haveCpuLocalForceWork =
@@ -1044,8 +1053,10 @@ static void reduceAndUpdateMuTot(DipoleData*                   dipoleData,
     {
         for (int j = 0; j < DIM; j++)
         {
-            muTotal[j] = (1.0 - lambda[efptCOUL]) * dipoleData->muStateAB[0][j]
-                         + lambda[efptCOUL] * dipoleData->muStateAB[1][j];
+            muTotal[j] = (1.0 - lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)])
+                                 * dipoleData->muStateAB[0][j]
+                         + lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)]
+                                   * dipoleData->muStateAB[1][j];
         }
     }
 }
@@ -1285,8 +1296,8 @@ void do_force(FILE*                               fplog,
                                  cr,
                                  box,
                                  as_rvec_array(x.unpaddedArrayRef().data()),
-                                 lambda[efptCOUL],
-                                 lambda[efptVDW],
+                                 lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)],
+                                 lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)],
                                  (stepWork.computeVirial || stepWork.computeEnergy),
                                  step,
                                  simulationWork.useGpuPmePpCommunication,
@@ -1332,8 +1343,8 @@ void do_force(FILE*                               fplog,
                                  cr,
                                  box,
                                  as_rvec_array(x.unpaddedArrayRef().data()),
-                                 lambda[efptCOUL],
-                                 lambda[efptVDW],
+                                 lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)],
+                                 lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)],
                                  (stepWork.computeVirial || stepWork.computeEnergy),
                                  step,
                                  simulationWork.useGpuPmePpCommunication,
@@ -1345,7 +1356,12 @@ void do_force(FILE*                               fplog,
 
     if (useGpuPmeOnThisRank)
     {
-        launchPmeGpuSpread(fr->pmedata, box, stepWork, localXReadyOnDevice, lambda[efptCOUL], wcycle);
+        launchPmeGpuSpread(fr->pmedata,
+                           box,
+                           stepWork,
+                           localXReadyOnDevice,
+                           lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)],
+                           wcycle);
     }
 
     const gmx::DomainLifetimeWorkload& domainWork = runScheduleWork->domainWork;
@@ -1501,7 +1517,10 @@ void do_force(FILE*                               fplog,
         // X copy/transform to allow overlap as well as after the GPU NB
         // launch to avoid FFT launch overhead hijacking the CPU and delaying
         // the nonbonded kernel.
-        launchPmeGpuFftAndGather(fr->pmedata, lambda[efptCOUL], wcycle, stepWork);
+        launchPmeGpuFftAndGather(fr->pmedata,
+                                 lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)],
+                                 wcycle,
+                                 stepWork);
     }
 
     /* Communicate coordinates and sum dipole if necessary +
@@ -1643,7 +1662,8 @@ void do_force(FILE*                               fplog,
                 dipoleData.muStaging[0],
                 dipoleData.muStaging[1]);
 
-        reduceAndUpdateMuTot(&dipoleData, cr, (fr->efep != efepNO), lambda, muTotal, ddBalanceRegionHandler);
+        reduceAndUpdateMuTot(
+                &dipoleData, cr, (fr->efep != FreeEnergyPerturbationType::No), lambda, muTotal, ddBalanceRegionHandler);
     }
 
     /* Reset energies */
@@ -1721,7 +1741,7 @@ void do_force(FILE*                               fplog,
         do_nb_verlet(fr, ic, enerd, stepWork, InteractionLocality::Local, enbvClearFYes, step, nrnb, wcycle);
     }
 
-    if (fr->efep != efepNO && stepWork.computeNonbondedForces)
+    if (fr->efep != FreeEnergyPerturbationType::No && stepWork.computeNonbondedForces)
     {
         /* Calculate the local and non-local free energy interactions here.
          * Happens here on the CPU both with and without GPU.
@@ -1731,7 +1751,7 @@ void do_force(FILE*                               fplog,
                                       as_rvec_array(x.unpaddedArrayRef().data()),
                                       &forceOutNonbonded->forceWithShiftForces(),
                                       *mdatoms,
-                                      inputrec.fepvals,
+                                      inputrec.fepvals.get(),
                                       lambda,
                                       enerd,
                                       stepWork,
@@ -1744,7 +1764,7 @@ void do_force(FILE*                               fplog,
                                           as_rvec_array(x.unpaddedArrayRef().data()),
                                           &forceOutNonbonded->forceWithShiftForces(),
                                           *mdatoms,
-                                          inputrec.fepvals,
+                                          inputrec.fepvals.get(),
                                           lambda,
                                           enerd,
                                           stepWork,
@@ -1801,10 +1821,10 @@ void do_force(FILE*                               fplog,
                                    *mdatoms,
                                    x.unpaddedConstArrayRef(),
                                    &forceOutMtsLevel0.forceWithVirial(),
-                                   lambda[efptVDW],
+                                   lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)],
                                    enerd->grpp.ener[egLJSR].data(),
                                    nrnb);
-        enerd->dvdl_lin[efptVDW] += dvdl_walls;
+        enerd->dvdl_lin[FreeEnergyPerturbationCouplingType::Vdw] += dvdl_walls;
     }
 
     if (stepWork.computeListedForces)
@@ -1835,7 +1855,7 @@ void do_force(FILE*                               fplog,
             ForceOutputs& forceOut     = (mtsIndex == 0 ? forceOutMtsLevel0 : *forceOutMtsLevel1);
             listedForces.calculate(wcycle,
                                    box,
-                                   inputrec.fepvals,
+                                   inputrec.fepvals.get(),
                                    cr,
                                    ms,
                                    x,
@@ -1847,7 +1867,7 @@ void do_force(FILE*                               fplog,
                                    &pbc,
                                    enerd,
                                    nrnb,
-                                   lambda.data(),
+                                   lambda,
                                    mdatoms,
                                    DOMAINDECOMP(cr) ? cr->dd->globalAtomIndices.data() : nullptr,
                                    stepWork);
@@ -1878,14 +1898,14 @@ void do_force(FILE*                               fplog,
     if ((stepWork.computeEnergy || stepWork.computeVirial) && fr->dispersionCorrection && MASTER(cr))
     {
         // Calculate long range corrections to pressure and energy
-        const DispersionCorrection::Correction correction =
-                fr->dispersionCorrection->calculate(box, lambda[efptVDW]);
+        const DispersionCorrection::Correction correction = fr->dispersionCorrection->calculate(
+                box, lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)]);
 
         if (stepWork.computeEnergy)
         {
             enerd->term[F_DISPCORR] = correction.energy;
             enerd->term[F_DVDL_VDW] += correction.dvdl;
-            enerd->dvdl_lin[efptVDW] += correction.dvdl;
+            enerd->dvdl_lin[FreeEnergyPerturbationCouplingType::Vdw] += correction.dvdl;
         }
         if (stepWork.computeVirial)
         {
@@ -2050,14 +2070,24 @@ void do_force(FILE*                               fplog,
                              && !DOMAINDECOMP(cr) && !stepWork.useGpuFBufferOps);
     if (alternateGpuWait)
     {
-        alternatePmeNbGpuWaitReduce(
-                fr->nbv.get(), fr->pmedata, forceOutNonbonded, forceOutMtsLevel1, enerd, lambda[efptCOUL], stepWork, wcycle);
+        alternatePmeNbGpuWaitReduce(fr->nbv.get(),
+                                    fr->pmedata,
+                                    forceOutNonbonded,
+                                    forceOutMtsLevel1,
+                                    enerd,
+                                    lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)],
+                                    stepWork,
+                                    wcycle);
     }
 
     if (!alternateGpuWait && useGpuPmeOnThisRank)
     {
-        pme_gpu_wait_and_reduce(
-                fr->pmedata, stepWork, wcycle, &forceOutMtsLevel1->forceWithVirial(), enerd, lambda[efptCOUL]);
+        pme_gpu_wait_and_reduce(fr->pmedata,
+                                stepWork,
+                                wcycle,
+                                &forceOutMtsLevel1->forceWithVirial(),
+                                enerd,
+                                lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)]);
     }
 
     /* Wait for local GPU NB outputs on the non-alternating wait path */
@@ -2254,7 +2284,7 @@ void do_force(FILE*                               fplog,
     if (stepWork.computeEnergy)
     {
         /* Compute the final potential energy terms */
-        accumulatePotentialEnergies(enerd, lambda, inputrec.fepvals);
+        accumulatePotentialEnergies(enerd, lambda, inputrec.fepvals.get());
 
         if (!EI_TPI(inputrec.eI))
         {