Merge release-2021 into master
[alexxy/gromacs.git] / src / gromacs / mdlib / vsite.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017 The GROMACS development team.
7  * Copyright (c) 2018,2019,2020,2021, by the GROMACS development team, led by
8  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
9  * and including many others, as listed in the AUTHORS file in the
10  * top-level source directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 /*! \internal \file
39  * \brief Implements the VirtualSitesHandler class and vsite standalone functions
40  *
41  * \author Berk Hess <hess@kth.se>
42  * \ingroup module_mdlib
43  */
44
45 #include "gmxpre.h"
46
47 #include "vsite.h"
48
49 #include <cstdio>
50
51 #include <algorithm>
52 #include <memory>
53 #include <vector>
54
55 #include "gromacs/domdec/domdec.h"
56 #include "gromacs/domdec/domdec_struct.h"
57 #include "gromacs/gmxlib/network.h"
58 #include "gromacs/gmxlib/nrnb.h"
59 #include "gromacs/math/functions.h"
60 #include "gromacs/math/vec.h"
61 #include "gromacs/mdlib/gmx_omp_nthreads.h"
62 #include "gromacs/mdtypes/commrec.h"
63 #include "gromacs/pbcutil/ishift.h"
64 #include "gromacs/pbcutil/pbc.h"
65 #include "gromacs/timing/wallcycle.h"
66 #include "gromacs/topology/ifunc.h"
67 #include "gromacs/topology/mtop_util.h"
68 #include "gromacs/topology/topology.h"
69 #include "gromacs/utility/exceptions.h"
70 #include "gromacs/utility/fatalerror.h"
71 #include "gromacs/utility/gmxassert.h"
72 #include "gromacs/utility/gmxomp.h"
73
74 /* The strategy used here for assigning virtual sites to (thread-)tasks
75  * is as follows:
76  *
77  * We divide the atom range that vsites operate on (natoms_local with DD,
78  * 0 - last atom involved in vsites without DD) equally over all threads.
79  *
80  * Vsites in the local range constructed from atoms in the local range
81  * and/or other vsites that are fully local are assigned to a simple,
82  * independent task.
83  *
84  * Vsites that are not assigned after using the above criterion get assigned
85  * to a so called "interdependent" thread task when none of the constructing
86  * atoms is a vsite. These tasks are called interdependent, because one task
87  * accesses atoms assigned to a different task/thread.
88  * Note that this option is turned off with large (local) atom counts
89  * to avoid high memory usage.
90  *
91  * Any remaining vsites are assigned to a separate master thread task.
92  */
93 namespace gmx
94 {
95
96 //! VirialHandling is often used outside VirtualSitesHandler class members
97 using VirialHandling = VirtualSitesHandler::VirialHandling;
98
99 /*! \brief Information on PBC and domain decomposition for virtual sites
100  */
101 struct DomainInfo
102 {
103 public:
104     //! Constructs without PBC and DD
105     DomainInfo() = default;
106
107     //! Constructs with PBC and DD, if !=nullptr
108     DomainInfo(PbcType pbcType, bool haveInterUpdateGroupVirtualSites, gmx_domdec_t* domdec) :
109         pbcType_(pbcType),
110         useMolPbc_(pbcType != PbcType::No && haveInterUpdateGroupVirtualSites),
111         domdec_(domdec)
112     {
113     }
114
115     //! Returns whether we are using domain decomposition with more than 1 DD rank
116     bool useDomdec() const { return (domdec_ != nullptr); }
117
118     //! The pbc type
119     const PbcType pbcType_ = PbcType::No;
120     //! Whether molecules are broken over PBC
121     const bool useMolPbc_ = false;
122     //! Pointer to the domain decomposition struct, nullptr without PP DD
123     const gmx_domdec_t* domdec_ = nullptr;
124 };
125
126 /*! \brief List of atom indices belonging to a task
127  */
128 struct AtomIndex
129 {
130     //! List of atom indices
131     std::vector<int> atom;
132 };
133
134 /*! \brief Data structure for thread tasks that use constructing atoms outside their own atom range
135  */
136 struct InterdependentTask
137 {
138     //! The interaction lists, only vsite entries are used
139     InteractionLists ilist;
140     //! Thread/task-local force buffer
141     std::vector<RVec> force;
142     //! The atom indices of the vsites of our task
143     std::vector<int> vsite;
144     //! Flags if elements in force are spread to or not
145     std::vector<bool> use;
146     //! The number of entries set to true in use
147     int nuse = 0;
148     //! Array of atoms indices, size nthreads, covering all nuse set elements in use
149     std::vector<AtomIndex> atomIndex;
150     //! List of tasks (force blocks) this task spread forces to
151     std::vector<int> spreadTask;
152     //! List of tasks that write to this tasks force block range
153     std::vector<int> reduceTask;
154 };
155
156 /*! \brief Vsite thread task data structure
157  */
158 struct VsiteThread
159 {
160     //! Start of atom range of this task
161     int rangeStart;
162     //! End of atom range of this task
163     int rangeEnd;
164     //! The interaction lists, only vsite entries are used
165     std::array<InteractionList, F_NRE> ilist;
166     //! Local fshift accumulation buffer
167     std::array<RVec, c_numShiftVectors> fshift;
168     //! Local virial dx*df accumulation buffer
169     matrix dxdf;
170     //! Tells if interdependent task idTask should be used (in addition to the rest of this task), this bool has the same value on all threads
171     bool useInterdependentTask;
172     //! Data for vsites that involve constructing atoms in the atom range of other threads/tasks
173     InterdependentTask idTask;
174
175     /*! \brief Constructor */
176     VsiteThread()
177     {
178         rangeStart = -1;
179         rangeEnd   = -1;
180         for (auto& elem : fshift)
181         {
182             elem = { 0.0_real, 0.0_real, 0.0_real };
183         }
184         clear_mat(dxdf);
185         useInterdependentTask = false;
186     }
187 };
188
189
190 /*! \brief Information on how the virtual site work is divided over thread tasks
191  */
192 class ThreadingInfo
193 {
194 public:
195     //! Constructor, retrieves the number of threads to use from gmx_omp_nthreads.h
196     ThreadingInfo();
197
198     //! Returns the number of threads to use for vsite operations
199     int numThreads() const { return numThreads_; }
200
201     //! Returns the thread data for the given thread
202     const VsiteThread& threadData(int threadIndex) const { return *tData_[threadIndex]; }
203
204     //! Returns the thread data for the given thread
205     VsiteThread& threadData(int threadIndex) { return *tData_[threadIndex]; }
206
207     //! Returns the thread data for vsites that depend on non-local vsites
208     const VsiteThread& threadDataNonLocalDependent() const { return *tData_[numThreads_]; }
209
210     //! Returns the thread data for vsites that depend on non-local vsites
211     VsiteThread& threadDataNonLocalDependent() { return *tData_[numThreads_]; }
212
213     //! Set VSites and distribute VSite work over threads, should be called after DD partitioning
214     void setVirtualSites(ArrayRef<const InteractionList> ilist,
215                          ArrayRef<const t_iparams>       iparams,
216                          int                             numAtoms,
217                          int                             homenr,
218                          ArrayRef<const ParticleType>    ptype,
219                          bool                            useDomdec);
220
221 private:
222     //! Number of threads used for vsite operations
223     const int numThreads_;
224     //! Thread local vsites and work structs
225     std::vector<std::unique_ptr<VsiteThread>> tData_;
226     //! Work array for dividing vsites over threads
227     std::vector<int> taskIndex_;
228 };
229
230 /*! \brief Impl class for VirtualSitesHandler
231  */
232 class VirtualSitesHandler::Impl
233 {
234 public:
235     //! Constructor, domdec should be nullptr without DD
236     Impl(const gmx_mtop_t&                 mtop,
237          gmx_domdec_t*                     domdec,
238          PbcType                           pbcType,
239          ArrayRef<const RangePartitioning> updateGroupingPerMoleculeType);
240
241     //! Returns the number of virtual sites acting over multiple update groups
242     int numInterUpdategroupVirtualSites() const { return numInterUpdategroupVirtualSites_; }
243
244     //! Set VSites and distribute VSite work over threads, should be called after DD partitioning
245     void setVirtualSites(ArrayRef<const InteractionList> ilist,
246                          int                             numAtoms,
247                          int                             homenr,
248                          ArrayRef<const ParticleType>    ptype);
249
250     /*! \brief Create positions of vsite atoms based for the local system
251      *
252      * \param[in,out] x          The coordinates
253      * \param[in,out] v          The velocities, needed if operation requires it
254      * \param[in]     box        The box
255      * \param[in]     operation  Whether we calculate positions, velocities, or both
256      */
257     void construct(ArrayRef<RVec> x, ArrayRef<RVec> v, const matrix box, VSiteOperation operation) const;
258
259     /*! \brief Spread the force operating on the vsite atoms on the surrounding atoms.
260      *
261      * vsite should point to a valid object.
262      * The virialHandling parameter determines how virial contributions are handled.
263      * If this is set to Linear, shift forces are accumulated into fshift.
264      * If this is set to NonLinear, non-linear contributions are added to virial.
265      * This non-linear correction is required when the virial is not calculated
266      * afterwards from the particle position and forces, but in a different way,
267      * as for instance for the PME mesh contribution.
268      */
269     void spreadForces(ArrayRef<const RVec> x,
270                       ArrayRef<RVec>       f,
271                       VirialHandling       virialHandling,
272                       ArrayRef<RVec>       fshift,
273                       matrix               virial,
274                       t_nrnb*              nrnb,
275                       const matrix         box,
276                       gmx_wallcycle*       wcycle);
277
278 private:
279     //! The number of vsites that cross update groups, when =0 no PBC treatment is needed
280     const int numInterUpdategroupVirtualSites_;
281     //! PBC and DD information
282     const DomainInfo domainInfo_;
283     //! The interaction parameters
284     const ArrayRef<const t_iparams> iparams_;
285     //! The interaction lists
286     ArrayRef<const InteractionList> ilists_;
287     //! Information for handling vsite threading
288     ThreadingInfo threadingInfo_;
289 };
290
291 VirtualSitesHandler::~VirtualSitesHandler() = default;
292
293 int VirtualSitesHandler::numInterUpdategroupVirtualSites() const
294 {
295     return impl_->numInterUpdategroupVirtualSites();
296 }
297
298 /*! \brief Returns the sum of the vsite ilist sizes over all vsite types
299  *
300  * \param[in] ilist  The interaction list
301  */
302 static int vsiteIlistNrCount(ArrayRef<const InteractionList> ilist)
303 {
304     int nr = 0;
305     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
306     {
307         nr += ilist[ftype].size();
308     }
309
310     return nr;
311 }
312
313 //! Computes the distance between xi and xj, pbc is used when pbc!=nullptr
314 static int pbc_rvec_sub(const t_pbc* pbc, const rvec xi, const rvec xj, rvec dx)
315 {
316     if (pbc)
317     {
318         return pbc_dx_aiuc(pbc, xi, xj, dx);
319     }
320     else
321     {
322         rvec_sub(xi, xj, dx);
323         return c_centralShiftIndex;
324     }
325 }
326
327 //! Returns the 1/norm(x)
328 static inline real inverseNorm(const rvec x)
329 {
330     return gmx::invsqrt(iprod(x, x));
331 }
332
333 //! Whether we're calculating the virtual site position
334 enum class VSiteCalculatePosition
335 {
336     Yes,
337     No
338 };
339 //! Whether we're calculating the virtual site velocity
340 enum class VSiteCalculateVelocity
341 {
342     Yes,
343     No
344 };
345
346 #ifndef DOXYGEN
347 /* Vsite construction routines */
348
349 // GCC 8 falsely flags unused variables if constexpr prunes a code path, fixed in GCC 9
350 // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85827
351 // clang-format off
352 GCC_DIAGNOSTIC_IGNORE(-Wunused-but-set-parameter)
353 // clang-format on
354
355 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
356 static void constr_vsite1(const rvec xi, rvec x, const rvec vi, rvec v)
357 {
358     if (calculatePosition == VSiteCalculatePosition::Yes)
359     {
360         copy_rvec(xi, x);
361         /* TOTAL: 0 flops */
362     }
363     if (calculateVelocity == VSiteCalculateVelocity::Yes)
364     {
365         copy_rvec(vi, v);
366     }
367 }
368
369 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
370 static void
371 constr_vsite2(const rvec xi, const rvec xj, rvec x, real a, const t_pbc* pbc, const rvec vi, const rvec vj, rvec v)
372 {
373     const real b = 1 - a;
374     /* 1 flop */
375
376     if (calculatePosition == VSiteCalculatePosition::Yes)
377     {
378         if (pbc)
379         {
380             rvec dx;
381             pbc_dx_aiuc(pbc, xj, xi, dx);
382             x[XX] = xi[XX] + a * dx[XX];
383             x[YY] = xi[YY] + a * dx[YY];
384             x[ZZ] = xi[ZZ] + a * dx[ZZ];
385         }
386         else
387         {
388             x[XX] = b * xi[XX] + a * xj[XX];
389             x[YY] = b * xi[YY] + a * xj[YY];
390             x[ZZ] = b * xi[ZZ] + a * xj[ZZ];
391             /* 9 Flops */
392         }
393         /* TOTAL: 10 flops */
394     }
395     if (calculateVelocity == VSiteCalculateVelocity::Yes)
396     {
397         v[XX] = b * vi[XX] + a * vj[XX];
398         v[YY] = b * vi[YY] + a * vj[YY];
399         v[ZZ] = b * vi[ZZ] + a * vj[ZZ];
400     }
401 }
402
403 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
404 static void
405 constr_vsite2FD(const rvec xi, const rvec xj, rvec x, real a, const t_pbc* pbc, const rvec vi, const rvec vj, rvec v)
406 {
407     rvec xij = { 0 };
408     pbc_rvec_sub(pbc, xj, xi, xij);
409     /* 3 flops */
410
411     const real invNormXij = inverseNorm(xij);
412     const real b          = a * invNormXij;
413     /* 6 + 10 flops */
414
415     if (calculatePosition == VSiteCalculatePosition::Yes)
416     {
417         x[XX] = xi[XX] + b * xij[XX];
418         x[YY] = xi[YY] + b * xij[YY];
419         x[ZZ] = xi[ZZ] + b * xij[ZZ];
420         /* 6 Flops */
421         /* TOTAL: 25 flops */
422     }
423     if (calculateVelocity == VSiteCalculateVelocity::Yes)
424     {
425         rvec vij = { 0 };
426         rvec_sub(vj, vi, vij);
427         const real vijDotXij = iprod(vij, xij);
428
429         v[XX] = vi[XX] + b * (vij[XX] - xij[XX] * vijDotXij * invNormXij * invNormXij);
430         v[YY] = vi[YY] + b * (vij[YY] - xij[YY] * vijDotXij * invNormXij * invNormXij);
431         v[ZZ] = vi[ZZ] + b * (vij[ZZ] - xij[ZZ] * vijDotXij * invNormXij * invNormXij);
432     }
433 }
434
435 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
436 static void constr_vsite3(const rvec   xi,
437                           const rvec   xj,
438                           const rvec   xk,
439                           rvec         x,
440                           real         a,
441                           real         b,
442                           const t_pbc* pbc,
443                           const rvec   vi,
444                           const rvec   vj,
445                           const rvec   vk,
446                           rvec         v)
447 {
448     const real c = 1 - a - b;
449     /* 2 flops */
450
451     if (calculatePosition == VSiteCalculatePosition::Yes)
452     {
453         if (pbc)
454         {
455             rvec dxj, dxk;
456
457             pbc_dx_aiuc(pbc, xj, xi, dxj);
458             pbc_dx_aiuc(pbc, xk, xi, dxk);
459             x[XX] = xi[XX] + a * dxj[XX] + b * dxk[XX];
460             x[YY] = xi[YY] + a * dxj[YY] + b * dxk[YY];
461             x[ZZ] = xi[ZZ] + a * dxj[ZZ] + b * dxk[ZZ];
462         }
463         else
464         {
465             x[XX] = c * xi[XX] + a * xj[XX] + b * xk[XX];
466             x[YY] = c * xi[YY] + a * xj[YY] + b * xk[YY];
467             x[ZZ] = c * xi[ZZ] + a * xj[ZZ] + b * xk[ZZ];
468             /* 15 Flops */
469         }
470         /* TOTAL: 17 flops */
471     }
472     if (calculateVelocity == VSiteCalculateVelocity::Yes)
473     {
474         v[XX] = c * vi[XX] + a * vj[XX] + b * vk[XX];
475         v[YY] = c * vi[YY] + a * vj[YY] + b * vk[YY];
476         v[ZZ] = c * vi[ZZ] + a * vj[ZZ] + b * vk[ZZ];
477     }
478 }
479
480 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
481 static void constr_vsite3FD(const rvec   xi,
482                             const rvec   xj,
483                             const rvec   xk,
484                             rvec         x,
485                             real         a,
486                             real         b,
487                             const t_pbc* pbc,
488                             const rvec   vi,
489                             const rvec   vj,
490                             const rvec   vk,
491                             rvec         v)
492 {
493     rvec xij, xjk, temp;
494
495     pbc_rvec_sub(pbc, xj, xi, xij);
496     pbc_rvec_sub(pbc, xk, xj, xjk);
497     /* 6 flops */
498
499     /* temp goes from i to a point on the line jk */
500     temp[XX] = xij[XX] + a * xjk[XX];
501     temp[YY] = xij[YY] + a * xjk[YY];
502     temp[ZZ] = xij[ZZ] + a * xjk[ZZ];
503     /* 6 flops */
504
505     const real invNormTemp = inverseNorm(temp);
506     const real c           = b * invNormTemp;
507     /* 6 + 10 flops */
508
509     if (calculatePosition == VSiteCalculatePosition::Yes)
510     {
511         x[XX] = xi[XX] + c * temp[XX];
512         x[YY] = xi[YY] + c * temp[YY];
513         x[ZZ] = xi[ZZ] + c * temp[ZZ];
514         /* 6 Flops */
515         /* TOTAL: 34 flops */
516     }
517     if (calculateVelocity == VSiteCalculateVelocity::Yes)
518     {
519         rvec vij = { 0 };
520         rvec vjk = { 0 };
521         rvec_sub(vj, vi, vij);
522         rvec_sub(vk, vj, vjk);
523         const rvec tempV = { vij[XX] + a * vjk[XX], vij[YY] + a * vjk[YY], vij[ZZ] + a * vjk[ZZ] };
524         const real tempDotTempV = iprod(temp, tempV);
525
526         v[XX] = vi[XX] + c * (tempV[XX] - temp[XX] * tempDotTempV * invNormTemp * invNormTemp);
527         v[YY] = vi[YY] + c * (tempV[YY] - temp[YY] * tempDotTempV * invNormTemp * invNormTemp);
528         v[ZZ] = vi[ZZ] + c * (tempV[ZZ] - temp[ZZ] * tempDotTempV * invNormTemp * invNormTemp);
529     }
530 }
531
532 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
533 static void constr_vsite3FAD(const rvec   xi,
534                              const rvec   xj,
535                              const rvec   xk,
536                              rvec         x,
537                              real         a,
538                              real         b,
539                              const t_pbc* pbc,
540                              const rvec   vi,
541                              const rvec   vj,
542                              const rvec   vk,
543                              rvec         v)
544 { // Note: a = d * cos(theta)
545     //       b = d * sin(theta)
546     rvec xij, xjk, xp;
547
548     pbc_rvec_sub(pbc, xj, xi, xij);
549     pbc_rvec_sub(pbc, xk, xj, xjk);
550     /* 6 flops */
551
552     const real invdij    = inverseNorm(xij);
553     const real xijDotXjk = iprod(xij, xjk);
554     const real c1        = invdij * invdij * xijDotXjk;
555     xp[XX]               = xjk[XX] - c1 * xij[XX];
556     xp[YY]               = xjk[YY] - c1 * xij[YY];
557     xp[ZZ]               = xjk[ZZ] - c1 * xij[ZZ];
558     const real a1        = a * invdij;
559     const real invNormXp = inverseNorm(xp);
560     const real b1        = b * invNormXp;
561     /* 45 */
562
563     if (calculatePosition == VSiteCalculatePosition::Yes)
564     {
565         x[XX] = xi[XX] + a1 * xij[XX] + b1 * xp[XX];
566         x[YY] = xi[YY] + a1 * xij[YY] + b1 * xp[YY];
567         x[ZZ] = xi[ZZ] + a1 * xij[ZZ] + b1 * xp[ZZ];
568         /* 12 Flops */
569         /* TOTAL: 63 flops */
570     }
571
572     if (calculateVelocity == VSiteCalculateVelocity::Yes)
573     {
574         rvec vij = { 0 };
575         rvec vjk = { 0 };
576         rvec_sub(vj, vi, vij);
577         rvec_sub(vk, vj, vjk);
578
579         const real vijDotXjkPlusXijDotVjk = iprod(vij, xjk) + iprod(xij, vjk);
580         const real xijDotVij              = iprod(xij, vij);
581         const real invNormXij2            = invdij * invdij;
582
583         rvec vp = { 0 };
584         vp[XX]  = vjk[XX]
585                  - xij[XX] * invNormXij2
586                            * (vijDotXjkPlusXijDotVjk - invNormXij2 * xijDotXjk * xijDotVij * 2)
587                  - vij[XX] * xijDotXjk * invNormXij2;
588         vp[YY] = vjk[YY]
589                  - xij[YY] * invNormXij2
590                            * (vijDotXjkPlusXijDotVjk - invNormXij2 * xijDotXjk * xijDotVij * 2)
591                  - vij[YY] * xijDotXjk * invNormXij2;
592         vp[ZZ] = vjk[ZZ]
593                  - xij[ZZ] * invNormXij2
594                            * (vijDotXjkPlusXijDotVjk - invNormXij2 * xijDotXjk * xijDotVij * 2)
595                  - vij[ZZ] * xijDotXjk * invNormXij2;
596
597         const real xpDotVp = iprod(xp, vp);
598
599         v[XX] = vi[XX] + a1 * (vij[XX] - xij[XX] * xijDotVij * invdij * invdij)
600                 + b1 * (vp[XX] - xp[XX] * xpDotVp * invNormXp * invNormXp);
601         v[YY] = vi[YY] + a1 * (vij[YY] - xij[YY] * xijDotVij * invdij * invdij)
602                 + b1 * (vp[YY] - xp[YY] * xpDotVp * invNormXp * invNormXp);
603         v[ZZ] = vi[ZZ] + a1 * (vij[ZZ] - xij[ZZ] * xijDotVij * invdij * invdij)
604                 + b1 * (vp[ZZ] - xp[ZZ] * xpDotVp * invNormXp * invNormXp);
605     }
606 }
607
608 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
609 static void constr_vsite3OUT(const rvec   xi,
610                              const rvec   xj,
611                              const rvec   xk,
612                              rvec         x,
613                              real         a,
614                              real         b,
615                              real         c,
616                              const t_pbc* pbc,
617                              const rvec   vi,
618                              const rvec   vj,
619                              const rvec   vk,
620                              rvec         v)
621 {
622     rvec xij, xik, temp;
623
624     pbc_rvec_sub(pbc, xj, xi, xij);
625     pbc_rvec_sub(pbc, xk, xi, xik);
626     cprod(xij, xik, temp);
627     /* 15 Flops */
628
629     if (calculatePosition == VSiteCalculatePosition::Yes)
630     {
631         x[XX] = xi[XX] + a * xij[XX] + b * xik[XX] + c * temp[XX];
632         x[YY] = xi[YY] + a * xij[YY] + b * xik[YY] + c * temp[YY];
633         x[ZZ] = xi[ZZ] + a * xij[ZZ] + b * xik[ZZ] + c * temp[ZZ];
634         /* 18 Flops */
635         /* TOTAL: 33 flops */
636     }
637
638     if (calculateVelocity == VSiteCalculateVelocity::Yes)
639     {
640         rvec vij = { 0 };
641         rvec vik = { 0 };
642         rvec_sub(vj, vi, vij);
643         rvec_sub(vk, vi, vik);
644
645         rvec temp1 = { 0 };
646         rvec temp2 = { 0 };
647         cprod(vij, xik, temp1);
648         cprod(xij, vik, temp2);
649
650         v[XX] = vi[XX] + a * vij[XX] + b * vik[XX] + c * (temp1[XX] + temp2[XX]);
651         v[YY] = vi[YY] + a * vij[YY] + b * vik[YY] + c * (temp1[YY] + temp2[YY]);
652         v[ZZ] = vi[ZZ] + a * vij[ZZ] + b * vik[ZZ] + c * (temp1[ZZ] + temp2[ZZ]);
653     }
654 }
655
656 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
657 static void constr_vsite4FD(const rvec   xi,
658                             const rvec   xj,
659                             const rvec   xk,
660                             const rvec   xl,
661                             rvec         x,
662                             real         a,
663                             real         b,
664                             real         c,
665                             const t_pbc* pbc,
666                             const rvec   vi,
667                             const rvec   vj,
668                             const rvec   vk,
669                             const rvec   vl,
670                             rvec         v)
671 {
672     rvec xij, xjk, xjl, temp;
673     real d;
674
675     pbc_rvec_sub(pbc, xj, xi, xij);
676     pbc_rvec_sub(pbc, xk, xj, xjk);
677     pbc_rvec_sub(pbc, xl, xj, xjl);
678     /* 9 flops */
679
680     /* temp goes from i to a point on the plane jkl */
681     temp[XX] = xij[XX] + a * xjk[XX] + b * xjl[XX];
682     temp[YY] = xij[YY] + a * xjk[YY] + b * xjl[YY];
683     temp[ZZ] = xij[ZZ] + a * xjk[ZZ] + b * xjl[ZZ];
684     /* 12 flops */
685
686     const real invRm = inverseNorm(temp);
687     d                = c * invRm;
688     /* 6 + 10 flops */
689
690     if (calculatePosition == VSiteCalculatePosition::Yes)
691     {
692         x[XX] = xi[XX] + d * temp[XX];
693         x[YY] = xi[YY] + d * temp[YY];
694         x[ZZ] = xi[ZZ] + d * temp[ZZ];
695         /* 6 Flops */
696         /* TOTAL: 43 flops */
697     }
698     if (calculateVelocity == VSiteCalculateVelocity::Yes)
699     {
700         rvec vij = { 0 };
701         rvec vjk = { 0 };
702         rvec vjl = { 0 };
703
704         rvec_sub(vj, vi, vij);
705         rvec_sub(vk, vj, vjk);
706         rvec_sub(vl, vj, vjl);
707
708         rvec vm = { 0 };
709         vm[XX]  = vij[XX] + a * vjk[XX] + b * vjl[XX];
710         vm[YY]  = vij[YY] + a * vjk[YY] + b * vjl[YY];
711         vm[ZZ]  = vij[ZZ] + a * vjk[ZZ] + b * vjl[ZZ];
712
713         const real vmDotRm = iprod(vm, temp);
714         v[XX]              = vi[XX] + d * (vm[XX] - temp[XX] * vmDotRm * invRm * invRm);
715         v[YY]              = vi[YY] + d * (vm[YY] - temp[YY] * vmDotRm * invRm * invRm);
716         v[ZZ]              = vi[ZZ] + d * (vm[ZZ] - temp[ZZ] * vmDotRm * invRm * invRm);
717     }
718 }
719
720 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
721 static void constr_vsite4FDN(const rvec   xi,
722                              const rvec   xj,
723                              const rvec   xk,
724                              const rvec   xl,
725                              rvec         x,
726                              real         a,
727                              real         b,
728                              real         c,
729                              const t_pbc* pbc,
730                              const rvec   vi,
731                              const rvec   vj,
732                              const rvec   vk,
733                              const rvec   vl,
734                              rvec         v)
735 {
736     rvec xij, xik, xil, ra, rb, rja, rjb, rm;
737     real d;
738
739     pbc_rvec_sub(pbc, xj, xi, xij);
740     pbc_rvec_sub(pbc, xk, xi, xik);
741     pbc_rvec_sub(pbc, xl, xi, xil);
742     /* 9 flops */
743
744     ra[XX] = a * xik[XX];
745     ra[YY] = a * xik[YY];
746     ra[ZZ] = a * xik[ZZ];
747
748     rb[XX] = b * xil[XX];
749     rb[YY] = b * xil[YY];
750     rb[ZZ] = b * xil[ZZ];
751
752     /* 6 flops */
753
754     rvec_sub(ra, xij, rja);
755     rvec_sub(rb, xij, rjb);
756     /* 6 flops */
757
758     cprod(rja, rjb, rm);
759     /* 9 flops */
760
761     const real invNormRm = inverseNorm(rm);
762     d                    = c * invNormRm;
763     /* 5+5+1 flops */
764
765     if (calculatePosition == VSiteCalculatePosition::Yes)
766     {
767         x[XX] = xi[XX] + d * rm[XX];
768         x[YY] = xi[YY] + d * rm[YY];
769         x[ZZ] = xi[ZZ] + d * rm[ZZ];
770         /* 6 Flops */
771         /* TOTAL: 47 flops */
772     }
773
774     if (calculateVelocity == VSiteCalculateVelocity::Yes)
775     {
776         rvec vij = { 0 };
777         rvec vik = { 0 };
778         rvec vil = { 0 };
779         rvec_sub(vj, vi, vij);
780         rvec_sub(vk, vi, vik);
781         rvec_sub(vl, vi, vil);
782
783         rvec vja = { 0 };
784         rvec vjb = { 0 };
785
786         vja[XX] = a * vik[XX] - vij[XX];
787         vja[YY] = a * vik[YY] - vij[YY];
788         vja[ZZ] = a * vik[ZZ] - vij[ZZ];
789         vjb[XX] = b * vil[XX] - vij[XX];
790         vjb[YY] = b * vil[YY] - vij[YY];
791         vjb[ZZ] = b * vil[ZZ] - vij[ZZ];
792
793         rvec temp1 = { 0 };
794         rvec temp2 = { 0 };
795         cprod(vja, rjb, temp1);
796         cprod(rja, vjb, temp2);
797
798         rvec vm = { 0 };
799         vm[XX]  = temp1[XX] + temp2[XX];
800         vm[YY]  = temp1[YY] + temp2[YY];
801         vm[ZZ]  = temp1[ZZ] + temp2[ZZ];
802
803         const real rmDotVm = iprod(rm, vm);
804         v[XX]              = vi[XX] + d * (vm[XX] - rm[XX] * rmDotVm * invNormRm * invNormRm);
805         v[YY]              = vi[YY] + d * (vm[YY] - rm[YY] * rmDotVm * invNormRm * invNormRm);
806         v[ZZ]              = vi[ZZ] + d * (vm[ZZ] - rm[ZZ] * rmDotVm * invNormRm * invNormRm);
807     }
808 }
809
810 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
811 static int constr_vsiten(const t_iatom*            ia,
812                          ArrayRef<const t_iparams> ip,
813                          ArrayRef<RVec>            x,
814                          const t_pbc*              pbc,
815                          ArrayRef<RVec>            v)
816 {
817     rvec x1, dx;
818     dvec dsum;
819     real a;
820     dvec dvsum = { 0 };
821     rvec v1    = { 0 };
822
823     const int n3 = 3 * ip[ia[0]].vsiten.n;
824     const int av = ia[1];
825     int       ai = ia[2];
826     copy_rvec(x[ai], x1);
827     copy_rvec(v[ai], v1);
828     clear_dvec(dsum);
829     for (int i = 3; i < n3; i += 3)
830     {
831         ai = ia[i + 2];
832         a  = ip[ia[i]].vsiten.a;
833         if (calculatePosition == VSiteCalculatePosition::Yes)
834         {
835             if (pbc)
836             {
837                 pbc_dx_aiuc(pbc, x[ai], x1, dx);
838             }
839             else
840             {
841                 rvec_sub(x[ai], x1, dx);
842             }
843             dsum[XX] += a * dx[XX];
844             dsum[YY] += a * dx[YY];
845             dsum[ZZ] += a * dx[ZZ];
846             /* 9 Flops */
847         }
848         if (calculateVelocity == VSiteCalculateVelocity::Yes)
849         {
850             rvec_sub(v[ai], v1, dx);
851             dvsum[XX] += a * dx[XX];
852             dvsum[YY] += a * dx[YY];
853             dvsum[ZZ] += a * dx[ZZ];
854             /* 9 Flops */
855         }
856     }
857
858     if (calculatePosition == VSiteCalculatePosition::Yes)
859     {
860         x[av][XX] = x1[XX] + dsum[XX];
861         x[av][YY] = x1[YY] + dsum[YY];
862         x[av][ZZ] = x1[ZZ] + dsum[ZZ];
863     }
864
865     if (calculateVelocity == VSiteCalculateVelocity::Yes)
866     {
867         v[av][XX] = v1[XX] + dvsum[XX];
868         v[av][YY] = v1[YY] + dvsum[YY];
869         v[av][ZZ] = v1[ZZ] + dvsum[ZZ];
870     }
871
872     return n3;
873 }
874 // End GCC 8 bug
875 GCC_DIAGNOSTIC_RESET
876
877 #endif // DOXYGEN
878
879 //! PBC modes for vsite construction and spreading
880 enum class PbcMode
881 {
882     all, //!< Apply normal, simple PBC for all vsites
883     none //!< No PBC treatment needed
884 };
885
886 /*! \brief Returns the PBC mode based on the system PBC and vsite properties
887  *
888  * \param[in] pbcPtr  A pointer to a PBC struct or nullptr when no PBC treatment is required
889  */
890 static PbcMode getPbcMode(const t_pbc* pbcPtr)
891 {
892     if (pbcPtr == nullptr)
893     {
894         return PbcMode::none;
895     }
896     else
897     {
898         return PbcMode::all;
899     }
900 }
901
902 /*! \brief Executes the vsite construction task for a single thread
903  *
904  * \tparam        operation  Whether we are calculating positions, velocities, or both
905  * \param[in,out] x   Coordinates to construct vsites for
906  * \param[in,out] v   Velocities are generated for virtual sites if `operation` requires it
907  * \param[in]     ip  Interaction parameters for all interaction, only vsite parameters are used
908  * \param[in]     ilist  The interaction lists, only vsites are usesd
909  * \param[in]     pbc_null  PBC struct, used for PBC distance calculations when !=nullptr
910  */
911 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
912 static void construct_vsites_thread(ArrayRef<RVec>                  x,
913                                     ArrayRef<RVec>                  v,
914                                     ArrayRef<const t_iparams>       ip,
915                                     ArrayRef<const InteractionList> ilist,
916                                     const t_pbc*                    pbc_null)
917 {
918     if (calculateVelocity == VSiteCalculateVelocity::Yes)
919     {
920         GMX_RELEASE_ASSERT(!v.empty(),
921                            "Can't calculate velocities without access to velocity vector.");
922     }
923
924     // Work around clang bug (unfixed as of Feb 2021)
925     // https://bugs.llvm.org/show_bug.cgi?id=35450
926     // clang-format off
927     CLANG_DIAGNOSTIC_IGNORE(-Wunused-lambda-capture)
928     // clang-format on
929     // GCC 8 falsely flags unused variables if constexpr prunes a code path, fixed in GCC 9
930     // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85827
931     // clang-format off
932     GCC_DIAGNOSTIC_IGNORE(-Wunused-but-set-parameter)
933     // clang-format on
934     // getVOrNull returns a velocity rvec if we need it, nullptr otherwise.
935     auto getVOrNull = [v](int idx) -> real* {
936         if (calculateVelocity == VSiteCalculateVelocity::Yes)
937         {
938             return v[idx].as_vec();
939         }
940         else
941         {
942             return nullptr;
943         }
944     };
945     GCC_DIAGNOSTIC_RESET
946     CLANG_DIAGNOSTIC_RESET
947
948     const PbcMode pbcMode = getPbcMode(pbc_null);
949     /* We need another pbc pointer, as with charge groups we switch per vsite */
950     const t_pbc* pbc_null2 = pbc_null;
951
952     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
953     {
954         if (ilist[ftype].empty())
955         {
956             continue;
957         }
958
959         { // TODO remove me
960             int nra = interaction_function[ftype].nratoms;
961             int inc = 1 + nra;
962             int nr  = ilist[ftype].size();
963
964             const t_iatom* ia = ilist[ftype].iatoms.data();
965
966             for (int i = 0; i < nr;)
967             {
968                 int tp = ia[0];
969                 /* The vsite and constructing atoms */
970                 int avsite = ia[1];
971                 int ai     = ia[2];
972                 /* Constants for constructing vsites */
973                 real a1 = ip[tp].vsite.a;
974                 /* Copy the old position */
975                 rvec xv;
976                 copy_rvec(x[avsite], xv);
977
978                 /* Construct the vsite depending on type */
979                 int  aj, ak, al;
980                 real b1, c1;
981                 switch (ftype)
982                 {
983                     case F_VSITE1:
984                         constr_vsite1<calculatePosition, calculateVelocity>(
985                                 x[ai], x[avsite], getVOrNull(ai), getVOrNull(avsite));
986                         break;
987                     case F_VSITE2:
988                         aj = ia[3];
989                         constr_vsite2<calculatePosition, calculateVelocity>(x[ai],
990                                                                             x[aj],
991                                                                             x[avsite],
992                                                                             a1,
993                                                                             pbc_null2,
994                                                                             getVOrNull(ai),
995                                                                             getVOrNull(aj),
996                                                                             getVOrNull(avsite));
997                         break;
998                     case F_VSITE2FD:
999                         aj = ia[3];
1000                         constr_vsite2FD<calculatePosition, calculateVelocity>(x[ai],
1001                                                                               x[aj],
1002                                                                               x[avsite],
1003                                                                               a1,
1004                                                                               pbc_null2,
1005                                                                               getVOrNull(ai),
1006                                                                               getVOrNull(aj),
1007                                                                               getVOrNull(avsite));
1008                         break;
1009                     case F_VSITE3:
1010                         aj = ia[3];
1011                         ak = ia[4];
1012                         b1 = ip[tp].vsite.b;
1013                         constr_vsite3<calculatePosition, calculateVelocity>(x[ai],
1014                                                                             x[aj],
1015                                                                             x[ak],
1016                                                                             x[avsite],
1017                                                                             a1,
1018                                                                             b1,
1019                                                                             pbc_null2,
1020                                                                             getVOrNull(ai),
1021                                                                             getVOrNull(aj),
1022                                                                             getVOrNull(ak),
1023                                                                             getVOrNull(avsite));
1024                         break;
1025                     case F_VSITE3FD:
1026                         aj = ia[3];
1027                         ak = ia[4];
1028                         b1 = ip[tp].vsite.b;
1029                         constr_vsite3FD<calculatePosition, calculateVelocity>(x[ai],
1030                                                                               x[aj],
1031                                                                               x[ak],
1032                                                                               x[avsite],
1033                                                                               a1,
1034                                                                               b1,
1035                                                                               pbc_null2,
1036                                                                               getVOrNull(ai),
1037                                                                               getVOrNull(aj),
1038                                                                               getVOrNull(ak),
1039                                                                               getVOrNull(avsite));
1040                         break;
1041                     case F_VSITE3FAD:
1042                         aj = ia[3];
1043                         ak = ia[4];
1044                         b1 = ip[tp].vsite.b;
1045                         constr_vsite3FAD<calculatePosition, calculateVelocity>(x[ai],
1046                                                                                x[aj],
1047                                                                                x[ak],
1048                                                                                x[avsite],
1049                                                                                a1,
1050                                                                                b1,
1051                                                                                pbc_null2,
1052                                                                                getVOrNull(ai),
1053                                                                                getVOrNull(aj),
1054                                                                                getVOrNull(ak),
1055                                                                                getVOrNull(avsite));
1056                         break;
1057                     case F_VSITE3OUT:
1058                         aj = ia[3];
1059                         ak = ia[4];
1060                         b1 = ip[tp].vsite.b;
1061                         c1 = ip[tp].vsite.c;
1062                         constr_vsite3OUT<calculatePosition, calculateVelocity>(x[ai],
1063                                                                                x[aj],
1064                                                                                x[ak],
1065                                                                                x[avsite],
1066                                                                                a1,
1067                                                                                b1,
1068                                                                                c1,
1069                                                                                pbc_null2,
1070                                                                                getVOrNull(ai),
1071                                                                                getVOrNull(aj),
1072                                                                                getVOrNull(ak),
1073                                                                                getVOrNull(avsite));
1074                         break;
1075                     case F_VSITE4FD:
1076                         aj = ia[3];
1077                         ak = ia[4];
1078                         al = ia[5];
1079                         b1 = ip[tp].vsite.b;
1080                         c1 = ip[tp].vsite.c;
1081                         constr_vsite4FD<calculatePosition, calculateVelocity>(x[ai],
1082                                                                               x[aj],
1083                                                                               x[ak],
1084                                                                               x[al],
1085                                                                               x[avsite],
1086                                                                               a1,
1087                                                                               b1,
1088                                                                               c1,
1089                                                                               pbc_null2,
1090                                                                               getVOrNull(ai),
1091                                                                               getVOrNull(aj),
1092                                                                               getVOrNull(ak),
1093                                                                               getVOrNull(al),
1094                                                                               getVOrNull(avsite));
1095                         break;
1096                     case F_VSITE4FDN:
1097                         aj = ia[3];
1098                         ak = ia[4];
1099                         al = ia[5];
1100                         b1 = ip[tp].vsite.b;
1101                         c1 = ip[tp].vsite.c;
1102                         constr_vsite4FDN<calculatePosition, calculateVelocity>(x[ai],
1103                                                                                x[aj],
1104                                                                                x[ak],
1105                                                                                x[al],
1106                                                                                x[avsite],
1107                                                                                a1,
1108                                                                                b1,
1109                                                                                c1,
1110                                                                                pbc_null2,
1111                                                                                getVOrNull(ai),
1112                                                                                getVOrNull(aj),
1113                                                                                getVOrNull(ak),
1114                                                                                getVOrNull(al),
1115                                                                                getVOrNull(avsite));
1116                         break;
1117                     case F_VSITEN:
1118                         inc = constr_vsiten<calculatePosition, calculateVelocity>(ia, ip, x, pbc_null2, v);
1119                         break;
1120                     default:
1121                         gmx_fatal(FARGS, "No such vsite type %d in %s, line %d", ftype, __FILE__, __LINE__);
1122                 }
1123
1124                 if (pbcMode == PbcMode::all)
1125                 {
1126                     /* Keep the vsite in the same periodic image as before */
1127                     rvec dx;
1128                     int  ishift = pbc_dx_aiuc(pbc_null, x[avsite], xv, dx);
1129                     if (ishift != c_centralShiftIndex)
1130                     {
1131                         rvec_add(xv, dx, x[avsite]);
1132                     }
1133                 }
1134
1135                 /* Increment loop variables */
1136                 i += inc;
1137                 ia += inc;
1138             }
1139         }
1140     }
1141 }
1142
1143 /*! \brief Dispatch the vsite construction tasks for all threads
1144  *
1145  * \param[in]     threadingInfo  Used to divide work over threads when != nullptr
1146  * \param[in,out] x   Coordinates to construct vsites for
1147  * \param[in,out] v   When not empty, velocities are generated for virtual sites
1148  * \param[in]     ip  Interaction parameters for all interaction, only vsite parameters are used
1149  * \param[in]     ilist  The interaction lists, only vsites are usesd
1150  * \param[in]     domainInfo  Information about PBC and DD
1151  * \param[in]     box  Used for PBC when PBC is set in domainInfo
1152  */
1153 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
1154 static void construct_vsites(const ThreadingInfo*            threadingInfo,
1155                              ArrayRef<RVec>                  x,
1156                              ArrayRef<RVec>                  v,
1157                              ArrayRef<const t_iparams>       ip,
1158                              ArrayRef<const InteractionList> ilist,
1159                              const DomainInfo&               domainInfo,
1160                              const matrix                    box)
1161 {
1162     const bool useDomdec = domainInfo.useDomdec();
1163
1164     t_pbc pbc, *pbc_null;
1165
1166     /* We only need to do pbc when we have inter update-group vsites.
1167      * Note that with domain decomposition we do not need to apply PBC here
1168      * when we have at least 3 domains along each dimension. Currently we
1169      * do not optimize this case.
1170      */
1171     if (domainInfo.pbcType_ != PbcType::No && domainInfo.useMolPbc_)
1172     {
1173         /* This is wasting some CPU time as we now do this multiple times
1174          * per MD step.
1175          */
1176         ivec null_ivec;
1177         clear_ivec(null_ivec);
1178         pbc_null = set_pbc_dd(
1179                 &pbc, domainInfo.pbcType_, useDomdec ? domainInfo.domdec_->numCells : null_ivec, FALSE, box);
1180     }
1181     else
1182     {
1183         pbc_null = nullptr;
1184     }
1185
1186     if (useDomdec)
1187     {
1188         if (calculateVelocity == VSiteCalculateVelocity::Yes)
1189         {
1190             dd_move_x_and_v_vsites(
1191                     *domainInfo.domdec_, box, as_rvec_array(x.data()), as_rvec_array(v.data()));
1192         }
1193         else
1194         {
1195             dd_move_x_vsites(*domainInfo.domdec_, box, as_rvec_array(x.data()));
1196         }
1197     }
1198
1199     if (threadingInfo == nullptr || threadingInfo->numThreads() == 1)
1200     {
1201         construct_vsites_thread<calculatePosition, calculateVelocity>(x, v, ip, ilist, pbc_null);
1202     }
1203     else
1204     {
1205 #pragma omp parallel num_threads(threadingInfo->numThreads())
1206         {
1207             try
1208             {
1209                 const int          th    = gmx_omp_get_thread_num();
1210                 const VsiteThread& tData = threadingInfo->threadData(th);
1211                 GMX_ASSERT(tData.rangeStart >= 0,
1212                            "The thread data should be initialized before calling construct_vsites");
1213
1214                 construct_vsites_thread<calculatePosition, calculateVelocity>(
1215                         x, v, ip, tData.ilist, pbc_null);
1216                 if (tData.useInterdependentTask)
1217                 {
1218                     /* Here we don't need a barrier (unlike the spreading),
1219                      * since both tasks only construct vsites from particles,
1220                      * or local vsites, not from non-local vsites.
1221                      */
1222                     construct_vsites_thread<calculatePosition, calculateVelocity>(
1223                             x, v, ip, tData.idTask.ilist, pbc_null);
1224                 }
1225             }
1226             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
1227         }
1228         /* Now we can construct the vsites that might depend on other vsites */
1229         construct_vsites_thread<calculatePosition, calculateVelocity>(
1230                 x, v, ip, threadingInfo->threadDataNonLocalDependent().ilist, pbc_null);
1231     }
1232 }
1233
1234 void VirtualSitesHandler::Impl::construct(ArrayRef<RVec> x,
1235                                           ArrayRef<RVec> v,
1236                                           const matrix   box,
1237                                           VSiteOperation operation) const
1238 {
1239     switch (operation)
1240     {
1241         case VSiteOperation::Positions:
1242             construct_vsites<VSiteCalculatePosition::Yes, VSiteCalculateVelocity::No>(
1243                     &threadingInfo_, x, v, iparams_, ilists_, domainInfo_, box);
1244             break;
1245         case VSiteOperation::Velocities:
1246             construct_vsites<VSiteCalculatePosition::No, VSiteCalculateVelocity::Yes>(
1247                     &threadingInfo_, x, v, iparams_, ilists_, domainInfo_, box);
1248             break;
1249         case VSiteOperation::PositionsAndVelocities:
1250             construct_vsites<VSiteCalculatePosition::Yes, VSiteCalculateVelocity::Yes>(
1251                     &threadingInfo_, x, v, iparams_, ilists_, domainInfo_, box);
1252             break;
1253         default: gmx_fatal(FARGS, "Unknown virtual site operation");
1254     }
1255 }
1256
1257 void VirtualSitesHandler::construct(ArrayRef<RVec> x, ArrayRef<RVec> v, const matrix box, VSiteOperation operation) const
1258 {
1259     impl_->construct(x, v, box, operation);
1260 }
1261
1262 void constructVirtualSites(ArrayRef<RVec> x, ArrayRef<const t_iparams> ip, ArrayRef<const InteractionList> ilist)
1263
1264 {
1265     // No PBC, no DD
1266     const DomainInfo domainInfo;
1267     construct_vsites<VSiteCalculatePosition::Yes, VSiteCalculateVelocity::No>(
1268             nullptr, x, {}, ip, ilist, domainInfo, nullptr);
1269 }
1270
1271 #ifndef DOXYGEN
1272 /* Force spreading routines */
1273
1274 static void spread_vsite1(const t_iatom ia[], ArrayRef<RVec> f)
1275 {
1276     const int av = ia[1];
1277     const int ai = ia[2];
1278
1279     f[av] += f[ai];
1280 }
1281
1282 template<VirialHandling virialHandling>
1283 static void spread_vsite2(const t_iatom        ia[],
1284                           real                 a,
1285                           ArrayRef<const RVec> x,
1286                           ArrayRef<RVec>       f,
1287                           ArrayRef<RVec>       fshift,
1288                           const t_pbc*         pbc)
1289 {
1290     rvec    fi, fj, dx;
1291     t_iatom av, ai, aj;
1292
1293     av = ia[1];
1294     ai = ia[2];
1295     aj = ia[3];
1296
1297     svmul(1 - a, f[av], fi);
1298     svmul(a, f[av], fj);
1299     /* 7 flop */
1300
1301     rvec_inc(f[ai], fi);
1302     rvec_inc(f[aj], fj);
1303     /* 6 Flops */
1304
1305     if (virialHandling == VirialHandling::Pbc)
1306     {
1307         int siv;
1308         int sij;
1309         if (pbc)
1310         {
1311             siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
1312             sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
1313         }
1314         else
1315         {
1316             siv = c_centralShiftIndex;
1317             sij = c_centralShiftIndex;
1318         }
1319
1320         if (siv != c_centralShiftIndex || sij != c_centralShiftIndex)
1321         {
1322             rvec_inc(fshift[siv], f[av]);
1323             rvec_dec(fshift[c_centralShiftIndex], fi);
1324             rvec_dec(fshift[sij], fj);
1325         }
1326     }
1327
1328     /* TOTAL: 13 flops */
1329 }
1330
1331 void constructVirtualSitesGlobal(const gmx_mtop_t& mtop, gmx::ArrayRef<gmx::RVec> x)
1332 {
1333     GMX_ASSERT(x.ssize() >= mtop.natoms, "x should contain the whole system");
1334     GMX_ASSERT(!mtop.moleculeBlockIndices.empty(),
1335                "molblock indices are needed in constructVsitesGlobal");
1336
1337     for (size_t mb = 0; mb < mtop.molblock.size(); mb++)
1338     {
1339         const gmx_molblock_t& molb = mtop.molblock[mb];
1340         const gmx_moltype_t&  molt = mtop.moltype[molb.type];
1341         if (vsiteIlistNrCount(molt.ilist) > 0)
1342         {
1343             int atomOffset = mtop.moleculeBlockIndices[mb].globalAtomStart;
1344             for (int mol = 0; mol < molb.nmol; mol++)
1345             {
1346                 constructVirtualSites(
1347                         x.subArray(atomOffset, molt.atoms.nr), mtop.ffparams.iparams, molt.ilist);
1348                 atomOffset += molt.atoms.nr;
1349             }
1350         }
1351     }
1352 }
1353
1354 template<VirialHandling virialHandling>
1355 static void spread_vsite2FD(const t_iatom        ia[],
1356                             real                 a,
1357                             ArrayRef<const RVec> x,
1358                             ArrayRef<RVec>       f,
1359                             ArrayRef<RVec>       fshift,
1360                             matrix               dxdf,
1361                             const t_pbc*         pbc)
1362 {
1363     const int av = ia[1];
1364     const int ai = ia[2];
1365     const int aj = ia[3];
1366     rvec      fv;
1367     copy_rvec(f[av], fv);
1368
1369     rvec xij;
1370     int  sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1371     /* 6 flops */
1372
1373     const real invDistance = inverseNorm(xij);
1374     const real b           = a * invDistance;
1375     /* 4 + ?10? flops */
1376
1377     const real fproj = iprod(xij, fv) * invDistance * invDistance;
1378
1379     rvec fj;
1380     fj[XX] = b * (fv[XX] - fproj * xij[XX]);
1381     fj[YY] = b * (fv[YY] - fproj * xij[YY]);
1382     fj[ZZ] = b * (fv[ZZ] - fproj * xij[ZZ]);
1383     /* 9 */
1384
1385     /* b is already calculated in constr_vsite2FD
1386        storing b somewhere will save flops.     */
1387
1388     f[ai][XX] += fv[XX] - fj[XX];
1389     f[ai][YY] += fv[YY] - fj[YY];
1390     f[ai][ZZ] += fv[ZZ] - fj[ZZ];
1391     f[aj][XX] += fj[XX];
1392     f[aj][YY] += fj[YY];
1393     f[aj][ZZ] += fj[ZZ];
1394     /* 9 Flops */
1395
1396     if (virialHandling == VirialHandling::Pbc)
1397     {
1398         int svi;
1399         if (pbc)
1400         {
1401             rvec xvi;
1402             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1403         }
1404         else
1405         {
1406             svi = c_centralShiftIndex;
1407         }
1408
1409         if (svi != c_centralShiftIndex || sji != c_centralShiftIndex)
1410         {
1411             rvec_dec(fshift[svi], fv);
1412             fshift[c_centralShiftIndex][XX] += fv[XX] - fj[XX];
1413             fshift[c_centralShiftIndex][YY] += fv[YY] - fj[YY];
1414             fshift[c_centralShiftIndex][ZZ] += fv[ZZ] - fj[ZZ];
1415             fshift[sji][XX] += fj[XX];
1416             fshift[sji][YY] += fj[YY];
1417             fshift[sji][ZZ] += fj[ZZ];
1418         }
1419     }
1420
1421     if (virialHandling == VirialHandling::NonLinear)
1422     {
1423         /* Under this condition, the virial for the current forces is not
1424          * calculated from the redistributed forces. This means that
1425          * the effect of non-linear virtual site constructions on the virial
1426          * needs to be added separately. This contribution can be calculated
1427          * in many ways, but the simplest and cheapest way is to use
1428          * the first constructing atom ai as a reference position in space:
1429          * subtract (xv-xi)*fv and add (xj-xi)*fj.
1430          */
1431         rvec xiv;
1432
1433         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1434
1435         for (int i = 0; i < DIM; i++)
1436         {
1437             for (int j = 0; j < DIM; j++)
1438             {
1439                 /* As xix is a linear combination of j and k, use that here */
1440                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * fj[j];
1441             }
1442         }
1443     }
1444
1445     /* TOTAL: 38 flops */
1446 }
1447
1448 template<VirialHandling virialHandling>
1449 static void spread_vsite3(const t_iatom        ia[],
1450                           real                 a,
1451                           real                 b,
1452                           ArrayRef<const RVec> x,
1453                           ArrayRef<RVec>       f,
1454                           ArrayRef<RVec>       fshift,
1455                           const t_pbc*         pbc)
1456 {
1457     rvec fi, fj, fk, dx;
1458     int  av, ai, aj, ak;
1459
1460     av = ia[1];
1461     ai = ia[2];
1462     aj = ia[3];
1463     ak = ia[4];
1464
1465     svmul(1 - a - b, f[av], fi);
1466     svmul(a, f[av], fj);
1467     svmul(b, f[av], fk);
1468     /* 11 flops */
1469
1470     rvec_inc(f[ai], fi);
1471     rvec_inc(f[aj], fj);
1472     rvec_inc(f[ak], fk);
1473     /* 9 Flops */
1474
1475     if (virialHandling == VirialHandling::Pbc)
1476     {
1477         int siv;
1478         int sij;
1479         int sik;
1480         if (pbc)
1481         {
1482             siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
1483             sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
1484             sik = pbc_dx_aiuc(pbc, x[ai], x[ak], dx);
1485         }
1486         else
1487         {
1488             siv = c_centralShiftIndex;
1489             sij = c_centralShiftIndex;
1490             sik = c_centralShiftIndex;
1491         }
1492
1493         if (siv != c_centralShiftIndex || sij != c_centralShiftIndex || sik != c_centralShiftIndex)
1494         {
1495             rvec_inc(fshift[siv], f[av]);
1496             rvec_dec(fshift[c_centralShiftIndex], fi);
1497             rvec_dec(fshift[sij], fj);
1498             rvec_dec(fshift[sik], fk);
1499         }
1500     }
1501
1502     /* TOTAL: 20 flops */
1503 }
1504
1505 template<VirialHandling virialHandling>
1506 static void spread_vsite3FD(const t_iatom        ia[],
1507                             real                 a,
1508                             real                 b,
1509                             ArrayRef<const RVec> x,
1510                             ArrayRef<RVec>       f,
1511                             ArrayRef<RVec>       fshift,
1512                             matrix               dxdf,
1513                             const t_pbc*         pbc)
1514 {
1515     real    fproj, a1;
1516     rvec    xvi, xij, xjk, xix, fv, temp;
1517     t_iatom av, ai, aj, ak;
1518     int     sji, skj;
1519
1520     av = ia[1];
1521     ai = ia[2];
1522     aj = ia[3];
1523     ak = ia[4];
1524     copy_rvec(f[av], fv);
1525
1526     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1527     skj = pbc_rvec_sub(pbc, x[ak], x[aj], xjk);
1528     /* 6 flops */
1529
1530     /* xix goes from i to point x on the line jk */
1531     xix[XX] = xij[XX] + a * xjk[XX];
1532     xix[YY] = xij[YY] + a * xjk[YY];
1533     xix[ZZ] = xij[ZZ] + a * xjk[ZZ];
1534     /* 6 flops */
1535
1536     const real invDistance = inverseNorm(xix);
1537     const real c           = b * invDistance;
1538     /* 4 + ?10? flops */
1539
1540     fproj = iprod(xix, fv) * invDistance * invDistance; /* = (xix . f)/(xix . xix) */
1541
1542     temp[XX] = c * (fv[XX] - fproj * xix[XX]);
1543     temp[YY] = c * (fv[YY] - fproj * xix[YY]);
1544     temp[ZZ] = c * (fv[ZZ] - fproj * xix[ZZ]);
1545     /* 16 */
1546
1547     /* c is already calculated in constr_vsite3FD
1548        storing c somewhere will save 26 flops!     */
1549
1550     a1 = 1 - a;
1551     f[ai][XX] += fv[XX] - temp[XX];
1552     f[ai][YY] += fv[YY] - temp[YY];
1553     f[ai][ZZ] += fv[ZZ] - temp[ZZ];
1554     f[aj][XX] += a1 * temp[XX];
1555     f[aj][YY] += a1 * temp[YY];
1556     f[aj][ZZ] += a1 * temp[ZZ];
1557     f[ak][XX] += a * temp[XX];
1558     f[ak][YY] += a * temp[YY];
1559     f[ak][ZZ] += a * temp[ZZ];
1560     /* 19 Flops */
1561
1562     if (virialHandling == VirialHandling::Pbc)
1563     {
1564         int svi;
1565         if (pbc)
1566         {
1567             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1568         }
1569         else
1570         {
1571             svi = c_centralShiftIndex;
1572         }
1573
1574         if (svi != c_centralShiftIndex || sji != c_centralShiftIndex || skj != c_centralShiftIndex)
1575         {
1576             rvec_dec(fshift[svi], fv);
1577             fshift[c_centralShiftIndex][XX] += fv[XX] - (1 + a) * temp[XX];
1578             fshift[c_centralShiftIndex][YY] += fv[YY] - (1 + a) * temp[YY];
1579             fshift[c_centralShiftIndex][ZZ] += fv[ZZ] - (1 + a) * temp[ZZ];
1580             fshift[sji][XX] += temp[XX];
1581             fshift[sji][YY] += temp[YY];
1582             fshift[sji][ZZ] += temp[ZZ];
1583             fshift[skj][XX] += a * temp[XX];
1584             fshift[skj][YY] += a * temp[YY];
1585             fshift[skj][ZZ] += a * temp[ZZ];
1586         }
1587     }
1588
1589     if (virialHandling == VirialHandling::NonLinear)
1590     {
1591         /* Under this condition, the virial for the current forces is not
1592          * calculated from the redistributed forces. This means that
1593          * the effect of non-linear virtual site constructions on the virial
1594          * needs to be added separately. This contribution can be calculated
1595          * in many ways, but the simplest and cheapest way is to use
1596          * the first constructing atom ai as a reference position in space:
1597          * subtract (xv-xi)*fv and add (xj-xi)*fj + (xk-xi)*fk.
1598          */
1599         rvec xiv;
1600
1601         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1602
1603         for (int i = 0; i < DIM; i++)
1604         {
1605             for (int j = 0; j < DIM; j++)
1606             {
1607                 /* As xix is a linear combination of j and k, use that here */
1608                 dxdf[i][j] += -xiv[i] * fv[j] + xix[i] * temp[j];
1609             }
1610         }
1611     }
1612
1613     /* TOTAL: 61 flops */
1614 }
1615
1616 template<VirialHandling virialHandling>
1617 static void spread_vsite3FAD(const t_iatom        ia[],
1618                              real                 a,
1619                              real                 b,
1620                              ArrayRef<const RVec> x,
1621                              ArrayRef<RVec>       f,
1622                              ArrayRef<RVec>       fshift,
1623                              matrix               dxdf,
1624                              const t_pbc*         pbc)
1625 {
1626     rvec    xvi, xij, xjk, xperp, Fpij, Fppp, fv, f1, f2, f3;
1627     real    a1, b1, c1, c2, invdij, invdij2, invdp, fproj;
1628     t_iatom av, ai, aj, ak;
1629     int     sji, skj;
1630
1631     av = ia[1];
1632     ai = ia[2];
1633     aj = ia[3];
1634     ak = ia[4];
1635     copy_rvec(f[ia[1]], fv);
1636
1637     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1638     skj = pbc_rvec_sub(pbc, x[ak], x[aj], xjk);
1639     /* 6 flops */
1640
1641     invdij    = inverseNorm(xij);
1642     invdij2   = invdij * invdij;
1643     c1        = iprod(xij, xjk) * invdij2;
1644     xperp[XX] = xjk[XX] - c1 * xij[XX];
1645     xperp[YY] = xjk[YY] - c1 * xij[YY];
1646     xperp[ZZ] = xjk[ZZ] - c1 * xij[ZZ];
1647     /* xperp in plane ijk, perp. to ij */
1648     invdp = inverseNorm(xperp);
1649     a1    = a * invdij;
1650     b1    = b * invdp;
1651     /* 45 flops */
1652
1653     /* a1, b1 and c1 are already calculated in constr_vsite3FAD
1654        storing them somewhere will save 45 flops!     */
1655
1656     fproj = iprod(xij, fv) * invdij2;
1657     svmul(fproj, xij, Fpij);                              /* proj. f on xij */
1658     svmul(iprod(xperp, fv) * invdp * invdp, xperp, Fppp); /* proj. f on xperp */
1659     svmul(b1 * fproj, xperp, f3);
1660     /* 23 flops */
1661
1662     rvec_sub(fv, Fpij, f1); /* f1 = f - Fpij */
1663     rvec_sub(f1, Fppp, f2); /* f2 = f - Fpij - Fppp */
1664     for (int d = 0; d < DIM; d++)
1665     {
1666         f1[d] *= a1;
1667         f2[d] *= b1;
1668     }
1669     /* 12 flops */
1670
1671     c2 = 1 + c1;
1672     f[ai][XX] += fv[XX] - f1[XX] + c1 * f2[XX] + f3[XX];
1673     f[ai][YY] += fv[YY] - f1[YY] + c1 * f2[YY] + f3[YY];
1674     f[ai][ZZ] += fv[ZZ] - f1[ZZ] + c1 * f2[ZZ] + f3[ZZ];
1675     f[aj][XX] += f1[XX] - c2 * f2[XX] - f3[XX];
1676     f[aj][YY] += f1[YY] - c2 * f2[YY] - f3[YY];
1677     f[aj][ZZ] += f1[ZZ] - c2 * f2[ZZ] - f3[ZZ];
1678     f[ak][XX] += f2[XX];
1679     f[ak][YY] += f2[YY];
1680     f[ak][ZZ] += f2[ZZ];
1681     /* 30 Flops */
1682
1683     if (virialHandling == VirialHandling::Pbc)
1684     {
1685         int svi;
1686
1687         if (pbc)
1688         {
1689             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1690         }
1691         else
1692         {
1693             svi = c_centralShiftIndex;
1694         }
1695
1696         if (svi != c_centralShiftIndex || sji != c_centralShiftIndex || skj != c_centralShiftIndex)
1697         {
1698             rvec_dec(fshift[svi], fv);
1699             fshift[c_centralShiftIndex][XX] += fv[XX] - f1[XX] - (1 - c1) * f2[XX] + f3[XX];
1700             fshift[c_centralShiftIndex][YY] += fv[YY] - f1[YY] - (1 - c1) * f2[YY] + f3[YY];
1701             fshift[c_centralShiftIndex][ZZ] += fv[ZZ] - f1[ZZ] - (1 - c1) * f2[ZZ] + f3[ZZ];
1702             fshift[sji][XX] += f1[XX] - c1 * f2[XX] - f3[XX];
1703             fshift[sji][YY] += f1[YY] - c1 * f2[YY] - f3[YY];
1704             fshift[sji][ZZ] += f1[ZZ] - c1 * f2[ZZ] - f3[ZZ];
1705             fshift[skj][XX] += f2[XX];
1706             fshift[skj][YY] += f2[YY];
1707             fshift[skj][ZZ] += f2[ZZ];
1708         }
1709     }
1710
1711     if (virialHandling == VirialHandling::NonLinear)
1712     {
1713         rvec xiv;
1714         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1715
1716         for (int i = 0; i < DIM; i++)
1717         {
1718             for (int j = 0; j < DIM; j++)
1719             {
1720                 /* Note that xik=xij+xjk, so we have to add xij*f2 */
1721                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * (f1[j] + (1 - c2) * f2[j] - f3[j])
1722                               + xjk[i] * f2[j];
1723             }
1724         }
1725     }
1726
1727     /* TOTAL: 113 flops */
1728 }
1729
1730 template<VirialHandling virialHandling>
1731 static void spread_vsite3OUT(const t_iatom        ia[],
1732                              real                 a,
1733                              real                 b,
1734                              real                 c,
1735                              ArrayRef<const RVec> x,
1736                              ArrayRef<RVec>       f,
1737                              ArrayRef<RVec>       fshift,
1738                              matrix               dxdf,
1739                              const t_pbc*         pbc)
1740 {
1741     rvec xvi, xij, xik, fv, fj, fk;
1742     real cfx, cfy, cfz;
1743     int  av, ai, aj, ak;
1744     int  sji, ski;
1745
1746     av = ia[1];
1747     ai = ia[2];
1748     aj = ia[3];
1749     ak = ia[4];
1750
1751     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1752     ski = pbc_rvec_sub(pbc, x[ak], x[ai], xik);
1753     /* 6 Flops */
1754
1755     copy_rvec(f[av], fv);
1756
1757     cfx = c * fv[XX];
1758     cfy = c * fv[YY];
1759     cfz = c * fv[ZZ];
1760     /* 3 Flops */
1761
1762     fj[XX] = a * fv[XX] - xik[ZZ] * cfy + xik[YY] * cfz;
1763     fj[YY] = xik[ZZ] * cfx + a * fv[YY] - xik[XX] * cfz;
1764     fj[ZZ] = -xik[YY] * cfx + xik[XX] * cfy + a * fv[ZZ];
1765
1766     fk[XX] = b * fv[XX] + xij[ZZ] * cfy - xij[YY] * cfz;
1767     fk[YY] = -xij[ZZ] * cfx + b * fv[YY] + xij[XX] * cfz;
1768     fk[ZZ] = xij[YY] * cfx - xij[XX] * cfy + b * fv[ZZ];
1769     /* 30 Flops */
1770
1771     f[ai][XX] += fv[XX] - fj[XX] - fk[XX];
1772     f[ai][YY] += fv[YY] - fj[YY] - fk[YY];
1773     f[ai][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ];
1774     rvec_inc(f[aj], fj);
1775     rvec_inc(f[ak], fk);
1776     /* 15 Flops */
1777
1778     if (virialHandling == VirialHandling::Pbc)
1779     {
1780         int svi;
1781         if (pbc)
1782         {
1783             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1784         }
1785         else
1786         {
1787             svi = c_centralShiftIndex;
1788         }
1789
1790         if (svi != c_centralShiftIndex || sji != c_centralShiftIndex || ski != c_centralShiftIndex)
1791         {
1792             rvec_dec(fshift[svi], fv);
1793             fshift[c_centralShiftIndex][XX] += fv[XX] - fj[XX] - fk[XX];
1794             fshift[c_centralShiftIndex][YY] += fv[YY] - fj[YY] - fk[YY];
1795             fshift[c_centralShiftIndex][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ];
1796             rvec_inc(fshift[sji], fj);
1797             rvec_inc(fshift[ski], fk);
1798         }
1799     }
1800
1801     if (virialHandling == VirialHandling::NonLinear)
1802     {
1803         rvec xiv;
1804
1805         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1806
1807         for (int i = 0; i < DIM; i++)
1808         {
1809             for (int j = 0; j < DIM; j++)
1810             {
1811                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * fj[j] + xik[i] * fk[j];
1812             }
1813         }
1814     }
1815
1816     /* TOTAL: 54 flops */
1817 }
1818
1819 template<VirialHandling virialHandling>
1820 static void spread_vsite4FD(const t_iatom        ia[],
1821                             real                 a,
1822                             real                 b,
1823                             real                 c,
1824                             ArrayRef<const RVec> x,
1825                             ArrayRef<RVec>       f,
1826                             ArrayRef<RVec>       fshift,
1827                             matrix               dxdf,
1828                             const t_pbc*         pbc)
1829 {
1830     real fproj, a1;
1831     rvec xvi, xij, xjk, xjl, xix, fv, temp;
1832     int  av, ai, aj, ak, al;
1833     int  sji, skj, slj, m;
1834
1835     av = ia[1];
1836     ai = ia[2];
1837     aj = ia[3];
1838     ak = ia[4];
1839     al = ia[5];
1840
1841     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1842     skj = pbc_rvec_sub(pbc, x[ak], x[aj], xjk);
1843     slj = pbc_rvec_sub(pbc, x[al], x[aj], xjl);
1844     /* 9 flops */
1845
1846     /* xix goes from i to point x on the plane jkl */
1847     for (m = 0; m < DIM; m++)
1848     {
1849         xix[m] = xij[m] + a * xjk[m] + b * xjl[m];
1850     }
1851     /* 12 flops */
1852
1853     const real invDistance = inverseNorm(xix);
1854     const real d           = c * invDistance;
1855     /* 4 + ?10? flops */
1856
1857     copy_rvec(f[av], fv);
1858
1859     fproj = iprod(xix, fv) * invDistance * invDistance; /* = (xix . f)/(xix . xix) */
1860
1861     for (m = 0; m < DIM; m++)
1862     {
1863         temp[m] = d * (fv[m] - fproj * xix[m]);
1864     }
1865     /* 16 */
1866
1867     /* c is already calculated in constr_vsite3FD
1868        storing c somewhere will save 35 flops!     */
1869
1870     a1 = 1 - a - b;
1871     for (m = 0; m < DIM; m++)
1872     {
1873         f[ai][m] += fv[m] - temp[m];
1874         f[aj][m] += a1 * temp[m];
1875         f[ak][m] += a * temp[m];
1876         f[al][m] += b * temp[m];
1877     }
1878     /* 26 Flops */
1879
1880     if (virialHandling == VirialHandling::Pbc)
1881     {
1882         int svi;
1883         if (pbc)
1884         {
1885             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1886         }
1887         else
1888         {
1889             svi = c_centralShiftIndex;
1890         }
1891
1892         if (svi != c_centralShiftIndex || sji != c_centralShiftIndex || skj != c_centralShiftIndex
1893             || slj != c_centralShiftIndex)
1894         {
1895             rvec_dec(fshift[svi], fv);
1896             for (m = 0; m < DIM; m++)
1897             {
1898                 fshift[c_centralShiftIndex][m] += fv[m] - (1 + a + b) * temp[m];
1899                 fshift[sji][m] += temp[m];
1900                 fshift[skj][m] += a * temp[m];
1901                 fshift[slj][m] += b * temp[m];
1902             }
1903         }
1904     }
1905
1906     if (virialHandling == VirialHandling::NonLinear)
1907     {
1908         rvec xiv;
1909         int  i, j;
1910
1911         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1912
1913         for (i = 0; i < DIM; i++)
1914         {
1915             for (j = 0; j < DIM; j++)
1916             {
1917                 dxdf[i][j] += -xiv[i] * fv[j] + xix[i] * temp[j];
1918             }
1919         }
1920     }
1921
1922     /* TOTAL: 77 flops */
1923 }
1924
1925 template<VirialHandling virialHandling>
1926 static void spread_vsite4FDN(const t_iatom        ia[],
1927                              real                 a,
1928                              real                 b,
1929                              real                 c,
1930                              ArrayRef<const RVec> x,
1931                              ArrayRef<RVec>       f,
1932                              ArrayRef<RVec>       fshift,
1933                              matrix               dxdf,
1934                              const t_pbc*         pbc)
1935 {
1936     rvec xvi, xij, xik, xil, ra, rb, rja, rjb, rab, rm, rt;
1937     rvec fv, fj, fk, fl;
1938     real invrm, denom;
1939     real cfx, cfy, cfz;
1940     int  av, ai, aj, ak, al;
1941     int  sij, sik, sil;
1942
1943     /* DEBUG: check atom indices */
1944     av = ia[1];
1945     ai = ia[2];
1946     aj = ia[3];
1947     ak = ia[4];
1948     al = ia[5];
1949
1950     copy_rvec(f[av], fv);
1951
1952     sij = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1953     sik = pbc_rvec_sub(pbc, x[ak], x[ai], xik);
1954     sil = pbc_rvec_sub(pbc, x[al], x[ai], xil);
1955     /* 9 flops */
1956
1957     ra[XX] = a * xik[XX];
1958     ra[YY] = a * xik[YY];
1959     ra[ZZ] = a * xik[ZZ];
1960
1961     rb[XX] = b * xil[XX];
1962     rb[YY] = b * xil[YY];
1963     rb[ZZ] = b * xil[ZZ];
1964
1965     /* 6 flops */
1966
1967     rvec_sub(ra, xij, rja);
1968     rvec_sub(rb, xij, rjb);
1969     rvec_sub(rb, ra, rab);
1970     /* 9 flops */
1971
1972     cprod(rja, rjb, rm);
1973     /* 9 flops */
1974
1975     invrm = inverseNorm(rm);
1976     denom = invrm * invrm;
1977     /* 5+5+2 flops */
1978
1979     cfx = c * invrm * fv[XX];
1980     cfy = c * invrm * fv[YY];
1981     cfz = c * invrm * fv[ZZ];
1982     /* 6 Flops */
1983
1984     cprod(rm, rab, rt);
1985     /* 9 flops */
1986
1987     rt[XX] *= denom;
1988     rt[YY] *= denom;
1989     rt[ZZ] *= denom;
1990     /* 3flops */
1991
1992     fj[XX] = (-rm[XX] * rt[XX]) * cfx + (rab[ZZ] - rm[YY] * rt[XX]) * cfy
1993              + (-rab[YY] - rm[ZZ] * rt[XX]) * cfz;
1994     fj[YY] = (-rab[ZZ] - rm[XX] * rt[YY]) * cfx + (-rm[YY] * rt[YY]) * cfy
1995              + (rab[XX] - rm[ZZ] * rt[YY]) * cfz;
1996     fj[ZZ] = (rab[YY] - rm[XX] * rt[ZZ]) * cfx + (-rab[XX] - rm[YY] * rt[ZZ]) * cfy
1997              + (-rm[ZZ] * rt[ZZ]) * cfz;
1998     /* 30 flops */
1999
2000     cprod(rjb, rm, rt);
2001     /* 9 flops */
2002
2003     rt[XX] *= denom * a;
2004     rt[YY] *= denom * a;
2005     rt[ZZ] *= denom * a;
2006     /* 3flops */
2007
2008     fk[XX] = (-rm[XX] * rt[XX]) * cfx + (-a * rjb[ZZ] - rm[YY] * rt[XX]) * cfy
2009              + (a * rjb[YY] - rm[ZZ] * rt[XX]) * cfz;
2010     fk[YY] = (a * rjb[ZZ] - rm[XX] * rt[YY]) * cfx + (-rm[YY] * rt[YY]) * cfy
2011              + (-a * rjb[XX] - rm[ZZ] * rt[YY]) * cfz;
2012     fk[ZZ] = (-a * rjb[YY] - rm[XX] * rt[ZZ]) * cfx + (a * rjb[XX] - rm[YY] * rt[ZZ]) * cfy
2013              + (-rm[ZZ] * rt[ZZ]) * cfz;
2014     /* 36 flops */
2015
2016     cprod(rm, rja, rt);
2017     /* 9 flops */
2018
2019     rt[XX] *= denom * b;
2020     rt[YY] *= denom * b;
2021     rt[ZZ] *= denom * b;
2022     /* 3flops */
2023
2024     fl[XX] = (-rm[XX] * rt[XX]) * cfx + (b * rja[ZZ] - rm[YY] * rt[XX]) * cfy
2025              + (-b * rja[YY] - rm[ZZ] * rt[XX]) * cfz;
2026     fl[YY] = (-b * rja[ZZ] - rm[XX] * rt[YY]) * cfx + (-rm[YY] * rt[YY]) * cfy
2027              + (b * rja[XX] - rm[ZZ] * rt[YY]) * cfz;
2028     fl[ZZ] = (b * rja[YY] - rm[XX] * rt[ZZ]) * cfx + (-b * rja[XX] - rm[YY] * rt[ZZ]) * cfy
2029              + (-rm[ZZ] * rt[ZZ]) * cfz;
2030     /* 36 flops */
2031
2032     f[ai][XX] += fv[XX] - fj[XX] - fk[XX] - fl[XX];
2033     f[ai][YY] += fv[YY] - fj[YY] - fk[YY] - fl[YY];
2034     f[ai][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ] - fl[ZZ];
2035     rvec_inc(f[aj], fj);
2036     rvec_inc(f[ak], fk);
2037     rvec_inc(f[al], fl);
2038     /* 21 flops */
2039
2040     if (virialHandling == VirialHandling::Pbc)
2041     {
2042         int svi;
2043         if (pbc)
2044         {
2045             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
2046         }
2047         else
2048         {
2049             svi = c_centralShiftIndex;
2050         }
2051
2052         if (svi != c_centralShiftIndex || sij != c_centralShiftIndex || sik != c_centralShiftIndex
2053             || sil != c_centralShiftIndex)
2054         {
2055             rvec_dec(fshift[svi], fv);
2056             fshift[c_centralShiftIndex][XX] += fv[XX] - fj[XX] - fk[XX] - fl[XX];
2057             fshift[c_centralShiftIndex][YY] += fv[YY] - fj[YY] - fk[YY] - fl[YY];
2058             fshift[c_centralShiftIndex][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ] - fl[ZZ];
2059             rvec_inc(fshift[sij], fj);
2060             rvec_inc(fshift[sik], fk);
2061             rvec_inc(fshift[sil], fl);
2062         }
2063     }
2064
2065     if (virialHandling == VirialHandling::NonLinear)
2066     {
2067         rvec xiv;
2068         int  i, j;
2069
2070         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
2071
2072         for (i = 0; i < DIM; i++)
2073         {
2074             for (j = 0; j < DIM; j++)
2075             {
2076                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * fj[j] + xik[i] * fk[j] + xil[i] * fl[j];
2077             }
2078         }
2079     }
2080
2081     /* Total: 207 flops (Yuck!) */
2082 }
2083
2084 template<VirialHandling virialHandling>
2085 static int spread_vsiten(const t_iatom             ia[],
2086                          ArrayRef<const t_iparams> ip,
2087                          ArrayRef<const RVec>      x,
2088                          ArrayRef<RVec>            f,
2089                          ArrayRef<RVec>            fshift,
2090                          const t_pbc*              pbc)
2091 {
2092     rvec xv, dx, fi;
2093     int  n3, av, i, ai;
2094     real a;
2095     int  siv;
2096
2097     n3 = 3 * ip[ia[0]].vsiten.n;
2098     av = ia[1];
2099     copy_rvec(x[av], xv);
2100
2101     for (i = 0; i < n3; i += 3)
2102     {
2103         ai = ia[i + 2];
2104         if (pbc)
2105         {
2106             siv = pbc_dx_aiuc(pbc, x[ai], xv, dx);
2107         }
2108         else
2109         {
2110             siv = c_centralShiftIndex;
2111         }
2112         a = ip[ia[i]].vsiten.a;
2113         svmul(a, f[av], fi);
2114         rvec_inc(f[ai], fi);
2115
2116         if (virialHandling == VirialHandling::Pbc && siv != c_centralShiftIndex)
2117         {
2118             rvec_inc(fshift[siv], fi);
2119             rvec_dec(fshift[c_centralShiftIndex], fi);
2120         }
2121         /* 6 Flops */
2122     }
2123
2124     return n3;
2125 }
2126
2127 #endif // DOXYGEN
2128
2129 //! Returns the number of virtual sites in the interaction list, for VSITEN the number of atoms
2130 static int vsite_count(ArrayRef<const InteractionList> ilist, int ftype)
2131 {
2132     if (ftype == F_VSITEN)
2133     {
2134         return ilist[ftype].size() / 3;
2135     }
2136     else
2137     {
2138         return ilist[ftype].size() / (1 + interaction_function[ftype].nratoms);
2139     }
2140 }
2141
2142 //! Executes the force spreading task for a single thread
2143 template<VirialHandling virialHandling>
2144 static void spreadForceForThread(ArrayRef<const RVec>            x,
2145                                  ArrayRef<RVec>                  f,
2146                                  ArrayRef<RVec>                  fshift,
2147                                  matrix                          dxdf,
2148                                  ArrayRef<const t_iparams>       ip,
2149                                  ArrayRef<const InteractionList> ilist,
2150                                  const t_pbc*                    pbc_null)
2151 {
2152     const PbcMode pbcMode = getPbcMode(pbc_null);
2153     /* We need another pbc pointer, as with charge groups we switch per vsite */
2154     const t_pbc*             pbc_null2 = pbc_null;
2155     gmx::ArrayRef<const int> vsite_pbc;
2156
2157     /* this loop goes backwards to be able to build *
2158      * higher type vsites from lower types         */
2159     for (int ftype = c_ftypeVsiteEnd - 1; ftype >= c_ftypeVsiteStart; ftype--)
2160     {
2161         if (ilist[ftype].empty())
2162         {
2163             continue;
2164         }
2165
2166         { // TODO remove me
2167             int nra = interaction_function[ftype].nratoms;
2168             int inc = 1 + nra;
2169             int nr  = ilist[ftype].size();
2170
2171             const t_iatom* ia = ilist[ftype].iatoms.data();
2172
2173             if (pbcMode == PbcMode::all)
2174             {
2175                 pbc_null2 = pbc_null;
2176             }
2177
2178             for (int i = 0; i < nr;)
2179             {
2180                 int tp = ia[0];
2181
2182                 /* Constants for constructing */
2183                 real a1, b1, c1;
2184                 a1 = ip[tp].vsite.a;
2185                 /* Construct the vsite depending on type */
2186                 switch (ftype)
2187                 {
2188                     case F_VSITE1: spread_vsite1(ia, f); break;
2189                     case F_VSITE2:
2190                         spread_vsite2<virialHandling>(ia, a1, x, f, fshift, pbc_null2);
2191                         break;
2192                     case F_VSITE2FD:
2193                         spread_vsite2FD<virialHandling>(ia, a1, x, f, fshift, dxdf, pbc_null2);
2194                         break;
2195                     case F_VSITE3:
2196                         b1 = ip[tp].vsite.b;
2197                         spread_vsite3<virialHandling>(ia, a1, b1, x, f, fshift, pbc_null2);
2198                         break;
2199                     case F_VSITE3FD:
2200                         b1 = ip[tp].vsite.b;
2201                         spread_vsite3FD<virialHandling>(ia, a1, b1, x, f, fshift, dxdf, pbc_null2);
2202                         break;
2203                     case F_VSITE3FAD:
2204                         b1 = ip[tp].vsite.b;
2205                         spread_vsite3FAD<virialHandling>(ia, a1, b1, x, f, fshift, dxdf, pbc_null2);
2206                         break;
2207                     case F_VSITE3OUT:
2208                         b1 = ip[tp].vsite.b;
2209                         c1 = ip[tp].vsite.c;
2210                         spread_vsite3OUT<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
2211                         break;
2212                     case F_VSITE4FD:
2213                         b1 = ip[tp].vsite.b;
2214                         c1 = ip[tp].vsite.c;
2215                         spread_vsite4FD<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
2216                         break;
2217                     case F_VSITE4FDN:
2218                         b1 = ip[tp].vsite.b;
2219                         c1 = ip[tp].vsite.c;
2220                         spread_vsite4FDN<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
2221                         break;
2222                     case F_VSITEN:
2223                         inc = spread_vsiten<virialHandling>(ia, ip, x, f, fshift, pbc_null2);
2224                         break;
2225                     default:
2226                         gmx_fatal(FARGS, "No such vsite type %d in %s, line %d", ftype, __FILE__, __LINE__);
2227                 }
2228                 clear_rvec(f[ia[1]]);
2229
2230                 /* Increment loop variables */
2231                 i += inc;
2232                 ia += inc;
2233             }
2234         }
2235     }
2236 }
2237
2238 //! Wrapper function for calling the templated thread-local spread function
2239 static void spreadForceWrapper(ArrayRef<const RVec>            x,
2240                                ArrayRef<RVec>                  f,
2241                                const VirialHandling            virialHandling,
2242                                ArrayRef<RVec>                  fshift,
2243                                matrix                          dxdf,
2244                                const bool                      clearDxdf,
2245                                ArrayRef<const t_iparams>       ip,
2246                                ArrayRef<const InteractionList> ilist,
2247                                const t_pbc*                    pbc_null)
2248 {
2249     if (virialHandling == VirialHandling::NonLinear && clearDxdf)
2250     {
2251         clear_mat(dxdf);
2252     }
2253
2254     switch (virialHandling)
2255     {
2256         case VirialHandling::None:
2257             spreadForceForThread<VirialHandling::None>(x, f, fshift, dxdf, ip, ilist, pbc_null);
2258             break;
2259         case VirialHandling::Pbc:
2260             spreadForceForThread<VirialHandling::Pbc>(x, f, fshift, dxdf, ip, ilist, pbc_null);
2261             break;
2262         case VirialHandling::NonLinear:
2263             spreadForceForThread<VirialHandling::NonLinear>(x, f, fshift, dxdf, ip, ilist, pbc_null);
2264             break;
2265     }
2266 }
2267
2268 //! Clears the task force buffer elements that are written by task idTask
2269 static void clearTaskForceBufferUsedElements(InterdependentTask* idTask)
2270 {
2271     int ntask = idTask->spreadTask.size();
2272     for (int ti = 0; ti < ntask; ti++)
2273     {
2274         const AtomIndex* atomList = &idTask->atomIndex[idTask->spreadTask[ti]];
2275         int              natom    = atomList->atom.size();
2276         RVec*            force    = idTask->force.data();
2277         for (int i = 0; i < natom; i++)
2278         {
2279             clear_rvec(force[atomList->atom[i]]);
2280         }
2281     }
2282 }
2283
2284 void VirtualSitesHandler::Impl::spreadForces(ArrayRef<const RVec> x,
2285                                              ArrayRef<RVec>       f,
2286                                              const VirialHandling virialHandling,
2287                                              ArrayRef<RVec>       fshift,
2288                                              matrix               virial,
2289                                              t_nrnb*              nrnb,
2290                                              const matrix         box,
2291                                              gmx_wallcycle*       wcycle)
2292 {
2293     wallcycle_start(wcycle, WallCycleCounter::VsiteSpread);
2294
2295     const bool useDomdec = domainInfo_.useDomdec();
2296
2297     t_pbc pbc, *pbc_null;
2298
2299     if (domainInfo_.useMolPbc_)
2300     {
2301         /* This is wasting some CPU time as we now do this multiple times
2302          * per MD step.
2303          */
2304         pbc_null = set_pbc_dd(
2305                 &pbc, domainInfo_.pbcType_, useDomdec ? domainInfo_.domdec_->numCells : nullptr, FALSE, box);
2306     }
2307     else
2308     {
2309         pbc_null = nullptr;
2310     }
2311
2312     if (useDomdec)
2313     {
2314         dd_clear_f_vsites(*domainInfo_.domdec_, f);
2315     }
2316
2317     const int numThreads = threadingInfo_.numThreads();
2318
2319     if (numThreads == 1)
2320     {
2321         matrix dxdf;
2322         spreadForceWrapper(x, f, virialHandling, fshift, dxdf, true, iparams_, ilists_, pbc_null);
2323
2324         if (virialHandling == VirialHandling::NonLinear)
2325         {
2326             for (int i = 0; i < DIM; i++)
2327             {
2328                 for (int j = 0; j < DIM; j++)
2329                 {
2330                     virial[i][j] += -0.5 * dxdf[i][j];
2331                 }
2332             }
2333         }
2334     }
2335     else
2336     {
2337         /* First spread the vsites that might depend on non-local vsites */
2338         auto& nlDependentVSites = threadingInfo_.threadDataNonLocalDependent();
2339         spreadForceWrapper(x,
2340                            f,
2341                            virialHandling,
2342                            fshift,
2343                            nlDependentVSites.dxdf,
2344                            true,
2345                            iparams_,
2346                            nlDependentVSites.ilist,
2347                            pbc_null);
2348
2349 #pragma omp parallel num_threads(numThreads)
2350         {
2351             try
2352             {
2353                 int          thread = gmx_omp_get_thread_num();
2354                 VsiteThread& tData  = threadingInfo_.threadData(thread);
2355
2356                 ArrayRef<RVec> fshift_t;
2357                 if (virialHandling == VirialHandling::Pbc)
2358                 {
2359                     if (thread == 0)
2360                     {
2361                         fshift_t = fshift;
2362                     }
2363                     else
2364                     {
2365                         fshift_t = tData.fshift;
2366
2367                         for (int i = 0; i < c_numShiftVectors; i++)
2368                         {
2369                             clear_rvec(fshift_t[i]);
2370                         }
2371                     }
2372                 }
2373
2374                 if (tData.useInterdependentTask)
2375                 {
2376                     /* Spread the vsites that spread outside our local range.
2377                      * This is done using a thread-local force buffer force.
2378                      * First we need to copy the input vsite forces to force.
2379                      */
2380                     InterdependentTask* idTask = &tData.idTask;
2381
2382                     /* Clear the buffer elements set by our task during
2383                      * the last call to spread_vsite_f.
2384                      */
2385                     clearTaskForceBufferUsedElements(idTask);
2386
2387                     int nvsite = idTask->vsite.size();
2388                     for (int i = 0; i < nvsite; i++)
2389                     {
2390                         copy_rvec(f[idTask->vsite[i]], idTask->force[idTask->vsite[i]]);
2391                     }
2392                     spreadForceWrapper(x,
2393                                        idTask->force,
2394                                        virialHandling,
2395                                        fshift_t,
2396                                        tData.dxdf,
2397                                        true,
2398                                        iparams_,
2399                                        tData.idTask.ilist,
2400                                        pbc_null);
2401
2402                     /* We need a barrier before reducing forces below
2403                      * that have been produced by a different thread above.
2404                      */
2405 #pragma omp barrier
2406
2407                     /* Loop over all thread task and reduce forces they
2408                      * produced on atoms that fall in our range.
2409                      * Note that atomic reduction would be a simpler solution,
2410                      * but that might not have good support on all platforms.
2411                      */
2412                     int ntask = idTask->reduceTask.size();
2413                     for (int ti = 0; ti < ntask; ti++)
2414                     {
2415                         const InterdependentTask& idt_foreign =
2416                                 threadingInfo_.threadData(idTask->reduceTask[ti]).idTask;
2417                         const AtomIndex& atomList  = idt_foreign.atomIndex[thread];
2418                         const RVec*      f_foreign = idt_foreign.force.data();
2419
2420                         for (int ind : atomList.atom)
2421                         {
2422                             rvec_inc(f[ind], f_foreign[ind]);
2423                             /* Clearing of f_foreign is done at the next step */
2424                         }
2425                     }
2426                     /* Clear the vsite forces, both in f and force */
2427                     for (int i = 0; i < nvsite; i++)
2428                     {
2429                         int ind = tData.idTask.vsite[i];
2430                         clear_rvec(f[ind]);
2431                         clear_rvec(tData.idTask.force[ind]);
2432                     }
2433                 }
2434
2435                 /* Spread the vsites that spread locally only */
2436                 spreadForceWrapper(
2437                         x, f, virialHandling, fshift_t, tData.dxdf, false, iparams_, tData.ilist, pbc_null);
2438             }
2439             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
2440         }
2441
2442         if (virialHandling == VirialHandling::Pbc)
2443         {
2444             for (int th = 1; th < numThreads; th++)
2445             {
2446                 for (int i = 0; i < c_numShiftVectors; i++)
2447                 {
2448                     rvec_inc(fshift[i], threadingInfo_.threadData(th).fshift[i]);
2449                 }
2450             }
2451         }
2452
2453         if (virialHandling == VirialHandling::NonLinear)
2454         {
2455             for (int th = 0; th < numThreads + 1; th++)
2456             {
2457                 /* MSVC doesn't like matrix references, so we use a pointer */
2458                 const matrix& dxdf = threadingInfo_.threadData(th).dxdf;
2459
2460                 for (int i = 0; i < DIM; i++)
2461                 {
2462                     for (int j = 0; j < DIM; j++)
2463                     {
2464                         virial[i][j] += -0.5 * dxdf[i][j];
2465                     }
2466                 }
2467             }
2468         }
2469     }
2470
2471     if (useDomdec)
2472     {
2473         dd_move_f_vsites(*domainInfo_.domdec_, f, fshift);
2474     }
2475
2476     inc_nrnb(nrnb, eNR_VSITE1, vsite_count(ilists_, F_VSITE1));
2477     inc_nrnb(nrnb, eNR_VSITE2, vsite_count(ilists_, F_VSITE2));
2478     inc_nrnb(nrnb, eNR_VSITE2FD, vsite_count(ilists_, F_VSITE2FD));
2479     inc_nrnb(nrnb, eNR_VSITE3, vsite_count(ilists_, F_VSITE3));
2480     inc_nrnb(nrnb, eNR_VSITE3FD, vsite_count(ilists_, F_VSITE3FD));
2481     inc_nrnb(nrnb, eNR_VSITE3FAD, vsite_count(ilists_, F_VSITE3FAD));
2482     inc_nrnb(nrnb, eNR_VSITE3OUT, vsite_count(ilists_, F_VSITE3OUT));
2483     inc_nrnb(nrnb, eNR_VSITE4FD, vsite_count(ilists_, F_VSITE4FD));
2484     inc_nrnb(nrnb, eNR_VSITE4FDN, vsite_count(ilists_, F_VSITE4FDN));
2485     inc_nrnb(nrnb, eNR_VSITEN, vsite_count(ilists_, F_VSITEN));
2486
2487     wallcycle_stop(wcycle, WallCycleCounter::VsiteSpread);
2488 }
2489
2490 /*! \brief Returns the an array with group indices for each atom
2491  *
2492  * \param[in] grouping  The paritioning of the atom range into atom groups
2493  */
2494 static std::vector<int> makeAtomToGroupMapping(const gmx::RangePartitioning& grouping)
2495 {
2496     std::vector<int> atomToGroup(grouping.fullRange().end(), 0);
2497
2498     for (int group = 0; group < grouping.numBlocks(); group++)
2499     {
2500         auto block = grouping.block(group);
2501         std::fill(atomToGroup.begin() + block.begin(), atomToGroup.begin() + block.end(), group);
2502     }
2503
2504     return atomToGroup;
2505 }
2506
2507 int countNonlinearVsites(const gmx_mtop_t& mtop)
2508 {
2509     int numNonlinearVsites = 0;
2510     for (const gmx_molblock_t& molb : mtop.molblock)
2511     {
2512         const gmx_moltype_t& molt = mtop.moltype[molb.type];
2513
2514         for (const auto& ilist : extractILists(molt.ilist, IF_VSITE))
2515         {
2516             if (ilist.functionType != F_VSITE2 && ilist.functionType != F_VSITE3
2517                 && ilist.functionType != F_VSITEN)
2518             {
2519                 numNonlinearVsites += molb.nmol * ilist.iatoms.size() / (1 + NRAL(ilist.functionType));
2520             }
2521         }
2522     }
2523
2524     return numNonlinearVsites;
2525 }
2526
2527 void VirtualSitesHandler::spreadForces(ArrayRef<const RVec> x,
2528                                        ArrayRef<RVec>       f,
2529                                        const VirialHandling virialHandling,
2530                                        ArrayRef<RVec>       fshift,
2531                                        matrix               virial,
2532                                        t_nrnb*              nrnb,
2533                                        const matrix         box,
2534                                        gmx_wallcycle*       wcycle)
2535 {
2536     impl_->spreadForces(x, f, virialHandling, fshift, virial, nrnb, box, wcycle);
2537 }
2538
2539 int countInterUpdategroupVsites(const gmx_mtop_t&                           mtop,
2540                                 gmx::ArrayRef<const gmx::RangePartitioning> updateGroupingsPerMoleculeType)
2541 {
2542     int n_intercg_vsite = 0;
2543     for (const gmx_molblock_t& molb : mtop.molblock)
2544     {
2545         const gmx_moltype_t& molt = mtop.moltype[molb.type];
2546
2547         std::vector<int> atomToGroup;
2548         if (!updateGroupingsPerMoleculeType.empty())
2549         {
2550             atomToGroup = makeAtomToGroupMapping(updateGroupingsPerMoleculeType[molb.type]);
2551         }
2552         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2553         {
2554             const int              nral = NRAL(ftype);
2555             const InteractionList& il   = molt.ilist[ftype];
2556             for (int i = 0; i < il.size(); i += 1 + nral)
2557             {
2558                 bool isInterGroup = atomToGroup.empty();
2559                 if (!isInterGroup)
2560                 {
2561                     const int group = atomToGroup[il.iatoms[1 + i]];
2562                     for (int a = 1; a < nral; a++)
2563                     {
2564                         if (atomToGroup[il.iatoms[1 + a]] != group)
2565                         {
2566                             isInterGroup = true;
2567                             break;
2568                         }
2569                     }
2570                 }
2571                 if (isInterGroup)
2572                 {
2573                     n_intercg_vsite += molb.nmol;
2574                 }
2575             }
2576         }
2577     }
2578
2579     return n_intercg_vsite;
2580 }
2581
2582 std::unique_ptr<VirtualSitesHandler> makeVirtualSitesHandler(const gmx_mtop_t& mtop,
2583                                                              const t_commrec*  cr,
2584                                                              PbcType           pbcType,
2585                                                              ArrayRef<const RangePartitioning> updateGroupingPerMoleculeType)
2586 {
2587     GMX_RELEASE_ASSERT(cr != nullptr, "We need a valid commrec");
2588
2589     std::unique_ptr<VirtualSitesHandler> vsite;
2590
2591     /* check if there are vsites */
2592     int nvsite = 0;
2593     for (int ftype = 0; ftype < F_NRE; ftype++)
2594     {
2595         if (interaction_function[ftype].flags & IF_VSITE)
2596         {
2597             GMX_ASSERT(ftype >= c_ftypeVsiteStart && ftype < c_ftypeVsiteEnd,
2598                        "c_ftypeVsiteStart and/or c_ftypeVsiteEnd do not have correct values");
2599
2600             nvsite += gmx_mtop_ftype_count(mtop, ftype);
2601         }
2602         else
2603         {
2604             GMX_ASSERT(ftype < c_ftypeVsiteStart || ftype >= c_ftypeVsiteEnd,
2605                        "c_ftypeVsiteStart and/or c_ftypeVsiteEnd do not have correct values");
2606         }
2607     }
2608
2609     if (nvsite == 0)
2610     {
2611         return vsite;
2612     }
2613
2614     return std::make_unique<VirtualSitesHandler>(mtop, cr->dd, pbcType, updateGroupingPerMoleculeType);
2615 }
2616
2617 ThreadingInfo::ThreadingInfo() : numThreads_(gmx_omp_nthreads_get(ModuleMultiThread::VirtualSite))
2618 {
2619     if (numThreads_ > 1)
2620     {
2621         /* We need one extra thread data structure for the overlap vsites */
2622         tData_.resize(numThreads_ + 1);
2623 #pragma omp parallel for num_threads(numThreads_) schedule(static)
2624         for (int thread = 0; thread < numThreads_; thread++)
2625         {
2626             try
2627             {
2628                 tData_[thread] = std::make_unique<VsiteThread>();
2629
2630                 InterdependentTask& idTask = tData_[thread]->idTask;
2631                 idTask.nuse                = 0;
2632                 idTask.atomIndex.resize(numThreads_);
2633             }
2634             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
2635         }
2636         if (numThreads_ > 1)
2637         {
2638             tData_[numThreads_] = std::make_unique<VsiteThread>();
2639         }
2640     }
2641 }
2642
2643 VirtualSitesHandler::Impl::Impl(const gmx_mtop_t&                       mtop,
2644                                 gmx_domdec_t*                           domdec,
2645                                 const PbcType                           pbcType,
2646                                 const ArrayRef<const RangePartitioning> updateGroupingPerMoleculeType) :
2647     numInterUpdategroupVirtualSites_(countInterUpdategroupVsites(mtop, updateGroupingPerMoleculeType)),
2648     domainInfo_({ pbcType, pbcType != PbcType::No && numInterUpdategroupVirtualSites_ > 0, domdec }),
2649     iparams_(mtop.ffparams.iparams)
2650 {
2651 }
2652
2653 VirtualSitesHandler::VirtualSitesHandler(const gmx_mtop_t&                       mtop,
2654                                          gmx_domdec_t*                           domdec,
2655                                          const PbcType                           pbcType,
2656                                          const ArrayRef<const RangePartitioning> updateGroupingPerMoleculeType) :
2657     impl_(new Impl(mtop, domdec, pbcType, updateGroupingPerMoleculeType))
2658 {
2659 }
2660
2661 //! Flag that atom \p atom which is home in another task, if it has not already been added before
2662 static inline void flagAtom(InterdependentTask* idTask, const int atom, const int numThreads, const int numAtomsPerThread)
2663 {
2664     if (!idTask->use[atom])
2665     {
2666         idTask->use[atom] = true;
2667         int thread        = atom / numAtomsPerThread;
2668         /* Assign all non-local atom force writes to thread 0 */
2669         if (thread >= numThreads)
2670         {
2671             thread = 0;
2672         }
2673         idTask->atomIndex[thread].atom.push_back(atom);
2674     }
2675 }
2676
2677 /*! \brief Here we try to assign all vsites that are in our local range.
2678  *
2679  * Our task local atom range is tData->rangeStart - tData->rangeEnd.
2680  * Vsites that depend only on local atoms, as indicated by taskIndex[]==thread,
2681  * are assigned to task tData->ilist. Vsites that depend on non-local atoms
2682  * but not on other vsites are assigned to task tData->id_task.ilist.
2683  * taskIndex[] is set for all vsites in our range, either to our local tasks
2684  * or to the single last task as taskIndex[]=2*nthreads.
2685  */
2686 static void assignVsitesToThread(VsiteThread*                    tData,
2687                                  int                             thread,
2688                                  int                             nthread,
2689                                  int                             natperthread,
2690                                  gmx::ArrayRef<int>              taskIndex,
2691                                  ArrayRef<const InteractionList> ilist,
2692                                  ArrayRef<const t_iparams>       ip,
2693                                  ArrayRef<const ParticleType>    ptype)
2694 {
2695     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2696     {
2697         tData->ilist[ftype].clear();
2698         tData->idTask.ilist[ftype].clear();
2699
2700         const int  nral1 = 1 + NRAL(ftype);
2701         const int* iat   = ilist[ftype].iatoms.data();
2702         for (int i = 0; i < ilist[ftype].size();)
2703         {
2704             /* Get the number of iatom entries in this virtual site.
2705              * The 3 below for F_VSITEN is from 1+NRAL(ftype)=3
2706              */
2707             const int numIAtoms = (ftype == F_VSITEN ? ip[iat[i]].vsiten.n * 3 : nral1);
2708
2709             if (iat[1 + i] < tData->rangeStart || iat[1 + i] >= tData->rangeEnd)
2710             {
2711                 /* This vsite belongs to a different thread */
2712                 i += numIAtoms;
2713                 continue;
2714             }
2715
2716             /* We would like to assign this vsite to task thread,
2717              * but it might depend on atoms outside the atom range of thread
2718              * or on another vsite not assigned to task thread.
2719              */
2720             int task = thread;
2721             if (ftype != F_VSITEN)
2722             {
2723                 for (int j = i + 2; j < i + nral1; j++)
2724                 {
2725                     /* Do a range check to avoid a harmless race on taskIndex */
2726                     if (iat[j] < tData->rangeStart || iat[j] >= tData->rangeEnd || taskIndex[iat[j]] != thread)
2727                     {
2728                         if (!tData->useInterdependentTask || ptype[iat[j]] == ParticleType::VSite)
2729                         {
2730                             /* At least one constructing atom is a vsite
2731                              * that is not assigned to the same thread.
2732                              * Put this vsite into a separate task.
2733                              */
2734                             task = 2 * nthread;
2735                             break;
2736                         }
2737
2738                         /* There are constructing atoms outside our range,
2739                          * put this vsite into a second task to be executed
2740                          * on the same thread. During construction no barrier
2741                          * is needed between the two tasks on the same thread.
2742                          * During spreading we need to run this task with
2743                          * an additional thread-local intermediate force buffer
2744                          * (or atomic reduction) and a barrier between the two
2745                          * tasks.
2746                          */
2747                         task = nthread + thread;
2748                     }
2749                 }
2750             }
2751             else
2752             {
2753                 for (int j = i + 2; j < i + numIAtoms; j += 3)
2754                 {
2755                     /* Do a range check to avoid a harmless race on taskIndex */
2756                     if (iat[j] < tData->rangeStart || iat[j] >= tData->rangeEnd || taskIndex[iat[j]] != thread)
2757                     {
2758                         GMX_ASSERT(ptype[iat[j]] != ParticleType::VSite,
2759                                    "A vsite to be assigned in assignVsitesToThread has a vsite as "
2760                                    "a constructing atom that does not belong to our task, such "
2761                                    "vsites should be assigned to the single 'master' task");
2762
2763                         if (tData->useInterdependentTask)
2764                         {
2765                             // Assign to the interdependent task
2766                             task = nthread + thread;
2767                         }
2768                         else
2769                         {
2770                             // Assign to the separate, non-parallel task
2771                             task = 2 * nthread;
2772                         }
2773                     }
2774                 }
2775             }
2776
2777             /* Update this vsite's thread index entry */
2778             taskIndex[iat[1 + i]] = task;
2779
2780             if (task == thread || task == nthread + thread)
2781             {
2782                 /* Copy this vsite to the thread data struct of thread */
2783                 InteractionList* il_task;
2784                 if (task == thread)
2785                 {
2786                     il_task = &tData->ilist[ftype];
2787                 }
2788                 else
2789                 {
2790                     il_task = &tData->idTask.ilist[ftype];
2791                 }
2792                 /* Copy the vsite data to the thread-task local array */
2793                 il_task->push_back(iat[i], numIAtoms - 1, iat + i + 1);
2794                 if (task == nthread + thread)
2795                 {
2796                     /* This vsite writes outside our own task force block.
2797                      * Put it into the interdependent task list and flag
2798                      * the atoms involved for reduction.
2799                      */
2800                     tData->idTask.vsite.push_back(iat[i + 1]);
2801                     if (ftype != F_VSITEN)
2802                     {
2803                         for (int j = i + 2; j < i + nral1; j++)
2804                         {
2805                             flagAtom(&tData->idTask, iat[j], nthread, natperthread);
2806                         }
2807                     }
2808                     else
2809                     {
2810                         for (int j = i + 2; j < i + numIAtoms; j += 3)
2811                         {
2812                             flagAtom(&tData->idTask, iat[j], nthread, natperthread);
2813                         }
2814                     }
2815                 }
2816             }
2817
2818             i += numIAtoms;
2819         }
2820     }
2821 }
2822
2823 /*! \brief Assign all vsites with taskIndex[]==task to task tData */
2824 static void assignVsitesToSingleTask(VsiteThread*                    tData,
2825                                      int                             task,
2826                                      gmx::ArrayRef<const int>        taskIndex,
2827                                      ArrayRef<const InteractionList> ilist,
2828                                      ArrayRef<const t_iparams>       ip)
2829 {
2830     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2831     {
2832         tData->ilist[ftype].clear();
2833         tData->idTask.ilist[ftype].clear();
2834
2835         int              nral1   = 1 + NRAL(ftype);
2836         int              inc     = nral1;
2837         const int*       iat     = ilist[ftype].iatoms.data();
2838         InteractionList* il_task = &tData->ilist[ftype];
2839
2840         for (int i = 0; i < ilist[ftype].size();)
2841         {
2842             if (ftype == F_VSITEN)
2843             {
2844                 /* The 3 below is from 1+NRAL(ftype)=3 */
2845                 inc = ip[iat[i]].vsiten.n * 3;
2846             }
2847             /* Check if the vsite is assigned to our task */
2848             if (taskIndex[iat[1 + i]] == task)
2849             {
2850                 /* Copy the vsite data to the thread-task local array */
2851                 il_task->push_back(iat[i], inc - 1, iat + i + 1);
2852             }
2853
2854             i += inc;
2855         }
2856     }
2857 }
2858
2859 void ThreadingInfo::setVirtualSites(ArrayRef<const InteractionList> ilists,
2860                                     ArrayRef<const t_iparams>       iparams,
2861                                     const int                       numAtoms,
2862                                     const int                       homenr,
2863                                     ArrayRef<const ParticleType>    ptype,
2864                                     const bool                      useDomdec)
2865 {
2866     if (numThreads_ <= 1)
2867     {
2868         /* Nothing to do */
2869         return;
2870     }
2871
2872     /* The current way of distributing the vsites over threads in primitive.
2873      * We divide the atom range 0 - natoms_in_vsite uniformly over threads,
2874      * without taking into account how the vsites are distributed.
2875      * Without domain decomposition we at least tighten the upper bound
2876      * of the range (useful for common systems such as a vsite-protein
2877      * in 3-site water).
2878      * With domain decomposition, as long as the vsites are distributed
2879      * uniformly in each domain along the major dimension, usually x,
2880      * it will also perform well.
2881      */
2882     int vsite_atom_range;
2883     int natperthread;
2884     if (!useDomdec)
2885     {
2886         vsite_atom_range = -1;
2887         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2888         {
2889             { // TODO remove me
2890                 if (ftype != F_VSITEN)
2891                 {
2892                     int                 nral1 = 1 + NRAL(ftype);
2893                     ArrayRef<const int> iat   = ilists[ftype].iatoms;
2894                     for (int i = 0; i < ilists[ftype].size(); i += nral1)
2895                     {
2896                         for (int j = i + 1; j < i + nral1; j++)
2897                         {
2898                             vsite_atom_range = std::max(vsite_atom_range, iat[j]);
2899                         }
2900                     }
2901                 }
2902                 else
2903                 {
2904                     int vs_ind_end;
2905
2906                     ArrayRef<const int> iat = ilists[ftype].iatoms;
2907
2908                     int i = 0;
2909                     while (i < ilists[ftype].size())
2910                     {
2911                         /* The 3 below is from 1+NRAL(ftype)=3 */
2912                         vs_ind_end = i + iparams[iat[i]].vsiten.n * 3;
2913
2914                         vsite_atom_range = std::max(vsite_atom_range, iat[i + 1]);
2915                         while (i < vs_ind_end)
2916                         {
2917                             vsite_atom_range = std::max(vsite_atom_range, iat[i + 2]);
2918                             i += 3;
2919                         }
2920                     }
2921                 }
2922             }
2923         }
2924         vsite_atom_range++;
2925         natperthread = (vsite_atom_range + numThreads_ - 1) / numThreads_;
2926     }
2927     else
2928     {
2929         /* Any local or not local atom could be involved in virtual sites.
2930          * But since we usually have very few non-local virtual sites
2931          * (only non-local vsites that depend on local vsites),
2932          * we distribute the local atom range equally over the threads.
2933          * When assigning vsites to threads, we should take care that the last
2934          * threads also covers the non-local range.
2935          */
2936         vsite_atom_range = numAtoms;
2937         natperthread     = (homenr + numThreads_ - 1) / numThreads_;
2938     }
2939
2940     if (debug)
2941     {
2942         fprintf(debug,
2943                 "virtual site thread dist: natoms %d, range %d, natperthread %d\n",
2944                 numAtoms,
2945                 vsite_atom_range,
2946                 natperthread);
2947     }
2948
2949     /* To simplify the vsite assignment, we make an index which tells us
2950      * to which task particles, both non-vsites and vsites, are assigned.
2951      */
2952     taskIndex_.resize(numAtoms);
2953
2954     /* Initialize the task index array. Here we assign the non-vsite
2955      * particles to task=thread, so we easily figure out if vsites
2956      * depend on local and/or non-local particles in assignVsitesToThread.
2957      */
2958     {
2959         int thread = 0;
2960         for (int i = 0; i < numAtoms; i++)
2961         {
2962             if (ptype[i] == ParticleType::VSite)
2963             {
2964                 /* vsites are not assigned to a task yet */
2965                 taskIndex_[i] = -1;
2966             }
2967             else
2968             {
2969                 /* assign non-vsite particles to task thread */
2970                 taskIndex_[i] = thread;
2971             }
2972             if (i == (thread + 1) * natperthread && thread < numThreads_)
2973             {
2974                 thread++;
2975             }
2976         }
2977     }
2978
2979 #pragma omp parallel num_threads(numThreads_)
2980     {
2981         try
2982         {
2983             int          thread = gmx_omp_get_thread_num();
2984             VsiteThread& tData  = *tData_[thread];
2985
2986             /* Clear the buffer use flags that were set before */
2987             if (tData.useInterdependentTask)
2988             {
2989                 InterdependentTask& idTask = tData.idTask;
2990
2991                 /* To avoid an extra OpenMP barrier in spread_vsite_f,
2992                  * we clear the force buffer at the next step,
2993                  * so we need to do it here as well.
2994                  */
2995                 clearTaskForceBufferUsedElements(&idTask);
2996
2997                 idTask.vsite.resize(0);
2998                 for (int t = 0; t < numThreads_; t++)
2999                 {
3000                     AtomIndex& atomIndex = idTask.atomIndex[t];
3001                     int        natom     = atomIndex.atom.size();
3002                     for (int i = 0; i < natom; i++)
3003                     {
3004                         idTask.use[atomIndex.atom[i]] = false;
3005                     }
3006                     atomIndex.atom.resize(0);
3007                 }
3008                 idTask.nuse = 0;
3009             }
3010
3011             /* To avoid large f_buf allocations of #threads*vsite_atom_range
3012              * we don't use the interdependent tasks with more than 200000 atoms.
3013              * This doesn't affect performance, since with such a large range
3014              * relatively few vsites will end up in the separate task.
3015              * Note that useInterdependentTask should be the same for all threads.
3016              */
3017             const int c_maxNumLocalAtomsForInterdependentTask = 200000;
3018             tData.useInterdependentTask = (vsite_atom_range <= c_maxNumLocalAtomsForInterdependentTask);
3019             if (tData.useInterdependentTask)
3020             {
3021                 size_t              natoms_use_in_vsites = vsite_atom_range;
3022                 InterdependentTask& idTask               = tData.idTask;
3023                 /* To avoid resizing and re-clearing every nstlist steps,
3024                  * we never down size the force buffer.
3025                  */
3026                 if (natoms_use_in_vsites > idTask.force.size() || natoms_use_in_vsites > idTask.use.size())
3027                 {
3028                     idTask.force.resize(natoms_use_in_vsites, { 0, 0, 0 });
3029                     idTask.use.resize(natoms_use_in_vsites, false);
3030                 }
3031             }
3032
3033             /* Assign all vsites that can execute independently on threads */
3034             tData.rangeStart = thread * natperthread;
3035             if (thread < numThreads_ - 1)
3036             {
3037                 tData.rangeEnd = (thread + 1) * natperthread;
3038             }
3039             else
3040             {
3041                 /* The last thread should cover up to the end of the range */
3042                 tData.rangeEnd = numAtoms;
3043             }
3044             assignVsitesToThread(
3045                     &tData, thread, numThreads_, natperthread, taskIndex_, ilists, iparams, ptype);
3046
3047             if (tData.useInterdependentTask)
3048             {
3049                 /* In the worst case, all tasks write to force ranges of
3050                  * all other tasks, leading to #tasks^2 scaling (this is only
3051                  * the overhead, the actual flops remain constant).
3052                  * But in most cases there is far less coupling. To improve
3053                  * scaling at high thread counts we therefore construct
3054                  * an index to only loop over the actually affected tasks.
3055                  */
3056                 InterdependentTask& idTask = tData.idTask;
3057
3058                 /* Ensure assignVsitesToThread finished on other threads */
3059 #pragma omp barrier
3060
3061                 idTask.spreadTask.resize(0);
3062                 idTask.reduceTask.resize(0);
3063                 for (int t = 0; t < numThreads_; t++)
3064                 {
3065                     /* Do we write to the force buffer of task t? */
3066                     if (!idTask.atomIndex[t].atom.empty())
3067                     {
3068                         idTask.spreadTask.push_back(t);
3069                     }
3070                     /* Does task t write to our force buffer? */
3071                     if (!tData_[t]->idTask.atomIndex[thread].atom.empty())
3072                     {
3073                         idTask.reduceTask.push_back(t);
3074                     }
3075                 }
3076             }
3077         }
3078         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
3079     }
3080     /* Assign all remaining vsites, that will have taskIndex[]=2*vsite->nthreads,
3081      * to a single task that will not run in parallel with other tasks.
3082      */
3083     assignVsitesToSingleTask(tData_[numThreads_].get(), 2 * numThreads_, taskIndex_, ilists, iparams);
3084
3085     if (debug && numThreads_ > 1)
3086     {
3087         fprintf(debug,
3088                 "virtual site useInterdependentTask %d, nuse:\n",
3089                 static_cast<int>(tData_[0]->useInterdependentTask));
3090         for (int th = 0; th < numThreads_ + 1; th++)
3091         {
3092             fprintf(debug, " %4d", tData_[th]->idTask.nuse);
3093         }
3094         fprintf(debug, "\n");
3095
3096         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
3097         {
3098             if (!ilists[ftype].empty())
3099             {
3100                 fprintf(debug, "%-20s thread dist:", interaction_function[ftype].longname);
3101                 for (int th = 0; th < numThreads_ + 1; th++)
3102                 {
3103                     fprintf(debug,
3104                             " %4d %4d ",
3105                             tData_[th]->ilist[ftype].size(),
3106                             tData_[th]->idTask.ilist[ftype].size());
3107                 }
3108                 fprintf(debug, "\n");
3109             }
3110         }
3111     }
3112
3113 #ifndef NDEBUG
3114     int nrOrig     = vsiteIlistNrCount(ilists);
3115     int nrThreaded = 0;
3116     for (int th = 0; th < numThreads_ + 1; th++)
3117     {
3118         nrThreaded += vsiteIlistNrCount(tData_[th]->ilist) + vsiteIlistNrCount(tData_[th]->idTask.ilist);
3119     }
3120     GMX_ASSERT(nrThreaded == nrOrig,
3121                "The number of virtual sites assigned to all thread task has to match the total "
3122                "number of virtual sites");
3123 #endif
3124 }
3125
3126 void VirtualSitesHandler::Impl::setVirtualSites(ArrayRef<const InteractionList> ilists,
3127                                                 const int                       numAtoms,
3128                                                 const int                       homenr,
3129                                                 ArrayRef<const ParticleType>    ptype)
3130 {
3131     ilists_ = ilists;
3132
3133     threadingInfo_.setVirtualSites(ilists, iparams_, numAtoms, homenr, ptype, domainInfo_.useDomdec());
3134 }
3135
3136 void VirtualSitesHandler::setVirtualSites(ArrayRef<const InteractionList> ilists,
3137                                           const int                       numAtoms,
3138                                           const int                       homenr,
3139                                           ArrayRef<const ParticleType>    ptype)
3140 {
3141     impl_->setVirtualSites(ilists, numAtoms, homenr, ptype);
3142 }
3143
3144 } // namespace gmx