9be178c243ddc82bbde88639c1c54b1df3f651ad
[alexxy/gromacs.git] / src / gromacs / mdlib / vsite.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017 The GROMACS development team.
7  * Copyright (c) 2018,2019,2020,2021, by the GROMACS development team, led by
8  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
9  * and including many others, as listed in the AUTHORS file in the
10  * top-level source directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 /*! \internal \file
39  * \brief Implements the VirtualSitesHandler class and vsite standalone functions
40  *
41  * \author Berk Hess <hess@kth.se>
42  * \ingroup module_mdlib
43  */
44
45 #include "gmxpre.h"
46
47 #include "vsite.h"
48
49 #include <cstdio>
50
51 #include <algorithm>
52 #include <memory>
53 #include <vector>
54
55 #include "gromacs/domdec/domdec.h"
56 #include "gromacs/domdec/domdec_struct.h"
57 #include "gromacs/gmxlib/network.h"
58 #include "gromacs/gmxlib/nrnb.h"
59 #include "gromacs/math/functions.h"
60 #include "gromacs/math/vec.h"
61 #include "gromacs/mdlib/gmx_omp_nthreads.h"
62 #include "gromacs/mdtypes/commrec.h"
63 #include "gromacs/mdtypes/mdatom.h"
64 #include "gromacs/pbcutil/ishift.h"
65 #include "gromacs/pbcutil/pbc.h"
66 #include "gromacs/timing/wallcycle.h"
67 #include "gromacs/topology/ifunc.h"
68 #include "gromacs/topology/mtop_util.h"
69 #include "gromacs/topology/topology.h"
70 #include "gromacs/utility/exceptions.h"
71 #include "gromacs/utility/fatalerror.h"
72 #include "gromacs/utility/gmxassert.h"
73 #include "gromacs/utility/gmxomp.h"
74
75 /* The strategy used here for assigning virtual sites to (thread-)tasks
76  * is as follows:
77  *
78  * We divide the atom range that vsites operate on (natoms_local with DD,
79  * 0 - last atom involved in vsites without DD) equally over all threads.
80  *
81  * Vsites in the local range constructed from atoms in the local range
82  * and/or other vsites that are fully local are assigned to a simple,
83  * independent task.
84  *
85  * Vsites that are not assigned after using the above criterion get assigned
86  * to a so called "interdependent" thread task when none of the constructing
87  * atoms is a vsite. These tasks are called interdependent, because one task
88  * accesses atoms assigned to a different task/thread.
89  * Note that this option is turned off with large (local) atom counts
90  * to avoid high memory usage.
91  *
92  * Any remaining vsites are assigned to a separate master thread task.
93  */
94 namespace gmx
95 {
96
97 //! VirialHandling is often used outside VirtualSitesHandler class members
98 using VirialHandling = VirtualSitesHandler::VirialHandling;
99
100 /*! \brief Information on PBC and domain decomposition for virtual sites
101  */
102 struct DomainInfo
103 {
104 public:
105     //! Constructs without PBC and DD
106     DomainInfo() = default;
107
108     //! Constructs with PBC and DD, if !=nullptr
109     DomainInfo(PbcType pbcType, bool haveInterUpdateGroupVirtualSites, gmx_domdec_t* domdec) :
110         pbcType_(pbcType),
111         useMolPbc_(pbcType != PbcType::No && haveInterUpdateGroupVirtualSites),
112         domdec_(domdec)
113     {
114     }
115
116     //! Returns whether we are using domain decomposition with more than 1 DD rank
117     bool useDomdec() const { return (domdec_ != nullptr); }
118
119     //! The pbc type
120     const PbcType pbcType_ = PbcType::No;
121     //! Whether molecules are broken over PBC
122     const bool useMolPbc_ = false;
123     //! Pointer to the domain decomposition struct, nullptr without PP DD
124     const gmx_domdec_t* domdec_ = nullptr;
125 };
126
127 /*! \brief List of atom indices belonging to a task
128  */
129 struct AtomIndex
130 {
131     //! List of atom indices
132     std::vector<int> atom;
133 };
134
135 /*! \brief Data structure for thread tasks that use constructing atoms outside their own atom range
136  */
137 struct InterdependentTask
138 {
139     //! The interaction lists, only vsite entries are used
140     InteractionLists ilist;
141     //! Thread/task-local force buffer
142     std::vector<RVec> force;
143     //! The atom indices of the vsites of our task
144     std::vector<int> vsite;
145     //! Flags if elements in force are spread to or not
146     std::vector<bool> use;
147     //! The number of entries set to true in use
148     int nuse = 0;
149     //! Array of atoms indices, size nthreads, covering all nuse set elements in use
150     std::vector<AtomIndex> atomIndex;
151     //! List of tasks (force blocks) this task spread forces to
152     std::vector<int> spreadTask;
153     //! List of tasks that write to this tasks force block range
154     std::vector<int> reduceTask;
155 };
156
157 /*! \brief Vsite thread task data structure
158  */
159 struct VsiteThread
160 {
161     //! Start of atom range of this task
162     int rangeStart;
163     //! End of atom range of this task
164     int rangeEnd;
165     //! The interaction lists, only vsite entries are used
166     std::array<InteractionList, F_NRE> ilist;
167     //! Local fshift accumulation buffer
168     std::array<RVec, c_numShiftVectors> fshift;
169     //! Local virial dx*df accumulation buffer
170     matrix dxdf;
171     //! Tells if interdependent task idTask should be used (in addition to the rest of this task), this bool has the same value on all threads
172     bool useInterdependentTask;
173     //! Data for vsites that involve constructing atoms in the atom range of other threads/tasks
174     InterdependentTask idTask;
175
176     /*! \brief Constructor */
177     VsiteThread()
178     {
179         rangeStart = -1;
180         rangeEnd   = -1;
181         for (auto& elem : fshift)
182         {
183             elem = { 0.0_real, 0.0_real, 0.0_real };
184         }
185         clear_mat(dxdf);
186         useInterdependentTask = false;
187     }
188 };
189
190
191 /*! \brief Information on how the virtual site work is divided over thread tasks
192  */
193 class ThreadingInfo
194 {
195 public:
196     //! Constructor, retrieves the number of threads to use from gmx_omp_nthreads.h
197     ThreadingInfo();
198
199     //! Returns the number of threads to use for vsite operations
200     int numThreads() const { return numThreads_; }
201
202     //! Returns the thread data for the given thread
203     const VsiteThread& threadData(int threadIndex) const { return *tData_[threadIndex]; }
204
205     //! Returns the thread data for the given thread
206     VsiteThread& threadData(int threadIndex) { return *tData_[threadIndex]; }
207
208     //! Returns the thread data for vsites that depend on non-local vsites
209     const VsiteThread& threadDataNonLocalDependent() const { return *tData_[numThreads_]; }
210
211     //! Returns the thread data for vsites that depend on non-local vsites
212     VsiteThread& threadDataNonLocalDependent() { return *tData_[numThreads_]; }
213
214     //! Set VSites and distribute VSite work over threads, should be called after DD partitioning
215     void setVirtualSites(ArrayRef<const InteractionList> ilist,
216                          ArrayRef<const t_iparams>       iparams,
217                          const t_mdatoms&                mdatoms,
218                          bool                            useDomdec);
219
220 private:
221     //! Number of threads used for vsite operations
222     const int numThreads_;
223     //! Thread local vsites and work structs
224     std::vector<std::unique_ptr<VsiteThread>> tData_;
225     //! Work array for dividing vsites over threads
226     std::vector<int> taskIndex_;
227 };
228
229 /*! \brief Impl class for VirtualSitesHandler
230  */
231 class VirtualSitesHandler::Impl
232 {
233 public:
234     //! Constructor, domdec should be nullptr without DD
235     Impl(const gmx_mtop_t& mtop, gmx_domdec_t* domdec, PbcType pbcType);
236
237     //! Returns the number of virtual sites acting over multiple update groups
238     int numInterUpdategroupVirtualSites() const { return numInterUpdategroupVirtualSites_; }
239
240     //! Set VSites and distribute VSite work over threads, should be called after DD partitioning
241     void setVirtualSites(ArrayRef<const InteractionList> ilist, const t_mdatoms& mdatoms);
242
243     /*! \brief Create positions of vsite atoms based for the local system
244      *
245      * \param[in,out] x          The coordinates
246      * \param[in,out] v          The velocities, needed if operation requires it
247      * \param[in]     box        The box
248      * \param[in]     operation  Whether we calculate positions, velocities, or both
249      */
250     void construct(ArrayRef<RVec> x, ArrayRef<RVec> v, const matrix box, VSiteOperation operation) const;
251
252     /*! \brief Spread the force operating on the vsite atoms on the surrounding atoms.
253      *
254      * vsite should point to a valid object.
255      * The virialHandling parameter determines how virial contributions are handled.
256      * If this is set to Linear, shift forces are accumulated into fshift.
257      * If this is set to NonLinear, non-linear contributions are added to virial.
258      * This non-linear correction is required when the virial is not calculated
259      * afterwards from the particle position and forces, but in a different way,
260      * as for instance for the PME mesh contribution.
261      */
262     void spreadForces(ArrayRef<const RVec> x,
263                       ArrayRef<RVec>       f,
264                       VirialHandling       virialHandling,
265                       ArrayRef<RVec>       fshift,
266                       matrix               virial,
267                       t_nrnb*              nrnb,
268                       const matrix         box,
269                       gmx_wallcycle*       wcycle);
270
271 private:
272     //! The number of vsites that cross update groups, when =0 no PBC treatment is needed
273     const int numInterUpdategroupVirtualSites_;
274     //! PBC and DD information
275     const DomainInfo domainInfo_;
276     //! The interaction parameters
277     const ArrayRef<const t_iparams> iparams_;
278     //! The interaction lists
279     ArrayRef<const InteractionList> ilists_;
280     //! Information for handling vsite threading
281     ThreadingInfo threadingInfo_;
282 };
283
284 VirtualSitesHandler::~VirtualSitesHandler() = default;
285
286 int VirtualSitesHandler::numInterUpdategroupVirtualSites() const
287 {
288     return impl_->numInterUpdategroupVirtualSites();
289 }
290
291 /*! \brief Returns the sum of the vsite ilist sizes over all vsite types
292  *
293  * \param[in] ilist  The interaction list
294  */
295 static int vsiteIlistNrCount(ArrayRef<const InteractionList> ilist)
296 {
297     int nr = 0;
298     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
299     {
300         nr += ilist[ftype].size();
301     }
302
303     return nr;
304 }
305
306 //! Computes the distance between xi and xj, pbc is used when pbc!=nullptr
307 static int pbc_rvec_sub(const t_pbc* pbc, const rvec xi, const rvec xj, rvec dx)
308 {
309     if (pbc)
310     {
311         return pbc_dx_aiuc(pbc, xi, xj, dx);
312     }
313     else
314     {
315         rvec_sub(xi, xj, dx);
316         return c_centralShiftIndex;
317     }
318 }
319
320 //! Returns the 1/norm(x)
321 static inline real inverseNorm(const rvec x)
322 {
323     return gmx::invsqrt(iprod(x, x));
324 }
325
326 //! Whether we're calculating the virtual site position
327 enum class VSiteCalculatePosition
328 {
329     Yes,
330     No
331 };
332 //! Whether we're calculating the virtual site velocity
333 enum class VSiteCalculateVelocity
334 {
335     Yes,
336     No
337 };
338
339 #ifndef DOXYGEN
340 /* Vsite construction routines */
341
342 // GCC 8 falsely flags unused variables if constexpr prunes a code path, fixed in GCC 9
343 // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85827
344 // clang-format off
345 GCC_DIAGNOSTIC_IGNORE(-Wunused-but-set-parameter)
346 // clang-format on
347
348 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
349 static void constr_vsite1(const rvec xi, rvec x, const rvec vi, rvec v)
350 {
351     if (calculatePosition == VSiteCalculatePosition::Yes)
352     {
353         copy_rvec(xi, x);
354         /* TOTAL: 0 flops */
355     }
356     if (calculateVelocity == VSiteCalculateVelocity::Yes)
357     {
358         copy_rvec(vi, v);
359     }
360 }
361
362 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
363 static void
364 constr_vsite2(const rvec xi, const rvec xj, rvec x, real a, const t_pbc* pbc, const rvec vi, const rvec vj, rvec v)
365 {
366     const real b = 1 - a;
367     /* 1 flop */
368
369     if (calculatePosition == VSiteCalculatePosition::Yes)
370     {
371         if (pbc)
372         {
373             rvec dx;
374             pbc_dx_aiuc(pbc, xj, xi, dx);
375             x[XX] = xi[XX] + a * dx[XX];
376             x[YY] = xi[YY] + a * dx[YY];
377             x[ZZ] = xi[ZZ] + a * dx[ZZ];
378         }
379         else
380         {
381             x[XX] = b * xi[XX] + a * xj[XX];
382             x[YY] = b * xi[YY] + a * xj[YY];
383             x[ZZ] = b * xi[ZZ] + a * xj[ZZ];
384             /* 9 Flops */
385         }
386         /* TOTAL: 10 flops */
387     }
388     if (calculateVelocity == VSiteCalculateVelocity::Yes)
389     {
390         v[XX] = b * vi[XX] + a * vj[XX];
391         v[YY] = b * vi[YY] + a * vj[YY];
392         v[ZZ] = b * vi[ZZ] + a * vj[ZZ];
393     }
394 }
395
396 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
397 static void
398 constr_vsite2FD(const rvec xi, const rvec xj, rvec x, real a, const t_pbc* pbc, const rvec vi, const rvec vj, rvec v)
399 {
400     rvec xij = { 0 };
401     pbc_rvec_sub(pbc, xj, xi, xij);
402     /* 3 flops */
403
404     const real invNormXij = inverseNorm(xij);
405     const real b          = a * invNormXij;
406     /* 6 + 10 flops */
407
408     if (calculatePosition == VSiteCalculatePosition::Yes)
409     {
410         x[XX] = xi[XX] + b * xij[XX];
411         x[YY] = xi[YY] + b * xij[YY];
412         x[ZZ] = xi[ZZ] + b * xij[ZZ];
413         /* 6 Flops */
414         /* TOTAL: 25 flops */
415     }
416     if (calculateVelocity == VSiteCalculateVelocity::Yes)
417     {
418         rvec vij = { 0 };
419         rvec_sub(vj, vi, vij);
420         const real vijDotXij = iprod(vij, xij);
421
422         v[XX] = vi[XX] + b * (vij[XX] - xij[XX] * vijDotXij * invNormXij * invNormXij);
423         v[YY] = vi[YY] + b * (vij[YY] - xij[YY] * vijDotXij * invNormXij * invNormXij);
424         v[ZZ] = vi[ZZ] + b * (vij[ZZ] - xij[ZZ] * vijDotXij * invNormXij * invNormXij);
425     }
426 }
427
428 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
429 static void constr_vsite3(const rvec   xi,
430                           const rvec   xj,
431                           const rvec   xk,
432                           rvec         x,
433                           real         a,
434                           real         b,
435                           const t_pbc* pbc,
436                           const rvec   vi,
437                           const rvec   vj,
438                           const rvec   vk,
439                           rvec         v)
440 {
441     const real c = 1 - a - b;
442     /* 2 flops */
443
444     if (calculatePosition == VSiteCalculatePosition::Yes)
445     {
446         if (pbc)
447         {
448             rvec dxj, dxk;
449
450             pbc_dx_aiuc(pbc, xj, xi, dxj);
451             pbc_dx_aiuc(pbc, xk, xi, dxk);
452             x[XX] = xi[XX] + a * dxj[XX] + b * dxk[XX];
453             x[YY] = xi[YY] + a * dxj[YY] + b * dxk[YY];
454             x[ZZ] = xi[ZZ] + a * dxj[ZZ] + b * dxk[ZZ];
455         }
456         else
457         {
458             x[XX] = c * xi[XX] + a * xj[XX] + b * xk[XX];
459             x[YY] = c * xi[YY] + a * xj[YY] + b * xk[YY];
460             x[ZZ] = c * xi[ZZ] + a * xj[ZZ] + b * xk[ZZ];
461             /* 15 Flops */
462         }
463         /* TOTAL: 17 flops */
464     }
465     if (calculateVelocity == VSiteCalculateVelocity::Yes)
466     {
467         v[XX] = c * vi[XX] + a * vj[XX] + b * vk[XX];
468         v[YY] = c * vi[YY] + a * vj[YY] + b * vk[YY];
469         v[ZZ] = c * vi[ZZ] + a * vj[ZZ] + b * vk[ZZ];
470     }
471 }
472
473 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
474 static void constr_vsite3FD(const rvec   xi,
475                             const rvec   xj,
476                             const rvec   xk,
477                             rvec         x,
478                             real         a,
479                             real         b,
480                             const t_pbc* pbc,
481                             const rvec   vi,
482                             const rvec   vj,
483                             const rvec   vk,
484                             rvec         v)
485 {
486     rvec xij, xjk, temp;
487
488     pbc_rvec_sub(pbc, xj, xi, xij);
489     pbc_rvec_sub(pbc, xk, xj, xjk);
490     /* 6 flops */
491
492     /* temp goes from i to a point on the line jk */
493     temp[XX] = xij[XX] + a * xjk[XX];
494     temp[YY] = xij[YY] + a * xjk[YY];
495     temp[ZZ] = xij[ZZ] + a * xjk[ZZ];
496     /* 6 flops */
497
498     const real invNormTemp = inverseNorm(temp);
499     const real c           = b * invNormTemp;
500     /* 6 + 10 flops */
501
502     if (calculatePosition == VSiteCalculatePosition::Yes)
503     {
504         x[XX] = xi[XX] + c * temp[XX];
505         x[YY] = xi[YY] + c * temp[YY];
506         x[ZZ] = xi[ZZ] + c * temp[ZZ];
507         /* 6 Flops */
508         /* TOTAL: 34 flops */
509     }
510     if (calculateVelocity == VSiteCalculateVelocity::Yes)
511     {
512         rvec vij = { 0 };
513         rvec vjk = { 0 };
514         rvec_sub(vj, vi, vij);
515         rvec_sub(vk, vj, vjk);
516         const rvec tempV = { vij[XX] + a * vjk[XX], vij[YY] + a * vjk[YY], vij[ZZ] + a * vjk[ZZ] };
517         const real tempDotTempV = iprod(temp, tempV);
518
519         v[XX] = vi[XX] + c * (tempV[XX] - temp[XX] * tempDotTempV * invNormTemp * invNormTemp);
520         v[YY] = vi[YY] + c * (tempV[YY] - temp[YY] * tempDotTempV * invNormTemp * invNormTemp);
521         v[ZZ] = vi[ZZ] + c * (tempV[ZZ] - temp[ZZ] * tempDotTempV * invNormTemp * invNormTemp);
522     }
523 }
524
525 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
526 static void constr_vsite3FAD(const rvec   xi,
527                              const rvec   xj,
528                              const rvec   xk,
529                              rvec         x,
530                              real         a,
531                              real         b,
532                              const t_pbc* pbc,
533                              const rvec   vi,
534                              const rvec   vj,
535                              const rvec   vk,
536                              rvec         v)
537 { // Note: a = d * cos(theta)
538     //       b = d * sin(theta)
539     rvec xij, xjk, xp;
540
541     pbc_rvec_sub(pbc, xj, xi, xij);
542     pbc_rvec_sub(pbc, xk, xj, xjk);
543     /* 6 flops */
544
545     const real invdij    = inverseNorm(xij);
546     const real xijDotXjk = iprod(xij, xjk);
547     const real c1        = invdij * invdij * xijDotXjk;
548     xp[XX]               = xjk[XX] - c1 * xij[XX];
549     xp[YY]               = xjk[YY] - c1 * xij[YY];
550     xp[ZZ]               = xjk[ZZ] - c1 * xij[ZZ];
551     const real a1        = a * invdij;
552     const real invNormXp = inverseNorm(xp);
553     const real b1        = b * invNormXp;
554     /* 45 */
555
556     if (calculatePosition == VSiteCalculatePosition::Yes)
557     {
558         x[XX] = xi[XX] + a1 * xij[XX] + b1 * xp[XX];
559         x[YY] = xi[YY] + a1 * xij[YY] + b1 * xp[YY];
560         x[ZZ] = xi[ZZ] + a1 * xij[ZZ] + b1 * xp[ZZ];
561         /* 12 Flops */
562         /* TOTAL: 63 flops */
563     }
564
565     if (calculateVelocity == VSiteCalculateVelocity::Yes)
566     {
567         rvec vij = { 0 };
568         rvec vjk = { 0 };
569         rvec_sub(vj, vi, vij);
570         rvec_sub(vk, vj, vjk);
571
572         const real vijDotXjkPlusXijDotVjk = iprod(vij, xjk) + iprod(xij, vjk);
573         const real xijDotVij              = iprod(xij, vij);
574         const real invNormXij2            = invdij * invdij;
575
576         rvec vp = { 0 };
577         vp[XX]  = vjk[XX]
578                  - xij[XX] * invNormXij2
579                            * (vijDotXjkPlusXijDotVjk - invNormXij2 * xijDotXjk * xijDotVij * 2)
580                  - vij[XX] * xijDotXjk * invNormXij2;
581         vp[YY] = vjk[YY]
582                  - xij[YY] * invNormXij2
583                            * (vijDotXjkPlusXijDotVjk - invNormXij2 * xijDotXjk * xijDotVij * 2)
584                  - vij[YY] * xijDotXjk * invNormXij2;
585         vp[ZZ] = vjk[ZZ]
586                  - xij[ZZ] * invNormXij2
587                            * (vijDotXjkPlusXijDotVjk - invNormXij2 * xijDotXjk * xijDotVij * 2)
588                  - vij[ZZ] * xijDotXjk * invNormXij2;
589
590         const real xpDotVp = iprod(xp, vp);
591
592         v[XX] = vi[XX] + a1 * (vij[XX] - xij[XX] * xijDotVij * invdij * invdij)
593                 + b1 * (vp[XX] - xp[XX] * xpDotVp * invNormXp * invNormXp);
594         v[YY] = vi[YY] + a1 * (vij[YY] - xij[YY] * xijDotVij * invdij * invdij)
595                 + b1 * (vp[YY] - xp[YY] * xpDotVp * invNormXp * invNormXp);
596         v[ZZ] = vi[ZZ] + a1 * (vij[ZZ] - xij[ZZ] * xijDotVij * invdij * invdij)
597                 + b1 * (vp[ZZ] - xp[ZZ] * xpDotVp * invNormXp * invNormXp);
598     }
599 }
600
601 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
602 static void constr_vsite3OUT(const rvec   xi,
603                              const rvec   xj,
604                              const rvec   xk,
605                              rvec         x,
606                              real         a,
607                              real         b,
608                              real         c,
609                              const t_pbc* pbc,
610                              const rvec   vi,
611                              const rvec   vj,
612                              const rvec   vk,
613                              rvec         v)
614 {
615     rvec xij, xik, temp;
616
617     pbc_rvec_sub(pbc, xj, xi, xij);
618     pbc_rvec_sub(pbc, xk, xi, xik);
619     cprod(xij, xik, temp);
620     /* 15 Flops */
621
622     if (calculatePosition == VSiteCalculatePosition::Yes)
623     {
624         x[XX] = xi[XX] + a * xij[XX] + b * xik[XX] + c * temp[XX];
625         x[YY] = xi[YY] + a * xij[YY] + b * xik[YY] + c * temp[YY];
626         x[ZZ] = xi[ZZ] + a * xij[ZZ] + b * xik[ZZ] + c * temp[ZZ];
627         /* 18 Flops */
628         /* TOTAL: 33 flops */
629     }
630
631     if (calculateVelocity == VSiteCalculateVelocity::Yes)
632     {
633         rvec vij = { 0 };
634         rvec vik = { 0 };
635         rvec_sub(vj, vi, vij);
636         rvec_sub(vk, vi, vik);
637
638         rvec temp1 = { 0 };
639         rvec temp2 = { 0 };
640         cprod(vij, xik, temp1);
641         cprod(xij, vik, temp2);
642
643         v[XX] = vi[XX] + a * vij[XX] + b * vik[XX] + c * (temp1[XX] + temp2[XX]);
644         v[YY] = vi[YY] + a * vij[YY] + b * vik[YY] + c * (temp1[YY] + temp2[YY]);
645         v[ZZ] = vi[ZZ] + a * vij[ZZ] + b * vik[ZZ] + c * (temp1[ZZ] + temp2[ZZ]);
646     }
647 }
648
649 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
650 static void constr_vsite4FD(const rvec   xi,
651                             const rvec   xj,
652                             const rvec   xk,
653                             const rvec   xl,
654                             rvec         x,
655                             real         a,
656                             real         b,
657                             real         c,
658                             const t_pbc* pbc,
659                             const rvec   vi,
660                             const rvec   vj,
661                             const rvec   vk,
662                             const rvec   vl,
663                             rvec         v)
664 {
665     rvec xij, xjk, xjl, temp;
666     real d;
667
668     pbc_rvec_sub(pbc, xj, xi, xij);
669     pbc_rvec_sub(pbc, xk, xj, xjk);
670     pbc_rvec_sub(pbc, xl, xj, xjl);
671     /* 9 flops */
672
673     /* temp goes from i to a point on the plane jkl */
674     temp[XX] = xij[XX] + a * xjk[XX] + b * xjl[XX];
675     temp[YY] = xij[YY] + a * xjk[YY] + b * xjl[YY];
676     temp[ZZ] = xij[ZZ] + a * xjk[ZZ] + b * xjl[ZZ];
677     /* 12 flops */
678
679     const real invRm = inverseNorm(temp);
680     d                = c * invRm;
681     /* 6 + 10 flops */
682
683     if (calculatePosition == VSiteCalculatePosition::Yes)
684     {
685         x[XX] = xi[XX] + d * temp[XX];
686         x[YY] = xi[YY] + d * temp[YY];
687         x[ZZ] = xi[ZZ] + d * temp[ZZ];
688         /* 6 Flops */
689         /* TOTAL: 43 flops */
690     }
691     if (calculateVelocity == VSiteCalculateVelocity::Yes)
692     {
693         rvec vij = { 0 };
694         rvec vjk = { 0 };
695         rvec vjl = { 0 };
696
697         rvec_sub(vj, vi, vij);
698         rvec_sub(vk, vj, vjk);
699         rvec_sub(vl, vj, vjl);
700
701         rvec vm = { 0 };
702         vm[XX]  = vij[XX] + a * vjk[XX] + b * vjl[XX];
703         vm[YY]  = vij[YY] + a * vjk[YY] + b * vjl[YY];
704         vm[ZZ]  = vij[ZZ] + a * vjk[ZZ] + b * vjl[ZZ];
705
706         const real vmDotRm = iprod(vm, temp);
707         v[XX]              = vi[XX] + d * (vm[XX] - temp[XX] * vmDotRm * invRm * invRm);
708         v[YY]              = vi[YY] + d * (vm[YY] - temp[YY] * vmDotRm * invRm * invRm);
709         v[ZZ]              = vi[ZZ] + d * (vm[ZZ] - temp[ZZ] * vmDotRm * invRm * invRm);
710     }
711 }
712
713 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
714 static void constr_vsite4FDN(const rvec   xi,
715                              const rvec   xj,
716                              const rvec   xk,
717                              const rvec   xl,
718                              rvec         x,
719                              real         a,
720                              real         b,
721                              real         c,
722                              const t_pbc* pbc,
723                              const rvec   vi,
724                              const rvec   vj,
725                              const rvec   vk,
726                              const rvec   vl,
727                              rvec         v)
728 {
729     rvec xij, xik, xil, ra, rb, rja, rjb, rm;
730     real d;
731
732     pbc_rvec_sub(pbc, xj, xi, xij);
733     pbc_rvec_sub(pbc, xk, xi, xik);
734     pbc_rvec_sub(pbc, xl, xi, xil);
735     /* 9 flops */
736
737     ra[XX] = a * xik[XX];
738     ra[YY] = a * xik[YY];
739     ra[ZZ] = a * xik[ZZ];
740
741     rb[XX] = b * xil[XX];
742     rb[YY] = b * xil[YY];
743     rb[ZZ] = b * xil[ZZ];
744
745     /* 6 flops */
746
747     rvec_sub(ra, xij, rja);
748     rvec_sub(rb, xij, rjb);
749     /* 6 flops */
750
751     cprod(rja, rjb, rm);
752     /* 9 flops */
753
754     const real invNormRm = inverseNorm(rm);
755     d                    = c * invNormRm;
756     /* 5+5+1 flops */
757
758     if (calculatePosition == VSiteCalculatePosition::Yes)
759     {
760         x[XX] = xi[XX] + d * rm[XX];
761         x[YY] = xi[YY] + d * rm[YY];
762         x[ZZ] = xi[ZZ] + d * rm[ZZ];
763         /* 6 Flops */
764         /* TOTAL: 47 flops */
765     }
766
767     if (calculateVelocity == VSiteCalculateVelocity::Yes)
768     {
769         rvec vij = { 0 };
770         rvec vik = { 0 };
771         rvec vil = { 0 };
772         rvec_sub(vj, vi, vij);
773         rvec_sub(vk, vi, vik);
774         rvec_sub(vl, vi, vil);
775
776         rvec vja = { 0 };
777         rvec vjb = { 0 };
778
779         vja[XX] = a * vik[XX] - vij[XX];
780         vja[YY] = a * vik[YY] - vij[YY];
781         vja[ZZ] = a * vik[ZZ] - vij[ZZ];
782         vjb[XX] = b * vil[XX] - vij[XX];
783         vjb[YY] = b * vil[YY] - vij[YY];
784         vjb[ZZ] = b * vil[ZZ] - vij[ZZ];
785
786         rvec temp1 = { 0 };
787         rvec temp2 = { 0 };
788         cprod(vja, rjb, temp1);
789         cprod(rja, vjb, temp2);
790
791         rvec vm = { 0 };
792         vm[XX]  = temp1[XX] + temp2[XX];
793         vm[YY]  = temp1[YY] + temp2[YY];
794         vm[ZZ]  = temp1[ZZ] + temp2[ZZ];
795
796         const real rmDotVm = iprod(rm, vm);
797         v[XX]              = vi[XX] + d * (vm[XX] - rm[XX] * rmDotVm * invNormRm * invNormRm);
798         v[YY]              = vi[YY] + d * (vm[YY] - rm[YY] * rmDotVm * invNormRm * invNormRm);
799         v[ZZ]              = vi[ZZ] + d * (vm[ZZ] - rm[ZZ] * rmDotVm * invNormRm * invNormRm);
800     }
801 }
802
803 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
804 static int constr_vsiten(const t_iatom*            ia,
805                          ArrayRef<const t_iparams> ip,
806                          ArrayRef<RVec>            x,
807                          const t_pbc*              pbc,
808                          ArrayRef<RVec>            v)
809 {
810     rvec x1, dx;
811     dvec dsum;
812     real a;
813     dvec dvsum = { 0 };
814     rvec v1    = { 0 };
815
816     const int n3 = 3 * ip[ia[0]].vsiten.n;
817     const int av = ia[1];
818     int       ai = ia[2];
819     copy_rvec(x[ai], x1);
820     copy_rvec(v[ai], v1);
821     clear_dvec(dsum);
822     for (int i = 3; i < n3; i += 3)
823     {
824         ai = ia[i + 2];
825         a  = ip[ia[i]].vsiten.a;
826         if (calculatePosition == VSiteCalculatePosition::Yes)
827         {
828             if (pbc)
829             {
830                 pbc_dx_aiuc(pbc, x[ai], x1, dx);
831             }
832             else
833             {
834                 rvec_sub(x[ai], x1, dx);
835             }
836             dsum[XX] += a * dx[XX];
837             dsum[YY] += a * dx[YY];
838             dsum[ZZ] += a * dx[ZZ];
839             /* 9 Flops */
840         }
841         if (calculateVelocity == VSiteCalculateVelocity::Yes)
842         {
843             rvec_sub(v[ai], v1, dx);
844             dvsum[XX] += a * dx[XX];
845             dvsum[YY] += a * dx[YY];
846             dvsum[ZZ] += a * dx[ZZ];
847             /* 9 Flops */
848         }
849     }
850
851     if (calculatePosition == VSiteCalculatePosition::Yes)
852     {
853         x[av][XX] = x1[XX] + dsum[XX];
854         x[av][YY] = x1[YY] + dsum[YY];
855         x[av][ZZ] = x1[ZZ] + dsum[ZZ];
856     }
857
858     if (calculateVelocity == VSiteCalculateVelocity::Yes)
859     {
860         v[av][XX] = v1[XX] + dvsum[XX];
861         v[av][YY] = v1[YY] + dvsum[YY];
862         v[av][ZZ] = v1[ZZ] + dvsum[ZZ];
863     }
864
865     return n3;
866 }
867 // End GCC 8 bug
868 GCC_DIAGNOSTIC_RESET
869
870 #endif // DOXYGEN
871
872 //! PBC modes for vsite construction and spreading
873 enum class PbcMode
874 {
875     all, //!< Apply normal, simple PBC for all vsites
876     none //!< No PBC treatment needed
877 };
878
879 /*! \brief Returns the PBC mode based on the system PBC and vsite properties
880  *
881  * \param[in] pbcPtr  A pointer to a PBC struct or nullptr when no PBC treatment is required
882  */
883 static PbcMode getPbcMode(const t_pbc* pbcPtr)
884 {
885     if (pbcPtr == nullptr)
886     {
887         return PbcMode::none;
888     }
889     else
890     {
891         return PbcMode::all;
892     }
893 }
894
895 /*! \brief Executes the vsite construction task for a single thread
896  *
897  * \tparam        operation  Whether we are calculating positions, velocities, or both
898  * \param[in,out] x   Coordinates to construct vsites for
899  * \param[in,out] v   Velocities are generated for virtual sites if `operation` requires it
900  * \param[in]     ip  Interaction parameters for all interaction, only vsite parameters are used
901  * \param[in]     ilist  The interaction lists, only vsites are usesd
902  * \param[in]     pbc_null  PBC struct, used for PBC distance calculations when !=nullptr
903  */
904 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
905 static void construct_vsites_thread(ArrayRef<RVec>                  x,
906                                     ArrayRef<RVec>                  v,
907                                     ArrayRef<const t_iparams>       ip,
908                                     ArrayRef<const InteractionList> ilist,
909                                     const t_pbc*                    pbc_null)
910 {
911     if (calculateVelocity == VSiteCalculateVelocity::Yes)
912     {
913         GMX_RELEASE_ASSERT(!v.empty(),
914                            "Can't calculate velocities without access to velocity vector.");
915     }
916
917     // Work around clang bug (unfixed as of Feb 2021)
918     // https://bugs.llvm.org/show_bug.cgi?id=35450
919     // clang-format off
920     CLANG_DIAGNOSTIC_IGNORE(-Wunused-lambda-capture)
921     // clang-format on
922     // GCC 8 falsely flags unused variables if constexpr prunes a code path, fixed in GCC 9
923     // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85827
924     // clang-format off
925     GCC_DIAGNOSTIC_IGNORE(-Wunused-but-set-parameter)
926     // clang-format on
927     // getVOrNull returns a velocity rvec if we need it, nullptr otherwise.
928     auto getVOrNull = [v](int idx) -> real* {
929         if (calculateVelocity == VSiteCalculateVelocity::Yes)
930         {
931             return v[idx].as_vec();
932         }
933         else
934         {
935             return nullptr;
936         }
937     };
938     GCC_DIAGNOSTIC_RESET
939     CLANG_DIAGNOSTIC_RESET
940
941     const PbcMode pbcMode = getPbcMode(pbc_null);
942     /* We need another pbc pointer, as with charge groups we switch per vsite */
943     const t_pbc* pbc_null2 = pbc_null;
944
945     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
946     {
947         if (ilist[ftype].empty())
948         {
949             continue;
950         }
951
952         { // TODO remove me
953             int nra = interaction_function[ftype].nratoms;
954             int inc = 1 + nra;
955             int nr  = ilist[ftype].size();
956
957             const t_iatom* ia = ilist[ftype].iatoms.data();
958
959             for (int i = 0; i < nr;)
960             {
961                 int tp = ia[0];
962                 /* The vsite and constructing atoms */
963                 int avsite = ia[1];
964                 int ai     = ia[2];
965                 /* Constants for constructing vsites */
966                 real a1 = ip[tp].vsite.a;
967                 /* Copy the old position */
968                 rvec xv;
969                 copy_rvec(x[avsite], xv);
970
971                 /* Construct the vsite depending on type */
972                 int  aj, ak, al;
973                 real b1, c1;
974                 switch (ftype)
975                 {
976                     case F_VSITE1:
977                         constr_vsite1<calculatePosition, calculateVelocity>(
978                                 x[ai], x[avsite], getVOrNull(ai), getVOrNull(avsite));
979                         break;
980                     case F_VSITE2:
981                         aj = ia[3];
982                         constr_vsite2<calculatePosition, calculateVelocity>(x[ai],
983                                                                             x[aj],
984                                                                             x[avsite],
985                                                                             a1,
986                                                                             pbc_null2,
987                                                                             getVOrNull(ai),
988                                                                             getVOrNull(aj),
989                                                                             getVOrNull(avsite));
990                         break;
991                     case F_VSITE2FD:
992                         aj = ia[3];
993                         constr_vsite2FD<calculatePosition, calculateVelocity>(x[ai],
994                                                                               x[aj],
995                                                                               x[avsite],
996                                                                               a1,
997                                                                               pbc_null2,
998                                                                               getVOrNull(ai),
999                                                                               getVOrNull(aj),
1000                                                                               getVOrNull(avsite));
1001                         break;
1002                     case F_VSITE3:
1003                         aj = ia[3];
1004                         ak = ia[4];
1005                         b1 = ip[tp].vsite.b;
1006                         constr_vsite3<calculatePosition, calculateVelocity>(x[ai],
1007                                                                             x[aj],
1008                                                                             x[ak],
1009                                                                             x[avsite],
1010                                                                             a1,
1011                                                                             b1,
1012                                                                             pbc_null2,
1013                                                                             getVOrNull(ai),
1014                                                                             getVOrNull(aj),
1015                                                                             getVOrNull(ak),
1016                                                                             getVOrNull(avsite));
1017                         break;
1018                     case F_VSITE3FD:
1019                         aj = ia[3];
1020                         ak = ia[4];
1021                         b1 = ip[tp].vsite.b;
1022                         constr_vsite3FD<calculatePosition, calculateVelocity>(x[ai],
1023                                                                               x[aj],
1024                                                                               x[ak],
1025                                                                               x[avsite],
1026                                                                               a1,
1027                                                                               b1,
1028                                                                               pbc_null2,
1029                                                                               getVOrNull(ai),
1030                                                                               getVOrNull(aj),
1031                                                                               getVOrNull(ak),
1032                                                                               getVOrNull(avsite));
1033                         break;
1034                     case F_VSITE3FAD:
1035                         aj = ia[3];
1036                         ak = ia[4];
1037                         b1 = ip[tp].vsite.b;
1038                         constr_vsite3FAD<calculatePosition, calculateVelocity>(x[ai],
1039                                                                                x[aj],
1040                                                                                x[ak],
1041                                                                                x[avsite],
1042                                                                                a1,
1043                                                                                b1,
1044                                                                                pbc_null2,
1045                                                                                getVOrNull(ai),
1046                                                                                getVOrNull(aj),
1047                                                                                getVOrNull(ak),
1048                                                                                getVOrNull(avsite));
1049                         break;
1050                     case F_VSITE3OUT:
1051                         aj = ia[3];
1052                         ak = ia[4];
1053                         b1 = ip[tp].vsite.b;
1054                         c1 = ip[tp].vsite.c;
1055                         constr_vsite3OUT<calculatePosition, calculateVelocity>(x[ai],
1056                                                                                x[aj],
1057                                                                                x[ak],
1058                                                                                x[avsite],
1059                                                                                a1,
1060                                                                                b1,
1061                                                                                c1,
1062                                                                                pbc_null2,
1063                                                                                getVOrNull(ai),
1064                                                                                getVOrNull(aj),
1065                                                                                getVOrNull(ak),
1066                                                                                getVOrNull(avsite));
1067                         break;
1068                     case F_VSITE4FD:
1069                         aj = ia[3];
1070                         ak = ia[4];
1071                         al = ia[5];
1072                         b1 = ip[tp].vsite.b;
1073                         c1 = ip[tp].vsite.c;
1074                         constr_vsite4FD<calculatePosition, calculateVelocity>(x[ai],
1075                                                                               x[aj],
1076                                                                               x[ak],
1077                                                                               x[al],
1078                                                                               x[avsite],
1079                                                                               a1,
1080                                                                               b1,
1081                                                                               c1,
1082                                                                               pbc_null2,
1083                                                                               getVOrNull(ai),
1084                                                                               getVOrNull(aj),
1085                                                                               getVOrNull(ak),
1086                                                                               getVOrNull(al),
1087                                                                               getVOrNull(avsite));
1088                         break;
1089                     case F_VSITE4FDN:
1090                         aj = ia[3];
1091                         ak = ia[4];
1092                         al = ia[5];
1093                         b1 = ip[tp].vsite.b;
1094                         c1 = ip[tp].vsite.c;
1095                         constr_vsite4FDN<calculatePosition, calculateVelocity>(x[ai],
1096                                                                                x[aj],
1097                                                                                x[ak],
1098                                                                                x[al],
1099                                                                                x[avsite],
1100                                                                                a1,
1101                                                                                b1,
1102                                                                                c1,
1103                                                                                pbc_null2,
1104                                                                                getVOrNull(ai),
1105                                                                                getVOrNull(aj),
1106                                                                                getVOrNull(ak),
1107                                                                                getVOrNull(al),
1108                                                                                getVOrNull(avsite));
1109                         break;
1110                     case F_VSITEN:
1111                         inc = constr_vsiten<calculatePosition, calculateVelocity>(ia, ip, x, pbc_null2, v);
1112                         break;
1113                     default:
1114                         gmx_fatal(FARGS, "No such vsite type %d in %s, line %d", ftype, __FILE__, __LINE__);
1115                 }
1116
1117                 if (pbcMode == PbcMode::all)
1118                 {
1119                     /* Keep the vsite in the same periodic image as before */
1120                     rvec dx;
1121                     int  ishift = pbc_dx_aiuc(pbc_null, x[avsite], xv, dx);
1122                     if (ishift != c_centralShiftIndex)
1123                     {
1124                         rvec_add(xv, dx, x[avsite]);
1125                     }
1126                 }
1127
1128                 /* Increment loop variables */
1129                 i += inc;
1130                 ia += inc;
1131             }
1132         }
1133     }
1134 }
1135
1136 /*! \brief Dispatch the vsite construction tasks for all threads
1137  *
1138  * \param[in]     threadingInfo  Used to divide work over threads when != nullptr
1139  * \param[in,out] x   Coordinates to construct vsites for
1140  * \param[in,out] v   When not empty, velocities are generated for virtual sites
1141  * \param[in]     ip  Interaction parameters for all interaction, only vsite parameters are used
1142  * \param[in]     ilist  The interaction lists, only vsites are usesd
1143  * \param[in]     domainInfo  Information about PBC and DD
1144  * \param[in]     box  Used for PBC when PBC is set in domainInfo
1145  */
1146 template<VSiteCalculatePosition calculatePosition, VSiteCalculateVelocity calculateVelocity>
1147 static void construct_vsites(const ThreadingInfo*            threadingInfo,
1148                              ArrayRef<RVec>                  x,
1149                              ArrayRef<RVec>                  v,
1150                              ArrayRef<const t_iparams>       ip,
1151                              ArrayRef<const InteractionList> ilist,
1152                              const DomainInfo&               domainInfo,
1153                              const matrix                    box)
1154 {
1155     const bool useDomdec = domainInfo.useDomdec();
1156
1157     t_pbc pbc, *pbc_null;
1158
1159     /* We only need to do pbc when we have inter update-group vsites.
1160      * Note that with domain decomposition we do not need to apply PBC here
1161      * when we have at least 3 domains along each dimension. Currently we
1162      * do not optimize this case.
1163      */
1164     if (domainInfo.pbcType_ != PbcType::No && domainInfo.useMolPbc_)
1165     {
1166         /* This is wasting some CPU time as we now do this multiple times
1167          * per MD step.
1168          */
1169         ivec null_ivec;
1170         clear_ivec(null_ivec);
1171         pbc_null = set_pbc_dd(
1172                 &pbc, domainInfo.pbcType_, useDomdec ? domainInfo.domdec_->numCells : null_ivec, FALSE, box);
1173     }
1174     else
1175     {
1176         pbc_null = nullptr;
1177     }
1178
1179     if (useDomdec)
1180     {
1181         if (calculateVelocity == VSiteCalculateVelocity::Yes)
1182         {
1183             dd_move_x_and_v_vsites(
1184                     *domainInfo.domdec_, box, as_rvec_array(x.data()), as_rvec_array(v.data()));
1185         }
1186         else
1187         {
1188             dd_move_x_vsites(*domainInfo.domdec_, box, as_rvec_array(x.data()));
1189         }
1190     }
1191
1192     if (threadingInfo == nullptr || threadingInfo->numThreads() == 1)
1193     {
1194         construct_vsites_thread<calculatePosition, calculateVelocity>(x, v, ip, ilist, pbc_null);
1195     }
1196     else
1197     {
1198 #pragma omp parallel num_threads(threadingInfo->numThreads())
1199         {
1200             try
1201             {
1202                 const int          th    = gmx_omp_get_thread_num();
1203                 const VsiteThread& tData = threadingInfo->threadData(th);
1204                 GMX_ASSERT(tData.rangeStart >= 0,
1205                            "The thread data should be initialized before calling construct_vsites");
1206
1207                 construct_vsites_thread<calculatePosition, calculateVelocity>(
1208                         x, v, ip, tData.ilist, pbc_null);
1209                 if (tData.useInterdependentTask)
1210                 {
1211                     /* Here we don't need a barrier (unlike the spreading),
1212                      * since both tasks only construct vsites from particles,
1213                      * or local vsites, not from non-local vsites.
1214                      */
1215                     construct_vsites_thread<calculatePosition, calculateVelocity>(
1216                             x, v, ip, tData.idTask.ilist, pbc_null);
1217                 }
1218             }
1219             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
1220         }
1221         /* Now we can construct the vsites that might depend on other vsites */
1222         construct_vsites_thread<calculatePosition, calculateVelocity>(
1223                 x, v, ip, threadingInfo->threadDataNonLocalDependent().ilist, pbc_null);
1224     }
1225 }
1226
1227 void VirtualSitesHandler::Impl::construct(ArrayRef<RVec> x,
1228                                           ArrayRef<RVec> v,
1229                                           const matrix   box,
1230                                           VSiteOperation operation) const
1231 {
1232     switch (operation)
1233     {
1234         case VSiteOperation::Positions:
1235             construct_vsites<VSiteCalculatePosition::Yes, VSiteCalculateVelocity::No>(
1236                     &threadingInfo_, x, v, iparams_, ilists_, domainInfo_, box);
1237             break;
1238         case VSiteOperation::Velocities:
1239             construct_vsites<VSiteCalculatePosition::No, VSiteCalculateVelocity::Yes>(
1240                     &threadingInfo_, x, v, iparams_, ilists_, domainInfo_, box);
1241             break;
1242         case VSiteOperation::PositionsAndVelocities:
1243             construct_vsites<VSiteCalculatePosition::Yes, VSiteCalculateVelocity::Yes>(
1244                     &threadingInfo_, x, v, iparams_, ilists_, domainInfo_, box);
1245             break;
1246         default: gmx_fatal(FARGS, "Unknown virtual site operation");
1247     }
1248 }
1249
1250 void VirtualSitesHandler::construct(ArrayRef<RVec> x, ArrayRef<RVec> v, const matrix box, VSiteOperation operation) const
1251 {
1252     impl_->construct(x, v, box, operation);
1253 }
1254
1255 void constructVirtualSites(ArrayRef<RVec> x, ArrayRef<const t_iparams> ip, ArrayRef<const InteractionList> ilist)
1256
1257 {
1258     // No PBC, no DD
1259     const DomainInfo domainInfo;
1260     construct_vsites<VSiteCalculatePosition::Yes, VSiteCalculateVelocity::No>(
1261             nullptr, x, {}, ip, ilist, domainInfo, nullptr);
1262 }
1263
1264 #ifndef DOXYGEN
1265 /* Force spreading routines */
1266
1267 static void spread_vsite1(const t_iatom ia[], ArrayRef<RVec> f)
1268 {
1269     const int av = ia[1];
1270     const int ai = ia[2];
1271
1272     f[av] += f[ai];
1273 }
1274
1275 template<VirialHandling virialHandling>
1276 static void spread_vsite2(const t_iatom        ia[],
1277                           real                 a,
1278                           ArrayRef<const RVec> x,
1279                           ArrayRef<RVec>       f,
1280                           ArrayRef<RVec>       fshift,
1281                           const t_pbc*         pbc)
1282 {
1283     rvec    fi, fj, dx;
1284     t_iatom av, ai, aj;
1285
1286     av = ia[1];
1287     ai = ia[2];
1288     aj = ia[3];
1289
1290     svmul(1 - a, f[av], fi);
1291     svmul(a, f[av], fj);
1292     /* 7 flop */
1293
1294     rvec_inc(f[ai], fi);
1295     rvec_inc(f[aj], fj);
1296     /* 6 Flops */
1297
1298     if (virialHandling == VirialHandling::Pbc)
1299     {
1300         int siv;
1301         int sij;
1302         if (pbc)
1303         {
1304             siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
1305             sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
1306         }
1307         else
1308         {
1309             siv = c_centralShiftIndex;
1310             sij = c_centralShiftIndex;
1311         }
1312
1313         if (siv != c_centralShiftIndex || sij != c_centralShiftIndex)
1314         {
1315             rvec_inc(fshift[siv], f[av]);
1316             rvec_dec(fshift[c_centralShiftIndex], fi);
1317             rvec_dec(fshift[sij], fj);
1318         }
1319     }
1320
1321     /* TOTAL: 13 flops */
1322 }
1323
1324 void constructVirtualSitesGlobal(const gmx_mtop_t& mtop, gmx::ArrayRef<gmx::RVec> x)
1325 {
1326     GMX_ASSERT(x.ssize() >= mtop.natoms, "x should contain the whole system");
1327     GMX_ASSERT(!mtop.moleculeBlockIndices.empty(),
1328                "molblock indices are needed in constructVsitesGlobal");
1329
1330     for (size_t mb = 0; mb < mtop.molblock.size(); mb++)
1331     {
1332         const gmx_molblock_t& molb = mtop.molblock[mb];
1333         const gmx_moltype_t&  molt = mtop.moltype[molb.type];
1334         if (vsiteIlistNrCount(molt.ilist) > 0)
1335         {
1336             int atomOffset = mtop.moleculeBlockIndices[mb].globalAtomStart;
1337             for (int mol = 0; mol < molb.nmol; mol++)
1338             {
1339                 constructVirtualSites(
1340                         x.subArray(atomOffset, molt.atoms.nr), mtop.ffparams.iparams, molt.ilist);
1341                 atomOffset += molt.atoms.nr;
1342             }
1343         }
1344     }
1345 }
1346
1347 template<VirialHandling virialHandling>
1348 static void spread_vsite2FD(const t_iatom        ia[],
1349                             real                 a,
1350                             ArrayRef<const RVec> x,
1351                             ArrayRef<RVec>       f,
1352                             ArrayRef<RVec>       fshift,
1353                             matrix               dxdf,
1354                             const t_pbc*         pbc)
1355 {
1356     const int av = ia[1];
1357     const int ai = ia[2];
1358     const int aj = ia[3];
1359     rvec      fv;
1360     copy_rvec(f[av], fv);
1361
1362     rvec xij;
1363     int  sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1364     /* 6 flops */
1365
1366     const real invDistance = inverseNorm(xij);
1367     const real b           = a * invDistance;
1368     /* 4 + ?10? flops */
1369
1370     const real fproj = iprod(xij, fv) * invDistance * invDistance;
1371
1372     rvec fj;
1373     fj[XX] = b * (fv[XX] - fproj * xij[XX]);
1374     fj[YY] = b * (fv[YY] - fproj * xij[YY]);
1375     fj[ZZ] = b * (fv[ZZ] - fproj * xij[ZZ]);
1376     /* 9 */
1377
1378     /* b is already calculated in constr_vsite2FD
1379        storing b somewhere will save flops.     */
1380
1381     f[ai][XX] += fv[XX] - fj[XX];
1382     f[ai][YY] += fv[YY] - fj[YY];
1383     f[ai][ZZ] += fv[ZZ] - fj[ZZ];
1384     f[aj][XX] += fj[XX];
1385     f[aj][YY] += fj[YY];
1386     f[aj][ZZ] += fj[ZZ];
1387     /* 9 Flops */
1388
1389     if (virialHandling == VirialHandling::Pbc)
1390     {
1391         int svi;
1392         if (pbc)
1393         {
1394             rvec xvi;
1395             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1396         }
1397         else
1398         {
1399             svi = c_centralShiftIndex;
1400         }
1401
1402         if (svi != c_centralShiftIndex || sji != c_centralShiftIndex)
1403         {
1404             rvec_dec(fshift[svi], fv);
1405             fshift[c_centralShiftIndex][XX] += fv[XX] - fj[XX];
1406             fshift[c_centralShiftIndex][YY] += fv[YY] - fj[YY];
1407             fshift[c_centralShiftIndex][ZZ] += fv[ZZ] - fj[ZZ];
1408             fshift[sji][XX] += fj[XX];
1409             fshift[sji][YY] += fj[YY];
1410             fshift[sji][ZZ] += fj[ZZ];
1411         }
1412     }
1413
1414     if (virialHandling == VirialHandling::NonLinear)
1415     {
1416         /* Under this condition, the virial for the current forces is not
1417          * calculated from the redistributed forces. This means that
1418          * the effect of non-linear virtual site constructions on the virial
1419          * needs to be added separately. This contribution can be calculated
1420          * in many ways, but the simplest and cheapest way is to use
1421          * the first constructing atom ai as a reference position in space:
1422          * subtract (xv-xi)*fv and add (xj-xi)*fj.
1423          */
1424         rvec xiv;
1425
1426         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1427
1428         for (int i = 0; i < DIM; i++)
1429         {
1430             for (int j = 0; j < DIM; j++)
1431             {
1432                 /* As xix is a linear combination of j and k, use that here */
1433                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * fj[j];
1434             }
1435         }
1436     }
1437
1438     /* TOTAL: 38 flops */
1439 }
1440
1441 template<VirialHandling virialHandling>
1442 static void spread_vsite3(const t_iatom        ia[],
1443                           real                 a,
1444                           real                 b,
1445                           ArrayRef<const RVec> x,
1446                           ArrayRef<RVec>       f,
1447                           ArrayRef<RVec>       fshift,
1448                           const t_pbc*         pbc)
1449 {
1450     rvec fi, fj, fk, dx;
1451     int  av, ai, aj, ak;
1452
1453     av = ia[1];
1454     ai = ia[2];
1455     aj = ia[3];
1456     ak = ia[4];
1457
1458     svmul(1 - a - b, f[av], fi);
1459     svmul(a, f[av], fj);
1460     svmul(b, f[av], fk);
1461     /* 11 flops */
1462
1463     rvec_inc(f[ai], fi);
1464     rvec_inc(f[aj], fj);
1465     rvec_inc(f[ak], fk);
1466     /* 9 Flops */
1467
1468     if (virialHandling == VirialHandling::Pbc)
1469     {
1470         int siv;
1471         int sij;
1472         int sik;
1473         if (pbc)
1474         {
1475             siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
1476             sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
1477             sik = pbc_dx_aiuc(pbc, x[ai], x[ak], dx);
1478         }
1479         else
1480         {
1481             siv = c_centralShiftIndex;
1482             sij = c_centralShiftIndex;
1483             sik = c_centralShiftIndex;
1484         }
1485
1486         if (siv != c_centralShiftIndex || sij != c_centralShiftIndex || sik != c_centralShiftIndex)
1487         {
1488             rvec_inc(fshift[siv], f[av]);
1489             rvec_dec(fshift[c_centralShiftIndex], fi);
1490             rvec_dec(fshift[sij], fj);
1491             rvec_dec(fshift[sik], fk);
1492         }
1493     }
1494
1495     /* TOTAL: 20 flops */
1496 }
1497
1498 template<VirialHandling virialHandling>
1499 static void spread_vsite3FD(const t_iatom        ia[],
1500                             real                 a,
1501                             real                 b,
1502                             ArrayRef<const RVec> x,
1503                             ArrayRef<RVec>       f,
1504                             ArrayRef<RVec>       fshift,
1505                             matrix               dxdf,
1506                             const t_pbc*         pbc)
1507 {
1508     real    fproj, a1;
1509     rvec    xvi, xij, xjk, xix, fv, temp;
1510     t_iatom av, ai, aj, ak;
1511     int     sji, skj;
1512
1513     av = ia[1];
1514     ai = ia[2];
1515     aj = ia[3];
1516     ak = ia[4];
1517     copy_rvec(f[av], fv);
1518
1519     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1520     skj = pbc_rvec_sub(pbc, x[ak], x[aj], xjk);
1521     /* 6 flops */
1522
1523     /* xix goes from i to point x on the line jk */
1524     xix[XX] = xij[XX] + a * xjk[XX];
1525     xix[YY] = xij[YY] + a * xjk[YY];
1526     xix[ZZ] = xij[ZZ] + a * xjk[ZZ];
1527     /* 6 flops */
1528
1529     const real invDistance = inverseNorm(xix);
1530     const real c           = b * invDistance;
1531     /* 4 + ?10? flops */
1532
1533     fproj = iprod(xix, fv) * invDistance * invDistance; /* = (xix . f)/(xix . xix) */
1534
1535     temp[XX] = c * (fv[XX] - fproj * xix[XX]);
1536     temp[YY] = c * (fv[YY] - fproj * xix[YY]);
1537     temp[ZZ] = c * (fv[ZZ] - fproj * xix[ZZ]);
1538     /* 16 */
1539
1540     /* c is already calculated in constr_vsite3FD
1541        storing c somewhere will save 26 flops!     */
1542
1543     a1 = 1 - a;
1544     f[ai][XX] += fv[XX] - temp[XX];
1545     f[ai][YY] += fv[YY] - temp[YY];
1546     f[ai][ZZ] += fv[ZZ] - temp[ZZ];
1547     f[aj][XX] += a1 * temp[XX];
1548     f[aj][YY] += a1 * temp[YY];
1549     f[aj][ZZ] += a1 * temp[ZZ];
1550     f[ak][XX] += a * temp[XX];
1551     f[ak][YY] += a * temp[YY];
1552     f[ak][ZZ] += a * temp[ZZ];
1553     /* 19 Flops */
1554
1555     if (virialHandling == VirialHandling::Pbc)
1556     {
1557         int svi;
1558         if (pbc)
1559         {
1560             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1561         }
1562         else
1563         {
1564             svi = c_centralShiftIndex;
1565         }
1566
1567         if (svi != c_centralShiftIndex || sji != c_centralShiftIndex || skj != c_centralShiftIndex)
1568         {
1569             rvec_dec(fshift[svi], fv);
1570             fshift[c_centralShiftIndex][XX] += fv[XX] - (1 + a) * temp[XX];
1571             fshift[c_centralShiftIndex][YY] += fv[YY] - (1 + a) * temp[YY];
1572             fshift[c_centralShiftIndex][ZZ] += fv[ZZ] - (1 + a) * temp[ZZ];
1573             fshift[sji][XX] += temp[XX];
1574             fshift[sji][YY] += temp[YY];
1575             fshift[sji][ZZ] += temp[ZZ];
1576             fshift[skj][XX] += a * temp[XX];
1577             fshift[skj][YY] += a * temp[YY];
1578             fshift[skj][ZZ] += a * temp[ZZ];
1579         }
1580     }
1581
1582     if (virialHandling == VirialHandling::NonLinear)
1583     {
1584         /* Under this condition, the virial for the current forces is not
1585          * calculated from the redistributed forces. This means that
1586          * the effect of non-linear virtual site constructions on the virial
1587          * needs to be added separately. This contribution can be calculated
1588          * in many ways, but the simplest and cheapest way is to use
1589          * the first constructing atom ai as a reference position in space:
1590          * subtract (xv-xi)*fv and add (xj-xi)*fj + (xk-xi)*fk.
1591          */
1592         rvec xiv;
1593
1594         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1595
1596         for (int i = 0; i < DIM; i++)
1597         {
1598             for (int j = 0; j < DIM; j++)
1599             {
1600                 /* As xix is a linear combination of j and k, use that here */
1601                 dxdf[i][j] += -xiv[i] * fv[j] + xix[i] * temp[j];
1602             }
1603         }
1604     }
1605
1606     /* TOTAL: 61 flops */
1607 }
1608
1609 template<VirialHandling virialHandling>
1610 static void spread_vsite3FAD(const t_iatom        ia[],
1611                              real                 a,
1612                              real                 b,
1613                              ArrayRef<const RVec> x,
1614                              ArrayRef<RVec>       f,
1615                              ArrayRef<RVec>       fshift,
1616                              matrix               dxdf,
1617                              const t_pbc*         pbc)
1618 {
1619     rvec    xvi, xij, xjk, xperp, Fpij, Fppp, fv, f1, f2, f3;
1620     real    a1, b1, c1, c2, invdij, invdij2, invdp, fproj;
1621     t_iatom av, ai, aj, ak;
1622     int     sji, skj;
1623
1624     av = ia[1];
1625     ai = ia[2];
1626     aj = ia[3];
1627     ak = ia[4];
1628     copy_rvec(f[ia[1]], fv);
1629
1630     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1631     skj = pbc_rvec_sub(pbc, x[ak], x[aj], xjk);
1632     /* 6 flops */
1633
1634     invdij    = inverseNorm(xij);
1635     invdij2   = invdij * invdij;
1636     c1        = iprod(xij, xjk) * invdij2;
1637     xperp[XX] = xjk[XX] - c1 * xij[XX];
1638     xperp[YY] = xjk[YY] - c1 * xij[YY];
1639     xperp[ZZ] = xjk[ZZ] - c1 * xij[ZZ];
1640     /* xperp in plane ijk, perp. to ij */
1641     invdp = inverseNorm(xperp);
1642     a1    = a * invdij;
1643     b1    = b * invdp;
1644     /* 45 flops */
1645
1646     /* a1, b1 and c1 are already calculated in constr_vsite3FAD
1647        storing them somewhere will save 45 flops!     */
1648
1649     fproj = iprod(xij, fv) * invdij2;
1650     svmul(fproj, xij, Fpij);                              /* proj. f on xij */
1651     svmul(iprod(xperp, fv) * invdp * invdp, xperp, Fppp); /* proj. f on xperp */
1652     svmul(b1 * fproj, xperp, f3);
1653     /* 23 flops */
1654
1655     rvec_sub(fv, Fpij, f1); /* f1 = f - Fpij */
1656     rvec_sub(f1, Fppp, f2); /* f2 = f - Fpij - Fppp */
1657     for (int d = 0; d < DIM; d++)
1658     {
1659         f1[d] *= a1;
1660         f2[d] *= b1;
1661     }
1662     /* 12 flops */
1663
1664     c2 = 1 + c1;
1665     f[ai][XX] += fv[XX] - f1[XX] + c1 * f2[XX] + f3[XX];
1666     f[ai][YY] += fv[YY] - f1[YY] + c1 * f2[YY] + f3[YY];
1667     f[ai][ZZ] += fv[ZZ] - f1[ZZ] + c1 * f2[ZZ] + f3[ZZ];
1668     f[aj][XX] += f1[XX] - c2 * f2[XX] - f3[XX];
1669     f[aj][YY] += f1[YY] - c2 * f2[YY] - f3[YY];
1670     f[aj][ZZ] += f1[ZZ] - c2 * f2[ZZ] - f3[ZZ];
1671     f[ak][XX] += f2[XX];
1672     f[ak][YY] += f2[YY];
1673     f[ak][ZZ] += f2[ZZ];
1674     /* 30 Flops */
1675
1676     if (virialHandling == VirialHandling::Pbc)
1677     {
1678         int svi;
1679
1680         if (pbc)
1681         {
1682             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1683         }
1684         else
1685         {
1686             svi = c_centralShiftIndex;
1687         }
1688
1689         if (svi != c_centralShiftIndex || sji != c_centralShiftIndex || skj != c_centralShiftIndex)
1690         {
1691             rvec_dec(fshift[svi], fv);
1692             fshift[c_centralShiftIndex][XX] += fv[XX] - f1[XX] - (1 - c1) * f2[XX] + f3[XX];
1693             fshift[c_centralShiftIndex][YY] += fv[YY] - f1[YY] - (1 - c1) * f2[YY] + f3[YY];
1694             fshift[c_centralShiftIndex][ZZ] += fv[ZZ] - f1[ZZ] - (1 - c1) * f2[ZZ] + f3[ZZ];
1695             fshift[sji][XX] += f1[XX] - c1 * f2[XX] - f3[XX];
1696             fshift[sji][YY] += f1[YY] - c1 * f2[YY] - f3[YY];
1697             fshift[sji][ZZ] += f1[ZZ] - c1 * f2[ZZ] - f3[ZZ];
1698             fshift[skj][XX] += f2[XX];
1699             fshift[skj][YY] += f2[YY];
1700             fshift[skj][ZZ] += f2[ZZ];
1701         }
1702     }
1703
1704     if (virialHandling == VirialHandling::NonLinear)
1705     {
1706         rvec xiv;
1707         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1708
1709         for (int i = 0; i < DIM; i++)
1710         {
1711             for (int j = 0; j < DIM; j++)
1712             {
1713                 /* Note that xik=xij+xjk, so we have to add xij*f2 */
1714                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * (f1[j] + (1 - c2) * f2[j] - f3[j])
1715                               + xjk[i] * f2[j];
1716             }
1717         }
1718     }
1719
1720     /* TOTAL: 113 flops */
1721 }
1722
1723 template<VirialHandling virialHandling>
1724 static void spread_vsite3OUT(const t_iatom        ia[],
1725                              real                 a,
1726                              real                 b,
1727                              real                 c,
1728                              ArrayRef<const RVec> x,
1729                              ArrayRef<RVec>       f,
1730                              ArrayRef<RVec>       fshift,
1731                              matrix               dxdf,
1732                              const t_pbc*         pbc)
1733 {
1734     rvec xvi, xij, xik, fv, fj, fk;
1735     real cfx, cfy, cfz;
1736     int  av, ai, aj, ak;
1737     int  sji, ski;
1738
1739     av = ia[1];
1740     ai = ia[2];
1741     aj = ia[3];
1742     ak = ia[4];
1743
1744     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1745     ski = pbc_rvec_sub(pbc, x[ak], x[ai], xik);
1746     /* 6 Flops */
1747
1748     copy_rvec(f[av], fv);
1749
1750     cfx = c * fv[XX];
1751     cfy = c * fv[YY];
1752     cfz = c * fv[ZZ];
1753     /* 3 Flops */
1754
1755     fj[XX] = a * fv[XX] - xik[ZZ] * cfy + xik[YY] * cfz;
1756     fj[YY] = xik[ZZ] * cfx + a * fv[YY] - xik[XX] * cfz;
1757     fj[ZZ] = -xik[YY] * cfx + xik[XX] * cfy + a * fv[ZZ];
1758
1759     fk[XX] = b * fv[XX] + xij[ZZ] * cfy - xij[YY] * cfz;
1760     fk[YY] = -xij[ZZ] * cfx + b * fv[YY] + xij[XX] * cfz;
1761     fk[ZZ] = xij[YY] * cfx - xij[XX] * cfy + b * fv[ZZ];
1762     /* 30 Flops */
1763
1764     f[ai][XX] += fv[XX] - fj[XX] - fk[XX];
1765     f[ai][YY] += fv[YY] - fj[YY] - fk[YY];
1766     f[ai][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ];
1767     rvec_inc(f[aj], fj);
1768     rvec_inc(f[ak], fk);
1769     /* 15 Flops */
1770
1771     if (virialHandling == VirialHandling::Pbc)
1772     {
1773         int svi;
1774         if (pbc)
1775         {
1776             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1777         }
1778         else
1779         {
1780             svi = c_centralShiftIndex;
1781         }
1782
1783         if (svi != c_centralShiftIndex || sji != c_centralShiftIndex || ski != c_centralShiftIndex)
1784         {
1785             rvec_dec(fshift[svi], fv);
1786             fshift[c_centralShiftIndex][XX] += fv[XX] - fj[XX] - fk[XX];
1787             fshift[c_centralShiftIndex][YY] += fv[YY] - fj[YY] - fk[YY];
1788             fshift[c_centralShiftIndex][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ];
1789             rvec_inc(fshift[sji], fj);
1790             rvec_inc(fshift[ski], fk);
1791         }
1792     }
1793
1794     if (virialHandling == VirialHandling::NonLinear)
1795     {
1796         rvec xiv;
1797
1798         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1799
1800         for (int i = 0; i < DIM; i++)
1801         {
1802             for (int j = 0; j < DIM; j++)
1803             {
1804                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * fj[j] + xik[i] * fk[j];
1805             }
1806         }
1807     }
1808
1809     /* TOTAL: 54 flops */
1810 }
1811
1812 template<VirialHandling virialHandling>
1813 static void spread_vsite4FD(const t_iatom        ia[],
1814                             real                 a,
1815                             real                 b,
1816                             real                 c,
1817                             ArrayRef<const RVec> x,
1818                             ArrayRef<RVec>       f,
1819                             ArrayRef<RVec>       fshift,
1820                             matrix               dxdf,
1821                             const t_pbc*         pbc)
1822 {
1823     real fproj, a1;
1824     rvec xvi, xij, xjk, xjl, xix, fv, temp;
1825     int  av, ai, aj, ak, al;
1826     int  sji, skj, slj, m;
1827
1828     av = ia[1];
1829     ai = ia[2];
1830     aj = ia[3];
1831     ak = ia[4];
1832     al = ia[5];
1833
1834     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1835     skj = pbc_rvec_sub(pbc, x[ak], x[aj], xjk);
1836     slj = pbc_rvec_sub(pbc, x[al], x[aj], xjl);
1837     /* 9 flops */
1838
1839     /* xix goes from i to point x on the plane jkl */
1840     for (m = 0; m < DIM; m++)
1841     {
1842         xix[m] = xij[m] + a * xjk[m] + b * xjl[m];
1843     }
1844     /* 12 flops */
1845
1846     const real invDistance = inverseNorm(xix);
1847     const real d           = c * invDistance;
1848     /* 4 + ?10? flops */
1849
1850     copy_rvec(f[av], fv);
1851
1852     fproj = iprod(xix, fv) * invDistance * invDistance; /* = (xix . f)/(xix . xix) */
1853
1854     for (m = 0; m < DIM; m++)
1855     {
1856         temp[m] = d * (fv[m] - fproj * xix[m]);
1857     }
1858     /* 16 */
1859
1860     /* c is already calculated in constr_vsite3FD
1861        storing c somewhere will save 35 flops!     */
1862
1863     a1 = 1 - a - b;
1864     for (m = 0; m < DIM; m++)
1865     {
1866         f[ai][m] += fv[m] - temp[m];
1867         f[aj][m] += a1 * temp[m];
1868         f[ak][m] += a * temp[m];
1869         f[al][m] += b * temp[m];
1870     }
1871     /* 26 Flops */
1872
1873     if (virialHandling == VirialHandling::Pbc)
1874     {
1875         int svi;
1876         if (pbc)
1877         {
1878             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1879         }
1880         else
1881         {
1882             svi = c_centralShiftIndex;
1883         }
1884
1885         if (svi != c_centralShiftIndex || sji != c_centralShiftIndex || skj != c_centralShiftIndex
1886             || slj != c_centralShiftIndex)
1887         {
1888             rvec_dec(fshift[svi], fv);
1889             for (m = 0; m < DIM; m++)
1890             {
1891                 fshift[c_centralShiftIndex][m] += fv[m] - (1 + a + b) * temp[m];
1892                 fshift[sji][m] += temp[m];
1893                 fshift[skj][m] += a * temp[m];
1894                 fshift[slj][m] += b * temp[m];
1895             }
1896         }
1897     }
1898
1899     if (virialHandling == VirialHandling::NonLinear)
1900     {
1901         rvec xiv;
1902         int  i, j;
1903
1904         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1905
1906         for (i = 0; i < DIM; i++)
1907         {
1908             for (j = 0; j < DIM; j++)
1909             {
1910                 dxdf[i][j] += -xiv[i] * fv[j] + xix[i] * temp[j];
1911             }
1912         }
1913     }
1914
1915     /* TOTAL: 77 flops */
1916 }
1917
1918 template<VirialHandling virialHandling>
1919 static void spread_vsite4FDN(const t_iatom        ia[],
1920                              real                 a,
1921                              real                 b,
1922                              real                 c,
1923                              ArrayRef<const RVec> x,
1924                              ArrayRef<RVec>       f,
1925                              ArrayRef<RVec>       fshift,
1926                              matrix               dxdf,
1927                              const t_pbc*         pbc)
1928 {
1929     rvec xvi, xij, xik, xil, ra, rb, rja, rjb, rab, rm, rt;
1930     rvec fv, fj, fk, fl;
1931     real invrm, denom;
1932     real cfx, cfy, cfz;
1933     int  av, ai, aj, ak, al;
1934     int  sij, sik, sil;
1935
1936     /* DEBUG: check atom indices */
1937     av = ia[1];
1938     ai = ia[2];
1939     aj = ia[3];
1940     ak = ia[4];
1941     al = ia[5];
1942
1943     copy_rvec(f[av], fv);
1944
1945     sij = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1946     sik = pbc_rvec_sub(pbc, x[ak], x[ai], xik);
1947     sil = pbc_rvec_sub(pbc, x[al], x[ai], xil);
1948     /* 9 flops */
1949
1950     ra[XX] = a * xik[XX];
1951     ra[YY] = a * xik[YY];
1952     ra[ZZ] = a * xik[ZZ];
1953
1954     rb[XX] = b * xil[XX];
1955     rb[YY] = b * xil[YY];
1956     rb[ZZ] = b * xil[ZZ];
1957
1958     /* 6 flops */
1959
1960     rvec_sub(ra, xij, rja);
1961     rvec_sub(rb, xij, rjb);
1962     rvec_sub(rb, ra, rab);
1963     /* 9 flops */
1964
1965     cprod(rja, rjb, rm);
1966     /* 9 flops */
1967
1968     invrm = inverseNorm(rm);
1969     denom = invrm * invrm;
1970     /* 5+5+2 flops */
1971
1972     cfx = c * invrm * fv[XX];
1973     cfy = c * invrm * fv[YY];
1974     cfz = c * invrm * fv[ZZ];
1975     /* 6 Flops */
1976
1977     cprod(rm, rab, rt);
1978     /* 9 flops */
1979
1980     rt[XX] *= denom;
1981     rt[YY] *= denom;
1982     rt[ZZ] *= denom;
1983     /* 3flops */
1984
1985     fj[XX] = (-rm[XX] * rt[XX]) * cfx + (rab[ZZ] - rm[YY] * rt[XX]) * cfy
1986              + (-rab[YY] - rm[ZZ] * rt[XX]) * cfz;
1987     fj[YY] = (-rab[ZZ] - rm[XX] * rt[YY]) * cfx + (-rm[YY] * rt[YY]) * cfy
1988              + (rab[XX] - rm[ZZ] * rt[YY]) * cfz;
1989     fj[ZZ] = (rab[YY] - rm[XX] * rt[ZZ]) * cfx + (-rab[XX] - rm[YY] * rt[ZZ]) * cfy
1990              + (-rm[ZZ] * rt[ZZ]) * cfz;
1991     /* 30 flops */
1992
1993     cprod(rjb, rm, rt);
1994     /* 9 flops */
1995
1996     rt[XX] *= denom * a;
1997     rt[YY] *= denom * a;
1998     rt[ZZ] *= denom * a;
1999     /* 3flops */
2000
2001     fk[XX] = (-rm[XX] * rt[XX]) * cfx + (-a * rjb[ZZ] - rm[YY] * rt[XX]) * cfy
2002              + (a * rjb[YY] - rm[ZZ] * rt[XX]) * cfz;
2003     fk[YY] = (a * rjb[ZZ] - rm[XX] * rt[YY]) * cfx + (-rm[YY] * rt[YY]) * cfy
2004              + (-a * rjb[XX] - rm[ZZ] * rt[YY]) * cfz;
2005     fk[ZZ] = (-a * rjb[YY] - rm[XX] * rt[ZZ]) * cfx + (a * rjb[XX] - rm[YY] * rt[ZZ]) * cfy
2006              + (-rm[ZZ] * rt[ZZ]) * cfz;
2007     /* 36 flops */
2008
2009     cprod(rm, rja, rt);
2010     /* 9 flops */
2011
2012     rt[XX] *= denom * b;
2013     rt[YY] *= denom * b;
2014     rt[ZZ] *= denom * b;
2015     /* 3flops */
2016
2017     fl[XX] = (-rm[XX] * rt[XX]) * cfx + (b * rja[ZZ] - rm[YY] * rt[XX]) * cfy
2018              + (-b * rja[YY] - rm[ZZ] * rt[XX]) * cfz;
2019     fl[YY] = (-b * rja[ZZ] - rm[XX] * rt[YY]) * cfx + (-rm[YY] * rt[YY]) * cfy
2020              + (b * rja[XX] - rm[ZZ] * rt[YY]) * cfz;
2021     fl[ZZ] = (b * rja[YY] - rm[XX] * rt[ZZ]) * cfx + (-b * rja[XX] - rm[YY] * rt[ZZ]) * cfy
2022              + (-rm[ZZ] * rt[ZZ]) * cfz;
2023     /* 36 flops */
2024
2025     f[ai][XX] += fv[XX] - fj[XX] - fk[XX] - fl[XX];
2026     f[ai][YY] += fv[YY] - fj[YY] - fk[YY] - fl[YY];
2027     f[ai][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ] - fl[ZZ];
2028     rvec_inc(f[aj], fj);
2029     rvec_inc(f[ak], fk);
2030     rvec_inc(f[al], fl);
2031     /* 21 flops */
2032
2033     if (virialHandling == VirialHandling::Pbc)
2034     {
2035         int svi;
2036         if (pbc)
2037         {
2038             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
2039         }
2040         else
2041         {
2042             svi = c_centralShiftIndex;
2043         }
2044
2045         if (svi != c_centralShiftIndex || sij != c_centralShiftIndex || sik != c_centralShiftIndex
2046             || sil != c_centralShiftIndex)
2047         {
2048             rvec_dec(fshift[svi], fv);
2049             fshift[c_centralShiftIndex][XX] += fv[XX] - fj[XX] - fk[XX] - fl[XX];
2050             fshift[c_centralShiftIndex][YY] += fv[YY] - fj[YY] - fk[YY] - fl[YY];
2051             fshift[c_centralShiftIndex][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ] - fl[ZZ];
2052             rvec_inc(fshift[sij], fj);
2053             rvec_inc(fshift[sik], fk);
2054             rvec_inc(fshift[sil], fl);
2055         }
2056     }
2057
2058     if (virialHandling == VirialHandling::NonLinear)
2059     {
2060         rvec xiv;
2061         int  i, j;
2062
2063         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
2064
2065         for (i = 0; i < DIM; i++)
2066         {
2067             for (j = 0; j < DIM; j++)
2068             {
2069                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * fj[j] + xik[i] * fk[j] + xil[i] * fl[j];
2070             }
2071         }
2072     }
2073
2074     /* Total: 207 flops (Yuck!) */
2075 }
2076
2077 template<VirialHandling virialHandling>
2078 static int spread_vsiten(const t_iatom             ia[],
2079                          ArrayRef<const t_iparams> ip,
2080                          ArrayRef<const RVec>      x,
2081                          ArrayRef<RVec>            f,
2082                          ArrayRef<RVec>            fshift,
2083                          const t_pbc*              pbc)
2084 {
2085     rvec xv, dx, fi;
2086     int  n3, av, i, ai;
2087     real a;
2088     int  siv;
2089
2090     n3 = 3 * ip[ia[0]].vsiten.n;
2091     av = ia[1];
2092     copy_rvec(x[av], xv);
2093
2094     for (i = 0; i < n3; i += 3)
2095     {
2096         ai = ia[i + 2];
2097         if (pbc)
2098         {
2099             siv = pbc_dx_aiuc(pbc, x[ai], xv, dx);
2100         }
2101         else
2102         {
2103             siv = c_centralShiftIndex;
2104         }
2105         a = ip[ia[i]].vsiten.a;
2106         svmul(a, f[av], fi);
2107         rvec_inc(f[ai], fi);
2108
2109         if (virialHandling == VirialHandling::Pbc && siv != c_centralShiftIndex)
2110         {
2111             rvec_inc(fshift[siv], fi);
2112             rvec_dec(fshift[c_centralShiftIndex], fi);
2113         }
2114         /* 6 Flops */
2115     }
2116
2117     return n3;
2118 }
2119
2120 #endif // DOXYGEN
2121
2122 //! Returns the number of virtual sites in the interaction list, for VSITEN the number of atoms
2123 static int vsite_count(ArrayRef<const InteractionList> ilist, int ftype)
2124 {
2125     if (ftype == F_VSITEN)
2126     {
2127         return ilist[ftype].size() / 3;
2128     }
2129     else
2130     {
2131         return ilist[ftype].size() / (1 + interaction_function[ftype].nratoms);
2132     }
2133 }
2134
2135 //! Executes the force spreading task for a single thread
2136 template<VirialHandling virialHandling>
2137 static void spreadForceForThread(ArrayRef<const RVec>            x,
2138                                  ArrayRef<RVec>                  f,
2139                                  ArrayRef<RVec>                  fshift,
2140                                  matrix                          dxdf,
2141                                  ArrayRef<const t_iparams>       ip,
2142                                  ArrayRef<const InteractionList> ilist,
2143                                  const t_pbc*                    pbc_null)
2144 {
2145     const PbcMode pbcMode = getPbcMode(pbc_null);
2146     /* We need another pbc pointer, as with charge groups we switch per vsite */
2147     const t_pbc*             pbc_null2 = pbc_null;
2148     gmx::ArrayRef<const int> vsite_pbc;
2149
2150     /* this loop goes backwards to be able to build *
2151      * higher type vsites from lower types         */
2152     for (int ftype = c_ftypeVsiteEnd - 1; ftype >= c_ftypeVsiteStart; ftype--)
2153     {
2154         if (ilist[ftype].empty())
2155         {
2156             continue;
2157         }
2158
2159         { // TODO remove me
2160             int nra = interaction_function[ftype].nratoms;
2161             int inc = 1 + nra;
2162             int nr  = ilist[ftype].size();
2163
2164             const t_iatom* ia = ilist[ftype].iatoms.data();
2165
2166             if (pbcMode == PbcMode::all)
2167             {
2168                 pbc_null2 = pbc_null;
2169             }
2170
2171             for (int i = 0; i < nr;)
2172             {
2173                 int tp = ia[0];
2174
2175                 /* Constants for constructing */
2176                 real a1, b1, c1;
2177                 a1 = ip[tp].vsite.a;
2178                 /* Construct the vsite depending on type */
2179                 switch (ftype)
2180                 {
2181                     case F_VSITE1: spread_vsite1(ia, f); break;
2182                     case F_VSITE2:
2183                         spread_vsite2<virialHandling>(ia, a1, x, f, fshift, pbc_null2);
2184                         break;
2185                     case F_VSITE2FD:
2186                         spread_vsite2FD<virialHandling>(ia, a1, x, f, fshift, dxdf, pbc_null2);
2187                         break;
2188                     case F_VSITE3:
2189                         b1 = ip[tp].vsite.b;
2190                         spread_vsite3<virialHandling>(ia, a1, b1, x, f, fshift, pbc_null2);
2191                         break;
2192                     case F_VSITE3FD:
2193                         b1 = ip[tp].vsite.b;
2194                         spread_vsite3FD<virialHandling>(ia, a1, b1, x, f, fshift, dxdf, pbc_null2);
2195                         break;
2196                     case F_VSITE3FAD:
2197                         b1 = ip[tp].vsite.b;
2198                         spread_vsite3FAD<virialHandling>(ia, a1, b1, x, f, fshift, dxdf, pbc_null2);
2199                         break;
2200                     case F_VSITE3OUT:
2201                         b1 = ip[tp].vsite.b;
2202                         c1 = ip[tp].vsite.c;
2203                         spread_vsite3OUT<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
2204                         break;
2205                     case F_VSITE4FD:
2206                         b1 = ip[tp].vsite.b;
2207                         c1 = ip[tp].vsite.c;
2208                         spread_vsite4FD<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
2209                         break;
2210                     case F_VSITE4FDN:
2211                         b1 = ip[tp].vsite.b;
2212                         c1 = ip[tp].vsite.c;
2213                         spread_vsite4FDN<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
2214                         break;
2215                     case F_VSITEN:
2216                         inc = spread_vsiten<virialHandling>(ia, ip, x, f, fshift, pbc_null2);
2217                         break;
2218                     default:
2219                         gmx_fatal(FARGS, "No such vsite type %d in %s, line %d", ftype, __FILE__, __LINE__);
2220                 }
2221                 clear_rvec(f[ia[1]]);
2222
2223                 /* Increment loop variables */
2224                 i += inc;
2225                 ia += inc;
2226             }
2227         }
2228     }
2229 }
2230
2231 //! Wrapper function for calling the templated thread-local spread function
2232 static void spreadForceWrapper(ArrayRef<const RVec>            x,
2233                                ArrayRef<RVec>                  f,
2234                                const VirialHandling            virialHandling,
2235                                ArrayRef<RVec>                  fshift,
2236                                matrix                          dxdf,
2237                                const bool                      clearDxdf,
2238                                ArrayRef<const t_iparams>       ip,
2239                                ArrayRef<const InteractionList> ilist,
2240                                const t_pbc*                    pbc_null)
2241 {
2242     if (virialHandling == VirialHandling::NonLinear && clearDxdf)
2243     {
2244         clear_mat(dxdf);
2245     }
2246
2247     switch (virialHandling)
2248     {
2249         case VirialHandling::None:
2250             spreadForceForThread<VirialHandling::None>(x, f, fshift, dxdf, ip, ilist, pbc_null);
2251             break;
2252         case VirialHandling::Pbc:
2253             spreadForceForThread<VirialHandling::Pbc>(x, f, fshift, dxdf, ip, ilist, pbc_null);
2254             break;
2255         case VirialHandling::NonLinear:
2256             spreadForceForThread<VirialHandling::NonLinear>(x, f, fshift, dxdf, ip, ilist, pbc_null);
2257             break;
2258     }
2259 }
2260
2261 //! Clears the task force buffer elements that are written by task idTask
2262 static void clearTaskForceBufferUsedElements(InterdependentTask* idTask)
2263 {
2264     int ntask = idTask->spreadTask.size();
2265     for (int ti = 0; ti < ntask; ti++)
2266     {
2267         const AtomIndex* atomList = &idTask->atomIndex[idTask->spreadTask[ti]];
2268         int              natom    = atomList->atom.size();
2269         RVec*            force    = idTask->force.data();
2270         for (int i = 0; i < natom; i++)
2271         {
2272             clear_rvec(force[atomList->atom[i]]);
2273         }
2274     }
2275 }
2276
2277 void VirtualSitesHandler::Impl::spreadForces(ArrayRef<const RVec> x,
2278                                              ArrayRef<RVec>       f,
2279                                              const VirialHandling virialHandling,
2280                                              ArrayRef<RVec>       fshift,
2281                                              matrix               virial,
2282                                              t_nrnb*              nrnb,
2283                                              const matrix         box,
2284                                              gmx_wallcycle*       wcycle)
2285 {
2286     wallcycle_start(wcycle, WallCycleCounter::VsiteSpread);
2287
2288     const bool useDomdec = domainInfo_.useDomdec();
2289
2290     t_pbc pbc, *pbc_null;
2291
2292     if (domainInfo_.useMolPbc_)
2293     {
2294         /* This is wasting some CPU time as we now do this multiple times
2295          * per MD step.
2296          */
2297         pbc_null = set_pbc_dd(
2298                 &pbc, domainInfo_.pbcType_, useDomdec ? domainInfo_.domdec_->numCells : nullptr, FALSE, box);
2299     }
2300     else
2301     {
2302         pbc_null = nullptr;
2303     }
2304
2305     if (useDomdec)
2306     {
2307         dd_clear_f_vsites(*domainInfo_.domdec_, f);
2308     }
2309
2310     const int numThreads = threadingInfo_.numThreads();
2311
2312     if (numThreads == 1)
2313     {
2314         matrix dxdf;
2315         spreadForceWrapper(x, f, virialHandling, fshift, dxdf, true, iparams_, ilists_, pbc_null);
2316
2317         if (virialHandling == VirialHandling::NonLinear)
2318         {
2319             for (int i = 0; i < DIM; i++)
2320             {
2321                 for (int j = 0; j < DIM; j++)
2322                 {
2323                     virial[i][j] += -0.5 * dxdf[i][j];
2324                 }
2325             }
2326         }
2327     }
2328     else
2329     {
2330         /* First spread the vsites that might depend on non-local vsites */
2331         auto& nlDependentVSites = threadingInfo_.threadDataNonLocalDependent();
2332         spreadForceWrapper(x,
2333                            f,
2334                            virialHandling,
2335                            fshift,
2336                            nlDependentVSites.dxdf,
2337                            true,
2338                            iparams_,
2339                            nlDependentVSites.ilist,
2340                            pbc_null);
2341
2342 #pragma omp parallel num_threads(numThreads)
2343         {
2344             try
2345             {
2346                 int          thread = gmx_omp_get_thread_num();
2347                 VsiteThread& tData  = threadingInfo_.threadData(thread);
2348
2349                 ArrayRef<RVec> fshift_t;
2350                 if (virialHandling == VirialHandling::Pbc)
2351                 {
2352                     if (thread == 0)
2353                     {
2354                         fshift_t = fshift;
2355                     }
2356                     else
2357                     {
2358                         fshift_t = tData.fshift;
2359
2360                         for (int i = 0; i < c_numShiftVectors; i++)
2361                         {
2362                             clear_rvec(fshift_t[i]);
2363                         }
2364                     }
2365                 }
2366
2367                 if (tData.useInterdependentTask)
2368                 {
2369                     /* Spread the vsites that spread outside our local range.
2370                      * This is done using a thread-local force buffer force.
2371                      * First we need to copy the input vsite forces to force.
2372                      */
2373                     InterdependentTask* idTask = &tData.idTask;
2374
2375                     /* Clear the buffer elements set by our task during
2376                      * the last call to spread_vsite_f.
2377                      */
2378                     clearTaskForceBufferUsedElements(idTask);
2379
2380                     int nvsite = idTask->vsite.size();
2381                     for (int i = 0; i < nvsite; i++)
2382                     {
2383                         copy_rvec(f[idTask->vsite[i]], idTask->force[idTask->vsite[i]]);
2384                     }
2385                     spreadForceWrapper(x,
2386                                        idTask->force,
2387                                        virialHandling,
2388                                        fshift_t,
2389                                        tData.dxdf,
2390                                        true,
2391                                        iparams_,
2392                                        tData.idTask.ilist,
2393                                        pbc_null);
2394
2395                     /* We need a barrier before reducing forces below
2396                      * that have been produced by a different thread above.
2397                      */
2398 #pragma omp barrier
2399
2400                     /* Loop over all thread task and reduce forces they
2401                      * produced on atoms that fall in our range.
2402                      * Note that atomic reduction would be a simpler solution,
2403                      * but that might not have good support on all platforms.
2404                      */
2405                     int ntask = idTask->reduceTask.size();
2406                     for (int ti = 0; ti < ntask; ti++)
2407                     {
2408                         const InterdependentTask& idt_foreign =
2409                                 threadingInfo_.threadData(idTask->reduceTask[ti]).idTask;
2410                         const AtomIndex& atomList  = idt_foreign.atomIndex[thread];
2411                         const RVec*      f_foreign = idt_foreign.force.data();
2412
2413                         for (int ind : atomList.atom)
2414                         {
2415                             rvec_inc(f[ind], f_foreign[ind]);
2416                             /* Clearing of f_foreign is done at the next step */
2417                         }
2418                     }
2419                     /* Clear the vsite forces, both in f and force */
2420                     for (int i = 0; i < nvsite; i++)
2421                     {
2422                         int ind = tData.idTask.vsite[i];
2423                         clear_rvec(f[ind]);
2424                         clear_rvec(tData.idTask.force[ind]);
2425                     }
2426                 }
2427
2428                 /* Spread the vsites that spread locally only */
2429                 spreadForceWrapper(
2430                         x, f, virialHandling, fshift_t, tData.dxdf, false, iparams_, tData.ilist, pbc_null);
2431             }
2432             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
2433         }
2434
2435         if (virialHandling == VirialHandling::Pbc)
2436         {
2437             for (int th = 1; th < numThreads; th++)
2438             {
2439                 for (int i = 0; i < c_numShiftVectors; i++)
2440                 {
2441                     rvec_inc(fshift[i], threadingInfo_.threadData(th).fshift[i]);
2442                 }
2443             }
2444         }
2445
2446         if (virialHandling == VirialHandling::NonLinear)
2447         {
2448             for (int th = 0; th < numThreads + 1; th++)
2449             {
2450                 /* MSVC doesn't like matrix references, so we use a pointer */
2451                 const matrix& dxdf = threadingInfo_.threadData(th).dxdf;
2452
2453                 for (int i = 0; i < DIM; i++)
2454                 {
2455                     for (int j = 0; j < DIM; j++)
2456                     {
2457                         virial[i][j] += -0.5 * dxdf[i][j];
2458                     }
2459                 }
2460             }
2461         }
2462     }
2463
2464     if (useDomdec)
2465     {
2466         dd_move_f_vsites(*domainInfo_.domdec_, f, fshift);
2467     }
2468
2469     inc_nrnb(nrnb, eNR_VSITE1, vsite_count(ilists_, F_VSITE1));
2470     inc_nrnb(nrnb, eNR_VSITE2, vsite_count(ilists_, F_VSITE2));
2471     inc_nrnb(nrnb, eNR_VSITE2FD, vsite_count(ilists_, F_VSITE2FD));
2472     inc_nrnb(nrnb, eNR_VSITE3, vsite_count(ilists_, F_VSITE3));
2473     inc_nrnb(nrnb, eNR_VSITE3FD, vsite_count(ilists_, F_VSITE3FD));
2474     inc_nrnb(nrnb, eNR_VSITE3FAD, vsite_count(ilists_, F_VSITE3FAD));
2475     inc_nrnb(nrnb, eNR_VSITE3OUT, vsite_count(ilists_, F_VSITE3OUT));
2476     inc_nrnb(nrnb, eNR_VSITE4FD, vsite_count(ilists_, F_VSITE4FD));
2477     inc_nrnb(nrnb, eNR_VSITE4FDN, vsite_count(ilists_, F_VSITE4FDN));
2478     inc_nrnb(nrnb, eNR_VSITEN, vsite_count(ilists_, F_VSITEN));
2479
2480     wallcycle_stop(wcycle, WallCycleCounter::VsiteSpread);
2481 }
2482
2483 /*! \brief Returns the an array with group indices for each atom
2484  *
2485  * \param[in] grouping  The paritioning of the atom range into atom groups
2486  */
2487 static std::vector<int> makeAtomToGroupMapping(const gmx::RangePartitioning& grouping)
2488 {
2489     std::vector<int> atomToGroup(grouping.fullRange().end(), 0);
2490
2491     for (int group = 0; group < grouping.numBlocks(); group++)
2492     {
2493         auto block = grouping.block(group);
2494         std::fill(atomToGroup.begin() + block.begin(), atomToGroup.begin() + block.end(), group);
2495     }
2496
2497     return atomToGroup;
2498 }
2499
2500 int countNonlinearVsites(const gmx_mtop_t& mtop)
2501 {
2502     int numNonlinearVsites = 0;
2503     for (const gmx_molblock_t& molb : mtop.molblock)
2504     {
2505         const gmx_moltype_t& molt = mtop.moltype[molb.type];
2506
2507         for (const auto& ilist : extractILists(molt.ilist, IF_VSITE))
2508         {
2509             if (ilist.functionType != F_VSITE2 && ilist.functionType != F_VSITE3
2510                 && ilist.functionType != F_VSITEN)
2511             {
2512                 numNonlinearVsites += molb.nmol * ilist.iatoms.size() / (1 + NRAL(ilist.functionType));
2513             }
2514         }
2515     }
2516
2517     return numNonlinearVsites;
2518 }
2519
2520 void VirtualSitesHandler::spreadForces(ArrayRef<const RVec> x,
2521                                        ArrayRef<RVec>       f,
2522                                        const VirialHandling virialHandling,
2523                                        ArrayRef<RVec>       fshift,
2524                                        matrix               virial,
2525                                        t_nrnb*              nrnb,
2526                                        const matrix         box,
2527                                        gmx_wallcycle*       wcycle)
2528 {
2529     impl_->spreadForces(x, f, virialHandling, fshift, virial, nrnb, box, wcycle);
2530 }
2531
2532 int countInterUpdategroupVsites(const gmx_mtop_t&                           mtop,
2533                                 gmx::ArrayRef<const gmx::RangePartitioning> updateGroupingPerMoleculetype)
2534 {
2535     int n_intercg_vsite = 0;
2536     for (const gmx_molblock_t& molb : mtop.molblock)
2537     {
2538         const gmx_moltype_t& molt = mtop.moltype[molb.type];
2539
2540         std::vector<int> atomToGroup;
2541         if (!updateGroupingPerMoleculetype.empty())
2542         {
2543             atomToGroup = makeAtomToGroupMapping(updateGroupingPerMoleculetype[molb.type]);
2544         }
2545         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2546         {
2547             const int              nral = NRAL(ftype);
2548             const InteractionList& il   = molt.ilist[ftype];
2549             for (int i = 0; i < il.size(); i += 1 + nral)
2550             {
2551                 bool isInterGroup = atomToGroup.empty();
2552                 if (!isInterGroup)
2553                 {
2554                     const int group = atomToGroup[il.iatoms[1 + i]];
2555                     for (int a = 1; a < nral; a++)
2556                     {
2557                         if (atomToGroup[il.iatoms[1 + a]] != group)
2558                         {
2559                             isInterGroup = true;
2560                             break;
2561                         }
2562                     }
2563                 }
2564                 if (isInterGroup)
2565                 {
2566                     n_intercg_vsite += molb.nmol;
2567                 }
2568             }
2569         }
2570     }
2571
2572     return n_intercg_vsite;
2573 }
2574
2575 std::unique_ptr<VirtualSitesHandler> makeVirtualSitesHandler(const gmx_mtop_t& mtop,
2576                                                              const t_commrec*  cr,
2577                                                              PbcType           pbcType)
2578 {
2579     GMX_RELEASE_ASSERT(cr != nullptr, "We need a valid commrec");
2580
2581     std::unique_ptr<VirtualSitesHandler> vsite;
2582
2583     /* check if there are vsites */
2584     int nvsite = 0;
2585     for (int ftype = 0; ftype < F_NRE; ftype++)
2586     {
2587         if (interaction_function[ftype].flags & IF_VSITE)
2588         {
2589             GMX_ASSERT(ftype >= c_ftypeVsiteStart && ftype < c_ftypeVsiteEnd,
2590                        "c_ftypeVsiteStart and/or c_ftypeVsiteEnd do not have correct values");
2591
2592             nvsite += gmx_mtop_ftype_count(mtop, ftype);
2593         }
2594         else
2595         {
2596             GMX_ASSERT(ftype < c_ftypeVsiteStart || ftype >= c_ftypeVsiteEnd,
2597                        "c_ftypeVsiteStart and/or c_ftypeVsiteEnd do not have correct values");
2598         }
2599     }
2600
2601     if (nvsite == 0)
2602     {
2603         return vsite;
2604     }
2605
2606     return std::make_unique<VirtualSitesHandler>(mtop, cr->dd, pbcType);
2607 }
2608
2609 ThreadingInfo::ThreadingInfo() : numThreads_(gmx_omp_nthreads_get(emntVSITE))
2610 {
2611     if (numThreads_ > 1)
2612     {
2613         /* We need one extra thread data structure for the overlap vsites */
2614         tData_.resize(numThreads_ + 1);
2615 #pragma omp parallel for num_threads(numThreads_) schedule(static)
2616         for (int thread = 0; thread < numThreads_; thread++)
2617         {
2618             try
2619             {
2620                 tData_[thread] = std::make_unique<VsiteThread>();
2621
2622                 InterdependentTask& idTask = tData_[thread]->idTask;
2623                 idTask.nuse                = 0;
2624                 idTask.atomIndex.resize(numThreads_);
2625             }
2626             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
2627         }
2628         if (numThreads_ > 1)
2629         {
2630             tData_[numThreads_] = std::make_unique<VsiteThread>();
2631         }
2632     }
2633 }
2634
2635 //! Returns the number of inter update-group vsites
2636 static int getNumInterUpdategroupVsites(const gmx_mtop_t& mtop, const gmx_domdec_t* domdec)
2637 {
2638     gmx::ArrayRef<const gmx::RangePartitioning> updateGroupingPerMoleculetype;
2639     if (domdec)
2640     {
2641         updateGroupingPerMoleculetype = getUpdateGroupingPerMoleculetype(*domdec);
2642     }
2643
2644     return countInterUpdategroupVsites(mtop, updateGroupingPerMoleculetype);
2645 }
2646
2647 VirtualSitesHandler::Impl::Impl(const gmx_mtop_t& mtop, gmx_domdec_t* domdec, const PbcType pbcType) :
2648     numInterUpdategroupVirtualSites_(getNumInterUpdategroupVsites(mtop, domdec)),
2649     domainInfo_({ pbcType, pbcType != PbcType::No && numInterUpdategroupVirtualSites_ > 0, domdec }),
2650     iparams_(mtop.ffparams.iparams)
2651 {
2652 }
2653
2654 VirtualSitesHandler::VirtualSitesHandler(const gmx_mtop_t& mtop, gmx_domdec_t* domdec, const PbcType pbcType) :
2655     impl_(new Impl(mtop, domdec, pbcType))
2656 {
2657 }
2658
2659 //! Flag that atom \p atom which is home in another task, if it has not already been added before
2660 static inline void flagAtom(InterdependentTask* idTask, const int atom, const int numThreads, const int numAtomsPerThread)
2661 {
2662     if (!idTask->use[atom])
2663     {
2664         idTask->use[atom] = true;
2665         int thread        = atom / numAtomsPerThread;
2666         /* Assign all non-local atom force writes to thread 0 */
2667         if (thread >= numThreads)
2668         {
2669             thread = 0;
2670         }
2671         idTask->atomIndex[thread].atom.push_back(atom);
2672     }
2673 }
2674
2675 /*! \brief Here we try to assign all vsites that are in our local range.
2676  *
2677  * Our task local atom range is tData->rangeStart - tData->rangeEnd.
2678  * Vsites that depend only on local atoms, as indicated by taskIndex[]==thread,
2679  * are assigned to task tData->ilist. Vsites that depend on non-local atoms
2680  * but not on other vsites are assigned to task tData->id_task.ilist.
2681  * taskIndex[] is set for all vsites in our range, either to our local tasks
2682  * or to the single last task as taskIndex[]=2*nthreads.
2683  */
2684 static void assignVsitesToThread(VsiteThread*                    tData,
2685                                  int                             thread,
2686                                  int                             nthread,
2687                                  int                             natperthread,
2688                                  gmx::ArrayRef<int>              taskIndex,
2689                                  ArrayRef<const InteractionList> ilist,
2690                                  ArrayRef<const t_iparams>       ip,
2691                                  const ParticleType*             ptype)
2692 {
2693     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2694     {
2695         tData->ilist[ftype].clear();
2696         tData->idTask.ilist[ftype].clear();
2697
2698         const int  nral1 = 1 + NRAL(ftype);
2699         const int* iat   = ilist[ftype].iatoms.data();
2700         for (int i = 0; i < ilist[ftype].size();)
2701         {
2702             /* Get the number of iatom entries in this virtual site.
2703              * The 3 below for F_VSITEN is from 1+NRAL(ftype)=3
2704              */
2705             const int numIAtoms = (ftype == F_VSITEN ? ip[iat[i]].vsiten.n * 3 : nral1);
2706
2707             if (iat[1 + i] < tData->rangeStart || iat[1 + i] >= tData->rangeEnd)
2708             {
2709                 /* This vsite belongs to a different thread */
2710                 i += numIAtoms;
2711                 continue;
2712             }
2713
2714             /* We would like to assign this vsite to task thread,
2715              * but it might depend on atoms outside the atom range of thread
2716              * or on another vsite not assigned to task thread.
2717              */
2718             int task = thread;
2719             if (ftype != F_VSITEN)
2720             {
2721                 for (int j = i + 2; j < i + nral1; j++)
2722                 {
2723                     /* Do a range check to avoid a harmless race on taskIndex */
2724                     if (iat[j] < tData->rangeStart || iat[j] >= tData->rangeEnd || taskIndex[iat[j]] != thread)
2725                     {
2726                         if (!tData->useInterdependentTask || ptype[iat[j]] == ParticleType::VSite)
2727                         {
2728                             /* At least one constructing atom is a vsite
2729                              * that is not assigned to the same thread.
2730                              * Put this vsite into a separate task.
2731                              */
2732                             task = 2 * nthread;
2733                             break;
2734                         }
2735
2736                         /* There are constructing atoms outside our range,
2737                          * put this vsite into a second task to be executed
2738                          * on the same thread. During construction no barrier
2739                          * is needed between the two tasks on the same thread.
2740                          * During spreading we need to run this task with
2741                          * an additional thread-local intermediate force buffer
2742                          * (or atomic reduction) and a barrier between the two
2743                          * tasks.
2744                          */
2745                         task = nthread + thread;
2746                     }
2747                 }
2748             }
2749             else
2750             {
2751                 for (int j = i + 2; j < i + numIAtoms; j += 3)
2752                 {
2753                     /* Do a range check to avoid a harmless race on taskIndex */
2754                     if (iat[j] < tData->rangeStart || iat[j] >= tData->rangeEnd || taskIndex[iat[j]] != thread)
2755                     {
2756                         GMX_ASSERT(ptype[iat[j]] != ParticleType::VSite,
2757                                    "A vsite to be assigned in assignVsitesToThread has a vsite as "
2758                                    "a constructing atom that does not belong to our task, such "
2759                                    "vsites should be assigned to the single 'master' task");
2760
2761                         task = nthread + thread;
2762                     }
2763                 }
2764             }
2765
2766             /* Update this vsite's thread index entry */
2767             taskIndex[iat[1 + i]] = task;
2768
2769             if (task == thread || task == nthread + thread)
2770             {
2771                 /* Copy this vsite to the thread data struct of thread */
2772                 InteractionList* il_task;
2773                 if (task == thread)
2774                 {
2775                     il_task = &tData->ilist[ftype];
2776                 }
2777                 else
2778                 {
2779                     il_task = &tData->idTask.ilist[ftype];
2780                 }
2781                 /* Copy the vsite data to the thread-task local array */
2782                 il_task->push_back(iat[i], numIAtoms - 1, iat + i + 1);
2783                 if (task == nthread + thread)
2784                 {
2785                     /* This vsite writes outside our own task force block.
2786                      * Put it into the interdependent task list and flag
2787                      * the atoms involved for reduction.
2788                      */
2789                     tData->idTask.vsite.push_back(iat[i + 1]);
2790                     if (ftype != F_VSITEN)
2791                     {
2792                         for (int j = i + 2; j < i + nral1; j++)
2793                         {
2794                             flagAtom(&tData->idTask, iat[j], nthread, natperthread);
2795                         }
2796                     }
2797                     else
2798                     {
2799                         for (int j = i + 2; j < i + numIAtoms; j += 3)
2800                         {
2801                             flagAtom(&tData->idTask, iat[j], nthread, natperthread);
2802                         }
2803                     }
2804                 }
2805             }
2806
2807             i += numIAtoms;
2808         }
2809     }
2810 }
2811
2812 /*! \brief Assign all vsites with taskIndex[]==task to task tData */
2813 static void assignVsitesToSingleTask(VsiteThread*                    tData,
2814                                      int                             task,
2815                                      gmx::ArrayRef<const int>        taskIndex,
2816                                      ArrayRef<const InteractionList> ilist,
2817                                      ArrayRef<const t_iparams>       ip)
2818 {
2819     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2820     {
2821         tData->ilist[ftype].clear();
2822         tData->idTask.ilist[ftype].clear();
2823
2824         int              nral1   = 1 + NRAL(ftype);
2825         int              inc     = nral1;
2826         const int*       iat     = ilist[ftype].iatoms.data();
2827         InteractionList* il_task = &tData->ilist[ftype];
2828
2829         for (int i = 0; i < ilist[ftype].size();)
2830         {
2831             if (ftype == F_VSITEN)
2832             {
2833                 /* The 3 below is from 1+NRAL(ftype)=3 */
2834                 inc = ip[iat[i]].vsiten.n * 3;
2835             }
2836             /* Check if the vsite is assigned to our task */
2837             if (taskIndex[iat[1 + i]] == task)
2838             {
2839                 /* Copy the vsite data to the thread-task local array */
2840                 il_task->push_back(iat[i], inc - 1, iat + i + 1);
2841             }
2842
2843             i += inc;
2844         }
2845     }
2846 }
2847
2848 void ThreadingInfo::setVirtualSites(ArrayRef<const InteractionList> ilists,
2849                                     ArrayRef<const t_iparams>       iparams,
2850                                     const t_mdatoms&                mdatoms,
2851                                     const bool                      useDomdec)
2852 {
2853     if (numThreads_ <= 1)
2854     {
2855         /* Nothing to do */
2856         return;
2857     }
2858
2859     /* The current way of distributing the vsites over threads in primitive.
2860      * We divide the atom range 0 - natoms_in_vsite uniformly over threads,
2861      * without taking into account how the vsites are distributed.
2862      * Without domain decomposition we at least tighten the upper bound
2863      * of the range (useful for common systems such as a vsite-protein
2864      * in 3-site water).
2865      * With domain decomposition, as long as the vsites are distributed
2866      * uniformly in each domain along the major dimension, usually x,
2867      * it will also perform well.
2868      */
2869     int vsite_atom_range;
2870     int natperthread;
2871     if (!useDomdec)
2872     {
2873         vsite_atom_range = -1;
2874         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2875         {
2876             { // TODO remove me
2877                 if (ftype != F_VSITEN)
2878                 {
2879                     int                 nral1 = 1 + NRAL(ftype);
2880                     ArrayRef<const int> iat   = ilists[ftype].iatoms;
2881                     for (int i = 0; i < ilists[ftype].size(); i += nral1)
2882                     {
2883                         for (int j = i + 1; j < i + nral1; j++)
2884                         {
2885                             vsite_atom_range = std::max(vsite_atom_range, iat[j]);
2886                         }
2887                     }
2888                 }
2889                 else
2890                 {
2891                     int vs_ind_end;
2892
2893                     ArrayRef<const int> iat = ilists[ftype].iatoms;
2894
2895                     int i = 0;
2896                     while (i < ilists[ftype].size())
2897                     {
2898                         /* The 3 below is from 1+NRAL(ftype)=3 */
2899                         vs_ind_end = i + iparams[iat[i]].vsiten.n * 3;
2900
2901                         vsite_atom_range = std::max(vsite_atom_range, iat[i + 1]);
2902                         while (i < vs_ind_end)
2903                         {
2904                             vsite_atom_range = std::max(vsite_atom_range, iat[i + 2]);
2905                             i += 3;
2906                         }
2907                     }
2908                 }
2909             }
2910         }
2911         vsite_atom_range++;
2912         natperthread = (vsite_atom_range + numThreads_ - 1) / numThreads_;
2913     }
2914     else
2915     {
2916         /* Any local or not local atom could be involved in virtual sites.
2917          * But since we usually have very few non-local virtual sites
2918          * (only non-local vsites that depend on local vsites),
2919          * we distribute the local atom range equally over the threads.
2920          * When assigning vsites to threads, we should take care that the last
2921          * threads also covers the non-local range.
2922          */
2923         vsite_atom_range = mdatoms.nr;
2924         natperthread     = (mdatoms.homenr + numThreads_ - 1) / numThreads_;
2925     }
2926
2927     if (debug)
2928     {
2929         fprintf(debug,
2930                 "virtual site thread dist: natoms %d, range %d, natperthread %d\n",
2931                 mdatoms.nr,
2932                 vsite_atom_range,
2933                 natperthread);
2934     }
2935
2936     /* To simplify the vsite assignment, we make an index which tells us
2937      * to which task particles, both non-vsites and vsites, are assigned.
2938      */
2939     taskIndex_.resize(mdatoms.nr);
2940
2941     /* Initialize the task index array. Here we assign the non-vsite
2942      * particles to task=thread, so we easily figure out if vsites
2943      * depend on local and/or non-local particles in assignVsitesToThread.
2944      */
2945     {
2946         int thread = 0;
2947         for (int i = 0; i < mdatoms.nr; i++)
2948         {
2949             if (mdatoms.ptype[i] == ParticleType::VSite)
2950             {
2951                 /* vsites are not assigned to a task yet */
2952                 taskIndex_[i] = -1;
2953             }
2954             else
2955             {
2956                 /* assign non-vsite particles to task thread */
2957                 taskIndex_[i] = thread;
2958             }
2959             if (i == (thread + 1) * natperthread && thread < numThreads_)
2960             {
2961                 thread++;
2962             }
2963         }
2964     }
2965
2966 #pragma omp parallel num_threads(numThreads_)
2967     {
2968         try
2969         {
2970             int          thread = gmx_omp_get_thread_num();
2971             VsiteThread& tData  = *tData_[thread];
2972
2973             /* Clear the buffer use flags that were set before */
2974             if (tData.useInterdependentTask)
2975             {
2976                 InterdependentTask& idTask = tData.idTask;
2977
2978                 /* To avoid an extra OpenMP barrier in spread_vsite_f,
2979                  * we clear the force buffer at the next step,
2980                  * so we need to do it here as well.
2981                  */
2982                 clearTaskForceBufferUsedElements(&idTask);
2983
2984                 idTask.vsite.resize(0);
2985                 for (int t = 0; t < numThreads_; t++)
2986                 {
2987                     AtomIndex& atomIndex = idTask.atomIndex[t];
2988                     int        natom     = atomIndex.atom.size();
2989                     for (int i = 0; i < natom; i++)
2990                     {
2991                         idTask.use[atomIndex.atom[i]] = false;
2992                     }
2993                     atomIndex.atom.resize(0);
2994                 }
2995                 idTask.nuse = 0;
2996             }
2997
2998             /* To avoid large f_buf allocations of #threads*vsite_atom_range
2999              * we don't use task2 with more than 200000 atoms. This doesn't
3000              * affect performance, since with such a large range relatively few
3001              * vsites will end up in the separate task.
3002              * Note that useTask2 should be the same for all threads.
3003              */
3004             tData.useInterdependentTask = (vsite_atom_range <= 200000);
3005             if (tData.useInterdependentTask)
3006             {
3007                 size_t              natoms_use_in_vsites = vsite_atom_range;
3008                 InterdependentTask& idTask               = tData.idTask;
3009                 /* To avoid resizing and re-clearing every nstlist steps,
3010                  * we never down size the force buffer.
3011                  */
3012                 if (natoms_use_in_vsites > idTask.force.size() || natoms_use_in_vsites > idTask.use.size())
3013                 {
3014                     idTask.force.resize(natoms_use_in_vsites, { 0, 0, 0 });
3015                     idTask.use.resize(natoms_use_in_vsites, false);
3016                 }
3017             }
3018
3019             /* Assign all vsites that can execute independently on threads */
3020             tData.rangeStart = thread * natperthread;
3021             if (thread < numThreads_ - 1)
3022             {
3023                 tData.rangeEnd = (thread + 1) * natperthread;
3024             }
3025             else
3026             {
3027                 /* The last thread should cover up to the end of the range */
3028                 tData.rangeEnd = mdatoms.nr;
3029             }
3030             assignVsitesToThread(
3031                     &tData, thread, numThreads_, natperthread, taskIndex_, ilists, iparams, mdatoms.ptype);
3032
3033             if (tData.useInterdependentTask)
3034             {
3035                 /* In the worst case, all tasks write to force ranges of
3036                  * all other tasks, leading to #tasks^2 scaling (this is only
3037                  * the overhead, the actual flops remain constant).
3038                  * But in most cases there is far less coupling. To improve
3039                  * scaling at high thread counts we therefore construct
3040                  * an index to only loop over the actually affected tasks.
3041                  */
3042                 InterdependentTask& idTask = tData.idTask;
3043
3044                 /* Ensure assignVsitesToThread finished on other threads */
3045 #pragma omp barrier
3046
3047                 idTask.spreadTask.resize(0);
3048                 idTask.reduceTask.resize(0);
3049                 for (int t = 0; t < numThreads_; t++)
3050                 {
3051                     /* Do we write to the force buffer of task t? */
3052                     if (!idTask.atomIndex[t].atom.empty())
3053                     {
3054                         idTask.spreadTask.push_back(t);
3055                     }
3056                     /* Does task t write to our force buffer? */
3057                     if (!tData_[t]->idTask.atomIndex[thread].atom.empty())
3058                     {
3059                         idTask.reduceTask.push_back(t);
3060                     }
3061                 }
3062             }
3063         }
3064         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
3065     }
3066     /* Assign all remaining vsites, that will have taskIndex[]=2*vsite->nthreads,
3067      * to a single task that will not run in parallel with other tasks.
3068      */
3069     assignVsitesToSingleTask(tData_[numThreads_].get(), 2 * numThreads_, taskIndex_, ilists, iparams);
3070
3071     if (debug && numThreads_ > 1)
3072     {
3073         fprintf(debug,
3074                 "virtual site useInterdependentTask %d, nuse:\n",
3075                 static_cast<int>(tData_[0]->useInterdependentTask));
3076         for (int th = 0; th < numThreads_ + 1; th++)
3077         {
3078             fprintf(debug, " %4d", tData_[th]->idTask.nuse);
3079         }
3080         fprintf(debug, "\n");
3081
3082         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
3083         {
3084             if (!ilists[ftype].empty())
3085             {
3086                 fprintf(debug, "%-20s thread dist:", interaction_function[ftype].longname);
3087                 for (int th = 0; th < numThreads_ + 1; th++)
3088                 {
3089                     fprintf(debug,
3090                             " %4d %4d ",
3091                             tData_[th]->ilist[ftype].size(),
3092                             tData_[th]->idTask.ilist[ftype].size());
3093                 }
3094                 fprintf(debug, "\n");
3095             }
3096         }
3097     }
3098
3099 #ifndef NDEBUG
3100     int nrOrig     = vsiteIlistNrCount(ilists);
3101     int nrThreaded = 0;
3102     for (int th = 0; th < numThreads_ + 1; th++)
3103     {
3104         nrThreaded += vsiteIlistNrCount(tData_[th]->ilist) + vsiteIlistNrCount(tData_[th]->idTask.ilist);
3105     }
3106     GMX_ASSERT(nrThreaded == nrOrig,
3107                "The number of virtual sites assigned to all thread task has to match the total "
3108                "number of virtual sites");
3109 #endif
3110 }
3111
3112 void VirtualSitesHandler::Impl::setVirtualSites(ArrayRef<const InteractionList> ilists,
3113                                                 const t_mdatoms&                mdatoms)
3114 {
3115     ilists_ = ilists;
3116
3117     threadingInfo_.setVirtualSites(ilists, iparams_, mdatoms, domainInfo_.useDomdec());
3118 }
3119
3120 void VirtualSitesHandler::setVirtualSites(ArrayRef<const InteractionList> ilists, const t_mdatoms& mdatoms)
3121 {
3122     impl_->setVirtualSites(ilists, mdatoms);
3123 }
3124
3125 } // namespace gmx