Refactor virtual site interface
authorBerk Hess <hess@kth.se>
Tue, 12 May 2020 09:01:35 +0000 (09:01 +0000)
committerPaul Bauer <paul.bauer.q@gmail.com>
Tue, 12 May 2020 09:01:35 +0000 (09:01 +0000)
Replaced struct gmx_vsite_t by pimpled class VirtualSitesHandler.
Removed most arguments that stay the same from the construction
and spreading calls. Replaced rvec pointers by ArrayRefs.
This change is only refactoring and there are no functional changes
in the code using this, but the PBC and virial handling have been
simplified and the parameters reduced. Templating of the spreading
functions on virial handling should improve performance.
Added full doxygen documentation.

29 files changed:
src/gromacs/domdec/domdec.cpp
src/gromacs/domdec/domdec.h
src/gromacs/domdec/domdec_specatomcomm.cpp
src/gromacs/domdec/domdec_specatomcomm.h
src/gromacs/domdec/domdec_topology.cpp
src/gromacs/domdec/domdec_vsite.cpp
src/gromacs/domdec/mdsetup.cpp
src/gromacs/domdec/mdsetup.h
src/gromacs/domdec/partition.cpp
src/gromacs/domdec/partition.h
src/gromacs/gmxpreprocess/grompp.cpp
src/gromacs/mdlib/force.h
src/gromacs/mdlib/sim_util.cpp
src/gromacs/mdlib/vsite.cpp
src/gromacs/mdlib/vsite.h
src/gromacs/mdrun/isimulator.h
src/gromacs/mdrun/md.cpp
src/gromacs/mdrun/mimic.cpp
src/gromacs/mdrun/minimize.cpp
src/gromacs/mdrun/rerun.cpp
src/gromacs/mdrun/runner.cpp
src/gromacs/mdrun/shellfc.cpp
src/gromacs/mdrun/shellfc.h
src/gromacs/modularsimulator/domdechelper.cpp
src/gromacs/modularsimulator/domdechelper.h
src/gromacs/modularsimulator/forceelement.cpp
src/gromacs/modularsimulator/forceelement.h
src/gromacs/modularsimulator/topologyholder.cpp
src/gromacs/modularsimulator/topologyholder.h

index 3f355f71a835a42ada515934ccf1a0594c98d79e..5170ca87446c72defe1323c30ccdc82f65ecc114 100644 (file)
@@ -2158,7 +2158,7 @@ static DDSystemInfo getSystemInfo(const gmx::MDLogger&           mdlog,
 
             if (MASTER(cr))
             {
-                dd_bonded_cg_distance(mdlog, &mtop, &ir, as_rvec_array(xGlobal.data()), box,
+                dd_bonded_cg_distance(mdlog, &mtop, &ir, xGlobal, box,
                                       options.checkBondedInteractions, &r_2b, &r_mb);
             }
             gmx_bcast(sizeof(r_2b), &r_2b, cr->mpi_comm_mygroup);
@@ -2492,13 +2492,13 @@ static void set_dd_limits(const gmx::MDLogger& mdlog,
     }
 }
 
-void dd_init_bondeds(FILE*                      fplog,
-                     gmx_domdec_t*              dd,
-                     const gmx_mtop_t&          mtop,
-                     const gmx_vsite_t*         vsite,
-                     const t_inputrec*          ir,
-                     gmx_bool                   bBCheck,
-                     gmx::ArrayRef<cginfo_mb_t> cginfo_mb)
+void dd_init_bondeds(FILE*                           fplog,
+                     gmx_domdec_t*                   dd,
+                     const gmx_mtop_t&               mtop,
+                     const gmx::VirtualSitesHandler* vsite,
+                     const t_inputrec*               ir,
+                     gmx_bool                        bBCheck,
+                     gmx::ArrayRef<cginfo_mb_t>      cginfo_mb)
 {
     gmx_domdec_comm_t* comm;
 
index 82cefa18c1eafcd85dab35aae4f4fda876e61dfc..6c9908536cea4e19ce39585bbe570e943db8d705 100644 (file)
@@ -74,7 +74,6 @@ struct gmx_ddbox_t;
 struct gmx_domdec_zones_t;
 struct gmx_localtop_t;
 struct gmx_mtop_t;
-struct gmx_vsite_t;
 struct t_block;
 struct t_blocka;
 struct t_commrec;
@@ -94,6 +93,7 @@ class DeviceStreamManager;
 class ForceWithShiftForces;
 class MDLogger;
 class RangePartitioning;
+class VirtualSitesHandler;
 } // namespace gmx
 
 /*! \brief Returns the global topology atom number belonging to local atom index i.
@@ -164,13 +164,13 @@ bool ddUsesUpdateGroups(const gmx_domdec_t& dd);
 bool is1D(const gmx_domdec_t& dd);
 
 /*! \brief Initialize data structures for bonded interactions */
-void dd_init_bondeds(FILE*                      fplog,
-                     gmx_domdec_t*              dd,
-                     const gmx_mtop_t&          mtop,
-                     const gmx_vsite_t*         vsite,
-                     const t_inputrec*          ir,
-                     gmx_bool                   bBCheck,
-                     gmx::ArrayRef<cginfo_mb_t> cginfo_mb);
+void dd_init_bondeds(FILE*                           fplog,
+                     gmx_domdec_t*                   dd,
+                     const gmx_mtop_t&               mtop,
+                     const gmx::VirtualSitesHandler* vsite,
+                     const t_inputrec*               ir,
+                     gmx_bool                        bBCheck,
+                     gmx::ArrayRef<cginfo_mb_t>      cginfo_mb);
 
 /*! \brief Returns whether molecules are always whole, i.e. not broken by PBC */
 bool dd_moleculesAreAlwaysWhole(const gmx_domdec_t& dd);
@@ -235,10 +235,10 @@ void reset_dd_statistics_counters(struct gmx_domdec_t* dd);
 /* In domdec_con.c */
 
 /*! \brief Communicates the virtual site forces, reduces the shift forces when \p fshift != NULL */
-void dd_move_f_vsites(struct gmx_domdec_t* dd, rvec* f, rvec* fshift);
+void dd_move_f_vsites(const gmx_domdec_t& dd, gmx::ArrayRef<gmx::RVec> f, gmx::ArrayRef<gmx::RVec> fshift);
 
 /*! \brief Clears the forces for virtual sites */
-void dd_clear_f_vsites(struct gmx_domdec_t* dd, rvec* f);
+void dd_clear_f_vsites(const gmx_domdec_t& dd, gmx::ArrayRef<gmx::RVec> f);
 
 /*! \brief Move x0 and also x1 if x1!=NULL. bX1IsCoord tells if to do PBC on x1 */
 void dd_move_x_constraints(struct gmx_domdec_t*     dd,
@@ -248,7 +248,7 @@ void dd_move_x_constraints(struct gmx_domdec_t*     dd,
                            gmx_bool                 bX1IsCoord);
 
 /*! \brief Communicates the coordinates involved in virtual sites */
-void dd_move_x_vsites(struct gmx_domdec_t* dd, const matrix box, rvec* x);
+void dd_move_x_vsites(const gmx_domdec_t& dd, const matrix box, rvec* x);
 
 /*! \brief Returns the local atom count array for all constraints
  *
@@ -272,12 +272,12 @@ gmx::ArrayRef<const int> dd_constraints_nlocalatoms(const gmx_domdec_t* dd);
                                                 const matrix                   box);
 
 /*! \brief Generate and store the reverse topology */
-void dd_make_reverse_top(FILE*              fplog,
-                         gmx_domdec_t*      dd,
-                         const gmx_mtop_t*  mtop,
-                         const gmx_vsite_t* vsite,
-                         const t_inputrec*  ir,
-                         gmx_bool           bBCheck);
+void dd_make_reverse_top(FILE*                           fplog,
+                         gmx_domdec_t*                   dd,
+                         const gmx_mtop_t*               mtop,
+                         const gmx::VirtualSitesHandler* vsite,
+                         const t_inputrec*               ir,
+                         gmx_bool                        bBCheck);
 
 /*! \brief Generate the local topology and virtual site data */
 void dd_make_local_top(struct gmx_domdec_t*       dd,
@@ -304,14 +304,14 @@ void dd_init_local_state(struct gmx_domdec_t* dd, const t_state* state_global, t
 t_blocka* makeBondedLinks(const gmx_mtop_t& mtop, gmx::ArrayRef<cginfo_mb_t> cginfo_mb);
 
 /*! \brief Calculate the maximum distance involved in 2-body and multi-body bonded interactions */
-void dd_bonded_cg_distance(const gmx::MDLogger& mdlog,
-                           const gmx_mtop_t*    mtop,
-                           const t_inputrec*    ir,
-                           const rvec*          x,
-                           const matrix         box,
-                           gmx_bool             bBCheck,
-                           real*                r_2b,
-                           real*                r_mb);
+void dd_bonded_cg_distance(const gmx::MDLogger&           mdlog,
+                           const gmx_mtop_t*              mtop,
+                           const t_inputrec*              ir,
+                           gmx::ArrayRef<const gmx::RVec> x,
+                           const matrix                   box,
+                           gmx_bool                       bBCheck,
+                           real*                          r_2b,
+                           real*                          r_mb);
 
 /*! \brief Construct the GPU halo exchange object(s).
  *
index d829c18e49456934eea0a5fd8efd0c6934d98dc0..2c543a034369a565317b012d29d2e262067c4926 100644 (file)
@@ -65,7 +65,7 @@
 #include "gromacs/utility/fatalerror.h"
 #include "gromacs/utility/gmxassert.h"
 
-void dd_move_f_specat(gmx_domdec_t* dd, gmx_domdec_specat_comm_t* spac, rvec* f, rvec* fshift)
+void dd_move_f_specat(const gmx_domdec_t* dd, gmx_domdec_specat_comm_t* spac, rvec* f, rvec* fshift)
 {
     gmx_specatsend_t* spas;
     rvec*             vbuf;
@@ -173,7 +173,12 @@ void dd_move_f_specat(gmx_domdec_t* dd, gmx_domdec_specat_comm_t* spac, rvec* f,
     }
 }
 
-void dd_move_x_specat(gmx_domdec_t* dd, gmx_domdec_specat_comm_t* spac, const matrix box, rvec* x0, rvec* x1, gmx_bool bX1IsCoord)
+void dd_move_x_specat(const gmx_domdec_t*       dd,
+                      gmx_domdec_specat_comm_t* spac,
+                      const matrix              box,
+                      rvec*                     x0,
+                      rvec*                     x1,
+                      gmx_bool                  bX1IsCoord)
 {
     gmx_specatsend_t* spas;
     int               nvec, v, n, nn, ns0, ns1, nr0, nr1, nr, d, dim, dir, i;
index b8718362fba1f1a3e6c78dcbe8d23d7f20545cb4..ab7d723c819ed7e8de4a4fc3d28f3162bf36be54 100644 (file)
@@ -86,7 +86,7 @@ struct gmx_domdec_specat_comm_t
 };
 
 /*! \brief Communicates the force for special atoms, the shift forces are reduced with \p fshift != NULL */
-void dd_move_f_specat(gmx_domdec_t* dd, gmx_domdec_specat_comm_t* spac, rvec* f, rvec* fshift);
+void dd_move_f_specat(const gmx_domdec_t* dd, gmx_domdec_specat_comm_t* spac, rvec* f, rvec* fshift);
 
 /*! \brief Communicates the coordinates for special atoms
  *
@@ -97,7 +97,7 @@ void dd_move_f_specat(gmx_domdec_t* dd, gmx_domdec_specat_comm_t* spac, rvec* f,
  * \param[in,out] x1         Vector to communicate, when != NULL
  * \param[in]     bX1IsCoord Tells is \p x1 is a coordinate vector (needs pbc)
  */
-void dd_move_x_specat(gmx_domdec_t*             dd,
+void dd_move_x_specat(const gmx_domdec_t*       dd,
                       gmx_domdec_specat_comm_t* spac,
                       const matrix              box,
                       rvec*                     x0,
index 39661d0dfd70f934f9058d8c3105282b009b1022..df87374b7d994dddad5ecb2050bce0c7a80b68a1 100644 (file)
@@ -90,7 +90,9 @@
 #include "domdec_vsite.h"
 #include "dump.h"
 
+using gmx::ArrayRef;
 using gmx::ListOfLists;
+using gmx::RVec;
 
 /*! \brief The number of integer item in the local state, used for broadcasting of the state */
 #define NITEM_DD_INIT_LOCAL_STATE 5
@@ -119,11 +121,11 @@ struct thread_work_t
      */
     thread_work_t(const gmx_ffparams_t& ffparams) : idef(ffparams) {}
 
-    InteractionDefinitions    idef;       /**< Partial local topology */
-    std::unique_ptr<VsitePbc> vsitePbc;   /**< vsite PBC structure */
-    int                       nbonded;    /**< The number of bondeds in this struct */
-    ListOfLists<int>          excl;       /**< List of exclusions */
-    int                       excl_count; /**< The total exclusion count for \p excl */
+    InteractionDefinitions         idef;       /**< Partial local topology */
+    std::unique_ptr<gmx::VsitePbc> vsitePbc;   /**< vsite PBC structure */
+    int                            nbonded;    /**< The number of bondeds in this struct */
+    ListOfLists<int>               excl;       /**< List of exclusions */
+    int                            excl_count; /**< The total exclusion count for \p excl */
 };
 
 /*! \brief Struct for the reverse topology: links bonded interactions to atomsx */
@@ -710,12 +712,12 @@ static gmx_reverse_top_t make_reverse_top(const gmx_mtop_t* mtop,
     return rt;
 }
 
-void dd_make_reverse_top(FILE*              fplog,
-                         gmx_domdec_t*      dd,
-                         const gmx_mtop_t*  mtop,
-                         const gmx_vsite_t* vsite,
-                         const t_inputrec*  ir,
-                         gmx_bool           bBCheck)
+void dd_make_reverse_top(FILE*                           fplog,
+                         gmx_domdec_t*                   dd,
+                         const gmx_mtop_t*               mtop,
+                         const gmx::VirtualSitesHandler* vsite,
+                         const t_inputrec*               ir,
+                         gmx_bool                        bBCheck)
 {
     if (fplog)
     {
@@ -744,16 +746,16 @@ void dd_make_reverse_top(FILE*              fplog,
         }
     }
 
-    if (vsite && vsite->numInterUpdategroupVsites > 0)
+    if (vsite && vsite->numInterUpdategroupVirtualSites() > 0)
     {
         if (fplog)
         {
             fprintf(fplog,
                     "There are %d inter update-group virtual sites,\n"
                     "will an extra communication step for selected coordinates and forces\n",
-                    vsite->numInterUpdategroupVsites);
+                    vsite->numInterUpdategroupVirtualSites());
         }
-        init_domdec_vsites(dd, vsite->numInterUpdategroupVsites);
+        init_domdec_vsites(dd, vsite->numInterUpdategroupVirtualSites());
     }
 
     if (dd->comm->systemInfo.haveSplitConstraints || dd->comm->systemInfo.haveSplitSettles)
@@ -1805,11 +1807,11 @@ static void update_max_bonded_distance(real r2, int ftype, int a1, int a2, bonde
     }
 }
 
-/*! \brief Set the distance, function type and atom indices for the longest distance between charge-groups of molecule type \p molt for two-body and multi-body bonded interactions */
+/*! \brief Set the distance, function type and atom indices for the longest distance between atoms of molecule type \p molt for two-body and multi-body bonded interactions */
 static void bonded_cg_distance_mol(const gmx_moltype_t* molt,
                                    gmx_bool             bBCheck,
                                    gmx_bool             bExcl,
-                                   rvec*                cg_cm,
+                                   ArrayRef<const RVec> x,
                                    bonded_distance_t*   bd_2b,
                                    bonded_distance_t*   bd_mb)
 {
@@ -1831,7 +1833,7 @@ static void bonded_cg_distance_mol(const gmx_moltype_t* molt,
                             int atomJ = il.iatoms[i + 1 + aj];
                             if (atomI != atomJ)
                             {
-                                real rij2 = distance2(cg_cm[atomI], cg_cm[atomJ]);
+                                real rij2 = distance2(x[atomI], x[atomJ]);
 
                                 update_max_bonded_distance(rij2, ftype, atomI, atomJ,
                                                            (nral == 2) ? bd_2b : bd_mb);
@@ -1851,7 +1853,7 @@ static void bonded_cg_distance_mol(const gmx_moltype_t* molt,
             {
                 if (ai != aj)
                 {
-                    real rij2 = distance2(cg_cm[ai], cg_cm[aj]);
+                    real rij2 = distance2(x[ai], x[aj]);
 
                     /* There is no function type for exclusions, use -1 */
                     update_max_bonded_distance(rij2, -1, ai, aj, bd_2b);
@@ -1864,7 +1866,7 @@ static void bonded_cg_distance_mol(const gmx_moltype_t* molt,
 /*! \brief Set the distance, function type and atom indices for the longest atom distance involved in intermolecular interactions for two-body and multi-body bonded interactions */
 static void bonded_distance_intermol(const InteractionLists& ilists_intermol,
                                      gmx_bool                bBCheck,
-                                     const rvec*             x,
+                                     ArrayRef<const RVec>    x,
                                      PbcType                 pbcType,
                                      const matrix            box,
                                      bonded_distance_t*      bd_2b,
@@ -1930,22 +1932,22 @@ static void getWholeMoleculeCoordinates(const gmx_moltype_t*  molt,
                                         PbcType               pbcType,
                                         t_graph*              graph,
                                         const matrix          box,
-                                        const rvec*           x,
-                                        rvec*                 xs)
+                                        ArrayRef<const RVec>  x,
+                                        ArrayRef<RVec>        xs)
 {
     int n, i;
 
     if (pbcType != PbcType::No)
     {
-        mk_mshift(nullptr, graph, pbcType, box, x);
+        mk_mshift(nullptr, graph, pbcType, box, as_rvec_array(x.data()));
 
-        shift_x(graph, box, x, xs);
+        shift_x(graph, box, as_rvec_array(x.data()), as_rvec_array(xs.data()));
         /* By doing an extra mk_mshift the molecules that are broken
          * because they were e.g. imported from another software
          * will be made whole again. Such are the healing powers
          * of GROMACS.
          */
-        mk_mshift(nullptr, graph, pbcType, box, xs);
+        mk_mshift(nullptr, graph, pbcType, box, as_rvec_array(xs.data()));
     }
     else
     {
@@ -1962,15 +1964,14 @@ static void getWholeMoleculeCoordinates(const gmx_moltype_t*  molt,
 
     if (moltypeHasVsite(*molt))
     {
-        construct_vsites(nullptr, xs, 0.0, nullptr, ffparams->iparams, molt->ilist, PbcType::No,
-                         TRUE, nullptr, nullptr);
+        gmx::constructVirtualSites(xs, ffparams->iparams, molt->ilist);
     }
 }
 
 void dd_bonded_cg_distance(const gmx::MDLogger& mdlog,
                            const gmx_mtop_t*    mtop,
                            const t_inputrec*    ir,
-                           const rvec*          x,
+                           ArrayRef<const RVec> x,
                            const matrix         box,
                            gmx_bool             bBCheck,
                            real*                r_2b,
@@ -1978,7 +1979,6 @@ void dd_bonded_cg_distance(const gmx::MDLogger& mdlog,
 {
     gmx_bool          bExclRequired;
     int               at_offset;
-    rvec*             xs;
     bonded_distance_t bd_2b = { 0, -1, -1, -1 };
     bonded_distance_t bd_mb = { 0, -1, -1, -1 };
 
@@ -2002,11 +2002,11 @@ void dd_bonded_cg_distance(const gmx::MDLogger& mdlog,
                 graph = mk_graph_moltype(molt);
             }
 
-            snew(xs, molt.atoms.nr);
+            std::vector<RVec> xs(molt.atoms.nr);
             for (int mol = 0; mol < molb.nmol; mol++)
             {
                 getWholeMoleculeCoordinates(&molt, &mtop->ffparams, ir->pbcType, &graph, box,
-                                            x + at_offset, xs);
+                                            x.subArray(at_offset, molt.atoms.nr), xs);
 
                 bonded_distance_t bd_mol_2b = { 0, -1, -1, -1 };
                 bonded_distance_t bd_mol_mb = { 0, -1, -1, -1 };
@@ -2021,7 +2021,6 @@ void dd_bonded_cg_distance(const gmx::MDLogger& mdlog,
 
                 at_offset += molt.atoms.nr;
             }
-            sfree(xs);
         }
     }
 
index 7b8c8f903f6ffe7622df0bf8d6359b8b1d65ddf5..0fe374ee45223796bb03b70b152009a7550efa7f 100644 (file)
 
 #include "domdec_specatomcomm.h"
 
-void dd_move_f_vsites(gmx_domdec_t* dd, rvec* f, rvec* fshift)
+void dd_move_f_vsites(const gmx_domdec_t& dd, gmx::ArrayRef<gmx::RVec> f, gmx::ArrayRef<gmx::RVec> fshift)
 {
-    if (dd->vsite_comm)
+    if (dd.vsite_comm)
     {
-        dd_move_f_specat(dd, dd->vsite_comm, f, fshift);
+        dd_move_f_specat(&dd, dd.vsite_comm, as_rvec_array(f.data()), as_rvec_array(fshift.data()));
     }
 }
 
-void dd_clear_f_vsites(gmx_domdec_t* dd, rvec* f)
+void dd_clear_f_vsites(const gmx_domdec_t& dd, gmx::ArrayRef<gmx::RVec> f)
 {
-    int i;
-
-    if (dd->vsite_comm)
+    if (dd.vsite_comm)
     {
-        for (i = dd->vsite_comm->at_start; i < dd->vsite_comm->at_end; i++)
+        for (int i = dd.vsite_comm->at_start; i < dd.vsite_comm->at_end; i++)
         {
             clear_rvec(f[i]);
         }
     }
 }
 
-void dd_move_x_vsites(gmx_domdec_t* dd, const matrix box, rvec* x)
+void dd_move_x_vsites(const gmx_domdec_t& dd, const matrix box, rvec* x)
 {
-    if (dd->vsite_comm)
+    if (dd.vsite_comm)
     {
-        dd_move_x_specat(dd, dd->vsite_comm, box, x, nullptr, FALSE);
+        dd_move_x_specat(&dd, dd.vsite_comm, box, x, nullptr, FALSE);
     }
 }
 
index 218a4e559de8683548c3ebb57de3bdb6970ee011..f8d9172c6b5a7894b261199d6e84558d0ae27ff0 100644 (file)
@@ -70,7 +70,7 @@ void mdAlgorithmsSetupAtomData(const t_commrec*        cr,
                                PaddedHostVector<RVec>* force,
                                MDAtoms*                mdAtoms,
                                Constraints*            constr,
-                               gmx_vsite_t*            vsite,
+                               VirtualSitesHandler*    vsite,
                                gmx_shellfc_t*          shellfc)
 {
     bool usingDomDec = DOMAINDECOMP(cr);
@@ -117,17 +117,7 @@ void mdAlgorithmsSetupAtomData(const t_commrec*        cr,
 
     if (vsite)
     {
-        if (usingDomDec)
-        {
-            /* The vsites were already assigned by the domdec topology code.
-             * We only need to do the thread division here.
-             */
-            split_vsites_over_threads(top->idef.il, top->idef.iparams, mdatoms, vsite);
-        }
-        else
-        {
-            set_vsite_top(vsite, top, mdatoms);
-        }
+        vsite->setVirtualSites(top->idef.il, *mdatoms);
     }
 
     /* Note that with DD only flexible constraints, not shells, are supported
index 369bf64d710a3343fd561e96f4a5686aeabdac8f..d406b09f7fb824d9f11659c9202ea26fe558d36b 100644 (file)
@@ -50,7 +50,6 @@ struct bonded_threading_t;
 struct gmx_localtop_t;
 struct gmx_mtop_t;
 struct gmx_shellfc_t;
-struct gmx_vsite_t;
 struct t_commrec;
 struct t_forcerec;
 struct t_inputrec;
@@ -60,6 +59,7 @@ namespace gmx
 {
 class Constraints;
 class MDAtoms;
+class VirtualSitesHandler;
 
 /*! \brief Gets the local shell with domain decomposition
  *
@@ -95,7 +95,7 @@ void mdAlgorithmsSetupAtomData(const t_commrec*        cr,
                                PaddedHostVector<RVec>* force,
                                MDAtoms*                mdAtoms,
                                Constraints*            constr,
-                               gmx_vsite_t*            vsite,
+                               VirtualSitesHandler*    vsite,
                                gmx_shellfc_t*          shellfc);
 
 } // namespace gmx
index b4fa8901ed22578512b88ae9e28e8050a92844c2..ca7c48c14d35692b3ddf67502c072529ef99b3d2 100644 (file)
@@ -2640,7 +2640,7 @@ void dd_partition_system(FILE*                        fplog,
                          gmx::MDAtoms*                mdAtoms,
                          gmx_localtop_t*              top_local,
                          t_forcerec*                  fr,
-                         gmx_vsite_t*                 vsite,
+                         gmx::VirtualSitesHandler*    vsite,
                          gmx::Constraints*            constr,
                          t_nrnb*                      nrnb,
                          gmx_wallcycle*               wcycle,
@@ -3088,7 +3088,7 @@ void dd_partition_system(FILE*                        fplog,
         switch (range)
         {
             case DDAtomRanges::Type::Vsites:
-                if (vsite && vsite->numInterUpdategroupVsites)
+                if (vsite && vsite->numInterUpdategroupVirtualSites())
                 {
                     n = dd_make_local_vsites(dd, n, top_local->idef.il);
                 }
@@ -3119,7 +3119,7 @@ void dd_partition_system(FILE*                        fplog,
 
     if (fr->haveDirectVirialContributions)
     {
-        if (vsite && vsite->numInterUpdategroupVsites)
+        if (vsite && vsite->numInterUpdategroupVirtualSites())
         {
             nat_f_novirsum = comm->atomRanges.end(DDAtomRanges::Type::Vsites);
         }
@@ -3186,7 +3186,7 @@ void dd_partition_system(FILE*                        fplog,
      * the last vsite construction, we need to communicate the constructing
      * atom coordinates again (for spreading the forces this MD step).
      */
-    dd_move_x_vsites(dd, state_local->box, state_local->x.rvec_array());
+    dd_move_x_vsites(*dd, state_local->box, state_local->x.rvec_array());
 
     wallcycle_sub_stop(wcycle, ewcsDD_TOPOTHER);
 
index 721893a3711ed2da16a7d2cdd46838ab6705f512..e06138a1b7c01b9e0d938d9ab6957793804671b3 100644 (file)
@@ -54,7 +54,6 @@ struct gmx_ddbox_t;
 struct gmx_domdec_t;
 struct gmx_localtop_t;
 struct gmx_mtop_t;
-struct gmx_vsite_t;
 struct gmx_wallcycle;
 struct pull_t;
 struct t_commrec;
@@ -69,6 +68,7 @@ class Constraints;
 class ImdSession;
 class MDAtoms;
 class MDLogger;
+class VirtualSitesHandler;
 } // namespace gmx
 
 //! Check whether the DD grid has moved too far for correctness.
@@ -100,7 +100,7 @@ void print_dd_statistics(const t_commrec* cr, const t_inputrec* ir, FILE* fplog)
  * \param[in] mdatoms       MD atoms
  * \param[in] top_local     Local topology
  * \param[in] fr            Force record
- * \param[in] vsite         Virtual sites
+ * \param[in] vsite         Virtual sites handler
  * \param[in] constr        Constraints
  * \param[in] nrnb          Cycle counters
  * \param[in] wcycle        Timers
@@ -122,7 +122,7 @@ void dd_partition_system(FILE*                             fplog,
                          gmx::MDAtoms*                     mdatoms,
                          gmx_localtop_t*                   top_local,
                          t_forcerec*                       fr,
-                         gmx_vsite_t*                      vsite,
+                         gmx::VirtualSitesHandler*         vsite,
                          gmx::Constraints*                 constr,
                          t_nrnb*                           nrnb,
                          gmx_wallcycle*                    wcycle,
index 93e2fc4c00b3d8c1e212484df8e682098af878b2..84c5d9932861bc66db36327a4b195ec679eea182 100644 (file)
@@ -1623,7 +1623,7 @@ static void set_verlet_buffer(const gmx_mtop_t*    mtop,
     ir->rlist = calcVerletBufferSize(*mtop, det(box), *ir, ir->nstlist, ir->nstlist - 1,
                                      buffer_temp, listSetup4x4);
 
-    const int n_nonlin_vsite = countNonlinearVsites(*mtop);
+    const int n_nonlin_vsite = gmx::countNonlinearVsites(*mtop);
     if (n_nonlin_vsite > 0)
     {
         std::string warningMessage = gmx::formatString(
index 9d155f369de5e5fef8a1b6a98780359ce7ed8218..387f1e1148d981550211f4f571dbc3db5733265f 100644 (file)
@@ -49,7 +49,6 @@ struct gmx_enfrot;
 struct SimulationGroups;
 struct gmx_localtop_t;
 struct gmx_multisim_t;
-struct gmx_vsite_t;
 struct gmx_wallcycle;
 class history_t;
 class InteractionDefinitions;
@@ -71,6 +70,7 @@ class ImdSession;
 class MdrunScheduleWorkload;
 class MDLogger;
 class StepWorkload;
+class VirtualSitesHandler;
 } // namespace gmx
 
 void do_force(FILE*                               log,
@@ -96,7 +96,7 @@ void do_force(FILE*                               log,
               gmx::ArrayRef<real>                 lambda,
               t_forcerec*                         fr,
               gmx::MdrunScheduleWorkload*         runScheduleWork,
-              const gmx_vsite_t*                  vsite,
+              gmx::VirtualSitesHandler*           vsite,
               rvec                                mu_tot,
               double                              t,
               gmx_edsam*                          ed,
index 179e3116e525860395178d1cd60ebde5df26bead..841337b6f642351acec9a727da1ee900a4e1228f 100644 (file)
 #include "gromacs/utility/strconvert.h"
 #include "gromacs/utility/sysinfo.h"
 
+using gmx::ArrayRef;
 using gmx::AtomLocality;
 using gmx::DomainLifetimeWorkload;
 using gmx::ForceOutputs;
 using gmx::InteractionLocality;
+using gmx::RVec;
 using gmx::SimulationWorkload;
 using gmx::StepWorkload;
 
@@ -242,13 +244,13 @@ static void pme_receive_force_ener(t_forcerec*           fr,
     wallcycle_stop(wcycle, ewcPP_PMEWAITRECVF);
 }
 
-static void print_large_forces(FILE*            fp,
-                               const t_mdatoms* md,
-                               const t_commrec* cr,
-                               int64_t          step,
-                               real             forceTolerance,
-                               const rvec*      x,
-                               const rvec*      f)
+static void print_large_forces(FILE*                fp,
+                               const t_mdatoms*     md,
+                               const t_commrec*     cr,
+                               int64_t              step,
+                               real                 forceTolerance,
+                               ArrayRef<const RVec> x,
+                               const rvec*          f)
 {
     real       force2Tolerance = gmx::square(forceTolerance);
     gmx::index numNonFinite    = 0;
@@ -276,26 +278,24 @@ static void print_large_forces(FILE*            fp,
     }
 }
 
-static void post_process_forces(const t_commrec*      cr,
-                                int64_t               step,
-                                t_nrnb*               nrnb,
-                                gmx_wallcycle_t       wcycle,
-                                const gmx_localtop_t* top,
-                                const matrix          box,
-                                const rvec            x[],
-                                ForceOutputs*         forceOutputs,
-                                tensor                vir_force,
-                                const t_mdatoms*      mdatoms,
-                                const t_forcerec*     fr,
-                                const gmx_vsite_t*    vsite,
-                                const StepWorkload&   stepWork)
+static void post_process_forces(const t_commrec*          cr,
+                                int64_t                   step,
+                                t_nrnb*                   nrnb,
+                                gmx_wallcycle_t           wcycle,
+                                const matrix              box,
+                                ArrayRef<const RVec>      x,
+                                ForceOutputs*             forceOutputs,
+                                tensor                    vir_force,
+                                const t_mdatoms*          mdatoms,
+                                const t_forcerec*         fr,
+                                gmx::VirtualSitesHandler* vsite,
+                                const StepWorkload&       stepWork)
 {
     rvec* f = as_rvec_array(forceOutputs->forceWithShiftForces().force().data());
 
     if (fr->haveDirectVirialContributions)
     {
         auto& forceWithVirial = forceOutputs->forceWithVirial();
-        rvec* fDirectVir      = as_rvec_array(forceWithVirial.force_.data());
 
         if (vsite)
         {
@@ -303,9 +303,11 @@ static void post_process_forces(const t_commrec*      cr,
              * This is parallellized. MPI communication is performed
              * if the constructing atoms aren't local.
              */
+            const gmx::VirtualSitesHandler::VirialHandling virialHandling =
+                    (stepWork.computeVirial ? gmx::VirtualSitesHandler::VirialHandling::NonLinear
+                                            : gmx::VirtualSitesHandler::VirialHandling::None);
             matrix virial = { { 0 } };
-            spread_vsite_f(vsite, x, fDirectVir, nullptr, stepWork.computeVirial, virial, nrnb,
-                           top->idef, fr->pbcType, fr->bMolPBC, box, cr, wcycle);
+            vsite->spreadForces(x, forceWithVirial.force_, virialHandling, {}, virial, nrnb, box, wcycle);
             forceWithVirial.addVirialContribution(virial);
         }
 
@@ -953,7 +955,7 @@ void do_force(FILE*                               fplog,
               gmx::ArrayRef<real>                 lambda,
               t_forcerec*                         fr,
               gmx::MdrunScheduleWorkload*         runScheduleWork,
-              const gmx_vsite_t*                  vsite,
+              gmx::VirtualSitesHandler*           vsite,
               rvec                                muTotal,
               double                              t,
               gmx_edsam*                          ed,
@@ -1784,16 +1786,17 @@ void do_force(FILE*                               fplog,
 
     if (stepWork.computeForces)
     {
-        rvec* f = as_rvec_array(forceOut.forceWithShiftForces().force().data());
-
         /* If we have NoVirSum forces, but we do not calculate the virial,
          * we sum fr->f_novirsum=forceOut.f later.
          */
         if (vsite && !(fr->haveDirectVirialContributions && !stepWork.computeVirial))
         {
-            rvec* fshift = as_rvec_array(forceOut.forceWithShiftForces().shiftForces().data());
-            spread_vsite_f(vsite, as_rvec_array(x.unpaddedArrayRef().data()), f, fshift, FALSE,
-                           nullptr, nrnb, top->idef, fr->pbcType, fr->bMolPBC, box, cr, wcycle);
+            auto f      = forceOut.forceWithShiftForces().force();
+            auto fshift = forceOut.forceWithShiftForces().shiftForces();
+            const gmx::VirtualSitesHandler::VirialHandling virialHandling =
+                    (stepWork.computeVirial ? gmx::VirtualSitesHandler::VirialHandling::Pbc
+                                            : gmx::VirtualSitesHandler::VirialHandling::None);
+            vsite->spreadForces(x.unpaddedArrayRef(), f, virialHandling, fshift, nullptr, nrnb, box, wcycle);
         }
 
         if (stepWork.computeVirial)
@@ -1817,8 +1820,8 @@ void do_force(FILE*                               fplog,
 
     if (stepWork.computeForces)
     {
-        post_process_forces(cr, step, nrnb, wcycle, top, box, as_rvec_array(x.unpaddedArrayRef().data()),
-                            &forceOut, vir_force, mdatoms, fr, vsite, stepWork);
+        post_process_forces(cr, step, nrnb, wcycle, box, x.unpaddedArrayRef(), &forceOut, vir_force,
+                            mdatoms, fr, vsite, stepWork);
     }
 
     if (stepWork.computeEnergy)
index bed9b8faa1cc8a67db4ae45cbd1071958a0cfa63..b2da74223b8e5fcb5c23ec2856c1c898a6b158d9 100644 (file)
  * To help us fund GROMACS development, we humbly ask that you cite
  * the research papers on the package. Check out http://www.gromacs.org.
  */
+/*! \libinternal \file
+ * \brief Implements the VirtualSitesHandler class and vsite standalone functions
+ *
+ * \author Berk Hess <hess@kth.se>
+ * \ingroup module_mdlib
+ * \inlibraryapi
+ */
+
 #include "gmxpre.h"
 
 #include "vsite.h"
  * Any remaining vsites are assigned to a separate master thread task.
  */
 
-using gmx::ArrayRef;
-using gmx::RVec;
+namespace gmx
+{
+
+//! VirialHandling is often used outside VirtualSitesHandler class members
+using VirialHandling = VirtualSitesHandler::VirialHandling;
+
+/*! \libinternal
+ * \brief Information on PBC and domain decomposition for virtual sites
+ */
+struct DomainInfo
+{
+public:
+    //! Constructs without PBC and DD
+    DomainInfo() = default;
+
+    //! Constructs with PBC and DD, if !=nullptr
+    DomainInfo(PbcType pbcType, bool haveInterUpdateGroupVirtualSites, gmx_domdec_t* domdec) :
+        pbcType_(pbcType),
+        useMolPbc_(pbcType != PbcType::No && haveInterUpdateGroupVirtualSites),
+        domdec_(domdec)
+    {
+    }
+
+    //! Returns whether we are using domain decomposition with more than 1 DD rank
+    bool useDomdec() const { return (domdec_ != nullptr); }
 
-/*! \brief List of atom indices belonging to a task */
+    //! The pbc type
+    const PbcType pbcType_ = PbcType::No;
+    //! Whether molecules are broken over PBC
+    const bool useMolPbc_ = false;
+    //! Pointer to the domain decomposition struct, nullptr without PP DD
+    const gmx_domdec_t* domdec_ = nullptr;
+};
+
+/*! \libinternal
+ * \brief List of atom indices belonging to a task
+ */
 struct AtomIndex
 {
     //! List of atom indices
     std::vector<int> atom;
 };
 
-/*! \brief Data structure for thread tasks that use constructing atoms outside their own atom range */
+/*! \libinternal
+ * \brief Data structure for thread tasks that use constructing atoms outside their own atom range
+ */
 struct InterdependentTask
 {
     //! The interaction lists, only vsite entries are used
@@ -116,7 +159,9 @@ struct InterdependentTask
     std::vector<int> reduceTask;
 };
 
-/*! \brief Vsite thread task data structure */
+/*! \libinternal
+ * \brief Vsite thread task data structure
+ */
 struct VsiteThread
 {
     //! Start of atom range of this task
@@ -126,7 +171,7 @@ struct VsiteThread
     //! The interaction lists, only vsite entries are used
     std::array<InteractionList, F_NRE> ilist;
     //! Local fshift accumulation buffer
-    rvec fshift[SHIFTS];
+    std::array<RVec, SHIFTS> fshift;
     //! Local virial dx*df accumulation buffer
     matrix dxdf;
     //! Tells if interdependent task idTask should be used (in addition to the rest of this task), this bool has the same value on all threads
@@ -139,12 +184,118 @@ struct VsiteThread
     {
         rangeStart = -1;
         rangeEnd   = -1;
-        clear_rvecs(SHIFTS, fshift);
+        for (auto& elem : fshift)
+        {
+            elem = { 0.0_real, 0.0_real, 0.0_real };
+        }
         clear_mat(dxdf);
         useInterdependentTask = false;
     }
 };
 
+
+/*! \libinternal
+ * \brief Information on how the virtual site work is divided over thread tasks
+ */
+class ThreadingInfo
+{
+public:
+    //! Constructor, retrieves the number of threads to use from gmx_omp_nthreads.h
+    ThreadingInfo();
+
+    //! Returns the number of threads to use for vsite operations
+    int numThreads() const { return numThreads_; }
+
+    //! Returns the thread data for the given thread
+    const VsiteThread& threadData(int threadIndex) const { return *tData_[threadIndex]; }
+
+    //! Returns the thread data for the given thread
+    VsiteThread& threadData(int threadIndex) { return *tData_[threadIndex]; }
+
+    //! Returns the thread data for vsites that depend on non-local vsites
+    const VsiteThread& threadDataNonLocalDependent() const { return *tData_[numThreads_]; }
+
+    //! Returns the thread data for vsites that depend on non-local vsites
+    VsiteThread& threadDataNonLocalDependent() { return *tData_[numThreads_]; }
+
+    //! Set VSites and distribute VSite work over threads, should be called after DD partitioning
+    void setVirtualSites(ArrayRef<const InteractionList> ilist,
+                         ArrayRef<const t_iparams>       iparams,
+                         const t_mdatoms&                mdatoms,
+                         bool                            useDomdec);
+
+private:
+    //! Number of threads used for vsite operations
+    const int numThreads_;
+    //! Thread local vsites and work structs
+    std::vector<std::unique_ptr<VsiteThread>> tData_;
+    //! Work array for dividing vsites over threads
+    std::vector<int> taskIndex_;
+};
+
+/*! \libinternal
+ * \brief Impl class for VirtualSitesHandler
+ */
+class VirtualSitesHandler::Impl
+{
+public:
+    //! Constructor, domdec should be nullptr without DD
+    Impl(const gmx_mtop_t& mtop, gmx_domdec_t* domdec, PbcType pbcType);
+
+    //! Returns the number of virtual sites acting over multiple update groups
+    int numInterUpdategroupVirtualSites() const { return numInterUpdategroupVirtualSites_; }
+
+    //! Set VSites and distribute VSite work over threads, should be called after DD partitioning
+    void setVirtualSites(ArrayRef<const InteractionList> ilist, const t_mdatoms& mdatoms);
+
+    /*! \brief Create positions of vsite atoms based for the local system
+     *
+     * \param[in,out] x        The coordinates
+     * \param[in]     dt       The time step
+     * \param[in,out] v        When != nullptr, velocities for vsites are set as displacement/dt
+     * \param[in]     box      The box
+     */
+    void construct(ArrayRef<RVec> x, real dt, ArrayRef<RVec> v, const matrix box) const;
+
+    /*! \brief Spread the force operating on the vsite atoms on the surrounding atoms.
+     *
+     * vsite should point to a valid object.
+     * The virialHandling parameter determines how virial contributions are handled.
+     * If this is set to Linear, shift forces are accumulated into fshift.
+     * If this is set to NonLinear, non-linear contributions are added to virial.
+     * This non-linear correction is required when the virial is not calculated
+     * afterwards from the particle position and forces, but in a different way,
+     * as for instance for the PME mesh contribution.
+     */
+    void spreadForces(ArrayRef<const RVec> x,
+                      ArrayRef<RVec>       f,
+                      VirialHandling       virialHandling,
+                      ArrayRef<RVec>       fshift,
+                      matrix               virial,
+                      t_nrnb*              nrnb,
+                      const matrix         box,
+                      gmx_wallcycle*       wcycle);
+
+private:
+    // The number of vsites that cross update groups, when =0 no PBC treatment is needed
+    const int numInterUpdategroupVirtualSites_;
+    // PBC and DD information
+    const DomainInfo domainInfo_;
+    // The interaction parameters
+    const ArrayRef<const t_iparams> iparams_;
+    // The interaction lists
+    ArrayRef<const InteractionList> ilists_;
+    // Information for handling vsite threading
+    ThreadingInfo threadingInfo_;
+};
+
+VirtualSitesHandler::~VirtualSitesHandler() = default;
+
+int VirtualSitesHandler::numInterUpdategroupVirtualSites() const
+{
+    return impl_->numInterUpdategroupVirtualSites();
+}
+
 /*! \brief Returns the sum of the vsite ilist sizes over all vsite types
  *
  * \param[in] ilist  The interaction list
@@ -160,6 +311,7 @@ static int vsiteIlistNrCount(ArrayRef<const InteractionList> ilist)
     return nr;
 }
 
+//! Computes the distance between xi and xj, pbc is used when pbc!=nullptr
 static int pbc_rvec_sub(const t_pbc* pbc, const rvec xi, const rvec xj, rvec dx)
 {
     if (pbc)
@@ -173,11 +325,13 @@ static int pbc_rvec_sub(const t_pbc* pbc, const rvec xi, const rvec xj, rvec dx)
     }
 }
 
+//! Returns the 1/norm(x)
 static inline real inverseNorm(const rvec x)
 {
     return gmx::invsqrt(iprod(x, x));
 }
 
+#ifndef DOXYGEN
 /* Vsite construction routines */
 
 static void constr_vsite2(const rvec xi, const rvec xj, rvec x, real a, const t_pbc* pbc)
@@ -398,8 +552,7 @@ static void constr_vsite4FDN(const rvec   xi,
     /* TOTAL: 47 flops */
 }
 
-
-static int constr_vsiten(const t_iatom* ia, ArrayRef<const t_iparams> ip, rvec* x, const t_pbc* pbc)
+static int constr_vsiten(const t_iatom* ia, ArrayRef<const t_iparams> ip, ArrayRef<RVec> x, const t_pbc* pbc)
 {
     rvec x1, dx;
     dvec dsum;
@@ -436,11 +589,13 @@ static int constr_vsiten(const t_iatom* ia, ArrayRef<const t_iparams> ip, rvec*
     return n3;
 }
 
-/*! \brief PBC modes for vsite construction and spreading */
+#endif // DOXYGEN
+
+//! PBC modes for vsite construction and spreading
 enum class PbcMode
 {
-    all, // Apply normal, simple PBC for all vsites
-    none // No PBC treatment needed
+    all, //!< Apply normal, simple PBC for all vsites
+    none //!< No PBC treatment needed
 };
 
 /*! \brief Returns the PBC mode based on the system PBC and vsite properties
@@ -459,15 +614,24 @@ static PbcMode getPbcMode(const t_pbc* pbcPtr)
     }
 }
 
-static void construct_vsites_thread(rvec                            x[],
-                                    real                            dt,
-                                    rvec*                           v,
+/*! \brief Executes the vsite construction task for a single thread
+ *
+ * \param[in,out] x   Coordinates to construct vsites for
+ * \param[in]     dt  Time step, needed when v is not empty
+ * \param[in,out] v   When not empty, velocities are generated for virtual sites
+ * \param[in]     ip  Interaction parameters for all interaction, only vsite parameters are used
+ * \param[in]     ilist  The interaction lists, only vsites are usesd
+ * \param[in]     pbc_null  PBC struct, used for PBC distance calculations when !=nullptr
+ */
+static void construct_vsites_thread(ArrayRef<RVec>                  x,
+                                    const real                      dt,
+                                    ArrayRef<RVec>                  v,
                                     ArrayRef<const t_iparams>       ip,
                                     ArrayRef<const InteractionList> ilist,
                                     const t_pbc*                    pbc_null)
 {
     real inv_dt;
-    if (v != nullptr)
+    if (!v.empty())
     {
         inv_dt = 1.0 / dt;
     }
@@ -575,7 +739,7 @@ static void construct_vsites_thread(rvec                            x[],
                         rvec_add(xv, dx, x[avsite]);
                     }
                 }
-                if (v != nullptr)
+                if (!v.empty())
                 {
                     /* Calculate velocity of vsite... */
                     rvec vv;
@@ -591,40 +755,44 @@ static void construct_vsites_thread(rvec                            x[],
     }
 }
 
-void construct_vsites(const gmx_vsite_t*              vsite,
-                      rvec                            x[],
-                      real                            dt,
-                      rvec*                           v,
-                      ArrayRef<const t_iparams>       ip,
-                      ArrayRef<const InteractionList> ilist,
-                      PbcType                         pbcType,
-                      gmx_bool                        bMolPBC,
-                      const t_commrec*                cr,
-                      const matrix                    box)
+/*! \brief Dispatch the vsite construction tasks for all threads
+ *
+ * \param[in]     threadingInfo  Used to divide work over threads when != nullptr
+ * \param[in,out] x   Coordinates to construct vsites for
+ * \param[in]     dt  Time step, needed when v is not empty
+ * \param[in,out] v   When not empty, velocities are generated for virtual sites
+ * \param[in]     ip  Interaction parameters for all interaction, only vsite parameters are used
+ * \param[in]     ilist  The interaction lists, only vsites are usesd
+ * \param[in]     domainInfo  Information about PBC and DD
+ * \param[in]     box  Used for PBC when PBC is set in domainInfo
+ */
+static void construct_vsites(const ThreadingInfo*            threadingInfo,
+                             ArrayRef<RVec>                  x,
+                             real                            dt,
+                             ArrayRef<RVec>                  v,
+                             ArrayRef<const t_iparams>       ip,
+                             ArrayRef<const InteractionList> ilist,
+                             const DomainInfo&               domainInfo,
+                             const matrix                    box)
 {
-    const bool useDomdec = (vsite != nullptr && vsite->useDomdec);
-    GMX_ASSERT(!useDomdec || (cr != nullptr && DOMAINDECOMP(cr)),
-               "When vsites are set up with domain decomposition, we need a valid commrec");
-    // TODO: Remove this assertion when we remove charge groups
-    GMX_ASSERT(vsite != nullptr || pbcType == PbcType::No,
-               "Without a vsite struct we can not do PBC (in case we have charge groups)");
+    const bool useDomdec = domainInfo.useDomdec();
 
     t_pbc pbc, *pbc_null;
 
-    /* We only need to do pbc when we have inter-cg vsites.
+    /* We only need to do pbc when we have inter update-group vsites.
      * Note that with domain decomposition we do not need to apply PBC here
      * when we have at least 3 domains along each dimension. Currently we
      * do not optimize this case.
      */
-    if (pbcType != PbcType::No && (useDomdec || bMolPBC)
-        && !(vsite != nullptr && vsite->numInterUpdategroupVsites == 0))
+    if (domainInfo.pbcType_ != PbcType::No && domainInfo.useMolPbc_)
     {
         /* This is wasting some CPU time as we now do this multiple times
          * per MD step.
          */
         ivec null_ivec;
         clear_ivec(null_ivec);
-        pbc_null = set_pbc_dd(&pbc, pbcType, useDomdec ? cr->dd->numCells : null_ivec, FALSE, box);
+        pbc_null = set_pbc_dd(&pbc, domainInfo.pbcType_,
+                              useDomdec ? domainInfo.domdec_->numCells : null_ivec, FALSE, box);
     }
     else
     {
@@ -633,21 +801,21 @@ void construct_vsites(const gmx_vsite_t*              vsite,
 
     if (useDomdec)
     {
-        dd_move_x_vsites(cr->dd, box, x);
+        dd_move_x_vsites(*domainInfo.domdec_, box, as_rvec_array(x.data()));
     }
 
-    if (vsite == nullptr || vsite->nthreads == 1)
+    if (threadingInfo == nullptr || threadingInfo->numThreads() == 1)
     {
         construct_vsites_thread(x, dt, v, ip, ilist, pbc_null);
     }
     else
     {
-#pragma omp parallel num_threads(vsite->nthreads)
+#pragma omp parallel num_threads(threadingInfo->numThreads())
         {
             try
             {
                 const int          th    = gmx_omp_get_thread_num();
-                const VsiteThread& tData = *vsite->tData[th];
+                const VsiteThread& tData = threadingInfo->threadData(th);
                 GMX_ASSERT(tData.rangeStart >= 0,
                            "The thread data should be initialized before calling construct_vsites");
 
@@ -664,15 +832,41 @@ void construct_vsites(const gmx_vsite_t*              vsite,
             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
         }
         /* Now we can construct the vsites that might depend on other vsites */
-        construct_vsites_thread(x, dt, v, ip, vsite->tData[vsite->nthreads]->ilist, pbc_null);
+        construct_vsites_thread(x, dt, v, ip, threadingInfo->threadDataNonLocalDependent().ilist, pbc_null);
     }
 }
 
-static void spread_vsite2(const t_iatom ia[], real a, const rvec x[], rvec f[], rvec fshift[], const t_pbc* pbc)
+void VirtualSitesHandler::Impl::construct(ArrayRef<RVec> x, real dt, ArrayRef<RVec> v, const matrix box) const
+{
+    construct_vsites(&threadingInfo_, x, dt, v, iparams_, ilists_, domainInfo_, box);
+}
+
+void VirtualSitesHandler::construct(ArrayRef<RVec> x, real dt, ArrayRef<RVec> v, const matrix box) const
+{
+    impl_->construct(x, dt, v, box);
+}
+
+void constructVirtualSites(ArrayRef<RVec> x, ArrayRef<const t_iparams> ip, ArrayRef<const InteractionList> ilist)
+
+{
+    // No PBC, no DD
+    const DomainInfo domainInfo;
+    construct_vsites(nullptr, x, 0, {}, ip, ilist, domainInfo, nullptr);
+}
+
+#ifndef DOXYGEN
+/* Force spreading routines */
+
+template<VirialHandling virialHandling>
+static void spread_vsite2(const t_iatom        ia[],
+                          real                 a,
+                          ArrayRef<const RVec> x,
+                          ArrayRef<RVec>       f,
+                          ArrayRef<RVec>       fshift,
+                          const t_pbc*         pbc)
 {
     rvec    fi, fj, dx;
     t_iatom av, ai, aj;
-    int     siv, sij;
 
     av = ia[1];
     ai = ia[2];
@@ -686,28 +880,33 @@ static void spread_vsite2(const t_iatom ia[], real a, const rvec x[], rvec f[],
     rvec_inc(f[aj], fj);
     /* 6 Flops */
 
-    if (pbc)
-    {
-        siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
-        sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
-    }
-    else
+    if (virialHandling == VirialHandling::Pbc)
     {
-        siv = CENTRAL;
-        sij = CENTRAL;
-    }
+        int siv;
+        int sij;
+        if (pbc)
+        {
+            siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
+            sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
+        }
+        else
+        {
+            siv = CENTRAL;
+            sij = CENTRAL;
+        }
 
-    if (fshift && (siv != CENTRAL || sij != CENTRAL))
-    {
-        rvec_inc(fshift[siv], f[av]);
-        rvec_dec(fshift[CENTRAL], fi);
-        rvec_dec(fshift[sij], fj);
+        if (siv != CENTRAL || sij != CENTRAL)
+        {
+            rvec_inc(fshift[siv], f[av]);
+            rvec_dec(fshift[CENTRAL], fi);
+            rvec_dec(fshift[sij], fj);
+        }
     }
 
     /* TOTAL: 13 flops */
 }
 
-void constructVsitesGlobal(const gmx_mtop_t& mtop, gmx::ArrayRef<gmx::RVec> x)
+void constructVirtualSitesGlobal(const gmx_mtop_t& mtop, gmx::ArrayRef<gmx::RVec> x)
 {
     GMX_ASSERT(x.ssize() >= mtop.natoms, "x should contain the whole system");
     GMX_ASSERT(!mtop.moleculeBlockIndices.empty(),
@@ -722,22 +921,22 @@ void constructVsitesGlobal(const gmx_mtop_t& mtop, gmx::ArrayRef<gmx::RVec> x)
             int atomOffset = mtop.moleculeBlockIndices[mb].globalAtomStart;
             for (int mol = 0; mol < molb.nmol; mol++)
             {
-                construct_vsites(nullptr, as_rvec_array(x.data()) + atomOffset, 0.0, nullptr,
-                                 mtop.ffparams.iparams, molt.ilist, PbcType::No, TRUE, nullptr, nullptr);
+                constructVirtualSites(x.subArray(atomOffset, molt.atoms.nr), mtop.ffparams.iparams,
+                                      molt.ilist);
                 atomOffset += molt.atoms.nr;
             }
         }
     }
 }
 
-static void spread_vsite2FD(const t_iatom ia[],
-                            real          a,
-                            const rvec    x[],
-                            rvec          f[],
-                            rvec          fshift[],
-                            gmx_bool      VirCorr,
-                            matrix        dxdf,
-                            const t_pbc*  pbc)
+template<VirialHandling virialHandling>
+static void spread_vsite2FD(const t_iatom        ia[],
+                            real                 a,
+                            ArrayRef<const RVec> x,
+                            ArrayRef<RVec>       f,
+                            ArrayRef<RVec>       fshift,
+                            matrix               dxdf,
+                            const t_pbc*         pbc)
 {
     const int av = ia[1];
     const int ai = ia[2];
@@ -772,7 +971,7 @@ static void spread_vsite2FD(const t_iatom ia[],
     f[aj][ZZ] += fj[ZZ];
     /* 9 Flops */
 
-    if (fshift)
+    if (virialHandling == VirialHandling::Pbc)
     {
         int svi;
         if (pbc)
@@ -797,9 +996,9 @@ static void spread_vsite2FD(const t_iatom ia[],
         }
     }
 
-    if (VirCorr)
+    if (virialHandling == VirialHandling::NonLinear)
     {
-        /* When VirCorr=TRUE, the virial for the current forces is not
+        /* Under this condition, the virial for the current forces is not
          * calculated from the redistributed forces. This means that
          * the effect of non-linear virtual site constructions on the virial
          * needs to be added separately. This contribution can be calculated
@@ -824,11 +1023,17 @@ static void spread_vsite2FD(const t_iatom ia[],
     /* TOTAL: 38 flops */
 }
 
-static void spread_vsite3(const t_iatom ia[], real a, real b, const rvec x[], rvec f[], rvec fshift[], const t_pbc* pbc)
+template<VirialHandling virialHandling>
+static void spread_vsite3(const t_iatom        ia[],
+                          real                 a,
+                          real                 b,
+                          ArrayRef<const RVec> x,
+                          ArrayRef<RVec>       f,
+                          ArrayRef<RVec>       fshift,
+                          const t_pbc*         pbc)
 {
     rvec fi, fj, fk, dx;
     int  av, ai, aj, ak;
-    int  siv, sij, sik;
 
     av = ia[1];
     ai = ia[2];
@@ -845,44 +1050,50 @@ static void spread_vsite3(const t_iatom ia[], real a, real b, const rvec x[], rv
     rvec_inc(f[ak], fk);
     /* 9 Flops */
 
-    if (pbc)
-    {
-        siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
-        sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
-        sik = pbc_dx_aiuc(pbc, x[ai], x[ak], dx);
-    }
-    else
+    if (virialHandling == VirialHandling::Pbc)
     {
-        siv = CENTRAL;
-        sij = CENTRAL;
-        sik = CENTRAL;
-    }
+        int siv;
+        int sij;
+        int sik;
+        if (pbc)
+        {
+            siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
+            sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
+            sik = pbc_dx_aiuc(pbc, x[ai], x[ak], dx);
+        }
+        else
+        {
+            siv = CENTRAL;
+            sij = CENTRAL;
+            sik = CENTRAL;
+        }
 
-    if (fshift && (siv != CENTRAL || sij != CENTRAL || sik != CENTRAL))
-    {
-        rvec_inc(fshift[siv], f[av]);
-        rvec_dec(fshift[CENTRAL], fi);
-        rvec_dec(fshift[sij], fj);
-        rvec_dec(fshift[sik], fk);
+        if (siv != CENTRAL || sij != CENTRAL || sik != CENTRAL)
+        {
+            rvec_inc(fshift[siv], f[av]);
+            rvec_dec(fshift[CENTRAL], fi);
+            rvec_dec(fshift[sij], fj);
+            rvec_dec(fshift[sik], fk);
+        }
     }
 
     /* TOTAL: 20 flops */
 }
 
-static void spread_vsite3FD(const t_iatom ia[],
-                            real          a,
-                            real          b,
-                            const rvec    x[],
-                            rvec          f[],
-                            rvec          fshift[],
-                            gmx_bool      VirCorr,
-                            matrix        dxdf,
-                            const t_pbc*  pbc)
+template<VirialHandling virialHandling>
+static void spread_vsite3FD(const t_iatom        ia[],
+                            real                 a,
+                            real                 b,
+                            ArrayRef<const RVec> x,
+                            ArrayRef<RVec>       f,
+                            ArrayRef<RVec>       fshift,
+                            matrix               dxdf,
+                            const t_pbc*         pbc)
 {
     real    fproj, a1;
     rvec    xvi, xij, xjk, xix, fv, temp;
     t_iatom av, ai, aj, ak;
-    int     svi, sji, skj;
+    int     sji, skj;
 
     av = ia[1];
     ai = ia[2];
@@ -926,32 +1137,36 @@ static void spread_vsite3FD(const t_iatom ia[],
     f[ak][ZZ] += a * temp[ZZ];
     /* 19 Flops */
 
-    if (pbc)
+    if (virialHandling == VirialHandling::Pbc)
     {
-        svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
-    }
-    else
-    {
-        svi = CENTRAL;
-    }
+        int svi;
+        if (pbc)
+        {
+            svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
+        }
+        else
+        {
+            svi = CENTRAL;
+        }
 
-    if (fshift && (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL))
-    {
-        rvec_dec(fshift[svi], fv);
-        fshift[CENTRAL][XX] += fv[XX] - (1 + a) * temp[XX];
-        fshift[CENTRAL][YY] += fv[YY] - (1 + a) * temp[YY];
-        fshift[CENTRAL][ZZ] += fv[ZZ] - (1 + a) * temp[ZZ];
-        fshift[sji][XX] += temp[XX];
-        fshift[sji][YY] += temp[YY];
-        fshift[sji][ZZ] += temp[ZZ];
-        fshift[skj][XX] += a * temp[XX];
-        fshift[skj][YY] += a * temp[YY];
-        fshift[skj][ZZ] += a * temp[ZZ];
+        if (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL)
+        {
+            rvec_dec(fshift[svi], fv);
+            fshift[CENTRAL][XX] += fv[XX] - (1 + a) * temp[XX];
+            fshift[CENTRAL][YY] += fv[YY] - (1 + a) * temp[YY];
+            fshift[CENTRAL][ZZ] += fv[ZZ] - (1 + a) * temp[ZZ];
+            fshift[sji][XX] += temp[XX];
+            fshift[sji][YY] += temp[YY];
+            fshift[sji][ZZ] += temp[ZZ];
+            fshift[skj][XX] += a * temp[XX];
+            fshift[skj][YY] += a * temp[YY];
+            fshift[skj][ZZ] += a * temp[ZZ];
+        }
     }
 
-    if (VirCorr)
+    if (virialHandling == VirialHandling::NonLinear)
     {
-        /* When VirCorr=TRUE, the virial for the current forces is not
+        /* Under this condition, the virial for the current forces is not
          * calculated from the redistributed forces. This means that
          * the effect of non-linear virtual site constructions on the virial
          * needs to be added separately. This contribution can be calculated
@@ -976,20 +1191,20 @@ static void spread_vsite3FD(const t_iatom ia[],
     /* TOTAL: 61 flops */
 }
 
-static void spread_vsite3FAD(const t_iatom ia[],
-                             real          a,
-                             real          b,
-                             const rvec    x[],
-                             rvec          f[],
-                             rvec          fshift[],
-                             gmx_bool      VirCorr,
-                             matrix        dxdf,
-                             const t_pbc*  pbc)
+template<VirialHandling virialHandling>
+static void spread_vsite3FAD(const t_iatom        ia[],
+                             real                 a,
+                             real                 b,
+                             ArrayRef<const RVec> x,
+                             ArrayRef<RVec>       f,
+                             ArrayRef<RVec>       fshift,
+                             matrix               dxdf,
+                             const t_pbc*         pbc)
 {
     rvec    xvi, xij, xjk, xperp, Fpij, Fppp, fv, f1, f2, f3;
     real    a1, b1, c1, c2, invdij, invdij2, invdp, fproj;
     t_iatom av, ai, aj, ak;
-    int     svi, sji, skj, d;
+    int     sji, skj;
 
     av = ia[1];
     ai = ia[2];
@@ -1024,7 +1239,7 @@ static void spread_vsite3FAD(const t_iatom ia[],
 
     rvec_sub(fv, Fpij, f1); /* f1 = f - Fpij */
     rvec_sub(f1, Fppp, f2); /* f2 = f - Fpij - Fppp */
-    for (d = 0; (d < DIM); d++)
+    for (int d = 0; d < DIM; d++)
     {
         f1[d] *= a1;
         f2[d] *= b1;
@@ -1043,39 +1258,42 @@ static void spread_vsite3FAD(const t_iatom ia[],
     f[ak][ZZ] += f2[ZZ];
     /* 30 Flops */
 
-    if (pbc)
-    {
-        svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
-    }
-    else
+    if (virialHandling == VirialHandling::Pbc)
     {
-        svi = CENTRAL;
-    }
+        int svi;
 
-    if (fshift && (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL))
-    {
-        rvec_dec(fshift[svi], fv);
-        fshift[CENTRAL][XX] += fv[XX] - f1[XX] - (1 - c1) * f2[XX] + f3[XX];
-        fshift[CENTRAL][YY] += fv[YY] - f1[YY] - (1 - c1) * f2[YY] + f3[YY];
-        fshift[CENTRAL][ZZ] += fv[ZZ] - f1[ZZ] - (1 - c1) * f2[ZZ] + f3[ZZ];
-        fshift[sji][XX] += f1[XX] - c1 * f2[XX] - f3[XX];
-        fshift[sji][YY] += f1[YY] - c1 * f2[YY] - f3[YY];
-        fshift[sji][ZZ] += f1[ZZ] - c1 * f2[ZZ] - f3[ZZ];
-        fshift[skj][XX] += f2[XX];
-        fshift[skj][YY] += f2[YY];
-        fshift[skj][ZZ] += f2[ZZ];
+        if (pbc)
+        {
+            svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
+        }
+        else
+        {
+            svi = CENTRAL;
+        }
+
+        if (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL)
+        {
+            rvec_dec(fshift[svi], fv);
+            fshift[CENTRAL][XX] += fv[XX] - f1[XX] - (1 - c1) * f2[XX] + f3[XX];
+            fshift[CENTRAL][YY] += fv[YY] - f1[YY] - (1 - c1) * f2[YY] + f3[YY];
+            fshift[CENTRAL][ZZ] += fv[ZZ] - f1[ZZ] - (1 - c1) * f2[ZZ] + f3[ZZ];
+            fshift[sji][XX] += f1[XX] - c1 * f2[XX] - f3[XX];
+            fshift[sji][YY] += f1[YY] - c1 * f2[YY] - f3[YY];
+            fshift[sji][ZZ] += f1[ZZ] - c1 * f2[ZZ] - f3[ZZ];
+            fshift[skj][XX] += f2[XX];
+            fshift[skj][YY] += f2[YY];
+            fshift[skj][ZZ] += f2[ZZ];
+        }
     }
 
-    if (VirCorr)
+    if (virialHandling == VirialHandling::NonLinear)
     {
         rvec xiv;
-        int  i, j;
-
         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
 
-        for (i = 0; i < DIM; i++)
+        for (int i = 0; i < DIM; i++)
         {
-            for (j = 0; j < DIM; j++)
+            for (int j = 0; j < DIM; j++)
             {
                 /* Note that xik=xij+xjk, so we have to add xij*f2 */
                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * (f1[j] + (1 - c2) * f2[j] - f3[j])
@@ -1087,21 +1305,21 @@ static void spread_vsite3FAD(const t_iatom ia[],
     /* TOTAL: 113 flops */
 }
 
-static void spread_vsite3OUT(const t_iatom ia[],
-                             real          a,
-                             real          b,
-                             real          c,
-                             const rvec    x[],
-                             rvec          f[],
-                             rvec          fshift[],
-                             gmx_bool      VirCorr,
-                             matrix        dxdf,
-                             const t_pbc*  pbc)
+template<VirialHandling virialHandling>
+static void spread_vsite3OUT(const t_iatom        ia[],
+                             real                 a,
+                             real                 b,
+                             real                 c,
+                             ArrayRef<const RVec> x,
+                             ArrayRef<RVec>       f,
+                             ArrayRef<RVec>       fshift,
+                             matrix               dxdf,
+                             const t_pbc*         pbc)
 {
     rvec xvi, xij, xik, fv, fj, fk;
     real cfx, cfy, cfz;
     int  av, ai, aj, ak;
-    int  svi, sji, ski;
+    int  sji, ski;
 
     av = ia[1];
     ai = ia[2];
@@ -1135,26 +1353,30 @@ static void spread_vsite3OUT(const t_iatom ia[],
     rvec_inc(f[ak], fk);
     /* 15 Flops */
 
-    if (pbc)
+    if (virialHandling == VirialHandling::Pbc)
     {
-        svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
-    }
-    else
-    {
-        svi = CENTRAL;
-    }
+        int svi;
+        if (pbc)
+        {
+            svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
+        }
+        else
+        {
+            svi = CENTRAL;
+        }
 
-    if (fshift && (svi != CENTRAL || sji != CENTRAL || ski != CENTRAL))
-    {
-        rvec_dec(fshift[svi], fv);
-        fshift[CENTRAL][XX] += fv[XX] - fj[XX] - fk[XX];
-        fshift[CENTRAL][YY] += fv[YY] - fj[YY] - fk[YY];
-        fshift[CENTRAL][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ];
-        rvec_inc(fshift[sji], fj);
-        rvec_inc(fshift[ski], fk);
+        if (svi != CENTRAL || sji != CENTRAL || ski != CENTRAL)
+        {
+            rvec_dec(fshift[svi], fv);
+            fshift[CENTRAL][XX] += fv[XX] - fj[XX] - fk[XX];
+            fshift[CENTRAL][YY] += fv[YY] - fj[YY] - fk[YY];
+            fshift[CENTRAL][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ];
+            rvec_inc(fshift[sji], fj);
+            rvec_inc(fshift[ski], fk);
+        }
     }
 
-    if (VirCorr)
+    if (virialHandling == VirialHandling::NonLinear)
     {
         rvec xiv;
 
@@ -1172,21 +1394,21 @@ static void spread_vsite3OUT(const t_iatom ia[],
     /* TOTAL: 54 flops */
 }
 
-static void spread_vsite4FD(const t_iatom ia[],
-                            real          a,
-                            real          b,
-                            real          c,
-                            const rvec    x[],
-                            rvec          f[],
-                            rvec          fshift[],
-                            gmx_bool      VirCorr,
-                            matrix        dxdf,
-                            const t_pbc*  pbc)
+template<VirialHandling virialHandling>
+static void spread_vsite4FD(const t_iatom        ia[],
+                            real                 a,
+                            real                 b,
+                            real                 c,
+                            ArrayRef<const RVec> x,
+                            ArrayRef<RVec>       f,
+                            ArrayRef<RVec>       fshift,
+                            matrix               dxdf,
+                            const t_pbc*         pbc)
 {
     real fproj, a1;
     rvec xvi, xij, xjk, xjl, xix, fv, temp;
     int  av, ai, aj, ak, al;
-    int  svi, sji, skj, slj, m;
+    int  sji, skj, slj, m;
 
     av = ia[1];
     ai = ia[2];
@@ -1233,28 +1455,32 @@ static void spread_vsite4FD(const t_iatom ia[],
     }
     /* 26 Flops */
 
-    if (pbc)
+    if (virialHandling == VirialHandling::Pbc)
     {
-        svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
-    }
-    else
-    {
-        svi = CENTRAL;
-    }
+        int svi;
+        if (pbc)
+        {
+            svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
+        }
+        else
+        {
+            svi = CENTRAL;
+        }
 
-    if (fshift && (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL || slj != CENTRAL))
-    {
-        rvec_dec(fshift[svi], fv);
-        for (m = 0; m < DIM; m++)
+        if (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL || slj != CENTRAL)
         {
-            fshift[CENTRAL][m] += fv[m] - (1 + a + b) * temp[m];
-            fshift[sji][m] += temp[m];
-            fshift[skj][m] += a * temp[m];
-            fshift[slj][m] += b * temp[m];
+            rvec_dec(fshift[svi], fv);
+            for (m = 0; m < DIM; m++)
+            {
+                fshift[CENTRAL][m] += fv[m] - (1 + a + b) * temp[m];
+                fshift[sji][m] += temp[m];
+                fshift[skj][m] += a * temp[m];
+                fshift[slj][m] += b * temp[m];
+            }
         }
     }
 
-    if (VirCorr)
+    if (virialHandling == VirialHandling::NonLinear)
     {
         rvec xiv;
         int  i, j;
@@ -1273,24 +1499,23 @@ static void spread_vsite4FD(const t_iatom ia[],
     /* TOTAL: 77 flops */
 }
 
-
-static void spread_vsite4FDN(const t_iatom ia[],
-                             real          a,
-                             real          b,
-                             real          c,
-                             const rvec    x[],
-                             rvec          f[],
-                             rvec          fshift[],
-                             gmx_bool      VirCorr,
-                             matrix        dxdf,
-                             const t_pbc*  pbc)
+template<VirialHandling virialHandling>
+static void spread_vsite4FDN(const t_iatom        ia[],
+                             real                 a,
+                             real                 b,
+                             real                 c,
+                             ArrayRef<const RVec> x,
+                             ArrayRef<RVec>       f,
+                             ArrayRef<RVec>       fshift,
+                             matrix               dxdf,
+                             const t_pbc*         pbc)
 {
     rvec xvi, xij, xik, xil, ra, rb, rja, rjb, rab, rm, rt;
     rvec fv, fj, fk, fl;
     real invrm, denom;
     real cfx, cfy, cfz;
     int  av, ai, aj, ak, al;
-    int  svi, sij, sik, sil;
+    int  sij, sik, sil;
 
     /* DEBUG: check atom indices */
     av = ia[1];
@@ -1389,27 +1614,31 @@ static void spread_vsite4FDN(const t_iatom ia[],
     rvec_inc(f[al], fl);
     /* 21 flops */
 
-    if (pbc)
-    {
-        svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
-    }
-    else
+    if (virialHandling == VirialHandling::Pbc)
     {
-        svi = CENTRAL;
-    }
+        int svi;
+        if (pbc)
+        {
+            svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
+        }
+        else
+        {
+            svi = CENTRAL;
+        }
 
-    if (fshift && (svi != CENTRAL || sij != CENTRAL || sik != CENTRAL || sil != CENTRAL))
-    {
-        rvec_dec(fshift[svi], fv);
-        fshift[CENTRAL][XX] += fv[XX] - fj[XX] - fk[XX] - fl[XX];
-        fshift[CENTRAL][YY] += fv[YY] - fj[YY] - fk[YY] - fl[YY];
-        fshift[CENTRAL][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ] - fl[ZZ];
-        rvec_inc(fshift[sij], fj);
-        rvec_inc(fshift[sik], fk);
-        rvec_inc(fshift[sil], fl);
+        if (svi != CENTRAL || sij != CENTRAL || sik != CENTRAL || sil != CENTRAL)
+        {
+            rvec_dec(fshift[svi], fv);
+            fshift[CENTRAL][XX] += fv[XX] - fj[XX] - fk[XX] - fl[XX];
+            fshift[CENTRAL][YY] += fv[YY] - fj[YY] - fk[YY] - fl[YY];
+            fshift[CENTRAL][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ] - fl[ZZ];
+            rvec_inc(fshift[sij], fj);
+            rvec_inc(fshift[sik], fk);
+            rvec_inc(fshift[sil], fl);
+        }
     }
 
-    if (VirCorr)
+    if (virialHandling == VirialHandling::NonLinear)
     {
         rvec xiv;
         int  i, j;
@@ -1428,12 +1657,12 @@ static void spread_vsite4FDN(const t_iatom ia[],
     /* Total: 207 flops (Yuck!) */
 }
 
-
+template<VirialHandling virialHandling>
 static int spread_vsiten(const t_iatom             ia[],
                          ArrayRef<const t_iparams> ip,
-                         const rvec                x[],
-                         rvec                      f[],
-                         rvec                      fshift[],
+                         ArrayRef<const RVec>      x,
+                         ArrayRef<RVec>            f,
+                         ArrayRef<RVec>            fshift,
                          const t_pbc*              pbc)
 {
     rvec xv, dx, fi;
@@ -1459,7 +1688,8 @@ static int spread_vsiten(const t_iatom             ia[],
         a = ip[ia[i]].vsiten.a;
         svmul(a, f[av], fi);
         rvec_inc(f[ai], fi);
-        if (fshift && siv != CENTRAL)
+
+        if (virialHandling == VirialHandling::Pbc && siv != CENTRAL)
         {
             rvec_inc(fshift[siv], fi);
             rvec_dec(fshift[CENTRAL], fi);
@@ -1470,7 +1700,9 @@ static int spread_vsiten(const t_iatom             ia[],
     return n3;
 }
 
+#endif // DOXYGEN
 
+//! Returns the number of virtual sites in the interaction list, for VSITEN the number of atoms
 static int vsite_count(ArrayRef<const InteractionList> ilist, int ftype)
 {
     if (ftype == F_VSITEN)
@@ -1483,14 +1715,15 @@ static int vsite_count(ArrayRef<const InteractionList> ilist, int ftype)
     }
 }
 
-static void spread_vsite_f_thread(const rvec                      x[],
-                                  rvec                            f[],
-                                  rvec*                           fshift,
-                                  gmx_bool                        VirCorr,
-                                  matrix                          dxdf,
-                                  ArrayRef<const t_iparams>       ip,
-                                  ArrayRef<const InteractionList> ilist,
-                                  const t_pbc*                    pbc_null)
+//! Executes the force spreading task for a single thread
+template<VirialHandling virialHandling>
+static void spreadForceForThread(ArrayRef<const RVec>            x,
+                                 ArrayRef<RVec>                  f,
+                                 ArrayRef<RVec>                  fshift,
+                                 matrix                          dxdf,
+                                 ArrayRef<const t_iparams>       ip,
+                                 ArrayRef<const InteractionList> ilist,
+                                 const t_pbc*                    pbc_null)
 {
     const PbcMode pbcMode = getPbcMode(pbc_null);
     /* We need another pbc pointer, as with charge groups we switch per vsite */
@@ -1528,38 +1761,42 @@ static void spread_vsite_f_thread(const rvec                      x[],
                 /* Construct the vsite depending on type */
                 switch (ftype)
                 {
-                    case F_VSITE2: spread_vsite2(ia, a1, x, f, fshift, pbc_null2); break;
+                    case F_VSITE2:
+                        spread_vsite2<virialHandling>(ia, a1, x, f, fshift, pbc_null2);
+                        break;
                     case F_VSITE2FD:
-                        spread_vsite2FD(ia, a1, x, f, fshift, VirCorr, dxdf, pbc_null2);
+                        spread_vsite2FD<virialHandling>(ia, a1, x, f, fshift, dxdf, pbc_null2);
                         break;
                     case F_VSITE3:
                         b1 = ip[tp].vsite.b;
-                        spread_vsite3(ia, a1, b1, x, f, fshift, pbc_null2);
+                        spread_vsite3<virialHandling>(ia, a1, b1, x, f, fshift, pbc_null2);
                         break;
                     case F_VSITE3FD:
                         b1 = ip[tp].vsite.b;
-                        spread_vsite3FD(ia, a1, b1, x, f, fshift, VirCorr, dxdf, pbc_null2);
+                        spread_vsite3FD<virialHandling>(ia, a1, b1, x, f, fshift, dxdf, pbc_null2);
                         break;
                     case F_VSITE3FAD:
                         b1 = ip[tp].vsite.b;
-                        spread_vsite3FAD(ia, a1, b1, x, f, fshift, VirCorr, dxdf, pbc_null2);
+                        spread_vsite3FAD<virialHandling>(ia, a1, b1, x, f, fshift, dxdf, pbc_null2);
                         break;
                     case F_VSITE3OUT:
                         b1 = ip[tp].vsite.b;
                         c1 = ip[tp].vsite.c;
-                        spread_vsite3OUT(ia, a1, b1, c1, x, f, fshift, VirCorr, dxdf, pbc_null2);
+                        spread_vsite3OUT<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
                         break;
                     case F_VSITE4FD:
                         b1 = ip[tp].vsite.b;
                         c1 = ip[tp].vsite.c;
-                        spread_vsite4FD(ia, a1, b1, c1, x, f, fshift, VirCorr, dxdf, pbc_null2);
+                        spread_vsite4FD<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
                         break;
                     case F_VSITE4FDN:
                         b1 = ip[tp].vsite.b;
                         c1 = ip[tp].vsite.c;
-                        spread_vsite4FDN(ia, a1, b1, c1, x, f, fshift, VirCorr, dxdf, pbc_null2);
+                        spread_vsite4FDN<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
+                        break;
+                    case F_VSITEN:
+                        inc = spread_vsiten<virialHandling>(ia, ip, x, f, fshift, pbc_null2);
                         break;
-                    case F_VSITEN: inc = spread_vsiten(ia, ip, x, f, fshift, pbc_null2); break;
                     default:
                         gmx_fatal(FARGS, "No such vsite type %d in %s, line %d", ftype, __FILE__, __LINE__);
                 }
@@ -1573,7 +1810,37 @@ static void spread_vsite_f_thread(const rvec                      x[],
     }
 }
 
-/*! \brief Clears the task force buffer elements that are written by task idTask */
+//! Wrapper function for calling the templated thread-local spread function
+static void spreadForceWrapper(ArrayRef<const RVec>            x,
+                               ArrayRef<RVec>                  f,
+                               const VirialHandling            virialHandling,
+                               ArrayRef<RVec>                  fshift,
+                               matrix                          dxdf,
+                               const bool                      clearDxdf,
+                               ArrayRef<const t_iparams>       ip,
+                               ArrayRef<const InteractionList> ilist,
+                               const t_pbc*                    pbc_null)
+{
+    if (virialHandling == VirialHandling::NonLinear && clearDxdf)
+    {
+        clear_mat(dxdf);
+    }
+
+    switch (virialHandling)
+    {
+        case VirialHandling::None:
+            spreadForceForThread<VirialHandling::None>(x, f, fshift, dxdf, ip, ilist, pbc_null);
+            break;
+        case VirialHandling::Pbc:
+            spreadForceForThread<VirialHandling::Pbc>(x, f, fshift, dxdf, ip, ilist, pbc_null);
+            break;
+        case VirialHandling::NonLinear:
+            spreadForceForThread<VirialHandling::NonLinear>(x, f, fshift, dxdf, ip, ilist, pbc_null);
+            break;
+    }
+}
+
+//! Clears the task force buffer elements that are written by task idTask
 static void clearTaskForceBufferUsedElements(InterdependentTask* idTask)
 {
     int ntask = idTask->spreadTask.size();
@@ -1589,34 +1856,28 @@ static void clearTaskForceBufferUsedElements(InterdependentTask* idTask)
     }
 }
 
-void spread_vsite_f(const gmx_vsite_t* vsite,
-                    const rvec* gmx_restrict x,
-                    rvec* gmx_restrict f,
-                    rvec* gmx_restrict            fshift,
-                    gmx_bool                      VirCorr,
-                    matrix                        vir,
-                    t_nrnb*                       nrnb,
-                    const InteractionDefinitions& idef,
-                    PbcType                       pbcType,
-                    gmx_bool                      bMolPBC,
-                    const matrix                  box,
-                    const t_commrec*              cr,
-                    gmx_wallcycle*                wcycle)
+void VirtualSitesHandler::Impl::spreadForces(ArrayRef<const RVec> x,
+                                             ArrayRef<RVec>       f,
+                                             const VirialHandling virialHandling,
+                                             ArrayRef<RVec>       fshift,
+                                             matrix               virial,
+                                             t_nrnb*              nrnb,
+                                             const matrix         box,
+                                             gmx_wallcycle*       wcycle)
 {
     wallcycle_start(wcycle, ewcVSITESPREAD);
-    const bool useDomdec = vsite->useDomdec;
-    GMX_ASSERT(!useDomdec || (cr != nullptr && DOMAINDECOMP(cr)),
-               "When vsites are set up with domain decomposition, we need a valid commrec");
+
+    const bool useDomdec = domainInfo_.useDomdec();
 
     t_pbc pbc, *pbc_null;
 
-    /* We only need to do pbc when we have inter-cg vsites */
-    if ((useDomdec || bMolPBC) && vsite->numInterUpdategroupVsites)
+    if (domainInfo_.useMolPbc_)
     {
         /* This is wasting some CPU time as we now do this multiple times
          * per MD step.
          */
-        pbc_null = set_pbc_dd(&pbc, pbcType, useDomdec ? cr->dd->numCells : nullptr, FALSE, box);
+        pbc_null = set_pbc_dd(&pbc, domainInfo_.pbcType_,
+                              useDomdec ? domainInfo_.domdec_->numCells : nullptr, FALSE, box);
     }
     else
     {
@@ -1625,25 +1886,23 @@ void spread_vsite_f(const gmx_vsite_t* vsite,
 
     if (useDomdec)
     {
-        dd_clear_f_vsites(cr->dd, f);
+        dd_clear_f_vsites(*domainInfo_.domdec_, f);
     }
 
-    if (vsite->nthreads == 1)
+    const int numThreads = threadingInfo_.numThreads();
+
+    if (numThreads == 1)
     {
         matrix dxdf;
-        if (VirCorr)
-        {
-            clear_mat(dxdf);
-        }
-        spread_vsite_f_thread(x, f, fshift, VirCorr, dxdf, idef.iparams, idef.il, pbc_null);
+        spreadForceWrapper(x, f, virialHandling, fshift, dxdf, true, iparams_, ilists_, pbc_null);
 
-        if (VirCorr)
+        if (virialHandling == VirialHandling::NonLinear)
         {
             for (int i = 0; i < DIM; i++)
             {
                 for (int j = 0; j < DIM; j++)
                 {
-                    vir[i][j] += -0.5 * dxdf[i][j];
+                    virial[i][j] += -0.5 * dxdf[i][j];
                 }
             }
         }
@@ -1651,37 +1910,33 @@ void spread_vsite_f(const gmx_vsite_t* vsite,
     else
     {
         /* First spread the vsites that might depend on non-local vsites */
-        if (VirCorr)
-        {
-            clear_mat(vsite->tData[vsite->nthreads]->dxdf);
-        }
-        spread_vsite_f_thread(x, f, fshift, VirCorr, vsite->tData[vsite->nthreads]->dxdf,
-                              idef.iparams, vsite->tData[vsite->nthreads]->ilist, pbc_null);
+        auto& nlDependentVSites = threadingInfo_.threadDataNonLocalDependent();
+        spreadForceWrapper(x, f, virialHandling, fshift, nlDependentVSites.dxdf, true, iparams_,
+                           nlDependentVSites.ilist, pbc_null);
 
-#pragma omp parallel num_threads(vsite->nthreads)
+#pragma omp parallel num_threads(numThreads)
         {
             try
             {
                 int          thread = gmx_omp_get_thread_num();
-                VsiteThread& tData  = *vsite->tData[thread];
+                VsiteThread& tData  = threadingInfo_.threadData(thread);
 
-                rvec* fshift_t;
-                if (thread == 0 || fshift == nullptr)
-                {
-                    fshift_t = fshift;
-                }
-                else
+                ArrayRef<RVec> fshift_t;
+                if (virialHandling == VirialHandling::Pbc)
                 {
-                    fshift_t = tData.fshift;
-
-                    for (int i = 0; i < SHIFTS; i++)
+                    if (thread == 0)
                     {
-                        clear_rvec(fshift_t[i]);
+                        fshift_t = fshift;
+                    }
+                    else
+                    {
+                        fshift_t = tData.fshift;
+
+                        for (int i = 0; i < SHIFTS; i++)
+                        {
+                            clear_rvec(fshift_t[i]);
+                        }
                     }
-                }
-                if (VirCorr)
-                {
-                    clear_mat(tData.dxdf);
                 }
 
                 if (tData.useInterdependentTask)
@@ -1702,8 +1957,8 @@ void spread_vsite_f(const gmx_vsite_t* vsite,
                     {
                         copy_rvec(f[idTask->vsite[i]], idTask->force[idTask->vsite[i]]);
                     }
-                    spread_vsite_f_thread(x, as_rvec_array(idTask->force.data()), fshift_t, VirCorr,
-                                          tData.dxdf, idef.iparams, tData.idTask.ilist, pbc_null);
+                    spreadForceWrapper(x, idTask->force, virialHandling, fshift_t, tData.dxdf, true,
+                                       iparams_, tData.idTask.ilist, pbc_null);
 
                     /* We need a barrier before reducing forces below
                      * that have been produced by a different thread above.
@@ -1718,15 +1973,13 @@ void spread_vsite_f(const gmx_vsite_t* vsite,
                     int ntask = idTask->reduceTask.size();
                     for (int ti = 0; ti < ntask; ti++)
                     {
-                        const InterdependentTask* idt_foreign =
-                                &vsite->tData[idTask->reduceTask[ti]]->idTask;
-                        const AtomIndex* atomList  = &idt_foreign->atomIndex[thread];
-                        const RVec*      f_foreign = idt_foreign->force.data();
+                        const InterdependentTask& idt_foreign =
+                                threadingInfo_.threadData(idTask->reduceTask[ti]).idTask;
+                        const AtomIndex& atomList  = idt_foreign.atomIndex[thread];
+                        const RVec*      f_foreign = idt_foreign.force.data();
 
-                        int natom = atomList->atom.size();
-                        for (int i = 0; i < natom; i++)
+                        for (int ind : atomList.atom)
                         {
-                            int ind = atomList->atom[i];
                             rvec_inc(f[ind], f_foreign[ind]);
                             /* Clearing of f_foreign is done at the next step */
                         }
@@ -1741,35 +1994,35 @@ void spread_vsite_f(const gmx_vsite_t* vsite,
                 }
 
                 /* Spread the vsites that spread locally only */
-                spread_vsite_f_thread(x, f, fshift_t, VirCorr, tData.dxdf, idef.iparams,
-                                      tData.ilist, pbc_null);
+                spreadForceWrapper(x, f, virialHandling, fshift_t, tData.dxdf, false, iparams_,
+                                   tData.ilist, pbc_null);
             }
             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
         }
 
-        if (fshift != nullptr)
+        if (virialHandling == VirialHandling::Pbc)
         {
-            for (int th = 1; th < vsite->nthreads; th++)
+            for (int th = 1; th < numThreads; th++)
             {
                 for (int i = 0; i < SHIFTS; i++)
                 {
-                    rvec_inc(fshift[i], vsite->tData[th]->fshift[i]);
+                    rvec_inc(fshift[i], threadingInfo_.threadData(th).fshift[i]);
                 }
             }
         }
 
-        if (VirCorr)
+        if (virialHandling == VirialHandling::NonLinear)
         {
-            for (int th = 0; th < vsite->nthreads + 1; th++)
+            for (int th = 0; th < numThreads + 1; th++)
             {
                 /* MSVC doesn't like matrix references, so we use a pointer */
-                const matrix* dxdf = &vsite->tData[th]->dxdf;
+                const matrix& dxdf = threadingInfo_.threadData(th).dxdf;
 
                 for (int i = 0; i < DIM; i++)
                 {
                     for (int j = 0; j < DIM; j++)
                     {
-                        vir[i][j] += -0.5 * (*dxdf)[i][j];
+                        virial[i][j] += -0.5 * dxdf[i][j];
                     }
                 }
             }
@@ -1778,18 +2031,18 @@ void spread_vsite_f(const gmx_vsite_t* vsite,
 
     if (useDomdec)
     {
-        dd_move_f_vsites(cr->dd, f, fshift);
+        dd_move_f_vsites(*domainInfo_.domdec_, f, fshift);
     }
 
-    inc_nrnb(nrnb, eNR_VSITE2, vsite_count(idef.il, F_VSITE2));
-    inc_nrnb(nrnb, eNR_VSITE2FD, vsite_count(idef.il, F_VSITE2FD));
-    inc_nrnb(nrnb, eNR_VSITE3, vsite_count(idef.il, F_VSITE3));
-    inc_nrnb(nrnb, eNR_VSITE3FD, vsite_count(idef.il, F_VSITE3FD));
-    inc_nrnb(nrnb, eNR_VSITE3FAD, vsite_count(idef.il, F_VSITE3FAD));
-    inc_nrnb(nrnb, eNR_VSITE3OUT, vsite_count(idef.il, F_VSITE3OUT));
-    inc_nrnb(nrnb, eNR_VSITE4FD, vsite_count(idef.il, F_VSITE4FD));
-    inc_nrnb(nrnb, eNR_VSITE4FDN, vsite_count(idef.il, F_VSITE4FDN));
-    inc_nrnb(nrnb, eNR_VSITEN, vsite_count(idef.il, F_VSITEN));
+    inc_nrnb(nrnb, eNR_VSITE2, vsite_count(ilists_, F_VSITE2));
+    inc_nrnb(nrnb, eNR_VSITE2FD, vsite_count(ilists_, F_VSITE2FD));
+    inc_nrnb(nrnb, eNR_VSITE3, vsite_count(ilists_, F_VSITE3));
+    inc_nrnb(nrnb, eNR_VSITE3FD, vsite_count(ilists_, F_VSITE3FD));
+    inc_nrnb(nrnb, eNR_VSITE3FAD, vsite_count(ilists_, F_VSITE3FAD));
+    inc_nrnb(nrnb, eNR_VSITE3OUT, vsite_count(ilists_, F_VSITE3OUT));
+    inc_nrnb(nrnb, eNR_VSITE4FD, vsite_count(ilists_, F_VSITE4FD));
+    inc_nrnb(nrnb, eNR_VSITE4FDN, vsite_count(ilists_, F_VSITE4FDN));
+    inc_nrnb(nrnb, eNR_VSITEN, vsite_count(ilists_, F_VSITEN));
 
     wallcycle_stop(wcycle, ewcVSITESPREAD);
 }
@@ -1831,6 +2084,18 @@ int countNonlinearVsites(const gmx_mtop_t& mtop)
     return numNonlinearVsites;
 }
 
+void VirtualSitesHandler::spreadForces(ArrayRef<const RVec> x,
+                                       ArrayRef<RVec>       f,
+                                       const VirialHandling virialHandling,
+                                       ArrayRef<RVec>       fshift,
+                                       matrix               virial,
+                                       t_nrnb*              nrnb,
+                                       const matrix         box,
+                                       gmx_wallcycle*       wcycle)
+{
+    impl_->spreadForces(x, f, virialHandling, fshift, virial, nrnb, box, wcycle);
+}
+
 int countInterUpdategroupVsites(const gmx_mtop_t&                           mtop,
                                 gmx::ArrayRef<const gmx::RangePartitioning> updateGroupingPerMoleculetype)
 {
@@ -1874,11 +2139,13 @@ int countInterUpdategroupVsites(const gmx_mtop_t&                           mtop
     return n_intercg_vsite;
 }
 
-std::unique_ptr<gmx_vsite_t> initVsite(const gmx_mtop_t& mtop, const t_commrec* cr)
+std::unique_ptr<VirtualSitesHandler> makeVirtualSitesHandler(const gmx_mtop_t& mtop,
+                                                             const t_commrec*  cr,
+                                                             PbcType           pbcType)
 {
     GMX_RELEASE_ASSERT(cr != nullptr, "We need a valid commrec");
 
-    std::unique_ptr<gmx_vsite_t> vsite;
+    std::unique_ptr<VirtualSitesHandler> vsite;
 
     /* check if there are vsites */
     int nvsite = 0;
@@ -1903,57 +2170,68 @@ std::unique_ptr<gmx_vsite_t> initVsite(const gmx_mtop_t& mtop, const t_commrec*
         return vsite;
     }
 
-    vsite = std::make_unique<gmx_vsite_t>();
-
-    gmx::ArrayRef<const gmx::RangePartitioning> updateGroupingPerMoleculetype;
-    if (DOMAINDECOMP(cr))
-    {
-        updateGroupingPerMoleculetype = getUpdateGroupingPerMoleculetype(*cr->dd);
-    }
-    vsite->numInterUpdategroupVsites = countInterUpdategroupVsites(mtop, updateGroupingPerMoleculetype);
-
-    vsite->useDomdec = (DOMAINDECOMP(cr) && cr->dd->nnodes > 1);
-
-    vsite->nthreads = gmx_omp_nthreads_get(emntVSITE);
+    return std::make_unique<VirtualSitesHandler>(mtop, cr->dd, pbcType);
+}
 
-    if (vsite->nthreads > 1)
+ThreadingInfo::ThreadingInfo() : numThreads_(gmx_omp_nthreads_get(emntVSITE))
+{
+    if (numThreads_ > 1)
     {
         /* We need one extra thread data structure for the overlap vsites */
-        vsite->tData.resize(vsite->nthreads + 1);
-#pragma omp parallel for num_threads(vsite->nthreads) schedule(static)
-        for (int thread = 0; thread < vsite->nthreads; thread++)
+        tData_.resize(numThreads_ + 1);
+#pragma omp parallel for num_threads(numThreads_) schedule(static)
+        for (int thread = 0; thread < numThreads_; thread++)
         {
             try
             {
-                vsite->tData[thread] = std::make_unique<VsiteThread>();
+                tData_[thread] = std::make_unique<VsiteThread>();
 
-                InterdependentTask& idTask = vsite->tData[thread]->idTask;
+                InterdependentTask& idTask = tData_[thread]->idTask;
                 idTask.nuse                = 0;
-                idTask.atomIndex.resize(vsite->nthreads);
+                idTask.atomIndex.resize(numThreads_);
             }
             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
         }
-        if (vsite->nthreads > 1)
+        if (numThreads_ > 1)
         {
-            vsite->tData[vsite->nthreads] = std::make_unique<VsiteThread>();
+            tData_[numThreads_] = std::make_unique<VsiteThread>();
         }
     }
+}
+
+//! Returns the number of inter update-group vsites
+static int getNumInterUpdategroupVsites(const gmx_mtop_t& mtop, const gmx_domdec_t* domdec)
+{
+    gmx::ArrayRef<const gmx::RangePartitioning> updateGroupingPerMoleculetype;
+    if (domdec)
+    {
+        updateGroupingPerMoleculetype = getUpdateGroupingPerMoleculetype(*domdec);
+    }
 
-    return vsite;
+    return countInterUpdategroupVsites(mtop, updateGroupingPerMoleculetype);
 }
 
-gmx_vsite_t::gmx_vsite_t() {}
+VirtualSitesHandler::Impl::Impl(const gmx_mtop_t& mtop, gmx_domdec_t* domdec, const PbcType pbcType) :
+    numInterUpdategroupVirtualSites_(getNumInterUpdategroupVsites(mtop, domdec)),
+    domainInfo_({ pbcType, pbcType != PbcType::No && numInterUpdategroupVirtualSites_ > 0, domdec }),
+    iparams_(mtop.ffparams.iparams)
+{
+}
 
-gmx_vsite_t::~gmx_vsite_t() {}
+VirtualSitesHandler::VirtualSitesHandler(const gmx_mtop_t& mtop, gmx_domdec_t* domdec, const PbcType pbcType) :
+    impl_(new Impl(mtop, domdec, pbcType))
+{
+}
 
-static inline void flagAtom(InterdependentTask* idTask, int atom, int thread, int nthread, int natperthread)
+//! Flag that atom \p atom which is home in another task, if it has not already been added before
+static inline void flagAtom(InterdependentTask* idTask, const int atom, const int numThreads, const int numAtomsPerThread)
 {
     if (!idTask->use[atom])
     {
         idTask->use[atom] = true;
-        thread            = atom / natperthread;
+        int thread        = atom / numAtomsPerThread;
         /* Assign all non-local atom force writes to thread 0 */
-        if (thread >= nthread)
+        if (thread >= numThreads)
         {
             thread = 0;
         }
@@ -1961,7 +2239,7 @@ static inline void flagAtom(InterdependentTask* idTask, int atom, int thread, in
     }
 }
 
-/*\brief Here we try to assign all vsites that are in our local range.
+/*\brief Here we try to assign all vsites that are in our local range.
  *
  * Our task local atom range is tData->rangeStart - tData->rangeEnd.
  * Vsites that depend only on local atoms, as indicated by taskIndex[]==thread,
@@ -2082,14 +2360,14 @@ static void assignVsitesToThread(VsiteThread*                    tData,
                     {
                         for (int j = i + 2; j < i + nral1; j++)
                         {
-                            flagAtom(&tData->idTask, iat[j], thread, nthread, natperthread);
+                            flagAtom(&tData->idTask, iat[j], nthread, natperthread);
                         }
                     }
                     else
                     {
                         for (int j = i + 2; j < i + inc; j += 3)
                         {
-                            flagAtom(&tData->idTask, iat[j], thread, nthread, natperthread);
+                            flagAtom(&tData->idTask, iat[j], nthread, natperthread);
                         }
                     }
                 }
@@ -2136,14 +2414,12 @@ static void assignVsitesToSingleTask(VsiteThread*                    tData,
     }
 }
 
-void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
-                               ArrayRef<const t_iparams>       ip,
-                               const t_mdatoms*                mdatoms,
-                               gmx_vsite_t*                    vsite)
+void ThreadingInfo::setVirtualSites(ArrayRef<const InteractionList> ilists,
+                                    ArrayRef<const t_iparams>       iparams,
+                                    const t_mdatoms&                mdatoms,
+                                    const bool                      useDomdec)
 {
-    int vsite_atom_range, natperthread;
-
-    if (vsite->nthreads == 1)
+    if (numThreads_ <= 1)
     {
         /* Nothing to do */
         return;
@@ -2159,7 +2435,9 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
      * uniformly in each domain along the major dimension, usually x,
      * it will also perform well.
      */
-    if (!vsite->useDomdec)
+    int vsite_atom_range;
+    int natperthread;
+    if (!useDomdec)
     {
         vsite_atom_range = -1;
         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
@@ -2168,8 +2446,8 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
                 if (ftype != F_VSITEN)
                 {
                     int                 nral1 = 1 + NRAL(ftype);
-                    ArrayRef<const int> iat   = ilist[ftype].iatoms;
-                    for (int i = 0; i < ilist[ftype].size(); i += nral1)
+                    ArrayRef<const int> iat   = ilists[ftype].iatoms;
+                    for (int i = 0; i < ilists[ftype].size(); i += nral1)
                     {
                         for (int j = i + 1; j < i + nral1; j++)
                         {
@@ -2181,13 +2459,13 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
                 {
                     int vs_ind_end;
 
-                    ArrayRef<const int> iat = ilist[ftype].iatoms;
+                    ArrayRef<const int> iat = ilists[ftype].iatoms;
 
                     int i = 0;
-                    while (i < ilist[ftype].size())
+                    while (i < ilists[ftype].size())
                     {
                         /* The 3 below is from 1+NRAL(ftype)=3 */
-                        vs_ind_end = i + ip[iat[i]].vsiten.n * 3;
+                        vs_ind_end = i + iparams[iat[i]].vsiten.n * 3;
 
                         vsite_atom_range = std::max(vsite_atom_range, iat[i + 1]);
                         while (i < vs_ind_end)
@@ -2200,7 +2478,7 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
             }
         }
         vsite_atom_range++;
-        natperthread = (vsite_atom_range + vsite->nthreads - 1) / vsite->nthreads;
+        natperthread = (vsite_atom_range + numThreads_ - 1) / numThreads_;
     }
     else
     {
@@ -2211,53 +2489,52 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
          * When assigning vsites to threads, we should take care that the last
          * threads also covers the non-local range.
          */
-        vsite_atom_range = mdatoms->nr;
-        natperthread     = (mdatoms->homenr + vsite->nthreads - 1) / vsite->nthreads;
+        vsite_atom_range = mdatoms.nr;
+        natperthread     = (mdatoms.homenr + numThreads_ - 1) / numThreads_;
     }
 
     if (debug)
     {
         fprintf(debug, "virtual site thread dist: natoms %d, range %d, natperthread %d\n",
-                mdatoms->nr, vsite_atom_range, natperthread);
+                mdatoms.nr, vsite_atom_range, natperthread);
     }
 
     /* To simplify the vsite assignment, we make an index which tells us
      * to which task particles, both non-vsites and vsites, are assigned.
      */
-    vsite->taskIndex.resize(mdatoms->nr);
+    taskIndex_.resize(mdatoms.nr);
 
     /* Initialize the task index array. Here we assign the non-vsite
      * particles to task=thread, so we easily figure out if vsites
      * depend on local and/or non-local particles in assignVsitesToThread.
      */
-    gmx::ArrayRef<int> taskIndex = vsite->taskIndex;
     {
         int thread = 0;
-        for (int i = 0; i < mdatoms->nr; i++)
+        for (int i = 0; i < mdatoms.nr; i++)
         {
-            if (mdatoms->ptype[i] == eptVSite)
+            if (mdatoms.ptype[i] == eptVSite)
             {
                 /* vsites are not assigned to a task yet */
-                taskIndex[i] = -1;
+                taskIndex_[i] = -1;
             }
             else
             {
                 /* assign non-vsite particles to task thread */
-                taskIndex[i] = thread;
+                taskIndex_[i] = thread;
             }
-            if (i == (thread + 1) * natperthread && thread < vsite->nthreads)
+            if (i == (thread + 1) * natperthread && thread < numThreads_)
             {
                 thread++;
             }
         }
     }
 
-#pragma omp parallel num_threads(vsite->nthreads)
+#pragma omp parallel num_threads(numThreads_)
     {
         try
         {
             int          thread = gmx_omp_get_thread_num();
-            VsiteThread& tData  = *vsite->tData[thread];
+            VsiteThread& tData  = *tData_[thread];
 
             /* Clear the buffer use flags that were set before */
             if (tData.useInterdependentTask)
@@ -2271,7 +2548,7 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
                 clearTaskForceBufferUsedElements(&idTask);
 
                 idTask.vsite.resize(0);
-                for (int t = 0; t < vsite->nthreads; t++)
+                for (int t = 0; t < numThreads_; t++)
                 {
                     AtomIndex& atomIndex = idTask.atomIndex[t];
                     int        natom     = atomIndex.atom.size();
@@ -2307,17 +2584,17 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
 
             /* Assign all vsites that can execute independently on threads */
             tData.rangeStart = thread * natperthread;
-            if (thread < vsite->nthreads - 1)
+            if (thread < numThreads_ - 1)
             {
                 tData.rangeEnd = (thread + 1) * natperthread;
             }
             else
             {
                 /* The last thread should cover up to the end of the range */
-                tData.rangeEnd = mdatoms->nr;
+                tData.rangeEnd = mdatoms.nr;
             }
-            assignVsitesToThread(&tData, thread, vsite->nthreads, natperthread, taskIndex, ilist,
-                                 ip, mdatoms->ptype);
+            assignVsitesToThread(&tData, thread, numThreads_, natperthread, taskIndex_, ilists,
+                                 iparams, mdatoms.ptype);
 
             if (tData.useInterdependentTask)
             {
@@ -2335,7 +2612,7 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
 
                 idTask.spreadTask.resize(0);
                 idTask.reduceTask.resize(0);
-                for (int t = 0; t < vsite->nthreads; t++)
+                for (int t = 0; t < numThreads_; t++)
                 {
                     /* Do we write to the force buffer of task t? */
                     if (!idTask.atomIndex[t].atom.empty())
@@ -2343,7 +2620,7 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
                         idTask.spreadTask.push_back(t);
                     }
                     /* Does task t write to our force buffer? */
-                    if (!vsite->tData[t]->idTask.atomIndex[thread].atom.empty())
+                    if (!tData_[t]->idTask.atomIndex[thread].atom.empty())
                     {
                         idTask.reduceTask.push_back(t);
                     }
@@ -2355,28 +2632,27 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
     /* Assign all remaining vsites, that will have taskIndex[]=2*vsite->nthreads,
      * to a single task that will not run in parallel with other tasks.
      */
-    assignVsitesToSingleTask(vsite->tData[vsite->nthreads].get(), 2 * vsite->nthreads, taskIndex,
-                             ilist, ip);
+    assignVsitesToSingleTask(tData_[numThreads_].get(), 2 * numThreads_, taskIndex_, ilists, iparams);
 
-    if (debug && vsite->nthreads > 1)
+    if (debug && numThreads_ > 1)
     {
         fprintf(debug, "virtual site useInterdependentTask %d, nuse:\n",
-                static_cast<int>(vsite->tData[0]->useInterdependentTask));
-        for (int th = 0; th < vsite->nthreads + 1; th++)
+                static_cast<int>(tData_[0]->useInterdependentTask));
+        for (int th = 0; th < numThreads_ + 1; th++)
         {
-            fprintf(debug, " %4d", vsite->tData[th]->idTask.nuse);
+            fprintf(debug, " %4d", tData_[th]->idTask.nuse);
         }
         fprintf(debug, "\n");
 
         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
         {
-            if (!ilist[ftype].empty())
+            if (!ilists[ftype].empty())
             {
                 fprintf(debug, "%-20s thread dist:", interaction_function[ftype].longname);
-                for (int th = 0; th < vsite->nthreads + 1; th++)
+                for (int th = 0; th < numThreads_ + 1; th++)
                 {
-                    fprintf(debug, " %4d %4d ", vsite->tData[th]->ilist[ftype].size(),
-                            vsite->tData[th]->idTask.ilist[ftype].size());
+                    fprintf(debug, " %4d %4d ", tData_[th]->ilist[ftype].size(),
+                            tData_[th]->idTask.ilist[ftype].size());
                 }
                 fprintf(debug, "\n");
             }
@@ -2384,12 +2660,11 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
     }
 
 #ifndef NDEBUG
-    int nrOrig     = vsiteIlistNrCount(ilist);
+    int nrOrig     = vsiteIlistNrCount(ilists);
     int nrThreaded = 0;
-    for (int th = 0; th < vsite->nthreads + 1; th++)
+    for (int th = 0; th < numThreads_ + 1; th++)
     {
-        nrThreaded += vsiteIlistNrCount(vsite->tData[th]->ilist)
-                      + vsiteIlistNrCount(vsite->tData[th]->idTask.ilist);
+        nrThreaded += vsiteIlistNrCount(tData_[th]->ilist) + vsiteIlistNrCount(tData_[th]->idTask.ilist);
     }
     GMX_ASSERT(nrThreaded == nrOrig,
                "The number of virtual sites assigned to all thread task has to match the total "
@@ -2397,10 +2672,17 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
 #endif
 }
 
-void set_vsite_top(gmx_vsite_t* vsite, const gmx_localtop_t* top, const t_mdatoms* md)
+void VirtualSitesHandler::Impl::setVirtualSites(ArrayRef<const InteractionList> ilists,
+                                                const t_mdatoms&                mdatoms)
 {
-    if (vsite->nthreads > 1)
-    {
-        split_vsites_over_threads(top->idef.il, top->idef.iparams, md, vsite);
-    }
+    ilists_ = ilists;
+
+    threadingInfo_.setVirtualSites(ilists, iparams_, mdatoms, domainInfo_.useDomdec());
 }
+
+void VirtualSitesHandler::setVirtualSites(ArrayRef<const InteractionList> ilists, const t_mdatoms& mdatoms)
+{
+    impl_->setVirtualSites(ilists, mdatoms);
+}
+
+} // namespace gmx
index e6894059421d03a4349526e196a0cb8a2513e475..be1d8c9ebc0f46c6fd4ba3738ac8cc9d1151b9e1 100644 (file)
  * To help us fund GROMACS development, we humbly ask that you cite
  * the research papers on the package. Check out http://www.gromacs.org.
  */
+/*! \libinternal \file
+ * \brief Declares the VirtualSitesHandler class and vsite standalone functions
+ *
+ * \author Berk Hess <hess@kth.se>
+ * \ingroup module_mdlib
+ * \inlibraryapi
+ */
+
 #ifndef GMX_MDLIB_VSITE_H
 #define GMX_MDLIB_VSITE_H
 
 #include <memory>
 
 #include "gromacs/math/vectypes.h"
-#include "gromacs/pbcutil/ishift.h"
 #include "gromacs/topology/idef.h"
 #include "gromacs/utility/arrayref.h"
 #include "gromacs/utility/basedefinitions.h"
+#include "gromacs/utility/classhelpers.h"
 #include "gromacs/utility/real.h"
 
-struct gmx_localtop_t;
+struct gmx_domdec_t;
 struct gmx_mtop_t;
 struct t_commrec;
 struct InteractionList;
 struct t_mdatoms;
 struct t_nrnb;
 struct gmx_wallcycle;
-struct VsiteThread;
 enum class PbcType : int;
 
 namespace gmx
 {
 class RangePartitioning;
-}
 
-/* The start and end values of for the vsite indices in the ftype enum.
- * The validity of these values is checked in init_vsite.
+/*! \brief The start value of the vsite indices in the ftype enum
+ *
+ * The validity of the start and end values is checked in makeVirtualSitesHandler().
  * This is used to avoid loops over all ftypes just to get the vsite entries.
  * (We should replace the fixed ilist array by only the used entries.)
  */
 static constexpr int c_ftypeVsiteStart = F_VSITE2;
-static constexpr int c_ftypeVsiteEnd   = F_VSITEN + 1;
+//! The start and end value of the vsite indices in the ftype enum
+static constexpr int c_ftypeVsiteEnd = F_VSITEN + 1;
 
-/* Type for storing PBC atom information for all vsite types in the system */
+//! Type for storing PBC atom information for all vsite types in the system
 typedef std::array<std::vector<int>, c_ftypeVsiteEnd - c_ftypeVsiteStart> VsitePbc;
 
-/* Data for handling vsites, needed with OpenMP threading or with charge-groups and PBC */
-struct gmx_vsite_t
+/*! \libinternal
+ * \brief Class that handles construction of vsites and spreading of vsite forces
+ */
+class VirtualSitesHandler
 {
-    gmx_vsite_t();
-
-    ~gmx_vsite_t();
-
-    /* The number of vsites that cross update groups, when =0 no PBC treatment is needed */
-    int numInterUpdategroupVsites;
-    int nthreads;                                    /* Number of threads used for vsites       */
-    std::vector<std::unique_ptr<VsiteThread>> tData; /* Thread local vsites and work structs    */
-    std::vector<int> taskIndex;                      /* Work array                              */
-    bool useDomdec; /* Tells whether we use domain decomposition with more than 1 DD rank */
+public:
+    //! Constructor, used only be the makeVirtualSitesHandler() factory function
+    VirtualSitesHandler(const gmx_mtop_t& mtop, gmx_domdec_t* domdec, PbcType pbcType);
+
+    ~VirtualSitesHandler();
+
+    //! Returns the number of virtual sites acting over multiple update groups
+    int numInterUpdategroupVirtualSites() const;
+
+    //! Set VSites and distribute VSite work over threads, should be called after each DD partitioning
+    void setVirtualSites(ArrayRef<const InteractionList> ilist, const t_mdatoms& mdatoms);
+
+    /*! \brief Create positions of vsite atoms based for the local system
+     *
+     * \param[in,out] x        The coordinates
+     * \param[in]     dt       The time step
+     * \param[in,out] v        When not empty, velocities for vsites are set as displacement/dt
+     * \param[in]     box      The box
+     */
+    void construct(ArrayRef<RVec> x, real dt, ArrayRef<RVec> v, const matrix box) const;
+
+    //! Tells how to handle virial contributions due to virtual sites
+    enum class VirialHandling : int
+    {
+        None,     //!< Do not compute virial contributions
+        Pbc,      //!< Add contributions working over PBC to shift forces
+        NonLinear //!< Compute contributions due to non-linear virtual sites
+    };
+
+    /*! \brief Spread the force operating on the vsite atoms on the surrounding atoms.
+     *
+     * vsite should point to a valid object.
+     * The virialHandling parameter determines how virial contributions are handled.
+     * If this is set to Linear, shift forces are accumulated into fshift.
+     * If this is set to NonLinear, non-linear contributions are added to virial.
+     * This non-linear correction is required when the virial is not calculated
+     * afterwards from the particle position and forces, but in a different way,
+     * as for instance for the PME mesh contribution.
+     */
+    void spreadForces(ArrayRef<const RVec> x,
+                      ArrayRef<RVec>       f,
+                      VirialHandling       virialHandling,
+                      ArrayRef<RVec>       fshift,
+                      matrix               virial,
+                      t_nrnb*              nrnb,
+                      const matrix         box,
+                      gmx_wallcycle*       wcycle);
+
+private:
+    //! Implementation type.
+    class Impl;
+    //! Implementation object.
+    PrivateImplPointer<Impl> impl_;
 };
 
 /*! \brief Create positions of vsite atoms based for the local system
  *
- * \param[in]     vsite    The vsite struct, when nullptr is passed, no MPI and no multi-threading
- *                         is used
  * \param[in,out] x        The coordinates
- * \param[in]     dt       The time step
- * \param[in,out] v        When != nullptr, velocities for vsites are set as displacement/dt
  * \param[in]     ip       Interaction parameters
  * \param[in]     ilist    The interaction list
- * \param[in]     pbcType  The type of periodic boundary conditions
- * \param[in]     bMolPBC  When true, molecules are broken over PBC
- * \param[in]     cr       The communication record
- * \param[in]     box      The box
  */
-void construct_vsites(const gmx_vsite_t*                   vsite,
-                      rvec                                 x[],
-                      real                                 dt,
-                      rvec                                 v[],
-                      gmx::ArrayRef<const t_iparams>       ip,
-                      gmx::ArrayRef<const InteractionList> ilist,
-                      PbcType                              pbcType,
-                      gmx_bool                             bMolPBC,
-                      const t_commrec*                     cr,
-                      const matrix                         box);
+void constructVirtualSites(ArrayRef<RVec>                  x,
+                           ArrayRef<const t_iparams>       ip,
+                           ArrayRef<const InteractionList> ilist);
 
 /*! \brief Create positions of vsite atoms for the whole system assuming all molecules are wholex
  *
  * \param[in]     mtop  The global topology
  * \param[in,out] x     The global coordinates
  */
-void constructVsitesGlobal(const gmx_mtop_t& mtop, gmx::ArrayRef<gmx::RVec> x);
-
-void spread_vsite_f(const gmx_vsite_t*            vsite,
-                    const rvec                    x[],
-                    rvec                          f[],
-                    rvec*                         fshift,
-                    gmx_bool                      VirCorr,
-                    matrix                        vir,
-                    t_nrnb*                       nrnb,
-                    const InteractionDefinitions& idef,
-                    PbcType                       pbcType,
-                    gmx_bool                      bMolPBC,
-                    const matrix                  box,
-                    const t_commrec*              cr,
-                    gmx_wallcycle*                wcycle);
-/* Spread the force operating on the vsite atoms on the surrounding atoms.
- * If fshift!=NULL also update the shift forces.
- * If VirCorr=TRUE add the virial correction for non-linear vsite constructs
- * to vir. This correction is required when the virial is not calculated
- * afterwards from the particle position and forces, but in a different way,
- * as for instance for the PME mesh contribution.
- */
+void constructVirtualSitesGlobal(const gmx_mtop_t& mtop, ArrayRef<RVec> x);
+
+//! Tells how to handle virial contributions due to virtual sites
+enum class VirtualSiteVirialHandling : int
+{
+    None,     //!< Do not compute virial contributions
+    Pbc,      //!< Add contributions working over PBC to shift forces
+    NonLinear //!< Compute contributions due to non-linear virtual sites
+};
 
-/* Return the number of non-linear virtual site constructions in the system */
+//! Return the number of non-linear virtual site constructions in the system
 int countNonlinearVsites(const gmx_mtop_t& mtop);
 
-/* Return the number of virtual sites that cross update groups
+/*! \brief Return the number of virtual sites that cross update groups
  *
  * \param[in] mtop                           The global topology
  * \param[in] updateGroupingPerMoleculetype  Update grouping per molecule type, pass empty when not using update groups
  */
-int countInterUpdategroupVsites(const gmx_mtop_t&                           mtop,
-                                gmx::ArrayRef<const gmx::RangePartitioning> updateGroupingPerMoleculetype);
+int countInterUpdategroupVsites(const gmx_mtop_t&                 mtop,
+                                ArrayRef<const RangePartitioning> updateGroupingPerMoleculetype);
 
-/* Initialize the virtual site struct,
+/*! \brief Create the virtual site handler
  *
- * \param[in] mtop  The global topology
- * \param[in] cr    The communication record
- * \returns A valid vsite struct or nullptr when there are no virtual sites
- */
-std::unique_ptr<gmx_vsite_t> initVsite(const gmx_mtop_t& mtop, const t_commrec* cr);
-
-void split_vsites_over_threads(gmx::ArrayRef<const InteractionList> ilist,
-                               gmx::ArrayRef<const t_iparams>       ip,
-                               const t_mdatoms*                     mdatoms,
-                               gmx_vsite_t*                         vsite);
-/* Divide the vsite work-load over the threads.
- * Should be called at the end of the domain decomposition.
+ * \param[in] mtop      The global topology
+ * \param[in] cr        The communication record
+ * \param[in] pbcType   The type of PBC
+ * \returns A valid vsite handler object or nullptr when there are no virtual sites
  */
+std::unique_ptr<VirtualSitesHandler> makeVirtualSitesHandler(const gmx_mtop_t& mtop,
+                                                             const t_commrec*  cr,
+                                                             PbcType           pbcType);
 
-void set_vsite_top(gmx_vsite_t* vsite, const gmx_localtop_t* top, const t_mdatoms* md);
-/* Set some vsite data for runs without domain decomposition.
- * Should be called once after init_vsite, before calling other routines.
- */
+} // namespace gmx
 
 #endif
index 8f7ee114832d4791962941c2da7c3289fbb0a9b8..c51aae29a4c9ab0d400a5faf3d19f0047f0b8f6a 100644 (file)
@@ -51,7 +51,6 @@ struct gmx_mtop_t;
 struct gmx_membed_t;
 struct gmx_multisim_t;
 struct gmx_output_env_t;
-struct gmx_vsite_t;
 struct gmx_wallcycle;
 struct gmx_walltime_accounting;
 struct ObservablesHistory;
@@ -79,6 +78,7 @@ class MDLogger;
 class MDAtoms;
 class StopHandlerBuilder;
 struct MdrunOptions;
+class VirtualSitesHandler;
 
 /*! \internal
  * \brief The Simulator interface
@@ -108,7 +108,7 @@ public:
                const gmx_output_env_t*             oenv,
                const MdrunOptions&                 mdrunOptions,
                StartingBehavior                    startingBehavior,
-               gmx_vsite_t*                        vsite,
+               VirtualSitesHandler*                vsite,
                Constraints*                        constr,
                gmx_enfrot*                         enforcedRotation,
                BoxDeformation*                     deform,
@@ -192,7 +192,7 @@ protected:
     //! Whether the simulation will start afresh, or restart with/without appending.
     const StartingBehavior startingBehavior;
     //! Handles virtual sites.
-    gmx_vsite_t* vsite;
+    VirtualSitesHandler* vsite;
     //! Handles constraints.
     Constraints* constr;
     //! Handles enforced rotation.
index 10b76df3778955f5a83038e18ce2edd367911be7..34037ecef7ba5c49607f29af2f522dd526dbc471 100644 (file)
@@ -507,8 +507,7 @@ void gmx::LegacySimulator::do_md()
         if (vsite)
         {
             /* Construct the virtual sites for the initial configuration */
-            construct_vsites(vsite, state->x.rvec_array(), ir->delta_t, nullptr, top.idef.iparams,
-                             top.idef.il, fr->pbcType, fr->bMolPBC, cr, state->box);
+            vsite->construct(state->x, ir->delta_t, {}, state->box);
         }
     }
 
@@ -1391,8 +1390,7 @@ void gmx::LegacySimulator::do_md()
         if (vsite != nullptr)
         {
             wallcycle_start(wcycle, ewcVSITECONSTR);
-            construct_vsites(vsite, state->x.rvec_array(), ir->delta_t, state->v.rvec_array(),
-                             top.idef.iparams, top.idef.il, fr->pbcType, fr->bMolPBC, cr, state->box);
+            vsite->construct(state->x, ir->delta_t, state->v, state->box);
             wallcycle_stop(wcycle, ewcVSITECONSTR);
         }
 
index c3697842a4657355252d42c3d8796b0c76a50c35..0b7e667ff40643fc0438cdf0f62b14d6ef13a064 100644 (file)
@@ -454,9 +454,7 @@ void gmx::LegacySimulator::do_mimic()
         if (vsite != nullptr)
         {
             wallcycle_start(wcycle, ewcVSITECONSTR);
-            construct_vsites(vsite, as_rvec_array(state->x.data()), ir->delta_t,
-                             as_rvec_array(state->v.data()), top.idef.iparams, top.idef.il,
-                             fr->pbcType, fr->bMolPBC, cr, state->box);
+            vsite->construct(state->x, ir->delta_t, state->v, state->box);
             wallcycle_stop(wcycle, ewcVSITECONSTR);
         }
 
index c6959e3643f567f4cc37fca2eca5547f81b10f8b..dc2658f1d3290c87490f879901b8ef446da25c47 100644 (file)
 using gmx::ArrayRef;
 using gmx::MdrunScheduleWorkload;
 using gmx::RVec;
+using gmx::VirtualSitesHandler;
 
 //! Utility structure for manipulating states during EM
 typedef struct
@@ -372,7 +373,7 @@ static void init_em(FILE*                fplog,
                     t_forcerec*          fr,
                     gmx::MDAtoms*        mdAtoms,
                     gmx_global_stat_t*   gstat,
-                    gmx_vsite_t*         vsite,
+                    VirtualSitesHandler* vsite,
                     gmx::Constraints*    constr,
                     gmx_shellfc_t**      shellfc)
 {
@@ -411,7 +412,6 @@ static void init_em(FILE*                fplog,
         }
     }
 
-    auto mdatoms = mdAtoms->mdatoms();
     if (DOMAINDECOMP(cr))
     {
         dd_init_local_state(cr->dd, state_global, &ems->s);
@@ -431,11 +431,6 @@ static void init_em(FILE*                fplog,
 
         mdAlgorithmsSetupAtomData(cr, ir, *top_global, top, fr, &ems->f, mdAtoms, constr, vsite,
                                   shellfc ? *shellfc : nullptr);
-
-        if (vsite)
-        {
-            set_vsite_top(vsite, top, mdatoms);
-        }
     }
 
     update_mdatoms(mdAtoms->mdatoms(), ems->s.lambda[efptMASS]);
@@ -721,7 +716,7 @@ static void em_dd_partition_system(FILE*                fplog,
                                    gmx_localtop_t*      top,
                                    gmx::MDAtoms*        mdAtoms,
                                    t_forcerec*          fr,
-                                   gmx_vsite_t*         vsite,
+                                   VirtualSitesHandler* vsite,
                                    gmx::Constraints*    constr,
                                    t_nrnb*              nrnb,
                                    gmx_wallcycle_t      wcycle)
@@ -784,7 +779,7 @@ public:
     //! Coordinates global reduction.
     gmx_global_stat_t gstat;
     //! Handles virtual sites.
-    gmx_vsite_t* vsite;
+    VirtualSitesHandler* vsite;
     //! Handles constraints.
     gmx::Constraints* constr;
     //! Handles strange things.
@@ -826,8 +821,7 @@ void EnergyEvaluator::run(em_state_t* ems, rvec mu_tot, tensor vir, tensor pres,
 
     if (vsite)
     {
-        construct_vsites(vsite, ems->s.x.rvec_array(), 1, nullptr, top->idef.iparams, top->idef.il,
-                         fr->pbcType, fr->bMolPBC, cr, ems->s.box);
+        vsite->construct(ems->s.x, 1, {}, ems->s.box);
     }
 
     if (DOMAINDECOMP(cr) && bNS)
@@ -1748,8 +1742,7 @@ void LegacySimulator::do_lbfgs()
 
     if (vsite)
     {
-        construct_vsites(vsite, state_global->x.rvec_array(), 1, nullptr, top.idef.iparams,
-                         top.idef.il, fr->pbcType, fr->bMolPBC, cr, state_global->box);
+        vsite->construct(state_global->x, 1, {}, state_global->box);
     }
 
     /* Call the force routine and some auxiliary (neighboursearching etc.) */
index ed3675117b886aa0bcaa0b6f65a16ca13b4f6682..5a66c06a4d174d7fc41d036ae4104b6d893b3a3c 100644 (file)
 #include "shellfc.h"
 
 using gmx::SimulationSignaller;
+using gmx::VirtualSitesHandler;
 
 /*! \brief Copy the state from \p rerunFrame to \p globalState and, if requested, construct vsites
  *
@@ -142,17 +143,13 @@ using gmx::SimulationSignaller;
  * \param[in,out] globalState     The global state container
  * \param[in]     constructVsites When true, vsite coordinates are constructed
  * \param[in]     vsite           Vsite setup, can be nullptr when \p constructVsites = false
- * \param[in]     idef            Topology parameters, used for constructing vsites
  * \param[in]     timeStep        Time step, used for constructing vsites
- * \param[in]     forceRec        Force record, used for constructing vsites
  */
-static void prepareRerunState(const t_trxframe&             rerunFrame,
-                              t_state*                      globalState,
-                              bool                          constructVsites,
-                              const gmx_vsite_t*            vsite,
-                              const InteractionDefinitions& idef,
-                              double                        timeStep,
-                              const t_forcerec&             forceRec)
+static void prepareRerunState(const t_trxframe&          rerunFrame,
+                              t_state*                   globalState,
+                              bool                       constructVsites,
+                              const VirtualSitesHandler* vsite,
+                              double                     timeStep)
 {
     auto x      = makeArrayRef(globalState->x);
     auto rerunX = arrayRefFromArray(reinterpret_cast<gmx::RVec*>(rerunFrame.x), globalState->natoms);
@@ -163,9 +160,7 @@ static void prepareRerunState(const t_trxframe&             rerunFrame,
     {
         GMX_ASSERT(vsite, "Need valid vsite for constructing vsites");
 
-        construct_vsites(vsite, globalState->x.rvec_array(), timeStep, globalState->v.rvec_array(),
-                         idef.iparams, idef.il, forceRec.pbcType, forceRec.bMolPBC, nullptr,
-                         globalState->box);
+        vsite->construct(globalState->x, timeStep, globalState->v, globalState->box);
     }
 }
 
@@ -494,7 +489,7 @@ void gmx::LegacySimulator::do_rerun()
                           "decomposition, "
                           "use a single rank");
             }
-            prepareRerunState(rerun_fr, state_global, constructVsites, vsite, top.idef, ir->delta_t, *fr);
+            prepareRerunState(rerun_fr, state_global, constructVsites, vsite, ir->delta_t);
         }
 
         isLastStep = isLastStep || stopHandler->stoppingAfterCurrentStep(bNS);
@@ -567,9 +562,7 @@ void gmx::LegacySimulator::do_rerun()
         if (vsite != nullptr)
         {
             wallcycle_start(wcycle, ewcVSITECONSTR);
-            construct_vsites(vsite, as_rvec_array(state->x.data()), ir->delta_t,
-                             as_rvec_array(state->v.data()), top.idef.iparams, top.idef.il,
-                             fr->pbcType, fr->bMolPBC, cr, state->box);
+            vsite->construct(state->x, ir->delta_t, state->v, state->box);
             wallcycle_stop(wcycle, ewcVSITECONSTR);
         }
 
index daec6ce305bb2a94f3d7faa93b24f9807b7fa494..33f3a23b0941a7c777ca9d3cdadf8c064a27c95d 100644 (file)
@@ -1349,10 +1349,10 @@ int Mdrunner::mdrunner()
                              globalState.get(), cr, &mdrunOptions.checkpointOptions.period);
     }
 
-    const bool                   thisRankHasPmeGpuTask = gpuTaskAssignments.thisRankHasPmeGpuTask();
-    std::unique_ptr<MDAtoms>     mdAtoms;
-    std::unique_ptr<gmx_vsite_t> vsite;
-    std::unique_ptr<GpuBonded>   gpuBonded;
+    const bool               thisRankHasPmeGpuTask = gpuTaskAssignments.thisRankHasPmeGpuTask();
+    std::unique_ptr<MDAtoms> mdAtoms;
+    std::unique_ptr<VirtualSitesHandler> vsite;
+    std::unique_ptr<GpuBonded>           gpuBonded;
 
     t_nrnb nrnb;
     if (thisRankHasDuty(cr, DUTY_PP))
@@ -1418,7 +1418,7 @@ int Mdrunner::mdrunner()
         }
 
         /* Initialize the virtual site communication */
-        vsite = initVsite(mtop, cr);
+        vsite = makeVirtualSitesHandler(mtop, cr, fr->pbcType);
 
         calc_shifts(box, fr->shift_vec);
 
@@ -1439,7 +1439,7 @@ int Mdrunner::mdrunner()
                  * for the initial distribution in the domain decomposition
                  * and for the initial shell prediction.
                  */
-                constructVsitesGlobal(mtop, globalState->x);
+                constructVirtualSitesGlobal(mtop, globalState->x);
             }
         }
 
index 4eeae2fe782d5183efa0e12cabbabd97fba3af6c..1e17a1551ac9e88b99dc27c9f685af3f3943142a 100644 (file)
@@ -922,7 +922,7 @@ void relax_shell_flexcon(FILE*                         fplog,
                          gmx::MdrunScheduleWorkload*   runScheduleWork,
                          double                        t,
                          rvec                          mu_tot,
-                         const gmx_vsite_t*            vsite,
+                         gmx::VirtualSitesHandler*     vsite,
                          const DDBalanceRegionHandler& ddBalanceRegionHandler)
 {
     real Epot[2], df[2];
@@ -941,8 +941,6 @@ void relax_shell_flexcon(FILE*                         fplog,
     ArrayRef<t_shell> shells       = shfc->shells;
     const int         nflexcon     = shfc->nflexcon;
 
-    const InteractionDefinitions& idef = top->idef;
-
     if (DOMAINDECOMP(cr))
     {
         nat = dd_natoms_vsite(cr->dd);
@@ -1082,9 +1080,7 @@ void relax_shell_flexcon(FILE*                         fplog,
     {
         if (vsite)
         {
-            construct_vsites(vsite, as_rvec_array(pos[Min].data()), inputrec->delta_t,
-                             as_rvec_array(v.data()), idef.iparams, idef.il, fr->pbcType,
-                             fr->bMolPBC, cr, box);
+            vsite->construct(pos[Min], inputrec->delta_t, v, box);
         }
 
         if (nflexcon)
index 5a3f31d7e403588431596a2d99b5a0e17d9d56ef..7b5be6592dd881100608a945c24ebd955037f75f 100644 (file)
 
 #include <cstdio>
 
-#include "gromacs/mdlib/vsite.h"
+#include "gromacs/math/vectypes.h"
 #include "gromacs/timing/wallcycle.h"
 #include "gromacs/topology/atoms.h"
 
 class DDBalanceRegionHandler;
 struct gmx_enerdata_t;
 struct gmx_enfrot;
+struct gmx_localtop_t;
 struct gmx_multisim_t;
 struct gmx_shellfc_t;
 struct gmx_mtop_t;
@@ -55,6 +56,8 @@ struct pull_t;
 struct t_forcerec;
 struct t_fcdata;
 struct t_inputrec;
+struct t_mdatoms;
+struct t_nrnb;
 class t_state;
 
 namespace gmx
@@ -66,6 +69,7 @@ class ArrayRefWithPadding;
 class Constraints;
 class ImdSession;
 class MdrunScheduleWorkload;
+class VirtualSitesHandler;
 } // namespace gmx
 
 /* Initialization function, also predicts the initial shell postions.
@@ -109,7 +113,7 @@ void relax_shell_flexcon(FILE*                               log,
                          gmx::MdrunScheduleWorkload*         runScheduleWork,
                          double                              t,
                          rvec                                mu_tot,
-                         const gmx_vsite_t*                  vsite,
+                         gmx::VirtualSitesHandler*           vsite,
                          const DDBalanceRegionHandler&       ddBalanceRegionHandler);
 
 /* Print some final output and delete shellfc */
index 67cc9d52e138907f131c1c2237c6ab7b39b36c82..db030fda3e7e6d34a4879a2482ceb6398c587331 100644 (file)
@@ -70,7 +70,7 @@ DomDecHelper::DomDecHelper(bool                               isVerbose,
                            t_nrnb*                            nrnb,
                            gmx_wallcycle*                     wcycle,
                            t_forcerec*                        fr,
-                           gmx_vsite_t*                       vsite,
+                           VirtualSitesHandler*               vsite,
                            ImdSession*                        imdSession,
                            pull_t*                            pull_work) :
     nextNSStep_(-1),
index 181698f3cc6de26a2f56499bb5bd3c2a96212717..5375c73e244c1004bc9ca3f27688fb6355f8659d 100644 (file)
@@ -1,7 +1,7 @@
 /*
  * This file is part of the GROMACS molecular simulation package.
  *
- * Copyright (c) 2019, by the GROMACS development team, led by
+ * Copyright (c) 2019,2020, by the GROMACS development team, led by
  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
  * and including many others, as listed in the AUTHORS file in the
  * top-level source directory and at http://www.gromacs.org.
@@ -45,7 +45,6 @@
 #include "modularsimulatorinterfaces.h"
 
 struct gmx_localtop_t;
-struct gmx_vsite_t;
 struct gmx_wallcycle;
 struct pull_t;
 struct t_commrec;
@@ -61,6 +60,7 @@ class MDAtoms;
 class MDLogger;
 class StatePropagatorData;
 class TopologyHolder;
+class VirtualSitesHandler;
 
 //! \addtogroup module_modularsimulator
 //! \{
@@ -102,7 +102,7 @@ public:
                  t_nrnb*                            nrnb,
                  gmx_wallcycle*                     wcycle,
                  t_forcerec*                        fr,
-                 gmx_vsite_t*                       vsite,
+                 VirtualSitesHandler*               vsite,
                  ImdSession*                        imdSession,
                  pull_t*                            pull_work);
 
@@ -160,7 +160,7 @@ private:
     //! Parameters for force calculations.
     t_forcerec* fr_;
     //! Handles virtual sites.
-    gmx_vsite_t* vsite_;
+    VirtualSitesHandler* vsite_;
     //! The Interactive Molecular Dynamics session.
     ImdSession* imdSession_;
     //! The pull work object.
index 734ba4eda49981a0b06b87c7083bd0bdcca1ec45..d2ca77e8b27a2bb6d14da2347be201b4e2e86499 100644 (file)
@@ -79,7 +79,7 @@ ForceElement::ForceElement(StatePropagatorData*           statePropagatorData,
                            t_fcdata*                      fcd,
                            gmx_wallcycle*                 wcycle,
                            MdrunScheduleWorkload*         runScheduleWork,
-                           gmx_vsite_t*                   vsite,
+                           VirtualSitesHandler*           vsite,
                            ImdSession*                    imdSession,
                            pull_t*                        pull_work,
                            Constraints*                   constr,
index 6bf90b37a4d9f46d590a24b77fa46324e4ef15d4..9abbe21ec76b6ba176870e284283c963dbe3abc7 100644 (file)
@@ -70,6 +70,7 @@ class ImdSession;
 class MDAtoms;
 class MdrunScheduleWorkload;
 class StatePropagatorData;
+class VirtualSitesHandler;
 
 /*! \libinternal
  * \ingroup module_modularsimulator
@@ -100,7 +101,7 @@ public:
                  t_fcdata*                      fcd,
                  gmx_wallcycle*                 wcycle,
                  MdrunScheduleWorkload*         runScheduleWork,
-                 gmx_vsite_t*                   vsite,
+                 VirtualSitesHandler*           vsite,
                  ImdSession*                    imdSession,
                  pull_t*                        pull_work,
                  Constraints*                   constr,
@@ -188,7 +189,7 @@ private:
     //! Parameters for force calculations.
     t_forcerec* fr_;
     //! Handles virtual sites.
-    gmx_vsite_t* vsite_;
+    VirtualSitesHandler* vsite_;
     //! The Interactive Molecular Dynamics session.
     ImdSession* imdSession_;
     //! The pull work object.
index ed4ba05e0782aa8db7c97ececb0d23a575e758d7..f7c811a9e389e8f430248d823f38f5fd350590ca 100644 (file)
 
 namespace gmx
 {
-TopologyHolder::TopologyHolder(const gmx_mtop_t& globalTopology,
-                               const t_commrec*  cr,
-                               const t_inputrec* inputrec,
-                               t_forcerec*       fr,
-                               MDAtoms*          mdAtoms,
-                               Constraints*      constr,
-                               gmx_vsite_t*      vsite) :
+TopologyHolder::TopologyHolder(const gmx_mtop_t&    globalTopology,
+                               const t_commrec*     cr,
+                               const t_inputrec*    inputrec,
+                               t_forcerec*          fr,
+                               MDAtoms*             mdAtoms,
+                               Constraints*         constr,
+                               VirtualSitesHandler* vsite) :
     globalTopology_(globalTopology),
     localTopology_(std::make_unique<gmx_localtop_t>(globalTopology.ffparams))
 {
index 53f03e6599f39bff0eb8994deca412c188f6d7eb..5e3f1171b7b1f946f27bd1d8276e95a2c3557e09 100644 (file)
@@ -49,7 +49,6 @@
 
 struct gmx_localtop_t;
 struct gmx_mtop_t;
-struct gmx_vsite_t;
 struct t_commrec;
 struct t_forcerec;
 struct t_inputrec;
@@ -58,6 +57,7 @@ namespace gmx
 {
 class Constraints;
 class MDAtoms;
+class VirtualSitesHandler;
 
 /*! \libinternal
  * \ingroup module_modularsimulator
@@ -70,13 +70,13 @@ class TopologyHolder final
 {
 public:
     //! Constructor
-    TopologyHolder(const gmx_mtop_t& globalTopology,
-                   const t_commrec*  cr,
-                   const t_inputrec* inputrec,
-                   t_forcerec*       fr,
-                   MDAtoms*          mdAtoms,
-                   Constraints*      constr,
-                   gmx_vsite_t*      vsite);
+    TopologyHolder(const gmx_mtop_t&    globalTopology,
+                   const t_commrec*     cr,
+                   const t_inputrec*    inputrec,
+                   t_forcerec*          fr,
+                   MDAtoms*             mdAtoms,
+                   Constraints*         constr,
+                   VirtualSitesHandler* vsite);
 
     //! Get global topology
     const gmx_mtop_t& globalTopology() const;