Decouple update-group aspects of vsites and constraints
[alexxy/gromacs.git] / src / gromacs / mdlib / vsite.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017 The GROMACS development team.
7  * Copyright (c) 2018,2019,2020,2021, by the GROMACS development team, led by
8  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
9  * and including many others, as listed in the AUTHORS file in the
10  * top-level source directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 /*! \internal \file
39  * \brief Implements the VirtualSitesHandler class and vsite standalone functions
40  *
41  * \author Berk Hess <hess@kth.se>
42  * \ingroup module_mdlib
43  */
44
45 #include "gmxpre.h"
46
47 #include "vsite.h"
48
49 #include <cstdio>
50
51 #include <algorithm>
52 #include <memory>
53 #include <vector>
54
55 #include "gromacs/domdec/domdec.h"
56 #include "gromacs/domdec/domdec_struct.h"
57 #include "gromacs/gmxlib/network.h"
58 #include "gromacs/gmxlib/nrnb.h"
59 #include "gromacs/math/functions.h"
60 #include "gromacs/math/vec.h"
61 #include "gromacs/mdlib/gmx_omp_nthreads.h"
62 #include "gromacs/mdtypes/commrec.h"
63 #include "gromacs/mdtypes/mdatom.h"
64 #include "gromacs/pbcutil/ishift.h"
65 #include "gromacs/pbcutil/pbc.h"
66 #include "gromacs/timing/wallcycle.h"
67 #include "gromacs/topology/ifunc.h"
68 #include "gromacs/topology/mtop_util.h"
69 #include "gromacs/topology/topology.h"
70 #include "gromacs/utility/exceptions.h"
71 #include "gromacs/utility/fatalerror.h"
72 #include "gromacs/utility/gmxassert.h"
73 #include "gromacs/utility/gmxomp.h"
74
75 /* The strategy used here for assigning virtual sites to (thread-)tasks
76  * is as follows:
77  *
78  * We divide the atom range that vsites operate on (natoms_local with DD,
79  * 0 - last atom involved in vsites without DD) equally over all threads.
80  *
81  * Vsites in the local range constructed from atoms in the local range
82  * and/or other vsites that are fully local are assigned to a simple,
83  * independent task.
84  *
85  * Vsites that are not assigned after using the above criterion get assigned
86  * to a so called "interdependent" thread task when none of the constructing
87  * atoms is a vsite. These tasks are called interdependent, because one task
88  * accesses atoms assigned to a different task/thread.
89  * Note that this option is turned off with large (local) atom counts
90  * to avoid high memory usage.
91  *
92  * Any remaining vsites are assigned to a separate master thread task.
93  */
94 namespace gmx
95 {
96
97 //! VirialHandling is often used outside VirtualSitesHandler class members
98 using VirialHandling = VirtualSitesHandler::VirialHandling;
99
100 /*! \brief Information on PBC and domain decomposition for virtual sites
101  */
102 struct DomainInfo
103 {
104 public:
105     //! Constructs without PBC and DD
106     DomainInfo() = default;
107
108     //! Constructs with PBC and DD, if !=nullptr
109     DomainInfo(PbcType pbcType, bool haveInterUpdateGroupVirtualSites, gmx_domdec_t* domdec) :
110         pbcType_(pbcType),
111         useMolPbc_(pbcType != PbcType::No && haveInterUpdateGroupVirtualSites),
112         domdec_(domdec)
113     {
114     }
115
116     //! Returns whether we are using domain decomposition with more than 1 DD rank
117     bool useDomdec() const { return (domdec_ != nullptr); }
118
119     //! The pbc type
120     const PbcType pbcType_ = PbcType::No;
121     //! Whether molecules are broken over PBC
122     const bool useMolPbc_ = false;
123     //! Pointer to the domain decomposition struct, nullptr without PP DD
124     const gmx_domdec_t* domdec_ = nullptr;
125 };
126
127 /*! \brief List of atom indices belonging to a task
128  */
129 struct AtomIndex
130 {
131     //! List of atom indices
132     std::vector<int> atom;
133 };
134
135 /*! \brief Data structure for thread tasks that use constructing atoms outside their own atom range
136  */
137 struct InterdependentTask
138 {
139     //! The interaction lists, only vsite entries are used
140     InteractionLists ilist;
141     //! Thread/task-local force buffer
142     std::vector<RVec> force;
143     //! The atom indices of the vsites of our task
144     std::vector<int> vsite;
145     //! Flags if elements in force are spread to or not
146     std::vector<bool> use;
147     //! The number of entries set to true in use
148     int nuse = 0;
149     //! Array of atoms indices, size nthreads, covering all nuse set elements in use
150     std::vector<AtomIndex> atomIndex;
151     //! List of tasks (force blocks) this task spread forces to
152     std::vector<int> spreadTask;
153     //! List of tasks that write to this tasks force block range
154     std::vector<int> reduceTask;
155 };
156
157 /*! \brief Vsite thread task data structure
158  */
159 struct VsiteThread
160 {
161     //! Start of atom range of this task
162     int rangeStart;
163     //! End of atom range of this task
164     int rangeEnd;
165     //! The interaction lists, only vsite entries are used
166     std::array<InteractionList, F_NRE> ilist;
167     //! Local fshift accumulation buffer
168     std::array<RVec, c_numShiftVectors> fshift;
169     //! Local virial dx*df accumulation buffer
170     matrix dxdf;
171     //! Tells if interdependent task idTask should be used (in addition to the rest of this task), this bool has the same value on all threads
172     bool useInterdependentTask;
173     //! Data for vsites that involve constructing atoms in the atom range of other threads/tasks
174     InterdependentTask idTask;
175
176     /*! \brief Constructor */
177     VsiteThread()
178     {
179         rangeStart = -1;
180         rangeEnd   = -1;
181         for (auto& elem : fshift)
182         {
183             elem = { 0.0_real, 0.0_real, 0.0_real };
184         }
185         clear_mat(dxdf);
186         useInterdependentTask = false;
187     }
188 };
189
190
191 /*! \brief Information on how the virtual site work is divided over thread tasks
192  */
193 class ThreadingInfo
194 {
195 public:
196     //! Constructor, retrieves the number of threads to use from gmx_omp_nthreads.h
197     ThreadingInfo();
198
199     //! Returns the number of threads to use for vsite operations
200     int numThreads() const { return numThreads_; }
201
202     //! Returns the thread data for the given thread
203     const VsiteThread& threadData(int threadIndex) const { return *tData_[threadIndex]; }
204
205     //! Returns the thread data for the given thread
206     VsiteThread& threadData(int threadIndex) { return *tData_[threadIndex]; }
207
208     //! Returns the thread data for vsites that depend on non-local vsites
209     const VsiteThread& threadDataNonLocalDependent() const { return *tData_[numThreads_]; }
210
211     //! Returns the thread data for vsites that depend on non-local vsites
212     VsiteThread& threadDataNonLocalDependent() { return *tData_[numThreads_]; }
213
214     //! Set VSites and distribute VSite work over threads, should be called after DD partitioning
215     void setVirtualSites(ArrayRef<const InteractionList> ilist,
216                          ArrayRef<const t_iparams>       iparams,
217                          const t_mdatoms&                mdatoms,
218                          bool                            useDomdec);
219
220 private:
221     //! Number of threads used for vsite operations
222     const int numThreads_;
223     //! Thread local vsites and work structs
224     std::vector<std::unique_ptr<VsiteThread>> tData_;
225     //! Work array for dividing vsites over threads
226     std::vector<int> taskIndex_;
227 };
228
229 /*! \brief Impl class for VirtualSitesHandler
230  */
231 class VirtualSitesHandler::Impl
232 {
233 public:
234     //! Constructor, domdec should be nullptr without DD
235     Impl(const gmx_mtop_t&                 mtop,
236          gmx_domdec_t*                     domdec,
237          PbcType                           pbcType,
238          ArrayRef<const RangePartitioning> updateGroupingPerMoleculeType);
239
240     //! Returns the number of virtual sites acting over multiple update groups
241     int numInterUpdategroupVirtualSites() const { return numInterUpdategroupVirtualSites_; }
242
243     //! Set VSites and distribute VSite work over threads, should be called after DD partitioning
244     void setVirtualSites(ArrayRef<const InteractionList> ilist, const t_mdatoms& mdatoms);
245
246     /*! \brief Create positions of vsite atoms based for the local system
247      *
248      * \param[in,out] x          The coordinates
249      * \param[in,out] v          The velocities, needed if operation requires it
250      * \param[in]     box        The box
251      * \param[in]     operation  Whether we calculate positions, velocities, or both
252      */
253     void construct(ArrayRef<RVec> x, ArrayRef<RVec> v, const matrix box, VSiteOperation operation) const;
254
255     /*! \brief Spread the force operating on the vsite atoms on the surrounding atoms.
256      *
257      * vsite should point to a valid object.
258      * The virialHandling parameter determines how virial contributions are handled.
259      * If this is set to Linear, shift forces are accumulated into fshift.
260      * If this is set to NonLinear, non-linear contributions are added to virial.
261      * This non-linear correction is required when the virial is not calculated
262      * afterwards from the particle position and forces, but in a different way,
263      * as for instance for the PME mesh contribution.
264      */
265     void spreadForces(ArrayRef<const RVec> x,
266                       ArrayRef<RVec>       f,
267                       VirialHandling       virialHandling,
268                       ArrayRef<RVec>       fshift,
269                       matrix               virial,
270                       t_nrnb*              nrnb,
271                       const matrix         box,
272                       gmx_wallcycle*       wcycle);
273
274 private:
275     //! The number of vsites that cross update groups, when =0 no PBC treatment is needed
276     const int numInterUpdategroupVirtualSites_;
277     //! PBC and DD information
278     const DomainInfo domainInfo_;
279     //! The interaction parameters
280     const ArrayRef<const t_iparams> iparams_;
281     //! The interaction lists
282     ArrayRef<const InteractionList> ilists_;
283     //! Information for handling vsite threading
284     ThreadingInfo threadingInfo_;
285 };
286
287 VirtualSitesHandler::~VirtualSitesHandler() = default;
288
289 int VirtualSitesHandler::numInterUpdategroupVirtualSites() const
290 {
291     return impl_->numInterUpdategroupVirtualSites();
292 }
293
294 /*! \brief Returns the sum of the vsite ilist sizes over all vsite types
295  *
296  * \param[in] ilist  The interaction list
297  */
298 static int vsiteIlistNrCount(ArrayRef<const InteractionList> ilist)
299 {
300     int nr = 0;
301     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
302     {
303         nr += ilist[ftype].size();
304     }
305
306     return nr;
307 }
308
309 //! Computes the distance between xi and xj, pbc is used when pbc!=nullptr
310 static int pbc_rvec_sub(const t_pbc* pbc, const rvec xi, const rvec xj, rvec dx)
311 {
312     if (pbc)
313     {
314         return pbc_dx_aiuc(pbc, xi, xj, dx);
315     }
316     else
317     {
318         rvec_sub(xi, xj, dx);
319         return c_centralShiftIndex;
320     }
321 }
322
323 //! Returns the 1/norm(x)
324 static inline real inverseNorm(const rvec x)
325 {
326     return gmx::invsqrt(iprod(x, x));
327 }
328
329 //! Whether we're calculating the virtual site position
330 enum class VSiteCalculatePosition
331 {
332     Yes,
333     No
334 };
335 //! Whether we're calculating the virtual site velocity
336 enum class VSiteCalculateVelocity
337 {
338     Yes,
339     No
340 };
341
342 #ifndef DOXYGEN
343 /* Vsite construction routines */
344
345 // GCC 8 falsely flags unused variables if constexpr prunes a code path, fixed in GCC 9
346 // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85827
347 // clang-format off
348 GCC_DIAGNOSTIC_IGNORE(-Wunused-but-set-parameter)
349 // clang-format on
350
351 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
352 static void constr_vsite1(const rvec xi, rvec x, const rvec vi, rvec v)
353 {
354     if (calculatePosition == VSiteCalculatePosition::Yes)
355     {
356         copy_rvec(xi, x);
357         /* TOTAL: 0 flops */
358     }
359     if (calculateVelocity == VSiteCalculateVelocity::Yes)
360     {
361         copy_rvec(vi, v);
362     }
363 }
364
365 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
366 static void
367 constr_vsite2(const rvec xi, const rvec xj, rvec x, real a, const t_pbc* pbc, const rvec vi, const rvec vj, rvec v)
368 {
369     const real b = 1 - a;
370     /* 1 flop */
371
372     if (calculatePosition == VSiteCalculatePosition::Yes)
373     {
374         if (pbc)
375         {
376             rvec dx;
377             pbc_dx_aiuc(pbc, xj, xi, dx);
378             x[XX] = xi[XX] + a * dx[XX];
379             x[YY] = xi[YY] + a * dx[YY];
380             x[ZZ] = xi[ZZ] + a * dx[ZZ];
381         }
382         else
383         {
384             x[XX] = b * xi[XX] + a * xj[XX];
385             x[YY] = b * xi[YY] + a * xj[YY];
386             x[ZZ] = b * xi[ZZ] + a * xj[ZZ];
387             /* 9 Flops */
388         }
389         /* TOTAL: 10 flops */
390     }
391     if (calculateVelocity == VSiteCalculateVelocity::Yes)
392     {
393         v[XX] = b * vi[XX] + a * vj[XX];
394         v[YY] = b * vi[YY] + a * vj[YY];
395         v[ZZ] = b * vi[ZZ] + a * vj[ZZ];
396     }
397 }
398
399 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
400 static void
401 constr_vsite2FD(const rvec xi, const rvec xj, rvec x, real a, const t_pbc* pbc, const rvec vi, const rvec vj, rvec v)
402 {
403     rvec xij = { 0 };
404     pbc_rvec_sub(pbc, xj, xi, xij);
405     /* 3 flops */
406
407     const real invNormXij = inverseNorm(xij);
408     const real b          = a * invNormXij;
409     /* 6 + 10 flops */
410
411     if (calculatePosition == VSiteCalculatePosition::Yes)
412     {
413         x[XX] = xi[XX] + b * xij[XX];
414         x[YY] = xi[YY] + b * xij[YY];
415         x[ZZ] = xi[ZZ] + b * xij[ZZ];
416         /* 6 Flops */
417         /* TOTAL: 25 flops */
418     }
419     if (calculateVelocity == VSiteCalculateVelocity::Yes)
420     {
421         rvec vij = { 0 };
422         rvec_sub(vj, vi, vij);
423         const real vijDotXij = iprod(vij, xij);
424
425         v[XX] = vi[XX] + b * (vij[XX] - xij[XX] * vijDotXij * invNormXij * invNormXij);
426         v[YY] = vi[YY] + b * (vij[YY] - xij[YY] * vijDotXij * invNormXij * invNormXij);
427         v[ZZ] = vi[ZZ] + b * (vij[ZZ] - xij[ZZ] * vijDotXij * invNormXij * invNormXij);
428     }
429 }
430
431 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
432 static void constr_vsite3(const rvec   xi,
433                           const rvec   xj,
434                           const rvec   xk,
435                           rvec         x,
436                           real         a,
437                           real         b,
438                           const t_pbc* pbc,
439                           const rvec   vi,
440                           const rvec   vj,
441                           const rvec   vk,
442                           rvec         v)
443 {
444     const real c = 1 - a - b;
445     /* 2 flops */
446
447     if (calculatePosition == VSiteCalculatePosition::Yes)
448     {
449         if (pbc)
450         {
451             rvec dxj, dxk;
452
453             pbc_dx_aiuc(pbc, xj, xi, dxj);
454             pbc_dx_aiuc(pbc, xk, xi, dxk);
455             x[XX] = xi[XX] + a * dxj[XX] + b * dxk[XX];
456             x[YY] = xi[YY] + a * dxj[YY] + b * dxk[YY];
457             x[ZZ] = xi[ZZ] + a * dxj[ZZ] + b * dxk[ZZ];
458         }
459         else
460         {
461             x[XX] = c * xi[XX] + a * xj[XX] + b * xk[XX];
462             x[YY] = c * xi[YY] + a * xj[YY] + b * xk[YY];
463             x[ZZ] = c * xi[ZZ] + a * xj[ZZ] + b * xk[ZZ];
464             /* 15 Flops */
465         }
466         /* TOTAL: 17 flops */
467     }
468     if (calculateVelocity == VSiteCalculateVelocity::Yes)
469     {
470         v[XX] = c * vi[XX] + a * vj[XX] + b * vk[XX];
471         v[YY] = c * vi[YY] + a * vj[YY] + b * vk[YY];
472         v[ZZ] = c * vi[ZZ] + a * vj[ZZ] + b * vk[ZZ];
473     }
474 }
475
476 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
477 static void constr_vsite3FD(const rvec   xi,
478                             const rvec   xj,
479                             const rvec   xk,
480                             rvec         x,
481                             real         a,
482                             real         b,
483                             const t_pbc* pbc,
484                             const rvec   vi,
485                             const rvec   vj,
486                             const rvec   vk,
487                             rvec         v)
488 {
489     rvec xij, xjk, temp;
490
491     pbc_rvec_sub(pbc, xj, xi, xij);
492     pbc_rvec_sub(pbc, xk, xj, xjk);
493     /* 6 flops */
494
495     /* temp goes from i to a point on the line jk */
496     temp[XX] = xij[XX] + a * xjk[XX];
497     temp[YY] = xij[YY] + a * xjk[YY];
498     temp[ZZ] = xij[ZZ] + a * xjk[ZZ];
499     /* 6 flops */
500
501     const real invNormTemp = inverseNorm(temp);
502     const real c           = b * invNormTemp;
503     /* 6 + 10 flops */
504
505     if (calculatePosition == VSiteCalculatePosition::Yes)
506     {
507         x[XX] = xi[XX] + c * temp[XX];
508         x[YY] = xi[YY] + c * temp[YY];
509         x[ZZ] = xi[ZZ] + c * temp[ZZ];
510         /* 6 Flops */
511         /* TOTAL: 34 flops */
512     }
513     if (calculateVelocity == VSiteCalculateVelocity::Yes)
514     {
515         rvec vij = { 0 };
516         rvec vjk = { 0 };
517         rvec_sub(vj, vi, vij);
518         rvec_sub(vk, vj, vjk);
519         const rvec tempV = { vij[XX] + a * vjk[XX], vij[YY] + a * vjk[YY], vij[ZZ] + a * vjk[ZZ] };
520         const real tempDotTempV = iprod(temp, tempV);
521
522         v[XX] = vi[XX] + c * (tempV[XX] - temp[XX] * tempDotTempV * invNormTemp * invNormTemp);
523         v[YY] = vi[YY] + c * (tempV[YY] - temp[YY] * tempDotTempV * invNormTemp * invNormTemp);
524         v[ZZ] = vi[ZZ] + c * (tempV[ZZ] - temp[ZZ] * tempDotTempV * invNormTemp * invNormTemp);
525     }
526 }
527
528 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
529 static void constr_vsite3FAD(const rvec   xi,
530                              const rvec   xj,
531                              const rvec   xk,
532                              rvec         x,
533                              real         a,
534                              real         b,
535                              const t_pbc* pbc,
536                              const rvec   vi,
537                              const rvec   vj,
538                              const rvec   vk,
539                              rvec         v)
540 { // Note: a = d * cos(theta)
541     //       b = d * sin(theta)
542     rvec xij, xjk, xp;
543
544     pbc_rvec_sub(pbc, xj, xi, xij);
545     pbc_rvec_sub(pbc, xk, xj, xjk);
546     /* 6 flops */
547
548     const real invdij    = inverseNorm(xij);
549     const real xijDotXjk = iprod(xij, xjk);
550     const real c1        = invdij * invdij * xijDotXjk;
551     xp[XX]               = xjk[XX] - c1 * xij[XX];
552     xp[YY]               = xjk[YY] - c1 * xij[YY];
553     xp[ZZ]               = xjk[ZZ] - c1 * xij[ZZ];
554     const real a1        = a * invdij;
555     const real invNormXp = inverseNorm(xp);
556     const real b1        = b * invNormXp;
557     /* 45 */
558
559     if (calculatePosition == VSiteCalculatePosition::Yes)
560     {
561         x[XX] = xi[XX] + a1 * xij[XX] + b1 * xp[XX];
562         x[YY] = xi[YY] + a1 * xij[YY] + b1 * xp[YY];
563         x[ZZ] = xi[ZZ] + a1 * xij[ZZ] + b1 * xp[ZZ];
564         /* 12 Flops */
565         /* TOTAL: 63 flops */
566     }
567
568     if (calculateVelocity == VSiteCalculateVelocity::Yes)
569     {
570         rvec vij = { 0 };
571         rvec vjk = { 0 };
572         rvec_sub(vj, vi, vij);
573         rvec_sub(vk, vj, vjk);
574
575         const real vijDotXjkPlusXijDotVjk = iprod(vij, xjk) + iprod(xij, vjk);
576         const real xijDotVij              = iprod(xij, vij);
577         const real invNormXij2            = invdij * invdij;
578
579         rvec vp = { 0 };
580         vp[XX]  = vjk[XX]
581                  - xij[XX] * invNormXij2
582                            * (vijDotXjkPlusXijDotVjk - invNormXij2 * xijDotXjk * xijDotVij * 2)
583                  - vij[XX] * xijDotXjk * invNormXij2;
584         vp[YY] = vjk[YY]
585                  - xij[YY] * invNormXij2
586                            * (vijDotXjkPlusXijDotVjk - invNormXij2 * xijDotXjk * xijDotVij * 2)
587                  - vij[YY] * xijDotXjk * invNormXij2;
588         vp[ZZ] = vjk[ZZ]
589                  - xij[ZZ] * invNormXij2
590                            * (vijDotXjkPlusXijDotVjk - invNormXij2 * xijDotXjk * xijDotVij * 2)
591                  - vij[ZZ] * xijDotXjk * invNormXij2;
592
593         const real xpDotVp = iprod(xp, vp);
594
595         v[XX] = vi[XX] + a1 * (vij[XX] - xij[XX] * xijDotVij * invdij * invdij)
596                 + b1 * (vp[XX] - xp[XX] * xpDotVp * invNormXp * invNormXp);
597         v[YY] = vi[YY] + a1 * (vij[YY] - xij[YY] * xijDotVij * invdij * invdij)
598                 + b1 * (vp[YY] - xp[YY] * xpDotVp * invNormXp * invNormXp);
599         v[ZZ] = vi[ZZ] + a1 * (vij[ZZ] - xij[ZZ] * xijDotVij * invdij * invdij)
600                 + b1 * (vp[ZZ] - xp[ZZ] * xpDotVp * invNormXp * invNormXp);
601     }
602 }
603
604 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
605 static void constr_vsite3OUT(const rvec   xi,
606                              const rvec   xj,
607                              const rvec   xk,
608                              rvec         x,
609                              real         a,
610                              real         b,
611                              real         c,
612                              const t_pbc* pbc,
613                              const rvec   vi,
614                              const rvec   vj,
615                              const rvec   vk,
616                              rvec         v)
617 {
618     rvec xij, xik, temp;
619
620     pbc_rvec_sub(pbc, xj, xi, xij);
621     pbc_rvec_sub(pbc, xk, xi, xik);
622     cprod(xij, xik, temp);
623     /* 15 Flops */
624
625     if (calculatePosition == VSiteCalculatePosition::Yes)
626     {
627         x[XX] = xi[XX] + a * xij[XX] + b * xik[XX] + c * temp[XX];
628         x[YY] = xi[YY] + a * xij[YY] + b * xik[YY] + c * temp[YY];
629         x[ZZ] = xi[ZZ] + a * xij[ZZ] + b * xik[ZZ] + c * temp[ZZ];
630         /* 18 Flops */
631         /* TOTAL: 33 flops */
632     }
633
634     if (calculateVelocity == VSiteCalculateVelocity::Yes)
635     {
636         rvec vij = { 0 };
637         rvec vik = { 0 };
638         rvec_sub(vj, vi, vij);
639         rvec_sub(vk, vi, vik);
640
641         rvec temp1 = { 0 };
642         rvec temp2 = { 0 };
643         cprod(vij, xik, temp1);
644         cprod(xij, vik, temp2);
645
646         v[XX] = vi[XX] + a * vij[XX] + b * vik[XX] + c * (temp1[XX] + temp2[XX]);
647         v[YY] = vi[YY] + a * vij[YY] + b * vik[YY] + c * (temp1[YY] + temp2[YY]);
648         v[ZZ] = vi[ZZ] + a * vij[ZZ] + b * vik[ZZ] + c * (temp1[ZZ] + temp2[ZZ]);
649     }
650 }
651
652 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
653 static void constr_vsite4FD(const rvec   xi,
654                             const rvec   xj,
655                             const rvec   xk,
656                             const rvec   xl,
657                             rvec         x,
658                             real         a,
659                             real         b,
660                             real         c,
661                             const t_pbc* pbc,
662                             const rvec   vi,
663                             const rvec   vj,
664                             const rvec   vk,
665                             const rvec   vl,
666                             rvec         v)
667 {
668     rvec xij, xjk, xjl, temp;
669     real d;
670
671     pbc_rvec_sub(pbc, xj, xi, xij);
672     pbc_rvec_sub(pbc, xk, xj, xjk);
673     pbc_rvec_sub(pbc, xl, xj, xjl);
674     /* 9 flops */
675
676     /* temp goes from i to a point on the plane jkl */
677     temp[XX] = xij[XX] + a * xjk[XX] + b * xjl[XX];
678     temp[YY] = xij[YY] + a * xjk[YY] + b * xjl[YY];
679     temp[ZZ] = xij[ZZ] + a * xjk[ZZ] + b * xjl[ZZ];
680     /* 12 flops */
681
682     const real invRm = inverseNorm(temp);
683     d                = c * invRm;
684     /* 6 + 10 flops */
685
686     if (calculatePosition == VSiteCalculatePosition::Yes)
687     {
688         x[XX] = xi[XX] + d * temp[XX];
689         x[YY] = xi[YY] + d * temp[YY];
690         x[ZZ] = xi[ZZ] + d * temp[ZZ];
691         /* 6 Flops */
692         /* TOTAL: 43 flops */
693     }
694     if (calculateVelocity == VSiteCalculateVelocity::Yes)
695     {
696         rvec vij = { 0 };
697         rvec vjk = { 0 };
698         rvec vjl = { 0 };
699
700         rvec_sub(vj, vi, vij);
701         rvec_sub(vk, vj, vjk);
702         rvec_sub(vl, vj, vjl);
703
704         rvec vm = { 0 };
705         vm[XX]  = vij[XX] + a * vjk[XX] + b * vjl[XX];
706         vm[YY]  = vij[YY] + a * vjk[YY] + b * vjl[YY];
707         vm[ZZ]  = vij[ZZ] + a * vjk[ZZ] + b * vjl[ZZ];
708
709         const real vmDotRm = iprod(vm, temp);
710         v[XX]              = vi[XX] + d * (vm[XX] - temp[XX] * vmDotRm * invRm * invRm);
711         v[YY]              = vi[YY] + d * (vm[YY] - temp[YY] * vmDotRm * invRm * invRm);
712         v[ZZ]              = vi[ZZ] + d * (vm[ZZ] - temp[ZZ] * vmDotRm * invRm * invRm);
713     }
714 }
715
716 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
717 static void constr_vsite4FDN(const rvec   xi,
718                              const rvec   xj,
719                              const rvec   xk,
720                              const rvec   xl,
721                              rvec         x,
722                              real         a,
723                              real         b,
724                              real         c,
725                              const t_pbc* pbc,
726                              const rvec   vi,
727                              const rvec   vj,
728                              const rvec   vk,
729                              const rvec   vl,
730                              rvec         v)
731 {
732     rvec xij, xik, xil, ra, rb, rja, rjb, rm;
733     real d;
734
735     pbc_rvec_sub(pbc, xj, xi, xij);
736     pbc_rvec_sub(pbc, xk, xi, xik);
737     pbc_rvec_sub(pbc, xl, xi, xil);
738     /* 9 flops */
739
740     ra[XX] = a * xik[XX];
741     ra[YY] = a * xik[YY];
742     ra[ZZ] = a * xik[ZZ];
743
744     rb[XX] = b * xil[XX];
745     rb[YY] = b * xil[YY];
746     rb[ZZ] = b * xil[ZZ];
747
748     /* 6 flops */
749
750     rvec_sub(ra, xij, rja);
751     rvec_sub(rb, xij, rjb);
752     /* 6 flops */
753
754     cprod(rja, rjb, rm);
755     /* 9 flops */
756
757     const real invNormRm = inverseNorm(rm);
758     d                    = c * invNormRm;
759     /* 5+5+1 flops */
760
761     if (calculatePosition == VSiteCalculatePosition::Yes)
762     {
763         x[XX] = xi[XX] + d * rm[XX];
764         x[YY] = xi[YY] + d * rm[YY];
765         x[ZZ] = xi[ZZ] + d * rm[ZZ];
766         /* 6 Flops */
767         /* TOTAL: 47 flops */
768     }
769
770     if (calculateVelocity == VSiteCalculateVelocity::Yes)
771     {
772         rvec vij = { 0 };
773         rvec vik = { 0 };
774         rvec vil = { 0 };
775         rvec_sub(vj, vi, vij);
776         rvec_sub(vk, vi, vik);
777         rvec_sub(vl, vi, vil);
778
779         rvec vja = { 0 };
780         rvec vjb = { 0 };
781
782         vja[XX] = a * vik[XX] - vij[XX];
783         vja[YY] = a * vik[YY] - vij[YY];
784         vja[ZZ] = a * vik[ZZ] - vij[ZZ];
785         vjb[XX] = b * vil[XX] - vij[XX];
786         vjb[YY] = b * vil[YY] - vij[YY];
787         vjb[ZZ] = b * vil[ZZ] - vij[ZZ];
788
789         rvec temp1 = { 0 };
790         rvec temp2 = { 0 };
791         cprod(vja, rjb, temp1);
792         cprod(rja, vjb, temp2);
793
794         rvec vm = { 0 };
795         vm[XX]  = temp1[XX] + temp2[XX];
796         vm[YY]  = temp1[YY] + temp2[YY];
797         vm[ZZ]  = temp1[ZZ] + temp2[ZZ];
798
799         const real rmDotVm = iprod(rm, vm);
800         v[XX]              = vi[XX] + d * (vm[XX] - rm[XX] * rmDotVm * invNormRm * invNormRm);
801         v[YY]              = vi[YY] + d * (vm[YY] - rm[YY] * rmDotVm * invNormRm * invNormRm);
802         v[ZZ]              = vi[ZZ] + d * (vm[ZZ] - rm[ZZ] * rmDotVm * invNormRm * invNormRm);
803     }
804 }
805
806 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
807 static int constr_vsiten(const t_iatom*            ia,
808                          ArrayRef<const t_iparams> ip,
809                          ArrayRef<RVec>            x,
810                          const t_pbc*              pbc,
811                          ArrayRef<RVec>            v)
812 {
813     rvec x1, dx;
814     dvec dsum;
815     real a;
816     dvec dvsum = { 0 };
817     rvec v1    = { 0 };
818
819     const int n3 = 3 * ip[ia[0]].vsiten.n;
820     const int av = ia[1];
821     int       ai = ia[2];
822     copy_rvec(x[ai], x1);
823     copy_rvec(v[ai], v1);
824     clear_dvec(dsum);
825     for (int i = 3; i < n3; i += 3)
826     {
827         ai = ia[i + 2];
828         a  = ip[ia[i]].vsiten.a;
829         if (calculatePosition == VSiteCalculatePosition::Yes)
830         {
831             if (pbc)
832             {
833                 pbc_dx_aiuc(pbc, x[ai], x1, dx);
834             }
835             else
836             {
837                 rvec_sub(x[ai], x1, dx);
838             }
839             dsum[XX] += a * dx[XX];
840             dsum[YY] += a * dx[YY];
841             dsum[ZZ] += a * dx[ZZ];
842             /* 9 Flops */
843         }
844         if (calculateVelocity == VSiteCalculateVelocity::Yes)
845         {
846             rvec_sub(v[ai], v1, dx);
847             dvsum[XX] += a * dx[XX];
848             dvsum[YY] += a * dx[YY];
849             dvsum[ZZ] += a * dx[ZZ];
850             /* 9 Flops */
851         }
852     }
853
854     if (calculatePosition == VSiteCalculatePosition::Yes)
855     {
856         x[av][XX] = x1[XX] + dsum[XX];
857         x[av][YY] = x1[YY] + dsum[YY];
858         x[av][ZZ] = x1[ZZ] + dsum[ZZ];
859     }
860
861     if (calculateVelocity == VSiteCalculateVelocity::Yes)
862     {
863         v[av][XX] = v1[XX] + dvsum[XX];
864         v[av][YY] = v1[YY] + dvsum[YY];
865         v[av][ZZ] = v1[ZZ] + dvsum[ZZ];
866     }
867
868     return n3;
869 }
870 // End GCC 8 bug
871 GCC_DIAGNOSTIC_RESET
872
873 #endif // DOXYGEN
874
875 //! PBC modes for vsite construction and spreading
876 enum class PbcMode
877 {
878     all, //!< Apply normal, simple PBC for all vsites
879     none //!< No PBC treatment needed
880 };
881
882 /*! \brief Returns the PBC mode based on the system PBC and vsite properties
883  *
884  * \param[in] pbcPtr  A pointer to a PBC struct or nullptr when no PBC treatment is required
885  */
886 static PbcMode getPbcMode(const t_pbc* pbcPtr)
887 {
888     if (pbcPtr == nullptr)
889     {
890         return PbcMode::none;
891     }
892     else
893     {
894         return PbcMode::all;
895     }
896 }
897
898 /*! \brief Executes the vsite construction task for a single thread
899  *
900  * \tparam        operation  Whether we are calculating positions, velocities, or both
901  * \param[in,out] x   Coordinates to construct vsites for
902  * \param[in,out] v   Velocities are generated for virtual sites if `operation` requires it
903  * \param[in]     ip  Interaction parameters for all interaction, only vsite parameters are used
904  * \param[in]     ilist  The interaction lists, only vsites are usesd
905  * \param[in]     pbc_null  PBC struct, used for PBC distance calculations when !=nullptr
906  */
907 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
908 static void construct_vsites_thread(ArrayRef<RVec>                  x,
909                                     ArrayRef<RVec>                  v,
910                                     ArrayRef<const t_iparams>       ip,
911                                     ArrayRef<const InteractionList> ilist,
912                                     const t_pbc*                    pbc_null)
913 {
914     if (calculateVelocity == VSiteCalculateVelocity::Yes)
915     {
916         GMX_RELEASE_ASSERT(!v.empty(),
917                            "Can't calculate velocities without access to velocity vector.");
918     }
919
920     // Work around clang bug (unfixed as of Feb 2021)
921     // https://bugs.llvm.org/show_bug.cgi?id=35450
922     // clang-format off
923     CLANG_DIAGNOSTIC_IGNORE(-Wunused-lambda-capture)
924     // clang-format on
925     // GCC 8 falsely flags unused variables if constexpr prunes a code path, fixed in GCC 9
926     // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85827
927     // clang-format off
928     GCC_DIAGNOSTIC_IGNORE(-Wunused-but-set-parameter)
929     // clang-format on
930     // getVOrNull returns a velocity rvec if we need it, nullptr otherwise.
931     auto getVOrNull = [v](int idx) -> real* {
932         if (calculateVelocity == VSiteCalculateVelocity::Yes)
933         {
934             return v[idx].as_vec();
935         }
936         else
937         {
938             return nullptr;
939         }
940     };
941     GCC_DIAGNOSTIC_RESET
942     CLANG_DIAGNOSTIC_RESET
943
944     const PbcMode pbcMode = getPbcMode(pbc_null);
945     /* We need another pbc pointer, as with charge groups we switch per vsite */
946     const t_pbc* pbc_null2 = pbc_null;
947
948     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
949     {
950         if (ilist[ftype].empty())
951         {
952             continue;
953         }
954
955         { // TODO remove me
956             int nra = interaction_function[ftype].nratoms;
957             int inc = 1 + nra;
958             int nr  = ilist[ftype].size();
959
960             const t_iatom* ia = ilist[ftype].iatoms.data();
961
962             for (int i = 0; i < nr;)
963             {
964                 int tp = ia[0];
965                 /* The vsite and constructing atoms */
966                 int avsite = ia[1];
967                 int ai     = ia[2];
968                 /* Constants for constructing vsites */
969                 real a1 = ip[tp].vsite.a;
970                 /* Copy the old position */
971                 rvec xv;
972                 copy_rvec(x[avsite], xv);
973
974                 /* Construct the vsite depending on type */
975                 int  aj, ak, al;
976                 real b1, c1;
977                 switch (ftype)
978                 {
979                     case F_VSITE1:
980                         constr_vsite1<calculatePosition, calculateVelocity>(
981                                 x[ai], x[avsite], getVOrNull(ai), getVOrNull(avsite));
982                         break;
983                     case F_VSITE2:
984                         aj = ia[3];
985                         constr_vsite2<calculatePosition, calculateVelocity>(x[ai],
986                                                                             x[aj],
987                                                                             x[avsite],
988                                                                             a1,
989                                                                             pbc_null2,
990                                                                             getVOrNull(ai),
991                                                                             getVOrNull(aj),
992                                                                             getVOrNull(avsite));
993                         break;
994                     case F_VSITE2FD:
995                         aj = ia[3];
996                         constr_vsite2FD<calculatePosition, calculateVelocity>(x[ai],
997                                                                               x[aj],
998                                                                               x[avsite],
999                                                                               a1,
1000                                                                               pbc_null2,
1001                                                                               getVOrNull(ai),
1002                                                                               getVOrNull(aj),
1003                                                                               getVOrNull(avsite));
1004                         break;
1005                     case F_VSITE3:
1006                         aj = ia[3];
1007                         ak = ia[4];
1008                         b1 = ip[tp].vsite.b;
1009                         constr_vsite3<calculatePosition, calculateVelocity>(x[ai],
1010                                                                             x[aj],
1011                                                                             x[ak],
1012                                                                             x[avsite],
1013                                                                             a1,
1014                                                                             b1,
1015                                                                             pbc_null2,
1016                                                                             getVOrNull(ai),
1017                                                                             getVOrNull(aj),
1018                                                                             getVOrNull(ak),
1019                                                                             getVOrNull(avsite));
1020                         break;
1021                     case F_VSITE3FD:
1022                         aj = ia[3];
1023                         ak = ia[4];
1024                         b1 = ip[tp].vsite.b;
1025                         constr_vsite3FD<calculatePosition, calculateVelocity>(x[ai],
1026                                                                               x[aj],
1027                                                                               x[ak],
1028                                                                               x[avsite],
1029                                                                               a1,
1030                                                                               b1,
1031                                                                               pbc_null2,
1032                                                                               getVOrNull(ai),
1033                                                                               getVOrNull(aj),
1034                                                                               getVOrNull(ak),
1035                                                                               getVOrNull(avsite));
1036                         break;
1037                     case F_VSITE3FAD:
1038                         aj = ia[3];
1039                         ak = ia[4];
1040                         b1 = ip[tp].vsite.b;
1041                         constr_vsite3FAD<calculatePosition, calculateVelocity>(x[ai],
1042                                                                                x[aj],
1043                                                                                x[ak],
1044                                                                                x[avsite],
1045                                                                                a1,
1046                                                                                b1,
1047                                                                                pbc_null2,
1048                                                                                getVOrNull(ai),
1049                                                                                getVOrNull(aj),
1050                                                                                getVOrNull(ak),
1051                                                                                getVOrNull(avsite));
1052                         break;
1053                     case F_VSITE3OUT:
1054                         aj = ia[3];
1055                         ak = ia[4];
1056                         b1 = ip[tp].vsite.b;
1057                         c1 = ip[tp].vsite.c;
1058                         constr_vsite3OUT<calculatePosition, calculateVelocity>(x[ai],
1059                                                                                x[aj],
1060                                                                                x[ak],
1061                                                                                x[avsite],
1062                                                                                a1,
1063                                                                                b1,
1064                                                                                c1,
1065                                                                                pbc_null2,
1066                                                                                getVOrNull(ai),
1067                                                                                getVOrNull(aj),
1068                                                                                getVOrNull(ak),
1069                                                                                getVOrNull(avsite));
1070                         break;
1071                     case F_VSITE4FD:
1072                         aj = ia[3];
1073                         ak = ia[4];
1074                         al = ia[5];
1075                         b1 = ip[tp].vsite.b;
1076                         c1 = ip[tp].vsite.c;
1077                         constr_vsite4FD<calculatePosition, calculateVelocity>(x[ai],
1078                                                                               x[aj],
1079                                                                               x[ak],
1080                                                                               x[al],
1081                                                                               x[avsite],
1082                                                                               a1,
1083                                                                               b1,
1084                                                                               c1,
1085                                                                               pbc_null2,
1086                                                                               getVOrNull(ai),
1087                                                                               getVOrNull(aj),
1088                                                                               getVOrNull(ak),
1089                                                                               getVOrNull(al),
1090                                                                               getVOrNull(avsite));
1091                         break;
1092                     case F_VSITE4FDN:
1093                         aj = ia[3];
1094                         ak = ia[4];
1095                         al = ia[5];
1096                         b1 = ip[tp].vsite.b;
1097                         c1 = ip[tp].vsite.c;
1098                         constr_vsite4FDN<calculatePosition, calculateVelocity>(x[ai],
1099                                                                                x[aj],
1100                                                                                x[ak],
1101                                                                                x[al],
1102                                                                                x[avsite],
1103                                                                                a1,
1104                                                                                b1,
1105                                                                                c1,
1106                                                                                pbc_null2,
1107                                                                                getVOrNull(ai),
1108                                                                                getVOrNull(aj),
1109                                                                                getVOrNull(ak),
1110                                                                                getVOrNull(al),
1111                                                                                getVOrNull(avsite));
1112                         break;
1113                     case F_VSITEN:
1114                         inc = constr_vsiten<calculatePosition, calculateVelocity>(ia, ip, x, pbc_null2, v);
1115                         break;
1116                     default:
1117                         gmx_fatal(FARGS, "No such vsite type %d in %s, line %d", ftype, __FILE__, __LINE__);
1118                 }
1119
1120                 if (pbcMode == PbcMode::all)
1121                 {
1122                     /* Keep the vsite in the same periodic image as before */
1123                     rvec dx;
1124                     int  ishift = pbc_dx_aiuc(pbc_null, x[avsite], xv, dx);
1125                     if (ishift != c_centralShiftIndex)
1126                     {
1127                         rvec_add(xv, dx, x[avsite]);
1128                     }
1129                 }
1130
1131                 /* Increment loop variables */
1132                 i += inc;
1133                 ia += inc;
1134             }
1135         }
1136     }
1137 }
1138
1139 /*! \brief Dispatch the vsite construction tasks for all threads
1140  *
1141  * \param[in]     threadingInfo  Used to divide work over threads when != nullptr
1142  * \param[in,out] x   Coordinates to construct vsites for
1143  * \param[in,out] v   When not empty, velocities are generated for virtual sites
1144  * \param[in]     ip  Interaction parameters for all interaction, only vsite parameters are used
1145  * \param[in]     ilist  The interaction lists, only vsites are usesd
1146  * \param[in]     domainInfo  Information about PBC and DD
1147  * \param[in]     box  Used for PBC when PBC is set in domainInfo
1148  */
1149 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
1150 static void construct_vsites(const ThreadingInfo*            threadingInfo,
1151                              ArrayRef<RVec>                  x,
1152                              ArrayRef<RVec>                  v,
1153                              ArrayRef<const t_iparams>       ip,
1154                              ArrayRef<const InteractionList> ilist,
1155                              const DomainInfo&               domainInfo,
1156                              const matrix                    box)
1157 {
1158     const bool useDomdec = domainInfo.useDomdec();
1159
1160     t_pbc pbc, *pbc_null;
1161
1162     /* We only need to do pbc when we have inter update-group vsites.
1163      * Note that with domain decomposition we do not need to apply PBC here
1164      * when we have at least 3 domains along each dimension. Currently we
1165      * do not optimize this case.
1166      */
1167     if (domainInfo.pbcType_ != PbcType::No && domainInfo.useMolPbc_)
1168     {
1169         /* This is wasting some CPU time as we now do this multiple times
1170          * per MD step.
1171          */
1172         ivec null_ivec;
1173         clear_ivec(null_ivec);
1174         pbc_null = set_pbc_dd(
1175                 &pbc, domainInfo.pbcType_, useDomdec ? domainInfo.domdec_->numCells : null_ivec, FALSE, box);
1176     }
1177     else
1178     {
1179         pbc_null = nullptr;
1180     }
1181
1182     if (useDomdec)
1183     {
1184         if (calculateVelocity == VSiteCalculateVelocity::Yes)
1185         {
1186             dd_move_x_and_v_vsites(
1187                     *domainInfo.domdec_, box, as_rvec_array(x.data()), as_rvec_array(v.data()));
1188         }
1189         else
1190         {
1191             dd_move_x_vsites(*domainInfo.domdec_, box, as_rvec_array(x.data()));
1192         }
1193     }
1194
1195     if (threadingInfo == nullptr || threadingInfo->numThreads() == 1)
1196     {
1197         construct_vsites_thread<calculatePosition, calculateVelocity>(x, v, ip, ilist, pbc_null);
1198     }
1199     else
1200     {
1201 #pragma omp parallel num_threads(threadingInfo->numThreads())
1202         {
1203             try
1204             {
1205                 const int          th    = gmx_omp_get_thread_num();
1206                 const VsiteThread& tData = threadingInfo->threadData(th);
1207                 GMX_ASSERT(tData.rangeStart >= 0,
1208                            "The thread data should be initialized before calling construct_vsites");
1209
1210                 construct_vsites_thread<calculatePosition, calculateVelocity>(
1211                         x, v, ip, tData.ilist, pbc_null);
1212                 if (tData.useInterdependentTask)
1213                 {
1214                     /* Here we don't need a barrier (unlike the spreading),
1215                      * since both tasks only construct vsites from particles,
1216                      * or local vsites, not from non-local vsites.
1217                      */
1218                     construct_vsites_thread<calculatePosition, calculateVelocity>(
1219                             x, v, ip, tData.idTask.ilist, pbc_null);
1220                 }
1221             }
1222             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
1223         }
1224         /* Now we can construct the vsites that might depend on other vsites */
1225         construct_vsites_thread<calculatePosition, calculateVelocity>(
1226                 x, v, ip, threadingInfo->threadDataNonLocalDependent().ilist, pbc_null);
1227     }
1228 }
1229
1230 void VirtualSitesHandler::Impl::construct(ArrayRef<RVec> x,
1231                                           ArrayRef<RVec> v,
1232                                           const matrix   box,
1233                                           VSiteOperation operation) const
1234 {
1235     switch (operation)
1236     {
1237         case VSiteOperation::Positions:
1238             construct_vsites<VSiteCalculatePosition::Yes, VSiteCalculateVelocity::No>(
1239                     &threadingInfo_, x, v, iparams_, ilists_, domainInfo_, box);
1240             break;
1241         case VSiteOperation::Velocities:
1242             construct_vsites<VSiteCalculatePosition::No, VSiteCalculateVelocity::Yes>(
1243                     &threadingInfo_, x, v, iparams_, ilists_, domainInfo_, box);
1244             break;
1245         case VSiteOperation::PositionsAndVelocities:
1246             construct_vsites<VSiteCalculatePosition::Yes, VSiteCalculateVelocity::Yes>(
1247                     &threadingInfo_, x, v, iparams_, ilists_, domainInfo_, box);
1248             break;
1249         default: gmx_fatal(FARGS, "Unknown virtual site operation");
1250     }
1251 }
1252
1253 void VirtualSitesHandler::construct(ArrayRef<RVec> x, ArrayRef<RVec> v, const matrix box, VSiteOperation operation) const
1254 {
1255     impl_->construct(x, v, box, operation);
1256 }
1257
1258 void constructVirtualSites(ArrayRef<RVec> x, ArrayRef<const t_iparams> ip, ArrayRef<const InteractionList> ilist)
1259
1260 {
1261     // No PBC, no DD
1262     const DomainInfo domainInfo;
1263     construct_vsites<VSiteCalculatePosition::Yes, VSiteCalculateVelocity::No>(
1264             nullptr, x, {}, ip, ilist, domainInfo, nullptr);
1265 }
1266
1267 #ifndef DOXYGEN
1268 /* Force spreading routines */
1269
1270 static void spread_vsite1(const t_iatom ia[], ArrayRef<RVec> f)
1271 {
1272     const int av = ia[1];
1273     const int ai = ia[2];
1274
1275     f[av] += f[ai];
1276 }
1277
1278 template<VirialHandling virialHandling>
1279 static void spread_vsite2(const t_iatom        ia[],
1280                           real                 a,
1281                           ArrayRef<const RVec> x,
1282                           ArrayRef<RVec>       f,
1283                           ArrayRef<RVec>       fshift,
1284                           const t_pbc*         pbc)
1285 {
1286     rvec    fi, fj, dx;
1287     t_iatom av, ai, aj;
1288
1289     av = ia[1];
1290     ai = ia[2];
1291     aj = ia[3];
1292
1293     svmul(1 - a, f[av], fi);
1294     svmul(a, f[av], fj);
1295     /* 7 flop */
1296
1297     rvec_inc(f[ai], fi);
1298     rvec_inc(f[aj], fj);
1299     /* 6 Flops */
1300
1301     if (virialHandling == VirialHandling::Pbc)
1302     {
1303         int siv;
1304         int sij;
1305         if (pbc)
1306         {
1307             siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
1308             sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
1309         }
1310         else
1311         {
1312             siv = c_centralShiftIndex;
1313             sij = c_centralShiftIndex;
1314         }
1315
1316         if (siv != c_centralShiftIndex || sij != c_centralShiftIndex)
1317         {
1318             rvec_inc(fshift[siv], f[av]);
1319             rvec_dec(fshift[c_centralShiftIndex], fi);
1320             rvec_dec(fshift[sij], fj);
1321         }
1322     }
1323
1324     /* TOTAL: 13 flops */
1325 }
1326
1327 void constructVirtualSitesGlobal(const gmx_mtop_t& mtop, gmx::ArrayRef<gmx::RVec> x)
1328 {
1329     GMX_ASSERT(x.ssize() >= mtop.natoms, "x should contain the whole system");
1330     GMX_ASSERT(!mtop.moleculeBlockIndices.empty(),
1331                "molblock indices are needed in constructVsitesGlobal");
1332
1333     for (size_t mb = 0; mb < mtop.molblock.size(); mb++)
1334     {
1335         const gmx_molblock_t& molb = mtop.molblock[mb];
1336         const gmx_moltype_t&  molt = mtop.moltype[molb.type];
1337         if (vsiteIlistNrCount(molt.ilist) > 0)
1338         {
1339             int atomOffset = mtop.moleculeBlockIndices[mb].globalAtomStart;
1340             for (int mol = 0; mol < molb.nmol; mol++)
1341             {
1342                 constructVirtualSites(
1343                         x.subArray(atomOffset, molt.atoms.nr), mtop.ffparams.iparams, molt.ilist);
1344                 atomOffset += molt.atoms.nr;
1345             }
1346         }
1347     }
1348 }
1349
1350 template<VirialHandling virialHandling>
1351 static void spread_vsite2FD(const t_iatom        ia[],
1352                             real                 a,
1353                             ArrayRef<const RVec> x,
1354                             ArrayRef<RVec>       f,
1355                             ArrayRef<RVec>       fshift,
1356                             matrix               dxdf,
1357                             const t_pbc*         pbc)
1358 {
1359     const int av = ia[1];
1360     const int ai = ia[2];
1361     const int aj = ia[3];
1362     rvec      fv;
1363     copy_rvec(f[av], fv);
1364
1365     rvec xij;
1366     int  sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1367     /* 6 flops */
1368
1369     const real invDistance = inverseNorm(xij);
1370     const real b           = a * invDistance;
1371     /* 4 + ?10? flops */
1372
1373     const real fproj = iprod(xij, fv) * invDistance * invDistance;
1374
1375     rvec fj;
1376     fj[XX] = b * (fv[XX] - fproj * xij[XX]);
1377     fj[YY] = b * (fv[YY] - fproj * xij[YY]);
1378     fj[ZZ] = b * (fv[ZZ] - fproj * xij[ZZ]);
1379     /* 9 */
1380
1381     /* b is already calculated in constr_vsite2FD
1382        storing b somewhere will save flops.     */
1383
1384     f[ai][XX] += fv[XX] - fj[XX];
1385     f[ai][YY] += fv[YY] - fj[YY];
1386     f[ai][ZZ] += fv[ZZ] - fj[ZZ];
1387     f[aj][XX] += fj[XX];
1388     f[aj][YY] += fj[YY];
1389     f[aj][ZZ] += fj[ZZ];
1390     /* 9 Flops */
1391
1392     if (virialHandling == VirialHandling::Pbc)
1393     {
1394         int svi;
1395         if (pbc)
1396         {
1397             rvec xvi;
1398             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1399         }
1400         else
1401         {
1402             svi = c_centralShiftIndex;
1403         }
1404
1405         if (svi != c_centralShiftIndex || sji != c_centralShiftIndex)
1406         {
1407             rvec_dec(fshift[svi], fv);
1408             fshift[c_centralShiftIndex][XX] += fv[XX] - fj[XX];
1409             fshift[c_centralShiftIndex][YY] += fv[YY] - fj[YY];
1410             fshift[c_centralShiftIndex][ZZ] += fv[ZZ] - fj[ZZ];
1411             fshift[sji][XX] += fj[XX];
1412             fshift[sji][YY] += fj[YY];
1413             fshift[sji][ZZ] += fj[ZZ];
1414         }
1415     }
1416
1417     if (virialHandling == VirialHandling::NonLinear)
1418     {
1419         /* Under this condition, the virial for the current forces is not
1420          * calculated from the redistributed forces. This means that
1421          * the effect of non-linear virtual site constructions on the virial
1422          * needs to be added separately. This contribution can be calculated
1423          * in many ways, but the simplest and cheapest way is to use
1424          * the first constructing atom ai as a reference position in space:
1425          * subtract (xv-xi)*fv and add (xj-xi)*fj.
1426          */
1427         rvec xiv;
1428
1429         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1430
1431         for (int i = 0; i < DIM; i++)
1432         {
1433             for (int j = 0; j < DIM; j++)
1434             {
1435                 /* As xix is a linear combination of j and k, use that here */
1436                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * fj[j];
1437             }
1438         }
1439     }
1440
1441     /* TOTAL: 38 flops */
1442 }
1443
1444 template<VirialHandling virialHandling>
1445 static void spread_vsite3(const t_iatom        ia[],
1446                           real                 a,
1447                           real                 b,
1448                           ArrayRef<const RVec> x,
1449                           ArrayRef<RVec>       f,
1450                           ArrayRef<RVec>       fshift,
1451                           const t_pbc*         pbc)
1452 {
1453     rvec fi, fj, fk, dx;
1454     int  av, ai, aj, ak;
1455
1456     av = ia[1];
1457     ai = ia[2];
1458     aj = ia[3];
1459     ak = ia[4];
1460
1461     svmul(1 - a - b, f[av], fi);
1462     svmul(a, f[av], fj);
1463     svmul(b, f[av], fk);
1464     /* 11 flops */
1465
1466     rvec_inc(f[ai], fi);
1467     rvec_inc(f[aj], fj);
1468     rvec_inc(f[ak], fk);
1469     /* 9 Flops */
1470
1471     if (virialHandling == VirialHandling::Pbc)
1472     {
1473         int siv;
1474         int sij;
1475         int sik;
1476         if (pbc)
1477         {
1478             siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
1479             sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
1480             sik = pbc_dx_aiuc(pbc, x[ai], x[ak], dx);
1481         }
1482         else
1483         {
1484             siv = c_centralShiftIndex;
1485             sij = c_centralShiftIndex;
1486             sik = c_centralShiftIndex;
1487         }
1488
1489         if (siv != c_centralShiftIndex || sij != c_centralShiftIndex || sik != c_centralShiftIndex)
1490         {
1491             rvec_inc(fshift[siv], f[av]);
1492             rvec_dec(fshift[c_centralShiftIndex], fi);
1493             rvec_dec(fshift[sij], fj);
1494             rvec_dec(fshift[sik], fk);
1495         }
1496     }
1497
1498     /* TOTAL: 20 flops */
1499 }
1500
1501 template<VirialHandling virialHandling>
1502 static void spread_vsite3FD(const t_iatom        ia[],
1503                             real                 a,
1504                             real                 b,
1505                             ArrayRef<const RVec> x,
1506                             ArrayRef<RVec>       f,
1507                             ArrayRef<RVec>       fshift,
1508                             matrix               dxdf,
1509                             const t_pbc*         pbc)
1510 {
1511     real    fproj, a1;
1512     rvec    xvi, xij, xjk, xix, fv, temp;
1513     t_iatom av, ai, aj, ak;
1514     int     sji, skj;
1515
1516     av = ia[1];
1517     ai = ia[2];
1518     aj = ia[3];
1519     ak = ia[4];
1520     copy_rvec(f[av], fv);
1521
1522     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1523     skj = pbc_rvec_sub(pbc, x[ak], x[aj], xjk);
1524     /* 6 flops */
1525
1526     /* xix goes from i to point x on the line jk */
1527     xix[XX] = xij[XX] + a * xjk[XX];
1528     xix[YY] = xij[YY] + a * xjk[YY];
1529     xix[ZZ] = xij[ZZ] + a * xjk[ZZ];
1530     /* 6 flops */
1531
1532     const real invDistance = inverseNorm(xix);
1533     const real c           = b * invDistance;
1534     /* 4 + ?10? flops */
1535
1536     fproj = iprod(xix, fv) * invDistance * invDistance; /* = (xix . f)/(xix . xix) */
1537
1538     temp[XX] = c * (fv[XX] - fproj * xix[XX]);
1539     temp[YY] = c * (fv[YY] - fproj * xix[YY]);
1540     temp[ZZ] = c * (fv[ZZ] - fproj * xix[ZZ]);
1541     /* 16 */
1542
1543     /* c is already calculated in constr_vsite3FD
1544        storing c somewhere will save 26 flops!     */
1545
1546     a1 = 1 - a;
1547     f[ai][XX] += fv[XX] - temp[XX];
1548     f[ai][YY] += fv[YY] - temp[YY];
1549     f[ai][ZZ] += fv[ZZ] - temp[ZZ];
1550     f[aj][XX] += a1 * temp[XX];
1551     f[aj][YY] += a1 * temp[YY];
1552     f[aj][ZZ] += a1 * temp[ZZ];
1553     f[ak][XX] += a * temp[XX];
1554     f[ak][YY] += a * temp[YY];
1555     f[ak][ZZ] += a * temp[ZZ];
1556     /* 19 Flops */
1557
1558     if (virialHandling == VirialHandling::Pbc)
1559     {
1560         int svi;
1561         if (pbc)
1562         {
1563             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1564         }
1565         else
1566         {
1567             svi = c_centralShiftIndex;
1568         }
1569
1570         if (svi != c_centralShiftIndex || sji != c_centralShiftIndex || skj != c_centralShiftIndex)
1571         {
1572             rvec_dec(fshift[svi], fv);
1573             fshift[c_centralShiftIndex][XX] += fv[XX] - (1 + a) * temp[XX];
1574             fshift[c_centralShiftIndex][YY] += fv[YY] - (1 + a) * temp[YY];
1575             fshift[c_centralShiftIndex][ZZ] += fv[ZZ] - (1 + a) * temp[ZZ];
1576             fshift[sji][XX] += temp[XX];
1577             fshift[sji][YY] += temp[YY];
1578             fshift[sji][ZZ] += temp[ZZ];
1579             fshift[skj][XX] += a * temp[XX];
1580             fshift[skj][YY] += a * temp[YY];
1581             fshift[skj][ZZ] += a * temp[ZZ];
1582         }
1583     }
1584
1585     if (virialHandling == VirialHandling::NonLinear)
1586     {
1587         /* Under this condition, the virial for the current forces is not
1588          * calculated from the redistributed forces. This means that
1589          * the effect of non-linear virtual site constructions on the virial
1590          * needs to be added separately. This contribution can be calculated
1591          * in many ways, but the simplest and cheapest way is to use
1592          * the first constructing atom ai as a reference position in space:
1593          * subtract (xv-xi)*fv and add (xj-xi)*fj + (xk-xi)*fk.
1594          */
1595         rvec xiv;
1596
1597         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1598
1599         for (int i = 0; i < DIM; i++)
1600         {
1601             for (int j = 0; j < DIM; j++)
1602             {
1603                 /* As xix is a linear combination of j and k, use that here */
1604                 dxdf[i][j] += -xiv[i] * fv[j] + xix[i] * temp[j];
1605             }
1606         }
1607     }
1608
1609     /* TOTAL: 61 flops */
1610 }
1611
1612 template<VirialHandling virialHandling>
1613 static void spread_vsite3FAD(const t_iatom        ia[],
1614                              real                 a,
1615                              real                 b,
1616                              ArrayRef<const RVec> x,
1617                              ArrayRef<RVec>       f,
1618                              ArrayRef<RVec>       fshift,
1619                              matrix               dxdf,
1620                              const t_pbc*         pbc)
1621 {
1622     rvec    xvi, xij, xjk, xperp, Fpij, Fppp, fv, f1, f2, f3;
1623     real    a1, b1, c1, c2, invdij, invdij2, invdp, fproj;
1624     t_iatom av, ai, aj, ak;
1625     int     sji, skj;
1626
1627     av = ia[1];
1628     ai = ia[2];
1629     aj = ia[3];
1630     ak = ia[4];
1631     copy_rvec(f[ia[1]], fv);
1632
1633     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1634     skj = pbc_rvec_sub(pbc, x[ak], x[aj], xjk);
1635     /* 6 flops */
1636
1637     invdij    = inverseNorm(xij);
1638     invdij2   = invdij * invdij;
1639     c1        = iprod(xij, xjk) * invdij2;
1640     xperp[XX] = xjk[XX] - c1 * xij[XX];
1641     xperp[YY] = xjk[YY] - c1 * xij[YY];
1642     xperp[ZZ] = xjk[ZZ] - c1 * xij[ZZ];
1643     /* xperp in plane ijk, perp. to ij */
1644     invdp = inverseNorm(xperp);
1645     a1    = a * invdij;
1646     b1    = b * invdp;
1647     /* 45 flops */
1648
1649     /* a1, b1 and c1 are already calculated in constr_vsite3FAD
1650        storing them somewhere will save 45 flops!     */
1651
1652     fproj = iprod(xij, fv) * invdij2;
1653     svmul(fproj, xij, Fpij);                              /* proj. f on xij */
1654     svmul(iprod(xperp, fv) * invdp * invdp, xperp, Fppp); /* proj. f on xperp */
1655     svmul(b1 * fproj, xperp, f3);
1656     /* 23 flops */
1657
1658     rvec_sub(fv, Fpij, f1); /* f1 = f - Fpij */
1659     rvec_sub(f1, Fppp, f2); /* f2 = f - Fpij - Fppp */
1660     for (int d = 0; d < DIM; d++)
1661     {
1662         f1[d] *= a1;
1663         f2[d] *= b1;
1664     }
1665     /* 12 flops */
1666
1667     c2 = 1 + c1;
1668     f[ai][XX] += fv[XX] - f1[XX] + c1 * f2[XX] + f3[XX];
1669     f[ai][YY] += fv[YY] - f1[YY] + c1 * f2[YY] + f3[YY];
1670     f[ai][ZZ] += fv[ZZ] - f1[ZZ] + c1 * f2[ZZ] + f3[ZZ];
1671     f[aj][XX] += f1[XX] - c2 * f2[XX] - f3[XX];
1672     f[aj][YY] += f1[YY] - c2 * f2[YY] - f3[YY];
1673     f[aj][ZZ] += f1[ZZ] - c2 * f2[ZZ] - f3[ZZ];
1674     f[ak][XX] += f2[XX];
1675     f[ak][YY] += f2[YY];
1676     f[ak][ZZ] += f2[ZZ];
1677     /* 30 Flops */
1678
1679     if (virialHandling == VirialHandling::Pbc)
1680     {
1681         int svi;
1682
1683         if (pbc)
1684         {
1685             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1686         }
1687         else
1688         {
1689             svi = c_centralShiftIndex;
1690         }
1691
1692         if (svi != c_centralShiftIndex || sji != c_centralShiftIndex || skj != c_centralShiftIndex)
1693         {
1694             rvec_dec(fshift[svi], fv);
1695             fshift[c_centralShiftIndex][XX] += fv[XX] - f1[XX] - (1 - c1) * f2[XX] + f3[XX];
1696             fshift[c_centralShiftIndex][YY] += fv[YY] - f1[YY] - (1 - c1) * f2[YY] + f3[YY];
1697             fshift[c_centralShiftIndex][ZZ] += fv[ZZ] - f1[ZZ] - (1 - c1) * f2[ZZ] + f3[ZZ];
1698             fshift[sji][XX] += f1[XX] - c1 * f2[XX] - f3[XX];
1699             fshift[sji][YY] += f1[YY] - c1 * f2[YY] - f3[YY];
1700             fshift[sji][ZZ] += f1[ZZ] - c1 * f2[ZZ] - f3[ZZ];
1701             fshift[skj][XX] += f2[XX];
1702             fshift[skj][YY] += f2[YY];
1703             fshift[skj][ZZ] += f2[ZZ];
1704         }
1705     }
1706
1707     if (virialHandling == VirialHandling::NonLinear)
1708     {
1709         rvec xiv;
1710         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1711
1712         for (int i = 0; i < DIM; i++)
1713         {
1714             for (int j = 0; j < DIM; j++)
1715             {
1716                 /* Note that xik=xij+xjk, so we have to add xij*f2 */
1717                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * (f1[j] + (1 - c2) * f2[j] - f3[j])
1718                               + xjk[i] * f2[j];
1719             }
1720         }
1721     }
1722
1723     /* TOTAL: 113 flops */
1724 }
1725
1726 template<VirialHandling virialHandling>
1727 static void spread_vsite3OUT(const t_iatom        ia[],
1728                              real                 a,
1729                              real                 b,
1730                              real                 c,
1731                              ArrayRef<const RVec> x,
1732                              ArrayRef<RVec>       f,
1733                              ArrayRef<RVec>       fshift,
1734                              matrix               dxdf,
1735                              const t_pbc*         pbc)
1736 {
1737     rvec xvi, xij, xik, fv, fj, fk;
1738     real cfx, cfy, cfz;
1739     int  av, ai, aj, ak;
1740     int  sji, ski;
1741
1742     av = ia[1];
1743     ai = ia[2];
1744     aj = ia[3];
1745     ak = ia[4];
1746
1747     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1748     ski = pbc_rvec_sub(pbc, x[ak], x[ai], xik);
1749     /* 6 Flops */
1750
1751     copy_rvec(f[av], fv);
1752
1753     cfx = c * fv[XX];
1754     cfy = c * fv[YY];
1755     cfz = c * fv[ZZ];
1756     /* 3 Flops */
1757
1758     fj[XX] = a * fv[XX] - xik[ZZ] * cfy + xik[YY] * cfz;
1759     fj[YY] = xik[ZZ] * cfx + a * fv[YY] - xik[XX] * cfz;
1760     fj[ZZ] = -xik[YY] * cfx + xik[XX] * cfy + a * fv[ZZ];
1761
1762     fk[XX] = b * fv[XX] + xij[ZZ] * cfy - xij[YY] * cfz;
1763     fk[YY] = -xij[ZZ] * cfx + b * fv[YY] + xij[XX] * cfz;
1764     fk[ZZ] = xij[YY] * cfx - xij[XX] * cfy + b * fv[ZZ];
1765     /* 30 Flops */
1766
1767     f[ai][XX] += fv[XX] - fj[XX] - fk[XX];
1768     f[ai][YY] += fv[YY] - fj[YY] - fk[YY];
1769     f[ai][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ];
1770     rvec_inc(f[aj], fj);
1771     rvec_inc(f[ak], fk);
1772     /* 15 Flops */
1773
1774     if (virialHandling == VirialHandling::Pbc)
1775     {
1776         int svi;
1777         if (pbc)
1778         {
1779             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1780         }
1781         else
1782         {
1783             svi = c_centralShiftIndex;
1784         }
1785
1786         if (svi != c_centralShiftIndex || sji != c_centralShiftIndex || ski != c_centralShiftIndex)
1787         {
1788             rvec_dec(fshift[svi], fv);
1789             fshift[c_centralShiftIndex][XX] += fv[XX] - fj[XX] - fk[XX];
1790             fshift[c_centralShiftIndex][YY] += fv[YY] - fj[YY] - fk[YY];
1791             fshift[c_centralShiftIndex][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ];
1792             rvec_inc(fshift[sji], fj);
1793             rvec_inc(fshift[ski], fk);
1794         }
1795     }
1796
1797     if (virialHandling == VirialHandling::NonLinear)
1798     {
1799         rvec xiv;
1800
1801         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1802
1803         for (int i = 0; i < DIM; i++)
1804         {
1805             for (int j = 0; j < DIM; j++)
1806             {
1807                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * fj[j] + xik[i] * fk[j];
1808             }
1809         }
1810     }
1811
1812     /* TOTAL: 54 flops */
1813 }
1814
1815 template<VirialHandling virialHandling>
1816 static void spread_vsite4FD(const t_iatom        ia[],
1817                             real                 a,
1818                             real                 b,
1819                             real                 c,
1820                             ArrayRef<const RVec> x,
1821                             ArrayRef<RVec>       f,
1822                             ArrayRef<RVec>       fshift,
1823                             matrix               dxdf,
1824                             const t_pbc*         pbc)
1825 {
1826     real fproj, a1;
1827     rvec xvi, xij, xjk, xjl, xix, fv, temp;
1828     int  av, ai, aj, ak, al;
1829     int  sji, skj, slj, m;
1830
1831     av = ia[1];
1832     ai = ia[2];
1833     aj = ia[3];
1834     ak = ia[4];
1835     al = ia[5];
1836
1837     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1838     skj = pbc_rvec_sub(pbc, x[ak], x[aj], xjk);
1839     slj = pbc_rvec_sub(pbc, x[al], x[aj], xjl);
1840     /* 9 flops */
1841
1842     /* xix goes from i to point x on the plane jkl */
1843     for (m = 0; m < DIM; m++)
1844     {
1845         xix[m] = xij[m] + a * xjk[m] + b * xjl[m];
1846     }
1847     /* 12 flops */
1848
1849     const real invDistance = inverseNorm(xix);
1850     const real d           = c * invDistance;
1851     /* 4 + ?10? flops */
1852
1853     copy_rvec(f[av], fv);
1854
1855     fproj = iprod(xix, fv) * invDistance * invDistance; /* = (xix . f)/(xix . xix) */
1856
1857     for (m = 0; m < DIM; m++)
1858     {
1859         temp[m] = d * (fv[m] - fproj * xix[m]);
1860     }
1861     /* 16 */
1862
1863     /* c is already calculated in constr_vsite3FD
1864        storing c somewhere will save 35 flops!     */
1865
1866     a1 = 1 - a - b;
1867     for (m = 0; m < DIM; m++)
1868     {
1869         f[ai][m] += fv[m] - temp[m];
1870         f[aj][m] += a1 * temp[m];
1871         f[ak][m] += a * temp[m];
1872         f[al][m] += b * temp[m];
1873     }
1874     /* 26 Flops */
1875
1876     if (virialHandling == VirialHandling::Pbc)
1877     {
1878         int svi;
1879         if (pbc)
1880         {
1881             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1882         }
1883         else
1884         {
1885             svi = c_centralShiftIndex;
1886         }
1887
1888         if (svi != c_centralShiftIndex || sji != c_centralShiftIndex || skj != c_centralShiftIndex
1889             || slj != c_centralShiftIndex)
1890         {
1891             rvec_dec(fshift[svi], fv);
1892             for (m = 0; m < DIM; m++)
1893             {
1894                 fshift[c_centralShiftIndex][m] += fv[m] - (1 + a + b) * temp[m];
1895                 fshift[sji][m] += temp[m];
1896                 fshift[skj][m] += a * temp[m];
1897                 fshift[slj][m] += b * temp[m];
1898             }
1899         }
1900     }
1901
1902     if (virialHandling == VirialHandling::NonLinear)
1903     {
1904         rvec xiv;
1905         int  i, j;
1906
1907         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1908
1909         for (i = 0; i < DIM; i++)
1910         {
1911             for (j = 0; j < DIM; j++)
1912             {
1913                 dxdf[i][j] += -xiv[i] * fv[j] + xix[i] * temp[j];
1914             }
1915         }
1916     }
1917
1918     /* TOTAL: 77 flops */
1919 }
1920
1921 template<VirialHandling virialHandling>
1922 static void spread_vsite4FDN(const t_iatom        ia[],
1923                              real                 a,
1924                              real                 b,
1925                              real                 c,
1926                              ArrayRef<const RVec> x,
1927                              ArrayRef<RVec>       f,
1928                              ArrayRef<RVec>       fshift,
1929                              matrix               dxdf,
1930                              const t_pbc*         pbc)
1931 {
1932     rvec xvi, xij, xik, xil, ra, rb, rja, rjb, rab, rm, rt;
1933     rvec fv, fj, fk, fl;
1934     real invrm, denom;
1935     real cfx, cfy, cfz;
1936     int  av, ai, aj, ak, al;
1937     int  sij, sik, sil;
1938
1939     /* DEBUG: check atom indices */
1940     av = ia[1];
1941     ai = ia[2];
1942     aj = ia[3];
1943     ak = ia[4];
1944     al = ia[5];
1945
1946     copy_rvec(f[av], fv);
1947
1948     sij = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1949     sik = pbc_rvec_sub(pbc, x[ak], x[ai], xik);
1950     sil = pbc_rvec_sub(pbc, x[al], x[ai], xil);
1951     /* 9 flops */
1952
1953     ra[XX] = a * xik[XX];
1954     ra[YY] = a * xik[YY];
1955     ra[ZZ] = a * xik[ZZ];
1956
1957     rb[XX] = b * xil[XX];
1958     rb[YY] = b * xil[YY];
1959     rb[ZZ] = b * xil[ZZ];
1960
1961     /* 6 flops */
1962
1963     rvec_sub(ra, xij, rja);
1964     rvec_sub(rb, xij, rjb);
1965     rvec_sub(rb, ra, rab);
1966     /* 9 flops */
1967
1968     cprod(rja, rjb, rm);
1969     /* 9 flops */
1970
1971     invrm = inverseNorm(rm);
1972     denom = invrm * invrm;
1973     /* 5+5+2 flops */
1974
1975     cfx = c * invrm * fv[XX];
1976     cfy = c * invrm * fv[YY];
1977     cfz = c * invrm * fv[ZZ];
1978     /* 6 Flops */
1979
1980     cprod(rm, rab, rt);
1981     /* 9 flops */
1982
1983     rt[XX] *= denom;
1984     rt[YY] *= denom;
1985     rt[ZZ] *= denom;
1986     /* 3flops */
1987
1988     fj[XX] = (-rm[XX] * rt[XX]) * cfx + (rab[ZZ] - rm[YY] * rt[XX]) * cfy
1989              + (-rab[YY] - rm[ZZ] * rt[XX]) * cfz;
1990     fj[YY] = (-rab[ZZ] - rm[XX] * rt[YY]) * cfx + (-rm[YY] * rt[YY]) * cfy
1991              + (rab[XX] - rm[ZZ] * rt[YY]) * cfz;
1992     fj[ZZ] = (rab[YY] - rm[XX] * rt[ZZ]) * cfx + (-rab[XX] - rm[YY] * rt[ZZ]) * cfy
1993              + (-rm[ZZ] * rt[ZZ]) * cfz;
1994     /* 30 flops */
1995
1996     cprod(rjb, rm, rt);
1997     /* 9 flops */
1998
1999     rt[XX] *= denom * a;
2000     rt[YY] *= denom * a;
2001     rt[ZZ] *= denom * a;
2002     /* 3flops */
2003
2004     fk[XX] = (-rm[XX] * rt[XX]) * cfx + (-a * rjb[ZZ] - rm[YY] * rt[XX]) * cfy
2005              + (a * rjb[YY] - rm[ZZ] * rt[XX]) * cfz;
2006     fk[YY] = (a * rjb[ZZ] - rm[XX] * rt[YY]) * cfx + (-rm[YY] * rt[YY]) * cfy
2007              + (-a * rjb[XX] - rm[ZZ] * rt[YY]) * cfz;
2008     fk[ZZ] = (-a * rjb[YY] - rm[XX] * rt[ZZ]) * cfx + (a * rjb[XX] - rm[YY] * rt[ZZ]) * cfy
2009              + (-rm[ZZ] * rt[ZZ]) * cfz;
2010     /* 36 flops */
2011
2012     cprod(rm, rja, rt);
2013     /* 9 flops */
2014
2015     rt[XX] *= denom * b;
2016     rt[YY] *= denom * b;
2017     rt[ZZ] *= denom * b;
2018     /* 3flops */
2019
2020     fl[XX] = (-rm[XX] * rt[XX]) * cfx + (b * rja[ZZ] - rm[YY] * rt[XX]) * cfy
2021              + (-b * rja[YY] - rm[ZZ] * rt[XX]) * cfz;
2022     fl[YY] = (-b * rja[ZZ] - rm[XX] * rt[YY]) * cfx + (-rm[YY] * rt[YY]) * cfy
2023              + (b * rja[XX] - rm[ZZ] * rt[YY]) * cfz;
2024     fl[ZZ] = (b * rja[YY] - rm[XX] * rt[ZZ]) * cfx + (-b * rja[XX] - rm[YY] * rt[ZZ]) * cfy
2025              + (-rm[ZZ] * rt[ZZ]) * cfz;
2026     /* 36 flops */
2027
2028     f[ai][XX] += fv[XX] - fj[XX] - fk[XX] - fl[XX];
2029     f[ai][YY] += fv[YY] - fj[YY] - fk[YY] - fl[YY];
2030     f[ai][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ] - fl[ZZ];
2031     rvec_inc(f[aj], fj);
2032     rvec_inc(f[ak], fk);
2033     rvec_inc(f[al], fl);
2034     /* 21 flops */
2035
2036     if (virialHandling == VirialHandling::Pbc)
2037     {
2038         int svi;
2039         if (pbc)
2040         {
2041             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
2042         }
2043         else
2044         {
2045             svi = c_centralShiftIndex;
2046         }
2047
2048         if (svi != c_centralShiftIndex || sij != c_centralShiftIndex || sik != c_centralShiftIndex
2049             || sil != c_centralShiftIndex)
2050         {
2051             rvec_dec(fshift[svi], fv);
2052             fshift[c_centralShiftIndex][XX] += fv[XX] - fj[XX] - fk[XX] - fl[XX];
2053             fshift[c_centralShiftIndex][YY] += fv[YY] - fj[YY] - fk[YY] - fl[YY];
2054             fshift[c_centralShiftIndex][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ] - fl[ZZ];
2055             rvec_inc(fshift[sij], fj);
2056             rvec_inc(fshift[sik], fk);
2057             rvec_inc(fshift[sil], fl);
2058         }
2059     }
2060
2061     if (virialHandling == VirialHandling::NonLinear)
2062     {
2063         rvec xiv;
2064         int  i, j;
2065
2066         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
2067
2068         for (i = 0; i < DIM; i++)
2069         {
2070             for (j = 0; j < DIM; j++)
2071             {
2072                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * fj[j] + xik[i] * fk[j] + xil[i] * fl[j];
2073             }
2074         }
2075     }
2076
2077     /* Total: 207 flops (Yuck!) */
2078 }
2079
2080 template<VirialHandling virialHandling>
2081 static int spread_vsiten(const t_iatom             ia[],
2082                          ArrayRef<const t_iparams> ip,
2083                          ArrayRef<const RVec>      x,
2084                          ArrayRef<RVec>            f,
2085                          ArrayRef<RVec>            fshift,
2086                          const t_pbc*              pbc)
2087 {
2088     rvec xv, dx, fi;
2089     int  n3, av, i, ai;
2090     real a;
2091     int  siv;
2092
2093     n3 = 3 * ip[ia[0]].vsiten.n;
2094     av = ia[1];
2095     copy_rvec(x[av], xv);
2096
2097     for (i = 0; i < n3; i += 3)
2098     {
2099         ai = ia[i + 2];
2100         if (pbc)
2101         {
2102             siv = pbc_dx_aiuc(pbc, x[ai], xv, dx);
2103         }
2104         else
2105         {
2106             siv = c_centralShiftIndex;
2107         }
2108         a = ip[ia[i]].vsiten.a;
2109         svmul(a, f[av], fi);
2110         rvec_inc(f[ai], fi);
2111
2112         if (virialHandling == VirialHandling::Pbc && siv != c_centralShiftIndex)
2113         {
2114             rvec_inc(fshift[siv], fi);
2115             rvec_dec(fshift[c_centralShiftIndex], fi);
2116         }
2117         /* 6 Flops */
2118     }
2119
2120     return n3;
2121 }
2122
2123 #endif // DOXYGEN
2124
2125 //! Returns the number of virtual sites in the interaction list, for VSITEN the number of atoms
2126 static int vsite_count(ArrayRef<const InteractionList> ilist, int ftype)
2127 {
2128     if (ftype == F_VSITEN)
2129     {
2130         return ilist[ftype].size() / 3;
2131     }
2132     else
2133     {
2134         return ilist[ftype].size() / (1 + interaction_function[ftype].nratoms);
2135     }
2136 }
2137
2138 //! Executes the force spreading task for a single thread
2139 template<VirialHandling virialHandling>
2140 static void spreadForceForThread(ArrayRef<const RVec>            x,
2141                                  ArrayRef<RVec>                  f,
2142                                  ArrayRef<RVec>                  fshift,
2143                                  matrix                          dxdf,
2144                                  ArrayRef<const t_iparams>       ip,
2145                                  ArrayRef<const InteractionList> ilist,
2146                                  const t_pbc*                    pbc_null)
2147 {
2148     const PbcMode pbcMode = getPbcMode(pbc_null);
2149     /* We need another pbc pointer, as with charge groups we switch per vsite */
2150     const t_pbc*             pbc_null2 = pbc_null;
2151     gmx::ArrayRef<const int> vsite_pbc;
2152
2153     /* this loop goes backwards to be able to build *
2154      * higher type vsites from lower types         */
2155     for (int ftype = c_ftypeVsiteEnd - 1; ftype >= c_ftypeVsiteStart; ftype--)
2156     {
2157         if (ilist[ftype].empty())
2158         {
2159             continue;
2160         }
2161
2162         { // TODO remove me
2163             int nra = interaction_function[ftype].nratoms;
2164             int inc = 1 + nra;
2165             int nr  = ilist[ftype].size();
2166
2167             const t_iatom* ia = ilist[ftype].iatoms.data();
2168
2169             if (pbcMode == PbcMode::all)
2170             {
2171                 pbc_null2 = pbc_null;
2172             }
2173
2174             for (int i = 0; i < nr;)
2175             {
2176                 int tp = ia[0];
2177
2178                 /* Constants for constructing */
2179                 real a1, b1, c1;
2180                 a1 = ip[tp].vsite.a;
2181                 /* Construct the vsite depending on type */
2182                 switch (ftype)
2183                 {
2184                     case F_VSITE1: spread_vsite1(ia, f); break;
2185                     case F_VSITE2:
2186                         spread_vsite2<virialHandling>(ia, a1, x, f, fshift, pbc_null2);
2187                         break;
2188                     case F_VSITE2FD:
2189                         spread_vsite2FD<virialHandling>(ia, a1, x, f, fshift, dxdf, pbc_null2);
2190                         break;
2191                     case F_VSITE3:
2192                         b1 = ip[tp].vsite.b;
2193                         spread_vsite3<virialHandling>(ia, a1, b1, x, f, fshift, pbc_null2);
2194                         break;
2195                     case F_VSITE3FD:
2196                         b1 = ip[tp].vsite.b;
2197                         spread_vsite3FD<virialHandling>(ia, a1, b1, x, f, fshift, dxdf, pbc_null2);
2198                         break;
2199                     case F_VSITE3FAD:
2200                         b1 = ip[tp].vsite.b;
2201                         spread_vsite3FAD<virialHandling>(ia, a1, b1, x, f, fshift, dxdf, pbc_null2);
2202                         break;
2203                     case F_VSITE3OUT:
2204                         b1 = ip[tp].vsite.b;
2205                         c1 = ip[tp].vsite.c;
2206                         spread_vsite3OUT<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
2207                         break;
2208                     case F_VSITE4FD:
2209                         b1 = ip[tp].vsite.b;
2210                         c1 = ip[tp].vsite.c;
2211                         spread_vsite4FD<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
2212                         break;
2213                     case F_VSITE4FDN:
2214                         b1 = ip[tp].vsite.b;
2215                         c1 = ip[tp].vsite.c;
2216                         spread_vsite4FDN<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
2217                         break;
2218                     case F_VSITEN:
2219                         inc = spread_vsiten<virialHandling>(ia, ip, x, f, fshift, pbc_null2);
2220                         break;
2221                     default:
2222                         gmx_fatal(FARGS, "No such vsite type %d in %s, line %d", ftype, __FILE__, __LINE__);
2223                 }
2224                 clear_rvec(f[ia[1]]);
2225
2226                 /* Increment loop variables */
2227                 i += inc;
2228                 ia += inc;
2229             }
2230         }
2231     }
2232 }
2233
2234 //! Wrapper function for calling the templated thread-local spread function
2235 static void spreadForceWrapper(ArrayRef<const RVec>            x,
2236                                ArrayRef<RVec>                  f,
2237                                const VirialHandling            virialHandling,
2238                                ArrayRef<RVec>                  fshift,
2239                                matrix                          dxdf,
2240                                const bool                      clearDxdf,
2241                                ArrayRef<const t_iparams>       ip,
2242                                ArrayRef<const InteractionList> ilist,
2243                                const t_pbc*                    pbc_null)
2244 {
2245     if (virialHandling == VirialHandling::NonLinear && clearDxdf)
2246     {
2247         clear_mat(dxdf);
2248     }
2249
2250     switch (virialHandling)
2251     {
2252         case VirialHandling::None:
2253             spreadForceForThread<VirialHandling::None>(x, f, fshift, dxdf, ip, ilist, pbc_null);
2254             break;
2255         case VirialHandling::Pbc:
2256             spreadForceForThread<VirialHandling::Pbc>(x, f, fshift, dxdf, ip, ilist, pbc_null);
2257             break;
2258         case VirialHandling::NonLinear:
2259             spreadForceForThread<VirialHandling::NonLinear>(x, f, fshift, dxdf, ip, ilist, pbc_null);
2260             break;
2261     }
2262 }
2263
2264 //! Clears the task force buffer elements that are written by task idTask
2265 static void clearTaskForceBufferUsedElements(InterdependentTask* idTask)
2266 {
2267     int ntask = idTask->spreadTask.size();
2268     for (int ti = 0; ti < ntask; ti++)
2269     {
2270         const AtomIndex* atomList = &idTask->atomIndex[idTask->spreadTask[ti]];
2271         int              natom    = atomList->atom.size();
2272         RVec*            force    = idTask->force.data();
2273         for (int i = 0; i < natom; i++)
2274         {
2275             clear_rvec(force[atomList->atom[i]]);
2276         }
2277     }
2278 }
2279
2280 void VirtualSitesHandler::Impl::spreadForces(ArrayRef<const RVec> x,
2281                                              ArrayRef<RVec>       f,
2282                                              const VirialHandling virialHandling,
2283                                              ArrayRef<RVec>       fshift,
2284                                              matrix               virial,
2285                                              t_nrnb*              nrnb,
2286                                              const matrix         box,
2287                                              gmx_wallcycle*       wcycle)
2288 {
2289     wallcycle_start(wcycle, WallCycleCounter::VsiteSpread);
2290
2291     const bool useDomdec = domainInfo_.useDomdec();
2292
2293     t_pbc pbc, *pbc_null;
2294
2295     if (domainInfo_.useMolPbc_)
2296     {
2297         /* This is wasting some CPU time as we now do this multiple times
2298          * per MD step.
2299          */
2300         pbc_null = set_pbc_dd(
2301                 &pbc, domainInfo_.pbcType_, useDomdec ? domainInfo_.domdec_->numCells : nullptr, FALSE, box);
2302     }
2303     else
2304     {
2305         pbc_null = nullptr;
2306     }
2307
2308     if (useDomdec)
2309     {
2310         dd_clear_f_vsites(*domainInfo_.domdec_, f);
2311     }
2312
2313     const int numThreads = threadingInfo_.numThreads();
2314
2315     if (numThreads == 1)
2316     {
2317         matrix dxdf;
2318         spreadForceWrapper(x, f, virialHandling, fshift, dxdf, true, iparams_, ilists_, pbc_null);
2319
2320         if (virialHandling == VirialHandling::NonLinear)
2321         {
2322             for (int i = 0; i < DIM; i++)
2323             {
2324                 for (int j = 0; j < DIM; j++)
2325                 {
2326                     virial[i][j] += -0.5 * dxdf[i][j];
2327                 }
2328             }
2329         }
2330     }
2331     else
2332     {
2333         /* First spread the vsites that might depend on non-local vsites */
2334         auto& nlDependentVSites = threadingInfo_.threadDataNonLocalDependent();
2335         spreadForceWrapper(x,
2336                            f,
2337                            virialHandling,
2338                            fshift,
2339                            nlDependentVSites.dxdf,
2340                            true,
2341                            iparams_,
2342                            nlDependentVSites.ilist,
2343                            pbc_null);
2344
2345 #pragma omp parallel num_threads(numThreads)
2346         {
2347             try
2348             {
2349                 int          thread = gmx_omp_get_thread_num();
2350                 VsiteThread& tData  = threadingInfo_.threadData(thread);
2351
2352                 ArrayRef<RVec> fshift_t;
2353                 if (virialHandling == VirialHandling::Pbc)
2354                 {
2355                     if (thread == 0)
2356                     {
2357                         fshift_t = fshift;
2358                     }
2359                     else
2360                     {
2361                         fshift_t = tData.fshift;
2362
2363                         for (int i = 0; i < c_numShiftVectors; i++)
2364                         {
2365                             clear_rvec(fshift_t[i]);
2366                         }
2367                     }
2368                 }
2369
2370                 if (tData.useInterdependentTask)
2371                 {
2372                     /* Spread the vsites that spread outside our local range.
2373                      * This is done using a thread-local force buffer force.
2374                      * First we need to copy the input vsite forces to force.
2375                      */
2376                     InterdependentTask* idTask = &tData.idTask;
2377
2378                     /* Clear the buffer elements set by our task during
2379                      * the last call to spread_vsite_f.
2380                      */
2381                     clearTaskForceBufferUsedElements(idTask);
2382
2383                     int nvsite = idTask->vsite.size();
2384                     for (int i = 0; i < nvsite; i++)
2385                     {
2386                         copy_rvec(f[idTask->vsite[i]], idTask->force[idTask->vsite[i]]);
2387                     }
2388                     spreadForceWrapper(x,
2389                                        idTask->force,
2390                                        virialHandling,
2391                                        fshift_t,
2392                                        tData.dxdf,
2393                                        true,
2394                                        iparams_,
2395                                        tData.idTask.ilist,
2396                                        pbc_null);
2397
2398                     /* We need a barrier before reducing forces below
2399                      * that have been produced by a different thread above.
2400                      */
2401 #pragma omp barrier
2402
2403                     /* Loop over all thread task and reduce forces they
2404                      * produced on atoms that fall in our range.
2405                      * Note that atomic reduction would be a simpler solution,
2406                      * but that might not have good support on all platforms.
2407                      */
2408                     int ntask = idTask->reduceTask.size();
2409                     for (int ti = 0; ti < ntask; ti++)
2410                     {
2411                         const InterdependentTask& idt_foreign =
2412                                 threadingInfo_.threadData(idTask->reduceTask[ti]).idTask;
2413                         const AtomIndex& atomList  = idt_foreign.atomIndex[thread];
2414                         const RVec*      f_foreign = idt_foreign.force.data();
2415
2416                         for (int ind : atomList.atom)
2417                         {
2418                             rvec_inc(f[ind], f_foreign[ind]);
2419                             /* Clearing of f_foreign is done at the next step */
2420                         }
2421                     }
2422                     /* Clear the vsite forces, both in f and force */
2423                     for (int i = 0; i < nvsite; i++)
2424                     {
2425                         int ind = tData.idTask.vsite[i];
2426                         clear_rvec(f[ind]);
2427                         clear_rvec(tData.idTask.force[ind]);
2428                     }
2429                 }
2430
2431                 /* Spread the vsites that spread locally only */
2432                 spreadForceWrapper(
2433                         x, f, virialHandling, fshift_t, tData.dxdf, false, iparams_, tData.ilist, pbc_null);
2434             }
2435             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
2436         }
2437
2438         if (virialHandling == VirialHandling::Pbc)
2439         {
2440             for (int th = 1; th < numThreads; th++)
2441             {
2442                 for (int i = 0; i < c_numShiftVectors; i++)
2443                 {
2444                     rvec_inc(fshift[i], threadingInfo_.threadData(th).fshift[i]);
2445                 }
2446             }
2447         }
2448
2449         if (virialHandling == VirialHandling::NonLinear)
2450         {
2451             for (int th = 0; th < numThreads + 1; th++)
2452             {
2453                 /* MSVC doesn't like matrix references, so we use a pointer */
2454                 const matrix& dxdf = threadingInfo_.threadData(th).dxdf;
2455
2456                 for (int i = 0; i < DIM; i++)
2457                 {
2458                     for (int j = 0; j < DIM; j++)
2459                     {
2460                         virial[i][j] += -0.5 * dxdf[i][j];
2461                     }
2462                 }
2463             }
2464         }
2465     }
2466
2467     if (useDomdec)
2468     {
2469         dd_move_f_vsites(*domainInfo_.domdec_, f, fshift);
2470     }
2471
2472     inc_nrnb(nrnb, eNR_VSITE1, vsite_count(ilists_, F_VSITE1));
2473     inc_nrnb(nrnb, eNR_VSITE2, vsite_count(ilists_, F_VSITE2));
2474     inc_nrnb(nrnb, eNR_VSITE2FD, vsite_count(ilists_, F_VSITE2FD));
2475     inc_nrnb(nrnb, eNR_VSITE3, vsite_count(ilists_, F_VSITE3));
2476     inc_nrnb(nrnb, eNR_VSITE3FD, vsite_count(ilists_, F_VSITE3FD));
2477     inc_nrnb(nrnb, eNR_VSITE3FAD, vsite_count(ilists_, F_VSITE3FAD));
2478     inc_nrnb(nrnb, eNR_VSITE3OUT, vsite_count(ilists_, F_VSITE3OUT));
2479     inc_nrnb(nrnb, eNR_VSITE4FD, vsite_count(ilists_, F_VSITE4FD));
2480     inc_nrnb(nrnb, eNR_VSITE4FDN, vsite_count(ilists_, F_VSITE4FDN));
2481     inc_nrnb(nrnb, eNR_VSITEN, vsite_count(ilists_, F_VSITEN));
2482
2483     wallcycle_stop(wcycle, WallCycleCounter::VsiteSpread);
2484 }
2485
2486 /*! \brief Returns the an array with group indices for each atom
2487  *
2488  * \param[in] grouping  The paritioning of the atom range into atom groups
2489  */
2490 static std::vector<int> makeAtomToGroupMapping(const gmx::RangePartitioning& grouping)
2491 {
2492     std::vector<int> atomToGroup(grouping.fullRange().end(), 0);
2493
2494     for (int group = 0; group < grouping.numBlocks(); group++)
2495     {
2496         auto block = grouping.block(group);
2497         std::fill(atomToGroup.begin() + block.begin(), atomToGroup.begin() + block.end(), group);
2498     }
2499
2500     return atomToGroup;
2501 }
2502
2503 int countNonlinearVsites(const gmx_mtop_t& mtop)
2504 {
2505     int numNonlinearVsites = 0;
2506     for (const gmx_molblock_t& molb : mtop.molblock)
2507     {
2508         const gmx_moltype_t& molt = mtop.moltype[molb.type];
2509
2510         for (const auto& ilist : extractILists(molt.ilist, IF_VSITE))
2511         {
2512             if (ilist.functionType != F_VSITE2 && ilist.functionType != F_VSITE3
2513                 && ilist.functionType != F_VSITEN)
2514             {
2515                 numNonlinearVsites += molb.nmol * ilist.iatoms.size() / (1 + NRAL(ilist.functionType));
2516             }
2517         }
2518     }
2519
2520     return numNonlinearVsites;
2521 }
2522
2523 void VirtualSitesHandler::spreadForces(ArrayRef<const RVec> x,
2524                                        ArrayRef<RVec>       f,
2525                                        const VirialHandling virialHandling,
2526                                        ArrayRef<RVec>       fshift,
2527                                        matrix               virial,
2528                                        t_nrnb*              nrnb,
2529                                        const matrix         box,
2530                                        gmx_wallcycle*       wcycle)
2531 {
2532     impl_->spreadForces(x, f, virialHandling, fshift, virial, nrnb, box, wcycle);
2533 }
2534
2535 int countInterUpdategroupVsites(const gmx_mtop_t&                           mtop,
2536                                 gmx::ArrayRef<const gmx::RangePartitioning> updateGroupingsPerMoleculeType)
2537 {
2538     int n_intercg_vsite = 0;
2539     for (const gmx_molblock_t& molb : mtop.molblock)
2540     {
2541         const gmx_moltype_t& molt = mtop.moltype[molb.type];
2542
2543         std::vector<int> atomToGroup;
2544         if (!updateGroupingsPerMoleculeType.empty())
2545         {
2546             atomToGroup = makeAtomToGroupMapping(updateGroupingsPerMoleculeType[molb.type]);
2547         }
2548         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2549         {
2550             const int              nral = NRAL(ftype);
2551             const InteractionList& il   = molt.ilist[ftype];
2552             for (int i = 0; i < il.size(); i += 1 + nral)
2553             {
2554                 bool isInterGroup = atomToGroup.empty();
2555                 if (!isInterGroup)
2556                 {
2557                     const int group = atomToGroup[il.iatoms[1 + i]];
2558                     for (int a = 1; a < nral; a++)
2559                     {
2560                         if (atomToGroup[il.iatoms[1 + a]] != group)
2561                         {
2562                             isInterGroup = true;
2563                             break;
2564                         }
2565                     }
2566                 }
2567                 if (isInterGroup)
2568                 {
2569                     n_intercg_vsite += molb.nmol;
2570                 }
2571             }
2572         }
2573     }
2574
2575     return n_intercg_vsite;
2576 }
2577
2578 std::unique_ptr<VirtualSitesHandler> makeVirtualSitesHandler(const gmx_mtop_t& mtop,
2579                                                              const t_commrec*  cr,
2580                                                              PbcType           pbcType,
2581                                                              ArrayRef<const RangePartitioning> updateGroupingPerMoleculeType)
2582 {
2583     GMX_RELEASE_ASSERT(cr != nullptr, "We need a valid commrec");
2584
2585     std::unique_ptr<VirtualSitesHandler> vsite;
2586
2587     /* check if there are vsites */
2588     int nvsite = 0;
2589     for (int ftype = 0; ftype < F_NRE; ftype++)
2590     {
2591         if (interaction_function[ftype].flags & IF_VSITE)
2592         {
2593             GMX_ASSERT(ftype >= c_ftypeVsiteStart && ftype < c_ftypeVsiteEnd,
2594                        "c_ftypeVsiteStart and/or c_ftypeVsiteEnd do not have correct values");
2595
2596             nvsite += gmx_mtop_ftype_count(mtop, ftype);
2597         }
2598         else
2599         {
2600             GMX_ASSERT(ftype < c_ftypeVsiteStart || ftype >= c_ftypeVsiteEnd,
2601                        "c_ftypeVsiteStart and/or c_ftypeVsiteEnd do not have correct values");
2602         }
2603     }
2604
2605     if (nvsite == 0)
2606     {
2607         return vsite;
2608     }
2609
2610     return std::make_unique<VirtualSitesHandler>(mtop, cr->dd, pbcType, updateGroupingPerMoleculeType);
2611 }
2612
2613 ThreadingInfo::ThreadingInfo() : numThreads_(gmx_omp_nthreads_get(ModuleMultiThread::VirtualSite))
2614 {
2615     if (numThreads_ > 1)
2616     {
2617         /* We need one extra thread data structure for the overlap vsites */
2618         tData_.resize(numThreads_ + 1);
2619 #pragma omp parallel for num_threads(numThreads_) schedule(static)
2620         for (int thread = 0; thread < numThreads_; thread++)
2621         {
2622             try
2623             {
2624                 tData_[thread] = std::make_unique<VsiteThread>();
2625
2626                 InterdependentTask& idTask = tData_[thread]->idTask;
2627                 idTask.nuse                = 0;
2628                 idTask.atomIndex.resize(numThreads_);
2629             }
2630             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
2631         }
2632         if (numThreads_ > 1)
2633         {
2634             tData_[numThreads_] = std::make_unique<VsiteThread>();
2635         }
2636     }
2637 }
2638
2639 VirtualSitesHandler::Impl::Impl(const gmx_mtop_t&                       mtop,
2640                                 gmx_domdec_t*                           domdec,
2641                                 const PbcType                           pbcType,
2642                                 const ArrayRef<const RangePartitioning> updateGroupingPerMoleculeType) :
2643     numInterUpdategroupVirtualSites_(countInterUpdategroupVsites(mtop, updateGroupingPerMoleculeType)),
2644     domainInfo_({ pbcType, pbcType != PbcType::No && numInterUpdategroupVirtualSites_ > 0, domdec }),
2645     iparams_(mtop.ffparams.iparams)
2646 {
2647 }
2648
2649 VirtualSitesHandler::VirtualSitesHandler(const gmx_mtop_t&                       mtop,
2650                                          gmx_domdec_t*                           domdec,
2651                                          const PbcType                           pbcType,
2652                                          const ArrayRef<const RangePartitioning> updateGroupingPerMoleculeType) :
2653     impl_(new Impl(mtop, domdec, pbcType, updateGroupingPerMoleculeType))
2654 {
2655 }
2656
2657 //! Flag that atom \p atom which is home in another task, if it has not already been added before
2658 static inline void flagAtom(InterdependentTask* idTask, const int atom, const int numThreads, const int numAtomsPerThread)
2659 {
2660     if (!idTask->use[atom])
2661     {
2662         idTask->use[atom] = true;
2663         int thread        = atom / numAtomsPerThread;
2664         /* Assign all non-local atom force writes to thread 0 */
2665         if (thread >= numThreads)
2666         {
2667             thread = 0;
2668         }
2669         idTask->atomIndex[thread].atom.push_back(atom);
2670     }
2671 }
2672
2673 /*! \brief Here we try to assign all vsites that are in our local range.
2674  *
2675  * Our task local atom range is tData->rangeStart - tData->rangeEnd.
2676  * Vsites that depend only on local atoms, as indicated by taskIndex[]==thread,
2677  * are assigned to task tData->ilist. Vsites that depend on non-local atoms
2678  * but not on other vsites are assigned to task tData->id_task.ilist.
2679  * taskIndex[] is set for all vsites in our range, either to our local tasks
2680  * or to the single last task as taskIndex[]=2*nthreads.
2681  */
2682 static void assignVsitesToThread(VsiteThread*                    tData,
2683                                  int                             thread,
2684                                  int                             nthread,
2685                                  int                             natperthread,
2686                                  gmx::ArrayRef<int>              taskIndex,
2687                                  ArrayRef<const InteractionList> ilist,
2688                                  ArrayRef<const t_iparams>       ip,
2689                                  const ParticleType*             ptype)
2690 {
2691     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2692     {
2693         tData->ilist[ftype].clear();
2694         tData->idTask.ilist[ftype].clear();
2695
2696         const int  nral1 = 1 + NRAL(ftype);
2697         const int* iat   = ilist[ftype].iatoms.data();
2698         for (int i = 0; i < ilist[ftype].size();)
2699         {
2700             /* Get the number of iatom entries in this virtual site.
2701              * The 3 below for F_VSITEN is from 1+NRAL(ftype)=3
2702              */
2703             const int numIAtoms = (ftype == F_VSITEN ? ip[iat[i]].vsiten.n * 3 : nral1);
2704
2705             if (iat[1 + i] < tData->rangeStart || iat[1 + i] >= tData->rangeEnd)
2706             {
2707                 /* This vsite belongs to a different thread */
2708                 i += numIAtoms;
2709                 continue;
2710             }
2711
2712             /* We would like to assign this vsite to task thread,
2713              * but it might depend on atoms outside the atom range of thread
2714              * or on another vsite not assigned to task thread.
2715              */
2716             int task = thread;
2717             if (ftype != F_VSITEN)
2718             {
2719                 for (int j = i + 2; j < i + nral1; j++)
2720                 {
2721                     /* Do a range check to avoid a harmless race on taskIndex */
2722                     if (iat[j] < tData->rangeStart || iat[j] >= tData->rangeEnd || taskIndex[iat[j]] != thread)
2723                     {
2724                         if (!tData->useInterdependentTask || ptype[iat[j]] == ParticleType::VSite)
2725                         {
2726                             /* At least one constructing atom is a vsite
2727                              * that is not assigned to the same thread.
2728                              * Put this vsite into a separate task.
2729                              */
2730                             task = 2 * nthread;
2731                             break;
2732                         }
2733
2734                         /* There are constructing atoms outside our range,
2735                          * put this vsite into a second task to be executed
2736                          * on the same thread. During construction no barrier
2737                          * is needed between the two tasks on the same thread.
2738                          * During spreading we need to run this task with
2739                          * an additional thread-local intermediate force buffer
2740                          * (or atomic reduction) and a barrier between the two
2741                          * tasks.
2742                          */
2743                         task = nthread + thread;
2744                     }
2745                 }
2746             }
2747             else
2748             {
2749                 for (int j = i + 2; j < i + numIAtoms; j += 3)
2750                 {
2751                     /* Do a range check to avoid a harmless race on taskIndex */
2752                     if (iat[j] < tData->rangeStart || iat[j] >= tData->rangeEnd || taskIndex[iat[j]] != thread)
2753                     {
2754                         GMX_ASSERT(ptype[iat[j]] != ParticleType::VSite,
2755                                    "A vsite to be assigned in assignVsitesToThread has a vsite as "
2756                                    "a constructing atom that does not belong to our task, such "
2757                                    "vsites should be assigned to the single 'master' task");
2758
2759                         task = nthread + thread;
2760                     }
2761                 }
2762             }
2763
2764             /* Update this vsite's thread index entry */
2765             taskIndex[iat[1 + i]] = task;
2766
2767             if (task == thread || task == nthread + thread)
2768             {
2769                 /* Copy this vsite to the thread data struct of thread */
2770                 InteractionList* il_task;
2771                 if (task == thread)
2772                 {
2773                     il_task = &tData->ilist[ftype];
2774                 }
2775                 else
2776                 {
2777                     il_task = &tData->idTask.ilist[ftype];
2778                 }
2779                 /* Copy the vsite data to the thread-task local array */
2780                 il_task->push_back(iat[i], numIAtoms - 1, iat + i + 1);
2781                 if (task == nthread + thread)
2782                 {
2783                     /* This vsite writes outside our own task force block.
2784                      * Put it into the interdependent task list and flag
2785                      * the atoms involved for reduction.
2786                      */
2787                     tData->idTask.vsite.push_back(iat[i + 1]);
2788                     if (ftype != F_VSITEN)
2789                     {
2790                         for (int j = i + 2; j < i + nral1; j++)
2791                         {
2792                             flagAtom(&tData->idTask, iat[j], nthread, natperthread);
2793                         }
2794                     }
2795                     else
2796                     {
2797                         for (int j = i + 2; j < i + numIAtoms; j += 3)
2798                         {
2799                             flagAtom(&tData->idTask, iat[j], nthread, natperthread);
2800                         }
2801                     }
2802                 }
2803             }
2804
2805             i += numIAtoms;
2806         }
2807     }
2808 }
2809
2810 /*! \brief Assign all vsites with taskIndex[]==task to task tData */
2811 static void assignVsitesToSingleTask(VsiteThread*                    tData,
2812                                      int                             task,
2813                                      gmx::ArrayRef<const int>        taskIndex,
2814                                      ArrayRef<const InteractionList> ilist,
2815                                      ArrayRef<const t_iparams>       ip)
2816 {
2817     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2818     {
2819         tData->ilist[ftype].clear();
2820         tData->idTask.ilist[ftype].clear();
2821
2822         int              nral1   = 1 + NRAL(ftype);
2823         int              inc     = nral1;
2824         const int*       iat     = ilist[ftype].iatoms.data();
2825         InteractionList* il_task = &tData->ilist[ftype];
2826
2827         for (int i = 0; i < ilist[ftype].size();)
2828         {
2829             if (ftype == F_VSITEN)
2830             {
2831                 /* The 3 below is from 1+NRAL(ftype)=3 */
2832                 inc = ip[iat[i]].vsiten.n * 3;
2833             }
2834             /* Check if the vsite is assigned to our task */
2835             if (taskIndex[iat[1 + i]] == task)
2836             {
2837                 /* Copy the vsite data to the thread-task local array */
2838                 il_task->push_back(iat[i], inc - 1, iat + i + 1);
2839             }
2840
2841             i += inc;
2842         }
2843     }
2844 }
2845
2846 void ThreadingInfo::setVirtualSites(ArrayRef<const InteractionList> ilists,
2847                                     ArrayRef<const t_iparams>       iparams,
2848                                     const t_mdatoms&                mdatoms,
2849                                     const bool                      useDomdec)
2850 {
2851     if (numThreads_ <= 1)
2852     {
2853         /* Nothing to do */
2854         return;
2855     }
2856
2857     /* The current way of distributing the vsites over threads in primitive.
2858      * We divide the atom range 0 - natoms_in_vsite uniformly over threads,
2859      * without taking into account how the vsites are distributed.
2860      * Without domain decomposition we at least tighten the upper bound
2861      * of the range (useful for common systems such as a vsite-protein
2862      * in 3-site water).
2863      * With domain decomposition, as long as the vsites are distributed
2864      * uniformly in each domain along the major dimension, usually x,
2865      * it will also perform well.
2866      */
2867     int vsite_atom_range;
2868     int natperthread;
2869     if (!useDomdec)
2870     {
2871         vsite_atom_range = -1;
2872         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2873         {
2874             { // TODO remove me
2875                 if (ftype != F_VSITEN)
2876                 {
2877                     int                 nral1 = 1 + NRAL(ftype);
2878                     ArrayRef<const int> iat   = ilists[ftype].iatoms;
2879                     for (int i = 0; i < ilists[ftype].size(); i += nral1)
2880                     {
2881                         for (int j = i + 1; j < i + nral1; j++)
2882                         {
2883                             vsite_atom_range = std::max(vsite_atom_range, iat[j]);
2884                         }
2885                     }
2886                 }
2887                 else
2888                 {
2889                     int vs_ind_end;
2890
2891                     ArrayRef<const int> iat = ilists[ftype].iatoms;
2892
2893                     int i = 0;
2894                     while (i < ilists[ftype].size())
2895                     {
2896                         /* The 3 below is from 1+NRAL(ftype)=3 */
2897                         vs_ind_end = i + iparams[iat[i]].vsiten.n * 3;
2898
2899                         vsite_atom_range = std::max(vsite_atom_range, iat[i + 1]);
2900                         while (i < vs_ind_end)
2901                         {
2902                             vsite_atom_range = std::max(vsite_atom_range, iat[i + 2]);
2903                             i += 3;
2904                         }
2905                     }
2906                 }
2907             }
2908         }
2909         vsite_atom_range++;
2910         natperthread = (vsite_atom_range + numThreads_ - 1) / numThreads_;
2911     }
2912     else
2913     {
2914         /* Any local or not local atom could be involved in virtual sites.
2915          * But since we usually have very few non-local virtual sites
2916          * (only non-local vsites that depend on local vsites),
2917          * we distribute the local atom range equally over the threads.
2918          * When assigning vsites to threads, we should take care that the last
2919          * threads also covers the non-local range.
2920          */
2921         vsite_atom_range = mdatoms.nr;
2922         natperthread     = (mdatoms.homenr + numThreads_ - 1) / numThreads_;
2923     }
2924
2925     if (debug)
2926     {
2927         fprintf(debug,
2928                 "virtual site thread dist: natoms %d, range %d, natperthread %d\n",
2929                 mdatoms.nr,
2930                 vsite_atom_range,
2931                 natperthread);
2932     }
2933
2934     /* To simplify the vsite assignment, we make an index which tells us
2935      * to which task particles, both non-vsites and vsites, are assigned.
2936      */
2937     taskIndex_.resize(mdatoms.nr);
2938
2939     /* Initialize the task index array. Here we assign the non-vsite
2940      * particles to task=thread, so we easily figure out if vsites
2941      * depend on local and/or non-local particles in assignVsitesToThread.
2942      */
2943     {
2944         int thread = 0;
2945         for (int i = 0; i < mdatoms.nr; i++)
2946         {
2947             if (mdatoms.ptype[i] == ParticleType::VSite)
2948             {
2949                 /* vsites are not assigned to a task yet */
2950                 taskIndex_[i] = -1;
2951             }
2952             else
2953             {
2954                 /* assign non-vsite particles to task thread */
2955                 taskIndex_[i] = thread;
2956             }
2957             if (i == (thread + 1) * natperthread && thread < numThreads_)
2958             {
2959                 thread++;
2960             }
2961         }
2962     }
2963
2964 #pragma omp parallel num_threads(numThreads_)
2965     {
2966         try
2967         {
2968             int          thread = gmx_omp_get_thread_num();
2969             VsiteThread& tData  = *tData_[thread];
2970
2971             /* Clear the buffer use flags that were set before */
2972             if (tData.useInterdependentTask)
2973             {
2974                 InterdependentTask& idTask = tData.idTask;
2975
2976                 /* To avoid an extra OpenMP barrier in spread_vsite_f,
2977                  * we clear the force buffer at the next step,
2978                  * so we need to do it here as well.
2979                  */
2980                 clearTaskForceBufferUsedElements(&idTask);
2981
2982                 idTask.vsite.resize(0);
2983                 for (int t = 0; t < numThreads_; t++)
2984                 {
2985                     AtomIndex& atomIndex = idTask.atomIndex[t];
2986                     int        natom     = atomIndex.atom.size();
2987                     for (int i = 0; i < natom; i++)
2988                     {
2989                         idTask.use[atomIndex.atom[i]] = false;
2990                     }
2991                     atomIndex.atom.resize(0);
2992                 }
2993                 idTask.nuse = 0;
2994             }
2995
2996             /* To avoid large f_buf allocations of #threads*vsite_atom_range
2997              * we don't use task2 with more than 200000 atoms. This doesn't
2998              * affect performance, since with such a large range relatively few
2999              * vsites will end up in the separate task.
3000              * Note that useTask2 should be the same for all threads.
3001              */
3002             tData.useInterdependentTask = (vsite_atom_range <= 200000);
3003             if (tData.useInterdependentTask)
3004             {
3005                 size_t              natoms_use_in_vsites = vsite_atom_range;
3006                 InterdependentTask& idTask               = tData.idTask;
3007                 /* To avoid resizing and re-clearing every nstlist steps,
3008                  * we never down size the force buffer.
3009                  */
3010                 if (natoms_use_in_vsites > idTask.force.size() || natoms_use_in_vsites > idTask.use.size())
3011                 {
3012                     idTask.force.resize(natoms_use_in_vsites, { 0, 0, 0 });
3013                     idTask.use.resize(natoms_use_in_vsites, false);
3014                 }
3015             }
3016
3017             /* Assign all vsites that can execute independently on threads */
3018             tData.rangeStart = thread * natperthread;
3019             if (thread < numThreads_ - 1)
3020             {
3021                 tData.rangeEnd = (thread + 1) * natperthread;
3022             }
3023             else
3024             {
3025                 /* The last thread should cover up to the end of the range */
3026                 tData.rangeEnd = mdatoms.nr;
3027             }
3028             assignVsitesToThread(
3029                     &tData, thread, numThreads_, natperthread, taskIndex_, ilists, iparams, mdatoms.ptype);
3030
3031             if (tData.useInterdependentTask)
3032             {
3033                 /* In the worst case, all tasks write to force ranges of
3034                  * all other tasks, leading to #tasks^2 scaling (this is only
3035                  * the overhead, the actual flops remain constant).
3036                  * But in most cases there is far less coupling. To improve
3037                  * scaling at high thread counts we therefore construct
3038                  * an index to only loop over the actually affected tasks.
3039                  */
3040                 InterdependentTask& idTask = tData.idTask;
3041
3042                 /* Ensure assignVsitesToThread finished on other threads */
3043 #pragma omp barrier
3044
3045                 idTask.spreadTask.resize(0);
3046                 idTask.reduceTask.resize(0);
3047                 for (int t = 0; t < numThreads_; t++)
3048                 {
3049                     /* Do we write to the force buffer of task t? */
3050                     if (!idTask.atomIndex[t].atom.empty())
3051                     {
3052                         idTask.spreadTask.push_back(t);
3053                     }
3054                     /* Does task t write to our force buffer? */
3055                     if (!tData_[t]->idTask.atomIndex[thread].atom.empty())
3056                     {
3057                         idTask.reduceTask.push_back(t);
3058                     }
3059                 }
3060             }
3061         }
3062         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
3063     }
3064     /* Assign all remaining vsites, that will have taskIndex[]=2*vsite->nthreads,
3065      * to a single task that will not run in parallel with other tasks.
3066      */
3067     assignVsitesToSingleTask(tData_[numThreads_].get(), 2 * numThreads_, taskIndex_, ilists, iparams);
3068
3069     if (debug && numThreads_ > 1)
3070     {
3071         fprintf(debug,
3072                 "virtual site useInterdependentTask %d, nuse:\n",
3073                 static_cast<int>(tData_[0]->useInterdependentTask));
3074         for (int th = 0; th < numThreads_ + 1; th++)
3075         {
3076             fprintf(debug, " %4d", tData_[th]->idTask.nuse);
3077         }
3078         fprintf(debug, "\n");
3079
3080         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
3081         {
3082             if (!ilists[ftype].empty())
3083             {
3084                 fprintf(debug, "%-20s thread dist:", interaction_function[ftype].longname);
3085                 for (int th = 0; th < numThreads_ + 1; th++)
3086                 {
3087                     fprintf(debug,
3088                             " %4d %4d ",
3089                             tData_[th]->ilist[ftype].size(),
3090                             tData_[th]->idTask.ilist[ftype].size());
3091                 }
3092                 fprintf(debug, "\n");
3093             }
3094         }
3095     }
3096
3097 #ifndef NDEBUG
3098     int nrOrig     = vsiteIlistNrCount(ilists);
3099     int nrThreaded = 0;
3100     for (int th = 0; th < numThreads_ + 1; th++)
3101     {
3102         nrThreaded += vsiteIlistNrCount(tData_[th]->ilist) + vsiteIlistNrCount(tData_[th]->idTask.ilist);
3103     }
3104     GMX_ASSERT(nrThreaded == nrOrig,
3105                "The number of virtual sites assigned to all thread task has to match the total "
3106                "number of virtual sites");
3107 #endif
3108 }
3109
3110 void VirtualSitesHandler::Impl::setVirtualSites(ArrayRef<const InteractionList> ilists,
3111                                                 const t_mdatoms&                mdatoms)
3112 {
3113     ilists_ = ilists;
3114
3115     threadingInfo_.setVirtualSites(ilists, iparams_, mdatoms, domainInfo_.useDomdec());
3116 }
3117
3118 void VirtualSitesHandler::setVirtualSites(ArrayRef<const InteractionList> ilists, const t_mdatoms& mdatoms)
3119 {
3120     impl_->setVirtualSites(ilists, mdatoms);
3121 }
3122
3123 } // namespace gmx