2c08f9788348096e813fb73deead23893112e6d2
[alexxy/gromacs.git] / src / gromacs / modularsimulator / propagator.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  * \brief Defines the propagator element for the modular simulator
37  *
38  * \author Pascal Merz <pascal.merz@me.com>
39  * \ingroup module_modularsimulator
40  */
41
42 #include "gmxpre.h"
43
44 #include "propagator.h"
45
46 #include "gromacs/utility.h"
47 #include "gromacs/math/vec.h"
48 #include "gromacs/math/vectypes.h"
49 #include "gromacs/mdlib/gmx_omp_nthreads.h"
50 #include "gromacs/mdlib/mdatoms.h"
51 #include "gromacs/mdlib/update.h"
52 #include "gromacs/mdtypes/inputrec.h"
53 #include "gromacs/mdtypes/mdatom.h"
54 #include "gromacs/timing/wallcycle.h"
55
56 #include "modularsimulator.h"
57 #include "simulatoralgorithm.h"
58 #include "statepropagatordata.h"
59
60 namespace gmx
61 {
62 namespace
63 {
64 // Names of integration steps, only used locally for error messages
65 constexpr EnumerationArray<IntegrationStage, const char*> integrationStepNames = {
66     "IntegrationStage::PositionsOnly",   "IntegrationStage::VelocitiesOnly",
67     "IntegrationStage::LeapFrog",        "IntegrationStage::VelocityVerletPositionsAndVelocities",
68     "IntegrationStage::ScaleVelocities", "IntegrationStage::ScalePositions"
69 };
70 } // namespace
71
72 //! Update velocities
73 template<NumVelocityScalingValues        numStartVelocityScalingValues,
74          ParrinelloRahmanVelocityScaling parrinelloRahmanVelocityScaling,
75          NumVelocityScalingValues        numEndVelocityScalingValues>
76 static void inline updateVelocities(int         a,
77                                     real        dt,
78                                     real        lambdaStart,
79                                     real        lambdaEnd,
80                                     const rvec* gmx_restrict invMassPerDim,
81                                     rvec* gmx_restrict v,
82                                     const rvec* gmx_restrict f,
83                                     const rvec               diagPR,
84                                     const matrix             matrixPR)
85 {
86     for (int d = 0; d < DIM; d++)
87     {
88         // TODO: Extract this into policy classes
89         if (numStartVelocityScalingValues != NumVelocityScalingValues::None
90             && parrinelloRahmanVelocityScaling == ParrinelloRahmanVelocityScaling::No)
91         {
92             v[a][d] *= lambdaStart;
93         }
94         if (numStartVelocityScalingValues != NumVelocityScalingValues::None
95             && parrinelloRahmanVelocityScaling == ParrinelloRahmanVelocityScaling::Diagonal)
96         {
97             v[a][d] *= (lambdaStart - diagPR[d]);
98         }
99         if (numStartVelocityScalingValues != NumVelocityScalingValues::None
100             && parrinelloRahmanVelocityScaling == ParrinelloRahmanVelocityScaling::Full)
101         {
102             v[a][d] = lambdaStart * v[a][d] - iprod(matrixPR[d], v[a]);
103         }
104         if (numStartVelocityScalingValues == NumVelocityScalingValues::None
105             && parrinelloRahmanVelocityScaling == ParrinelloRahmanVelocityScaling::Diagonal)
106         {
107             v[a][d] *= (1 - diagPR[d]);
108         }
109         if (numStartVelocityScalingValues == NumVelocityScalingValues::None
110             && parrinelloRahmanVelocityScaling == ParrinelloRahmanVelocityScaling::Full)
111         {
112             v[a][d] -= iprod(matrixPR[d], v[a]);
113         }
114         v[a][d] += f[a][d] * invMassPerDim[a][d] * dt;
115         if (numEndVelocityScalingValues != NumVelocityScalingValues::None)
116         {
117             v[a][d] *= lambdaEnd;
118         }
119     }
120 }
121
122 //! Update positions
123 static void inline updatePositions(int         a,
124                                    real        dt,
125                                    const rvec* gmx_restrict x,
126                                    rvec* gmx_restrict xprime,
127                                    const rvec* gmx_restrict v)
128 {
129     for (int d = 0; d < DIM; d++)
130     {
131         xprime[a][d] = x[a][d] + v[a][d] * dt;
132     }
133 }
134
135 //! Scale velocities
136 template<NumVelocityScalingValues numStartVelocityScalingValues>
137 static void inline scaleVelocities(int a, real lambda, rvec* gmx_restrict v)
138 {
139     if (numStartVelocityScalingValues != NumVelocityScalingValues::None)
140     {
141         for (int d = 0; d < DIM; d++)
142         {
143             v[a][d] *= lambda;
144         }
145     }
146 }
147
148 //! Scale positions
149 template<NumPositionScalingValues numPositionScalingValues>
150 static void inline scalePositions(int a, real lambda, rvec* gmx_restrict x)
151 {
152     if (numPositionScalingValues != NumPositionScalingValues::None)
153     {
154         for (int d = 0; d < DIM; d++)
155         {
156             x[a][d] *= lambda;
157         }
158     }
159 }
160
161 //! Helper function diagonalizing the PR matrix if possible
162 template<ParrinelloRahmanVelocityScaling parrinelloRahmanVelocityScaling>
163 static inline bool diagonalizePRMatrix(matrix matrixPR, rvec diagPR)
164 {
165     if (parrinelloRahmanVelocityScaling != ParrinelloRahmanVelocityScaling::Full)
166     {
167         return false;
168     }
169     else
170     {
171         if (matrixPR[YY][XX] == 0 && matrixPR[ZZ][XX] == 0 && matrixPR[ZZ][YY] == 0)
172         {
173             diagPR[XX] = matrixPR[XX][XX];
174             diagPR[YY] = matrixPR[YY][YY];
175             diagPR[ZZ] = matrixPR[ZZ][ZZ];
176             return true;
177         }
178         else
179         {
180             return false;
181         }
182     }
183 }
184
185 //! Propagation (position only)
186 template<>
187 template<NumVelocityScalingValues        numStartVelocityScalingValues,
188          ParrinelloRahmanVelocityScaling parrinelloRahmanVelocityScaling,
189          NumVelocityScalingValues        numEndVelocityScalingValues,
190          NumPositionScalingValues        numPositionScalingValues>
191 void Propagator<IntegrationStage::PositionsOnly>::run()
192 {
193     wallcycle_start(wcycle_, WallCycleCounter::Update);
194
195     auto xp = as_rvec_array(statePropagatorData_->positionsView().paddedArrayRef().data());
196     auto x  = as_rvec_array(statePropagatorData_->constPositionsView().paddedArrayRef().data());
197     auto v  = as_rvec_array(statePropagatorData_->constVelocitiesView().paddedArrayRef().data());
198
199     int nth    = gmx_omp_nthreads_get(emntUpdate);
200     int homenr = mdAtoms_->mdatoms()->homenr;
201
202 #pragma omp parallel for num_threads(nth) schedule(static) default(none) shared(nth, homenr, x, xp, v)
203     for (int th = 0; th < nth; th++)
204     {
205         try
206         {
207             int start_th, end_th;
208             getThreadAtomRange(nth, th, homenr, &start_th, &end_th);
209
210             for (int a = start_th; a < end_th; a++)
211             {
212                 updatePositions(a, timestep_, x, xp, v);
213             }
214         }
215         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
216     }
217     wallcycle_stop(wcycle_, WallCycleCounter::Update);
218 }
219
220 //! Propagation (scale position only)
221 template<>
222 template<NumVelocityScalingValues        numStartVelocityScalingValues,
223          ParrinelloRahmanVelocityScaling parrinelloRahmanVelocityScaling,
224          NumVelocityScalingValues        numEndVelocityScalingValues,
225          NumPositionScalingValues        numPositionScalingValues>
226 void Propagator<IntegrationStage::ScalePositions>::run()
227 {
228     wallcycle_start(wcycle_, WallCycleCounter::Update);
229
230     auto* x = as_rvec_array(statePropagatorData_->positionsView().paddedArrayRef().data());
231
232     const real lambda =
233             (numPositionScalingValues == NumPositionScalingValues::Single) ? positionScaling_[0] : 1.0;
234
235     int nth    = gmx_omp_nthreads_get(emntUpdate);
236     int homenr = mdAtoms_->mdatoms()->homenr;
237
238 #pragma omp parallel for num_threads(nth) schedule(static) default(none) shared(nth, homenr, x) \
239         firstprivate(lambda)
240     for (int th = 0; th < nth; th++)
241     {
242         try
243         {
244             int start_th, end_th;
245             getThreadAtomRange(nth, th, homenr, &start_th, &end_th);
246
247             for (int a = start_th; a < end_th; a++)
248             {
249                 scalePositions<numPositionScalingValues>(
250                         a,
251                         (numPositionScalingValues == NumPositionScalingValues::Multiple)
252                                 ? positionScaling_[mdAtoms_->mdatoms()->cTC[a]]
253                                 : lambda,
254                         x);
255             }
256         }
257         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
258     }
259     wallcycle_stop(wcycle_, WallCycleCounter::Update);
260 }
261
262 //! Propagation (velocity only)
263 template<>
264 template<NumVelocityScalingValues        numStartVelocityScalingValues,
265          ParrinelloRahmanVelocityScaling parrinelloRahmanVelocityScaling,
266          NumVelocityScalingValues        numEndVelocityScalingValues,
267          NumPositionScalingValues        numPositionScalingValues>
268 void Propagator<IntegrationStage::VelocitiesOnly>::run()
269 {
270     wallcycle_start(wcycle_, WallCycleCounter::Update);
271
272     auto v = as_rvec_array(statePropagatorData_->velocitiesView().paddedArrayRef().data());
273     auto f = as_rvec_array(statePropagatorData_->constForcesView().force().data());
274     auto invMassPerDim = mdAtoms_->mdatoms()->invMassPerDim;
275
276     const real lambdaStart = (numStartVelocityScalingValues == NumVelocityScalingValues::Single)
277                                      ? startVelocityScaling_[0]
278                                      : 1.0;
279     const real lambdaEnd = (numEndVelocityScalingValues == NumVelocityScalingValues::Single)
280                                    ? endVelocityScaling_[0]
281                                    : 1.0;
282
283     const bool isFullScalingMatrixDiagonal =
284             diagonalizePRMatrix<parrinelloRahmanVelocityScaling>(matrixPR_, diagPR_);
285
286     const int nth    = gmx_omp_nthreads_get(emntUpdate);
287     const int homenr = mdAtoms_->mdatoms()->homenr;
288
289 // const variables could be shared, but gcc-8 & gcc-9 don't agree how to write that...
290 // https://www.gnu.org/software/gcc/gcc-9/porting_to.html -> OpenMP data sharing
291 #pragma omp parallel for num_threads(nth) schedule(static) default(none) shared(v, f, invMassPerDim) \
292         firstprivate(nth, homenr, lambdaStart, lambdaEnd, isFullScalingMatrixDiagonal)
293     for (int th = 0; th < nth; th++)
294     {
295         try
296         {
297             int start_th, end_th;
298             getThreadAtomRange(nth, th, homenr, &start_th, &end_th);
299
300             for (int a = start_th; a < end_th; a++)
301             {
302                 if (isFullScalingMatrixDiagonal)
303                 {
304                     updateVelocities<numStartVelocityScalingValues, ParrinelloRahmanVelocityScaling::Diagonal, numEndVelocityScalingValues>(
305                             a,
306                             timestep_,
307                             numStartVelocityScalingValues == NumVelocityScalingValues::Multiple
308                                     ? startVelocityScaling_[mdAtoms_->mdatoms()->cTC[a]]
309                                     : lambdaStart,
310                             numEndVelocityScalingValues == NumVelocityScalingValues::Multiple
311                                     ? endVelocityScaling_[mdAtoms_->mdatoms()->cTC[a]]
312                                     : lambdaEnd,
313                             invMassPerDim,
314                             v,
315                             f,
316                             diagPR_,
317                             matrixPR_);
318                 }
319                 else
320                 {
321                     updateVelocities<numStartVelocityScalingValues, parrinelloRahmanVelocityScaling, numEndVelocityScalingValues>(
322                             a,
323                             timestep_,
324                             numStartVelocityScalingValues == NumVelocityScalingValues::Multiple
325                                     ? startVelocityScaling_[mdAtoms_->mdatoms()->cTC[a]]
326                                     : lambdaStart,
327                             numEndVelocityScalingValues == NumVelocityScalingValues::Multiple
328                                     ? endVelocityScaling_[mdAtoms_->mdatoms()->cTC[a]]
329                                     : lambdaEnd,
330                             invMassPerDim,
331                             v,
332                             f,
333                             diagPR_,
334                             matrixPR_);
335                 }
336             }
337         }
338         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
339     }
340     wallcycle_stop(wcycle_, WallCycleCounter::Update);
341 }
342
343 //! Propagation (leapfrog case - position and velocity)
344 template<>
345 template<NumVelocityScalingValues        numStartVelocityScalingValues,
346          ParrinelloRahmanVelocityScaling parrinelloRahmanVelocityScaling,
347          NumVelocityScalingValues        numEndVelocityScalingValues,
348          NumPositionScalingValues        numPositionScalingValues>
349 void Propagator<IntegrationStage::LeapFrog>::run()
350 {
351     wallcycle_start(wcycle_, WallCycleCounter::Update);
352
353     auto xp = as_rvec_array(statePropagatorData_->positionsView().paddedArrayRef().data());
354     auto x  = as_rvec_array(statePropagatorData_->constPositionsView().paddedArrayRef().data());
355     auto v  = as_rvec_array(statePropagatorData_->velocitiesView().paddedArrayRef().data());
356     auto f  = as_rvec_array(statePropagatorData_->constForcesView().force().data());
357     auto invMassPerDim = mdAtoms_->mdatoms()->invMassPerDim;
358
359     const real lambdaStart = (numStartVelocityScalingValues == NumVelocityScalingValues::Single)
360                                      ? startVelocityScaling_[0]
361                                      : 1.0;
362     const real lambdaEnd = (numEndVelocityScalingValues == NumVelocityScalingValues::Single)
363                                    ? endVelocityScaling_[0]
364                                    : 1.0;
365
366     const bool isFullScalingMatrixDiagonal =
367             diagonalizePRMatrix<parrinelloRahmanVelocityScaling>(matrixPR_, diagPR_);
368
369     const int nth    = gmx_omp_nthreads_get(emntUpdate);
370     const int homenr = mdAtoms_->mdatoms()->homenr;
371
372 // const variables could be shared, but gcc-8 & gcc-9 don't agree how to write that...
373 // https://www.gnu.org/software/gcc/gcc-9/porting_to.html -> OpenMP data sharing
374 #pragma omp parallel for num_threads(nth) schedule(static) default(none) \
375         shared(x, xp, v, f, invMassPerDim)                               \
376                 firstprivate(nth, homenr, lambdaStart, lambdaEnd, isFullScalingMatrixDiagonal)
377     for (int th = 0; th < nth; th++)
378     {
379         try
380         {
381             int start_th, end_th;
382             getThreadAtomRange(nth, th, homenr, &start_th, &end_th);
383
384             for (int a = start_th; a < end_th; a++)
385             {
386                 if (isFullScalingMatrixDiagonal)
387                 {
388                     updateVelocities<numStartVelocityScalingValues, ParrinelloRahmanVelocityScaling::Diagonal, numEndVelocityScalingValues>(
389                             a,
390                             timestep_,
391                             numStartVelocityScalingValues == NumVelocityScalingValues::Multiple
392                                     ? startVelocityScaling_[mdAtoms_->mdatoms()->cTC[a]]
393                                     : lambdaStart,
394                             numEndVelocityScalingValues == NumVelocityScalingValues::Multiple
395                                     ? endVelocityScaling_[mdAtoms_->mdatoms()->cTC[a]]
396                                     : lambdaEnd,
397                             invMassPerDim,
398                             v,
399                             f,
400                             diagPR_,
401                             matrixPR_);
402                 }
403                 else
404                 {
405                     updateVelocities<numStartVelocityScalingValues, parrinelloRahmanVelocityScaling, numEndVelocityScalingValues>(
406                             a,
407                             timestep_,
408                             numStartVelocityScalingValues == NumVelocityScalingValues::Multiple
409                                     ? startVelocityScaling_[mdAtoms_->mdatoms()->cTC[a]]
410                                     : lambdaStart,
411                             numEndVelocityScalingValues == NumVelocityScalingValues::Multiple
412                                     ? endVelocityScaling_[mdAtoms_->mdatoms()->cTC[a]]
413                                     : lambdaEnd,
414                             invMassPerDim,
415                             v,
416                             f,
417                             diagPR_,
418                             matrixPR_);
419                 }
420                 updatePositions(a, timestep_, x, xp, v);
421             }
422         }
423         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
424     }
425     wallcycle_stop(wcycle_, WallCycleCounter::Update);
426 }
427
428 //! Propagation (velocity verlet stage 2 - velocity and position)
429 template<>
430 template<NumVelocityScalingValues        numStartVelocityScalingValues,
431          ParrinelloRahmanVelocityScaling parrinelloRahmanVelocityScaling,
432          NumVelocityScalingValues        numEndVelocityScalingValues,
433          NumPositionScalingValues        numPositionScalingValues>
434 void Propagator<IntegrationStage::VelocityVerletPositionsAndVelocities>::run()
435 {
436     wallcycle_start(wcycle_, WallCycleCounter::Update);
437
438     auto xp = as_rvec_array(statePropagatorData_->positionsView().paddedArrayRef().data());
439     auto x  = as_rvec_array(statePropagatorData_->constPositionsView().paddedArrayRef().data());
440     auto v  = as_rvec_array(statePropagatorData_->velocitiesView().paddedArrayRef().data());
441     auto f  = as_rvec_array(statePropagatorData_->constForcesView().force().data());
442     auto invMassPerDim = mdAtoms_->mdatoms()->invMassPerDim;
443
444     const real lambdaStart = (numStartVelocityScalingValues == NumVelocityScalingValues::Single)
445                                      ? startVelocityScaling_[0]
446                                      : 1.0;
447     const real lambdaEnd = (numEndVelocityScalingValues == NumVelocityScalingValues::Single)
448                                    ? endVelocityScaling_[0]
449                                    : 1.0;
450
451     const bool isFullScalingMatrixDiagonal =
452             diagonalizePRMatrix<parrinelloRahmanVelocityScaling>(matrixPR_, diagPR_);
453
454     const int nth    = gmx_omp_nthreads_get(emntUpdate);
455     const int homenr = mdAtoms_->mdatoms()->homenr;
456
457 // const variables could be shared, but gcc-8 & gcc-9 don't agree how to write that...
458 // https://www.gnu.org/software/gcc/gcc-9/porting_to.html -> OpenMP data sharing
459 #pragma omp parallel for num_threads(nth) schedule(static) default(none) \
460         shared(x, xp, v, f, invMassPerDim)                               \
461                 firstprivate(nth, homenr, lambdaStart, lambdaEnd, isFullScalingMatrixDiagonal)
462     for (int th = 0; th < nth; th++)
463     {
464         try
465         {
466             int start_th, end_th;
467             getThreadAtomRange(nth, th, homenr, &start_th, &end_th);
468
469             for (int a = start_th; a < end_th; a++)
470             {
471                 if (isFullScalingMatrixDiagonal)
472                 {
473                     updateVelocities<numStartVelocityScalingValues, ParrinelloRahmanVelocityScaling::Diagonal, numEndVelocityScalingValues>(
474                             a,
475                             0.5 * timestep_,
476                             numStartVelocityScalingValues == NumVelocityScalingValues::Multiple
477                                     ? startVelocityScaling_[mdAtoms_->mdatoms()->cTC[a]]
478                                     : lambdaStart,
479                             numEndVelocityScalingValues == NumVelocityScalingValues::Multiple
480                                     ? endVelocityScaling_[mdAtoms_->mdatoms()->cTC[a]]
481                                     : lambdaEnd,
482                             invMassPerDim,
483                             v,
484                             f,
485                             diagPR_,
486                             matrixPR_);
487                 }
488                 else
489                 {
490                     updateVelocities<numStartVelocityScalingValues, parrinelloRahmanVelocityScaling, numEndVelocityScalingValues>(
491                             a,
492                             0.5 * timestep_,
493                             numStartVelocityScalingValues == NumVelocityScalingValues::Multiple
494                                     ? startVelocityScaling_[mdAtoms_->mdatoms()->cTC[a]]
495                                     : lambdaStart,
496                             numEndVelocityScalingValues == NumVelocityScalingValues::Multiple
497                                     ? endVelocityScaling_[mdAtoms_->mdatoms()->cTC[a]]
498                                     : lambdaEnd,
499                             invMassPerDim,
500                             v,
501                             f,
502                             diagPR_,
503                             matrixPR_);
504                 }
505                 updatePositions(a, timestep_, x, xp, v);
506             }
507         }
508         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
509     }
510     wallcycle_stop(wcycle_, WallCycleCounter::Update);
511 }
512
513 //! Scaling (velocity scaling only)
514 template<>
515 template<NumVelocityScalingValues        numStartVelocityScalingValues,
516          ParrinelloRahmanVelocityScaling parrinelloRahmanVelocityScaling,
517          NumVelocityScalingValues        numEndVelocityScalingValues,
518          NumPositionScalingValues        numPositionScalingValues>
519 void Propagator<IntegrationStage::ScaleVelocities>::run()
520 {
521     if (numStartVelocityScalingValues == NumVelocityScalingValues::None)
522     {
523         return;
524     }
525     wallcycle_start(wcycle_, WallCycleCounter::Update);
526
527     auto* v = as_rvec_array(statePropagatorData_->velocitiesView().paddedArrayRef().data());
528
529     const real lambdaStart = (numStartVelocityScalingValues == NumVelocityScalingValues::Single)
530                                      ? startVelocityScaling_[0]
531                                      : 1.0;
532
533     const int nth    = gmx_omp_nthreads_get(emntUpdate);
534     const int homenr = mdAtoms_->mdatoms()->homenr;
535
536 // const variables could be shared, but gcc-8 & gcc-9 don't agree how to write that...
537 // https://www.gnu.org/software/gcc/gcc-9/porting_to.html -> OpenMP data sharing
538 #pragma omp parallel for num_threads(nth) schedule(static) default(none) shared(v) \
539         firstprivate(nth, homenr, lambdaStart)
540     for (int th = 0; th < nth; th++)
541     {
542         try
543         {
544             int start_th = 0;
545             int end_th   = 0;
546             getThreadAtomRange(nth, th, homenr, &start_th, &end_th);
547
548             for (int a = start_th; a < end_th; a++)
549             {
550                 scaleVelocities<numStartVelocityScalingValues>(
551                         a,
552                         numStartVelocityScalingValues == NumVelocityScalingValues::Multiple
553                                 ? startVelocityScaling_[mdAtoms_->mdatoms()->cTC[a]]
554                                 : lambdaStart,
555                         v);
556             }
557         }
558         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
559     }
560     wallcycle_stop(wcycle_, WallCycleCounter::Update);
561 }
562
563 template<IntegrationStage integrationStage>
564 Propagator<integrationStage>::Propagator(double               timestep,
565                                          StatePropagatorData* statePropagatorData,
566                                          const MDAtoms*       mdAtoms,
567                                          gmx_wallcycle*       wcycle) :
568     timestep_(timestep),
569     statePropagatorData_(statePropagatorData),
570     doSingleStartVelocityScaling_(false),
571     doGroupStartVelocityScaling_(false),
572     doSingleEndVelocityScaling_(false),
573     doGroupEndVelocityScaling_(false),
574     scalingStepVelocity_(-1),
575     diagPR_{ 0 },
576     matrixPR_{ { 0 } },
577     scalingStepPR_(-1),
578     mdAtoms_(mdAtoms),
579     wcycle_(wcycle)
580 {
581 }
582
583 template<IntegrationStage integrationStage>
584 void Propagator<integrationStage>::scheduleTask(Step step,
585                                                 Time gmx_unused            time,
586                                                 const RegisterRunFunction& registerRunFunction)
587 {
588     const bool doSingleVScalingThisStep =
589             (doSingleStartVelocityScaling_ && (step == scalingStepVelocity_));
590     const bool doGroupVScalingThisStep = (doGroupStartVelocityScaling_ && (step == scalingStepVelocity_));
591
592     if (integrationStage == IntegrationStage::ScaleVelocities)
593     {
594         // IntegrationStage::ScaleVelocities only needs to run if some kind of
595         // velocity scaling is needed on the current step.
596         if (!doSingleVScalingThisStep && !doGroupVScalingThisStep)
597         {
598             return;
599         }
600     }
601
602     if (integrationStage == IntegrationStage::ScalePositions)
603     {
604         // IntegrationStage::ScalePositions only needs to run if
605         // position scaling is needed on the current step.
606         if (step != scalingStepPosition_)
607         {
608             return;
609         }
610         // Since IntegrationStage::ScalePositions is the only stage for which position scaling
611         // is implemented we handle it here to avoid enlarging the decision tree below.
612         if (doSinglePositionScaling_)
613         {
614             registerRunFunction([this]() {
615                 run<NumVelocityScalingValues::None,
616                     ParrinelloRahmanVelocityScaling::No,
617                     NumVelocityScalingValues::None,
618                     NumPositionScalingValues::Single>();
619             });
620         }
621         else if (doGroupPositionScaling_)
622         {
623             registerRunFunction([this]() {
624                 run<NumVelocityScalingValues::None,
625                     ParrinelloRahmanVelocityScaling::No,
626                     NumVelocityScalingValues::None,
627                     NumPositionScalingValues::Multiple>();
628             });
629         }
630     }
631
632     const bool doParrinelloRahmanThisStep = (step == scalingStepPR_);
633
634     if (doSingleVScalingThisStep)
635     {
636         if (doParrinelloRahmanThisStep)
637         {
638             if (doSingleEndVelocityScaling_)
639             {
640                 registerRunFunction([this]() {
641                     run<NumVelocityScalingValues::Single,
642                         ParrinelloRahmanVelocityScaling::Full,
643                         NumVelocityScalingValues::Single,
644                         NumPositionScalingValues::None>();
645                 });
646             }
647             else
648             {
649                 registerRunFunction([this]() {
650                     run<NumVelocityScalingValues::Single,
651                         ParrinelloRahmanVelocityScaling::Full,
652                         NumVelocityScalingValues::None,
653                         NumPositionScalingValues::None>();
654                 });
655             }
656         }
657         else
658         {
659             if (doSingleEndVelocityScaling_)
660             {
661                 registerRunFunction([this]() {
662                     run<NumVelocityScalingValues::Single,
663                         ParrinelloRahmanVelocityScaling::No,
664                         NumVelocityScalingValues::Single,
665                         NumPositionScalingValues::None>();
666                 });
667             }
668             else
669             {
670                 registerRunFunction([this]() {
671                     run<NumVelocityScalingValues::Single,
672                         ParrinelloRahmanVelocityScaling::No,
673                         NumVelocityScalingValues::None,
674                         NumPositionScalingValues::None>();
675                 });
676             }
677         }
678     }
679     else if (doGroupVScalingThisStep)
680     {
681         if (doParrinelloRahmanThisStep)
682         {
683             if (doGroupEndVelocityScaling_)
684             {
685                 registerRunFunction([this]() {
686                     run<NumVelocityScalingValues::Multiple,
687                         ParrinelloRahmanVelocityScaling::Full,
688                         NumVelocityScalingValues::Multiple,
689                         NumPositionScalingValues::None>();
690                 });
691             }
692             else
693             {
694                 registerRunFunction([this]() {
695                     run<NumVelocityScalingValues::Multiple,
696                         ParrinelloRahmanVelocityScaling::Full,
697                         NumVelocityScalingValues::None,
698                         NumPositionScalingValues::None>();
699                 });
700             }
701         }
702         else
703         {
704             if (doGroupEndVelocityScaling_)
705             {
706                 registerRunFunction([this]() {
707                     run<NumVelocityScalingValues::Multiple,
708                         ParrinelloRahmanVelocityScaling::No,
709                         NumVelocityScalingValues::Multiple,
710                         NumPositionScalingValues::None>();
711                 });
712             }
713             else
714             {
715                 registerRunFunction([this]() {
716                     run<NumVelocityScalingValues::Multiple,
717                         ParrinelloRahmanVelocityScaling::No,
718                         NumVelocityScalingValues::None,
719                         NumPositionScalingValues::None>();
720                 });
721             }
722         }
723     }
724     else
725     {
726         if (doParrinelloRahmanThisStep)
727         {
728             registerRunFunction([this]() {
729                 run<NumVelocityScalingValues::None,
730                     ParrinelloRahmanVelocityScaling::Full,
731                     NumVelocityScalingValues::None,
732                     NumPositionScalingValues::None>();
733             });
734         }
735         else
736         {
737             registerRunFunction([this]() {
738                 run<NumVelocityScalingValues::None,
739                     ParrinelloRahmanVelocityScaling::No,
740                     NumVelocityScalingValues::None,
741                     NumPositionScalingValues::None>();
742             });
743         }
744     }
745 }
746
747 template<IntegrationStage integrationStage>
748 constexpr bool hasStartVelocityScaling()
749 {
750     return (integrationStage == IntegrationStage::VelocitiesOnly
751             || integrationStage == IntegrationStage::LeapFrog
752             || integrationStage == IntegrationStage::VelocityVerletPositionsAndVelocities
753             || integrationStage == IntegrationStage::ScaleVelocities);
754 }
755
756 template<IntegrationStage integrationStage>
757 constexpr bool hasEndVelocityScaling()
758 {
759     return (hasStartVelocityScaling<integrationStage>()
760             && integrationStage != IntegrationStage::ScaleVelocities);
761 }
762
763 template<IntegrationStage integrationStage>
764 constexpr bool hasPositionScaling()
765 {
766     return (integrationStage == IntegrationStage::ScalePositions);
767 }
768
769 template<IntegrationStage integrationStage>
770 constexpr bool hasParrinelloRahmanScaling()
771 {
772     return (integrationStage == IntegrationStage::VelocitiesOnly
773             || integrationStage == IntegrationStage::LeapFrog
774             || integrationStage == IntegrationStage::VelocityVerletPositionsAndVelocities);
775 }
776
777 template<IntegrationStage integrationStage>
778 void Propagator<integrationStage>::setNumVelocityScalingVariables(int numVelocityScalingVariables,
779                                                                   ScaleVelocities scaleVelocities)
780 {
781     GMX_RELEASE_ASSERT(
782             hasStartVelocityScaling<integrationStage>() || hasEndVelocityScaling<integrationStage>(),
783             formatString("Velocity scaling not implemented for %s", integrationStepNames[integrationStage])
784                     .c_str());
785     GMX_RELEASE_ASSERT(startVelocityScaling_.empty(),
786                        "Number of velocity scaling variables cannot be changed once set.");
787
788     const bool scaleEndVelocities = (scaleVelocities == ScaleVelocities::PreStepAndPostStep);
789     startVelocityScaling_.resize(numVelocityScalingVariables, 1.);
790     if (scaleEndVelocities)
791     {
792         endVelocityScaling_.resize(numVelocityScalingVariables, 1.);
793     }
794     doSingleStartVelocityScaling_ = numVelocityScalingVariables == 1;
795     doGroupStartVelocityScaling_  = numVelocityScalingVariables > 1;
796     doSingleEndVelocityScaling_   = doSingleStartVelocityScaling_ && scaleEndVelocities;
797     doGroupEndVelocityScaling_    = doGroupStartVelocityScaling_ && scaleEndVelocities;
798 }
799
800 template<IntegrationStage integrationStage>
801 void Propagator<integrationStage>::setNumPositionScalingVariables(int numPositionScalingVariables)
802 {
803     GMX_RELEASE_ASSERT(hasPositionScaling<integrationStage>(),
804                        formatString("Position scaling not implemented for %s",
805                                     integrationStepNames[integrationStage])
806                                .c_str());
807     GMX_RELEASE_ASSERT(positionScaling_.empty(),
808                        "Number of position scaling variables cannot be changed once set.");
809     positionScaling_.resize(numPositionScalingVariables, 1.);
810     doSinglePositionScaling_ = (numPositionScalingVariables == 1);
811     doGroupPositionScaling_  = (numPositionScalingVariables > 1);
812 }
813
814 template<IntegrationStage integrationStage>
815 ArrayRef<real> Propagator<integrationStage>::viewOnStartVelocityScaling()
816 {
817     GMX_RELEASE_ASSERT(hasStartVelocityScaling<integrationStage>(),
818                        formatString("Start velocity scaling not implemented for %s",
819                                     integrationStepNames[integrationStage])
820                                .c_str());
821     GMX_RELEASE_ASSERT(!startVelocityScaling_.empty(),
822                        "Number of velocity scaling variables not set.");
823
824     return startVelocityScaling_;
825 }
826
827 template<IntegrationStage integrationStage>
828 ArrayRef<real> Propagator<integrationStage>::viewOnEndVelocityScaling()
829 {
830     GMX_RELEASE_ASSERT(hasEndVelocityScaling<integrationStage>(),
831                        formatString("End velocity scaling not implemented for %s",
832                                     integrationStepNames[integrationStage])
833                                .c_str());
834     GMX_RELEASE_ASSERT(!endVelocityScaling_.empty(),
835                        "Number of velocity scaling variables not set.");
836
837     return endVelocityScaling_;
838 }
839
840 template<IntegrationStage integrationStage>
841 ArrayRef<real> Propagator<integrationStage>::viewOnPositionScaling()
842 {
843     GMX_RELEASE_ASSERT(hasPositionScaling<integrationStage>(),
844                        formatString("Position scaling not implemented for %s",
845                                     integrationStepNames[integrationStage])
846                                .c_str());
847     GMX_RELEASE_ASSERT(!positionScaling_.empty(), "Number of position scaling variables not set.");
848
849     return positionScaling_;
850 }
851
852 template<IntegrationStage integrationStage>
853 PropagatorCallback Propagator<integrationStage>::velocityScalingCallback()
854 {
855     GMX_RELEASE_ASSERT(
856             hasStartVelocityScaling<integrationStage>() || hasEndVelocityScaling<integrationStage>(),
857             formatString("Velocity scaling not implemented for %s", integrationStepNames[integrationStage])
858                     .c_str());
859
860     return [this](Step step) { scalingStepVelocity_ = step; };
861 }
862
863 template<IntegrationStage integrationStage>
864 PropagatorCallback Propagator<integrationStage>::positionScalingCallback()
865 {
866     GMX_RELEASE_ASSERT(hasPositionScaling<integrationStage>(),
867                        formatString("Position scaling not implemented for %s",
868                                     integrationStepNames[integrationStage])
869                                .c_str());
870
871     return [this](Step step) { scalingStepPosition_ = step; };
872 }
873
874 template<IntegrationStage integrationStage>
875 ArrayRef<rvec> Propagator<integrationStage>::viewOnPRScalingMatrix()
876 {
877     GMX_RELEASE_ASSERT(hasParrinelloRahmanScaling<integrationStage>(),
878                        formatString("Parrinello-Rahman scaling not implemented for %s",
879                                     integrationStepNames[integrationStage])
880                                .c_str());
881
882     clear_mat(matrixPR_);
883     // gcc-5 needs this to be explicit (all other tested compilers would be ok
884     // with simply returning matrixPR)
885     return ArrayRef<rvec>(matrixPR_);
886 }
887
888 template<IntegrationStage integrationStage>
889 PropagatorCallback Propagator<integrationStage>::prScalingCallback()
890 {
891     GMX_RELEASE_ASSERT(hasParrinelloRahmanScaling<integrationStage>(),
892                        formatString("Parrinello-Rahman scaling not implemented for %s",
893                                     integrationStepNames[integrationStage])
894                                .c_str());
895
896     return [this](Step step) { scalingStepPR_ = step; };
897 }
898
899 // doxygen is confused by the two definitions
900 //! \cond
901 template<IntegrationStage integrationStage>
902 ISimulatorElement* Propagator<integrationStage>::getElementPointerImpl(
903         LegacySimulatorData*                    legacySimulatorData,
904         ModularSimulatorAlgorithmBuilderHelper* builderHelper,
905         StatePropagatorData*                    statePropagatorData,
906         EnergyData gmx_unused*     energyData,
907         FreeEnergyPerturbationData gmx_unused* freeEnergyPerturbationData,
908         GlobalCommunicationHelper gmx_unused* globalCommunicationHelper,
909         const PropagatorTag&                  propagatorTag,
910         double                                timestep)
911 {
912     GMX_RELEASE_ASSERT(!(integrationStage == IntegrationStage::ScaleVelocities
913                          || integrationStage == IntegrationStage::ScalePositions)
914                                || (timestep == 0.0),
915                        "Scaling elements don't propagate the system.");
916     auto* element    = builderHelper->storeElement(std::make_unique<Propagator<integrationStage>>(
917             timestep, statePropagatorData, legacySimulatorData->mdAtoms, legacySimulatorData->wcycle));
918     auto* propagator = static_cast<Propagator<integrationStage>*>(element);
919     builderHelper->registerWithThermostat(
920             { [propagator](int num, ScaleVelocities scaleVelocities) {
921                  propagator->setNumVelocityScalingVariables(num, scaleVelocities);
922              },
923               [propagator]() { return propagator->viewOnStartVelocityScaling(); },
924               [propagator]() { return propagator->viewOnEndVelocityScaling(); },
925               [propagator]() { return propagator->velocityScalingCallback(); },
926               propagatorTag });
927     builderHelper->registerWithBarostat(
928             { [propagator]() { return propagator->viewOnPRScalingMatrix(); },
929               [propagator]() { return propagator->prScalingCallback(); },
930               propagatorTag });
931     return element;
932 }
933
934 template<IntegrationStage integrationStage>
935 ISimulatorElement* Propagator<integrationStage>::getElementPointerImpl(
936         LegacySimulatorData*                    legacySimulatorData,
937         ModularSimulatorAlgorithmBuilderHelper* builderHelper,
938         StatePropagatorData*                    statePropagatorData,
939         EnergyData*                             energyData,
940         FreeEnergyPerturbationData*             freeEnergyPerturbationData,
941         GlobalCommunicationHelper*              globalCommunicationHelper,
942         const PropagatorTag&                    propagatorTag)
943 {
944     GMX_RELEASE_ASSERT(
945             integrationStage == IntegrationStage::ScaleVelocities
946                     || integrationStage == IntegrationStage::ScalePositions,
947             "Adding a propagator without time step is only allowed for scaling elements");
948     return getElementPointerImpl(legacySimulatorData,
949                                  builderHelper,
950                                  statePropagatorData,
951                                  energyData,
952                                  freeEnergyPerturbationData,
953                                  globalCommunicationHelper,
954                                  propagatorTag,
955                                  0.0);
956 }
957 //! \endcond
958
959 // Explicit template initializations
960 template class Propagator<IntegrationStage::PositionsOnly>;
961 template class Propagator<IntegrationStage::VelocitiesOnly>;
962 template class Propagator<IntegrationStage::LeapFrog>;
963 template class Propagator<IntegrationStage::VelocityVerletPositionsAndVelocities>;
964 template class Propagator<IntegrationStage::ScaleVelocities>;
965 template class Propagator<IntegrationStage::ScalePositions>;
966
967 } // namespace gmx