f2c8b0ef666388eff2c6ea76dc6c5e5be4e588ab
[alexxy/gromacs.git] / src / gromacs / mdlib / vsite.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017 The GROMACS development team.
7  * Copyright (c) 2018,2019,2020, by the GROMACS development team, led by
8  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
9  * and including many others, as listed in the AUTHORS file in the
10  * top-level source directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 /*! \libinternal \file
39  * \brief Implements the VirtualSitesHandler class and vsite standalone functions
40  *
41  * \author Berk Hess <hess@kth.se>
42  * \ingroup module_mdlib
43  * \inlibraryapi
44  */
45
46 #include "gmxpre.h"
47
48 #include "vsite.h"
49
50 #include <cstdio>
51
52 #include <algorithm>
53 #include <memory>
54 #include <vector>
55
56 #include "gromacs/domdec/domdec.h"
57 #include "gromacs/domdec/domdec_struct.h"
58 #include "gromacs/gmxlib/network.h"
59 #include "gromacs/gmxlib/nrnb.h"
60 #include "gromacs/math/functions.h"
61 #include "gromacs/math/vec.h"
62 #include "gromacs/mdlib/gmx_omp_nthreads.h"
63 #include "gromacs/mdtypes/commrec.h"
64 #include "gromacs/mdtypes/mdatom.h"
65 #include "gromacs/pbcutil/ishift.h"
66 #include "gromacs/pbcutil/pbc.h"
67 #include "gromacs/timing/wallcycle.h"
68 #include "gromacs/topology/ifunc.h"
69 #include "gromacs/topology/mtop_util.h"
70 #include "gromacs/topology/topology.h"
71 #include "gromacs/utility/exceptions.h"
72 #include "gromacs/utility/fatalerror.h"
73 #include "gromacs/utility/gmxassert.h"
74 #include "gromacs/utility/gmxomp.h"
75
76 /* The strategy used here for assigning virtual sites to (thread-)tasks
77  * is as follows:
78  *
79  * We divide the atom range that vsites operate on (natoms_local with DD,
80  * 0 - last atom involved in vsites without DD) equally over all threads.
81  *
82  * Vsites in the local range constructed from atoms in the local range
83  * and/or other vsites that are fully local are assigned to a simple,
84  * independent task.
85  *
86  * Vsites that are not assigned after using the above criterion get assigned
87  * to a so called "interdependent" thread task when none of the constructing
88  * atoms is a vsite. These tasks are called interdependent, because one task
89  * accesses atoms assigned to a different task/thread.
90  * Note that this option is turned off with large (local) atom counts
91  * to avoid high memory usage.
92  *
93  * Any remaining vsites are assigned to a separate master thread task.
94  */
95
96 namespace gmx
97 {
98
99 //! VirialHandling is often used outside VirtualSitesHandler class members
100 using VirialHandling = VirtualSitesHandler::VirialHandling;
101
102 /*! \libinternal
103  * \brief Information on PBC and domain decomposition for virtual sites
104  */
105 struct DomainInfo
106 {
107 public:
108     //! Constructs without PBC and DD
109     DomainInfo() = default;
110
111     //! Constructs with PBC and DD, if !=nullptr
112     DomainInfo(PbcType pbcType, bool haveInterUpdateGroupVirtualSites, gmx_domdec_t* domdec) :
113         pbcType_(pbcType),
114         useMolPbc_(pbcType != PbcType::No && haveInterUpdateGroupVirtualSites),
115         domdec_(domdec)
116     {
117     }
118
119     //! Returns whether we are using domain decomposition with more than 1 DD rank
120     bool useDomdec() const { return (domdec_ != nullptr); }
121
122     //! The pbc type
123     const PbcType pbcType_ = PbcType::No;
124     //! Whether molecules are broken over PBC
125     const bool useMolPbc_ = false;
126     //! Pointer to the domain decomposition struct, nullptr without PP DD
127     const gmx_domdec_t* domdec_ = nullptr;
128 };
129
130 /*! \libinternal
131  * \brief List of atom indices belonging to a task
132  */
133 struct AtomIndex
134 {
135     //! List of atom indices
136     std::vector<int> atom;
137 };
138
139 /*! \libinternal
140  * \brief Data structure for thread tasks that use constructing atoms outside their own atom range
141  */
142 struct InterdependentTask
143 {
144     //! The interaction lists, only vsite entries are used
145     InteractionLists ilist;
146     //! Thread/task-local force buffer
147     std::vector<RVec> force;
148     //! The atom indices of the vsites of our task
149     std::vector<int> vsite;
150     //! Flags if elements in force are spread to or not
151     std::vector<bool> use;
152     //! The number of entries set to true in use
153     int nuse = 0;
154     //! Array of atoms indices, size nthreads, covering all nuse set elements in use
155     std::vector<AtomIndex> atomIndex;
156     //! List of tasks (force blocks) this task spread forces to
157     std::vector<int> spreadTask;
158     //! List of tasks that write to this tasks force block range
159     std::vector<int> reduceTask;
160 };
161
162 /*! \libinternal
163  * \brief Vsite thread task data structure
164  */
165 struct VsiteThread
166 {
167     //! Start of atom range of this task
168     int rangeStart;
169     //! End of atom range of this task
170     int rangeEnd;
171     //! The interaction lists, only vsite entries are used
172     std::array<InteractionList, F_NRE> ilist;
173     //! Local fshift accumulation buffer
174     std::array<RVec, SHIFTS> fshift;
175     //! Local virial dx*df accumulation buffer
176     matrix dxdf;
177     //! Tells if interdependent task idTask should be used (in addition to the rest of this task), this bool has the same value on all threads
178     bool useInterdependentTask;
179     //! Data for vsites that involve constructing atoms in the atom range of other threads/tasks
180     InterdependentTask idTask;
181
182     /*! \brief Constructor */
183     VsiteThread()
184     {
185         rangeStart = -1;
186         rangeEnd   = -1;
187         for (auto& elem : fshift)
188         {
189             elem = { 0.0_real, 0.0_real, 0.0_real };
190         }
191         clear_mat(dxdf);
192         useInterdependentTask = false;
193     }
194 };
195
196
197 /*! \libinternal
198  * \brief Information on how the virtual site work is divided over thread tasks
199  */
200 class ThreadingInfo
201 {
202 public:
203     //! Constructor, retrieves the number of threads to use from gmx_omp_nthreads.h
204     ThreadingInfo();
205
206     //! Returns the number of threads to use for vsite operations
207     int numThreads() const { return numThreads_; }
208
209     //! Returns the thread data for the given thread
210     const VsiteThread& threadData(int threadIndex) const { return *tData_[threadIndex]; }
211
212     //! Returns the thread data for the given thread
213     VsiteThread& threadData(int threadIndex) { return *tData_[threadIndex]; }
214
215     //! Returns the thread data for vsites that depend on non-local vsites
216     const VsiteThread& threadDataNonLocalDependent() const { return *tData_[numThreads_]; }
217
218     //! Returns the thread data for vsites that depend on non-local vsites
219     VsiteThread& threadDataNonLocalDependent() { return *tData_[numThreads_]; }
220
221     //! Set VSites and distribute VSite work over threads, should be called after DD partitioning
222     void setVirtualSites(ArrayRef<const InteractionList> ilist,
223                          ArrayRef<const t_iparams>       iparams,
224                          const t_mdatoms&                mdatoms,
225                          bool                            useDomdec);
226
227 private:
228     //! Number of threads used for vsite operations
229     const int numThreads_;
230     //! Thread local vsites and work structs
231     std::vector<std::unique_ptr<VsiteThread>> tData_;
232     //! Work array for dividing vsites over threads
233     std::vector<int> taskIndex_;
234 };
235
236 /*! \libinternal
237  * \brief Impl class for VirtualSitesHandler
238  */
239 class VirtualSitesHandler::Impl
240 {
241 public:
242     //! Constructor, domdec should be nullptr without DD
243     Impl(const gmx_mtop_t& mtop, gmx_domdec_t* domdec, PbcType pbcType);
244
245     //! Returns the number of virtual sites acting over multiple update groups
246     int numInterUpdategroupVirtualSites() const { return numInterUpdategroupVirtualSites_; }
247
248     //! Set VSites and distribute VSite work over threads, should be called after DD partitioning
249     void setVirtualSites(ArrayRef<const InteractionList> ilist, const t_mdatoms& mdatoms);
250
251     /*! \brief Create positions of vsite atoms based for the local system
252      *
253      * \param[in,out] x        The coordinates
254      * \param[in]     dt       The time step
255      * \param[in,out] v        When != nullptr, velocities for vsites are set as displacement/dt
256      * \param[in]     box      The box
257      */
258     void construct(ArrayRef<RVec> x, real dt, ArrayRef<RVec> v, const matrix box) const;
259
260     /*! \brief Spread the force operating on the vsite atoms on the surrounding atoms.
261      *
262      * vsite should point to a valid object.
263      * The virialHandling parameter determines how virial contributions are handled.
264      * If this is set to Linear, shift forces are accumulated into fshift.
265      * If this is set to NonLinear, non-linear contributions are added to virial.
266      * This non-linear correction is required when the virial is not calculated
267      * afterwards from the particle position and forces, but in a different way,
268      * as for instance for the PME mesh contribution.
269      */
270     void spreadForces(ArrayRef<const RVec> x,
271                       ArrayRef<RVec>       f,
272                       VirialHandling       virialHandling,
273                       ArrayRef<RVec>       fshift,
274                       matrix               virial,
275                       t_nrnb*              nrnb,
276                       const matrix         box,
277                       gmx_wallcycle*       wcycle);
278
279 private:
280     //! The number of vsites that cross update groups, when =0 no PBC treatment is needed
281     const int numInterUpdategroupVirtualSites_;
282     //! PBC and DD information
283     const DomainInfo domainInfo_;
284     //! The interaction parameters
285     const ArrayRef<const t_iparams> iparams_;
286     //! The interaction lists
287     ArrayRef<const InteractionList> ilists_;
288     //! Information for handling vsite threading
289     ThreadingInfo threadingInfo_;
290 };
291
292 VirtualSitesHandler::~VirtualSitesHandler() = default;
293
294 int VirtualSitesHandler::numInterUpdategroupVirtualSites() const
295 {
296     return impl_->numInterUpdategroupVirtualSites();
297 }
298
299 /*! \brief Returns the sum of the vsite ilist sizes over all vsite types
300  *
301  * \param[in] ilist  The interaction list
302  */
303 static int vsiteIlistNrCount(ArrayRef<const InteractionList> ilist)
304 {
305     int nr = 0;
306     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
307     {
308         nr += ilist[ftype].size();
309     }
310
311     return nr;
312 }
313
314 //! Computes the distance between xi and xj, pbc is used when pbc!=nullptr
315 static int pbc_rvec_sub(const t_pbc* pbc, const rvec xi, const rvec xj, rvec dx)
316 {
317     if (pbc)
318     {
319         return pbc_dx_aiuc(pbc, xi, xj, dx);
320     }
321     else
322     {
323         rvec_sub(xi, xj, dx);
324         return CENTRAL;
325     }
326 }
327
328 //! Returns the 1/norm(x)
329 static inline real inverseNorm(const rvec x)
330 {
331     return gmx::invsqrt(iprod(x, x));
332 }
333
334 #ifndef DOXYGEN
335 /* Vsite construction routines */
336
337 static void constr_vsite2(const rvec xi, const rvec xj, rvec x, real a, const t_pbc* pbc)
338 {
339     real b = 1 - a;
340     /* 1 flop */
341
342     if (pbc)
343     {
344         rvec dx;
345         pbc_dx_aiuc(pbc, xj, xi, dx);
346         x[XX] = xi[XX] + a * dx[XX];
347         x[YY] = xi[YY] + a * dx[YY];
348         x[ZZ] = xi[ZZ] + a * dx[ZZ];
349     }
350     else
351     {
352         x[XX] = b * xi[XX] + a * xj[XX];
353         x[YY] = b * xi[YY] + a * xj[YY];
354         x[ZZ] = b * xi[ZZ] + a * xj[ZZ];
355         /* 9 Flops */
356     }
357
358     /* TOTAL: 10 flops */
359 }
360
361 static void constr_vsite2FD(const rvec xi, const rvec xj, rvec x, real a, const t_pbc* pbc)
362 {
363     rvec xij;
364     pbc_rvec_sub(pbc, xj, xi, xij);
365     /* 3 flops */
366
367     const real b = a * inverseNorm(xij);
368     /* 6 + 10 flops */
369
370     x[XX] = xi[XX] + b * xij[XX];
371     x[YY] = xi[YY] + b * xij[YY];
372     x[ZZ] = xi[ZZ] + b * xij[ZZ];
373     /* 6 Flops */
374
375     /* TOTAL: 25 flops */
376 }
377
378 static void constr_vsite3(const rvec xi, const rvec xj, const rvec xk, rvec x, real a, real b, const t_pbc* pbc)
379 {
380     real c = 1 - a - b;
381     /* 2 flops */
382
383     if (pbc)
384     {
385         rvec dxj, dxk;
386
387         pbc_dx_aiuc(pbc, xj, xi, dxj);
388         pbc_dx_aiuc(pbc, xk, xi, dxk);
389         x[XX] = xi[XX] + a * dxj[XX] + b * dxk[XX];
390         x[YY] = xi[YY] + a * dxj[YY] + b * dxk[YY];
391         x[ZZ] = xi[ZZ] + a * dxj[ZZ] + b * dxk[ZZ];
392     }
393     else
394     {
395         x[XX] = c * xi[XX] + a * xj[XX] + b * xk[XX];
396         x[YY] = c * xi[YY] + a * xj[YY] + b * xk[YY];
397         x[ZZ] = c * xi[ZZ] + a * xj[ZZ] + b * xk[ZZ];
398         /* 15 Flops */
399     }
400
401     /* TOTAL: 17 flops */
402 }
403
404 static void constr_vsite3FD(const rvec xi, const rvec xj, const rvec xk, rvec x, real a, real b, const t_pbc* pbc)
405 {
406     rvec xij, xjk, temp;
407     real c;
408
409     pbc_rvec_sub(pbc, xj, xi, xij);
410     pbc_rvec_sub(pbc, xk, xj, xjk);
411     /* 6 flops */
412
413     /* temp goes from i to a point on the line jk */
414     temp[XX] = xij[XX] + a * xjk[XX];
415     temp[YY] = xij[YY] + a * xjk[YY];
416     temp[ZZ] = xij[ZZ] + a * xjk[ZZ];
417     /* 6 flops */
418
419     c = b * inverseNorm(temp);
420     /* 6 + 10 flops */
421
422     x[XX] = xi[XX] + c * temp[XX];
423     x[YY] = xi[YY] + c * temp[YY];
424     x[ZZ] = xi[ZZ] + c * temp[ZZ];
425     /* 6 Flops */
426
427     /* TOTAL: 34 flops */
428 }
429
430 static void constr_vsite3FAD(const rvec xi, const rvec xj, const rvec xk, rvec x, real a, real b, const t_pbc* pbc)
431 {
432     rvec xij, xjk, xp;
433     real a1, b1, c1, invdij;
434
435     pbc_rvec_sub(pbc, xj, xi, xij);
436     pbc_rvec_sub(pbc, xk, xj, xjk);
437     /* 6 flops */
438
439     invdij = inverseNorm(xij);
440     c1     = invdij * invdij * iprod(xij, xjk);
441     xp[XX] = xjk[XX] - c1 * xij[XX];
442     xp[YY] = xjk[YY] - c1 * xij[YY];
443     xp[ZZ] = xjk[ZZ] - c1 * xij[ZZ];
444     a1     = a * invdij;
445     b1     = b * inverseNorm(xp);
446     /* 45 */
447
448     x[XX] = xi[XX] + a1 * xij[XX] + b1 * xp[XX];
449     x[YY] = xi[YY] + a1 * xij[YY] + b1 * xp[YY];
450     x[ZZ] = xi[ZZ] + a1 * xij[ZZ] + b1 * xp[ZZ];
451     /* 12 Flops */
452
453     /* TOTAL: 63 flops */
454 }
455
456 static void
457 constr_vsite3OUT(const rvec xi, const rvec xj, const rvec xk, rvec x, real a, real b, real c, const t_pbc* pbc)
458 {
459     rvec xij, xik, temp;
460
461     pbc_rvec_sub(pbc, xj, xi, xij);
462     pbc_rvec_sub(pbc, xk, xi, xik);
463     cprod(xij, xik, temp);
464     /* 15 Flops */
465
466     x[XX] = xi[XX] + a * xij[XX] + b * xik[XX] + c * temp[XX];
467     x[YY] = xi[YY] + a * xij[YY] + b * xik[YY] + c * temp[YY];
468     x[ZZ] = xi[ZZ] + a * xij[ZZ] + b * xik[ZZ] + c * temp[ZZ];
469     /* 18 Flops */
470
471     /* TOTAL: 33 flops */
472 }
473
474 static void constr_vsite4FD(const rvec   xi,
475                             const rvec   xj,
476                             const rvec   xk,
477                             const rvec   xl,
478                             rvec         x,
479                             real         a,
480                             real         b,
481                             real         c,
482                             const t_pbc* pbc)
483 {
484     rvec xij, xjk, xjl, temp;
485     real d;
486
487     pbc_rvec_sub(pbc, xj, xi, xij);
488     pbc_rvec_sub(pbc, xk, xj, xjk);
489     pbc_rvec_sub(pbc, xl, xj, xjl);
490     /* 9 flops */
491
492     /* temp goes from i to a point on the plane jkl */
493     temp[XX] = xij[XX] + a * xjk[XX] + b * xjl[XX];
494     temp[YY] = xij[YY] + a * xjk[YY] + b * xjl[YY];
495     temp[ZZ] = xij[ZZ] + a * xjk[ZZ] + b * xjl[ZZ];
496     /* 12 flops */
497
498     d = c * inverseNorm(temp);
499     /* 6 + 10 flops */
500
501     x[XX] = xi[XX] + d * temp[XX];
502     x[YY] = xi[YY] + d * temp[YY];
503     x[ZZ] = xi[ZZ] + d * temp[ZZ];
504     /* 6 Flops */
505
506     /* TOTAL: 43 flops */
507 }
508
509 static void constr_vsite4FDN(const rvec   xi,
510                              const rvec   xj,
511                              const rvec   xk,
512                              const rvec   xl,
513                              rvec         x,
514                              real         a,
515                              real         b,
516                              real         c,
517                              const t_pbc* pbc)
518 {
519     rvec xij, xik, xil, ra, rb, rja, rjb, rm;
520     real d;
521
522     pbc_rvec_sub(pbc, xj, xi, xij);
523     pbc_rvec_sub(pbc, xk, xi, xik);
524     pbc_rvec_sub(pbc, xl, xi, xil);
525     /* 9 flops */
526
527     ra[XX] = a * xik[XX];
528     ra[YY] = a * xik[YY];
529     ra[ZZ] = a * xik[ZZ];
530
531     rb[XX] = b * xil[XX];
532     rb[YY] = b * xil[YY];
533     rb[ZZ] = b * xil[ZZ];
534
535     /* 6 flops */
536
537     rvec_sub(ra, xij, rja);
538     rvec_sub(rb, xij, rjb);
539     /* 6 flops */
540
541     cprod(rja, rjb, rm);
542     /* 9 flops */
543
544     d = c * inverseNorm(rm);
545     /* 5+5+1 flops */
546
547     x[XX] = xi[XX] + d * rm[XX];
548     x[YY] = xi[YY] + d * rm[YY];
549     x[ZZ] = xi[ZZ] + d * rm[ZZ];
550     /* 6 Flops */
551
552     /* TOTAL: 47 flops */
553 }
554
555 static int constr_vsiten(const t_iatom* ia, ArrayRef<const t_iparams> ip, ArrayRef<RVec> x, const t_pbc* pbc)
556 {
557     rvec x1, dx;
558     dvec dsum;
559     int  n3, av, ai;
560     real a;
561
562     n3 = 3 * ip[ia[0]].vsiten.n;
563     av = ia[1];
564     ai = ia[2];
565     copy_rvec(x[ai], x1);
566     clear_dvec(dsum);
567     for (int i = 3; i < n3; i += 3)
568     {
569         ai = ia[i + 2];
570         a  = ip[ia[i]].vsiten.a;
571         if (pbc)
572         {
573             pbc_dx_aiuc(pbc, x[ai], x1, dx);
574         }
575         else
576         {
577             rvec_sub(x[ai], x1, dx);
578         }
579         dsum[XX] += a * dx[XX];
580         dsum[YY] += a * dx[YY];
581         dsum[ZZ] += a * dx[ZZ];
582         /* 9 Flops */
583     }
584
585     x[av][XX] = x1[XX] + dsum[XX];
586     x[av][YY] = x1[YY] + dsum[YY];
587     x[av][ZZ] = x1[ZZ] + dsum[ZZ];
588
589     return n3;
590 }
591
592 #endif // DOXYGEN
593
594 //! PBC modes for vsite construction and spreading
595 enum class PbcMode
596 {
597     all, //!< Apply normal, simple PBC for all vsites
598     none //!< No PBC treatment needed
599 };
600
601 /*! \brief Returns the PBC mode based on the system PBC and vsite properties
602  *
603  * \param[in] pbcPtr  A pointer to a PBC struct or nullptr when no PBC treatment is required
604  */
605 static PbcMode getPbcMode(const t_pbc* pbcPtr)
606 {
607     if (pbcPtr == nullptr)
608     {
609         return PbcMode::none;
610     }
611     else
612     {
613         return PbcMode::all;
614     }
615 }
616
617 /*! \brief Executes the vsite construction task for a single thread
618  *
619  * \param[in,out] x   Coordinates to construct vsites for
620  * \param[in]     dt  Time step, needed when v is not empty
621  * \param[in,out] v   When not empty, velocities are generated for virtual sites
622  * \param[in]     ip  Interaction parameters for all interaction, only vsite parameters are used
623  * \param[in]     ilist  The interaction lists, only vsites are usesd
624  * \param[in]     pbc_null  PBC struct, used for PBC distance calculations when !=nullptr
625  */
626 static void construct_vsites_thread(ArrayRef<RVec>                  x,
627                                     const real                      dt,
628                                     ArrayRef<RVec>                  v,
629                                     ArrayRef<const t_iparams>       ip,
630                                     ArrayRef<const InteractionList> ilist,
631                                     const t_pbc*                    pbc_null)
632 {
633     real inv_dt;
634     if (!v.empty())
635     {
636         inv_dt = 1.0 / dt;
637     }
638     else
639     {
640         inv_dt = 1.0;
641     }
642
643     const PbcMode pbcMode = getPbcMode(pbc_null);
644     /* We need another pbc pointer, as with charge groups we switch per vsite */
645     const t_pbc* pbc_null2 = pbc_null;
646
647     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
648     {
649         if (ilist[ftype].empty())
650         {
651             continue;
652         }
653
654         { // TODO remove me
655             int nra = interaction_function[ftype].nratoms;
656             int inc = 1 + nra;
657             int nr  = ilist[ftype].size();
658
659             const t_iatom* ia = ilist[ftype].iatoms.data();
660
661             for (int i = 0; i < nr;)
662             {
663                 int tp = ia[0];
664                 /* The vsite and constructing atoms */
665                 int avsite = ia[1];
666                 int ai     = ia[2];
667                 /* Constants for constructing vsites */
668                 real a1 = ip[tp].vsite.a;
669                 /* Copy the old position */
670                 rvec xv;
671                 copy_rvec(x[avsite], xv);
672
673                 /* Construct the vsite depending on type */
674                 int  aj, ak, al;
675                 real b1, c1;
676                 switch (ftype)
677                 {
678                     case F_VSITE2:
679                         aj = ia[3];
680                         constr_vsite2(x[ai], x[aj], x[avsite], a1, pbc_null2);
681                         break;
682                     case F_VSITE2FD:
683                         aj = ia[3];
684                         constr_vsite2FD(x[ai], x[aj], x[avsite], a1, pbc_null2);
685                         break;
686                     case F_VSITE3:
687                         aj = ia[3];
688                         ak = ia[4];
689                         b1 = ip[tp].vsite.b;
690                         constr_vsite3(x[ai], x[aj], x[ak], x[avsite], a1, b1, pbc_null2);
691                         break;
692                     case F_VSITE3FD:
693                         aj = ia[3];
694                         ak = ia[4];
695                         b1 = ip[tp].vsite.b;
696                         constr_vsite3FD(x[ai], x[aj], x[ak], x[avsite], a1, b1, pbc_null2);
697                         break;
698                     case F_VSITE3FAD:
699                         aj = ia[3];
700                         ak = ia[4];
701                         b1 = ip[tp].vsite.b;
702                         constr_vsite3FAD(x[ai], x[aj], x[ak], x[avsite], a1, b1, pbc_null2);
703                         break;
704                     case F_VSITE3OUT:
705                         aj = ia[3];
706                         ak = ia[4];
707                         b1 = ip[tp].vsite.b;
708                         c1 = ip[tp].vsite.c;
709                         constr_vsite3OUT(x[ai], x[aj], x[ak], x[avsite], a1, b1, c1, pbc_null2);
710                         break;
711                     case F_VSITE4FD:
712                         aj = ia[3];
713                         ak = ia[4];
714                         al = ia[5];
715                         b1 = ip[tp].vsite.b;
716                         c1 = ip[tp].vsite.c;
717                         constr_vsite4FD(x[ai], x[aj], x[ak], x[al], x[avsite], a1, b1, c1, pbc_null2);
718                         break;
719                     case F_VSITE4FDN:
720                         aj = ia[3];
721                         ak = ia[4];
722                         al = ia[5];
723                         b1 = ip[tp].vsite.b;
724                         c1 = ip[tp].vsite.c;
725                         constr_vsite4FDN(x[ai], x[aj], x[ak], x[al], x[avsite], a1, b1, c1, pbc_null2);
726                         break;
727                     case F_VSITEN: inc = constr_vsiten(ia, ip, x, pbc_null2); break;
728                     default:
729                         gmx_fatal(FARGS, "No such vsite type %d in %s, line %d", ftype, __FILE__, __LINE__);
730                 }
731
732                 if (pbcMode == PbcMode::all)
733                 {
734                     /* Keep the vsite in the same periodic image as before */
735                     rvec dx;
736                     int  ishift = pbc_dx_aiuc(pbc_null, x[avsite], xv, dx);
737                     if (ishift != CENTRAL)
738                     {
739                         rvec_add(xv, dx, x[avsite]);
740                     }
741                 }
742                 if (!v.empty())
743                 {
744                     /* Calculate velocity of vsite... */
745                     rvec vv;
746                     rvec_sub(x[avsite], xv, vv);
747                     svmul(inv_dt, vv, v[avsite]);
748                 }
749
750                 /* Increment loop variables */
751                 i += inc;
752                 ia += inc;
753             }
754         }
755     }
756 }
757
758 /*! \brief Dispatch the vsite construction tasks for all threads
759  *
760  * \param[in]     threadingInfo  Used to divide work over threads when != nullptr
761  * \param[in,out] x   Coordinates to construct vsites for
762  * \param[in]     dt  Time step, needed when v is not empty
763  * \param[in,out] v   When not empty, velocities are generated for virtual sites
764  * \param[in]     ip  Interaction parameters for all interaction, only vsite parameters are used
765  * \param[in]     ilist  The interaction lists, only vsites are usesd
766  * \param[in]     domainInfo  Information about PBC and DD
767  * \param[in]     box  Used for PBC when PBC is set in domainInfo
768  */
769 static void construct_vsites(const ThreadingInfo*            threadingInfo,
770                              ArrayRef<RVec>                  x,
771                              real                            dt,
772                              ArrayRef<RVec>                  v,
773                              ArrayRef<const t_iparams>       ip,
774                              ArrayRef<const InteractionList> ilist,
775                              const DomainInfo&               domainInfo,
776                              const matrix                    box)
777 {
778     const bool useDomdec = domainInfo.useDomdec();
779
780     t_pbc pbc, *pbc_null;
781
782     /* We only need to do pbc when we have inter update-group vsites.
783      * Note that with domain decomposition we do not need to apply PBC here
784      * when we have at least 3 domains along each dimension. Currently we
785      * do not optimize this case.
786      */
787     if (domainInfo.pbcType_ != PbcType::No && domainInfo.useMolPbc_)
788     {
789         /* This is wasting some CPU time as we now do this multiple times
790          * per MD step.
791          */
792         ivec null_ivec;
793         clear_ivec(null_ivec);
794         pbc_null = set_pbc_dd(&pbc, domainInfo.pbcType_,
795                               useDomdec ? domainInfo.domdec_->numCells : null_ivec, FALSE, box);
796     }
797     else
798     {
799         pbc_null = nullptr;
800     }
801
802     if (useDomdec)
803     {
804         dd_move_x_vsites(*domainInfo.domdec_, box, as_rvec_array(x.data()));
805     }
806
807     if (threadingInfo == nullptr || threadingInfo->numThreads() == 1)
808     {
809         construct_vsites_thread(x, dt, v, ip, ilist, pbc_null);
810     }
811     else
812     {
813 #pragma omp parallel num_threads(threadingInfo->numThreads())
814         {
815             try
816             {
817                 const int          th    = gmx_omp_get_thread_num();
818                 const VsiteThread& tData = threadingInfo->threadData(th);
819                 GMX_ASSERT(tData.rangeStart >= 0,
820                            "The thread data should be initialized before calling construct_vsites");
821
822                 construct_vsites_thread(x, dt, v, ip, tData.ilist, pbc_null);
823                 if (tData.useInterdependentTask)
824                 {
825                     /* Here we don't need a barrier (unlike the spreading),
826                      * since both tasks only construct vsites from particles,
827                      * or local vsites, not from non-local vsites.
828                      */
829                     construct_vsites_thread(x, dt, v, ip, tData.idTask.ilist, pbc_null);
830                 }
831             }
832             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
833         }
834         /* Now we can construct the vsites that might depend on other vsites */
835         construct_vsites_thread(x, dt, v, ip, threadingInfo->threadDataNonLocalDependent().ilist, pbc_null);
836     }
837 }
838
839 void VirtualSitesHandler::Impl::construct(ArrayRef<RVec> x, real dt, ArrayRef<RVec> v, const matrix box) const
840 {
841     construct_vsites(&threadingInfo_, x, dt, v, iparams_, ilists_, domainInfo_, box);
842 }
843
844 void VirtualSitesHandler::construct(ArrayRef<RVec> x, real dt, ArrayRef<RVec> v, const matrix box) const
845 {
846     impl_->construct(x, dt, v, box);
847 }
848
849 void constructVirtualSites(ArrayRef<RVec> x, ArrayRef<const t_iparams> ip, ArrayRef<const InteractionList> ilist)
850
851 {
852     // No PBC, no DD
853     const DomainInfo domainInfo;
854     construct_vsites(nullptr, x, 0, {}, ip, ilist, domainInfo, nullptr);
855 }
856
857 #ifndef DOXYGEN
858 /* Force spreading routines */
859
860 template<VirialHandling virialHandling>
861 static void spread_vsite2(const t_iatom        ia[],
862                           real                 a,
863                           ArrayRef<const RVec> x,
864                           ArrayRef<RVec>       f,
865                           ArrayRef<RVec>       fshift,
866                           const t_pbc*         pbc)
867 {
868     rvec    fi, fj, dx;
869     t_iatom av, ai, aj;
870
871     av = ia[1];
872     ai = ia[2];
873     aj = ia[3];
874
875     svmul(1 - a, f[av], fi);
876     svmul(a, f[av], fj);
877     /* 7 flop */
878
879     rvec_inc(f[ai], fi);
880     rvec_inc(f[aj], fj);
881     /* 6 Flops */
882
883     if (virialHandling == VirialHandling::Pbc)
884     {
885         int siv;
886         int sij;
887         if (pbc)
888         {
889             siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
890             sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
891         }
892         else
893         {
894             siv = CENTRAL;
895             sij = CENTRAL;
896         }
897
898         if (siv != CENTRAL || sij != CENTRAL)
899         {
900             rvec_inc(fshift[siv], f[av]);
901             rvec_dec(fshift[CENTRAL], fi);
902             rvec_dec(fshift[sij], fj);
903         }
904     }
905
906     /* TOTAL: 13 flops */
907 }
908
909 void constructVirtualSitesGlobal(const gmx_mtop_t& mtop, gmx::ArrayRef<gmx::RVec> x)
910 {
911     GMX_ASSERT(x.ssize() >= mtop.natoms, "x should contain the whole system");
912     GMX_ASSERT(!mtop.moleculeBlockIndices.empty(),
913                "molblock indices are needed in constructVsitesGlobal");
914
915     for (size_t mb = 0; mb < mtop.molblock.size(); mb++)
916     {
917         const gmx_molblock_t& molb = mtop.molblock[mb];
918         const gmx_moltype_t&  molt = mtop.moltype[molb.type];
919         if (vsiteIlistNrCount(molt.ilist) > 0)
920         {
921             int atomOffset = mtop.moleculeBlockIndices[mb].globalAtomStart;
922             for (int mol = 0; mol < molb.nmol; mol++)
923             {
924                 constructVirtualSites(x.subArray(atomOffset, molt.atoms.nr), mtop.ffparams.iparams,
925                                       molt.ilist);
926                 atomOffset += molt.atoms.nr;
927             }
928         }
929     }
930 }
931
932 template<VirialHandling virialHandling>
933 static void spread_vsite2FD(const t_iatom        ia[],
934                             real                 a,
935                             ArrayRef<const RVec> x,
936                             ArrayRef<RVec>       f,
937                             ArrayRef<RVec>       fshift,
938                             matrix               dxdf,
939                             const t_pbc*         pbc)
940 {
941     const int av = ia[1];
942     const int ai = ia[2];
943     const int aj = ia[3];
944     rvec      fv;
945     copy_rvec(f[av], fv);
946
947     rvec xij;
948     int  sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
949     /* 6 flops */
950
951     const real invDistance = inverseNorm(xij);
952     const real b           = a * invDistance;
953     /* 4 + ?10? flops */
954
955     const real fproj = iprod(xij, fv) * invDistance * invDistance;
956
957     rvec fj;
958     fj[XX] = b * (fv[XX] - fproj * xij[XX]);
959     fj[YY] = b * (fv[YY] - fproj * xij[YY]);
960     fj[ZZ] = b * (fv[ZZ] - fproj * xij[ZZ]);
961     /* 9 */
962
963     /* b is already calculated in constr_vsite2FD
964        storing b somewhere will save flops.     */
965
966     f[ai][XX] += fv[XX] - fj[XX];
967     f[ai][YY] += fv[YY] - fj[YY];
968     f[ai][ZZ] += fv[ZZ] - fj[ZZ];
969     f[aj][XX] += fj[XX];
970     f[aj][YY] += fj[YY];
971     f[aj][ZZ] += fj[ZZ];
972     /* 9 Flops */
973
974     if (virialHandling == VirialHandling::Pbc)
975     {
976         int svi;
977         if (pbc)
978         {
979             rvec xvi;
980             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
981         }
982         else
983         {
984             svi = CENTRAL;
985         }
986
987         if (svi != CENTRAL || sji != CENTRAL)
988         {
989             rvec_dec(fshift[svi], fv);
990             fshift[CENTRAL][XX] += fv[XX] - fj[XX];
991             fshift[CENTRAL][YY] += fv[YY] - fj[YY];
992             fshift[CENTRAL][ZZ] += fv[ZZ] - fj[ZZ];
993             fshift[sji][XX] += fj[XX];
994             fshift[sji][YY] += fj[YY];
995             fshift[sji][ZZ] += fj[ZZ];
996         }
997     }
998
999     if (virialHandling == VirialHandling::NonLinear)
1000     {
1001         /* Under this condition, the virial for the current forces is not
1002          * calculated from the redistributed forces. This means that
1003          * the effect of non-linear virtual site constructions on the virial
1004          * needs to be added separately. This contribution can be calculated
1005          * in many ways, but the simplest and cheapest way is to use
1006          * the first constructing atom ai as a reference position in space:
1007          * subtract (xv-xi)*fv and add (xj-xi)*fj.
1008          */
1009         rvec xiv;
1010
1011         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1012
1013         for (int i = 0; i < DIM; i++)
1014         {
1015             for (int j = 0; j < DIM; j++)
1016             {
1017                 /* As xix is a linear combination of j and k, use that here */
1018                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * fj[j];
1019             }
1020         }
1021     }
1022
1023     /* TOTAL: 38 flops */
1024 }
1025
1026 template<VirialHandling virialHandling>
1027 static void spread_vsite3(const t_iatom        ia[],
1028                           real                 a,
1029                           real                 b,
1030                           ArrayRef<const RVec> x,
1031                           ArrayRef<RVec>       f,
1032                           ArrayRef<RVec>       fshift,
1033                           const t_pbc*         pbc)
1034 {
1035     rvec fi, fj, fk, dx;
1036     int  av, ai, aj, ak;
1037
1038     av = ia[1];
1039     ai = ia[2];
1040     aj = ia[3];
1041     ak = ia[4];
1042
1043     svmul(1 - a - b, f[av], fi);
1044     svmul(a, f[av], fj);
1045     svmul(b, f[av], fk);
1046     /* 11 flops */
1047
1048     rvec_inc(f[ai], fi);
1049     rvec_inc(f[aj], fj);
1050     rvec_inc(f[ak], fk);
1051     /* 9 Flops */
1052
1053     if (virialHandling == VirialHandling::Pbc)
1054     {
1055         int siv;
1056         int sij;
1057         int sik;
1058         if (pbc)
1059         {
1060             siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
1061             sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
1062             sik = pbc_dx_aiuc(pbc, x[ai], x[ak], dx);
1063         }
1064         else
1065         {
1066             siv = CENTRAL;
1067             sij = CENTRAL;
1068             sik = CENTRAL;
1069         }
1070
1071         if (siv != CENTRAL || sij != CENTRAL || sik != CENTRAL)
1072         {
1073             rvec_inc(fshift[siv], f[av]);
1074             rvec_dec(fshift[CENTRAL], fi);
1075             rvec_dec(fshift[sij], fj);
1076             rvec_dec(fshift[sik], fk);
1077         }
1078     }
1079
1080     /* TOTAL: 20 flops */
1081 }
1082
1083 template<VirialHandling virialHandling>
1084 static void spread_vsite3FD(const t_iatom        ia[],
1085                             real                 a,
1086                             real                 b,
1087                             ArrayRef<const RVec> x,
1088                             ArrayRef<RVec>       f,
1089                             ArrayRef<RVec>       fshift,
1090                             matrix               dxdf,
1091                             const t_pbc*         pbc)
1092 {
1093     real    fproj, a1;
1094     rvec    xvi, xij, xjk, xix, fv, temp;
1095     t_iatom av, ai, aj, ak;
1096     int     sji, skj;
1097
1098     av = ia[1];
1099     ai = ia[2];
1100     aj = ia[3];
1101     ak = ia[4];
1102     copy_rvec(f[av], fv);
1103
1104     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1105     skj = pbc_rvec_sub(pbc, x[ak], x[aj], xjk);
1106     /* 6 flops */
1107
1108     /* xix goes from i to point x on the line jk */
1109     xix[XX] = xij[XX] + a * xjk[XX];
1110     xix[YY] = xij[YY] + a * xjk[YY];
1111     xix[ZZ] = xij[ZZ] + a * xjk[ZZ];
1112     /* 6 flops */
1113
1114     const real invDistance = inverseNorm(xix);
1115     const real c           = b * invDistance;
1116     /* 4 + ?10? flops */
1117
1118     fproj = iprod(xix, fv) * invDistance * invDistance; /* = (xix . f)/(xix . xix) */
1119
1120     temp[XX] = c * (fv[XX] - fproj * xix[XX]);
1121     temp[YY] = c * (fv[YY] - fproj * xix[YY]);
1122     temp[ZZ] = c * (fv[ZZ] - fproj * xix[ZZ]);
1123     /* 16 */
1124
1125     /* c is already calculated in constr_vsite3FD
1126        storing c somewhere will save 26 flops!     */
1127
1128     a1 = 1 - a;
1129     f[ai][XX] += fv[XX] - temp[XX];
1130     f[ai][YY] += fv[YY] - temp[YY];
1131     f[ai][ZZ] += fv[ZZ] - temp[ZZ];
1132     f[aj][XX] += a1 * temp[XX];
1133     f[aj][YY] += a1 * temp[YY];
1134     f[aj][ZZ] += a1 * temp[ZZ];
1135     f[ak][XX] += a * temp[XX];
1136     f[ak][YY] += a * temp[YY];
1137     f[ak][ZZ] += a * temp[ZZ];
1138     /* 19 Flops */
1139
1140     if (virialHandling == VirialHandling::Pbc)
1141     {
1142         int svi;
1143         if (pbc)
1144         {
1145             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1146         }
1147         else
1148         {
1149             svi = CENTRAL;
1150         }
1151
1152         if (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL)
1153         {
1154             rvec_dec(fshift[svi], fv);
1155             fshift[CENTRAL][XX] += fv[XX] - (1 + a) * temp[XX];
1156             fshift[CENTRAL][YY] += fv[YY] - (1 + a) * temp[YY];
1157             fshift[CENTRAL][ZZ] += fv[ZZ] - (1 + a) * temp[ZZ];
1158             fshift[sji][XX] += temp[XX];
1159             fshift[sji][YY] += temp[YY];
1160             fshift[sji][ZZ] += temp[ZZ];
1161             fshift[skj][XX] += a * temp[XX];
1162             fshift[skj][YY] += a * temp[YY];
1163             fshift[skj][ZZ] += a * temp[ZZ];
1164         }
1165     }
1166
1167     if (virialHandling == VirialHandling::NonLinear)
1168     {
1169         /* Under this condition, the virial for the current forces is not
1170          * calculated from the redistributed forces. This means that
1171          * the effect of non-linear virtual site constructions on the virial
1172          * needs to be added separately. This contribution can be calculated
1173          * in many ways, but the simplest and cheapest way is to use
1174          * the first constructing atom ai as a reference position in space:
1175          * subtract (xv-xi)*fv and add (xj-xi)*fj + (xk-xi)*fk.
1176          */
1177         rvec xiv;
1178
1179         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1180
1181         for (int i = 0; i < DIM; i++)
1182         {
1183             for (int j = 0; j < DIM; j++)
1184             {
1185                 /* As xix is a linear combination of j and k, use that here */
1186                 dxdf[i][j] += -xiv[i] * fv[j] + xix[i] * temp[j];
1187             }
1188         }
1189     }
1190
1191     /* TOTAL: 61 flops */
1192 }
1193
1194 template<VirialHandling virialHandling>
1195 static void spread_vsite3FAD(const t_iatom        ia[],
1196                              real                 a,
1197                              real                 b,
1198                              ArrayRef<const RVec> x,
1199                              ArrayRef<RVec>       f,
1200                              ArrayRef<RVec>       fshift,
1201                              matrix               dxdf,
1202                              const t_pbc*         pbc)
1203 {
1204     rvec    xvi, xij, xjk, xperp, Fpij, Fppp, fv, f1, f2, f3;
1205     real    a1, b1, c1, c2, invdij, invdij2, invdp, fproj;
1206     t_iatom av, ai, aj, ak;
1207     int     sji, skj;
1208
1209     av = ia[1];
1210     ai = ia[2];
1211     aj = ia[3];
1212     ak = ia[4];
1213     copy_rvec(f[ia[1]], fv);
1214
1215     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1216     skj = pbc_rvec_sub(pbc, x[ak], x[aj], xjk);
1217     /* 6 flops */
1218
1219     invdij    = inverseNorm(xij);
1220     invdij2   = invdij * invdij;
1221     c1        = iprod(xij, xjk) * invdij2;
1222     xperp[XX] = xjk[XX] - c1 * xij[XX];
1223     xperp[YY] = xjk[YY] - c1 * xij[YY];
1224     xperp[ZZ] = xjk[ZZ] - c1 * xij[ZZ];
1225     /* xperp in plane ijk, perp. to ij */
1226     invdp = inverseNorm(xperp);
1227     a1    = a * invdij;
1228     b1    = b * invdp;
1229     /* 45 flops */
1230
1231     /* a1, b1 and c1 are already calculated in constr_vsite3FAD
1232        storing them somewhere will save 45 flops!     */
1233
1234     fproj = iprod(xij, fv) * invdij2;
1235     svmul(fproj, xij, Fpij);                              /* proj. f on xij */
1236     svmul(iprod(xperp, fv) * invdp * invdp, xperp, Fppp); /* proj. f on xperp */
1237     svmul(b1 * fproj, xperp, f3);
1238     /* 23 flops */
1239
1240     rvec_sub(fv, Fpij, f1); /* f1 = f - Fpij */
1241     rvec_sub(f1, Fppp, f2); /* f2 = f - Fpij - Fppp */
1242     for (int d = 0; d < DIM; d++)
1243     {
1244         f1[d] *= a1;
1245         f2[d] *= b1;
1246     }
1247     /* 12 flops */
1248
1249     c2 = 1 + c1;
1250     f[ai][XX] += fv[XX] - f1[XX] + c1 * f2[XX] + f3[XX];
1251     f[ai][YY] += fv[YY] - f1[YY] + c1 * f2[YY] + f3[YY];
1252     f[ai][ZZ] += fv[ZZ] - f1[ZZ] + c1 * f2[ZZ] + f3[ZZ];
1253     f[aj][XX] += f1[XX] - c2 * f2[XX] - f3[XX];
1254     f[aj][YY] += f1[YY] - c2 * f2[YY] - f3[YY];
1255     f[aj][ZZ] += f1[ZZ] - c2 * f2[ZZ] - f3[ZZ];
1256     f[ak][XX] += f2[XX];
1257     f[ak][YY] += f2[YY];
1258     f[ak][ZZ] += f2[ZZ];
1259     /* 30 Flops */
1260
1261     if (virialHandling == VirialHandling::Pbc)
1262     {
1263         int svi;
1264
1265         if (pbc)
1266         {
1267             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1268         }
1269         else
1270         {
1271             svi = CENTRAL;
1272         }
1273
1274         if (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL)
1275         {
1276             rvec_dec(fshift[svi], fv);
1277             fshift[CENTRAL][XX] += fv[XX] - f1[XX] - (1 - c1) * f2[XX] + f3[XX];
1278             fshift[CENTRAL][YY] += fv[YY] - f1[YY] - (1 - c1) * f2[YY] + f3[YY];
1279             fshift[CENTRAL][ZZ] += fv[ZZ] - f1[ZZ] - (1 - c1) * f2[ZZ] + f3[ZZ];
1280             fshift[sji][XX] += f1[XX] - c1 * f2[XX] - f3[XX];
1281             fshift[sji][YY] += f1[YY] - c1 * f2[YY] - f3[YY];
1282             fshift[sji][ZZ] += f1[ZZ] - c1 * f2[ZZ] - f3[ZZ];
1283             fshift[skj][XX] += f2[XX];
1284             fshift[skj][YY] += f2[YY];
1285             fshift[skj][ZZ] += f2[ZZ];
1286         }
1287     }
1288
1289     if (virialHandling == VirialHandling::NonLinear)
1290     {
1291         rvec xiv;
1292         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1293
1294         for (int i = 0; i < DIM; i++)
1295         {
1296             for (int j = 0; j < DIM; j++)
1297             {
1298                 /* Note that xik=xij+xjk, so we have to add xij*f2 */
1299                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * (f1[j] + (1 - c2) * f2[j] - f3[j])
1300                               + xjk[i] * f2[j];
1301             }
1302         }
1303     }
1304
1305     /* TOTAL: 113 flops */
1306 }
1307
1308 template<VirialHandling virialHandling>
1309 static void spread_vsite3OUT(const t_iatom        ia[],
1310                              real                 a,
1311                              real                 b,
1312                              real                 c,
1313                              ArrayRef<const RVec> x,
1314                              ArrayRef<RVec>       f,
1315                              ArrayRef<RVec>       fshift,
1316                              matrix               dxdf,
1317                              const t_pbc*         pbc)
1318 {
1319     rvec xvi, xij, xik, fv, fj, fk;
1320     real cfx, cfy, cfz;
1321     int  av, ai, aj, ak;
1322     int  sji, ski;
1323
1324     av = ia[1];
1325     ai = ia[2];
1326     aj = ia[3];
1327     ak = ia[4];
1328
1329     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1330     ski = pbc_rvec_sub(pbc, x[ak], x[ai], xik);
1331     /* 6 Flops */
1332
1333     copy_rvec(f[av], fv);
1334
1335     cfx = c * fv[XX];
1336     cfy = c * fv[YY];
1337     cfz = c * fv[ZZ];
1338     /* 3 Flops */
1339
1340     fj[XX] = a * fv[XX] - xik[ZZ] * cfy + xik[YY] * cfz;
1341     fj[YY] = xik[ZZ] * cfx + a * fv[YY] - xik[XX] * cfz;
1342     fj[ZZ] = -xik[YY] * cfx + xik[XX] * cfy + a * fv[ZZ];
1343
1344     fk[XX] = b * fv[XX] + xij[ZZ] * cfy - xij[YY] * cfz;
1345     fk[YY] = -xij[ZZ] * cfx + b * fv[YY] + xij[XX] * cfz;
1346     fk[ZZ] = xij[YY] * cfx - xij[XX] * cfy + b * fv[ZZ];
1347     /* 30 Flops */
1348
1349     f[ai][XX] += fv[XX] - fj[XX] - fk[XX];
1350     f[ai][YY] += fv[YY] - fj[YY] - fk[YY];
1351     f[ai][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ];
1352     rvec_inc(f[aj], fj);
1353     rvec_inc(f[ak], fk);
1354     /* 15 Flops */
1355
1356     if (virialHandling == VirialHandling::Pbc)
1357     {
1358         int svi;
1359         if (pbc)
1360         {
1361             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1362         }
1363         else
1364         {
1365             svi = CENTRAL;
1366         }
1367
1368         if (svi != CENTRAL || sji != CENTRAL || ski != CENTRAL)
1369         {
1370             rvec_dec(fshift[svi], fv);
1371             fshift[CENTRAL][XX] += fv[XX] - fj[XX] - fk[XX];
1372             fshift[CENTRAL][YY] += fv[YY] - fj[YY] - fk[YY];
1373             fshift[CENTRAL][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ];
1374             rvec_inc(fshift[sji], fj);
1375             rvec_inc(fshift[ski], fk);
1376         }
1377     }
1378
1379     if (virialHandling == VirialHandling::NonLinear)
1380     {
1381         rvec xiv;
1382
1383         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1384
1385         for (int i = 0; i < DIM; i++)
1386         {
1387             for (int j = 0; j < DIM; j++)
1388             {
1389                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * fj[j] + xik[i] * fk[j];
1390             }
1391         }
1392     }
1393
1394     /* TOTAL: 54 flops */
1395 }
1396
1397 template<VirialHandling virialHandling>
1398 static void spread_vsite4FD(const t_iatom        ia[],
1399                             real                 a,
1400                             real                 b,
1401                             real                 c,
1402                             ArrayRef<const RVec> x,
1403                             ArrayRef<RVec>       f,
1404                             ArrayRef<RVec>       fshift,
1405                             matrix               dxdf,
1406                             const t_pbc*         pbc)
1407 {
1408     real fproj, a1;
1409     rvec xvi, xij, xjk, xjl, xix, fv, temp;
1410     int  av, ai, aj, ak, al;
1411     int  sji, skj, slj, m;
1412
1413     av = ia[1];
1414     ai = ia[2];
1415     aj = ia[3];
1416     ak = ia[4];
1417     al = ia[5];
1418
1419     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1420     skj = pbc_rvec_sub(pbc, x[ak], x[aj], xjk);
1421     slj = pbc_rvec_sub(pbc, x[al], x[aj], xjl);
1422     /* 9 flops */
1423
1424     /* xix goes from i to point x on the plane jkl */
1425     for (m = 0; m < DIM; m++)
1426     {
1427         xix[m] = xij[m] + a * xjk[m] + b * xjl[m];
1428     }
1429     /* 12 flops */
1430
1431     const real invDistance = inverseNorm(xix);
1432     const real d           = c * invDistance;
1433     /* 4 + ?10? flops */
1434
1435     copy_rvec(f[av], fv);
1436
1437     fproj = iprod(xix, fv) * invDistance * invDistance; /* = (xix . f)/(xix . xix) */
1438
1439     for (m = 0; m < DIM; m++)
1440     {
1441         temp[m] = d * (fv[m] - fproj * xix[m]);
1442     }
1443     /* 16 */
1444
1445     /* c is already calculated in constr_vsite3FD
1446        storing c somewhere will save 35 flops!     */
1447
1448     a1 = 1 - a - b;
1449     for (m = 0; m < DIM; m++)
1450     {
1451         f[ai][m] += fv[m] - temp[m];
1452         f[aj][m] += a1 * temp[m];
1453         f[ak][m] += a * temp[m];
1454         f[al][m] += b * temp[m];
1455     }
1456     /* 26 Flops */
1457
1458     if (virialHandling == VirialHandling::Pbc)
1459     {
1460         int svi;
1461         if (pbc)
1462         {
1463             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1464         }
1465         else
1466         {
1467             svi = CENTRAL;
1468         }
1469
1470         if (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL || slj != CENTRAL)
1471         {
1472             rvec_dec(fshift[svi], fv);
1473             for (m = 0; m < DIM; m++)
1474             {
1475                 fshift[CENTRAL][m] += fv[m] - (1 + a + b) * temp[m];
1476                 fshift[sji][m] += temp[m];
1477                 fshift[skj][m] += a * temp[m];
1478                 fshift[slj][m] += b * temp[m];
1479             }
1480         }
1481     }
1482
1483     if (virialHandling == VirialHandling::NonLinear)
1484     {
1485         rvec xiv;
1486         int  i, j;
1487
1488         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1489
1490         for (i = 0; i < DIM; i++)
1491         {
1492             for (j = 0; j < DIM; j++)
1493             {
1494                 dxdf[i][j] += -xiv[i] * fv[j] + xix[i] * temp[j];
1495             }
1496         }
1497     }
1498
1499     /* TOTAL: 77 flops */
1500 }
1501
1502 template<VirialHandling virialHandling>
1503 static void spread_vsite4FDN(const t_iatom        ia[],
1504                              real                 a,
1505                              real                 b,
1506                              real                 c,
1507                              ArrayRef<const RVec> x,
1508                              ArrayRef<RVec>       f,
1509                              ArrayRef<RVec>       fshift,
1510                              matrix               dxdf,
1511                              const t_pbc*         pbc)
1512 {
1513     rvec xvi, xij, xik, xil, ra, rb, rja, rjb, rab, rm, rt;
1514     rvec fv, fj, fk, fl;
1515     real invrm, denom;
1516     real cfx, cfy, cfz;
1517     int  av, ai, aj, ak, al;
1518     int  sij, sik, sil;
1519
1520     /* DEBUG: check atom indices */
1521     av = ia[1];
1522     ai = ia[2];
1523     aj = ia[3];
1524     ak = ia[4];
1525     al = ia[5];
1526
1527     copy_rvec(f[av], fv);
1528
1529     sij = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1530     sik = pbc_rvec_sub(pbc, x[ak], x[ai], xik);
1531     sil = pbc_rvec_sub(pbc, x[al], x[ai], xil);
1532     /* 9 flops */
1533
1534     ra[XX] = a * xik[XX];
1535     ra[YY] = a * xik[YY];
1536     ra[ZZ] = a * xik[ZZ];
1537
1538     rb[XX] = b * xil[XX];
1539     rb[YY] = b * xil[YY];
1540     rb[ZZ] = b * xil[ZZ];
1541
1542     /* 6 flops */
1543
1544     rvec_sub(ra, xij, rja);
1545     rvec_sub(rb, xij, rjb);
1546     rvec_sub(rb, ra, rab);
1547     /* 9 flops */
1548
1549     cprod(rja, rjb, rm);
1550     /* 9 flops */
1551
1552     invrm = inverseNorm(rm);
1553     denom = invrm * invrm;
1554     /* 5+5+2 flops */
1555
1556     cfx = c * invrm * fv[XX];
1557     cfy = c * invrm * fv[YY];
1558     cfz = c * invrm * fv[ZZ];
1559     /* 6 Flops */
1560
1561     cprod(rm, rab, rt);
1562     /* 9 flops */
1563
1564     rt[XX] *= denom;
1565     rt[YY] *= denom;
1566     rt[ZZ] *= denom;
1567     /* 3flops */
1568
1569     fj[XX] = (-rm[XX] * rt[XX]) * cfx + (rab[ZZ] - rm[YY] * rt[XX]) * cfy
1570              + (-rab[YY] - rm[ZZ] * rt[XX]) * cfz;
1571     fj[YY] = (-rab[ZZ] - rm[XX] * rt[YY]) * cfx + (-rm[YY] * rt[YY]) * cfy
1572              + (rab[XX] - rm[ZZ] * rt[YY]) * cfz;
1573     fj[ZZ] = (rab[YY] - rm[XX] * rt[ZZ]) * cfx + (-rab[XX] - rm[YY] * rt[ZZ]) * cfy
1574              + (-rm[ZZ] * rt[ZZ]) * cfz;
1575     /* 30 flops */
1576
1577     cprod(rjb, rm, rt);
1578     /* 9 flops */
1579
1580     rt[XX] *= denom * a;
1581     rt[YY] *= denom * a;
1582     rt[ZZ] *= denom * a;
1583     /* 3flops */
1584
1585     fk[XX] = (-rm[XX] * rt[XX]) * cfx + (-a * rjb[ZZ] - rm[YY] * rt[XX]) * cfy
1586              + (a * rjb[YY] - rm[ZZ] * rt[XX]) * cfz;
1587     fk[YY] = (a * rjb[ZZ] - rm[XX] * rt[YY]) * cfx + (-rm[YY] * rt[YY]) * cfy
1588              + (-a * rjb[XX] - rm[ZZ] * rt[YY]) * cfz;
1589     fk[ZZ] = (-a * rjb[YY] - rm[XX] * rt[ZZ]) * cfx + (a * rjb[XX] - rm[YY] * rt[ZZ]) * cfy
1590              + (-rm[ZZ] * rt[ZZ]) * cfz;
1591     /* 36 flops */
1592
1593     cprod(rm, rja, rt);
1594     /* 9 flops */
1595
1596     rt[XX] *= denom * b;
1597     rt[YY] *= denom * b;
1598     rt[ZZ] *= denom * b;
1599     /* 3flops */
1600
1601     fl[XX] = (-rm[XX] * rt[XX]) * cfx + (b * rja[ZZ] - rm[YY] * rt[XX]) * cfy
1602              + (-b * rja[YY] - rm[ZZ] * rt[XX]) * cfz;
1603     fl[YY] = (-b * rja[ZZ] - rm[XX] * rt[YY]) * cfx + (-rm[YY] * rt[YY]) * cfy
1604              + (b * rja[XX] - rm[ZZ] * rt[YY]) * cfz;
1605     fl[ZZ] = (b * rja[YY] - rm[XX] * rt[ZZ]) * cfx + (-b * rja[XX] - rm[YY] * rt[ZZ]) * cfy
1606              + (-rm[ZZ] * rt[ZZ]) * cfz;
1607     /* 36 flops */
1608
1609     f[ai][XX] += fv[XX] - fj[XX] - fk[XX] - fl[XX];
1610     f[ai][YY] += fv[YY] - fj[YY] - fk[YY] - fl[YY];
1611     f[ai][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ] - fl[ZZ];
1612     rvec_inc(f[aj], fj);
1613     rvec_inc(f[ak], fk);
1614     rvec_inc(f[al], fl);
1615     /* 21 flops */
1616
1617     if (virialHandling == VirialHandling::Pbc)
1618     {
1619         int svi;
1620         if (pbc)
1621         {
1622             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1623         }
1624         else
1625         {
1626             svi = CENTRAL;
1627         }
1628
1629         if (svi != CENTRAL || sij != CENTRAL || sik != CENTRAL || sil != CENTRAL)
1630         {
1631             rvec_dec(fshift[svi], fv);
1632             fshift[CENTRAL][XX] += fv[XX] - fj[XX] - fk[XX] - fl[XX];
1633             fshift[CENTRAL][YY] += fv[YY] - fj[YY] - fk[YY] - fl[YY];
1634             fshift[CENTRAL][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ] - fl[ZZ];
1635             rvec_inc(fshift[sij], fj);
1636             rvec_inc(fshift[sik], fk);
1637             rvec_inc(fshift[sil], fl);
1638         }
1639     }
1640
1641     if (virialHandling == VirialHandling::NonLinear)
1642     {
1643         rvec xiv;
1644         int  i, j;
1645
1646         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1647
1648         for (i = 0; i < DIM; i++)
1649         {
1650             for (j = 0; j < DIM; j++)
1651             {
1652                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * fj[j] + xik[i] * fk[j] + xil[i] * fl[j];
1653             }
1654         }
1655     }
1656
1657     /* Total: 207 flops (Yuck!) */
1658 }
1659
1660 template<VirialHandling virialHandling>
1661 static int spread_vsiten(const t_iatom             ia[],
1662                          ArrayRef<const t_iparams> ip,
1663                          ArrayRef<const RVec>      x,
1664                          ArrayRef<RVec>            f,
1665                          ArrayRef<RVec>            fshift,
1666                          const t_pbc*              pbc)
1667 {
1668     rvec xv, dx, fi;
1669     int  n3, av, i, ai;
1670     real a;
1671     int  siv;
1672
1673     n3 = 3 * ip[ia[0]].vsiten.n;
1674     av = ia[1];
1675     copy_rvec(x[av], xv);
1676
1677     for (i = 0; i < n3; i += 3)
1678     {
1679         ai = ia[i + 2];
1680         if (pbc)
1681         {
1682             siv = pbc_dx_aiuc(pbc, x[ai], xv, dx);
1683         }
1684         else
1685         {
1686             siv = CENTRAL;
1687         }
1688         a = ip[ia[i]].vsiten.a;
1689         svmul(a, f[av], fi);
1690         rvec_inc(f[ai], fi);
1691
1692         if (virialHandling == VirialHandling::Pbc && siv != CENTRAL)
1693         {
1694             rvec_inc(fshift[siv], fi);
1695             rvec_dec(fshift[CENTRAL], fi);
1696         }
1697         /* 6 Flops */
1698     }
1699
1700     return n3;
1701 }
1702
1703 #endif // DOXYGEN
1704
1705 //! Returns the number of virtual sites in the interaction list, for VSITEN the number of atoms
1706 static int vsite_count(ArrayRef<const InteractionList> ilist, int ftype)
1707 {
1708     if (ftype == F_VSITEN)
1709     {
1710         return ilist[ftype].size() / 3;
1711     }
1712     else
1713     {
1714         return ilist[ftype].size() / (1 + interaction_function[ftype].nratoms);
1715     }
1716 }
1717
1718 //! Executes the force spreading task for a single thread
1719 template<VirialHandling virialHandling>
1720 static void spreadForceForThread(ArrayRef<const RVec>            x,
1721                                  ArrayRef<RVec>                  f,
1722                                  ArrayRef<RVec>                  fshift,
1723                                  matrix                          dxdf,
1724                                  ArrayRef<const t_iparams>       ip,
1725                                  ArrayRef<const InteractionList> ilist,
1726                                  const t_pbc*                    pbc_null)
1727 {
1728     const PbcMode pbcMode = getPbcMode(pbc_null);
1729     /* We need another pbc pointer, as with charge groups we switch per vsite */
1730     const t_pbc*             pbc_null2 = pbc_null;
1731     gmx::ArrayRef<const int> vsite_pbc;
1732
1733     /* this loop goes backwards to be able to build *
1734      * higher type vsites from lower types         */
1735     for (int ftype = c_ftypeVsiteEnd - 1; ftype >= c_ftypeVsiteStart; ftype--)
1736     {
1737         if (ilist[ftype].empty())
1738         {
1739             continue;
1740         }
1741
1742         { // TODO remove me
1743             int nra = interaction_function[ftype].nratoms;
1744             int inc = 1 + nra;
1745             int nr  = ilist[ftype].size();
1746
1747             const t_iatom* ia = ilist[ftype].iatoms.data();
1748
1749             if (pbcMode == PbcMode::all)
1750             {
1751                 pbc_null2 = pbc_null;
1752             }
1753
1754             for (int i = 0; i < nr;)
1755             {
1756                 int tp = ia[0];
1757
1758                 /* Constants for constructing */
1759                 real a1, b1, c1;
1760                 a1 = ip[tp].vsite.a;
1761                 /* Construct the vsite depending on type */
1762                 switch (ftype)
1763                 {
1764                     case F_VSITE2:
1765                         spread_vsite2<virialHandling>(ia, a1, x, f, fshift, pbc_null2);
1766                         break;
1767                     case F_VSITE2FD:
1768                         spread_vsite2FD<virialHandling>(ia, a1, x, f, fshift, dxdf, pbc_null2);
1769                         break;
1770                     case F_VSITE3:
1771                         b1 = ip[tp].vsite.b;
1772                         spread_vsite3<virialHandling>(ia, a1, b1, x, f, fshift, pbc_null2);
1773                         break;
1774                     case F_VSITE3FD:
1775                         b1 = ip[tp].vsite.b;
1776                         spread_vsite3FD<virialHandling>(ia, a1, b1, x, f, fshift, dxdf, pbc_null2);
1777                         break;
1778                     case F_VSITE3FAD:
1779                         b1 = ip[tp].vsite.b;
1780                         spread_vsite3FAD<virialHandling>(ia, a1, b1, x, f, fshift, dxdf, pbc_null2);
1781                         break;
1782                     case F_VSITE3OUT:
1783                         b1 = ip[tp].vsite.b;
1784                         c1 = ip[tp].vsite.c;
1785                         spread_vsite3OUT<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
1786                         break;
1787                     case F_VSITE4FD:
1788                         b1 = ip[tp].vsite.b;
1789                         c1 = ip[tp].vsite.c;
1790                         spread_vsite4FD<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
1791                         break;
1792                     case F_VSITE4FDN:
1793                         b1 = ip[tp].vsite.b;
1794                         c1 = ip[tp].vsite.c;
1795                         spread_vsite4FDN<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
1796                         break;
1797                     case F_VSITEN:
1798                         inc = spread_vsiten<virialHandling>(ia, ip, x, f, fshift, pbc_null2);
1799                         break;
1800                     default:
1801                         gmx_fatal(FARGS, "No such vsite type %d in %s, line %d", ftype, __FILE__, __LINE__);
1802                 }
1803                 clear_rvec(f[ia[1]]);
1804
1805                 /* Increment loop variables */
1806                 i += inc;
1807                 ia += inc;
1808             }
1809         }
1810     }
1811 }
1812
1813 //! Wrapper function for calling the templated thread-local spread function
1814 static void spreadForceWrapper(ArrayRef<const RVec>            x,
1815                                ArrayRef<RVec>                  f,
1816                                const VirialHandling            virialHandling,
1817                                ArrayRef<RVec>                  fshift,
1818                                matrix                          dxdf,
1819                                const bool                      clearDxdf,
1820                                ArrayRef<const t_iparams>       ip,
1821                                ArrayRef<const InteractionList> ilist,
1822                                const t_pbc*                    pbc_null)
1823 {
1824     if (virialHandling == VirialHandling::NonLinear && clearDxdf)
1825     {
1826         clear_mat(dxdf);
1827     }
1828
1829     switch (virialHandling)
1830     {
1831         case VirialHandling::None:
1832             spreadForceForThread<VirialHandling::None>(x, f, fshift, dxdf, ip, ilist, pbc_null);
1833             break;
1834         case VirialHandling::Pbc:
1835             spreadForceForThread<VirialHandling::Pbc>(x, f, fshift, dxdf, ip, ilist, pbc_null);
1836             break;
1837         case VirialHandling::NonLinear:
1838             spreadForceForThread<VirialHandling::NonLinear>(x, f, fshift, dxdf, ip, ilist, pbc_null);
1839             break;
1840     }
1841 }
1842
1843 //! Clears the task force buffer elements that are written by task idTask
1844 static void clearTaskForceBufferUsedElements(InterdependentTask* idTask)
1845 {
1846     int ntask = idTask->spreadTask.size();
1847     for (int ti = 0; ti < ntask; ti++)
1848     {
1849         const AtomIndex* atomList = &idTask->atomIndex[idTask->spreadTask[ti]];
1850         int              natom    = atomList->atom.size();
1851         RVec*            force    = idTask->force.data();
1852         for (int i = 0; i < natom; i++)
1853         {
1854             clear_rvec(force[atomList->atom[i]]);
1855         }
1856     }
1857 }
1858
1859 void VirtualSitesHandler::Impl::spreadForces(ArrayRef<const RVec> x,
1860                                              ArrayRef<RVec>       f,
1861                                              const VirialHandling virialHandling,
1862                                              ArrayRef<RVec>       fshift,
1863                                              matrix               virial,
1864                                              t_nrnb*              nrnb,
1865                                              const matrix         box,
1866                                              gmx_wallcycle*       wcycle)
1867 {
1868     wallcycle_start(wcycle, ewcVSITESPREAD);
1869
1870     const bool useDomdec = domainInfo_.useDomdec();
1871
1872     t_pbc pbc, *pbc_null;
1873
1874     if (domainInfo_.useMolPbc_)
1875     {
1876         /* This is wasting some CPU time as we now do this multiple times
1877          * per MD step.
1878          */
1879         pbc_null = set_pbc_dd(&pbc, domainInfo_.pbcType_,
1880                               useDomdec ? domainInfo_.domdec_->numCells : nullptr, FALSE, box);
1881     }
1882     else
1883     {
1884         pbc_null = nullptr;
1885     }
1886
1887     if (useDomdec)
1888     {
1889         dd_clear_f_vsites(*domainInfo_.domdec_, f);
1890     }
1891
1892     const int numThreads = threadingInfo_.numThreads();
1893
1894     if (numThreads == 1)
1895     {
1896         matrix dxdf;
1897         spreadForceWrapper(x, f, virialHandling, fshift, dxdf, true, iparams_, ilists_, pbc_null);
1898
1899         if (virialHandling == VirialHandling::NonLinear)
1900         {
1901             for (int i = 0; i < DIM; i++)
1902             {
1903                 for (int j = 0; j < DIM; j++)
1904                 {
1905                     virial[i][j] += -0.5 * dxdf[i][j];
1906                 }
1907             }
1908         }
1909     }
1910     else
1911     {
1912         /* First spread the vsites that might depend on non-local vsites */
1913         auto& nlDependentVSites = threadingInfo_.threadDataNonLocalDependent();
1914         spreadForceWrapper(x, f, virialHandling, fshift, nlDependentVSites.dxdf, true, iparams_,
1915                            nlDependentVSites.ilist, pbc_null);
1916
1917 #pragma omp parallel num_threads(numThreads)
1918         {
1919             try
1920             {
1921                 int          thread = gmx_omp_get_thread_num();
1922                 VsiteThread& tData  = threadingInfo_.threadData(thread);
1923
1924                 ArrayRef<RVec> fshift_t;
1925                 if (virialHandling == VirialHandling::Pbc)
1926                 {
1927                     if (thread == 0)
1928                     {
1929                         fshift_t = fshift;
1930                     }
1931                     else
1932                     {
1933                         fshift_t = tData.fshift;
1934
1935                         for (int i = 0; i < SHIFTS; i++)
1936                         {
1937                             clear_rvec(fshift_t[i]);
1938                         }
1939                     }
1940                 }
1941
1942                 if (tData.useInterdependentTask)
1943                 {
1944                     /* Spread the vsites that spread outside our local range.
1945                      * This is done using a thread-local force buffer force.
1946                      * First we need to copy the input vsite forces to force.
1947                      */
1948                     InterdependentTask* idTask = &tData.idTask;
1949
1950                     /* Clear the buffer elements set by our task during
1951                      * the last call to spread_vsite_f.
1952                      */
1953                     clearTaskForceBufferUsedElements(idTask);
1954
1955                     int nvsite = idTask->vsite.size();
1956                     for (int i = 0; i < nvsite; i++)
1957                     {
1958                         copy_rvec(f[idTask->vsite[i]], idTask->force[idTask->vsite[i]]);
1959                     }
1960                     spreadForceWrapper(x, idTask->force, virialHandling, fshift_t, tData.dxdf, true,
1961                                        iparams_, tData.idTask.ilist, pbc_null);
1962
1963                     /* We need a barrier before reducing forces below
1964                      * that have been produced by a different thread above.
1965                      */
1966 #pragma omp barrier
1967
1968                     /* Loop over all thread task and reduce forces they
1969                      * produced on atoms that fall in our range.
1970                      * Note that atomic reduction would be a simpler solution,
1971                      * but that might not have good support on all platforms.
1972                      */
1973                     int ntask = idTask->reduceTask.size();
1974                     for (int ti = 0; ti < ntask; ti++)
1975                     {
1976                         const InterdependentTask& idt_foreign =
1977                                 threadingInfo_.threadData(idTask->reduceTask[ti]).idTask;
1978                         const AtomIndex& atomList  = idt_foreign.atomIndex[thread];
1979                         const RVec*      f_foreign = idt_foreign.force.data();
1980
1981                         for (int ind : atomList.atom)
1982                         {
1983                             rvec_inc(f[ind], f_foreign[ind]);
1984                             /* Clearing of f_foreign is done at the next step */
1985                         }
1986                     }
1987                     /* Clear the vsite forces, both in f and force */
1988                     for (int i = 0; i < nvsite; i++)
1989                     {
1990                         int ind = tData.idTask.vsite[i];
1991                         clear_rvec(f[ind]);
1992                         clear_rvec(tData.idTask.force[ind]);
1993                     }
1994                 }
1995
1996                 /* Spread the vsites that spread locally only */
1997                 spreadForceWrapper(x, f, virialHandling, fshift_t, tData.dxdf, false, iparams_,
1998                                    tData.ilist, pbc_null);
1999             }
2000             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
2001         }
2002
2003         if (virialHandling == VirialHandling::Pbc)
2004         {
2005             for (int th = 1; th < numThreads; th++)
2006             {
2007                 for (int i = 0; i < SHIFTS; i++)
2008                 {
2009                     rvec_inc(fshift[i], threadingInfo_.threadData(th).fshift[i]);
2010                 }
2011             }
2012         }
2013
2014         if (virialHandling == VirialHandling::NonLinear)
2015         {
2016             for (int th = 0; th < numThreads + 1; th++)
2017             {
2018                 /* MSVC doesn't like matrix references, so we use a pointer */
2019                 const matrix& dxdf = threadingInfo_.threadData(th).dxdf;
2020
2021                 for (int i = 0; i < DIM; i++)
2022                 {
2023                     for (int j = 0; j < DIM; j++)
2024                     {
2025                         virial[i][j] += -0.5 * dxdf[i][j];
2026                     }
2027                 }
2028             }
2029         }
2030     }
2031
2032     if (useDomdec)
2033     {
2034         dd_move_f_vsites(*domainInfo_.domdec_, f, fshift);
2035     }
2036
2037     inc_nrnb(nrnb, eNR_VSITE2, vsite_count(ilists_, F_VSITE2));
2038     inc_nrnb(nrnb, eNR_VSITE2FD, vsite_count(ilists_, F_VSITE2FD));
2039     inc_nrnb(nrnb, eNR_VSITE3, vsite_count(ilists_, F_VSITE3));
2040     inc_nrnb(nrnb, eNR_VSITE3FD, vsite_count(ilists_, F_VSITE3FD));
2041     inc_nrnb(nrnb, eNR_VSITE3FAD, vsite_count(ilists_, F_VSITE3FAD));
2042     inc_nrnb(nrnb, eNR_VSITE3OUT, vsite_count(ilists_, F_VSITE3OUT));
2043     inc_nrnb(nrnb, eNR_VSITE4FD, vsite_count(ilists_, F_VSITE4FD));
2044     inc_nrnb(nrnb, eNR_VSITE4FDN, vsite_count(ilists_, F_VSITE4FDN));
2045     inc_nrnb(nrnb, eNR_VSITEN, vsite_count(ilists_, F_VSITEN));
2046
2047     wallcycle_stop(wcycle, ewcVSITESPREAD);
2048 }
2049
2050 /*! \brief Returns the an array with group indices for each atom
2051  *
2052  * \param[in] grouping  The paritioning of the atom range into atom groups
2053  */
2054 static std::vector<int> makeAtomToGroupMapping(const gmx::RangePartitioning& grouping)
2055 {
2056     std::vector<int> atomToGroup(grouping.fullRange().end(), 0);
2057
2058     for (int group = 0; group < grouping.numBlocks(); group++)
2059     {
2060         auto block = grouping.block(group);
2061         std::fill(atomToGroup.begin() + block.begin(), atomToGroup.begin() + block.end(), group);
2062     }
2063
2064     return atomToGroup;
2065 }
2066
2067 int countNonlinearVsites(const gmx_mtop_t& mtop)
2068 {
2069     int numNonlinearVsites = 0;
2070     for (const gmx_molblock_t& molb : mtop.molblock)
2071     {
2072         const gmx_moltype_t& molt = mtop.moltype[molb.type];
2073
2074         for (const auto& ilist : extractILists(molt.ilist, IF_VSITE))
2075         {
2076             if (ilist.functionType != F_VSITE2 && ilist.functionType != F_VSITE3
2077                 && ilist.functionType != F_VSITEN)
2078             {
2079                 numNonlinearVsites += molb.nmol * ilist.iatoms.size() / (1 + NRAL(ilist.functionType));
2080             }
2081         }
2082     }
2083
2084     return numNonlinearVsites;
2085 }
2086
2087 void VirtualSitesHandler::spreadForces(ArrayRef<const RVec> x,
2088                                        ArrayRef<RVec>       f,
2089                                        const VirialHandling virialHandling,
2090                                        ArrayRef<RVec>       fshift,
2091                                        matrix               virial,
2092                                        t_nrnb*              nrnb,
2093                                        const matrix         box,
2094                                        gmx_wallcycle*       wcycle)
2095 {
2096     impl_->spreadForces(x, f, virialHandling, fshift, virial, nrnb, box, wcycle);
2097 }
2098
2099 int countInterUpdategroupVsites(const gmx_mtop_t&                           mtop,
2100                                 gmx::ArrayRef<const gmx::RangePartitioning> updateGroupingPerMoleculetype)
2101 {
2102     int n_intercg_vsite = 0;
2103     for (const gmx_molblock_t& molb : mtop.molblock)
2104     {
2105         const gmx_moltype_t& molt = mtop.moltype[molb.type];
2106
2107         std::vector<int> atomToGroup;
2108         if (!updateGroupingPerMoleculetype.empty())
2109         {
2110             atomToGroup = makeAtomToGroupMapping(updateGroupingPerMoleculetype[molb.type]);
2111         }
2112         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2113         {
2114             const int              nral = NRAL(ftype);
2115             const InteractionList& il   = molt.ilist[ftype];
2116             for (int i = 0; i < il.size(); i += 1 + nral)
2117             {
2118                 bool isInterGroup = atomToGroup.empty();
2119                 if (!isInterGroup)
2120                 {
2121                     const int group = atomToGroup[il.iatoms[1 + i]];
2122                     for (int a = 1; a < nral; a++)
2123                     {
2124                         if (atomToGroup[il.iatoms[1 + a]] != group)
2125                         {
2126                             isInterGroup = true;
2127                             break;
2128                         }
2129                     }
2130                 }
2131                 if (isInterGroup)
2132                 {
2133                     n_intercg_vsite += molb.nmol;
2134                 }
2135             }
2136         }
2137     }
2138
2139     return n_intercg_vsite;
2140 }
2141
2142 std::unique_ptr<VirtualSitesHandler> makeVirtualSitesHandler(const gmx_mtop_t& mtop,
2143                                                              const t_commrec*  cr,
2144                                                              PbcType           pbcType)
2145 {
2146     GMX_RELEASE_ASSERT(cr != nullptr, "We need a valid commrec");
2147
2148     std::unique_ptr<VirtualSitesHandler> vsite;
2149
2150     /* check if there are vsites */
2151     int nvsite = 0;
2152     for (int ftype = 0; ftype < F_NRE; ftype++)
2153     {
2154         if (interaction_function[ftype].flags & IF_VSITE)
2155         {
2156             GMX_ASSERT(ftype >= c_ftypeVsiteStart && ftype < c_ftypeVsiteEnd,
2157                        "c_ftypeVsiteStart and/or c_ftypeVsiteEnd do not have correct values");
2158
2159             nvsite += gmx_mtop_ftype_count(&mtop, ftype);
2160         }
2161         else
2162         {
2163             GMX_ASSERT(ftype < c_ftypeVsiteStart || ftype >= c_ftypeVsiteEnd,
2164                        "c_ftypeVsiteStart and/or c_ftypeVsiteEnd do not have correct values");
2165         }
2166     }
2167
2168     if (nvsite == 0)
2169     {
2170         return vsite;
2171     }
2172
2173     return std::make_unique<VirtualSitesHandler>(mtop, cr->dd, pbcType);
2174 }
2175
2176 ThreadingInfo::ThreadingInfo() : numThreads_(gmx_omp_nthreads_get(emntVSITE))
2177 {
2178     if (numThreads_ > 1)
2179     {
2180         /* We need one extra thread data structure for the overlap vsites */
2181         tData_.resize(numThreads_ + 1);
2182 #pragma omp parallel for num_threads(numThreads_) schedule(static)
2183         for (int thread = 0; thread < numThreads_; thread++)
2184         {
2185             try
2186             {
2187                 tData_[thread] = std::make_unique<VsiteThread>();
2188
2189                 InterdependentTask& idTask = tData_[thread]->idTask;
2190                 idTask.nuse                = 0;
2191                 idTask.atomIndex.resize(numThreads_);
2192             }
2193             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
2194         }
2195         if (numThreads_ > 1)
2196         {
2197             tData_[numThreads_] = std::make_unique<VsiteThread>();
2198         }
2199     }
2200 }
2201
2202 //! Returns the number of inter update-group vsites
2203 static int getNumInterUpdategroupVsites(const gmx_mtop_t& mtop, const gmx_domdec_t* domdec)
2204 {
2205     gmx::ArrayRef<const gmx::RangePartitioning> updateGroupingPerMoleculetype;
2206     if (domdec)
2207     {
2208         updateGroupingPerMoleculetype = getUpdateGroupingPerMoleculetype(*domdec);
2209     }
2210
2211     return countInterUpdategroupVsites(mtop, updateGroupingPerMoleculetype);
2212 }
2213
2214 VirtualSitesHandler::Impl::Impl(const gmx_mtop_t& mtop, gmx_domdec_t* domdec, const PbcType pbcType) :
2215     numInterUpdategroupVirtualSites_(getNumInterUpdategroupVsites(mtop, domdec)),
2216     domainInfo_({ pbcType, pbcType != PbcType::No && numInterUpdategroupVirtualSites_ > 0, domdec }),
2217     iparams_(mtop.ffparams.iparams)
2218 {
2219 }
2220
2221 VirtualSitesHandler::VirtualSitesHandler(const gmx_mtop_t& mtop, gmx_domdec_t* domdec, const PbcType pbcType) :
2222     impl_(new Impl(mtop, domdec, pbcType))
2223 {
2224 }
2225
2226 //! Flag that atom \p atom which is home in another task, if it has not already been added before
2227 static inline void flagAtom(InterdependentTask* idTask, const int atom, const int numThreads, const int numAtomsPerThread)
2228 {
2229     if (!idTask->use[atom])
2230     {
2231         idTask->use[atom] = true;
2232         int thread        = atom / numAtomsPerThread;
2233         /* Assign all non-local atom force writes to thread 0 */
2234         if (thread >= numThreads)
2235         {
2236             thread = 0;
2237         }
2238         idTask->atomIndex[thread].atom.push_back(atom);
2239     }
2240 }
2241
2242 /*! \brief Here we try to assign all vsites that are in our local range.
2243  *
2244  * Our task local atom range is tData->rangeStart - tData->rangeEnd.
2245  * Vsites that depend only on local atoms, as indicated by taskIndex[]==thread,
2246  * are assigned to task tData->ilist. Vsites that depend on non-local atoms
2247  * but not on other vsites are assigned to task tData->id_task.ilist.
2248  * taskIndex[] is set for all vsites in our range, either to our local tasks
2249  * or to the single last task as taskIndex[]=2*nthreads.
2250  */
2251 static void assignVsitesToThread(VsiteThread*                    tData,
2252                                  int                             thread,
2253                                  int                             nthread,
2254                                  int                             natperthread,
2255                                  gmx::ArrayRef<int>              taskIndex,
2256                                  ArrayRef<const InteractionList> ilist,
2257                                  ArrayRef<const t_iparams>       ip,
2258                                  const unsigned short*           ptype)
2259 {
2260     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2261     {
2262         tData->ilist[ftype].clear();
2263         tData->idTask.ilist[ftype].clear();
2264
2265         int        nral1 = 1 + NRAL(ftype);
2266         int        inc   = nral1;
2267         const int* iat   = ilist[ftype].iatoms.data();
2268         for (int i = 0; i < ilist[ftype].size();)
2269         {
2270             if (ftype == F_VSITEN)
2271             {
2272                 /* The 3 below is from 1+NRAL(ftype)=3 */
2273                 inc = ip[iat[i]].vsiten.n * 3;
2274             }
2275
2276             if (iat[1 + i] < tData->rangeStart || iat[1 + i] >= tData->rangeEnd)
2277             {
2278                 /* This vsite belongs to a different thread */
2279                 i += inc;
2280                 continue;
2281             }
2282
2283             /* We would like to assign this vsite to task thread,
2284              * but it might depend on atoms outside the atom range of thread
2285              * or on another vsite not assigned to task thread.
2286              */
2287             int task = thread;
2288             if (ftype != F_VSITEN)
2289             {
2290                 for (int j = i + 2; j < i + nral1; j++)
2291                 {
2292                     /* Do a range check to avoid a harmless race on taskIndex */
2293                     if (iat[j] < tData->rangeStart || iat[j] >= tData->rangeEnd || taskIndex[iat[j]] != thread)
2294                     {
2295                         if (!tData->useInterdependentTask || ptype[iat[j]] == eptVSite)
2296                         {
2297                             /* At least one constructing atom is a vsite
2298                              * that is not assigned to the same thread.
2299                              * Put this vsite into a separate task.
2300                              */
2301                             task = 2 * nthread;
2302                             break;
2303                         }
2304
2305                         /* There are constructing atoms outside our range,
2306                          * put this vsite into a second task to be executed
2307                          * on the same thread. During construction no barrier
2308                          * is needed between the two tasks on the same thread.
2309                          * During spreading we need to run this task with
2310                          * an additional thread-local intermediate force buffer
2311                          * (or atomic reduction) and a barrier between the two
2312                          * tasks.
2313                          */
2314                         task = nthread + thread;
2315                     }
2316                 }
2317             }
2318             else
2319             {
2320                 for (int j = i + 2; j < i + inc; j += 3)
2321                 {
2322                     /* Do a range check to avoid a harmless race on taskIndex */
2323                     if (iat[j] < tData->rangeStart || iat[j] >= tData->rangeEnd || taskIndex[iat[j]] != thread)
2324                     {
2325                         GMX_ASSERT(ptype[iat[j]] != eptVSite,
2326                                    "A vsite to be assigned in assignVsitesToThread has a vsite as "
2327                                    "a constructing atom that does not belong to our task, such "
2328                                    "vsites should be assigned to the single 'master' task");
2329
2330                         task = nthread + thread;
2331                     }
2332                 }
2333             }
2334
2335             /* Update this vsite's thread index entry */
2336             taskIndex[iat[1 + i]] = task;
2337
2338             if (task == thread || task == nthread + thread)
2339             {
2340                 /* Copy this vsite to the thread data struct of thread */
2341                 InteractionList* il_task;
2342                 if (task == thread)
2343                 {
2344                     il_task = &tData->ilist[ftype];
2345                 }
2346                 else
2347                 {
2348                     il_task = &tData->idTask.ilist[ftype];
2349                 }
2350                 /* Copy the vsite data to the thread-task local array */
2351                 il_task->push_back(iat[i], nral1 - 1, iat + i + 1);
2352                 if (task == nthread + thread)
2353                 {
2354                     /* This vsite write outside our own task force block.
2355                      * Put it into the interdependent task list and flag
2356                      * the atoms involved for reduction.
2357                      */
2358                     tData->idTask.vsite.push_back(iat[i + 1]);
2359                     if (ftype != F_VSITEN)
2360                     {
2361                         for (int j = i + 2; j < i + nral1; j++)
2362                         {
2363                             flagAtom(&tData->idTask, iat[j], nthread, natperthread);
2364                         }
2365                     }
2366                     else
2367                     {
2368                         for (int j = i + 2; j < i + inc; j += 3)
2369                         {
2370                             flagAtom(&tData->idTask, iat[j], nthread, natperthread);
2371                         }
2372                     }
2373                 }
2374             }
2375
2376             i += inc;
2377         }
2378     }
2379 }
2380
2381 /*! \brief Assign all vsites with taskIndex[]==task to task tData */
2382 static void assignVsitesToSingleTask(VsiteThread*                    tData,
2383                                      int                             task,
2384                                      gmx::ArrayRef<const int>        taskIndex,
2385                                      ArrayRef<const InteractionList> ilist,
2386                                      ArrayRef<const t_iparams>       ip)
2387 {
2388     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2389     {
2390         tData->ilist[ftype].clear();
2391         tData->idTask.ilist[ftype].clear();
2392
2393         int              nral1   = 1 + NRAL(ftype);
2394         int              inc     = nral1;
2395         const int*       iat     = ilist[ftype].iatoms.data();
2396         InteractionList* il_task = &tData->ilist[ftype];
2397
2398         for (int i = 0; i < ilist[ftype].size();)
2399         {
2400             if (ftype == F_VSITEN)
2401             {
2402                 /* The 3 below is from 1+NRAL(ftype)=3 */
2403                 inc = ip[iat[i]].vsiten.n * 3;
2404             }
2405             /* Check if the vsite is assigned to our task */
2406             if (taskIndex[iat[1 + i]] == task)
2407             {
2408                 /* Copy the vsite data to the thread-task local array */
2409                 il_task->push_back(iat[i], inc - 1, iat + i + 1);
2410             }
2411
2412             i += inc;
2413         }
2414     }
2415 }
2416
2417 void ThreadingInfo::setVirtualSites(ArrayRef<const InteractionList> ilists,
2418                                     ArrayRef<const t_iparams>       iparams,
2419                                     const t_mdatoms&                mdatoms,
2420                                     const bool                      useDomdec)
2421 {
2422     if (numThreads_ <= 1)
2423     {
2424         /* Nothing to do */
2425         return;
2426     }
2427
2428     /* The current way of distributing the vsites over threads in primitive.
2429      * We divide the atom range 0 - natoms_in_vsite uniformly over threads,
2430      * without taking into account how the vsites are distributed.
2431      * Without domain decomposition we at least tighten the upper bound
2432      * of the range (useful for common systems such as a vsite-protein
2433      * in 3-site water).
2434      * With domain decomposition, as long as the vsites are distributed
2435      * uniformly in each domain along the major dimension, usually x,
2436      * it will also perform well.
2437      */
2438     int vsite_atom_range;
2439     int natperthread;
2440     if (!useDomdec)
2441     {
2442         vsite_atom_range = -1;
2443         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2444         {
2445             { // TODO remove me
2446                 if (ftype != F_VSITEN)
2447                 {
2448                     int                 nral1 = 1 + NRAL(ftype);
2449                     ArrayRef<const int> iat   = ilists[ftype].iatoms;
2450                     for (int i = 0; i < ilists[ftype].size(); i += nral1)
2451                     {
2452                         for (int j = i + 1; j < i + nral1; j++)
2453                         {
2454                             vsite_atom_range = std::max(vsite_atom_range, iat[j]);
2455                         }
2456                     }
2457                 }
2458                 else
2459                 {
2460                     int vs_ind_end;
2461
2462                     ArrayRef<const int> iat = ilists[ftype].iatoms;
2463
2464                     int i = 0;
2465                     while (i < ilists[ftype].size())
2466                     {
2467                         /* The 3 below is from 1+NRAL(ftype)=3 */
2468                         vs_ind_end = i + iparams[iat[i]].vsiten.n * 3;
2469
2470                         vsite_atom_range = std::max(vsite_atom_range, iat[i + 1]);
2471                         while (i < vs_ind_end)
2472                         {
2473                             vsite_atom_range = std::max(vsite_atom_range, iat[i + 2]);
2474                             i += 3;
2475                         }
2476                     }
2477                 }
2478             }
2479         }
2480         vsite_atom_range++;
2481         natperthread = (vsite_atom_range + numThreads_ - 1) / numThreads_;
2482     }
2483     else
2484     {
2485         /* Any local or not local atom could be involved in virtual sites.
2486          * But since we usually have very few non-local virtual sites
2487          * (only non-local vsites that depend on local vsites),
2488          * we distribute the local atom range equally over the threads.
2489          * When assigning vsites to threads, we should take care that the last
2490          * threads also covers the non-local range.
2491          */
2492         vsite_atom_range = mdatoms.nr;
2493         natperthread     = (mdatoms.homenr + numThreads_ - 1) / numThreads_;
2494     }
2495
2496     if (debug)
2497     {
2498         fprintf(debug, "virtual site thread dist: natoms %d, range %d, natperthread %d\n",
2499                 mdatoms.nr, vsite_atom_range, natperthread);
2500     }
2501
2502     /* To simplify the vsite assignment, we make an index which tells us
2503      * to which task particles, both non-vsites and vsites, are assigned.
2504      */
2505     taskIndex_.resize(mdatoms.nr);
2506
2507     /* Initialize the task index array. Here we assign the non-vsite
2508      * particles to task=thread, so we easily figure out if vsites
2509      * depend on local and/or non-local particles in assignVsitesToThread.
2510      */
2511     {
2512         int thread = 0;
2513         for (int i = 0; i < mdatoms.nr; i++)
2514         {
2515             if (mdatoms.ptype[i] == eptVSite)
2516             {
2517                 /* vsites are not assigned to a task yet */
2518                 taskIndex_[i] = -1;
2519             }
2520             else
2521             {
2522                 /* assign non-vsite particles to task thread */
2523                 taskIndex_[i] = thread;
2524             }
2525             if (i == (thread + 1) * natperthread && thread < numThreads_)
2526             {
2527                 thread++;
2528             }
2529         }
2530     }
2531
2532 #pragma omp parallel num_threads(numThreads_)
2533     {
2534         try
2535         {
2536             int          thread = gmx_omp_get_thread_num();
2537             VsiteThread& tData  = *tData_[thread];
2538
2539             /* Clear the buffer use flags that were set before */
2540             if (tData.useInterdependentTask)
2541             {
2542                 InterdependentTask& idTask = tData.idTask;
2543
2544                 /* To avoid an extra OpenMP barrier in spread_vsite_f,
2545                  * we clear the force buffer at the next step,
2546                  * so we need to do it here as well.
2547                  */
2548                 clearTaskForceBufferUsedElements(&idTask);
2549
2550                 idTask.vsite.resize(0);
2551                 for (int t = 0; t < numThreads_; t++)
2552                 {
2553                     AtomIndex& atomIndex = idTask.atomIndex[t];
2554                     int        natom     = atomIndex.atom.size();
2555                     for (int i = 0; i < natom; i++)
2556                     {
2557                         idTask.use[atomIndex.atom[i]] = false;
2558                     }
2559                     atomIndex.atom.resize(0);
2560                 }
2561                 idTask.nuse = 0;
2562             }
2563
2564             /* To avoid large f_buf allocations of #threads*vsite_atom_range
2565              * we don't use task2 with more than 200000 atoms. This doesn't
2566              * affect performance, since with such a large range relatively few
2567              * vsites will end up in the separate task.
2568              * Note that useTask2 should be the same for all threads.
2569              */
2570             tData.useInterdependentTask = (vsite_atom_range <= 200000);
2571             if (tData.useInterdependentTask)
2572             {
2573                 size_t              natoms_use_in_vsites = vsite_atom_range;
2574                 InterdependentTask& idTask               = tData.idTask;
2575                 /* To avoid resizing and re-clearing every nstlist steps,
2576                  * we never down size the force buffer.
2577                  */
2578                 if (natoms_use_in_vsites > idTask.force.size() || natoms_use_in_vsites > idTask.use.size())
2579                 {
2580                     idTask.force.resize(natoms_use_in_vsites, { 0, 0, 0 });
2581                     idTask.use.resize(natoms_use_in_vsites, false);
2582                 }
2583             }
2584
2585             /* Assign all vsites that can execute independently on threads */
2586             tData.rangeStart = thread * natperthread;
2587             if (thread < numThreads_ - 1)
2588             {
2589                 tData.rangeEnd = (thread + 1) * natperthread;
2590             }
2591             else
2592             {
2593                 /* The last thread should cover up to the end of the range */
2594                 tData.rangeEnd = mdatoms.nr;
2595             }
2596             assignVsitesToThread(&tData, thread, numThreads_, natperthread, taskIndex_, ilists,
2597                                  iparams, mdatoms.ptype);
2598
2599             if (tData.useInterdependentTask)
2600             {
2601                 /* In the worst case, all tasks write to force ranges of
2602                  * all other tasks, leading to #tasks^2 scaling (this is only
2603                  * the overhead, the actual flops remain constant).
2604                  * But in most cases there is far less coupling. To improve
2605                  * scaling at high thread counts we therefore construct
2606                  * an index to only loop over the actually affected tasks.
2607                  */
2608                 InterdependentTask& idTask = tData.idTask;
2609
2610                 /* Ensure assignVsitesToThread finished on other threads */
2611 #pragma omp barrier
2612
2613                 idTask.spreadTask.resize(0);
2614                 idTask.reduceTask.resize(0);
2615                 for (int t = 0; t < numThreads_; t++)
2616                 {
2617                     /* Do we write to the force buffer of task t? */
2618                     if (!idTask.atomIndex[t].atom.empty())
2619                     {
2620                         idTask.spreadTask.push_back(t);
2621                     }
2622                     /* Does task t write to our force buffer? */
2623                     if (!tData_[t]->idTask.atomIndex[thread].atom.empty())
2624                     {
2625                         idTask.reduceTask.push_back(t);
2626                     }
2627                 }
2628             }
2629         }
2630         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
2631     }
2632     /* Assign all remaining vsites, that will have taskIndex[]=2*vsite->nthreads,
2633      * to a single task that will not run in parallel with other tasks.
2634      */
2635     assignVsitesToSingleTask(tData_[numThreads_].get(), 2 * numThreads_, taskIndex_, ilists, iparams);
2636
2637     if (debug && numThreads_ > 1)
2638     {
2639         fprintf(debug, "virtual site useInterdependentTask %d, nuse:\n",
2640                 static_cast<int>(tData_[0]->useInterdependentTask));
2641         for (int th = 0; th < numThreads_ + 1; th++)
2642         {
2643             fprintf(debug, " %4d", tData_[th]->idTask.nuse);
2644         }
2645         fprintf(debug, "\n");
2646
2647         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2648         {
2649             if (!ilists[ftype].empty())
2650             {
2651                 fprintf(debug, "%-20s thread dist:", interaction_function[ftype].longname);
2652                 for (int th = 0; th < numThreads_ + 1; th++)
2653                 {
2654                     fprintf(debug, " %4d %4d ", tData_[th]->ilist[ftype].size(),
2655                             tData_[th]->idTask.ilist[ftype].size());
2656                 }
2657                 fprintf(debug, "\n");
2658             }
2659         }
2660     }
2661
2662 #ifndef NDEBUG
2663     int nrOrig     = vsiteIlistNrCount(ilists);
2664     int nrThreaded = 0;
2665     for (int th = 0; th < numThreads_ + 1; th++)
2666     {
2667         nrThreaded += vsiteIlistNrCount(tData_[th]->ilist) + vsiteIlistNrCount(tData_[th]->idTask.ilist);
2668     }
2669     GMX_ASSERT(nrThreaded == nrOrig,
2670                "The number of virtual sites assigned to all thread task has to match the total "
2671                "number of virtual sites");
2672 #endif
2673 }
2674
2675 void VirtualSitesHandler::Impl::setVirtualSites(ArrayRef<const InteractionList> ilists,
2676                                                 const t_mdatoms&                mdatoms)
2677 {
2678     ilists_ = ilists;
2679
2680     threadingInfo_.setVirtualSites(ilists, iparams_, mdatoms, domainInfo_.useDomdec());
2681 }
2682
2683 void VirtualSitesHandler::setVirtualSites(ArrayRef<const InteractionList> ilists, const t_mdatoms& mdatoms)
2684 {
2685     impl_->setVirtualSites(ilists, mdatoms);
2686 }
2687
2688 } // namespace gmx