Merge branch release-2021
[alexxy/gromacs.git] / src / gromacs / mdlib / vsite.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017 The GROMACS development team.
7  * Copyright (c) 2018,2019,2020, by the GROMACS development team, led by
8  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
9  * and including many others, as listed in the AUTHORS file in the
10  * top-level source directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 /*! \internal \file
39  * \brief Implements the VirtualSitesHandler class and vsite standalone functions
40  *
41  * \author Berk Hess <hess@kth.se>
42  * \ingroup module_mdlib
43  */
44
45 #include "gmxpre.h"
46
47 #include "vsite.h"
48
49 #include <cstdio>
50
51 #include <algorithm>
52 #include <memory>
53 #include <vector>
54
55 #include "gromacs/domdec/domdec.h"
56 #include "gromacs/domdec/domdec_struct.h"
57 #include "gromacs/gmxlib/network.h"
58 #include "gromacs/gmxlib/nrnb.h"
59 #include "gromacs/math/functions.h"
60 #include "gromacs/math/vec.h"
61 #include "gromacs/mdlib/gmx_omp_nthreads.h"
62 #include "gromacs/mdtypes/commrec.h"
63 #include "gromacs/mdtypes/mdatom.h"
64 #include "gromacs/pbcutil/ishift.h"
65 #include "gromacs/pbcutil/pbc.h"
66 #include "gromacs/timing/wallcycle.h"
67 #include "gromacs/topology/ifunc.h"
68 #include "gromacs/topology/mtop_util.h"
69 #include "gromacs/topology/topology.h"
70 #include "gromacs/utility/exceptions.h"
71 #include "gromacs/utility/fatalerror.h"
72 #include "gromacs/utility/gmxassert.h"
73 #include "gromacs/utility/gmxomp.h"
74
75 /* The strategy used here for assigning virtual sites to (thread-)tasks
76  * is as follows:
77  *
78  * We divide the atom range that vsites operate on (natoms_local with DD,
79  * 0 - last atom involved in vsites without DD) equally over all threads.
80  *
81  * Vsites in the local range constructed from atoms in the local range
82  * and/or other vsites that are fully local are assigned to a simple,
83  * independent task.
84  *
85  * Vsites that are not assigned after using the above criterion get assigned
86  * to a so called "interdependent" thread task when none of the constructing
87  * atoms is a vsite. These tasks are called interdependent, because one task
88  * accesses atoms assigned to a different task/thread.
89  * Note that this option is turned off with large (local) atom counts
90  * to avoid high memory usage.
91  *
92  * Any remaining vsites are assigned to a separate master thread task.
93  */
94 namespace gmx
95 {
96
97 //! VirialHandling is often used outside VirtualSitesHandler class members
98 using VirialHandling = VirtualSitesHandler::VirialHandling;
99
100 /*! \brief Information on PBC and domain decomposition for virtual sites
101  */
102 struct DomainInfo
103 {
104 public:
105     //! Constructs without PBC and DD
106     DomainInfo() = default;
107
108     //! Constructs with PBC and DD, if !=nullptr
109     DomainInfo(PbcType pbcType, bool haveInterUpdateGroupVirtualSites, gmx_domdec_t* domdec) :
110         pbcType_(pbcType),
111         useMolPbc_(pbcType != PbcType::No && haveInterUpdateGroupVirtualSites),
112         domdec_(domdec)
113     {
114     }
115
116     //! Returns whether we are using domain decomposition with more than 1 DD rank
117     bool useDomdec() const { return (domdec_ != nullptr); }
118
119     //! The pbc type
120     const PbcType pbcType_ = PbcType::No;
121     //! Whether molecules are broken over PBC
122     const bool useMolPbc_ = false;
123     //! Pointer to the domain decomposition struct, nullptr without PP DD
124     const gmx_domdec_t* domdec_ = nullptr;
125 };
126
127 /*! \brief List of atom indices belonging to a task
128  */
129 struct AtomIndex
130 {
131     //! List of atom indices
132     std::vector<int> atom;
133 };
134
135 /*! \brief Data structure for thread tasks that use constructing atoms outside their own atom range
136  */
137 struct InterdependentTask
138 {
139     //! The interaction lists, only vsite entries are used
140     InteractionLists ilist;
141     //! Thread/task-local force buffer
142     std::vector<RVec> force;
143     //! The atom indices of the vsites of our task
144     std::vector<int> vsite;
145     //! Flags if elements in force are spread to or not
146     std::vector<bool> use;
147     //! The number of entries set to true in use
148     int nuse = 0;
149     //! Array of atoms indices, size nthreads, covering all nuse set elements in use
150     std::vector<AtomIndex> atomIndex;
151     //! List of tasks (force blocks) this task spread forces to
152     std::vector<int> spreadTask;
153     //! List of tasks that write to this tasks force block range
154     std::vector<int> reduceTask;
155 };
156
157 /*! \brief Vsite thread task data structure
158  */
159 struct VsiteThread
160 {
161     //! Start of atom range of this task
162     int rangeStart;
163     //! End of atom range of this task
164     int rangeEnd;
165     //! The interaction lists, only vsite entries are used
166     std::array<InteractionList, F_NRE> ilist;
167     //! Local fshift accumulation buffer
168     std::array<RVec, SHIFTS> fshift;
169     //! Local virial dx*df accumulation buffer
170     matrix dxdf;
171     //! Tells if interdependent task idTask should be used (in addition to the rest of this task), this bool has the same value on all threads
172     bool useInterdependentTask;
173     //! Data for vsites that involve constructing atoms in the atom range of other threads/tasks
174     InterdependentTask idTask;
175
176     /*! \brief Constructor */
177     VsiteThread()
178     {
179         rangeStart = -1;
180         rangeEnd   = -1;
181         for (auto& elem : fshift)
182         {
183             elem = { 0.0_real, 0.0_real, 0.0_real };
184         }
185         clear_mat(dxdf);
186         useInterdependentTask = false;
187     }
188 };
189
190
191 /*! \brief Information on how the virtual site work is divided over thread tasks
192  */
193 class ThreadingInfo
194 {
195 public:
196     //! Constructor, retrieves the number of threads to use from gmx_omp_nthreads.h
197     ThreadingInfo();
198
199     //! Returns the number of threads to use for vsite operations
200     int numThreads() const { return numThreads_; }
201
202     //! Returns the thread data for the given thread
203     const VsiteThread& threadData(int threadIndex) const { return *tData_[threadIndex]; }
204
205     //! Returns the thread data for the given thread
206     VsiteThread& threadData(int threadIndex) { return *tData_[threadIndex]; }
207
208     //! Returns the thread data for vsites that depend on non-local vsites
209     const VsiteThread& threadDataNonLocalDependent() const { return *tData_[numThreads_]; }
210
211     //! Returns the thread data for vsites that depend on non-local vsites
212     VsiteThread& threadDataNonLocalDependent() { return *tData_[numThreads_]; }
213
214     //! Set VSites and distribute VSite work over threads, should be called after DD partitioning
215     void setVirtualSites(ArrayRef<const InteractionList> ilist,
216                          ArrayRef<const t_iparams>       iparams,
217                          const t_mdatoms&                mdatoms,
218                          bool                            useDomdec);
219
220 private:
221     //! Number of threads used for vsite operations
222     const int numThreads_;
223     //! Thread local vsites and work structs
224     std::vector<std::unique_ptr<VsiteThread>> tData_;
225     //! Work array for dividing vsites over threads
226     std::vector<int> taskIndex_;
227 };
228
229 /*! \brief Impl class for VirtualSitesHandler
230  */
231 class VirtualSitesHandler::Impl
232 {
233 public:
234     //! Constructor, domdec should be nullptr without DD
235     Impl(const gmx_mtop_t& mtop, gmx_domdec_t* domdec, PbcType pbcType);
236
237     //! Returns the number of virtual sites acting over multiple update groups
238     int numInterUpdategroupVirtualSites() const { return numInterUpdategroupVirtualSites_; }
239
240     //! Set VSites and distribute VSite work over threads, should be called after DD partitioning
241     void setVirtualSites(ArrayRef<const InteractionList> ilist, const t_mdatoms& mdatoms);
242
243     /*! \brief Create positions of vsite atoms based for the local system
244      *
245      * \param[in,out] x        The coordinates
246      * \param[in]     dt       The time step
247      * \param[in,out] v        When != nullptr, velocities for vsites are set as displacement/dt
248      * \param[in]     box      The box
249      */
250     void construct(ArrayRef<RVec> x, real dt, ArrayRef<RVec> v, const matrix box) const;
251
252     /*! \brief Spread the force operating on the vsite atoms on the surrounding atoms.
253      *
254      * vsite should point to a valid object.
255      * The virialHandling parameter determines how virial contributions are handled.
256      * If this is set to Linear, shift forces are accumulated into fshift.
257      * If this is set to NonLinear, non-linear contributions are added to virial.
258      * This non-linear correction is required when the virial is not calculated
259      * afterwards from the particle position and forces, but in a different way,
260      * as for instance for the PME mesh contribution.
261      */
262     void spreadForces(ArrayRef<const RVec> x,
263                       ArrayRef<RVec>       f,
264                       VirialHandling       virialHandling,
265                       ArrayRef<RVec>       fshift,
266                       matrix               virial,
267                       t_nrnb*              nrnb,
268                       const matrix         box,
269                       gmx_wallcycle*       wcycle);
270
271 private:
272     //! The number of vsites that cross update groups, when =0 no PBC treatment is needed
273     const int numInterUpdategroupVirtualSites_;
274     //! PBC and DD information
275     const DomainInfo domainInfo_;
276     //! The interaction parameters
277     const ArrayRef<const t_iparams> iparams_;
278     //! The interaction lists
279     ArrayRef<const InteractionList> ilists_;
280     //! Information for handling vsite threading
281     ThreadingInfo threadingInfo_;
282 };
283
284 VirtualSitesHandler::~VirtualSitesHandler() = default;
285
286 int VirtualSitesHandler::numInterUpdategroupVirtualSites() const
287 {
288     return impl_->numInterUpdategroupVirtualSites();
289 }
290
291 /*! \brief Returns the sum of the vsite ilist sizes over all vsite types
292  *
293  * \param[in] ilist  The interaction list
294  */
295 static int vsiteIlistNrCount(ArrayRef<const InteractionList> ilist)
296 {
297     int nr = 0;
298     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
299     {
300         nr += ilist[ftype].size();
301     }
302
303     return nr;
304 }
305
306 //! Computes the distance between xi and xj, pbc is used when pbc!=nullptr
307 static int pbc_rvec_sub(const t_pbc* pbc, const rvec xi, const rvec xj, rvec dx)
308 {
309     if (pbc)
310     {
311         return pbc_dx_aiuc(pbc, xi, xj, dx);
312     }
313     else
314     {
315         rvec_sub(xi, xj, dx);
316         return CENTRAL;
317     }
318 }
319
320 //! Returns the 1/norm(x)
321 static inline real inverseNorm(const rvec x)
322 {
323     return gmx::invsqrt(iprod(x, x));
324 }
325
326 #ifndef DOXYGEN
327 /* Vsite construction routines */
328
329 static void constr_vsite1(const rvec xi, rvec x)
330 {
331     copy_rvec(xi, x);
332
333     /* TOTAL: 0 flops */
334 }
335
336 static void constr_vsite2(const rvec xi, const rvec xj, rvec x, real a, const t_pbc* pbc)
337 {
338     real b = 1 - a;
339     /* 1 flop */
340
341     if (pbc)
342     {
343         rvec dx;
344         pbc_dx_aiuc(pbc, xj, xi, dx);
345         x[XX] = xi[XX] + a * dx[XX];
346         x[YY] = xi[YY] + a * dx[YY];
347         x[ZZ] = xi[ZZ] + a * dx[ZZ];
348     }
349     else
350     {
351         x[XX] = b * xi[XX] + a * xj[XX];
352         x[YY] = b * xi[YY] + a * xj[YY];
353         x[ZZ] = b * xi[ZZ] + a * xj[ZZ];
354         /* 9 Flops */
355     }
356
357     /* TOTAL: 10 flops */
358 }
359
360 static void constr_vsite2FD(const rvec xi, const rvec xj, rvec x, real a, const t_pbc* pbc)
361 {
362     rvec xij;
363     pbc_rvec_sub(pbc, xj, xi, xij);
364     /* 3 flops */
365
366     const real b = a * inverseNorm(xij);
367     /* 6 + 10 flops */
368
369     x[XX] = xi[XX] + b * xij[XX];
370     x[YY] = xi[YY] + b * xij[YY];
371     x[ZZ] = xi[ZZ] + b * xij[ZZ];
372     /* 6 Flops */
373
374     /* TOTAL: 25 flops */
375 }
376
377 static void constr_vsite3(const rvec xi, const rvec xj, const rvec xk, rvec x, real a, real b, const t_pbc* pbc)
378 {
379     real c = 1 - a - b;
380     /* 2 flops */
381
382     if (pbc)
383     {
384         rvec dxj, dxk;
385
386         pbc_dx_aiuc(pbc, xj, xi, dxj);
387         pbc_dx_aiuc(pbc, xk, xi, dxk);
388         x[XX] = xi[XX] + a * dxj[XX] + b * dxk[XX];
389         x[YY] = xi[YY] + a * dxj[YY] + b * dxk[YY];
390         x[ZZ] = xi[ZZ] + a * dxj[ZZ] + b * dxk[ZZ];
391     }
392     else
393     {
394         x[XX] = c * xi[XX] + a * xj[XX] + b * xk[XX];
395         x[YY] = c * xi[YY] + a * xj[YY] + b * xk[YY];
396         x[ZZ] = c * xi[ZZ] + a * xj[ZZ] + b * xk[ZZ];
397         /* 15 Flops */
398     }
399
400     /* TOTAL: 17 flops */
401 }
402
403 static void constr_vsite3FD(const rvec xi, const rvec xj, const rvec xk, rvec x, real a, real b, const t_pbc* pbc)
404 {
405     rvec xij, xjk, temp;
406     real c;
407
408     pbc_rvec_sub(pbc, xj, xi, xij);
409     pbc_rvec_sub(pbc, xk, xj, xjk);
410     /* 6 flops */
411
412     /* temp goes from i to a point on the line jk */
413     temp[XX] = xij[XX] + a * xjk[XX];
414     temp[YY] = xij[YY] + a * xjk[YY];
415     temp[ZZ] = xij[ZZ] + a * xjk[ZZ];
416     /* 6 flops */
417
418     c = b * inverseNorm(temp);
419     /* 6 + 10 flops */
420
421     x[XX] = xi[XX] + c * temp[XX];
422     x[YY] = xi[YY] + c * temp[YY];
423     x[ZZ] = xi[ZZ] + c * temp[ZZ];
424     /* 6 Flops */
425
426     /* TOTAL: 34 flops */
427 }
428
429 static void constr_vsite3FAD(const rvec xi, const rvec xj, const rvec xk, rvec x, real a, real b, const t_pbc* pbc)
430 {
431     rvec xij, xjk, xp;
432     real a1, b1, c1, invdij;
433
434     pbc_rvec_sub(pbc, xj, xi, xij);
435     pbc_rvec_sub(pbc, xk, xj, xjk);
436     /* 6 flops */
437
438     invdij = inverseNorm(xij);
439     c1     = invdij * invdij * iprod(xij, xjk);
440     xp[XX] = xjk[XX] - c1 * xij[XX];
441     xp[YY] = xjk[YY] - c1 * xij[YY];
442     xp[ZZ] = xjk[ZZ] - c1 * xij[ZZ];
443     a1     = a * invdij;
444     b1     = b * inverseNorm(xp);
445     /* 45 */
446
447     x[XX] = xi[XX] + a1 * xij[XX] + b1 * xp[XX];
448     x[YY] = xi[YY] + a1 * xij[YY] + b1 * xp[YY];
449     x[ZZ] = xi[ZZ] + a1 * xij[ZZ] + b1 * xp[ZZ];
450     /* 12 Flops */
451
452     /* TOTAL: 63 flops */
453 }
454
455 static void
456 constr_vsite3OUT(const rvec xi, const rvec xj, const rvec xk, rvec x, real a, real b, real c, const t_pbc* pbc)
457 {
458     rvec xij, xik, temp;
459
460     pbc_rvec_sub(pbc, xj, xi, xij);
461     pbc_rvec_sub(pbc, xk, xi, xik);
462     cprod(xij, xik, temp);
463     /* 15 Flops */
464
465     x[XX] = xi[XX] + a * xij[XX] + b * xik[XX] + c * temp[XX];
466     x[YY] = xi[YY] + a * xij[YY] + b * xik[YY] + c * temp[YY];
467     x[ZZ] = xi[ZZ] + a * xij[ZZ] + b * xik[ZZ] + c * temp[ZZ];
468     /* 18 Flops */
469
470     /* TOTAL: 33 flops */
471 }
472
473 static void constr_vsite4FD(const rvec   xi,
474                             const rvec   xj,
475                             const rvec   xk,
476                             const rvec   xl,
477                             rvec         x,
478                             real         a,
479                             real         b,
480                             real         c,
481                             const t_pbc* pbc)
482 {
483     rvec xij, xjk, xjl, temp;
484     real d;
485
486     pbc_rvec_sub(pbc, xj, xi, xij);
487     pbc_rvec_sub(pbc, xk, xj, xjk);
488     pbc_rvec_sub(pbc, xl, xj, xjl);
489     /* 9 flops */
490
491     /* temp goes from i to a point on the plane jkl */
492     temp[XX] = xij[XX] + a * xjk[XX] + b * xjl[XX];
493     temp[YY] = xij[YY] + a * xjk[YY] + b * xjl[YY];
494     temp[ZZ] = xij[ZZ] + a * xjk[ZZ] + b * xjl[ZZ];
495     /* 12 flops */
496
497     d = c * inverseNorm(temp);
498     /* 6 + 10 flops */
499
500     x[XX] = xi[XX] + d * temp[XX];
501     x[YY] = xi[YY] + d * temp[YY];
502     x[ZZ] = xi[ZZ] + d * temp[ZZ];
503     /* 6 Flops */
504
505     /* TOTAL: 43 flops */
506 }
507
508 static void constr_vsite4FDN(const rvec   xi,
509                              const rvec   xj,
510                              const rvec   xk,
511                              const rvec   xl,
512                              rvec         x,
513                              real         a,
514                              real         b,
515                              real         c,
516                              const t_pbc* pbc)
517 {
518     rvec xij, xik, xil, ra, rb, rja, rjb, rm;
519     real d;
520
521     pbc_rvec_sub(pbc, xj, xi, xij);
522     pbc_rvec_sub(pbc, xk, xi, xik);
523     pbc_rvec_sub(pbc, xl, xi, xil);
524     /* 9 flops */
525
526     ra[XX] = a * xik[XX];
527     ra[YY] = a * xik[YY];
528     ra[ZZ] = a * xik[ZZ];
529
530     rb[XX] = b * xil[XX];
531     rb[YY] = b * xil[YY];
532     rb[ZZ] = b * xil[ZZ];
533
534     /* 6 flops */
535
536     rvec_sub(ra, xij, rja);
537     rvec_sub(rb, xij, rjb);
538     /* 6 flops */
539
540     cprod(rja, rjb, rm);
541     /* 9 flops */
542
543     d = c * inverseNorm(rm);
544     /* 5+5+1 flops */
545
546     x[XX] = xi[XX] + d * rm[XX];
547     x[YY] = xi[YY] + d * rm[YY];
548     x[ZZ] = xi[ZZ] + d * rm[ZZ];
549     /* 6 Flops */
550
551     /* TOTAL: 47 flops */
552 }
553
554 static int constr_vsiten(const t_iatom* ia, ArrayRef<const t_iparams> ip, ArrayRef<RVec> x, const t_pbc* pbc)
555 {
556     rvec x1, dx;
557     dvec dsum;
558     int  n3, av, ai;
559     real a;
560
561     n3 = 3 * ip[ia[0]].vsiten.n;
562     av = ia[1];
563     ai = ia[2];
564     copy_rvec(x[ai], x1);
565     clear_dvec(dsum);
566     for (int i = 3; i < n3; i += 3)
567     {
568         ai = ia[i + 2];
569         a  = ip[ia[i]].vsiten.a;
570         if (pbc)
571         {
572             pbc_dx_aiuc(pbc, x[ai], x1, dx);
573         }
574         else
575         {
576             rvec_sub(x[ai], x1, dx);
577         }
578         dsum[XX] += a * dx[XX];
579         dsum[YY] += a * dx[YY];
580         dsum[ZZ] += a * dx[ZZ];
581         /* 9 Flops */
582     }
583
584     x[av][XX] = x1[XX] + dsum[XX];
585     x[av][YY] = x1[YY] + dsum[YY];
586     x[av][ZZ] = x1[ZZ] + dsum[ZZ];
587
588     return n3;
589 }
590
591 #endif // DOXYGEN
592
593 //! PBC modes for vsite construction and spreading
594 enum class PbcMode
595 {
596     all, //!< Apply normal, simple PBC for all vsites
597     none //!< No PBC treatment needed
598 };
599
600 /*! \brief Returns the PBC mode based on the system PBC and vsite properties
601  *
602  * \param[in] pbcPtr  A pointer to a PBC struct or nullptr when no PBC treatment is required
603  */
604 static PbcMode getPbcMode(const t_pbc* pbcPtr)
605 {
606     if (pbcPtr == nullptr)
607     {
608         return PbcMode::none;
609     }
610     else
611     {
612         return PbcMode::all;
613     }
614 }
615
616 /*! \brief Executes the vsite construction task for a single thread
617  *
618  * \param[in,out] x   Coordinates to construct vsites for
619  * \param[in]     dt  Time step, needed when v is not empty
620  * \param[in,out] v   When not empty, velocities are generated for virtual sites
621  * \param[in]     ip  Interaction parameters for all interaction, only vsite parameters are used
622  * \param[in]     ilist  The interaction lists, only vsites are usesd
623  * \param[in]     pbc_null  PBC struct, used for PBC distance calculations when !=nullptr
624  */
625 static void construct_vsites_thread(ArrayRef<RVec>                  x,
626                                     const real                      dt,
627                                     ArrayRef<RVec>                  v,
628                                     ArrayRef<const t_iparams>       ip,
629                                     ArrayRef<const InteractionList> ilist,
630                                     const t_pbc*                    pbc_null)
631 {
632     real inv_dt;
633     if (!v.empty())
634     {
635         inv_dt = 1.0 / dt;
636     }
637     else
638     {
639         inv_dt = 1.0;
640     }
641
642     const PbcMode pbcMode = getPbcMode(pbc_null);
643     /* We need another pbc pointer, as with charge groups we switch per vsite */
644     const t_pbc* pbc_null2 = pbc_null;
645
646     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
647     {
648         if (ilist[ftype].empty())
649         {
650             continue;
651         }
652
653         { // TODO remove me
654             int nra = interaction_function[ftype].nratoms;
655             int inc = 1 + nra;
656             int nr  = ilist[ftype].size();
657
658             const t_iatom* ia = ilist[ftype].iatoms.data();
659
660             for (int i = 0; i < nr;)
661             {
662                 int tp = ia[0];
663                 /* The vsite and constructing atoms */
664                 int avsite = ia[1];
665                 int ai     = ia[2];
666                 /* Constants for constructing vsites */
667                 real a1 = ip[tp].vsite.a;
668                 /* Copy the old position */
669                 rvec xv;
670                 copy_rvec(x[avsite], xv);
671
672                 /* Construct the vsite depending on type */
673                 int  aj, ak, al;
674                 real b1, c1;
675                 switch (ftype)
676                 {
677                     case F_VSITE1: constr_vsite1(x[ai], x[avsite]); break;
678                     case F_VSITE2:
679                         aj = ia[3];
680                         constr_vsite2(x[ai], x[aj], x[avsite], a1, pbc_null2);
681                         break;
682                     case F_VSITE2FD:
683                         aj = ia[3];
684                         constr_vsite2FD(x[ai], x[aj], x[avsite], a1, pbc_null2);
685                         break;
686                     case F_VSITE3:
687                         aj = ia[3];
688                         ak = ia[4];
689                         b1 = ip[tp].vsite.b;
690                         constr_vsite3(x[ai], x[aj], x[ak], x[avsite], a1, b1, pbc_null2);
691                         break;
692                     case F_VSITE3FD:
693                         aj = ia[3];
694                         ak = ia[4];
695                         b1 = ip[tp].vsite.b;
696                         constr_vsite3FD(x[ai], x[aj], x[ak], x[avsite], a1, b1, pbc_null2);
697                         break;
698                     case F_VSITE3FAD:
699                         aj = ia[3];
700                         ak = ia[4];
701                         b1 = ip[tp].vsite.b;
702                         constr_vsite3FAD(x[ai], x[aj], x[ak], x[avsite], a1, b1, pbc_null2);
703                         break;
704                     case F_VSITE3OUT:
705                         aj = ia[3];
706                         ak = ia[4];
707                         b1 = ip[tp].vsite.b;
708                         c1 = ip[tp].vsite.c;
709                         constr_vsite3OUT(x[ai], x[aj], x[ak], x[avsite], a1, b1, c1, pbc_null2);
710                         break;
711                     case F_VSITE4FD:
712                         aj = ia[3];
713                         ak = ia[4];
714                         al = ia[5];
715                         b1 = ip[tp].vsite.b;
716                         c1 = ip[tp].vsite.c;
717                         constr_vsite4FD(x[ai], x[aj], x[ak], x[al], x[avsite], a1, b1, c1, pbc_null2);
718                         break;
719                     case F_VSITE4FDN:
720                         aj = ia[3];
721                         ak = ia[4];
722                         al = ia[5];
723                         b1 = ip[tp].vsite.b;
724                         c1 = ip[tp].vsite.c;
725                         constr_vsite4FDN(x[ai], x[aj], x[ak], x[al], x[avsite], a1, b1, c1, pbc_null2);
726                         break;
727                     case F_VSITEN: inc = constr_vsiten(ia, ip, x, pbc_null2); break;
728                     default:
729                         gmx_fatal(FARGS, "No such vsite type %d in %s, line %d", ftype, __FILE__, __LINE__);
730                 }
731
732                 if (pbcMode == PbcMode::all)
733                 {
734                     /* Keep the vsite in the same periodic image as before */
735                     rvec dx;
736                     int  ishift = pbc_dx_aiuc(pbc_null, x[avsite], xv, dx);
737                     if (ishift != CENTRAL)
738                     {
739                         rvec_add(xv, dx, x[avsite]);
740                     }
741                 }
742                 if (!v.empty())
743                 {
744                     /* Calculate velocity of vsite... */
745                     rvec vv;
746                     rvec_sub(x[avsite], xv, vv);
747                     svmul(inv_dt, vv, v[avsite]);
748                 }
749
750                 /* Increment loop variables */
751                 i += inc;
752                 ia += inc;
753             }
754         }
755     }
756 }
757
758 /*! \brief Dispatch the vsite construction tasks for all threads
759  *
760  * \param[in]     threadingInfo  Used to divide work over threads when != nullptr
761  * \param[in,out] x   Coordinates to construct vsites for
762  * \param[in]     dt  Time step, needed when v is not empty
763  * \param[in,out] v   When not empty, velocities are generated for virtual sites
764  * \param[in]     ip  Interaction parameters for all interaction, only vsite parameters are used
765  * \param[in]     ilist  The interaction lists, only vsites are usesd
766  * \param[in]     domainInfo  Information about PBC and DD
767  * \param[in]     box  Used for PBC when PBC is set in domainInfo
768  */
769 static void construct_vsites(const ThreadingInfo*            threadingInfo,
770                              ArrayRef<RVec>                  x,
771                              real                            dt,
772                              ArrayRef<RVec>                  v,
773                              ArrayRef<const t_iparams>       ip,
774                              ArrayRef<const InteractionList> ilist,
775                              const DomainInfo&               domainInfo,
776                              const matrix                    box)
777 {
778     const bool useDomdec = domainInfo.useDomdec();
779
780     t_pbc pbc, *pbc_null;
781
782     /* We only need to do pbc when we have inter update-group vsites.
783      * Note that with domain decomposition we do not need to apply PBC here
784      * when we have at least 3 domains along each dimension. Currently we
785      * do not optimize this case.
786      */
787     if (domainInfo.pbcType_ != PbcType::No && domainInfo.useMolPbc_)
788     {
789         /* This is wasting some CPU time as we now do this multiple times
790          * per MD step.
791          */
792         ivec null_ivec;
793         clear_ivec(null_ivec);
794         pbc_null = set_pbc_dd(
795                 &pbc, domainInfo.pbcType_, useDomdec ? domainInfo.domdec_->numCells : null_ivec, FALSE, box);
796     }
797     else
798     {
799         pbc_null = nullptr;
800     }
801
802     if (useDomdec)
803     {
804         dd_move_x_vsites(*domainInfo.domdec_, box, as_rvec_array(x.data()));
805     }
806
807     if (threadingInfo == nullptr || threadingInfo->numThreads() == 1)
808     {
809         construct_vsites_thread(x, dt, v, ip, ilist, pbc_null);
810     }
811     else
812     {
813 #pragma omp parallel num_threads(threadingInfo->numThreads())
814         {
815             try
816             {
817                 const int          th    = gmx_omp_get_thread_num();
818                 const VsiteThread& tData = threadingInfo->threadData(th);
819                 GMX_ASSERT(tData.rangeStart >= 0,
820                            "The thread data should be initialized before calling construct_vsites");
821
822                 construct_vsites_thread(x, dt, v, ip, tData.ilist, pbc_null);
823                 if (tData.useInterdependentTask)
824                 {
825                     /* Here we don't need a barrier (unlike the spreading),
826                      * since both tasks only construct vsites from particles,
827                      * or local vsites, not from non-local vsites.
828                      */
829                     construct_vsites_thread(x, dt, v, ip, tData.idTask.ilist, pbc_null);
830                 }
831             }
832             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
833         }
834         /* Now we can construct the vsites that might depend on other vsites */
835         construct_vsites_thread(x, dt, v, ip, threadingInfo->threadDataNonLocalDependent().ilist, pbc_null);
836     }
837 }
838
839 void VirtualSitesHandler::Impl::construct(ArrayRef<RVec> x, real dt, ArrayRef<RVec> v, const matrix box) const
840 {
841     construct_vsites(&threadingInfo_, x, dt, v, iparams_, ilists_, domainInfo_, box);
842 }
843
844 void VirtualSitesHandler::construct(ArrayRef<RVec> x, real dt, ArrayRef<RVec> v, const matrix box) const
845 {
846     impl_->construct(x, dt, v, box);
847 }
848
849 void constructVirtualSites(ArrayRef<RVec> x, ArrayRef<const t_iparams> ip, ArrayRef<const InteractionList> ilist)
850
851 {
852     // No PBC, no DD
853     const DomainInfo domainInfo;
854     construct_vsites(nullptr, x, 0, {}, ip, ilist, domainInfo, nullptr);
855 }
856
857 #ifndef DOXYGEN
858 /* Force spreading routines */
859
860 static void spread_vsite1(const t_iatom ia[], ArrayRef<RVec> f)
861 {
862     const int av = ia[1];
863     const int ai = ia[2];
864
865     f[av] += f[ai];
866 }
867
868 template<VirialHandling virialHandling>
869 static void spread_vsite2(const t_iatom        ia[],
870                           real                 a,
871                           ArrayRef<const RVec> x,
872                           ArrayRef<RVec>       f,
873                           ArrayRef<RVec>       fshift,
874                           const t_pbc*         pbc)
875 {
876     rvec    fi, fj, dx;
877     t_iatom av, ai, aj;
878
879     av = ia[1];
880     ai = ia[2];
881     aj = ia[3];
882
883     svmul(1 - a, f[av], fi);
884     svmul(a, f[av], fj);
885     /* 7 flop */
886
887     rvec_inc(f[ai], fi);
888     rvec_inc(f[aj], fj);
889     /* 6 Flops */
890
891     if (virialHandling == VirialHandling::Pbc)
892     {
893         int siv;
894         int sij;
895         if (pbc)
896         {
897             siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
898             sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
899         }
900         else
901         {
902             siv = CENTRAL;
903             sij = CENTRAL;
904         }
905
906         if (siv != CENTRAL || sij != CENTRAL)
907         {
908             rvec_inc(fshift[siv], f[av]);
909             rvec_dec(fshift[CENTRAL], fi);
910             rvec_dec(fshift[sij], fj);
911         }
912     }
913
914     /* TOTAL: 13 flops */
915 }
916
917 void constructVirtualSitesGlobal(const gmx_mtop_t& mtop, gmx::ArrayRef<gmx::RVec> x)
918 {
919     GMX_ASSERT(x.ssize() >= mtop.natoms, "x should contain the whole system");
920     GMX_ASSERT(!mtop.moleculeBlockIndices.empty(),
921                "molblock indices are needed in constructVsitesGlobal");
922
923     for (size_t mb = 0; mb < mtop.molblock.size(); mb++)
924     {
925         const gmx_molblock_t& molb = mtop.molblock[mb];
926         const gmx_moltype_t&  molt = mtop.moltype[molb.type];
927         if (vsiteIlistNrCount(molt.ilist) > 0)
928         {
929             int atomOffset = mtop.moleculeBlockIndices[mb].globalAtomStart;
930             for (int mol = 0; mol < molb.nmol; mol++)
931             {
932                 constructVirtualSites(
933                         x.subArray(atomOffset, molt.atoms.nr), mtop.ffparams.iparams, molt.ilist);
934                 atomOffset += molt.atoms.nr;
935             }
936         }
937     }
938 }
939
940 template<VirialHandling virialHandling>
941 static void spread_vsite2FD(const t_iatom        ia[],
942                             real                 a,
943                             ArrayRef<const RVec> x,
944                             ArrayRef<RVec>       f,
945                             ArrayRef<RVec>       fshift,
946                             matrix               dxdf,
947                             const t_pbc*         pbc)
948 {
949     const int av = ia[1];
950     const int ai = ia[2];
951     const int aj = ia[3];
952     rvec      fv;
953     copy_rvec(f[av], fv);
954
955     rvec xij;
956     int  sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
957     /* 6 flops */
958
959     const real invDistance = inverseNorm(xij);
960     const real b           = a * invDistance;
961     /* 4 + ?10? flops */
962
963     const real fproj = iprod(xij, fv) * invDistance * invDistance;
964
965     rvec fj;
966     fj[XX] = b * (fv[XX] - fproj * xij[XX]);
967     fj[YY] = b * (fv[YY] - fproj * xij[YY]);
968     fj[ZZ] = b * (fv[ZZ] - fproj * xij[ZZ]);
969     /* 9 */
970
971     /* b is already calculated in constr_vsite2FD
972        storing b somewhere will save flops.     */
973
974     f[ai][XX] += fv[XX] - fj[XX];
975     f[ai][YY] += fv[YY] - fj[YY];
976     f[ai][ZZ] += fv[ZZ] - fj[ZZ];
977     f[aj][XX] += fj[XX];
978     f[aj][YY] += fj[YY];
979     f[aj][ZZ] += fj[ZZ];
980     /* 9 Flops */
981
982     if (virialHandling == VirialHandling::Pbc)
983     {
984         int svi;
985         if (pbc)
986         {
987             rvec xvi;
988             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
989         }
990         else
991         {
992             svi = CENTRAL;
993         }
994
995         if (svi != CENTRAL || sji != CENTRAL)
996         {
997             rvec_dec(fshift[svi], fv);
998             fshift[CENTRAL][XX] += fv[XX] - fj[XX];
999             fshift[CENTRAL][YY] += fv[YY] - fj[YY];
1000             fshift[CENTRAL][ZZ] += fv[ZZ] - fj[ZZ];
1001             fshift[sji][XX] += fj[XX];
1002             fshift[sji][YY] += fj[YY];
1003             fshift[sji][ZZ] += fj[ZZ];
1004         }
1005     }
1006
1007     if (virialHandling == VirialHandling::NonLinear)
1008     {
1009         /* Under this condition, the virial for the current forces is not
1010          * calculated from the redistributed forces. This means that
1011          * the effect of non-linear virtual site constructions on the virial
1012          * needs to be added separately. This contribution can be calculated
1013          * in many ways, but the simplest and cheapest way is to use
1014          * the first constructing atom ai as a reference position in space:
1015          * subtract (xv-xi)*fv and add (xj-xi)*fj.
1016          */
1017         rvec xiv;
1018
1019         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1020
1021         for (int i = 0; i < DIM; i++)
1022         {
1023             for (int j = 0; j < DIM; j++)
1024             {
1025                 /* As xix is a linear combination of j and k, use that here */
1026                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * fj[j];
1027             }
1028         }
1029     }
1030
1031     /* TOTAL: 38 flops */
1032 }
1033
1034 template<VirialHandling virialHandling>
1035 static void spread_vsite3(const t_iatom        ia[],
1036                           real                 a,
1037                           real                 b,
1038                           ArrayRef<const RVec> x,
1039                           ArrayRef<RVec>       f,
1040                           ArrayRef<RVec>       fshift,
1041                           const t_pbc*         pbc)
1042 {
1043     rvec fi, fj, fk, dx;
1044     int  av, ai, aj, ak;
1045
1046     av = ia[1];
1047     ai = ia[2];
1048     aj = ia[3];
1049     ak = ia[4];
1050
1051     svmul(1 - a - b, f[av], fi);
1052     svmul(a, f[av], fj);
1053     svmul(b, f[av], fk);
1054     /* 11 flops */
1055
1056     rvec_inc(f[ai], fi);
1057     rvec_inc(f[aj], fj);
1058     rvec_inc(f[ak], fk);
1059     /* 9 Flops */
1060
1061     if (virialHandling == VirialHandling::Pbc)
1062     {
1063         int siv;
1064         int sij;
1065         int sik;
1066         if (pbc)
1067         {
1068             siv = pbc_dx_aiuc(pbc, x[ai], x[av], dx);
1069             sij = pbc_dx_aiuc(pbc, x[ai], x[aj], dx);
1070             sik = pbc_dx_aiuc(pbc, x[ai], x[ak], dx);
1071         }
1072         else
1073         {
1074             siv = CENTRAL;
1075             sij = CENTRAL;
1076             sik = CENTRAL;
1077         }
1078
1079         if (siv != CENTRAL || sij != CENTRAL || sik != CENTRAL)
1080         {
1081             rvec_inc(fshift[siv], f[av]);
1082             rvec_dec(fshift[CENTRAL], fi);
1083             rvec_dec(fshift[sij], fj);
1084             rvec_dec(fshift[sik], fk);
1085         }
1086     }
1087
1088     /* TOTAL: 20 flops */
1089 }
1090
1091 template<VirialHandling virialHandling>
1092 static void spread_vsite3FD(const t_iatom        ia[],
1093                             real                 a,
1094                             real                 b,
1095                             ArrayRef<const RVec> x,
1096                             ArrayRef<RVec>       f,
1097                             ArrayRef<RVec>       fshift,
1098                             matrix               dxdf,
1099                             const t_pbc*         pbc)
1100 {
1101     real    fproj, a1;
1102     rvec    xvi, xij, xjk, xix, fv, temp;
1103     t_iatom av, ai, aj, ak;
1104     int     sji, skj;
1105
1106     av = ia[1];
1107     ai = ia[2];
1108     aj = ia[3];
1109     ak = ia[4];
1110     copy_rvec(f[av], fv);
1111
1112     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1113     skj = pbc_rvec_sub(pbc, x[ak], x[aj], xjk);
1114     /* 6 flops */
1115
1116     /* xix goes from i to point x on the line jk */
1117     xix[XX] = xij[XX] + a * xjk[XX];
1118     xix[YY] = xij[YY] + a * xjk[YY];
1119     xix[ZZ] = xij[ZZ] + a * xjk[ZZ];
1120     /* 6 flops */
1121
1122     const real invDistance = inverseNorm(xix);
1123     const real c           = b * invDistance;
1124     /* 4 + ?10? flops */
1125
1126     fproj = iprod(xix, fv) * invDistance * invDistance; /* = (xix . f)/(xix . xix) */
1127
1128     temp[XX] = c * (fv[XX] - fproj * xix[XX]);
1129     temp[YY] = c * (fv[YY] - fproj * xix[YY]);
1130     temp[ZZ] = c * (fv[ZZ] - fproj * xix[ZZ]);
1131     /* 16 */
1132
1133     /* c is already calculated in constr_vsite3FD
1134        storing c somewhere will save 26 flops!     */
1135
1136     a1 = 1 - a;
1137     f[ai][XX] += fv[XX] - temp[XX];
1138     f[ai][YY] += fv[YY] - temp[YY];
1139     f[ai][ZZ] += fv[ZZ] - temp[ZZ];
1140     f[aj][XX] += a1 * temp[XX];
1141     f[aj][YY] += a1 * temp[YY];
1142     f[aj][ZZ] += a1 * temp[ZZ];
1143     f[ak][XX] += a * temp[XX];
1144     f[ak][YY] += a * temp[YY];
1145     f[ak][ZZ] += a * temp[ZZ];
1146     /* 19 Flops */
1147
1148     if (virialHandling == VirialHandling::Pbc)
1149     {
1150         int svi;
1151         if (pbc)
1152         {
1153             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1154         }
1155         else
1156         {
1157             svi = CENTRAL;
1158         }
1159
1160         if (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL)
1161         {
1162             rvec_dec(fshift[svi], fv);
1163             fshift[CENTRAL][XX] += fv[XX] - (1 + a) * temp[XX];
1164             fshift[CENTRAL][YY] += fv[YY] - (1 + a) * temp[YY];
1165             fshift[CENTRAL][ZZ] += fv[ZZ] - (1 + a) * temp[ZZ];
1166             fshift[sji][XX] += temp[XX];
1167             fshift[sji][YY] += temp[YY];
1168             fshift[sji][ZZ] += temp[ZZ];
1169             fshift[skj][XX] += a * temp[XX];
1170             fshift[skj][YY] += a * temp[YY];
1171             fshift[skj][ZZ] += a * temp[ZZ];
1172         }
1173     }
1174
1175     if (virialHandling == VirialHandling::NonLinear)
1176     {
1177         /* Under this condition, the virial for the current forces is not
1178          * calculated from the redistributed forces. This means that
1179          * the effect of non-linear virtual site constructions on the virial
1180          * needs to be added separately. This contribution can be calculated
1181          * in many ways, but the simplest and cheapest way is to use
1182          * the first constructing atom ai as a reference position in space:
1183          * subtract (xv-xi)*fv and add (xj-xi)*fj + (xk-xi)*fk.
1184          */
1185         rvec xiv;
1186
1187         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1188
1189         for (int i = 0; i < DIM; i++)
1190         {
1191             for (int j = 0; j < DIM; j++)
1192             {
1193                 /* As xix is a linear combination of j and k, use that here */
1194                 dxdf[i][j] += -xiv[i] * fv[j] + xix[i] * temp[j];
1195             }
1196         }
1197     }
1198
1199     /* TOTAL: 61 flops */
1200 }
1201
1202 template<VirialHandling virialHandling>
1203 static void spread_vsite3FAD(const t_iatom        ia[],
1204                              real                 a,
1205                              real                 b,
1206                              ArrayRef<const RVec> x,
1207                              ArrayRef<RVec>       f,
1208                              ArrayRef<RVec>       fshift,
1209                              matrix               dxdf,
1210                              const t_pbc*         pbc)
1211 {
1212     rvec    xvi, xij, xjk, xperp, Fpij, Fppp, fv, f1, f2, f3;
1213     real    a1, b1, c1, c2, invdij, invdij2, invdp, fproj;
1214     t_iatom av, ai, aj, ak;
1215     int     sji, skj;
1216
1217     av = ia[1];
1218     ai = ia[2];
1219     aj = ia[3];
1220     ak = ia[4];
1221     copy_rvec(f[ia[1]], fv);
1222
1223     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1224     skj = pbc_rvec_sub(pbc, x[ak], x[aj], xjk);
1225     /* 6 flops */
1226
1227     invdij    = inverseNorm(xij);
1228     invdij2   = invdij * invdij;
1229     c1        = iprod(xij, xjk) * invdij2;
1230     xperp[XX] = xjk[XX] - c1 * xij[XX];
1231     xperp[YY] = xjk[YY] - c1 * xij[YY];
1232     xperp[ZZ] = xjk[ZZ] - c1 * xij[ZZ];
1233     /* xperp in plane ijk, perp. to ij */
1234     invdp = inverseNorm(xperp);
1235     a1    = a * invdij;
1236     b1    = b * invdp;
1237     /* 45 flops */
1238
1239     /* a1, b1 and c1 are already calculated in constr_vsite3FAD
1240        storing them somewhere will save 45 flops!     */
1241
1242     fproj = iprod(xij, fv) * invdij2;
1243     svmul(fproj, xij, Fpij);                              /* proj. f on xij */
1244     svmul(iprod(xperp, fv) * invdp * invdp, xperp, Fppp); /* proj. f on xperp */
1245     svmul(b1 * fproj, xperp, f3);
1246     /* 23 flops */
1247
1248     rvec_sub(fv, Fpij, f1); /* f1 = f - Fpij */
1249     rvec_sub(f1, Fppp, f2); /* f2 = f - Fpij - Fppp */
1250     for (int d = 0; d < DIM; d++)
1251     {
1252         f1[d] *= a1;
1253         f2[d] *= b1;
1254     }
1255     /* 12 flops */
1256
1257     c2 = 1 + c1;
1258     f[ai][XX] += fv[XX] - f1[XX] + c1 * f2[XX] + f3[XX];
1259     f[ai][YY] += fv[YY] - f1[YY] + c1 * f2[YY] + f3[YY];
1260     f[ai][ZZ] += fv[ZZ] - f1[ZZ] + c1 * f2[ZZ] + f3[ZZ];
1261     f[aj][XX] += f1[XX] - c2 * f2[XX] - f3[XX];
1262     f[aj][YY] += f1[YY] - c2 * f2[YY] - f3[YY];
1263     f[aj][ZZ] += f1[ZZ] - c2 * f2[ZZ] - f3[ZZ];
1264     f[ak][XX] += f2[XX];
1265     f[ak][YY] += f2[YY];
1266     f[ak][ZZ] += f2[ZZ];
1267     /* 30 Flops */
1268
1269     if (virialHandling == VirialHandling::Pbc)
1270     {
1271         int svi;
1272
1273         if (pbc)
1274         {
1275             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1276         }
1277         else
1278         {
1279             svi = CENTRAL;
1280         }
1281
1282         if (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL)
1283         {
1284             rvec_dec(fshift[svi], fv);
1285             fshift[CENTRAL][XX] += fv[XX] - f1[XX] - (1 - c1) * f2[XX] + f3[XX];
1286             fshift[CENTRAL][YY] += fv[YY] - f1[YY] - (1 - c1) * f2[YY] + f3[YY];
1287             fshift[CENTRAL][ZZ] += fv[ZZ] - f1[ZZ] - (1 - c1) * f2[ZZ] + f3[ZZ];
1288             fshift[sji][XX] += f1[XX] - c1 * f2[XX] - f3[XX];
1289             fshift[sji][YY] += f1[YY] - c1 * f2[YY] - f3[YY];
1290             fshift[sji][ZZ] += f1[ZZ] - c1 * f2[ZZ] - f3[ZZ];
1291             fshift[skj][XX] += f2[XX];
1292             fshift[skj][YY] += f2[YY];
1293             fshift[skj][ZZ] += f2[ZZ];
1294         }
1295     }
1296
1297     if (virialHandling == VirialHandling::NonLinear)
1298     {
1299         rvec xiv;
1300         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1301
1302         for (int i = 0; i < DIM; i++)
1303         {
1304             for (int j = 0; j < DIM; j++)
1305             {
1306                 /* Note that xik=xij+xjk, so we have to add xij*f2 */
1307                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * (f1[j] + (1 - c2) * f2[j] - f3[j])
1308                               + xjk[i] * f2[j];
1309             }
1310         }
1311     }
1312
1313     /* TOTAL: 113 flops */
1314 }
1315
1316 template<VirialHandling virialHandling>
1317 static void spread_vsite3OUT(const t_iatom        ia[],
1318                              real                 a,
1319                              real                 b,
1320                              real                 c,
1321                              ArrayRef<const RVec> x,
1322                              ArrayRef<RVec>       f,
1323                              ArrayRef<RVec>       fshift,
1324                              matrix               dxdf,
1325                              const t_pbc*         pbc)
1326 {
1327     rvec xvi, xij, xik, fv, fj, fk;
1328     real cfx, cfy, cfz;
1329     int  av, ai, aj, ak;
1330     int  sji, ski;
1331
1332     av = ia[1];
1333     ai = ia[2];
1334     aj = ia[3];
1335     ak = ia[4];
1336
1337     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1338     ski = pbc_rvec_sub(pbc, x[ak], x[ai], xik);
1339     /* 6 Flops */
1340
1341     copy_rvec(f[av], fv);
1342
1343     cfx = c * fv[XX];
1344     cfy = c * fv[YY];
1345     cfz = c * fv[ZZ];
1346     /* 3 Flops */
1347
1348     fj[XX] = a * fv[XX] - xik[ZZ] * cfy + xik[YY] * cfz;
1349     fj[YY] = xik[ZZ] * cfx + a * fv[YY] - xik[XX] * cfz;
1350     fj[ZZ] = -xik[YY] * cfx + xik[XX] * cfy + a * fv[ZZ];
1351
1352     fk[XX] = b * fv[XX] + xij[ZZ] * cfy - xij[YY] * cfz;
1353     fk[YY] = -xij[ZZ] * cfx + b * fv[YY] + xij[XX] * cfz;
1354     fk[ZZ] = xij[YY] * cfx - xij[XX] * cfy + b * fv[ZZ];
1355     /* 30 Flops */
1356
1357     f[ai][XX] += fv[XX] - fj[XX] - fk[XX];
1358     f[ai][YY] += fv[YY] - fj[YY] - fk[YY];
1359     f[ai][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ];
1360     rvec_inc(f[aj], fj);
1361     rvec_inc(f[ak], fk);
1362     /* 15 Flops */
1363
1364     if (virialHandling == VirialHandling::Pbc)
1365     {
1366         int svi;
1367         if (pbc)
1368         {
1369             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1370         }
1371         else
1372         {
1373             svi = CENTRAL;
1374         }
1375
1376         if (svi != CENTRAL || sji != CENTRAL || ski != CENTRAL)
1377         {
1378             rvec_dec(fshift[svi], fv);
1379             fshift[CENTRAL][XX] += fv[XX] - fj[XX] - fk[XX];
1380             fshift[CENTRAL][YY] += fv[YY] - fj[YY] - fk[YY];
1381             fshift[CENTRAL][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ];
1382             rvec_inc(fshift[sji], fj);
1383             rvec_inc(fshift[ski], fk);
1384         }
1385     }
1386
1387     if (virialHandling == VirialHandling::NonLinear)
1388     {
1389         rvec xiv;
1390
1391         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1392
1393         for (int i = 0; i < DIM; i++)
1394         {
1395             for (int j = 0; j < DIM; j++)
1396             {
1397                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * fj[j] + xik[i] * fk[j];
1398             }
1399         }
1400     }
1401
1402     /* TOTAL: 54 flops */
1403 }
1404
1405 template<VirialHandling virialHandling>
1406 static void spread_vsite4FD(const t_iatom        ia[],
1407                             real                 a,
1408                             real                 b,
1409                             real                 c,
1410                             ArrayRef<const RVec> x,
1411                             ArrayRef<RVec>       f,
1412                             ArrayRef<RVec>       fshift,
1413                             matrix               dxdf,
1414                             const t_pbc*         pbc)
1415 {
1416     real fproj, a1;
1417     rvec xvi, xij, xjk, xjl, xix, fv, temp;
1418     int  av, ai, aj, ak, al;
1419     int  sji, skj, slj, m;
1420
1421     av = ia[1];
1422     ai = ia[2];
1423     aj = ia[3];
1424     ak = ia[4];
1425     al = ia[5];
1426
1427     sji = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1428     skj = pbc_rvec_sub(pbc, x[ak], x[aj], xjk);
1429     slj = pbc_rvec_sub(pbc, x[al], x[aj], xjl);
1430     /* 9 flops */
1431
1432     /* xix goes from i to point x on the plane jkl */
1433     for (m = 0; m < DIM; m++)
1434     {
1435         xix[m] = xij[m] + a * xjk[m] + b * xjl[m];
1436     }
1437     /* 12 flops */
1438
1439     const real invDistance = inverseNorm(xix);
1440     const real d           = c * invDistance;
1441     /* 4 + ?10? flops */
1442
1443     copy_rvec(f[av], fv);
1444
1445     fproj = iprod(xix, fv) * invDistance * invDistance; /* = (xix . f)/(xix . xix) */
1446
1447     for (m = 0; m < DIM; m++)
1448     {
1449         temp[m] = d * (fv[m] - fproj * xix[m]);
1450     }
1451     /* 16 */
1452
1453     /* c is already calculated in constr_vsite3FD
1454        storing c somewhere will save 35 flops!     */
1455
1456     a1 = 1 - a - b;
1457     for (m = 0; m < DIM; m++)
1458     {
1459         f[ai][m] += fv[m] - temp[m];
1460         f[aj][m] += a1 * temp[m];
1461         f[ak][m] += a * temp[m];
1462         f[al][m] += b * temp[m];
1463     }
1464     /* 26 Flops */
1465
1466     if (virialHandling == VirialHandling::Pbc)
1467     {
1468         int svi;
1469         if (pbc)
1470         {
1471             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1472         }
1473         else
1474         {
1475             svi = CENTRAL;
1476         }
1477
1478         if (svi != CENTRAL || sji != CENTRAL || skj != CENTRAL || slj != CENTRAL)
1479         {
1480             rvec_dec(fshift[svi], fv);
1481             for (m = 0; m < DIM; m++)
1482             {
1483                 fshift[CENTRAL][m] += fv[m] - (1 + a + b) * temp[m];
1484                 fshift[sji][m] += temp[m];
1485                 fshift[skj][m] += a * temp[m];
1486                 fshift[slj][m] += b * temp[m];
1487             }
1488         }
1489     }
1490
1491     if (virialHandling == VirialHandling::NonLinear)
1492     {
1493         rvec xiv;
1494         int  i, j;
1495
1496         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1497
1498         for (i = 0; i < DIM; i++)
1499         {
1500             for (j = 0; j < DIM; j++)
1501             {
1502                 dxdf[i][j] += -xiv[i] * fv[j] + xix[i] * temp[j];
1503             }
1504         }
1505     }
1506
1507     /* TOTAL: 77 flops */
1508 }
1509
1510 template<VirialHandling virialHandling>
1511 static void spread_vsite4FDN(const t_iatom        ia[],
1512                              real                 a,
1513                              real                 b,
1514                              real                 c,
1515                              ArrayRef<const RVec> x,
1516                              ArrayRef<RVec>       f,
1517                              ArrayRef<RVec>       fshift,
1518                              matrix               dxdf,
1519                              const t_pbc*         pbc)
1520 {
1521     rvec xvi, xij, xik, xil, ra, rb, rja, rjb, rab, rm, rt;
1522     rvec fv, fj, fk, fl;
1523     real invrm, denom;
1524     real cfx, cfy, cfz;
1525     int  av, ai, aj, ak, al;
1526     int  sij, sik, sil;
1527
1528     /* DEBUG: check atom indices */
1529     av = ia[1];
1530     ai = ia[2];
1531     aj = ia[3];
1532     ak = ia[4];
1533     al = ia[5];
1534
1535     copy_rvec(f[av], fv);
1536
1537     sij = pbc_rvec_sub(pbc, x[aj], x[ai], xij);
1538     sik = pbc_rvec_sub(pbc, x[ak], x[ai], xik);
1539     sil = pbc_rvec_sub(pbc, x[al], x[ai], xil);
1540     /* 9 flops */
1541
1542     ra[XX] = a * xik[XX];
1543     ra[YY] = a * xik[YY];
1544     ra[ZZ] = a * xik[ZZ];
1545
1546     rb[XX] = b * xil[XX];
1547     rb[YY] = b * xil[YY];
1548     rb[ZZ] = b * xil[ZZ];
1549
1550     /* 6 flops */
1551
1552     rvec_sub(ra, xij, rja);
1553     rvec_sub(rb, xij, rjb);
1554     rvec_sub(rb, ra, rab);
1555     /* 9 flops */
1556
1557     cprod(rja, rjb, rm);
1558     /* 9 flops */
1559
1560     invrm = inverseNorm(rm);
1561     denom = invrm * invrm;
1562     /* 5+5+2 flops */
1563
1564     cfx = c * invrm * fv[XX];
1565     cfy = c * invrm * fv[YY];
1566     cfz = c * invrm * fv[ZZ];
1567     /* 6 Flops */
1568
1569     cprod(rm, rab, rt);
1570     /* 9 flops */
1571
1572     rt[XX] *= denom;
1573     rt[YY] *= denom;
1574     rt[ZZ] *= denom;
1575     /* 3flops */
1576
1577     fj[XX] = (-rm[XX] * rt[XX]) * cfx + (rab[ZZ] - rm[YY] * rt[XX]) * cfy
1578              + (-rab[YY] - rm[ZZ] * rt[XX]) * cfz;
1579     fj[YY] = (-rab[ZZ] - rm[XX] * rt[YY]) * cfx + (-rm[YY] * rt[YY]) * cfy
1580              + (rab[XX] - rm[ZZ] * rt[YY]) * cfz;
1581     fj[ZZ] = (rab[YY] - rm[XX] * rt[ZZ]) * cfx + (-rab[XX] - rm[YY] * rt[ZZ]) * cfy
1582              + (-rm[ZZ] * rt[ZZ]) * cfz;
1583     /* 30 flops */
1584
1585     cprod(rjb, rm, rt);
1586     /* 9 flops */
1587
1588     rt[XX] *= denom * a;
1589     rt[YY] *= denom * a;
1590     rt[ZZ] *= denom * a;
1591     /* 3flops */
1592
1593     fk[XX] = (-rm[XX] * rt[XX]) * cfx + (-a * rjb[ZZ] - rm[YY] * rt[XX]) * cfy
1594              + (a * rjb[YY] - rm[ZZ] * rt[XX]) * cfz;
1595     fk[YY] = (a * rjb[ZZ] - rm[XX] * rt[YY]) * cfx + (-rm[YY] * rt[YY]) * cfy
1596              + (-a * rjb[XX] - rm[ZZ] * rt[YY]) * cfz;
1597     fk[ZZ] = (-a * rjb[YY] - rm[XX] * rt[ZZ]) * cfx + (a * rjb[XX] - rm[YY] * rt[ZZ]) * cfy
1598              + (-rm[ZZ] * rt[ZZ]) * cfz;
1599     /* 36 flops */
1600
1601     cprod(rm, rja, rt);
1602     /* 9 flops */
1603
1604     rt[XX] *= denom * b;
1605     rt[YY] *= denom * b;
1606     rt[ZZ] *= denom * b;
1607     /* 3flops */
1608
1609     fl[XX] = (-rm[XX] * rt[XX]) * cfx + (b * rja[ZZ] - rm[YY] * rt[XX]) * cfy
1610              + (-b * rja[YY] - rm[ZZ] * rt[XX]) * cfz;
1611     fl[YY] = (-b * rja[ZZ] - rm[XX] * rt[YY]) * cfx + (-rm[YY] * rt[YY]) * cfy
1612              + (b * rja[XX] - rm[ZZ] * rt[YY]) * cfz;
1613     fl[ZZ] = (b * rja[YY] - rm[XX] * rt[ZZ]) * cfx + (-b * rja[XX] - rm[YY] * rt[ZZ]) * cfy
1614              + (-rm[ZZ] * rt[ZZ]) * cfz;
1615     /* 36 flops */
1616
1617     f[ai][XX] += fv[XX] - fj[XX] - fk[XX] - fl[XX];
1618     f[ai][YY] += fv[YY] - fj[YY] - fk[YY] - fl[YY];
1619     f[ai][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ] - fl[ZZ];
1620     rvec_inc(f[aj], fj);
1621     rvec_inc(f[ak], fk);
1622     rvec_inc(f[al], fl);
1623     /* 21 flops */
1624
1625     if (virialHandling == VirialHandling::Pbc)
1626     {
1627         int svi;
1628         if (pbc)
1629         {
1630             svi = pbc_rvec_sub(pbc, x[av], x[ai], xvi);
1631         }
1632         else
1633         {
1634             svi = CENTRAL;
1635         }
1636
1637         if (svi != CENTRAL || sij != CENTRAL || sik != CENTRAL || sil != CENTRAL)
1638         {
1639             rvec_dec(fshift[svi], fv);
1640             fshift[CENTRAL][XX] += fv[XX] - fj[XX] - fk[XX] - fl[XX];
1641             fshift[CENTRAL][YY] += fv[YY] - fj[YY] - fk[YY] - fl[YY];
1642             fshift[CENTRAL][ZZ] += fv[ZZ] - fj[ZZ] - fk[ZZ] - fl[ZZ];
1643             rvec_inc(fshift[sij], fj);
1644             rvec_inc(fshift[sik], fk);
1645             rvec_inc(fshift[sil], fl);
1646         }
1647     }
1648
1649     if (virialHandling == VirialHandling::NonLinear)
1650     {
1651         rvec xiv;
1652         int  i, j;
1653
1654         pbc_rvec_sub(pbc, x[av], x[ai], xiv);
1655
1656         for (i = 0; i < DIM; i++)
1657         {
1658             for (j = 0; j < DIM; j++)
1659             {
1660                 dxdf[i][j] += -xiv[i] * fv[j] + xij[i] * fj[j] + xik[i] * fk[j] + xil[i] * fl[j];
1661             }
1662         }
1663     }
1664
1665     /* Total: 207 flops (Yuck!) */
1666 }
1667
1668 template<VirialHandling virialHandling>
1669 static int spread_vsiten(const t_iatom             ia[],
1670                          ArrayRef<const t_iparams> ip,
1671                          ArrayRef<const RVec>      x,
1672                          ArrayRef<RVec>            f,
1673                          ArrayRef<RVec>            fshift,
1674                          const t_pbc*              pbc)
1675 {
1676     rvec xv, dx, fi;
1677     int  n3, av, i, ai;
1678     real a;
1679     int  siv;
1680
1681     n3 = 3 * ip[ia[0]].vsiten.n;
1682     av = ia[1];
1683     copy_rvec(x[av], xv);
1684
1685     for (i = 0; i < n3; i += 3)
1686     {
1687         ai = ia[i + 2];
1688         if (pbc)
1689         {
1690             siv = pbc_dx_aiuc(pbc, x[ai], xv, dx);
1691         }
1692         else
1693         {
1694             siv = CENTRAL;
1695         }
1696         a = ip[ia[i]].vsiten.a;
1697         svmul(a, f[av], fi);
1698         rvec_inc(f[ai], fi);
1699
1700         if (virialHandling == VirialHandling::Pbc && siv != CENTRAL)
1701         {
1702             rvec_inc(fshift[siv], fi);
1703             rvec_dec(fshift[CENTRAL], fi);
1704         }
1705         /* 6 Flops */
1706     }
1707
1708     return n3;
1709 }
1710
1711 #endif // DOXYGEN
1712
1713 //! Returns the number of virtual sites in the interaction list, for VSITEN the number of atoms
1714 static int vsite_count(ArrayRef<const InteractionList> ilist, int ftype)
1715 {
1716     if (ftype == F_VSITEN)
1717     {
1718         return ilist[ftype].size() / 3;
1719     }
1720     else
1721     {
1722         return ilist[ftype].size() / (1 + interaction_function[ftype].nratoms);
1723     }
1724 }
1725
1726 //! Executes the force spreading task for a single thread
1727 template<VirialHandling virialHandling>
1728 static void spreadForceForThread(ArrayRef<const RVec>            x,
1729                                  ArrayRef<RVec>                  f,
1730                                  ArrayRef<RVec>                  fshift,
1731                                  matrix                          dxdf,
1732                                  ArrayRef<const t_iparams>       ip,
1733                                  ArrayRef<const InteractionList> ilist,
1734                                  const t_pbc*                    pbc_null)
1735 {
1736     const PbcMode pbcMode = getPbcMode(pbc_null);
1737     /* We need another pbc pointer, as with charge groups we switch per vsite */
1738     const t_pbc*             pbc_null2 = pbc_null;
1739     gmx::ArrayRef<const int> vsite_pbc;
1740
1741     /* this loop goes backwards to be able to build *
1742      * higher type vsites from lower types         */
1743     for (int ftype = c_ftypeVsiteEnd - 1; ftype >= c_ftypeVsiteStart; ftype--)
1744     {
1745         if (ilist[ftype].empty())
1746         {
1747             continue;
1748         }
1749
1750         { // TODO remove me
1751             int nra = interaction_function[ftype].nratoms;
1752             int inc = 1 + nra;
1753             int nr  = ilist[ftype].size();
1754
1755             const t_iatom* ia = ilist[ftype].iatoms.data();
1756
1757             if (pbcMode == PbcMode::all)
1758             {
1759                 pbc_null2 = pbc_null;
1760             }
1761
1762             for (int i = 0; i < nr;)
1763             {
1764                 int tp = ia[0];
1765
1766                 /* Constants for constructing */
1767                 real a1, b1, c1;
1768                 a1 = ip[tp].vsite.a;
1769                 /* Construct the vsite depending on type */
1770                 switch (ftype)
1771                 {
1772                     case F_VSITE1: spread_vsite1(ia, f); break;
1773                     case F_VSITE2:
1774                         spread_vsite2<virialHandling>(ia, a1, x, f, fshift, pbc_null2);
1775                         break;
1776                     case F_VSITE2FD:
1777                         spread_vsite2FD<virialHandling>(ia, a1, x, f, fshift, dxdf, pbc_null2);
1778                         break;
1779                     case F_VSITE3:
1780                         b1 = ip[tp].vsite.b;
1781                         spread_vsite3<virialHandling>(ia, a1, b1, x, f, fshift, pbc_null2);
1782                         break;
1783                     case F_VSITE3FD:
1784                         b1 = ip[tp].vsite.b;
1785                         spread_vsite3FD<virialHandling>(ia, a1, b1, x, f, fshift, dxdf, pbc_null2);
1786                         break;
1787                     case F_VSITE3FAD:
1788                         b1 = ip[tp].vsite.b;
1789                         spread_vsite3FAD<virialHandling>(ia, a1, b1, x, f, fshift, dxdf, pbc_null2);
1790                         break;
1791                     case F_VSITE3OUT:
1792                         b1 = ip[tp].vsite.b;
1793                         c1 = ip[tp].vsite.c;
1794                         spread_vsite3OUT<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
1795                         break;
1796                     case F_VSITE4FD:
1797                         b1 = ip[tp].vsite.b;
1798                         c1 = ip[tp].vsite.c;
1799                         spread_vsite4FD<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
1800                         break;
1801                     case F_VSITE4FDN:
1802                         b1 = ip[tp].vsite.b;
1803                         c1 = ip[tp].vsite.c;
1804                         spread_vsite4FDN<virialHandling>(ia, a1, b1, c1, x, f, fshift, dxdf, pbc_null2);
1805                         break;
1806                     case F_VSITEN:
1807                         inc = spread_vsiten<virialHandling>(ia, ip, x, f, fshift, pbc_null2);
1808                         break;
1809                     default:
1810                         gmx_fatal(FARGS, "No such vsite type %d in %s, line %d", ftype, __FILE__, __LINE__);
1811                 }
1812                 clear_rvec(f[ia[1]]);
1813
1814                 /* Increment loop variables */
1815                 i += inc;
1816                 ia += inc;
1817             }
1818         }
1819     }
1820 }
1821
1822 //! Wrapper function for calling the templated thread-local spread function
1823 static void spreadForceWrapper(ArrayRef<const RVec>            x,
1824                                ArrayRef<RVec>                  f,
1825                                const VirialHandling            virialHandling,
1826                                ArrayRef<RVec>                  fshift,
1827                                matrix                          dxdf,
1828                                const bool                      clearDxdf,
1829                                ArrayRef<const t_iparams>       ip,
1830                                ArrayRef<const InteractionList> ilist,
1831                                const t_pbc*                    pbc_null)
1832 {
1833     if (virialHandling == VirialHandling::NonLinear && clearDxdf)
1834     {
1835         clear_mat(dxdf);
1836     }
1837
1838     switch (virialHandling)
1839     {
1840         case VirialHandling::None:
1841             spreadForceForThread<VirialHandling::None>(x, f, fshift, dxdf, ip, ilist, pbc_null);
1842             break;
1843         case VirialHandling::Pbc:
1844             spreadForceForThread<VirialHandling::Pbc>(x, f, fshift, dxdf, ip, ilist, pbc_null);
1845             break;
1846         case VirialHandling::NonLinear:
1847             spreadForceForThread<VirialHandling::NonLinear>(x, f, fshift, dxdf, ip, ilist, pbc_null);
1848             break;
1849     }
1850 }
1851
1852 //! Clears the task force buffer elements that are written by task idTask
1853 static void clearTaskForceBufferUsedElements(InterdependentTask* idTask)
1854 {
1855     int ntask = idTask->spreadTask.size();
1856     for (int ti = 0; ti < ntask; ti++)
1857     {
1858         const AtomIndex* atomList = &idTask->atomIndex[idTask->spreadTask[ti]];
1859         int              natom    = atomList->atom.size();
1860         RVec*            force    = idTask->force.data();
1861         for (int i = 0; i < natom; i++)
1862         {
1863             clear_rvec(force[atomList->atom[i]]);
1864         }
1865     }
1866 }
1867
1868 void VirtualSitesHandler::Impl::spreadForces(ArrayRef<const RVec> x,
1869                                              ArrayRef<RVec>       f,
1870                                              const VirialHandling virialHandling,
1871                                              ArrayRef<RVec>       fshift,
1872                                              matrix               virial,
1873                                              t_nrnb*              nrnb,
1874                                              const matrix         box,
1875                                              gmx_wallcycle*       wcycle)
1876 {
1877     wallcycle_start(wcycle, ewcVSITESPREAD);
1878
1879     const bool useDomdec = domainInfo_.useDomdec();
1880
1881     t_pbc pbc, *pbc_null;
1882
1883     if (domainInfo_.useMolPbc_)
1884     {
1885         /* This is wasting some CPU time as we now do this multiple times
1886          * per MD step.
1887          */
1888         pbc_null = set_pbc_dd(
1889                 &pbc, domainInfo_.pbcType_, useDomdec ? domainInfo_.domdec_->numCells : nullptr, FALSE, box);
1890     }
1891     else
1892     {
1893         pbc_null = nullptr;
1894     }
1895
1896     if (useDomdec)
1897     {
1898         dd_clear_f_vsites(*domainInfo_.domdec_, f);
1899     }
1900
1901     const int numThreads = threadingInfo_.numThreads();
1902
1903     if (numThreads == 1)
1904     {
1905         matrix dxdf;
1906         spreadForceWrapper(x, f, virialHandling, fshift, dxdf, true, iparams_, ilists_, pbc_null);
1907
1908         if (virialHandling == VirialHandling::NonLinear)
1909         {
1910             for (int i = 0; i < DIM; i++)
1911             {
1912                 for (int j = 0; j < DIM; j++)
1913                 {
1914                     virial[i][j] += -0.5 * dxdf[i][j];
1915                 }
1916             }
1917         }
1918     }
1919     else
1920     {
1921         /* First spread the vsites that might depend on non-local vsites */
1922         auto& nlDependentVSites = threadingInfo_.threadDataNonLocalDependent();
1923         spreadForceWrapper(x,
1924                            f,
1925                            virialHandling,
1926                            fshift,
1927                            nlDependentVSites.dxdf,
1928                            true,
1929                            iparams_,
1930                            nlDependentVSites.ilist,
1931                            pbc_null);
1932
1933 #pragma omp parallel num_threads(numThreads)
1934         {
1935             try
1936             {
1937                 int          thread = gmx_omp_get_thread_num();
1938                 VsiteThread& tData  = threadingInfo_.threadData(thread);
1939
1940                 ArrayRef<RVec> fshift_t;
1941                 if (virialHandling == VirialHandling::Pbc)
1942                 {
1943                     if (thread == 0)
1944                     {
1945                         fshift_t = fshift;
1946                     }
1947                     else
1948                     {
1949                         fshift_t = tData.fshift;
1950
1951                         for (int i = 0; i < SHIFTS; i++)
1952                         {
1953                             clear_rvec(fshift_t[i]);
1954                         }
1955                     }
1956                 }
1957
1958                 if (tData.useInterdependentTask)
1959                 {
1960                     /* Spread the vsites that spread outside our local range.
1961                      * This is done using a thread-local force buffer force.
1962                      * First we need to copy the input vsite forces to force.
1963                      */
1964                     InterdependentTask* idTask = &tData.idTask;
1965
1966                     /* Clear the buffer elements set by our task during
1967                      * the last call to spread_vsite_f.
1968                      */
1969                     clearTaskForceBufferUsedElements(idTask);
1970
1971                     int nvsite = idTask->vsite.size();
1972                     for (int i = 0; i < nvsite; i++)
1973                     {
1974                         copy_rvec(f[idTask->vsite[i]], idTask->force[idTask->vsite[i]]);
1975                     }
1976                     spreadForceWrapper(x,
1977                                        idTask->force,
1978                                        virialHandling,
1979                                        fshift_t,
1980                                        tData.dxdf,
1981                                        true,
1982                                        iparams_,
1983                                        tData.idTask.ilist,
1984                                        pbc_null);
1985
1986                     /* We need a barrier before reducing forces below
1987                      * that have been produced by a different thread above.
1988                      */
1989 #pragma omp barrier
1990
1991                     /* Loop over all thread task and reduce forces they
1992                      * produced on atoms that fall in our range.
1993                      * Note that atomic reduction would be a simpler solution,
1994                      * but that might not have good support on all platforms.
1995                      */
1996                     int ntask = idTask->reduceTask.size();
1997                     for (int ti = 0; ti < ntask; ti++)
1998                     {
1999                         const InterdependentTask& idt_foreign =
2000                                 threadingInfo_.threadData(idTask->reduceTask[ti]).idTask;
2001                         const AtomIndex& atomList  = idt_foreign.atomIndex[thread];
2002                         const RVec*      f_foreign = idt_foreign.force.data();
2003
2004                         for (int ind : atomList.atom)
2005                         {
2006                             rvec_inc(f[ind], f_foreign[ind]);
2007                             /* Clearing of f_foreign is done at the next step */
2008                         }
2009                     }
2010                     /* Clear the vsite forces, both in f and force */
2011                     for (int i = 0; i < nvsite; i++)
2012                     {
2013                         int ind = tData.idTask.vsite[i];
2014                         clear_rvec(f[ind]);
2015                         clear_rvec(tData.idTask.force[ind]);
2016                     }
2017                 }
2018
2019                 /* Spread the vsites that spread locally only */
2020                 spreadForceWrapper(
2021                         x, f, virialHandling, fshift_t, tData.dxdf, false, iparams_, tData.ilist, pbc_null);
2022             }
2023             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
2024         }
2025
2026         if (virialHandling == VirialHandling::Pbc)
2027         {
2028             for (int th = 1; th < numThreads; th++)
2029             {
2030                 for (int i = 0; i < SHIFTS; i++)
2031                 {
2032                     rvec_inc(fshift[i], threadingInfo_.threadData(th).fshift[i]);
2033                 }
2034             }
2035         }
2036
2037         if (virialHandling == VirialHandling::NonLinear)
2038         {
2039             for (int th = 0; th < numThreads + 1; th++)
2040             {
2041                 /* MSVC doesn't like matrix references, so we use a pointer */
2042                 const matrix& dxdf = threadingInfo_.threadData(th).dxdf;
2043
2044                 for (int i = 0; i < DIM; i++)
2045                 {
2046                     for (int j = 0; j < DIM; j++)
2047                     {
2048                         virial[i][j] += -0.5 * dxdf[i][j];
2049                     }
2050                 }
2051             }
2052         }
2053     }
2054
2055     if (useDomdec)
2056     {
2057         dd_move_f_vsites(*domainInfo_.domdec_, f, fshift);
2058     }
2059
2060     inc_nrnb(nrnb, eNR_VSITE1, vsite_count(ilists_, F_VSITE1));
2061     inc_nrnb(nrnb, eNR_VSITE2, vsite_count(ilists_, F_VSITE2));
2062     inc_nrnb(nrnb, eNR_VSITE2FD, vsite_count(ilists_, F_VSITE2FD));
2063     inc_nrnb(nrnb, eNR_VSITE3, vsite_count(ilists_, F_VSITE3));
2064     inc_nrnb(nrnb, eNR_VSITE3FD, vsite_count(ilists_, F_VSITE3FD));
2065     inc_nrnb(nrnb, eNR_VSITE3FAD, vsite_count(ilists_, F_VSITE3FAD));
2066     inc_nrnb(nrnb, eNR_VSITE3OUT, vsite_count(ilists_, F_VSITE3OUT));
2067     inc_nrnb(nrnb, eNR_VSITE4FD, vsite_count(ilists_, F_VSITE4FD));
2068     inc_nrnb(nrnb, eNR_VSITE4FDN, vsite_count(ilists_, F_VSITE4FDN));
2069     inc_nrnb(nrnb, eNR_VSITEN, vsite_count(ilists_, F_VSITEN));
2070
2071     wallcycle_stop(wcycle, ewcVSITESPREAD);
2072 }
2073
2074 /*! \brief Returns the an array with group indices for each atom
2075  *
2076  * \param[in] grouping  The paritioning of the atom range into atom groups
2077  */
2078 static std::vector<int> makeAtomToGroupMapping(const gmx::RangePartitioning& grouping)
2079 {
2080     std::vector<int> atomToGroup(grouping.fullRange().end(), 0);
2081
2082     for (int group = 0; group < grouping.numBlocks(); group++)
2083     {
2084         auto block = grouping.block(group);
2085         std::fill(atomToGroup.begin() + block.begin(), atomToGroup.begin() + block.end(), group);
2086     }
2087
2088     return atomToGroup;
2089 }
2090
2091 int countNonlinearVsites(const gmx_mtop_t& mtop)
2092 {
2093     int numNonlinearVsites = 0;
2094     for (const gmx_molblock_t& molb : mtop.molblock)
2095     {
2096         const gmx_moltype_t& molt = mtop.moltype[molb.type];
2097
2098         for (const auto& ilist : extractILists(molt.ilist, IF_VSITE))
2099         {
2100             if (ilist.functionType != F_VSITE2 && ilist.functionType != F_VSITE3
2101                 && ilist.functionType != F_VSITEN)
2102             {
2103                 numNonlinearVsites += molb.nmol * ilist.iatoms.size() / (1 + NRAL(ilist.functionType));
2104             }
2105         }
2106     }
2107
2108     return numNonlinearVsites;
2109 }
2110
2111 void VirtualSitesHandler::spreadForces(ArrayRef<const RVec> x,
2112                                        ArrayRef<RVec>       f,
2113                                        const VirialHandling virialHandling,
2114                                        ArrayRef<RVec>       fshift,
2115                                        matrix               virial,
2116                                        t_nrnb*              nrnb,
2117                                        const matrix         box,
2118                                        gmx_wallcycle*       wcycle)
2119 {
2120     impl_->spreadForces(x, f, virialHandling, fshift, virial, nrnb, box, wcycle);
2121 }
2122
2123 int countInterUpdategroupVsites(const gmx_mtop_t&                           mtop,
2124                                 gmx::ArrayRef<const gmx::RangePartitioning> updateGroupingPerMoleculetype)
2125 {
2126     int n_intercg_vsite = 0;
2127     for (const gmx_molblock_t& molb : mtop.molblock)
2128     {
2129         const gmx_moltype_t& molt = mtop.moltype[molb.type];
2130
2131         std::vector<int> atomToGroup;
2132         if (!updateGroupingPerMoleculetype.empty())
2133         {
2134             atomToGroup = makeAtomToGroupMapping(updateGroupingPerMoleculetype[molb.type]);
2135         }
2136         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2137         {
2138             const int              nral = NRAL(ftype);
2139             const InteractionList& il   = molt.ilist[ftype];
2140             for (int i = 0; i < il.size(); i += 1 + nral)
2141             {
2142                 bool isInterGroup = atomToGroup.empty();
2143                 if (!isInterGroup)
2144                 {
2145                     const int group = atomToGroup[il.iatoms[1 + i]];
2146                     for (int a = 1; a < nral; a++)
2147                     {
2148                         if (atomToGroup[il.iatoms[1 + a]] != group)
2149                         {
2150                             isInterGroup = true;
2151                             break;
2152                         }
2153                     }
2154                 }
2155                 if (isInterGroup)
2156                 {
2157                     n_intercg_vsite += molb.nmol;
2158                 }
2159             }
2160         }
2161     }
2162
2163     return n_intercg_vsite;
2164 }
2165
2166 std::unique_ptr<VirtualSitesHandler> makeVirtualSitesHandler(const gmx_mtop_t& mtop,
2167                                                              const t_commrec*  cr,
2168                                                              PbcType           pbcType)
2169 {
2170     GMX_RELEASE_ASSERT(cr != nullptr, "We need a valid commrec");
2171
2172     std::unique_ptr<VirtualSitesHandler> vsite;
2173
2174     /* check if there are vsites */
2175     int nvsite = 0;
2176     for (int ftype = 0; ftype < F_NRE; ftype++)
2177     {
2178         if (interaction_function[ftype].flags & IF_VSITE)
2179         {
2180             GMX_ASSERT(ftype >= c_ftypeVsiteStart && ftype < c_ftypeVsiteEnd,
2181                        "c_ftypeVsiteStart and/or c_ftypeVsiteEnd do not have correct values");
2182
2183             nvsite += gmx_mtop_ftype_count(&mtop, ftype);
2184         }
2185         else
2186         {
2187             GMX_ASSERT(ftype < c_ftypeVsiteStart || ftype >= c_ftypeVsiteEnd,
2188                        "c_ftypeVsiteStart and/or c_ftypeVsiteEnd do not have correct values");
2189         }
2190     }
2191
2192     if (nvsite == 0)
2193     {
2194         return vsite;
2195     }
2196
2197     return std::make_unique<VirtualSitesHandler>(mtop, cr->dd, pbcType);
2198 }
2199
2200 ThreadingInfo::ThreadingInfo() : numThreads_(gmx_omp_nthreads_get(emntVSITE))
2201 {
2202     if (numThreads_ > 1)
2203     {
2204         /* We need one extra thread data structure for the overlap vsites */
2205         tData_.resize(numThreads_ + 1);
2206 #pragma omp parallel for num_threads(numThreads_) schedule(static)
2207         for (int thread = 0; thread < numThreads_; thread++)
2208         {
2209             try
2210             {
2211                 tData_[thread] = std::make_unique<VsiteThread>();
2212
2213                 InterdependentTask& idTask = tData_[thread]->idTask;
2214                 idTask.nuse                = 0;
2215                 idTask.atomIndex.resize(numThreads_);
2216             }
2217             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
2218         }
2219         if (numThreads_ > 1)
2220         {
2221             tData_[numThreads_] = std::make_unique<VsiteThread>();
2222         }
2223     }
2224 }
2225
2226 //! Returns the number of inter update-group vsites
2227 static int getNumInterUpdategroupVsites(const gmx_mtop_t& mtop, const gmx_domdec_t* domdec)
2228 {
2229     gmx::ArrayRef<const gmx::RangePartitioning> updateGroupingPerMoleculetype;
2230     if (domdec)
2231     {
2232         updateGroupingPerMoleculetype = getUpdateGroupingPerMoleculetype(*domdec);
2233     }
2234
2235     return countInterUpdategroupVsites(mtop, updateGroupingPerMoleculetype);
2236 }
2237
2238 VirtualSitesHandler::Impl::Impl(const gmx_mtop_t& mtop, gmx_domdec_t* domdec, const PbcType pbcType) :
2239     numInterUpdategroupVirtualSites_(getNumInterUpdategroupVsites(mtop, domdec)),
2240     domainInfo_({ pbcType, pbcType != PbcType::No && numInterUpdategroupVirtualSites_ > 0, domdec }),
2241     iparams_(mtop.ffparams.iparams)
2242 {
2243 }
2244
2245 VirtualSitesHandler::VirtualSitesHandler(const gmx_mtop_t& mtop, gmx_domdec_t* domdec, const PbcType pbcType) :
2246     impl_(new Impl(mtop, domdec, pbcType))
2247 {
2248 }
2249
2250 //! Flag that atom \p atom which is home in another task, if it has not already been added before
2251 static inline void flagAtom(InterdependentTask* idTask, const int atom, const int numThreads, const int numAtomsPerThread)
2252 {
2253     if (!idTask->use[atom])
2254     {
2255         idTask->use[atom] = true;
2256         int thread        = atom / numAtomsPerThread;
2257         /* Assign all non-local atom force writes to thread 0 */
2258         if (thread >= numThreads)
2259         {
2260             thread = 0;
2261         }
2262         idTask->atomIndex[thread].atom.push_back(atom);
2263     }
2264 }
2265
2266 /*! \brief Here we try to assign all vsites that are in our local range.
2267  *
2268  * Our task local atom range is tData->rangeStart - tData->rangeEnd.
2269  * Vsites that depend only on local atoms, as indicated by taskIndex[]==thread,
2270  * are assigned to task tData->ilist. Vsites that depend on non-local atoms
2271  * but not on other vsites are assigned to task tData->id_task.ilist.
2272  * taskIndex[] is set for all vsites in our range, either to our local tasks
2273  * or to the single last task as taskIndex[]=2*nthreads.
2274  */
2275 static void assignVsitesToThread(VsiteThread*                    tData,
2276                                  int                             thread,
2277                                  int                             nthread,
2278                                  int                             natperthread,
2279                                  gmx::ArrayRef<int>              taskIndex,
2280                                  ArrayRef<const InteractionList> ilist,
2281                                  ArrayRef<const t_iparams>       ip,
2282                                  const unsigned short*           ptype)
2283 {
2284     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2285     {
2286         tData->ilist[ftype].clear();
2287         tData->idTask.ilist[ftype].clear();
2288
2289         const int  nral1 = 1 + NRAL(ftype);
2290         const int* iat   = ilist[ftype].iatoms.data();
2291         for (int i = 0; i < ilist[ftype].size();)
2292         {
2293             /* Get the number of iatom entries in this virtual site.
2294              * The 3 below for F_VSITEN is from 1+NRAL(ftype)=3
2295              */
2296             const int numIAtoms = (ftype == F_VSITEN ? ip[iat[i]].vsiten.n * 3 : nral1);
2297
2298             if (iat[1 + i] < tData->rangeStart || iat[1 + i] >= tData->rangeEnd)
2299             {
2300                 /* This vsite belongs to a different thread */
2301                 i += numIAtoms;
2302                 continue;
2303             }
2304
2305             /* We would like to assign this vsite to task thread,
2306              * but it might depend on atoms outside the atom range of thread
2307              * or on another vsite not assigned to task thread.
2308              */
2309             int task = thread;
2310             if (ftype != F_VSITEN)
2311             {
2312                 for (int j = i + 2; j < i + nral1; j++)
2313                 {
2314                     /* Do a range check to avoid a harmless race on taskIndex */
2315                     if (iat[j] < tData->rangeStart || iat[j] >= tData->rangeEnd || taskIndex[iat[j]] != thread)
2316                     {
2317                         if (!tData->useInterdependentTask || ptype[iat[j]] == eptVSite)
2318                         {
2319                             /* At least one constructing atom is a vsite
2320                              * that is not assigned to the same thread.
2321                              * Put this vsite into a separate task.
2322                              */
2323                             task = 2 * nthread;
2324                             break;
2325                         }
2326
2327                         /* There are constructing atoms outside our range,
2328                          * put this vsite into a second task to be executed
2329                          * on the same thread. During construction no barrier
2330                          * is needed between the two tasks on the same thread.
2331                          * During spreading we need to run this task with
2332                          * an additional thread-local intermediate force buffer
2333                          * (or atomic reduction) and a barrier between the two
2334                          * tasks.
2335                          */
2336                         task = nthread + thread;
2337                     }
2338                 }
2339             }
2340             else
2341             {
2342                 for (int j = i + 2; j < i + numIAtoms; j += 3)
2343                 {
2344                     /* Do a range check to avoid a harmless race on taskIndex */
2345                     if (iat[j] < tData->rangeStart || iat[j] >= tData->rangeEnd || taskIndex[iat[j]] != thread)
2346                     {
2347                         GMX_ASSERT(ptype[iat[j]] != eptVSite,
2348                                    "A vsite to be assigned in assignVsitesToThread has a vsite as "
2349                                    "a constructing atom that does not belong to our task, such "
2350                                    "vsites should be assigned to the single 'master' task");
2351
2352                         task = nthread + thread;
2353                     }
2354                 }
2355             }
2356
2357             /* Update this vsite's thread index entry */
2358             taskIndex[iat[1 + i]] = task;
2359
2360             if (task == thread || task == nthread + thread)
2361             {
2362                 /* Copy this vsite to the thread data struct of thread */
2363                 InteractionList* il_task;
2364                 if (task == thread)
2365                 {
2366                     il_task = &tData->ilist[ftype];
2367                 }
2368                 else
2369                 {
2370                     il_task = &tData->idTask.ilist[ftype];
2371                 }
2372                 /* Copy the vsite data to the thread-task local array */
2373                 il_task->push_back(iat[i], numIAtoms - 1, iat + i + 1);
2374                 if (task == nthread + thread)
2375                 {
2376                     /* This vsite writes outside our own task force block.
2377                      * Put it into the interdependent task list and flag
2378                      * the atoms involved for reduction.
2379                      */
2380                     tData->idTask.vsite.push_back(iat[i + 1]);
2381                     if (ftype != F_VSITEN)
2382                     {
2383                         for (int j = i + 2; j < i + nral1; j++)
2384                         {
2385                             flagAtom(&tData->idTask, iat[j], nthread, natperthread);
2386                         }
2387                     }
2388                     else
2389                     {
2390                         for (int j = i + 2; j < i + numIAtoms; j += 3)
2391                         {
2392                             flagAtom(&tData->idTask, iat[j], nthread, natperthread);
2393                         }
2394                     }
2395                 }
2396             }
2397
2398             i += numIAtoms;
2399         }
2400     }
2401 }
2402
2403 /*! \brief Assign all vsites with taskIndex[]==task to task tData */
2404 static void assignVsitesToSingleTask(VsiteThread*                    tData,
2405                                      int                             task,
2406                                      gmx::ArrayRef<const int>        taskIndex,
2407                                      ArrayRef<const InteractionList> ilist,
2408                                      ArrayRef<const t_iparams>       ip)
2409 {
2410     for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2411     {
2412         tData->ilist[ftype].clear();
2413         tData->idTask.ilist[ftype].clear();
2414
2415         int              nral1   = 1 + NRAL(ftype);
2416         int              inc     = nral1;
2417         const int*       iat     = ilist[ftype].iatoms.data();
2418         InteractionList* il_task = &tData->ilist[ftype];
2419
2420         for (int i = 0; i < ilist[ftype].size();)
2421         {
2422             if (ftype == F_VSITEN)
2423             {
2424                 /* The 3 below is from 1+NRAL(ftype)=3 */
2425                 inc = ip[iat[i]].vsiten.n * 3;
2426             }
2427             /* Check if the vsite is assigned to our task */
2428             if (taskIndex[iat[1 + i]] == task)
2429             {
2430                 /* Copy the vsite data to the thread-task local array */
2431                 il_task->push_back(iat[i], inc - 1, iat + i + 1);
2432             }
2433
2434             i += inc;
2435         }
2436     }
2437 }
2438
2439 void ThreadingInfo::setVirtualSites(ArrayRef<const InteractionList> ilists,
2440                                     ArrayRef<const t_iparams>       iparams,
2441                                     const t_mdatoms&                mdatoms,
2442                                     const bool                      useDomdec)
2443 {
2444     if (numThreads_ <= 1)
2445     {
2446         /* Nothing to do */
2447         return;
2448     }
2449
2450     /* The current way of distributing the vsites over threads in primitive.
2451      * We divide the atom range 0 - natoms_in_vsite uniformly over threads,
2452      * without taking into account how the vsites are distributed.
2453      * Without domain decomposition we at least tighten the upper bound
2454      * of the range (useful for common systems such as a vsite-protein
2455      * in 3-site water).
2456      * With domain decomposition, as long as the vsites are distributed
2457      * uniformly in each domain along the major dimension, usually x,
2458      * it will also perform well.
2459      */
2460     int vsite_atom_range;
2461     int natperthread;
2462     if (!useDomdec)
2463     {
2464         vsite_atom_range = -1;
2465         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2466         {
2467             { // TODO remove me
2468                 if (ftype != F_VSITEN)
2469                 {
2470                     int                 nral1 = 1 + NRAL(ftype);
2471                     ArrayRef<const int> iat   = ilists[ftype].iatoms;
2472                     for (int i = 0; i < ilists[ftype].size(); i += nral1)
2473                     {
2474                         for (int j = i + 1; j < i + nral1; j++)
2475                         {
2476                             vsite_atom_range = std::max(vsite_atom_range, iat[j]);
2477                         }
2478                     }
2479                 }
2480                 else
2481                 {
2482                     int vs_ind_end;
2483
2484                     ArrayRef<const int> iat = ilists[ftype].iatoms;
2485
2486                     int i = 0;
2487                     while (i < ilists[ftype].size())
2488                     {
2489                         /* The 3 below is from 1+NRAL(ftype)=3 */
2490                         vs_ind_end = i + iparams[iat[i]].vsiten.n * 3;
2491
2492                         vsite_atom_range = std::max(vsite_atom_range, iat[i + 1]);
2493                         while (i < vs_ind_end)
2494                         {
2495                             vsite_atom_range = std::max(vsite_atom_range, iat[i + 2]);
2496                             i += 3;
2497                         }
2498                     }
2499                 }
2500             }
2501         }
2502         vsite_atom_range++;
2503         natperthread = (vsite_atom_range + numThreads_ - 1) / numThreads_;
2504     }
2505     else
2506     {
2507         /* Any local or not local atom could be involved in virtual sites.
2508          * But since we usually have very few non-local virtual sites
2509          * (only non-local vsites that depend on local vsites),
2510          * we distribute the local atom range equally over the threads.
2511          * When assigning vsites to threads, we should take care that the last
2512          * threads also covers the non-local range.
2513          */
2514         vsite_atom_range = mdatoms.nr;
2515         natperthread     = (mdatoms.homenr + numThreads_ - 1) / numThreads_;
2516     }
2517
2518     if (debug)
2519     {
2520         fprintf(debug,
2521                 "virtual site thread dist: natoms %d, range %d, natperthread %d\n",
2522                 mdatoms.nr,
2523                 vsite_atom_range,
2524                 natperthread);
2525     }
2526
2527     /* To simplify the vsite assignment, we make an index which tells us
2528      * to which task particles, both non-vsites and vsites, are assigned.
2529      */
2530     taskIndex_.resize(mdatoms.nr);
2531
2532     /* Initialize the task index array. Here we assign the non-vsite
2533      * particles to task=thread, so we easily figure out if vsites
2534      * depend on local and/or non-local particles in assignVsitesToThread.
2535      */
2536     {
2537         int thread = 0;
2538         for (int i = 0; i < mdatoms.nr; i++)
2539         {
2540             if (mdatoms.ptype[i] == eptVSite)
2541             {
2542                 /* vsites are not assigned to a task yet */
2543                 taskIndex_[i] = -1;
2544             }
2545             else
2546             {
2547                 /* assign non-vsite particles to task thread */
2548                 taskIndex_[i] = thread;
2549             }
2550             if (i == (thread + 1) * natperthread && thread < numThreads_)
2551             {
2552                 thread++;
2553             }
2554         }
2555     }
2556
2557 #pragma omp parallel num_threads(numThreads_)
2558     {
2559         try
2560         {
2561             int          thread = gmx_omp_get_thread_num();
2562             VsiteThread& tData  = *tData_[thread];
2563
2564             /* Clear the buffer use flags that were set before */
2565             if (tData.useInterdependentTask)
2566             {
2567                 InterdependentTask& idTask = tData.idTask;
2568
2569                 /* To avoid an extra OpenMP barrier in spread_vsite_f,
2570                  * we clear the force buffer at the next step,
2571                  * so we need to do it here as well.
2572                  */
2573                 clearTaskForceBufferUsedElements(&idTask);
2574
2575                 idTask.vsite.resize(0);
2576                 for (int t = 0; t < numThreads_; t++)
2577                 {
2578                     AtomIndex& atomIndex = idTask.atomIndex[t];
2579                     int        natom     = atomIndex.atom.size();
2580                     for (int i = 0; i < natom; i++)
2581                     {
2582                         idTask.use[atomIndex.atom[i]] = false;
2583                     }
2584                     atomIndex.atom.resize(0);
2585                 }
2586                 idTask.nuse = 0;
2587             }
2588
2589             /* To avoid large f_buf allocations of #threads*vsite_atom_range
2590              * we don't use task2 with more than 200000 atoms. This doesn't
2591              * affect performance, since with such a large range relatively few
2592              * vsites will end up in the separate task.
2593              * Note that useTask2 should be the same for all threads.
2594              */
2595             tData.useInterdependentTask = (vsite_atom_range <= 200000);
2596             if (tData.useInterdependentTask)
2597             {
2598                 size_t              natoms_use_in_vsites = vsite_atom_range;
2599                 InterdependentTask& idTask               = tData.idTask;
2600                 /* To avoid resizing and re-clearing every nstlist steps,
2601                  * we never down size the force buffer.
2602                  */
2603                 if (natoms_use_in_vsites > idTask.force.size() || natoms_use_in_vsites > idTask.use.size())
2604                 {
2605                     idTask.force.resize(natoms_use_in_vsites, { 0, 0, 0 });
2606                     idTask.use.resize(natoms_use_in_vsites, false);
2607                 }
2608             }
2609
2610             /* Assign all vsites that can execute independently on threads */
2611             tData.rangeStart = thread * natperthread;
2612             if (thread < numThreads_ - 1)
2613             {
2614                 tData.rangeEnd = (thread + 1) * natperthread;
2615             }
2616             else
2617             {
2618                 /* The last thread should cover up to the end of the range */
2619                 tData.rangeEnd = mdatoms.nr;
2620             }
2621             assignVsitesToThread(
2622                     &tData, thread, numThreads_, natperthread, taskIndex_, ilists, iparams, mdatoms.ptype);
2623
2624             if (tData.useInterdependentTask)
2625             {
2626                 /* In the worst case, all tasks write to force ranges of
2627                  * all other tasks, leading to #tasks^2 scaling (this is only
2628                  * the overhead, the actual flops remain constant).
2629                  * But in most cases there is far less coupling. To improve
2630                  * scaling at high thread counts we therefore construct
2631                  * an index to only loop over the actually affected tasks.
2632                  */
2633                 InterdependentTask& idTask = tData.idTask;
2634
2635                 /* Ensure assignVsitesToThread finished on other threads */
2636 #pragma omp barrier
2637
2638                 idTask.spreadTask.resize(0);
2639                 idTask.reduceTask.resize(0);
2640                 for (int t = 0; t < numThreads_; t++)
2641                 {
2642                     /* Do we write to the force buffer of task t? */
2643                     if (!idTask.atomIndex[t].atom.empty())
2644                     {
2645                         idTask.spreadTask.push_back(t);
2646                     }
2647                     /* Does task t write to our force buffer? */
2648                     if (!tData_[t]->idTask.atomIndex[thread].atom.empty())
2649                     {
2650                         idTask.reduceTask.push_back(t);
2651                     }
2652                 }
2653             }
2654         }
2655         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
2656     }
2657     /* Assign all remaining vsites, that will have taskIndex[]=2*vsite->nthreads,
2658      * to a single task that will not run in parallel with other tasks.
2659      */
2660     assignVsitesToSingleTask(tData_[numThreads_].get(), 2 * numThreads_, taskIndex_, ilists, iparams);
2661
2662     if (debug && numThreads_ > 1)
2663     {
2664         fprintf(debug,
2665                 "virtual site useInterdependentTask %d, nuse:\n",
2666                 static_cast<int>(tData_[0]->useInterdependentTask));
2667         for (int th = 0; th < numThreads_ + 1; th++)
2668         {
2669             fprintf(debug, " %4d", tData_[th]->idTask.nuse);
2670         }
2671         fprintf(debug, "\n");
2672
2673         for (int ftype = c_ftypeVsiteStart; ftype < c_ftypeVsiteEnd; ftype++)
2674         {
2675             if (!ilists[ftype].empty())
2676             {
2677                 fprintf(debug, "%-20s thread dist:", interaction_function[ftype].longname);
2678                 for (int th = 0; th < numThreads_ + 1; th++)
2679                 {
2680                     fprintf(debug,
2681                             " %4d %4d ",
2682                             tData_[th]->ilist[ftype].size(),
2683                             tData_[th]->idTask.ilist[ftype].size());
2684                 }
2685                 fprintf(debug, "\n");
2686             }
2687         }
2688     }
2689
2690 #ifndef NDEBUG
2691     int nrOrig     = vsiteIlistNrCount(ilists);
2692     int nrThreaded = 0;
2693     for (int th = 0; th < numThreads_ + 1; th++)
2694     {
2695         nrThreaded += vsiteIlistNrCount(tData_[th]->ilist) + vsiteIlistNrCount(tData_[th]->idTask.ilist);
2696     }
2697     GMX_ASSERT(nrThreaded == nrOrig,
2698                "The number of virtual sites assigned to all thread task has to match the total "
2699                "number of virtual sites");
2700 #endif
2701 }
2702
2703 void VirtualSitesHandler::Impl::setVirtualSites(ArrayRef<const InteractionList> ilists,
2704                                                 const t_mdatoms&                mdatoms)
2705 {
2706     ilists_ = ilists;
2707
2708     threadingInfo_.setVirtualSites(ilists, iparams_, mdatoms, domainInfo_.useDomdec());
2709 }
2710
2711 void VirtualSitesHandler::setVirtualSites(ArrayRef<const InteractionList> ilists, const t_mdatoms& mdatoms)
2712 {
2713     impl_->setVirtualSites(ilists, mdatoms);
2714 }
2715
2716 } // namespace gmx