ArrayRef and const ref in gmx_nb_free_energy_kernel
authorejjordan <ejjordan@kth.se>
Wed, 24 Mar 2021 22:33:40 +0000 (23:33 +0100)
committerArtem Zhmurov <zhmurov@gmail.com>
Thu, 25 Mar 2021 12:30:54 +0000 (12:30 +0000)
Use const ref where possible in nonbonded free energy kernels. Also
use ArrayRef for passing coordinates.

src/gromacs/gmxlib/nonbonded/nb_free_energy.cpp
src/gromacs/gmxlib/nonbonded/nb_free_energy.h
src/gromacs/mdlib/sim_util.cpp
src/gromacs/nbnxm/kerneldispatch.cpp
src/gromacs/nbnxm/nbnxm.h

index 66adc9fc3e2b182a7b0902963e760fe4b09c7b33..5fb29200b803afac67fe445ca99453916bdc9747 100644 (file)
@@ -199,16 +199,16 @@ static inline RealType potSwitchPotentialMod(const RealType potentialInp,
 
 //! Templated free-energy non-bonded kernel
 template<typename DataTypes, bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
-static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
-                                  rvec* gmx_restrict         xx,
-                                  gmx::ForceWithShiftForces* forceWithShiftForces,
-                                  const t_forcerec* gmx_restrict fr,
-                                  const t_mdatoms* gmx_restrict mdatoms,
-                                  int                           flags,
-                                  gmx::ArrayRef<const real>     lambda,
-                                  gmx::ArrayRef<real>           dvdl,
-                                  gmx::ArrayRef<real>           energygrp_elec,
-                                  gmx::ArrayRef<real>           energygrp_vdw,
+static void nb_free_energy_kernel(const t_nblist&                nlist,
+                                  gmx::ArrayRef<const gmx::RVec> coords,
+                                  gmx::ForceWithShiftForces*     forceWithShiftForces,
+                                  const t_forcerec&              fr,
+                                  const t_mdatoms&               mdatoms,
+                                  int                            flags,
+                                  gmx::ArrayRef<const real>      lambda,
+                                  gmx::ArrayRef<real>            dvdl,
+                                  gmx::ArrayRef<real>            energygrp_elec,
+                                  gmx::ArrayRef<real>            energygrp_vdw,
                                   t_nrnb* gmx_restrict nrnb)
 {
 #define STATE_A 0
@@ -228,24 +228,24 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
     constexpr real six        = 6.0;
 
     /* Extract pointer to non-bonded interaction constants */
-    const interaction_const_t* ic = fr->ic.get();
+    const interaction_const_t* ic = fr.ic.get();
 
     // Extract pair list data
-    const int                nri    = nlist->nri;
-    gmx::ArrayRef<const int> iinr   = nlist->iinr;
-    gmx::ArrayRef<const int> jindex = nlist->jindex;
-    gmx::ArrayRef<const int> jjnr   = nlist->jjnr;
-    gmx::ArrayRef<const int> shift  = nlist->shift;
-    gmx::ArrayRef<const int> gid    = nlist->gid;
-
-    const real*               shiftvec  = fr->shift_vec[0];
-    const real*               chargeA   = mdatoms->chargeA;
-    const real*               chargeB   = mdatoms->chargeB;
-    const int*                typeA     = mdatoms->typeA;
-    const int*                typeB     = mdatoms->typeB;
-    const int                 ntype     = fr->ntype;
-    gmx::ArrayRef<const real> nbfp      = fr->nbfp;
-    gmx::ArrayRef<const real> nbfp_grid = fr->ljpme_c6grid;
+    const int                nri    = nlist.nri;
+    gmx::ArrayRef<const int> iinr   = nlist.iinr;
+    gmx::ArrayRef<const int> jindex = nlist.jindex;
+    gmx::ArrayRef<const int> jjnr   = nlist.jjnr;
+    gmx::ArrayRef<const int> shift  = nlist.shift;
+    gmx::ArrayRef<const int> gid    = nlist.gid;
+
+    const real*               shiftvec  = fr.shift_vec[0];
+    const real*               chargeA   = mdatoms.chargeA;
+    const real*               chargeB   = mdatoms.chargeB;
+    const int*                typeA     = mdatoms.typeA;
+    const int*                typeB     = mdatoms.typeB;
+    const int                 ntype     = fr.ntype;
+    gmx::ArrayRef<const real> nbfp      = fr.nbfp;
+    gmx::ArrayRef<const real> nbfp_grid = fr.ljpme_c6grid;
 
     const real  lambda_coul   = lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)];
     const real  lambda_vdw    = lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)];
@@ -374,11 +374,11 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
     }
 
     // TODO: We should get rid of using pointers to real
-    const real* x             = xx[0];
+    const real* x             = coords[0];
     real* gmx_restrict f      = &(forceWithShiftForces->force()[0][0]);
     real* gmx_restrict fshift = &(forceWithShiftForces->shiftForces()[0][0]);
 
-    const real rlistSquared = gmx::square(fr->rlist);
+    const real rlistSquared = gmx::square(fr.rlist);
 
     int numExcludedPairsBeyondRlist = 0;
 
@@ -421,7 +421,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
             const RealType rSq = dX * dX + dY * dY + dZ * dZ;
             RealType       fScalC[NSTATES], fScalV[NSTATES];
             /* Check if this pair on the exlusions list.*/
-            const bool bPairIncluded = nlist->excl_fep.empty() || nlist->excl_fep[k];
+            const bool bPairIncluded = nlist.excl_fep.empty() || nlist.excl_fep[k];
 
             if (rSq >= rcutoff_max2 && bPairIncluded)
             {
@@ -849,7 +849,7 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
      * 12  flops per outer iteration
      * 150 flops per inner iteration
      */
-    atomicNrnbIncrement(nrnb, eNR_NBKERNEL_FREE_ENERGY, nlist->nri * 12 + nlist->jindex[nri] * 150);
+    atomicNrnbIncrement(nrnb, eNR_NBKERNEL_FREE_ENERGY, nlist.nri * 12 + nlist.jindex[nri] * 150);
 
     if (numExcludedPairsBeyondRlist > 0)
     {
@@ -862,20 +862,20 @@ static void nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
                   "The error is likely triggered by the use of couple-intramol=no "
                   "and the maximal distance in the decoupled molecule exceeding rlist.",
                   numExcludedPairsBeyondRlist,
-                  fr->rlist);
+                  fr.rlist);
     }
 }
 
-typedef void (*KernelFunction)(const t_nblist* gmx_restrict nlist,
-                               rvec* gmx_restrict         xx,
-                               gmx::ForceWithShiftForces* forceWithShiftForces,
-                               const t_forcerec* gmx_restrict fr,
-                               const t_mdatoms* gmx_restrict mdatoms,
-                               int                           flags,
-                               gmx::ArrayRef<const real>     lambda,
-                               gmx::ArrayRef<real>           dvdl,
-                               gmx::ArrayRef<real>           energygrp_elec,
-                               gmx::ArrayRef<real>           energygrp_vdw,
+typedef void (*KernelFunction)(const t_nblist&                nlist,
+                               gmx::ArrayRef<const gmx::RVec> coords,
+                               gmx::ForceWithShiftForces*     forceWithShiftForces,
+                               const t_forcerec&              fr,
+                               const t_mdatoms&               mdatoms,
+                               int                            flags,
+                               gmx::ArrayRef<const real>      lambda,
+                               gmx::ArrayRef<real>            dvdl,
+                               gmx::ArrayRef<real>            energygrp_elec,
+                               gmx::ArrayRef<real>            energygrp_vdw,
                                t_nrnb* gmx_restrict nrnb);
 
 template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
@@ -991,19 +991,19 @@ static KernelFunction dispatchKernel(const bool                 scLambdasOrAlpha
 }
 
 
-void gmx_nb_free_energy_kernel(const t_nblist*            nlist,
-                               rvec*                      xx,
-                               gmx::ForceWithShiftForces* ff,
-                               const t_forcerec*          fr,
-                               const t_mdatoms*           mdatoms,
-                               int                        flags,
-                               gmx::ArrayRef<const real>  lambda,
-                               gmx::ArrayRef<real>        dvdl,
-                               gmx::ArrayRef<real>        energygrp_elec,
-                               gmx::ArrayRef<real>        energygrp_vdw,
-                               t_nrnb*                    nrnb)
+void gmx_nb_free_energy_kernel(const t_nblist&                nlist,
+                               gmx::ArrayRef<const gmx::RVec> coords,
+                               gmx::ForceWithShiftForces*     ff,
+                               const t_forcerec&              fr,
+                               const t_mdatoms&               mdatoms,
+                               int                            flags,
+                               gmx::ArrayRef<const real>      lambda,
+                               gmx::ArrayRef<real>            dvdl,
+                               gmx::ArrayRef<real>            energygrp_elec,
+                               gmx::ArrayRef<real>            energygrp_vdw,
+                               t_nrnb*                        nrnb)
 {
-    const interaction_const_t& ic = *fr->ic;
+    const interaction_const_t& ic = *fr.ic;
     GMX_ASSERT(EEL_PME_EWALD(ic.eeltype) || ic.eeltype == CoulombInteractionType::Cut || EEL_RF(ic.eeltype),
                "Unsupported eeltype with free energy");
     GMX_ASSERT(ic.softCoreParameters, "We need soft-core parameters");
@@ -1013,7 +1013,7 @@ void gmx_nb_free_energy_kernel(const t_nblist*            nlist,
     const bool  elecInteractionTypeIsEwald = (EEL_PME_EWALD(ic.eeltype));
     const bool  vdwModifierIsPotSwitch     = (ic.vdw_modifier == InteractionModifiers::PotSwitch);
     bool        scLambdasOrAlphasDiffer    = true;
-    const bool  useSimd                    = fr->use_simd_kernels;
+    const bool  useSimd                    = fr.use_simd_kernels;
 
     if (scParams.alphaCoulomb == 0 && scParams.alphaVdw == 0)
     {
@@ -1036,5 +1036,5 @@ void gmx_nb_free_energy_kernel(const t_nblist*            nlist,
                                 vdwModifierIsPotSwitch,
                                 useSimd,
                                 ic);
-    kernelFunc(nlist, xx, ff, fr, mdatoms, flags, lambda, dvdl, energygrp_elec, energygrp_vdw, nrnb);
+    kernelFunc(nlist, coords, ff, fr, mdatoms, flags, lambda, dvdl, energygrp_elec, energygrp_vdw, nrnb);
 }
index d3cf9f546ed1bdaa9dca7d04f569da0ba018a005..67920fc4ae8712d42975f52ad5fbbba9c309494a 100644 (file)
@@ -52,16 +52,16 @@ template<typename>
 class ArrayRef;
 } // namespace gmx
 
-void gmx_nb_free_energy_kernel(const t_nblist* gmx_restrict nlist,
-                               rvec* gmx_restrict         xx,
-                               gmx::ForceWithShiftForces* forceWithShiftForces,
-                               const t_forcerec* gmx_restrict fr,
-                               const t_mdatoms* gmx_restrict mdatoms,
-                               int                           flags,
-                               gmx::ArrayRef<const real>     lambda,
-                               gmx::ArrayRef<real>           dvdl,
-                               gmx::ArrayRef<real>           energygrp_elec,
-                               gmx::ArrayRef<real>           energygrp_vdw,
+void gmx_nb_free_energy_kernel(const t_nblist&                nlist,
+                               gmx::ArrayRef<const gmx::RVec> coords,
+                               gmx::ForceWithShiftForces*     forceWithShiftForces,
+                               const t_forcerec&              fr,
+                               const t_mdatoms&               mdatoms,
+                               int                            flags,
+                               gmx::ArrayRef<const real>      lambda,
+                               gmx::ArrayRef<real>            dvdl,
+                               gmx::ArrayRef<real>            energygrp_elec,
+                               gmx::ArrayRef<real>            energygrp_vdw,
                                t_nrnb* gmx_restrict nrnb);
 
 #endif
index 362cbaccf4b90f517577d22ea9506cbba3c0442d..abb68c4a765906d30c5a60e2090fa37f61549860 100644 (file)
@@ -1756,8 +1756,8 @@ void do_force(FILE*                               fplog,
          * Happens here on the CPU both with and without GPU.
          */
         nbv->dispatchFreeEnergyKernel(InteractionLocality::Local,
-                                      fr,
-                                      as_rvec_array(x.unpaddedArrayRef().data()),
+                                      *fr,
+                                      x.unpaddedArrayRef(),
                                       &forceOutNonbonded->forceWithShiftForces(),
                                       *mdatoms,
                                       inputrec.fepvals.get(),
@@ -1769,8 +1769,8 @@ void do_force(FILE*                               fplog,
         if (havePPDomainDecomposition(cr))
         {
             nbv->dispatchFreeEnergyKernel(InteractionLocality::NonLocal,
-                                          fr,
-                                          as_rvec_array(x.unpaddedArrayRef().data()),
+                                          *fr,
+                                          x.unpaddedArrayRef(),
                                           &forceOutNonbonded->forceWithShiftForces(),
                                           *mdatoms,
                                           inputrec.fepvals.get(),
index ca88b0a5b6834b1ab2de4bf40b6e974b413c8907..e40f04f8bba01545ada1630fd217fe3d23e6eb24 100644 (file)
@@ -493,9 +493,9 @@ void nonbonded_verlet_t::dispatchNonbondedKernel(gmx::InteractionLocality   iLoc
     accountFlops(nrnb, pairlistSet, *this, ic, stepWork);
 }
 
-void nonbonded_verlet_t::dispatchFreeEnergyKernel(gmx::InteractionLocality   iLocality,
-                                                  const t_forcerec*          fr,
-                                                  rvec                       x[],
+void nonbonded_verlet_t::dispatchFreeEnergyKernel(gmx::InteractionLocality       iLocality,
+                                                  const t_forcerec&              fr,
+                                                  gmx::ArrayRef<const gmx::RVec> coords,
                                                   gmx::ForceWithShiftForces* forceWithShiftForces,
                                                   const t_mdatoms&           mdatoms,
                                                   t_lambda*                  fepvals,
@@ -546,11 +546,11 @@ void nonbonded_verlet_t::dispatchFreeEnergyKernel(gmx::InteractionLocality   iLo
     {
         try
         {
-            gmx_nb_free_energy_kernel(nbl_fep[th].get(),
-                                      x,
+            gmx_nb_free_energy_kernel(*nbl_fep[th],
+                                      coords,
                                       forceWithShiftForces,
                                       fr,
-                                      &mdatoms,
+                                      mdatoms,
                                       kernelFlags,
                                       kernelLambda,
                                       kernelDvdl,
@@ -604,11 +604,11 @@ void nonbonded_verlet_t::dispatchFreeEnergyKernel(gmx::InteractionLocality   iLo
             {
                 try
                 {
-                    gmx_nb_free_energy_kernel(nbl_fep[th].get(),
-                                              x,
+                    gmx_nb_free_energy_kernel(*nbl_fep[th],
+                                              coords,
                                               forceWithShiftForces,
                                               fr,
-                                              &mdatoms,
+                                              mdatoms,
                                               kernelFlags,
                                               kernelLambda,
                                               kernelDvdl,
index a36efad0233237db92deef65047ea47846db95ec..37f09824e7ca74f3d1ef56937677d61985d3c615 100644 (file)
@@ -367,16 +367,16 @@ public:
                                  t_nrnb*                    nrnb);
 
     //! Executes the non-bonded free-energy kernel, always runs on the CPU
-    void dispatchFreeEnergyKernel(gmx::InteractionLocality   iLocality,
-                                  const t_forcerec*          fr,
-                                  rvec                       x[],
-                                  gmx::ForceWithShiftForces* forceWithShiftForces,
-                                  const t_mdatoms&           mdatoms,
-                                  t_lambda*                  fepvals,
-                                  gmx::ArrayRef<const real>  lambda,
-                                  gmx_enerdata_t*            enerd,
-                                  const gmx::StepWorkload&   stepWork,
-                                  t_nrnb*                    nrnb);
+    void dispatchFreeEnergyKernel(gmx::InteractionLocality       iLocality,
+                                  const t_forcerec&              fr,
+                                  gmx::ArrayRef<const gmx::RVec> coords,
+                                  gmx::ForceWithShiftForces*     forceWithShiftForces,
+                                  const t_mdatoms&               mdatoms,
+                                  t_lambda*                      fepvals,
+                                  gmx::ArrayRef<const real>      lambda,
+                                  gmx_enerdata_t*                enerd,
+                                  const gmx::StepWorkload&       stepWork,
+                                  t_nrnb*                        nrnb);
 
     /*! \brief Add the forces stored in nbat to f, zeros the forces in nbat
      * \param [in] locality         Local or non-local