Refactor virtual site interface
[alexxy/gromacs.git] / src / gromacs / mdlib / vsite.cpp
index bed9b8faa1cc8a67db4ae45cbd1071958a0cfa63..b2da74223b8e5fcb5c23ec2856c1c898a6b158d9 100644 (file)
  * To help us fund GROMACS development, we humbly ask that you cite
  * the research papers on the package. Check out http://www.gromacs.org.
  */
+/*! \libinternal \file
+ * \brief Implements the VirtualSitesHandler class and vsite standalone functions
+ *
+ * \author Berk Hess <hess@kth.se>
+ * \ingroup module_mdlib
+ * \inlibraryapi
+ */
+
 #include "gmxpre.h"
 
 #include "vsite.h"
  * Any remaining vsites are assigned to a separate master thread task.
  */
 
-using gmx::ArrayRef;
-using gmx::RVec;
+namespace gmx
+{
+
+//! VirialHandling is often used outside VirtualSitesHandler class members
+using VirialHandling = VirtualSitesHandler::VirialHandling;
+
+/*! \libinternal
+ * \brief Information on PBC and domain decomposition for virtual sites
+ */
+struct DomainInfo
+{
+public:
+    //! Constructs without PBC and DD
+    DomainInfo() = default;
+
+    //! Constructs with PBC and DD, if !=nullptr
+    DomainInfo(PbcType pbcType, bool haveInterUpdateGroupVirtualSites, gmx_domdec_t* domdec) :
+        pbcType_(pbcType),
+        useMolPbc_(pbcType != PbcType::No && haveInterUpdateGroupVirtualSites),
+        domdec_(domdec)
+    {
+    }
+
+    //! Returns whether we are using domain decomposition with more than 1 DD rank
+    bool useDomdec() const { return (domdec_ != nullptr); }
 
-/*! \brief List of atom indices belonging to a task */
+    //! The pbc type
+    const PbcType pbcType_ = PbcType::No;
+    //! Whether molecules are broken over PBC
+    const bool useMolPbc_ = false;
+    //! Pointer to the domain decomposition struct, nullptr without PP DD
+    const gmx_domdec_t* domdec_ = nullptr;
+};
+
+/*! \libinternal
+ * \brief List of atom indices belonging to a task
+ */
 struct AtomIndex
 {
     //! List of atom indices
     std::vector<int> atom;
 };
 
-/*! \brief Data structure for thread tasks that use constructing atoms outside their own atom range */
+/*! \libinternal
+ * \brief Data structure for thread tasks that use constructing atoms outside their own atom range
+ */
 struct InterdependentTask
 {
     //! The interaction lists, only vsite entries are used
@@ -116,7 +159,9 @@ struct InterdependentTask
     std::vector<int> reduceTask;
 };
 
-/*! \brief Vsite thread task data structure */
+/*! \libinternal
+ * \brief Vsite thread task data structure
+ */
 struct VsiteThread
 {
     //! Start of atom range of this task
@@ -126,7 +171,7 @@ struct VsiteThread
     //! The interaction lists, only vsite entries are used
     std::array<InteractionList, F_NRE> ilist;
     //! Local fshift accumulation buffer
-    rvec fshift[SHIFTS];
+    std::array<RVec, SHIFTS> fshift;
     //! Local virial dx*df accumulation buffer
     matrix dxdf;
     //! Tells if interdependent task idTask should be used (in addition to the rest of this task), this bool has the same value on all threads
@@ -139,12 +184,118 @@ struct VsiteThread
     {
         rangeStart = -1;
         rangeEnd   = -1;
-        clear_rvecs(SHIFTS, fshift);
+        for (auto& elem : fshift)
+        {
+            elem = { 0.0_real, 0.0_real, 0.0_real };
+        }
         clear_mat(dxdf);
         useInterdependentTask = false;
     }
 };
 
+
+/*! \libinternal
+ * \brief Information on how the virtual site work is divided over thread tasks
+ */
+class ThreadingInfo
+{
+public:
+    //! Constructor, retrieves the number of threads to use from gmx_omp_nthreads.h
+    ThreadingInfo();
+
+    //! Returns the number of threads to use for vsite operations
+    int numThreads() const { return numThreads_; }
+
+    //! Returns the thread data for the given thread
+    const VsiteThread& threadData(int threadIndex) const { return *tData_[threadIndex]; }
+
+    //! Returns the thread data for the given thread
+    VsiteThread& threadData(int threadIndex) { return *tData_[threadIndex]; }
+
+    //! Returns the thread data for vsites that depend on non-local vsites
+    const VsiteThread& threadDataNonLocalDependent() const { return *tData_[numThreads_]; }
+
+    //! Returns the thread data for vsites that depend on non-local vsites
+    VsiteThread& threadDataNonLocalDependent() { return *tData_[numThreads_]; }
+
+    //! Set VSites and distribute VSite work over threads, should be called after DD partitioning
+    void setVirtualSites(ArrayRef<const InteractionList> ilist,
+                         ArrayRef<const t_iparams>       iparams,
+                         const t_mdatoms&                mdatoms,
+                         bool                            useDomdec);
+
+private:
+    //! Number of threads used for vsite operations
+    const int numThreads_;
+    //! Thread local vsites and work structs
+    std::vector<std::unique_ptr<VsiteThread>> tData_;
+    //! Work array for dividing vsites over threads
+    std::vector<int> taskIndex_;
+};
+
+/*! \libinternal
+ * \brief Impl class for VirtualSitesHandler
+ */
+class VirtualSitesHandler::Impl
+{
+public:
+    //! Constructor, domdec should be nullptr without DD
+    Impl(const gmx_mtop_t& mtop, gmx_domdec_t* domdec, PbcType pbcType);
+
+    //! Returns the number of virtual sites acting over multiple update groups
+    int numInterUpdategroupVirtualSites() const { return numInterUpdategroupVirtualSites_; }
+
+    //! Set VSites and distribute VSite work over threads, should be called after DD partitioning
+    void setVirtualSites(ArrayRef<const InteractionList> ilist, const t_mdatoms& mdatoms);
+
+    /*! \brief Create positions of vsite atoms based for the local system
+     *
+     * \param[in,out] x        The coordinates
+     * \param[in]     dt       The time step
+     * \param[in,out] v        When != nullptr, velocities for vsites are set as displacement/dt
+     * \param[in]     box      The box
+     */
+    void construct(ArrayRef<RVec> x, real dt, ArrayRef<RVec> v, const matrix box) const;
+
+    /*! \brief Spread the force operating on the vsite atoms on the surrounding atoms.
+     *
+     * vsite should point to a valid object.
+     * The virialHandling parameter determines how virial contributions are handled.
+     * If this is set to Linear, shift forces are accumulated into fshift.
+     * If this is set to NonLinear, non-linear contributions are added to virial.
+     * This non-linear correction is required when the virial is not calculated
+     * afterwards from the particle position and forces, but in a different way,
+     * as for instance for the PME mesh contribution.
+     */
+    void spreadForces(ArrayRef<const RVec> x,
+                      ArrayRef<RVec>       f,
+                      VirialHandling       virialHandling,
+                      ArrayRef<RVec>       fshift,
+                      matrix               virial,
+                      t_nrnb*              nrnb,
+                      const matrix         box,
+                      gmx_wallcycle*       wcycle);
+
+private:
+    // The number of vsites that cross update groups, when =0 no PBC treatment is needed
+    const int numInterUpdategroupVirtualSites_;
+    // PBC and DD information
+    const DomainInfo domainInfo_;
+    // The interaction parameters
+    const ArrayRef<const t_iparams> iparams_;
+    // The interaction lists
+    ArrayRef<const InteractionList> ilists_;
+    // Information for handling vsite threading
+    ThreadingInfo threadingInfo_;
+};
+
+VirtualSitesHandler::~VirtualSitesHandler() = default;
+
+int VirtualSitesHandler::numInterUpdategroupVirtualSites() const
+{
+    return impl_->numInterUpdategroupVirtualSites();
+}
+
 /*! \brief Returns the sum of the vsite ilist sizes over all vsite types
  *
  * \param[in] ilist  The interaction list
@@ -160,6 +311,7 @@ static int vsiteIlistNrCount(ArrayRef<const InteractionList> ilist)
     return nr;
 }
 
+//! Computes the distance between xi and xj, pbc is used when pbc!=nullptr
 static int pbc_rvec_sub(const t_pbc* pbc, const rvec xi, const rvec xj, rvec dx)
 {
     if (pbc)
@@ -173,11 +325,13 @@ static int pbc_rvec_sub(const t_pbc* pbc, const rvec xi, const rvec xj, rvec dx)
     }
 }
 
+//! Returns the 1/norm(x)
 static inline real inverseNorm(const rvec x)
 {
     return gmx::invsqrt(iprod(x, x));
 }
 
+#ifndef DOXYGEN
 /* Vsite construction routines */
 
 static void constr_vsite2(const rvec xi, const rvec xj, rvec x, real a, const t_pbc* pbc)
@@ -398,8 +552,7 @@ static void constr_vsite4FDN(const rvec   xi,
     /* TOTAL: 47 flops */
 }
 
-
-static int constr_vsiten(const t_iatom* ia, ArrayRef<const t_iparams> ip, rvec* x, const t_pbc* pbc)
+static int constr_vsiten(const t_iatom* ia, ArrayRef<const t_iparams> ip, ArrayRef<RVec> x, const t_pbc* pbc)
 {
     rvec x1, dx;
     dvec dsum;
@@ -436,11 +589,13 @@ static int constr_vsiten(const t_iatom* ia, ArrayRef<const t_iparams> ip, rvec*
     return n3;
 }
 
-/*! \brief PBC modes for vsite construction and spreading */
+#endif // DOXYGEN
+
+//! PBC modes for vsite construction and spreading
 enum class PbcMode
 {
-    all, // Apply normal, simple PBC for all vsites
-    none // No PBC treatment needed
+    all, //!< Apply normal, simple PBC for all vsites
+    none //!< No PBC treatment needed
 };
 
 /*! \brief Returns the PBC mode based on the system PBC and vsite properties
@@ -459,15 +614,24 @@ static PbcMode getPbcMode(const t_pbc* pbcPtr)
     }
 }
 
-static void construct_vsites_thread(rvec                            x[],
-                                    real                            dt,
-                                    rvec*                           v,
+/*! \brief Executes the vsite construction task for a single thread
+ *
+ * \param[in,out] x   Coordinates to construct vsites for
+ * \param[in]     dt  Time step, needed when v is not empty
+ * \param[in,out] v   When not empty, velocities are generated for virtual sites
+ * \param[in]     ip  Interaction parameters for all interaction, only vsite parameters are used
+ * \param[in]     ilist  The interaction lists, only vsites are usesd
+ * \param[in]     pbc_null  PBC struct, used for PBC distance calculations when !=nullptr
+ */
+static void construct_vsites_thread(ArrayRef<RVec>                  x,
+                                    const real                      dt,
+                                    ArrayRef<RVec>                  v,
                                     ArrayRef<const t_iparams>       ip,
                                     ArrayRef<const InteractionList> ilist,
                                     const t_pbc*                    pbc_null)
 {
     real inv_dt;
-    if (v != nullptr)
+    if (!v.empty())
     {
         inv_dt = 1.0 / dt;
     }
@@ -575,7 +739,7 @@ static void construct_vsites_thread(rvec                            x[],
                         rvec_add(xv, dx, x[avsite]);
                     }
                 }
-                if (v != nullptr)
+                if (!v.empty())
                 {
                     /* Calculate velocity of vsite... */
                     rvec vv;
@@ -591,40 +755,44 @@ static void construct_vsites_thread(rvec                            x[],
     }
 }
 
-void construct_vsites(const gmx_vsite_t*              vsite,
-                      rvec                            x[],
-                      real                            dt,
-                      rvec*                           v,
-                      ArrayRef<const t_iparams>       ip,
-                      ArrayRef<const InteractionList> ilist,
-                      PbcType                         pbcType,
-                      gmx_bool                        bMolPBC,
-                      const t_commrec*                cr,
-                      const matrix                    box)
+/*! \brief Dispatch the vsite construction tasks for all threads
+ *
+ * \param[in]     threadingInfo  Used to divide work over threads when != nullptr
+ * \param[in,out] x   Coordinates to construct vsites for
+ * \param[in]     dt  Time step, needed when v is not empty
+ * \param[in,out] v   When not empty, velocities are generated for virtual sites
+ * \param[in]     ip  Interaction parameters for all interaction, only vsite parameters are used
+ * \param[in]     ilist  The interaction lists, only vsites are usesd
+ * \param[in]     domainInfo  Information about PBC and DD
+ * \param[in]     box  Used for PBC when PBC is set in domainInfo
+ */
+static void construct_vsites(const ThreadingInfo*            threadingInfo,
+                             ArrayRef<RVec>                  x,
+                             real                            dt,
+                             ArrayRef<RVec>                  v,
+                             ArrayRef<const t_iparams>       ip,
+                             ArrayRef<const InteractionList> ilist,
+                             const DomainInfo&               domainInfo,
+                             const matrix                    box)
 {
-    const bool useDomdec = (vsite != nullptr && vsite->useDomdec);
-    GMX_ASSERT(!useDomdec || (cr != nullptr && DOMAINDECOMP(cr)),
-               "When vsites are set up with domain decomposition, we need a valid commrec");
-    // TODO: Remove this assertion when we remove charge groups
-    GMX_ASSERT(vsite != nullptr || pbcType == PbcType::No,
-               "Without a vsite struct we can not do PBC (in case we have charge groups)");
+    const bool useDomdec = domainInfo.useDomdec();
 
     t_pbc pbc, *pbc_null;
 
-    /* We only need to do pbc when we have inter-cg vsites.
+    /* We only need to do pbc when we have inter update-group vsites.
      * Note that with domain decomposition we do not need to apply PBC here
      * when we have at least 3 domains along each dimension. Currently we
      * do not optimize this case.
      */
-    if (pbcType != PbcType::No && (useDomdec || bMolPBC)
-        && !(vsite != nullptr && vsite->numInterUpdategroupVsites == 0))
+    if (domainInfo.pbcType_ != PbcType::No && domainInfo.useMolPbc_)
     {
         /* This is wasting some CPU time as we now do this multiple times
          * per MD step.
          */
         ivec null_ivec;
         clear_ivec(null_ivec);
-        pbc_null = set_pbc_dd(&pbc, pbcType, useDomdec ? cr->dd->numCells : null_ivec, FALSE, box);
+        pbc_null = set_pbc_dd(&pbc, domainInfo.pbcType_,
+                              useDomdec ? domainInfo.domdec_->numCells : null_ivec, FALSE, box);
     }
     else
     {
@@ -633,21 +801,21 @@ void construct_vsites(const gmx_vsite_t*              vsite,
 
     if (useDomdec)
     {
-        dd_move_x_vsites(cr->dd, box, x);
+        dd_move_x_vsites(*domainInfo.domdec_, box, as_rvec_array(x.data()));
     }
 
-    if (vsite == nullptr || vsite->nthreads == 1)
+    if (threadingInfo == nullptr || threadingInfo->numThreads() == 1)
     {
         construct_vsites_thread(x, dt, v, ip, ilist, pbc_null);
     }
     else
     {
-#pragma omp parallel num_threads(vsite->nthreads)
+#pragma omp parallel num_threads(threadingInfo->numThreads())
         {
             try
             {
                 const int          th    = gmx_omp_get_thread_num();
-                const VsiteThread& tData = *vsite->tData[th];
+                const VsiteThread& tData = threadingInfo->threadData(th);
                 GMX_ASSERT(tData.rangeStart >= 0,
                            "The thread data should be initialized before calling construct_vsites");
 
@@ -664,15 +832,41 @@ void construct_vsites(const gmx_vsite_t*              vsite,
             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
         }
         /* Now we can construct the vsites that might depend on other vsites */
-        construct_vsites_thread(x, dt, v, ip, vsite->tData[vsite->nthreads]->ilist, pbc_null);
+        construct_vsites_thread(x, dt, v, ip, threadingInfo->threadDataNonLocalDependent().ilist, pbc_null);
     }
 }
 
-static void spread_vsite2(const t_iatom ia[], real a, const rvec x[], rvec f[], rvec fshift[], const t_pbc* pbc)
+void VirtualSitesHandler::Impl::construct(ArrayRef<RVec> x, real dt, ArrayRef<RVec> v, const matrix box) const
+{
+    construct_vsites(&threadingInfo_, x, dt, v, iparams_, ilists_, domainInfo_, box);
+}
+
+void VirtualSitesHandler::construct(ArrayRef<RVec> x, real dt, ArrayRef<RVec> v, const matrix box) const
+{
+    impl_->construct(x, dt, v, box);
+}
+
+void constructVirtualSites(ArrayRef<RVec> x, ArrayRef<const t_iparams> ip, ArrayRef<const InteractionList> ilist)
+
+{
+    // No PBC, no DD
+    const DomainInfo domainInfo;
+    construct_vsites(nullptr, x, 0, {}, ip, ilist, domainInfo, nullptr);
+}
+
+#ifndef DOXYGEN
+/* Force spreading routines */
+
+template<VirialHandling virialHandling>
+static void spread_vsite2(const t_iatom        ia[],
+                          real                 a,
+                          ArrayRef<const RVec> x,
+                          ArrayRef<RVec>       f,
+                          ArrayRef<RVec>       fshift,
+                          const t_pbc*         pbc)
 {
     rvec    fi, fj, dx;
     t_iatom av, ai, aj;
-    int     siv, sij;
 
     av = ia[1];
     ai = ia[2];
@@ -686,28 +880,33 @@ static void spread_vsite2(const t_iatom ia[], real a, const rvec x[], rvec f[],
     rvec_inc(f[aj], fj);
     /* 6 Flops */
 
-    if (pbc)
-    {
-        siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
-        sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
-    }
-    else
+    if (virialHandling == VirialHandling::Pbc)
     {
-        siv = CENTRAL;
-        sij = CENTRAL;
-    }
+        int siv;
+        int sij;
+        if (pbc)
+        {
+            siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
+            sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
+        }
+        else
+        {
+            siv = CENTRAL;
+            sij = CENTRAL;
+        }
 
-    if (fshift && (siv != CENTRAL || sij != CENTRAL))
-    {
-        rvec_inc(fshift[siv], f[av]);
-        rvec_dec(fshift[CENTRAL], fi);
-        rvec_dec(fshift[sij], fj);
+        if (siv != CENTRAL || sij != CENTRAL)
+        {
+            rvec_inc(fshift[siv], f[av]);
+            rvec_dec(fshift[CENTRAL], fi);
+            rvec_dec(fshift[sij], fj);
+        }
     }
 
     /* TOTAL: 13 flops */
 }
 
-void constructVsitesGlobal(const gmx_mtop_t& mtop, gmx::ArrayRef<gmx::RVec> x)
+void constructVirtualSitesGlobal(const gmx_mtop_t& mtop, gmx::ArrayRef<gmx::RVec> x)
 {
     GMX_ASSERT(x.ssize() >= mtop.natoms, "x should contain the whole system");
     GMX_ASSERT(!mtop.moleculeBlockIndices.empty(),
@@ -722,22 +921,22 @@ void constructVsitesGlobal(const gmx_mtop_t& mtop, gmx::ArrayRef<gmx::RVec> x)
             int atomOffset = mtop.moleculeBlockIndices[mb].globalAtomStart;
             for (int mol = 0; mol < molb.nmol; mol++)
             {
-                construct_vsites(nullptr, as_rvec_array(x.data()) + atomOffset, 0.0, nullptr,
-                                 mtop.ffparams.iparams, molt.ilist, PbcType::No, TRUE, nullptr, nullptr);
+                constructVirtualSites(x.subArray(atomOffset, molt.atoms.nr), mtop.ffparams.iparams,
+                                      molt.ilist);
                 atomOffset += molt.atoms.nr;
             }
         }
     }
 }
 
-static void spread_vsite2FD(const t_iatom ia[],
-                            real          a,
-                            const rvec    x[],
-                            rvec          f[],
-                            rvec          fshift[],
-                            gmx_bool      VirCorr,
-                            matrix        dxdf,
-                            const t_pbc*  pbc)
+template<VirialHandling virialHandling>
+static void spread_vsite2FD(const t_iatom        ia[],
+                            real                 a,
+                            ArrayRef<const RVec> x,
+                            ArrayRef<RVec>       f,
+                            ArrayRef<RVec>       fshift,
+                            matrix               dxdf,
+                            const t_pbc*         pbc)
 {
     const int av = ia[1];
     const int ai = ia[2];
@@ -772,7 +971,7 @@ static void spread_vsite2FD(const t_iatom ia[],
     f[aj][ZZ] += fj[ZZ];
     /* 9 Flops */
 
-    if (fshift)
+    if (virialHandling == VirialHandling::Pbc)
     {
         int svi;
         if (pbc)
@@ -797,9 +996,9 @@ static void spread_vsite2FD(const t_iatom ia[],
         }
     }
 
-    if (VirCorr)
+    if (virialHandling == VirialHandling::NonLinear)
     {
-        /* When VirCorr=TRUE, the virial for the current forces is not
+        /* Under this condition, the virial for the current forces is not
          * calculated from the redistributed forces. This means that
          * the effect of non-linear virtual site constructions on the virial
          * needs to be added separately. This contribution can be calculated
@@ -824,11 +1023,17 @@ static void spread_vsite2FD(const t_iatom ia[],
     /* TOTAL: 38 flops */
 }
 
-static void spread_vsite3(const t_iatom ia[], real a, real b, const rvec x[], rvec f[], rvec fshift[], const t_pbc* pbc)
+template<VirialHandling virialHandling>
+static void spread_vsite3(const t_iatom        ia[],
+                          real                 a,
+                          real                 b,
+                          ArrayRef<const RVec> x,
+                          ArrayRef<RVec>       f,
+                          ArrayRef<RVec>       fshift,
+                          const t_pbc*         pbc)
 {
     rvec fi, fj, fk, dx;
     int  av, ai, aj, ak;
-    int  siv, sij, sik;
 
     av = ia[1];
     ai = ia[2];
@@ -845,44 +1050,50 @@ static void spread_vsite3(const t_iatom ia[], real a, real b, const rvec x[], rv
     rvec_inc(f[ak], fk);
     /* 9 Flops */
 
-    if (pbc)
-    {
-        siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
-        sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
-        sik = pbc_dx_aiuc(pbc, x[ai], x[ak], dx);
-    }
-    else
+    if (virialHandling == VirialHandling::Pbc)
     {
-        siv = CENTRAL;
-        sij = CENTRAL;
-        sik = CENTRAL;
-    }
+        int siv;
+        int sij;
+        int sik;
+        if (pbc)
+        {
+            siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
+            sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
+            sik = pbc_dx_aiuc(pbc, x[ai], x[ak], dx);
+        }
+        else
+        {
+            siv = CENTRAL;
+            sij = CENTRAL;
+            sik = CENTRAL;
+        }
 
-    if (fshift && (siv != CENTRAL || sij != CENTRAL || sik != CENTRAL))
-    {
-        rvec_inc(fshift[siv], f[av]);
-        rvec_dec(fshift[CENTRAL], fi);
-        rvec_dec(fshift[sij], fj);
-        rvec_dec(fshift[sik], fk);
+        if (siv != CENTRAL || sij != CENTRAL || sik != CENTRAL)
+        {
+            rvec_inc(fshift[siv], f[av]);
+            rvec_dec(fshift[CENTRAL], fi);
+            rvec_dec(fshift[sij], fj);
+            rvec_dec(fshift[sik], fk);
+        }
     }
 
     /* TOTAL: 20 flops */
 }
 
-static void spread_vsite3FD(const t_iatom ia[],
-                            real          a,
-                            real          b,
-                            const rvec    x[],
-                            rvec          f[],
-                            rvec          fshift[],
-                            gmx_bool      VirCorr,
-                            matrix        dxdf,
-                            const t_pbc*  pbc)
+template<VirialHandling virialHandling>
+static void spread_vsite3FD(const t_iatom        ia[],
+                            real                 a,
+                            real                 b,
+                            ArrayRef<const RVec> x,
+                            ArrayRef<RVec>       f,
+                            ArrayRef<RVec>       fshift,
+                            matrix               dxdf,
+                            const t_pbc*         pbc)
 {
     real    fproj, a1;
     rvec    xvi, xij, xjk, xix, fv, temp;
     t_iatom av, ai, aj, ak;
-    int     svi, sji, skj;
+    int     sji, skj;
 
     av = ia[1];
     ai = ia[2];
@@ -926,32 +1137,36 @@ static void spread_vsite3FD(const t_iatom ia[],
     f[ak][ZZ] += a * temp[ZZ];
     /* 19 Flops */
 
-    if (pbc)
+    if (virialHandling == VirialHandling::Pbc)
     {
-        svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
-    }
-    else
-    {
-        svi = CENTRAL;
-    }
+        int svi;
+        if (pbc)
+        {
+            svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
+        }
+        else
+        {
+            svi = CENTRAL;
+        }
 
-    if (fshift && (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL))
-    {
-        rvec_dec(fshift[svi], fv);
-        fshift[CENTRAL][XX] += fv[XX] - (1 + a) * temp[XX];
-        fshift[CENTRAL][YY] += fv[YY] - (1 + a) * temp[YY];
-        fshift[CENTRAL][ZZ] += fv[ZZ] - (1 + a) * temp[ZZ];
-        fshift[sji][XX] += temp[XX];
-        fshift[sji][YY] += temp[YY];
-        fshift[sji][ZZ] += temp[ZZ];
-        fshift[skj][XX] += a * temp[XX];
-        fshift[skj][YY] += a * temp[YY];
-        fshift[skj][ZZ] += a * temp[ZZ];
+        if (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL)
+        {
+            rvec_dec(fshift[svi], fv);
+            fshift[CENTRAL][XX] += fv[XX] - (1 + a) * temp[XX];
+            fshift[CENTRAL][YY] += fv[YY] - (1 + a) * temp[YY];
+            fshift[CENTRAL][ZZ] += fv[ZZ] - (1 + a) * temp[ZZ];
+            fshift[sji][XX] += temp[XX];
+            fshift[sji][YY] += temp[YY];
+            fshift[sji][ZZ] += temp[ZZ];
+            fshift[skj][XX] += a * temp[XX];
+            fshift[skj][YY] += a * temp[YY];
+            fshift[skj][ZZ] += a * temp[ZZ];
+        }
     }
 
-    if (VirCorr)
+    if (virialHandling == VirialHandling::NonLinear)
     {
-        /* When VirCorr=TRUE, the virial for the current forces is not
+        /* Under this condition, the virial for the current forces is not
          * calculated from the redistributed forces. This means that
          * the effect of non-linear virtual site constructions on the virial
          * needs to be added separately. This contribution can be calculated
@@ -976,20 +1191,20 @@ static void spread_vsite3FD(const t_iatom ia[],
     /* TOTAL: 61 flops */
 }
 
-static void spread_vsite3FAD(const t_iatom ia[],
-                             real          a,
-                             real          b,
-                             const rvec    x[],
-                             rvec          f[],
-                             rvec          fshift[],
-                             gmx_bool      VirCorr,
-                             matrix        dxdf,
-                             const t_pbc*  pbc)
+template<VirialHandling virialHandling>
+static void spread_vsite3FAD(const t_iatom        ia[],
+                             real                 a,
+                             real                 b,
+                             ArrayRef<const RVec> x,
+                             ArrayRef<RVec>       f,
+                             ArrayRef<RVec>       fshift,
+                             matrix               dxdf,
+                             const t_pbc*         pbc)
 {
     rvec    xvi, xij, xjk, xperp, Fpij, Fppp, fv, f1, f2, f3;
     real    a1, b1, c1, c2, invdij, invdij2, invdp, fproj;
     t_iatom av, ai, aj, ak;
-    int     svi, sji, skj, d;
+    int     sji, skj;
 
     av = ia[1];
     ai = ia[2];
@@ -1024,7 +1239,7 @@ static void spread_vsite3FAD(const t_iatom ia[],
 
     rvec_sub(fv, Fpij, f1); /* f1 = f - Fpij */
     rvec_sub(f1, Fppp, f2); /* f2 = f - Fpij - Fppp */
-    for (d = 0; (d < DIM); d++)
+    for (int d = 0; d < DIM; d++)
     {
         f1[d] *= a1;
         f2[d] *= b1;
@@ -1043,39 +1258,42 @@ static void spread_vsite3FAD(const t_iatom ia[],
     f[ak][ZZ] += f2[ZZ];
     /* 30 Flops */
 
-    if (pbc)
-    {
-        svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
-    }
-    else
+    if (virialHandling == VirialHandling::Pbc)
     {
-        svi = CENTRAL;
-    }
+        int svi;
 
-    if (fshift && (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL))
-    {
-        rvec_dec(fshift[svi], fv);
-        fshift[CENTRAL][XX] += fv[XX] - f1[XX] - (1 - c1) * f2[XX] + f3[XX];
-        fshift[CENTRAL][YY] += fv[YY] - f1[YY] - (1 - c1) * f2[YY] + f3[YY];
-        fshift[CENTRAL][ZZ] += fv[ZZ] - f1[ZZ] - (1 - c1) * f2[ZZ] + f3[ZZ];
-        fshift[sji][XX] += f1[XX] - c1 * f2[XX] - f3[XX];
-        fshift[sji][YY] += f1[YY] - c1 * f2[YY] - f3[YY];
-        fshift[sji][ZZ] += f1[ZZ] - c1 * f2[ZZ] - f3[ZZ];
-        fshift[skj][XX] += f2[XX];
-        fshift[skj][YY] += f2[YY];
-        fshift[skj][ZZ] += f2[ZZ];
+        if (pbc)
+        {
+            svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
+        }
+        else
+        {
+            svi = CENTRAL;
+        }
+
+        if (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL)
+        {
+            rvec_dec(fshift[svi], fv);
+            fshift[CENTRAL][XX] += fv[XX] - f1[XX] - (1 - c1) * f2[XX] + f3[XX];
+            fshift[CENTRAL][YY] += fv[YY] - f1[YY] - (1 - c1) * f2[YY] + f3[YY];
+            fshift[CENTRAL][ZZ] += fv[ZZ] - f1[ZZ] - (1 - c1) * f2[ZZ] + f3[ZZ];
+            fshift[sji][XX] += f1[XX] - c1 * f2[XX] - f3[XX];
+            fshift[sji][YY] += f1[YY] - c1 * f2[YY] - f3[YY];
+            fshift[sji][ZZ] += f1[ZZ] - c1 * f2[ZZ] - f3[ZZ];
+            fshift[skj][XX] += f2[XX];
+            fshift[skj][YY] += f2[YY];
+            fshift[skj][ZZ] += f2[ZZ];
+        }
     }
 
-    if (VirCorr)
+    if (virialHandling == VirialHandling::NonLinear)
     {
         rvec xiv;
-        int  i, j;
-
         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
 
-        for (i = 0; i < DIM; i++)
+        for (int i = 0; i < DIM; i++)
         {
-            for (j = 0; j < DIM; j++)
+            for (int j = 0; j < DIM; j++)
             {
                 /* Note that xik=xij+xjk, so we have to add xij*f2 */
                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * (f1[j] + (1 - c2) * f2[j] - f3[j])
@@ -1087,21 +1305,21 @@ static void spread_vsite3FAD(const t_iatom ia[],
     /* TOTAL: 113 flops */
 }
 
-static void spread_vsite3OUT(const t_iatom ia[],
-                             real          a,
-                             real          b,
-                             real          c,
-                             const rvec    x[],
-                             rvec          f[],
-                             rvec          fshift[],
-                             gmx_bool      VirCorr,
-                             matrix        dxdf,
-                             const t_pbc*  pbc)
+template<VirialHandling virialHandling>
+static void spread_vsite3OUT(const t_iatom        ia[],
+                             real                 a,
+                             real                 b,
+                             real                 c,
+                             ArrayRef<const RVec> x,
+                             ArrayRef<RVec>       f,
+                             ArrayRef<RVec>       fshift,
+                             matrix               dxdf,
+                             const t_pbc*         pbc)
 {
     rvec xvi, xij, xik, fv, fj, fk;
     real cfx, cfy, cfz;
     int  av, ai, aj, ak;
-    int  svi, sji, ski;
+    int  sji, ski;
 
     av = ia[1];
     ai = ia[2];
@@ -1135,26 +1353,30 @@ static void spread_vsite3OUT(const t_iatom ia[],
     rvec_inc(f[ak], fk);
     /* 15 Flops */
 
-    if (pbc)
+    if (virialHandling == VirialHandling::Pbc)
     {
-        svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
-    }
-    else
-    {
-        svi = CENTRAL;
-    }
+        int svi;
+        if (pbc)
+        {
+            svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
+        }
+        else
+        {
+            svi = CENTRAL;
+        }
 
-    if (fshift && (svi != CENTRAL || sji != CENTRAL || ski != CENTRAL))
-    {
-        rvec_dec(fshift[svi], fv);
-        fshift[CENTRAL][XX] += fv[XX] - fj[XX] - fk[XX];
-        fshift[CENTRAL][YY] += fv[YY] - fj[YY] - fk[YY];
-        fshift[CENTRAL][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ];
-        rvec_inc(fshift[sji], fj);
-        rvec_inc(fshift[ski], fk);
+        if (svi != CENTRAL || sji != CENTRAL || ski != CENTRAL)
+        {
+            rvec_dec(fshift[svi], fv);
+            fshift[CENTRAL][XX] += fv[XX] - fj[XX] - fk[XX];
+            fshift[CENTRAL][YY] += fv[YY] - fj[YY] - fk[YY];
+            fshift[CENTRAL][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ];
+            rvec_inc(fshift[sji], fj);
+            rvec_inc(fshift[ski], fk);
+        }
     }
 
-    if (VirCorr)
+    if (virialHandling == VirialHandling::NonLinear)
     {
         rvec xiv;
 
@@ -1172,21 +1394,21 @@ static void spread_vsite3OUT(const t_iatom ia[],
     /* TOTAL: 54 flops */
 }
 
-static void spread_vsite4FD(const t_iatom ia[],
-                            real          a,
-                            real          b,
-                            real          c,
-                            const rvec    x[],
-                            rvec          f[],
-                            rvec          fshift[],
-                            gmx_bool      VirCorr,
-                            matrix        dxdf,
-                            const t_pbc*  pbc)
+template<VirialHandling virialHandling>
+static void spread_vsite4FD(const t_iatom        ia[],
+                            real                 a,
+                            real                 b,
+                            real                 c,
+                            ArrayRef<const RVec> x,
+                            ArrayRef<RVec>       f,
+                            ArrayRef<RVec>       fshift,
+                            matrix               dxdf,
+                            const t_pbc*         pbc)
 {
     real fproj, a1;
     rvec xvi, xij, xjk, xjl, xix, fv, temp;
     int  av, ai, aj, ak, al;
-    int  svi, sji, skj, slj, m;
+    int  sji, skj, slj, m;
 
     av = ia[1];
     ai = ia[2];
@@ -1233,28 +1455,32 @@ static void spread_vsite4FD(const t_iatom ia[],
     }
     /* 26 Flops */
 
-    if (pbc)
+    if (virialHandling == VirialHandling::Pbc)
     {
-        svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
-    }
-    else
-    {
-        svi = CENTRAL;
-    }
+        int svi;
+        if (pbc)
+        {
+            svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
+        }
+        else
+        {
+            svi = CENTRAL;
+        }
 
-    if (fshift && (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL || slj != CENTRAL))
-    {
-        rvec_dec(fshift[svi], fv);
-        for (m = 0; m < DIM; m++)
+        if (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL || slj != CENTRAL)
         {
-            fshift[CENTRAL][m] += fv[m] - (1 + a + b) * temp[m];
-            fshift[sji][m] += temp[m];
-            fshift[skj][m] += a * temp[m];
-            fshift[slj][m] += b * temp[m];
+            rvec_dec(fshift[svi], fv);
+            for (m = 0; m < DIM; m++)
+            {
+                fshift[CENTRAL][m] += fv[m] - (1 + a + b) * temp[m];
+                fshift[sji][m] += temp[m];
+                fshift[skj][m] += a * temp[m];
+                fshift[slj][m] += b * temp[m];
+            }
         }
     }
 
-    if (VirCorr)
+    if (virialHandling == VirialHandling::NonLinear)
     {
         rvec xiv;
         int  i, j;
@@ -1273,24 +1499,23 @@ static void spread_vsite4FD(const t_iatom ia[],
     /* TOTAL: 77 flops */
 }
 
-
-static void spread_vsite4FDN(const t_iatom ia[],
-                             real          a,
-                             real          b,
-                             real          c,
-                             const rvec    x[],
-                             rvec          f[],
-                             rvec          fshift[],
-                             gmx_bool      VirCorr,
-                             matrix        dxdf,
-                             const t_pbc*  pbc)
+template<VirialHandling virialHandling>
+static void spread_vsite4FDN(const t_iatom        ia[],
+                             real                 a,
+                             real                 b,
+                             real                 c,
+                             ArrayRef<const RVec> x,
+                             ArrayRef<RVec>       f,
+                             ArrayRef<RVec>       fshift,
+                             matrix               dxdf,
+                             const t_pbc*         pbc)
 {
     rvec xvi, xij, xik, xil, ra, rb, rja, rjb, rab, rm, rt;
     rvec fv, fj, fk, fl;
     real invrm, denom;
     real cfx, cfy, cfz;
     int  av, ai, aj, ak, al;
-    int  svi, sij, sik, sil;
+    int  sij, sik, sil;
 
     /* DEBUG: check atom indices */
     av = ia[1];
@@ -1389,27 +1614,31 @@ static void spread_vsite4FDN(const t_iatom ia[],
     rvec_inc(f[al], fl);
     /* 21 flops */
 
-    if (pbc)
-    {
-        svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
-    }
-    else
+    if (virialHandling == VirialHandling::Pbc)
     {
-        svi = CENTRAL;
-    }
+        int svi;
+        if (pbc)
+        {
+            svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
+        }
+        else
+        {
+            svi = CENTRAL;
+        }
 
-    if (fshift && (svi != CENTRAL || sij != CENTRAL || sik != CENTRAL || sil != CENTRAL))
-    {
-        rvec_dec(fshift[svi], fv);
-        fshift[CENTRAL][XX] += fv[XX] - fj[XX] - fk[XX] - fl[XX];
-        fshift[CENTRAL][YY] += fv[YY] - fj[YY] - fk[YY] - fl[YY];
-        fshift[CENTRAL][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ] - fl[ZZ];
-        rvec_inc(fshift[sij], fj);
-        rvec_inc(fshift[sik], fk);
-        rvec_inc(fshift[sil], fl);
+        if (svi != CENTRAL || sij != CENTRAL || sik != CENTRAL || sil != CENTRAL)
+        {
+            rvec_dec(fshift[svi], fv);
+            fshift[CENTRAL][XX] += fv[XX] - fj[XX] - fk[XX] - fl[XX];
+            fshift[CENTRAL][YY] += fv[YY] - fj[YY] - fk[YY] - fl[YY];
+            fshift[CENTRAL][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ] - fl[ZZ];
+            rvec_inc(fshift[sij], fj);
+            rvec_inc(fshift[sik], fk);
+            rvec_inc(fshift[sil], fl);
+        }
     }
 
-    if (VirCorr)
+    if (virialHandling == VirialHandling::NonLinear)
     {
         rvec xiv;
         int  i, j;
@@ -1428,12 +1657,12 @@ static void spread_vsite4FDN(const t_iatom ia[],
     /* Total: 207 flops (Yuck!) */
 }
 
-
+template<VirialHandling virialHandling>
 static int spread_vsiten(const t_iatom             ia[],
                          ArrayRef<const t_iparams> ip,
-                         const rvec                x[],
-                         rvec                      f[],
-                         rvec                      fshift[],
+                         ArrayRef<const RVec>      x,
+                         ArrayRef<RVec>            f,
+                         ArrayRef<RVec>            fshift,
                          const t_pbc*              pbc)
 {
     rvec xv, dx, fi;
@@ -1459,7 +1688,8 @@ static int spread_vsiten(const t_iatom             ia[],
         a = ip[ia[i]].vsiten.a;
         svmul(a, f[av], fi);
         rvec_inc(f[ai], fi);
-        if (fshift && siv != CENTRAL)
+
+        if (virialHandling == VirialHandling::Pbc && siv != CENTRAL)
         {
             rvec_inc(fshift[siv], fi);
             rvec_dec(fshift[CENTRAL], fi);
@@ -1470,7 +1700,9 @@ static int spread_vsiten(const t_iatom             ia[],
     return n3;
 }
 
+#endif // DOXYGEN
 
+//! Returns the number of virtual sites in the interaction list, for VSITEN the number of atoms
 static int vsite_count(ArrayRef<const InteractionList> ilist, int ftype)
 {
     if (ftype == F_VSITEN)
@@ -1483,14 +1715,15 @@ static int vsite_count(ArrayRef<const InteractionList> ilist, int ftype)
     }
 }
 
-static void spread_vsite_f_thread(const rvec                      x[],
-                                  rvec                            f[],
-                                  rvec*                           fshift,
-                                  gmx_bool                        VirCorr,
-                                  matrix                          dxdf,
-                                  ArrayRef<const t_iparams>       ip,
-                                  ArrayRef<const InteractionList> ilist,
-                                  const t_pbc*                    pbc_null)
+//! Executes the force spreading task for a single thread
+template<VirialHandling virialHandling>
+static void spreadForceForThread(ArrayRef<const RVec>            x,
+                                 ArrayRef<RVec>                  f,
+                                 ArrayRef<RVec>                  fshift,
+                                 matrix                          dxdf,
+                                 ArrayRef<const t_iparams>       ip,
+                                 ArrayRef<const InteractionList> ilist,
+                                 const t_pbc*                    pbc_null)
 {
     const PbcMode pbcMode = getPbcMode(pbc_null);
     /* We need another pbc pointer, as with charge groups we switch per vsite */
@@ -1528,38 +1761,42 @@ static void spread_vsite_f_thread(const rvec                      x[],
                 /* Construct the vsite depending on type */
                 switch (ftype)
                 {
-                    case F_VSITE2: spread_vsite2(ia, a1, x, f, fshift, pbc_null2); break;
+                    case F_VSITE2:
+                        spread_vsite2<virialHandling>(ia, a1, x, f, fshift, pbc_null2);
+                        break;
                     case F_VSITE2FD:
-                        spread_vsite2FD(ia, a1, x, f, fshift, VirCorr, dxdf, pbc_null2);
+                        spread_vsite2FD<virialHandling>(ia, a1, x, f, fshift, dxdf, pbc_null2);
                         break;
                     case F_VSITE3:
                         b1 = ip[tp].vsite.b;
-                        spread_vsite3(ia, a1, b1, x, f, fshift, pbc_null2);
+                        spread_vsite3<virialHandling>(ia, a1, b1, x, f, fshift, pbc_null2);
                         break;
                     case F_VSITE3FD:
                         b1 = ip[tp].vsite.b;
-                        spread_vsite3FD(ia, a1, b1, x, f, fshift, VirCorr, dxdf, pbc_null2);
+                        spread_vsite3FD<virialHandling>(ia, a1, b1, x, f, fshift, dxdf, pbc_null2);
                         break;
                     case F_VSITE3FAD:
                         b1 = ip[tp].vsite.b;
-                        spread_vsite3FAD(ia, a1, b1, x, f, fshift, VirCorr, dxdf, pbc_null2);
+                        spread_vsite3FAD<virialHandling>(ia, a1, b1, x, f, fshift, dxdf, pbc_null2);
                         break;
                     case F_VSITE3OUT:
                         b1 = ip[tp].vsite.b;
                         c1 = ip[tp].vsite.c;
-                        spread_vsite3OUT(ia, a1, b1, c1, x, f, fshift, VirCorr, dxdf, pbc_null2);
+                        spread_vsite3OUT<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
                         break;
                     case F_VSITE4FD:
                         b1 = ip[tp].vsite.b;
                         c1 = ip[tp].vsite.c;
-                        spread_vsite4FD(ia, a1, b1, c1, x, f, fshift, VirCorr, dxdf, pbc_null2);
+                        spread_vsite4FD<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
                         break;
                     case F_VSITE4FDN:
                         b1 = ip[tp].vsite.b;
                         c1 = ip[tp].vsite.c;
-                        spread_vsite4FDN(ia, a1, b1, c1, x, f, fshift, VirCorr, dxdf, pbc_null2);
+                        spread_vsite4FDN<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
+                        break;
+                    case F_VSITEN:
+                        inc = spread_vsiten<virialHandling>(ia, ip, x, f, fshift, pbc_null2);
                         break;
-                    case F_VSITEN: inc = spread_vsiten(ia, ip, x, f, fshift, pbc_null2); break;
                     default:
                         gmx_fatal(FARGS, "No such vsite type %d in %s, line %d", ftype, __FILE__, __LINE__);
                 }
@@ -1573,7 +1810,37 @@ static void spread_vsite_f_thread(const rvec                      x[],
     }
 }
 
-/*! \brief Clears the task force buffer elements that are written by task idTask */
+//! Wrapper function for calling the templated thread-local spread function
+static void spreadForceWrapper(ArrayRef<const RVec>            x,
+                               ArrayRef<RVec>                  f,
+                               const VirialHandling            virialHandling,
+                               ArrayRef<RVec>                  fshift,
+                               matrix                          dxdf,
+                               const bool                      clearDxdf,
+                               ArrayRef<const t_iparams>       ip,
+                               ArrayRef<const InteractionList> ilist,
+                               const t_pbc*                    pbc_null)
+{
+    if (virialHandling == VirialHandling::NonLinear && clearDxdf)
+    {
+        clear_mat(dxdf);
+    }
+
+    switch (virialHandling)
+    {
+        case VirialHandling::None:
+            spreadForceForThread<VirialHandling::None>(x, f, fshift, dxdf, ip, ilist, pbc_null);
+            break;
+        case VirialHandling::Pbc:
+            spreadForceForThread<VirialHandling::Pbc>(x, f, fshift, dxdf, ip, ilist, pbc_null);
+            break;
+        case VirialHandling::NonLinear:
+            spreadForceForThread<VirialHandling::NonLinear>(x, f, fshift, dxdf, ip, ilist, pbc_null);
+            break;
+    }
+}
+
+//! Clears the task force buffer elements that are written by task idTask
 static void clearTaskForceBufferUsedElements(InterdependentTask* idTask)
 {
     int ntask = idTask->spreadTask.size();
@@ -1589,34 +1856,28 @@ static void clearTaskForceBufferUsedElements(InterdependentTask* idTask)
     }
 }
 
-void spread_vsite_f(const gmx_vsite_t* vsite,
-                    const rvec* gmx_restrict x,
-                    rvec* gmx_restrict f,
-                    rvec* gmx_restrict            fshift,
-                    gmx_bool                      VirCorr,
-                    matrix                        vir,
-                    t_nrnb*                       nrnb,
-                    const InteractionDefinitions& idef,
-                    PbcType                       pbcType,
-                    gmx_bool                      bMolPBC,
-                    const matrix                  box,
-                    const t_commrec*              cr,
-                    gmx_wallcycle*                wcycle)
+void VirtualSitesHandler::Impl::spreadForces(ArrayRef<const RVec> x,
+                                             ArrayRef<RVec>       f,
+                                             const VirialHandling virialHandling,
+                                             ArrayRef<RVec>       fshift,
+                                             matrix               virial,
+                                             t_nrnb*              nrnb,
+                                             const matrix         box,
+                                             gmx_wallcycle*       wcycle)
 {
     wallcycle_start(wcycle, ewcVSITESPREAD);
-    const bool useDomdec = vsite->useDomdec;
-    GMX_ASSERT(!useDomdec || (cr != nullptr && DOMAINDECOMP(cr)),
-               "When vsites are set up with domain decomposition, we need a valid commrec");
+
+    const bool useDomdec = domainInfo_.useDomdec();
 
     t_pbc pbc, *pbc_null;
 
-    /* We only need to do pbc when we have inter-cg vsites */
-    if ((useDomdec || bMolPBC) && vsite->numInterUpdategroupVsites)
+    if (domainInfo_.useMolPbc_)
     {
         /* This is wasting some CPU time as we now do this multiple times
          * per MD step.
          */
-        pbc_null = set_pbc_dd(&pbc, pbcType, useDomdec ? cr->dd->numCells : nullptr, FALSE, box);
+        pbc_null = set_pbc_dd(&pbc, domainInfo_.pbcType_,
+                              useDomdec ? domainInfo_.domdec_->numCells : nullptr, FALSE, box);
     }
     else
     {
@@ -1625,25 +1886,23 @@ void spread_vsite_f(const gmx_vsite_t* vsite,
 
     if (useDomdec)
     {
-        dd_clear_f_vsites(cr->dd, f);
+        dd_clear_f_vsites(*domainInfo_.domdec_, f);
     }
 
-    if (vsite->nthreads == 1)
+    const int numThreads = threadingInfo_.numThreads();
+
+    if (numThreads == 1)
     {
         matrix dxdf;
-        if (VirCorr)
-        {
-            clear_mat(dxdf);
-        }
-        spread_vsite_f_thread(x, f, fshift, VirCorr, dxdf, idef.iparams, idef.il, pbc_null);
+        spreadForceWrapper(x, f, virialHandling, fshift, dxdf, true, iparams_, ilists_, pbc_null);
 
-        if (VirCorr)
+        if (virialHandling == VirialHandling::NonLinear)
         {
             for (int i = 0; i < DIM; i++)
             {
                 for (int j = 0; j < DIM; j++)
                 {
-                    vir[i][j] += -0.5 * dxdf[i][j];
+                    virial[i][j] += -0.5 * dxdf[i][j];
                 }
             }
         }
@@ -1651,37 +1910,33 @@ void spread_vsite_f(const gmx_vsite_t* vsite,
     else
     {
         /* First spread the vsites that might depend on non-local vsites */
-        if (VirCorr)
-        {
-            clear_mat(vsite->tData[vsite->nthreads]->dxdf);
-        }
-        spread_vsite_f_thread(x, f, fshift, VirCorr, vsite->tData[vsite->nthreads]->dxdf,
-                              idef.iparams, vsite->tData[vsite->nthreads]->ilist, pbc_null);
+        auto& nlDependentVSites = threadingInfo_.threadDataNonLocalDependent();
+        spreadForceWrapper(x, f, virialHandling, fshift, nlDependentVSites.dxdf, true, iparams_,
+                           nlDependentVSites.ilist, pbc_null);
 
-#pragma omp parallel num_threads(vsite->nthreads)
+#pragma omp parallel num_threads(numThreads)
         {
             try
             {
                 int          thread = gmx_omp_get_thread_num();
-                VsiteThread& tData  = *vsite->tData[thread];
+                VsiteThread& tData  = threadingInfo_.threadData(thread);
 
-                rvec* fshift_t;
-                if (thread == 0 || fshift == nullptr)
-                {
-                    fshift_t = fshift;
-                }
-                else
+                ArrayRef<RVec> fshift_t;
+                if (virialHandling == VirialHandling::Pbc)
                 {
-                    fshift_t = tData.fshift;
-
-                    for (int i = 0; i < SHIFTS; i++)
+                    if (thread == 0)
                     {
-                        clear_rvec(fshift_t[i]);
+                        fshift_t = fshift;
+                    }
+                    else
+                    {
+                        fshift_t = tData.fshift;
+
+                        for (int i = 0; i < SHIFTS; i++)
+                        {
+                            clear_rvec(fshift_t[i]);
+                        }
                     }
-                }
-                if (VirCorr)
-                {
-                    clear_mat(tData.dxdf);
                 }
 
                 if (tData.useInterdependentTask)
@@ -1702,8 +1957,8 @@ void spread_vsite_f(const gmx_vsite_t* vsite,
                     {
                         copy_rvec(f[idTask->vsite[i]], idTask->force[idTask->vsite[i]]);
                     }
-                    spread_vsite_f_thread(x, as_rvec_array(idTask->force.data()), fshift_t, VirCorr,
-                                          tData.dxdf, idef.iparams, tData.idTask.ilist, pbc_null);
+                    spreadForceWrapper(x, idTask->force, virialHandling, fshift_t, tData.dxdf, true,
+                                       iparams_, tData.idTask.ilist, pbc_null);
 
                     /* We need a barrier before reducing forces below
                      * that have been produced by a different thread above.
@@ -1718,15 +1973,13 @@ void spread_vsite_f(const gmx_vsite_t* vsite,
                     int ntask = idTask->reduceTask.size();
                     for (int ti = 0; ti < ntask; ti++)
                     {
-                        const InterdependentTask* idt_foreign =
-                                &vsite->tData[idTask->reduceTask[ti]]->idTask;
-                        const AtomIndex* atomList  = &idt_foreign->atomIndex[thread];
-                        const RVec*      f_foreign = idt_foreign->force.data();
+                        const InterdependentTask& idt_foreign =
+                                threadingInfo_.threadData(idTask->reduceTask[ti]).idTask;
+                        const AtomIndex& atomList  = idt_foreign.atomIndex[thread];
+                        const RVec*      f_foreign = idt_foreign.force.data();
 
-                        int natom = atomList->atom.size();
-                        for (int i = 0; i < natom; i++)
+                        for (int ind : atomList.atom)
                         {
-                            int ind = atomList->atom[i];
                             rvec_inc(f[ind], f_foreign[ind]);
                             /* Clearing of f_foreign is done at the next step */
                         }
@@ -1741,35 +1994,35 @@ void spread_vsite_f(const gmx_vsite_t* vsite,
                 }
 
                 /* Spread the vsites that spread locally only */
-                spread_vsite_f_thread(x, f, fshift_t, VirCorr, tData.dxdf, idef.iparams,
-                                      tData.ilist, pbc_null);
+                spreadForceWrapper(x, f, virialHandling, fshift_t, tData.dxdf, false, iparams_,
+                                   tData.ilist, pbc_null);
             }
             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
         }
 
-        if (fshift != nullptr)
+        if (virialHandling == VirialHandling::Pbc)
         {
-            for (int th = 1; th < vsite->nthreads; th++)
+            for (int th = 1; th < numThreads; th++)
             {
                 for (int i = 0; i < SHIFTS; i++)
                 {
-                    rvec_inc(fshift[i], vsite->tData[th]->fshift[i]);
+                    rvec_inc(fshift[i], threadingInfo_.threadData(th).fshift[i]);
                 }
             }
         }
 
-        if (VirCorr)
+        if (virialHandling == VirialHandling::NonLinear)
         {
-            for (int th = 0; th < vsite->nthreads + 1; th++)
+            for (int th = 0; th < numThreads + 1; th++)
             {
                 /* MSVC doesn't like matrix references, so we use a pointer */
-                const matrix* dxdf = &vsite->tData[th]->dxdf;
+                const matrix& dxdf = threadingInfo_.threadData(th).dxdf;
 
                 for (int i = 0; i < DIM; i++)
                 {
                     for (int j = 0; j < DIM; j++)
                     {
-                        vir[i][j] += -0.5 * (*dxdf)[i][j];
+                        virial[i][j] += -0.5 * dxdf[i][j];
                     }
                 }
             }
@@ -1778,18 +2031,18 @@ void spread_vsite_f(const gmx_vsite_t* vsite,
 
     if (useDomdec)
     {
-        dd_move_f_vsites(cr->dd, f, fshift);
+        dd_move_f_vsites(*domainInfo_.domdec_, f, fshift);
     }
 
-    inc_nrnb(nrnb, eNR_VSITE2, vsite_count(idef.il, F_VSITE2));
-    inc_nrnb(nrnb, eNR_VSITE2FD, vsite_count(idef.il, F_VSITE2FD));
-    inc_nrnb(nrnb, eNR_VSITE3, vsite_count(idef.il, F_VSITE3));
-    inc_nrnb(nrnb, eNR_VSITE3FD, vsite_count(idef.il, F_VSITE3FD));
-    inc_nrnb(nrnb, eNR_VSITE3FAD, vsite_count(idef.il, F_VSITE3FAD));
-    inc_nrnb(nrnb, eNR_VSITE3OUT, vsite_count(idef.il, F_VSITE3OUT));
-    inc_nrnb(nrnb, eNR_VSITE4FD, vsite_count(idef.il, F_VSITE4FD));
-    inc_nrnb(nrnb, eNR_VSITE4FDN, vsite_count(idef.il, F_VSITE4FDN));
-    inc_nrnb(nrnb, eNR_VSITEN, vsite_count(idef.il, F_VSITEN));
+    inc_nrnb(nrnb, eNR_VSITE2, vsite_count(ilists_, F_VSITE2));
+    inc_nrnb(nrnb, eNR_VSITE2FD, vsite_count(ilists_, F_VSITE2FD));
+    inc_nrnb(nrnb, eNR_VSITE3, vsite_count(ilists_, F_VSITE3));
+    inc_nrnb(nrnb, eNR_VSITE3FD, vsite_count(ilists_, F_VSITE3FD));
+    inc_nrnb(nrnb, eNR_VSITE3FAD, vsite_count(ilists_, F_VSITE3FAD));
+    inc_nrnb(nrnb, eNR_VSITE3OUT, vsite_count(ilists_, F_VSITE3OUT));
+    inc_nrnb(nrnb, eNR_VSITE4FD, vsite_count(ilists_, F_VSITE4FD));
+    inc_nrnb(nrnb, eNR_VSITE4FDN, vsite_count(ilists_, F_VSITE4FDN));
+    inc_nrnb(nrnb, eNR_VSITEN, vsite_count(ilists_, F_VSITEN));
 
     wallcycle_stop(wcycle, ewcVSITESPREAD);
 }
@@ -1831,6 +2084,18 @@ int countNonlinearVsites(const gmx_mtop_t& mtop)
     return numNonlinearVsites;
 }
 
+void VirtualSitesHandler::spreadForces(ArrayRef<const RVec> x,
+                                       ArrayRef<RVec>       f,
+                                       const VirialHandling virialHandling,
+                                       ArrayRef<RVec>       fshift,
+                                       matrix               virial,
+                                       t_nrnb*              nrnb,
+                                       const matrix         box,
+                                       gmx_wallcycle*       wcycle)
+{
+    impl_->spreadForces(x, f, virialHandling, fshift, virial, nrnb, box, wcycle);
+}
+
 int countInterUpdategroupVsites(const gmx_mtop_t&                           mtop,
                                 gmx::ArrayRef<const gmx::RangePartitioning> updateGroupingPerMoleculetype)
 {
@@ -1874,11 +2139,13 @@ int countInterUpdategroupVsites(const gmx_mtop_t&                           mtop
     return n_intercg_vsite;
 }
 
-std::unique_ptr<gmx_vsite_t> initVsite(const gmx_mtop_t& mtop, const t_commrec* cr)
+std::unique_ptr<VirtualSitesHandler> makeVirtualSitesHandler(const gmx_mtop_t& mtop,
+                                                             const t_commrec*  cr,
+                                                             PbcType           pbcType)
 {
     GMX_RELEASE_ASSERT(cr != nullptr, "We need a valid commrec");
 
-    std::unique_ptr<gmx_vsite_t> vsite;
+    std::unique_ptr<VirtualSitesHandler> vsite;
 
     /* check if there are vsites */
     int nvsite = 0;
@@ -1903,57 +2170,68 @@ std::unique_ptr<gmx_vsite_t> initVsite(const gmx_mtop_t& mtop, const t_commrec*
         return vsite;
     }
 
-    vsite = std::make_unique<gmx_vsite_t>();
-
-    gmx::ArrayRef<const gmx::RangePartitioning> updateGroupingPerMoleculetype;
-    if (DOMAINDECOMP(cr))
-    {
-        updateGroupingPerMoleculetype = getUpdateGroupingPerMoleculetype(*cr->dd);
-    }
-    vsite->numInterUpdategroupVsites = countInterUpdategroupVsites(mtop, updateGroupingPerMoleculetype);
-
-    vsite->useDomdec = (DOMAINDECOMP(cr) && cr->dd->nnodes > 1);
-
-    vsite->nthreads = gmx_omp_nthreads_get(emntVSITE);
+    return std::make_unique<VirtualSitesHandler>(mtop, cr->dd, pbcType);
+}
 
-    if (vsite->nthreads > 1)
+ThreadingInfo::ThreadingInfo() : numThreads_(gmx_omp_nthreads_get(emntVSITE))
+{
+    if (numThreads_ > 1)
     {
         /* We need one extra thread data structure for the overlap vsites */
-        vsite->tData.resize(vsite->nthreads + 1);
-#pragma omp parallel for num_threads(vsite->nthreads) schedule(static)
-        for (int thread = 0; thread < vsite->nthreads; thread++)
+        tData_.resize(numThreads_ + 1);
+#pragma omp parallel for num_threads(numThreads_) schedule(static)
+        for (int thread = 0; thread < numThreads_; thread++)
         {
             try
             {
-                vsite->tData[thread] = std::make_unique<VsiteThread>();
+                tData_[thread] = std::make_unique<VsiteThread>();
 
-                InterdependentTask& idTask = vsite->tData[thread]->idTask;
+                InterdependentTask& idTask = tData_[thread]->idTask;
                 idTask.nuse                = 0;
-                idTask.atomIndex.resize(vsite->nthreads);
+                idTask.atomIndex.resize(numThreads_);
             }
             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
         }
-        if (vsite->nthreads > 1)
+        if (numThreads_ > 1)
         {
-            vsite->tData[vsite->nthreads] = std::make_unique<VsiteThread>();
+            tData_[numThreads_] = std::make_unique<VsiteThread>();
         }
     }
+}
+
+//! Returns the number of inter update-group vsites
+static int getNumInterUpdategroupVsites(const gmx_mtop_t& mtop, const gmx_domdec_t* domdec)
+{
+    gmx::ArrayRef<const gmx::RangePartitioning> updateGroupingPerMoleculetype;
+    if (domdec)
+    {
+        updateGroupingPerMoleculetype = getUpdateGroupingPerMoleculetype(*domdec);
+    }
 
-    return vsite;
+    return countInterUpdategroupVsites(mtop, updateGroupingPerMoleculetype);
 }
 
-gmx_vsite_t::gmx_vsite_t() {}
+VirtualSitesHandler::Impl::Impl(const gmx_mtop_t& mtop, gmx_domdec_t* domdec, const PbcType pbcType) :
+    numInterUpdategroupVirtualSites_(getNumInterUpdategroupVsites(mtop, domdec)),
+    domainInfo_({ pbcType, pbcType != PbcType::No && numInterUpdategroupVirtualSites_ > 0, domdec }),
+    iparams_(mtop.ffparams.iparams)
+{
+}
 
-gmx_vsite_t::~gmx_vsite_t() {}
+VirtualSitesHandler::VirtualSitesHandler(const gmx_mtop_t& mtop, gmx_domdec_t* domdec, const PbcType pbcType) :
+    impl_(new Impl(mtop, domdec, pbcType))
+{
+}
 
-static inline void flagAtom(InterdependentTask* idTask, int atom, int thread, int nthread, int natperthread)
+//! Flag that atom \p atom which is home in another task, if it has not already been added before
+static inline void flagAtom(InterdependentTask* idTask, const int atom, const int numThreads, const int numAtomsPerThread)
 {
     if (!idTask->use[atom])
     {
         idTask->use[atom] = true;
-        thread            = atom / natperthread;
+        int thread        = atom / numAtomsPerThread;
         /* Assign all non-local atom force writes to thread 0 */
-        if (thread >= nthread)
+        if (thread >= numThreads)
         {
             thread = 0;
         }
@@ -1961,7 +2239,7 @@ static inline void flagAtom(InterdependentTask* idTask, int atom, int thread, in
     }
 }
 
-/*\brief Here we try to assign all vsites that are in our local range.
+/*\brief Here we try to assign all vsites that are in our local range.
  *
  * Our task local atom range is tData->rangeStart - tData->rangeEnd.
  * Vsites that depend only on local atoms, as indicated by taskIndex[]==thread,
@@ -2082,14 +2360,14 @@ static void assignVsitesToThread(VsiteThread*                    tData,
                     {
                         for (int j = i + 2; j < i + nral1; j++)
                         {
-                            flagAtom(&tData->idTask, iat[j], thread, nthread, natperthread);
+                            flagAtom(&tData->idTask, iat[j], nthread, natperthread);
                         }
                     }
                     else
                     {
                         for (int j = i + 2; j < i + inc; j += 3)
                         {
-                            flagAtom(&tData->idTask, iat[j], thread, nthread, natperthread);
+                            flagAtom(&tData->idTask, iat[j], nthread, natperthread);
                         }
                     }
                 }
@@ -2136,14 +2414,12 @@ static void assignVsitesToSingleTask(VsiteThread*                    tData,
     }
 }
 
-void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
-                               ArrayRef<const t_iparams>       ip,
-                               const t_mdatoms*                mdatoms,
-                               gmx_vsite_t*                    vsite)
+void ThreadingInfo::setVirtualSites(ArrayRef<const InteractionList> ilists,
+                                    ArrayRef<const t_iparams>       iparams,
+                                    const t_mdatoms&                mdatoms,
+                                    const bool                      useDomdec)
 {
-    int vsite_atom_range, natperthread;
-
-    if (vsite->nthreads == 1)
+    if (numThreads_ <= 1)
     {
         /* Nothing to do */
         return;
@@ -2159,7 +2435,9 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
      * uniformly in each domain along the major dimension, usually x,
      * it will also perform well.
      */
-    if (!vsite->useDomdec)
+    int vsite_atom_range;
+    int natperthread;
+    if (!useDomdec)
     {
         vsite_atom_range = -1;
         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
@@ -2168,8 +2446,8 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
                 if (ftype != F_VSITEN)
                 {
                     int                 nral1 = 1 + NRAL(ftype);
-                    ArrayRef<const int> iat   = ilist[ftype].iatoms;
-                    for (int i = 0; i < ilist[ftype].size(); i += nral1)
+                    ArrayRef<const int> iat   = ilists[ftype].iatoms;
+                    for (int i = 0; i < ilists[ftype].size(); i += nral1)
                     {
                         for (int j = i + 1; j < i + nral1; j++)
                         {
@@ -2181,13 +2459,13 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
                 {
                     int vs_ind_end;
 
-                    ArrayRef<const int> iat = ilist[ftype].iatoms;
+                    ArrayRef<const int> iat = ilists[ftype].iatoms;
 
                     int i = 0;
-                    while (i < ilist[ftype].size())
+                    while (i < ilists[ftype].size())
                     {
                         /* The 3 below is from 1+NRAL(ftype)=3 */
-                        vs_ind_end = i + ip[iat[i]].vsiten.n * 3;
+                        vs_ind_end = i + iparams[iat[i]].vsiten.n * 3;
 
                         vsite_atom_range = std::max(vsite_atom_range, iat[i + 1]);
                         while (i < vs_ind_end)
@@ -2200,7 +2478,7 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
             }
         }
         vsite_atom_range++;
-        natperthread = (vsite_atom_range + vsite->nthreads - 1) / vsite->nthreads;
+        natperthread = (vsite_atom_range + numThreads_ - 1) / numThreads_;
     }
     else
     {
@@ -2211,53 +2489,52 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
          * When assigning vsites to threads, we should take care that the last
          * threads also covers the non-local range.
          */
-        vsite_atom_range = mdatoms->nr;
-        natperthread     = (mdatoms->homenr + vsite->nthreads - 1) / vsite->nthreads;
+        vsite_atom_range = mdatoms.nr;
+        natperthread     = (mdatoms.homenr + numThreads_ - 1) / numThreads_;
     }
 
     if (debug)
     {
         fprintf(debug, "virtual site thread dist: natoms %d, range %d, natperthread %d\n",
-                mdatoms->nr, vsite_atom_range, natperthread);
+                mdatoms.nr, vsite_atom_range, natperthread);
     }
 
     /* To simplify the vsite assignment, we make an index which tells us
      * to which task particles, both non-vsites and vsites, are assigned.
      */
-    vsite->taskIndex.resize(mdatoms->nr);
+    taskIndex_.resize(mdatoms.nr);
 
     /* Initialize the task index array. Here we assign the non-vsite
      * particles to task=thread, so we easily figure out if vsites
      * depend on local and/or non-local particles in assignVsitesToThread.
      */
-    gmx::ArrayRef<int> taskIndex = vsite->taskIndex;
     {
         int thread = 0;
-        for (int i = 0; i < mdatoms->nr; i++)
+        for (int i = 0; i < mdatoms.nr; i++)
         {
-            if (mdatoms->ptype[i] == eptVSite)
+            if (mdatoms.ptype[i] == eptVSite)
             {
                 /* vsites are not assigned to a task yet */
-                taskIndex[i] = -1;
+                taskIndex_[i] = -1;
             }
             else
             {
                 /* assign non-vsite particles to task thread */
-                taskIndex[i] = thread;
+                taskIndex_[i] = thread;
             }
-            if (i == (thread + 1) * natperthread && thread < vsite->nthreads)
+            if (i == (thread + 1) * natperthread && thread < numThreads_)
             {
                 thread++;
             }
         }
     }
 
-#pragma omp parallel num_threads(vsite->nthreads)
+#pragma omp parallel num_threads(numThreads_)
     {
         try
         {
             int          thread = gmx_omp_get_thread_num();
-            VsiteThread& tData  = *vsite->tData[thread];
+            VsiteThread& tData  = *tData_[thread];
 
             /* Clear the buffer use flags that were set before */
             if (tData.useInterdependentTask)
@@ -2271,7 +2548,7 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
                 clearTaskForceBufferUsedElements(&idTask);
 
                 idTask.vsite.resize(0);
-                for (int t = 0; t < vsite->nthreads; t++)
+                for (int t = 0; t < numThreads_; t++)
                 {
                     AtomIndex& atomIndex = idTask.atomIndex[t];
                     int        natom     = atomIndex.atom.size();
@@ -2307,17 +2584,17 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
 
             /* Assign all vsites that can execute independently on threads */
             tData.rangeStart = thread * natperthread;
-            if (thread < vsite->nthreads - 1)
+            if (thread < numThreads_ - 1)
             {
                 tData.rangeEnd = (thread + 1) * natperthread;
             }
             else
             {
                 /* The last thread should cover up to the end of the range */
-                tData.rangeEnd = mdatoms->nr;
+                tData.rangeEnd = mdatoms.nr;
             }
-            assignVsitesToThread(&tData, thread, vsite->nthreads, natperthread, taskIndex, ilist,
-                                 ip, mdatoms->ptype);
+            assignVsitesToThread(&tData, thread, numThreads_, natperthread, taskIndex_, ilists,
+                                 iparams, mdatoms.ptype);
 
             if (tData.useInterdependentTask)
             {
@@ -2335,7 +2612,7 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
 
                 idTask.spreadTask.resize(0);
                 idTask.reduceTask.resize(0);
-                for (int t = 0; t < vsite->nthreads; t++)
+                for (int t = 0; t < numThreads_; t++)
                 {
                     /* Do we write to the force buffer of task t? */
                     if (!idTask.atomIndex[t].atom.empty())
@@ -2343,7 +2620,7 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
                         idTask.spreadTask.push_back(t);
                     }
                     /* Does task t write to our force buffer? */
-                    if (!vsite->tData[t]->idTask.atomIndex[thread].atom.empty())
+                    if (!tData_[t]->idTask.atomIndex[thread].atom.empty())
                     {
                         idTask.reduceTask.push_back(t);
                     }
@@ -2355,28 +2632,27 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
     /* Assign all remaining vsites, that will have taskIndex[]=2*vsite->nthreads,
      * to a single task that will not run in parallel with other tasks.
      */
-    assignVsitesToSingleTask(vsite->tData[vsite->nthreads].get(), 2 * vsite->nthreads, taskIndex,
-                             ilist, ip);
+    assignVsitesToSingleTask(tData_[numThreads_].get(), 2 * numThreads_, taskIndex_, ilists, iparams);
 
-    if (debug && vsite->nthreads > 1)
+    if (debug && numThreads_ > 1)
     {
         fprintf(debug, "virtual site useInterdependentTask %d, nuse:\n",
-                static_cast<int>(vsite->tData[0]->useInterdependentTask));
-        for (int th = 0; th < vsite->nthreads + 1; th++)
+                static_cast<int>(tData_[0]->useInterdependentTask));
+        for (int th = 0; th < numThreads_ + 1; th++)
         {
-            fprintf(debug, " %4d", vsite->tData[th]->idTask.nuse);
+            fprintf(debug, " %4d", tData_[th]->idTask.nuse);
         }
         fprintf(debug, "\n");
 
         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
         {
-            if (!ilist[ftype].empty())
+            if (!ilists[ftype].empty())
             {
                 fprintf(debug, "%-20s thread dist:", interaction_function[ftype].longname);
-                for (int th = 0; th < vsite->nthreads + 1; th++)
+                for (int th = 0; th < numThreads_ + 1; th++)
                 {
-                    fprintf(debug, " %4d %4d ", vsite->tData[th]->ilist[ftype].size(),
-                            vsite->tData[th]->idTask.ilist[ftype].size());
+                    fprintf(debug, " %4d %4d ", tData_[th]->ilist[ftype].size(),
+                            tData_[th]->idTask.ilist[ftype].size());
                 }
                 fprintf(debug, "\n");
             }
@@ -2384,12 +2660,11 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
     }
 
 #ifndef NDEBUG
-    int nrOrig     = vsiteIlistNrCount(ilist);
+    int nrOrig     = vsiteIlistNrCount(ilists);
     int nrThreaded = 0;
-    for (int th = 0; th < vsite->nthreads + 1; th++)
+    for (int th = 0; th < numThreads_ + 1; th++)
     {
-        nrThreaded += vsiteIlistNrCount(vsite->tData[th]->ilist)
-                      + vsiteIlistNrCount(vsite->tData[th]->idTask.ilist);
+        nrThreaded += vsiteIlistNrCount(tData_[th]->ilist) + vsiteIlistNrCount(tData_[th]->idTask.ilist);
     }
     GMX_ASSERT(nrThreaded == nrOrig,
                "The number of virtual sites assigned to all thread task has to match the total "
@@ -2397,10 +2672,17 @@ void split_vsites_over_threads(ArrayRef<const InteractionList> ilist,
 #endif
 }
 
-void set_vsite_top(gmx_vsite_t* vsite, const gmx_localtop_t* top, const t_mdatoms* md)
+void VirtualSitesHandler::Impl::setVirtualSites(ArrayRef<const InteractionList> ilists,
+                                                const t_mdatoms&                mdatoms)
 {
-    if (vsite->nthreads > 1)
-    {
-        split_vsites_over_threads(top->idef.il, top->idef.iparams, md, vsite);
-    }
+    ilists_ = ilists;
+
+    threadingInfo_.setVirtualSites(ilists, iparams_, mdatoms, domainInfo_.useDomdec());
 }
+
+void VirtualSitesHandler::setVirtualSites(ArrayRef<const InteractionList> ilists, const t_mdatoms& mdatoms)
+{
+    impl_->setVirtualSites(ilists, mdatoms);
+}
+
+} // namespace gmx