Remove mdatoms from pairs in listed forces
authorejjordan <ejjordan@kth.se>
Wed, 28 Apr 2021 11:14:40 +0000 (13:14 +0200)
committerArtem Zhmurov <zhmurov@gmail.com>
Mon, 3 May 2021 19:54:08 +0000 (19:54 +0000)
src/gromacs/listed_forces/listed_forces.cpp
src/gromacs/listed_forces/pairs.cpp
src/gromacs/listed_forces/pairs.h
src/gromacs/listed_forces/tests/pairs.cpp

index a35e58911cbda26533b7de417c3206291df4045a..d8511e56b434c7d9567409a8d25c620432b78a3e 100644 (file)
@@ -512,7 +512,11 @@ real calc_one_bond(int                           thread,
                  pbc,
                  lambda.data(),
                  dvdl.data(),
-                 md,
+                 gmx::arrayRefFromArray(md->chargeA, md->nr),
+                 gmx::arrayRefFromArray(md->chargeB, md->nr),
+                 md->bPerturbed ? gmx::arrayRefFromArray(md->bPerturbed, md->nr) : gmx::ArrayRef<bool>(),
+                 gmx::arrayRefFromArray(md->cENER, md->nr),
+                 md->nPerturbed,
                  fr,
                  havePerturbedInteractions,
                  stepWork,
index 9e73bfa9f4feb42a8d258a9b4980c334d03b1241..a7b18375713f2c1593c5b2c7d4cd0c4adec5dc44 100644 (file)
@@ -54,7 +54,6 @@
 #include "gromacs/mdtypes/forcerec.h"
 #include "gromacs/mdtypes/group.h"
 #include "gromacs/mdtypes/interaction_const.h"
-#include "gromacs/mdtypes/mdatom.h"
 #include "gromacs/mdtypes/md_enums.h"
 #include "gromacs/mdtypes/nblist.h"
 #include "gromacs/mdtypes/simulation_workload.h"
@@ -342,20 +341,24 @@ static real free_energy_evaluate_single(real
 
 /*! \brief Calculate pair interactions, supports all types and conditions. */
 template<BondedKernelFlavor flavor>
-static real do_pairs_general(int                 ftype,
-                             int                 nbonds,
-                             const t_iatom       iatoms[],
-                             const t_iparams     iparams[],
-                             const rvec          x[],
-                             rvec4               f[],
-                             rvec                fshift[],
-                             const struct t_pbc* pbc,
-                             const real*         lambda,
-                             real*               dvdl,
-                             const t_mdatoms*    md,
-                             const t_forcerec*   fr,
-                             gmx_grppairener_t*  grppener,
-                             int*                global_atom_index)
+static real do_pairs_general(int                           ftype,
+                             int                           nbonds,
+                             const t_iatom                 iatoms[],
+                             const t_iparams               iparams[],
+                             const rvec                    x[],
+                             rvec4                         f[],
+                             rvec                          fshift[],
+                             const struct t_pbc*           pbc,
+                             const real*                   lambda,
+                             real*                         dvdl,
+                             gmx::ArrayRef<real>           chargeA,
+                             gmx::ArrayRef<real>           chargeB,
+                             gmx::ArrayRef<bool>           atomIsPerturbed,
+                             gmx::ArrayRef<unsigned short> cENER,
+                             int                           numEnergyGroups,
+                             const t_forcerec*             fr,
+                             gmx_grppairener_t*            grppener,
+                             int*                          global_atom_index)
 {
     real            qq, c6, c12;
     rvec            dx;
@@ -427,29 +430,26 @@ static real do_pairs_general(int                 ftype,
 
     const real epsfac = fr->ic->epsfac;
 
-    bFreeEnergy      = FALSE;
-    auto* cENER      = md->cENER;
-    auto* bPerturbed = md->bPerturbed;
-    auto* chargeA    = md->chargeA;
-    auto* chargeB    = md->chargeB;
+    bFreeEnergy = FALSE;
     for (i = 0; (i < nbonds);)
     {
         itype = iatoms[i++];
         ai    = iatoms[i++];
         aj    = iatoms[i++];
-        gid   = GID(cENER[ai], cENER[aj], md->nenergrp);
+        gid   = GID(cENER[ai], cENER[aj], numEnergyGroups);
 
         /* Get parameters */
         switch (ftype)
         {
             case F_LJ14:
-                bFreeEnergy = (fr->efep != FreeEnergyPerturbationType::No
-                               && ((md->bPerturbed && (bPerturbed[ai] || bPerturbed[aj]))
-                                   || iparams[itype].lj14.c6A != iparams[itype].lj14.c6B
-                                   || iparams[itype].lj14.c12A != iparams[itype].lj14.c12B));
-                qq          = chargeA[ai] * chargeA[aj] * epsfac * fr->fudgeQQ;
-                c6          = iparams[itype].lj14.c6A;
-                c12         = iparams[itype].lj14.c12A;
+                bFreeEnergy =
+                        (fr->efep != FreeEnergyPerturbationType::No
+                         && ((!atomIsPerturbed.empty() && (atomIsPerturbed[ai] || atomIsPerturbed[aj]))
+                             || iparams[itype].lj14.c6A != iparams[itype].lj14.c6B
+                             || iparams[itype].lj14.c12A != iparams[itype].lj14.c12B));
+                qq  = chargeA[ai] * chargeA[aj] * epsfac * fr->fudgeQQ;
+                c6  = iparams[itype].lj14.c6A;
+                c12 = iparams[itype].lj14.c12A;
                 break;
             case F_LJC14_Q:
                 qq = iparams[itype].ljc14.qi * iparams[itype].ljc14.qj * epsfac
@@ -568,14 +568,14 @@ static real do_pairs_general(int                 ftype,
  * This function is templated for real/SimdReal and for optimization.
  */
 template<typename T, int pack_size, typename pbc_type>
-static void do_pairs_simple(int              nbonds,
-                            const t_iatom    iatoms[],
-                            const t_iparams  iparams[],
-                            const rvec       x[],
-                            rvec4            f[],
-                            const pbc_type   pbc,
-                            const t_mdatoms* md,
-                            const real       scale_factor)
+static void do_pairs_simple(int                 nbonds,
+                            const t_iatom       iatoms[],
+                            const t_iparams     iparams[],
+                            const rvec          x[],
+                            rvec4               f[],
+                            const pbc_type      pbc,
+                            gmx::ArrayRef<real> charge,
+                            const real          scale_factor)
 {
     const int nfa1 = 1 + 2;
 
@@ -592,7 +592,6 @@ static void do_pairs_simple(int              nbonds,
     std::int32_t aj[pack_size];
     real         coeff[3 * pack_size];
 #endif
-    auto* chargeA = md->chargeA;
 
     /* nbonds is #pairs*nfa1, here we step pack_size pairs */
     for (int i = 0; i < nbonds; i += pack_size * nfa1)
@@ -611,7 +610,7 @@ static void do_pairs_simple(int              nbonds,
             {
                 coeff[0 * pack_size + s] = iparams[itype].lj14.c6A;
                 coeff[1 * pack_size + s] = iparams[itype].lj14.c12A;
-                coeff[2 * pack_size + s] = chargeA[ai[s]] * chargeA[aj[s]];
+                coeff[2 * pack_size + s] = charge[ai[s]] * charge[aj[s]];
 
                 /* Avoid indexing the iatoms array out of bounds.
                  * We pad the coordinate indices with the last atom pair.
@@ -672,22 +671,26 @@ static void do_pairs_simple(int              nbonds,
 }
 
 /*! \brief Calculate all listed pair interactions */
-void do_pairs(int                      ftype,
-              int                      nbonds,
-              const t_iatom            iatoms[],
-              const t_iparams          iparams[],
-              const rvec               x[],
-              rvec4                    f[],
-              rvec                     fshift[],
-              const struct t_pbc*      pbc,
-              const real*              lambda,
-              real*                    dvdl,
-              const t_mdatoms*         md,
-              const t_forcerec*        fr,
-              const bool               havePerturbedInteractions,
-              const gmx::StepWorkload& stepWork,
-              gmx_grppairener_t*       grppener,
-              int*                     global_atom_index)
+void do_pairs(int                           ftype,
+              int                           nbonds,
+              const t_iatom                 iatoms[],
+              const t_iparams               iparams[],
+              const rvec                    x[],
+              rvec4                         f[],
+              rvec                          fshift[],
+              const struct t_pbc*           pbc,
+              const real*                   lambda,
+              real*                         dvdl,
+              gmx::ArrayRef<real>           chargeA,
+              gmx::ArrayRef<real>           chargeB,
+              gmx::ArrayRef<bool>           atomIsPerturbed,
+              gmx::ArrayRef<unsigned short> cENER,
+              const int                     numEnergyGroups,
+              const t_forcerec*             fr,
+              const bool                    havePerturbedInteractions,
+              const gmx::StepWorkload&      stepWork,
+              gmx_grppairener_t*            grppener,
+              int*                          global_atom_index)
 {
     if (ftype == F_LJ14 && fr->ic->vdwtype != VanDerWaalsType::User && !EEL_USER(fr->ic->eeltype)
         && !havePerturbedInteractions && (!stepWork.computeVirial && !stepWork.computeEnergy))
@@ -707,7 +710,7 @@ void do_pairs(int                      ftype,
             set_pbc_simd(pbc, pbc_simd);
 
             do_pairs_simple<SimdReal, GMX_SIMD_REAL_WIDTH, const real*>(
-                    nbonds, iatoms, iparams, x, f, pbc_simd, md, fr->ic->epsfac * fr->fudgeQQ);
+                    nbonds, iatoms, iparams, x, f, pbc_simd, chargeA, fr->ic->epsfac * fr->fudgeQQ);
         }
         else
 #endif
@@ -727,17 +730,49 @@ void do_pairs(int                      ftype,
             }
 
             do_pairs_simple<real, 1, const t_pbc*>(
-                    nbonds, iatoms, iparams, x, f, pbc_nonnull, md, fr->ic->epsfac * fr->fudgeQQ);
+                    nbonds, iatoms, iparams, x, f, pbc_nonnull, chargeA, fr->ic->epsfac * fr->fudgeQQ);
         }
     }
     else if (stepWork.computeVirial)
     {
-        do_pairs_general<BondedKernelFlavor::ForcesAndVirialAndEnergy>(
-                ftype, nbonds, iatoms, iparams, x, f, fshift, pbc, lambda, dvdl, md, fr, grppener, global_atom_index);
+        do_pairs_general<BondedKernelFlavor::ForcesAndVirialAndEnergy>(ftype,
+                                                                       nbonds,
+                                                                       iatoms,
+                                                                       iparams,
+                                                                       x,
+                                                                       f,
+                                                                       fshift,
+                                                                       pbc,
+                                                                       lambda,
+                                                                       dvdl,
+                                                                       chargeA,
+                                                                       chargeB,
+                                                                       atomIsPerturbed,
+                                                                       cENER,
+                                                                       numEnergyGroups,
+                                                                       fr,
+                                                                       grppener,
+                                                                       global_atom_index);
     }
     else
     {
-        do_pairs_general<BondedKernelFlavor::ForcesAndEnergy>(
-                ftype, nbonds, iatoms, iparams, x, f, fshift, pbc, lambda, dvdl, md, fr, grppener, global_atom_index);
+        do_pairs_general<BondedKernelFlavor::ForcesAndEnergy>(ftype,
+                                                              nbonds,
+                                                              iatoms,
+                                                              iparams,
+                                                              x,
+                                                              f,
+                                                              fshift,
+                                                              pbc,
+                                                              lambda,
+                                                              dvdl,
+                                                              chargeA,
+                                                              chargeB,
+                                                              atomIsPerturbed,
+                                                              cENER,
+                                                              numEnergyGroups,
+                                                              fr,
+                                                              grppener,
+                                                              global_atom_index);
     }
 }
index 34593c29336f5037097e0ea7b66828ffd4ba2a03..57c0c42aba3374239b3ebf339322fbe98b8dc242 100644 (file)
 struct gmx_grppairener_t;
 struct t_forcerec;
 struct t_pbc;
-struct t_mdatoms;
 
 namespace gmx
 {
 class StepWorkload;
+template<typename>
+class ArrayRef;
 } // namespace gmx
 
 /*! \brief Calculate VdW/charge listed pair interactions (usually 1-4
@@ -64,21 +65,25 @@ class StepWorkload;
  *
  * global_atom_index is only passed for printing error messages.
  */
-void do_pairs(int                      ftype,
-              int                      nbonds,
-              const t_iatom            iatoms[],
-              const t_iparams          iparams[],
-              const rvec               x[],
-              rvec4                    f[],
-              rvec                     fshift[],
-              const struct t_pbc*      pbc,
-              const real*              lambda,
-              real*                    dvdl,
-              const t_mdatoms*         md,
-              const t_forcerec*        fr,
-              bool                     havePerturbedPairs,
-              const gmx::StepWorkload& stepWork,
-              gmx_grppairener_t*       grppener,
-              int*                     global_atom_index);
+void do_pairs(int                           ftype,
+              int                           nbonds,
+              const t_iatom                 iatoms[],
+              const t_iparams               iparams[],
+              const rvec                    x[],
+              rvec4                         f[],
+              rvec                          fshift[],
+              const struct t_pbc*           pbc,
+              const real*                   lambda,
+              real*                         dvdl,
+              gmx::ArrayRef<real>           chargeA,
+              gmx::ArrayRef<real>           chargeB,
+              gmx::ArrayRef<bool>           atomIsPerturbed,
+              gmx::ArrayRef<unsigned short> cENER,
+              int                           numEnergyGroups,
+              const t_forcerec*             fr,
+              bool                          havePerturbedPairs,
+              const gmx::StepWorkload&      stepWork,
+              gmx_grppairener_t*            grppener,
+              int*                          global_atom_index);
 
 #endif
index 5b038f1240ee4debc6730c7cee1170fdfbbe4a78..1a734aa51c39d58ff1ae547538e407203f141a7a 100644 (file)
@@ -381,7 +381,11 @@ protected:
                      &pbc_,
                      lambdas.data(),
                      output.dvdLambda.data(),
-                     &mdatoms,
+                     gmx::arrayRefFromArray(mdatoms.chargeA, mdatoms.nr),
+                     gmx::arrayRefFromArray(mdatoms.chargeB, mdatoms.nr),
+                     gmx::arrayRefFromArray(mdatoms.bPerturbed, mdatoms.nr),
+                     gmx::arrayRefFromArray(mdatoms.cENER, mdatoms.nr),
+                     mdatoms.nPerturbed,
                      fr,
                      havePerturbedInteractions,
                      stepWork,