Introduce plumbing for ObservablesReducer
[alexxy/gromacs.git] / src / gromacs / modularsimulator / nosehooverchains.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  * \brief Defines classes related to Nose-Hoover chains for the modular simulator
37  *
38  * \author Pascal Merz <pascal.merz@me.com>
39  * \ingroup module_modularsimulator
40  */
41
42 #include "gmxpre.h"
43
44 #include "nosehooverchains.h"
45
46 #include <numeric>
47
48 #include "gromacs/domdec/domdec_network.h"
49 #include "gromacs/math/functions.h"
50 #include "gromacs/math/units.h"
51 #include "gromacs/math/vec.h"
52 #include "gromacs/mdlib/stat.h"
53 #include "gromacs/mdtypes/commrec.h"
54 #include "gromacs/mdtypes/group.h"
55 #include "gromacs/mdtypes/inputrec.h"
56 #include "gromacs/utility/fatalerror.h"
57 #include "gromacs/utility/strconvert.h"
58
59 #include "energydata.h"
60 #include "mttk.h"
61 #include "simulatoralgorithm.h"
62 #include "trotterhelperfunctions.h"
63 #include "velocityscalingtemperaturecoupling.h"
64
65 namespace gmx
66 {
67 // Names of the NHC usage options
68 static constexpr EnumerationArray<NhcUsage, const char*> nhcUsageNames = { "System", "Barostat" };
69
70 //! The current state of the Nose-Hoover chain degree of freedom for a temperature group
71 class NoseHooverGroup final
72 {
73 public:
74     //! Constructor
75     NoseHooverGroup(int      chainLength,
76                     real     referenceTemperature,
77                     real     numDegreesOfFreedom,
78                     real     couplingTime,
79                     real     couplingTimeStep,
80                     NhcUsage nhcUsage);
81
82     //! Trotter operator for the NHC degrees of freedom
83     real applyNhc(real currentKineticEnergy, real couplingTimeStep);
84
85     //! Save to or restore from a CheckpointData object
86     template<CheckpointDataOperation operation>
87     void doCheckpoint(CheckpointData<operation>* checkpointData);
88     //! Broadcast values read from checkpoint over DD ranks
89     void broadcastCheckpointValues(const gmx_domdec_t* dd);
90
91     //! Whether the coordinate time is at a full coupling time step
92     bool isAtFullCouplingTimeStep() const;
93     //! Update the value of the NHC integral with the current coordinates
94     void calculateIntegral();
95     //! Return the current NHC integral for the group
96     double integral() const;
97     //! Return the current time of the NHC integral for the group
98     real integralTime() const;
99
100     //! Set the reference temperature
101     void updateReferenceTemperature(real temperature);
102
103 private:
104     //! Increment coordinate time and update integral if applicable
105     void finalizeUpdate(real couplingTimeStep);
106
107     //! The reference temperature of this group
108     real referenceTemperature_;
109     //! The coupling time of this group
110     const real couplingTime_;
111     //! The number of degrees of freedom in this group
112     const real numDegreesOfFreedom_;
113     //! The chain length of this group
114     const int chainLength_;
115     //! The coupling time step, indicates when the coordinates are at a full step
116     const real couplingTimeStep_;
117     //! The thermostat degree of freedom
118     std::vector<real> xi_;
119     //! Velocity of the thermostat dof
120     std::vector<real> xiVelocities_;
121     //! Work exerted by thermostat per group
122     double temperatureCouplingIntegral_;
123     //! Inverse mass of the thermostat dof
124     std::vector<real> invXiMass_;
125     //! The current time of xi and xiVelocities_
126     real coordinateTime_;
127     //! The current time of the temperature integral
128     real integralTime_;
129 };
130
131 NoseHooverGroup::NoseHooverGroup(int      chainLength,
132                                  real     referenceTemperature,
133                                  real     numDegreesOfFreedom,
134                                  real     couplingTime,
135                                  real     couplingTimeStep,
136                                  NhcUsage nhcUsage) :
137     referenceTemperature_(referenceTemperature),
138     couplingTime_(couplingTime),
139     numDegreesOfFreedom_(numDegreesOfFreedom),
140     chainLength_(chainLength),
141     couplingTimeStep_(couplingTimeStep),
142     xi_(chainLength, 0),
143     xiVelocities_(chainLength, 0),
144     temperatureCouplingIntegral_(0),
145     invXiMass_(chainLength, 0),
146     coordinateTime_(0),
147     integralTime_(0)
148 {
149     if (referenceTemperature > 0 && couplingTime > 0 && numDegreesOfFreedom > 0)
150     {
151         for (auto chainPosition = 0; chainPosition < chainLength; ++chainPosition)
152         {
153             const real numDof = ((chainPosition == 0) ? numDegreesOfFreedom : 1);
154             invXiMass_[chainPosition] =
155                     1.0 / (gmx::square(couplingTime / M_2PI) * referenceTemperature * numDof * c_boltz);
156             if (nhcUsage == NhcUsage::Barostat && chainPosition == 0)
157             {
158                 invXiMass_[chainPosition] /= DIM * DIM;
159             }
160         }
161     }
162 }
163
164 NoseHooverChainsData::NoseHooverChainsData(int                  numTemperatureGroups,
165                                            real                 couplingTimeStep,
166                                            int                  chainLength,
167                                            ArrayRef<const real> referenceTemperature,
168                                            ArrayRef<const real> couplingTime,
169                                            ArrayRef<const real> numDegreesOfFreedom,
170                                            NhcUsage             nhcUsage) :
171     identifier_(formatString("NoseHooverChainsData-%s", nhcUsageNames[nhcUsage])),
172     numTemperatureGroups_(numTemperatureGroups)
173 {
174     if (nhcUsage == NhcUsage::System)
175     {
176         for (auto temperatureGroup = 0; temperatureGroup < numTemperatureGroups; ++temperatureGroup)
177         {
178             noseHooverGroups_.emplace_back(chainLength,
179                                            referenceTemperature[temperatureGroup],
180                                            numDegreesOfFreedom[temperatureGroup],
181                                            couplingTime[temperatureGroup],
182                                            couplingTimeStep,
183                                            nhcUsage);
184         }
185     }
186     else if (nhcUsage == NhcUsage::Barostat)
187     {
188         GMX_RELEASE_ASSERT(numTemperatureGroups == 1,
189                            "There can only be one barostat for the system");
190         // Barostat has a single degree of freedom
191         const int degreesOfFreedom = 1;
192         noseHooverGroups_.emplace_back(
193                 chainLength, referenceTemperature[0], degreesOfFreedom, couplingTime[0], couplingTimeStep, nhcUsage);
194     }
195 }
196
197 NoseHooverChainsData::NoseHooverChainsData(const NoseHooverChainsData& other) :
198     identifier_(other.identifier_),
199     noseHooverGroups_(other.noseHooverGroups_),
200     numTemperatureGroups_(other.numTemperatureGroups_)
201 {
202 }
203
204 void NoseHooverChainsData::build(NhcUsage                                nhcUsage,
205                                  LegacySimulatorData*                    legacySimulatorData,
206                                  ModularSimulatorAlgorithmBuilderHelper* builderHelper,
207                                  EnergyData*                             energyData)
208 {
209     // The caller of this method now owns the data and will handle its lifetime (see ModularSimulatorAlgorithm)
210     if (nhcUsage == NhcUsage::System)
211     {
212         builderHelper->storeSimulationData(
213                 NoseHooverChainsData::dataID(nhcUsage),
214                 NoseHooverChainsData(
215                         legacySimulatorData->inputrec->opts.ngtc,
216                         legacySimulatorData->inputrec->delta_t * legacySimulatorData->inputrec->nsttcouple,
217                         legacySimulatorData->inputrec->opts.nhchainlength,
218                         constArrayRefFromArray(legacySimulatorData->inputrec->opts.ref_t,
219                                                legacySimulatorData->inputrec->opts.ngtc),
220                         constArrayRefFromArray(legacySimulatorData->inputrec->opts.tau_t,
221                                                legacySimulatorData->inputrec->opts.ngtc),
222                         constArrayRefFromArray(legacySimulatorData->inputrec->opts.nrdf,
223                                                legacySimulatorData->inputrec->opts.ngtc),
224                         nhcUsage));
225     }
226     else
227     {
228         const int numTemperatureGroups = 1;
229         builderHelper->storeSimulationData(
230                 NoseHooverChainsData::dataID(nhcUsage),
231                 NoseHooverChainsData(
232                         numTemperatureGroups,
233                         legacySimulatorData->inputrec->delta_t * legacySimulatorData->inputrec->nstpcouple,
234                         legacySimulatorData->inputrec->opts.nhchainlength,
235                         constArrayRefFromArray(legacySimulatorData->inputrec->opts.ref_t, 1),
236                         constArrayRefFromArray(legacySimulatorData->inputrec->opts.tau_t, 1),
237                         ArrayRef<real>(),
238                         nhcUsage));
239     }
240     auto* nhcDataPtr =
241             builderHelper
242                     ->simulationData<NoseHooverChainsData>(NoseHooverChainsData::dataID(nhcUsage))
243                     .value();
244     builderHelper->registerReferenceTemperatureUpdate(
245             [nhcDataPtr](ArrayRef<const real> temperatures, ReferenceTemperatureChangeAlgorithm algorithm) {
246                 nhcDataPtr->updateReferenceTemperature(temperatures, algorithm);
247             });
248
249     const auto* ptrToDataObject =
250             builderHelper
251                     ->simulationData<NoseHooverChainsData>(NoseHooverChainsData::dataID(nhcUsage))
252                     .value();
253     energyData->addConservedEnergyContribution([ptrToDataObject](Step /*unused*/, Time time) {
254         return ptrToDataObject->temperatureCouplingIntegral(time);
255     });
256 }
257
258 void NoseHooverGroup::finalizeUpdate(real couplingTimeStep)
259 {
260     coordinateTime_ += couplingTimeStep;
261     if (isAtFullCouplingTimeStep())
262     {
263         calculateIntegral();
264     }
265 }
266
267 inline int NoseHooverChainsData::numTemperatureGroups() const
268 {
269     return numTemperatureGroups_;
270 }
271
272 inline bool NoseHooverChainsData::isAtFullCouplingTimeStep() const
273 {
274     return std::all_of(noseHooverGroups_.begin(), noseHooverGroups_.end(), [](const auto& group) {
275         return group.isAtFullCouplingTimeStep();
276     });
277 }
278
279 void NoseHooverGroup::calculateIntegral()
280 {
281     // Calculate current value of thermostat integral
282     temperatureCouplingIntegral_ = 0.0;
283     for (auto chainPosition = 0; chainPosition < chainLength_; ++chainPosition)
284     {
285         // Chain thermostats have only one degree of freedom
286         const real numDegreesOfFreedomThisPosition = (chainPosition == 0) ? numDegreesOfFreedom_ : 1;
287         temperatureCouplingIntegral_ +=
288                 0.5 * gmx::square(xiVelocities_[chainPosition]) / invXiMass_[chainPosition]
289                 + numDegreesOfFreedomThisPosition * xi_[chainPosition] * c_boltz * referenceTemperature_;
290     }
291     integralTime_ = coordinateTime_;
292 }
293
294 inline double NoseHooverGroup::integral() const
295 {
296     return temperatureCouplingIntegral_;
297 }
298
299 inline real NoseHooverGroup::integralTime() const
300 {
301     return integralTime_;
302 }
303
304 double NoseHooverChainsData::temperatureCouplingIntegral(Time gmx_used_in_debug time) const
305 {
306     /* When using nsttcouple >= nstcalcenergy, we accept that the coupling
307      * integral might be ahead of the current energy calculation step. The
308      * extended system degrees of freedom are either in sync or ahead of the
309      * rest of the system.
310      */
311     GMX_ASSERT(!std::any_of(noseHooverGroups_.begin(),
312                             noseHooverGroups_.end(),
313                             [time](const auto& group) {
314                                 return !(time <= group.integralTime()
315                                          || timesClose(group.integralTime(), time));
316                             }),
317                "NoseHooverChainsData conserved energy time mismatch.");
318     double result = 0;
319     std::for_each(noseHooverGroups_.begin(), noseHooverGroups_.end(), [&result](const auto& group) {
320         result += group.integral();
321     });
322     return result;
323 }
324
325 inline bool NoseHooverGroup::isAtFullCouplingTimeStep() const
326 {
327     // Check whether coordinate time divided by the time step is close to integer
328     return timesClose(std::lround(coordinateTime_ / couplingTimeStep_) * couplingTimeStep_, coordinateTime_);
329 }
330
331 void NoseHooverChainsData::updateReferenceTemperature(ArrayRef<const real> temperatures,
332                                                       ReferenceTemperatureChangeAlgorithm gmx_unused algorithm)
333 {
334     // Currently, we don't know about any temperature change algorithms, so we assert this never gets called
335     GMX_ASSERT(false, "NoseHooverChainsData: Unknown ReferenceTemperatureChangeAlgorithm.");
336     for (auto temperatureGroup = 0; temperatureGroup < numTemperatureGroups_; ++temperatureGroup)
337     {
338         noseHooverGroups_[temperatureGroup].updateReferenceTemperature(temperatures[temperatureGroup]);
339         if (noseHooverGroups_[temperatureGroup].isAtFullCouplingTimeStep())
340         {
341             noseHooverGroups_[temperatureGroup].calculateIntegral();
342         }
343     }
344 }
345
346 void NoseHooverGroup::updateReferenceTemperature(real temperature)
347 {
348     const bool newTemperatureIsValid = (temperature > 0 && couplingTime_ > 0 && numDegreesOfFreedom_ > 0);
349     const bool oldTemperatureIsValid =
350             (referenceTemperature_ > 0 && couplingTime_ > 0 && numDegreesOfFreedom_ > 0);
351     GMX_RELEASE_ASSERT(newTemperatureIsValid == oldTemperatureIsValid,
352                        "Cannot turn temperature coupling on / off during simulation run.");
353     if (oldTemperatureIsValid && newTemperatureIsValid)
354     {
355         const real velocityFactor = std::sqrt(temperature / referenceTemperature_);
356         for (auto chainPosition = 0; chainPosition < chainLength_; ++chainPosition)
357         {
358             invXiMass_[chainPosition] *= (referenceTemperature_ / temperature);
359             xiVelocities_[chainPosition] *= velocityFactor;
360         }
361     }
362     referenceTemperature_ = temperature;
363     if (isAtFullCouplingTimeStep())
364     {
365         calculateIntegral();
366     }
367 }
368
369 namespace
370 {
371 /*!
372  * \brief Enum describing the contents NoseHooverChainsData writes to modular checkpoint
373  *
374  * When changing the checkpoint content, add a new element just above Count, and adjust the
375  * checkpoint functionality.
376  */
377 enum class CheckpointVersion
378 {
379     Base, //!< First version of modular checkpointing
380     Count //!< Number of entries. Add new versions right above this!
381 };
382 constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
383 } // namespace
384
385 template<CheckpointDataOperation operation>
386 void NoseHooverChainsData::doCheckpointData(CheckpointData<operation>* checkpointData)
387 {
388     checkpointVersion(checkpointData, "NoseHooverChainsData version", c_currentVersion);
389
390     for (int temperatureGroup = 0; temperatureGroup < numTemperatureGroups_; ++temperatureGroup)
391     {
392         const auto temperatureGroupStr = "T-group #" + toString(temperatureGroup);
393         auto       groupCheckpointData = checkpointData->subCheckpointData(temperatureGroupStr);
394         noseHooverGroups_[temperatureGroup].doCheckpoint(&groupCheckpointData);
395     }
396 }
397
398 template<CheckpointDataOperation operation>
399 void NoseHooverGroup::doCheckpoint(CheckpointData<operation>* checkpointData)
400 {
401     checkpointData->arrayRef("xi", makeCheckpointArrayRef<operation>(xi_));
402     checkpointData->arrayRef("xi velocities", makeCheckpointArrayRef<operation>(xiVelocities_));
403     checkpointData->scalar("Coordinate time", &coordinateTime_);
404 }
405
406 //! Broadcast values read from checkpoint over DD ranks
407 void NoseHooverGroup::broadcastCheckpointValues(const gmx_domdec_t* dd)
408 {
409     dd_bcast(dd, ssize(xi_) * int(sizeof(real)), xi_.data());
410     dd_bcast(dd, ssize(xiVelocities_) * int(sizeof(real)), xiVelocities_.data());
411     dd_bcast(dd, int(sizeof(real)), &coordinateTime_);
412 }
413
414 void NoseHooverChainsData::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
415                                                const t_commrec*                   cr)
416 {
417     if (MASTER(cr))
418     {
419         doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
420     }
421 }
422
423 void NoseHooverChainsData::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
424                                                   const t_commrec*                  cr)
425 {
426     if (MASTER(cr))
427     {
428         doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
429     }
430     for (auto& group : noseHooverGroups_)
431     {
432         if (DOMAINDECOMP(cr))
433         {
434             group.broadcastCheckpointValues(cr->dd);
435         }
436         group.calculateIntegral();
437     }
438 }
439
440 const std::string& NoseHooverChainsData::clientID()
441 {
442     return identifier_;
443 }
444
445 std::string NoseHooverChainsData::dataID(NhcUsage nhcUsage)
446 {
447     return formatString("NoseHooverChainsData%s", nhcUsageNames[nhcUsage]);
448 }
449
450 /* This follows Tuckerman et al. 2006
451  *
452  * In NVT, the Trotter decomposition reads
453  *   exp[iL dt] = exp[iLT dt/2] exp[iLv dt/2] exp[iLx dt] exp[iLv dt/2] exp[iLT dt/2]
454  * iLv denotes the velocity propagation, iLx the position propagation
455  * iLT denotes the thermostat propagation implemented here:
456  *     v_xi[i](t-dt/2) = v_xi[i](t-dt) + dt_xi/2 * a_xi[i](t-dt);
457  *     xi[i](t) = xi[i](t-dt) + dt_xi * v_xi[i](t-dt/2);
458  *     v_sys *= exp(-dt/2 * v_xi[1](t-dt/2))
459  *     v_xi[i](t) = v_xi[i](t-dt/2) + dt_xi/2 * a_xi[i](t);
460  * where i = 1 ... N_chain, and
461  *     a[i](t) = (M_xi * v_xi[i+1](t)^2 - 2*K_ref) / M_xi , i = 2 ... N_chain
462  *     a[1](t) = (K_sys - K_ref) / M_xi
463  * Note, iLT contains a term scaling the system velocities!
464  *
465  * In the legacy GROMACS simulator, the top of the loop marks the simulation
466  * state at x(t), v(t-dt/2), f(t-1), mirroring the leap-frog implementation.
467  * The loop then proceeds to calculate the forces at time t, followed by a
468  * velocity half step (corresponding to the second exp[iLv dt/2] above).
469  * For Tuckerman NHC NVT, this is followed by a thermostat propagation to reach
470  * the full timestep t state. This is the state which is printed to file, so
471  * we need to scale the velocities.
472  * After writing to file, the next step effectively starts, by moving the thermostat
473  * variables (half step), the velocities (half step) and the positions (full step),
474  * which is equivalent to the first three terms of the Trotter decomposition above.
475  * Currently, modular simulator is replicating the division of the simulator loop
476  * used by the legacy simulator. The implementation here is independent of these
477  * assumptions, but the builder of the simulator must be careful to ensure that
478  * velocity scaling is applied before re-using the velocities after the thermostat.
479  *
480  * The time-scale separation between the particles and the thermostat requires the
481  * NHC operator to have a higher-order factorization. The method used is the
482  * Suzuki-Yoshida scheme which uses weighted time steps chosen to cancel out
483  * lower-order error terms. Here, the fifth order SY scheme is used.
484  */
485 real NoseHooverGroup::applyNhc(real currentKineticEnergy, const real couplingTimeStep)
486 {
487     if (currentKineticEnergy < 0)
488     {
489         finalizeUpdate(couplingTimeStep);
490         return 1.0;
491     }
492
493     constexpr unsigned int c_suzukiYoshidaOrder                         = 5;
494     constexpr double       c_suzukiYoshidaWeights[c_suzukiYoshidaOrder] = {
495         0.2967324292201065, 0.2967324292201065, -0.186929716880426, 0.2967324292201065, 0.2967324292201065
496     };
497
498     real velocityScalingFactor = 1.0;
499
500     // Apply Suzuki-Yoshida scheme
501     for (unsigned int syOuterLoop = 0; syOuterLoop < c_suzukiYoshidaOrder; ++syOuterLoop)
502     {
503         for (unsigned int syInnerLoop = 0; syInnerLoop < c_suzukiYoshidaOrder; ++syInnerLoop)
504         {
505             const real timeStep =
506                     couplingTimeStep * c_suzukiYoshidaWeights[syInnerLoop] / c_suzukiYoshidaOrder;
507
508             // Reverse loop - start from last thermostat in chain to update velocities,
509             // because we need the new velocity to scale the next thermostat in the chain
510             for (auto chainPosition = chainLength_ - 1; chainPosition >= 0; --chainPosition)
511             {
512                 const real kineticEnergy2 =
513                         ((chainPosition == 0) ? 2 * currentKineticEnergy
514                                               : gmx::square(xiVelocities_[chainPosition - 1])
515                                                         / invXiMass_[chainPosition - 1]);
516                 // DOF of temperature group or chain member
517                 const real numDof         = ((chainPosition == 0) ? numDegreesOfFreedom_ : 1);
518                 const real xiAcceleration = invXiMass_[chainPosition]
519                                             * (kineticEnergy2 - numDof * c_boltz * referenceTemperature_);
520
521                 // We scale based on the next thermostat in chain.
522                 // Last thermostat in chain doesn't get scaled.
523                 const real localScalingFactor =
524                         (chainPosition < chainLength_ - 1)
525                                 ? exp(-0.25 * timeStep * xiVelocities_[chainPosition + 1])
526                                 : 1.0;
527                 xiVelocities_[chainPosition] = localScalingFactor
528                                                * (xiVelocities_[chainPosition] * localScalingFactor
529                                                   + 0.5 * timeStep * xiAcceleration);
530             }
531
532             // Calculate the new system scaling factor
533             const real systemScalingFactor = std::exp(-timeStep * xiVelocities_[0]);
534             velocityScalingFactor *= systemScalingFactor;
535             currentKineticEnergy *= systemScalingFactor * systemScalingFactor;
536
537             // Forward loop - start from the system thermostat
538             for (auto chainPosition = 0; chainPosition < chainLength_; ++chainPosition)
539             {
540                 // Update thermostat positions
541                 xi_[chainPosition] += timeStep * xiVelocities_[chainPosition];
542
543                 // Kinetic energy of system or previous chain member
544                 const real kineticEnergy2 =
545                         ((chainPosition == 0) ? 2 * currentKineticEnergy
546                                               : gmx::square(xiVelocities_[chainPosition - 1])
547                                                         / invXiMass_[chainPosition - 1]);
548                 // DOF of temperature group or chain member
549                 const real numDof         = ((chainPosition == 0) ? numDegreesOfFreedom_ : 1);
550                 const real xiAcceleration = invXiMass_[chainPosition]
551                                             * (kineticEnergy2 - numDof * c_boltz * referenceTemperature_);
552
553                 // We scale based on the next thermostat in chain.
554                 // Last thermostat in chain doesn't get scaled.
555                 const real localScalingFactor =
556                         (chainPosition < chainLength_ - 1)
557                                 ? exp(-0.25 * timeStep * xiVelocities_[chainPosition + 1])
558                                 : 1.0;
559                 xiVelocities_[chainPosition] = localScalingFactor
560                                                * (xiVelocities_[chainPosition] * localScalingFactor
561                                                   + 0.5 * timeStep * xiAcceleration);
562             }
563         }
564     }
565     finalizeUpdate(couplingTimeStep);
566     return velocityScalingFactor;
567 }
568
569 real NoseHooverChainsData::applyNhc(int temperatureGroup, double propagationTimeStep, real currentKineticEnergy)
570 {
571     return noseHooverGroups_[temperatureGroup].applyNhc(currentKineticEnergy, propagationTimeStep);
572 }
573
574 /*!
575  * \brief Calculate the current kinetic energy
576  *
577  * \param tcstat  The group's kinetic energy structure
578  * \return real   The current kinetic energy
579  */
580 inline real NoseHooverChainsElement::currentKineticEnergy(const t_grp_tcstat& tcstat)
581 {
582     if (nhcUsage_ == NhcUsage::System)
583     {
584         if (useFullStepKE_ == UseFullStepKE::Yes)
585         {
586             return trace(tcstat.ekinf) * tcstat.ekinscalef_nhc;
587         }
588         else
589         {
590             return trace(tcstat.ekinh) * tcstat.ekinscaleh_nhc;
591         }
592     }
593     else if (nhcUsage_ == NhcUsage::Barostat)
594     {
595         GMX_RELEASE_ASSERT(useFullStepKE_ == UseFullStepKE::Yes,
596                            "Barostat NHC only works with full step KE.");
597         return mttkData_->kineticEnergy();
598     }
599     else
600     {
601         gmx_fatal(FARGS, "Unknown NhcUsage.");
602     }
603 }
604
605 void NoseHooverChainsElement::propagateNhc()
606 {
607     auto* ekind = energyData_->ekindata();
608
609     for (int temperatureGroup = 0; (temperatureGroup < noseHooverChainData_->numTemperatureGroups());
610          temperatureGroup++)
611     {
612         auto scalingFactor =
613                 noseHooverChainData_->applyNhc(temperatureGroup,
614                                                propagationTimeStep_,
615                                                currentKineticEnergy(ekind->tcstat[temperatureGroup]));
616
617         if (nhcUsage_ == NhcUsage::System)
618         {
619             // Scale system velocities by scalingFactor
620             lambdaStartVelocities_[temperatureGroup] = scalingFactor;
621             // Scale kinetic energy by scalingFactor^2
622             ekind->tcstat[temperatureGroup].ekinscaleh_nhc *= scalingFactor * scalingFactor;
623             ekind->tcstat[temperatureGroup].ekinscalef_nhc *= scalingFactor * scalingFactor;
624         }
625         else if (nhcUsage_ == NhcUsage::Barostat)
626         {
627             // Scale eta velocities by scalingFactor
628             mttkData_->scale(scalingFactor, noseHooverChainData_->isAtFullCouplingTimeStep());
629         }
630     }
631
632     if (nhcUsage_ == NhcUsage::System && noseHooverChainData_->isAtFullCouplingTimeStep())
633     {
634         // We've set the scaling factors for the full time step, so scale
635         // kinetic energy accordingly before it gets printed
636         energyData_->updateKineticEnergy();
637     }
638 }
639
640 NoseHooverChainsElement::NoseHooverChainsElement(int                   nstcouple,
641                                                  int                   offset,
642                                                  NhcUsage              nhcUsage,
643                                                  UseFullStepKE         useFullStepKE,
644                                                  double                propagationTimeStep,
645                                                  ScheduleOnInitStep    scheduleOnInitStep,
646                                                  Step                  initStep,
647                                                  EnergyData*           energyData,
648                                                  NoseHooverChainsData* noseHooverChainData,
649                                                  MttkData*             mttkData) :
650     nsttcouple_(nstcouple),
651     offset_(offset),
652     propagationTimeStep_(propagationTimeStep),
653     nhcUsage_(nhcUsage),
654     useFullStepKE_(useFullStepKE),
655     scheduleOnInitStep_(scheduleOnInitStep),
656     initialStep_(initStep),
657     energyData_(energyData),
658     noseHooverChainData_(noseHooverChainData),
659     mttkData_(mttkData)
660 {
661 }
662
663 void NoseHooverChainsElement::elementSetup()
664 {
665     GMX_RELEASE_ASSERT(
666             !(nhcUsage_ == NhcUsage::System && !propagatorCallback_),
667             "Nose-Hoover chain element was not connected to a propagator.\n"
668             "Connection to a propagator element is needed to scale the velocities.\n"
669             "Use connectWithPropagator(...) before building the ModularSimulatorAlgorithm "
670             "object.");
671 }
672
673 void NoseHooverChainsElement::scheduleTask(Step step, Time /*unused*/, const RegisterRunFunction& registerRunFunction)
674 {
675     if (step == initialStep_ && scheduleOnInitStep_ == ScheduleOnInitStep::No)
676     {
677         return;
678     }
679     if (do_per_step(step + nsttcouple_ + offset_, nsttcouple_))
680     {
681         // do T-coupling this step
682         registerRunFunction([this]() { propagateNhc(); });
683
684         if (propagatorCallback_)
685         {
686             // Let propagator know that we want to do T-coupling
687             propagatorCallback_(step);
688         }
689     }
690 }
691
692 void NoseHooverChainsElement::connectWithPropagator(const PropagatorConnection& connectionData,
693                                                     const PropagatorTag&        propagatorTag)
694 {
695     if (connectionData.tag == propagatorTag)
696     {
697         GMX_RELEASE_ASSERT(connectionData.hasStartVelocityScaling(),
698                            "Trotter NHC needs start velocity scaling.");
699         connectionData.setNumVelocityScalingVariables(noseHooverChainData_->numTemperatureGroups(),
700                                                       ScaleVelocities::PreStepOnly);
701         lambdaStartVelocities_ = connectionData.getViewOnStartVelocityScaling();
702         propagatorCallback_    = connectionData.getVelocityScalingCallback();
703     }
704 }
705
706 //! \cond
707 // Doxygen gets confused by the overload
708 ISimulatorElement* NoseHooverChainsElement::getElementPointerImpl(
709         LegacySimulatorData*                    legacySimulatorData,
710         ModularSimulatorAlgorithmBuilderHelper* builderHelper,
711         StatePropagatorData gmx_unused* statePropagatorData,
712         EnergyData*                     energyData,
713         FreeEnergyPerturbationData gmx_unused* freeEnergyPerturbationData,
714         GlobalCommunicationHelper gmx_unused*  globalCommunicationHelper,
715         ObservablesReducer*                    observablesReducer,
716         NhcUsage                               nhcUsage,
717         Offset                                 offset,
718         UseFullStepKE                          useFullStepKE,
719         ScheduleOnInitStep                     scheduleOnInitStep,
720         const MttkPropagatorConnectionDetails& mttkPropagatorConnectionDetails)
721 {
722     GMX_RELEASE_ASSERT(nhcUsage == NhcUsage::Barostat, "System NHC element needs a propagator tag.");
723     if (!builderHelper->simulationData<MttkData>(MttkData::dataID()))
724     {
725         MttkData::build(legacySimulatorData, builderHelper, statePropagatorData, energyData, mttkPropagatorConnectionDetails);
726     }
727     return getElementPointerImpl(legacySimulatorData,
728                                  builderHelper,
729                                  statePropagatorData,
730                                  energyData,
731                                  freeEnergyPerturbationData,
732                                  globalCommunicationHelper,
733                                  observablesReducer,
734                                  nhcUsage,
735                                  offset,
736                                  useFullStepKE,
737                                  scheduleOnInitStep,
738                                  PropagatorTag(""));
739 }
740
741 ISimulatorElement* NoseHooverChainsElement::getElementPointerImpl(
742         LegacySimulatorData*                    legacySimulatorData,
743         ModularSimulatorAlgorithmBuilderHelper* builderHelper,
744         StatePropagatorData gmx_unused* statePropagatorData,
745         EnergyData*                     energyData,
746         FreeEnergyPerturbationData gmx_unused* freeEnergyPerturbationData,
747         GlobalCommunicationHelper gmx_unused* globalCommunicationHelper,
748         ObservablesReducer gmx_unused* observablesReducer,
749         NhcUsage                       nhcUsage,
750         Offset                         offset,
751         UseFullStepKE                  useFullStepKE,
752         ScheduleOnInitStep             scheduleOnInitStep,
753         const PropagatorTag&           propagatorTag)
754 {
755     if (!builderHelper->simulationData<NoseHooverChainsData>(NoseHooverChainsData::dataID(nhcUsage)))
756     {
757         NoseHooverChainsData::build(nhcUsage, legacySimulatorData, builderHelper, energyData);
758     }
759     auto* nhcData = builderHelper
760                             ->simulationData<NoseHooverChainsData>(NoseHooverChainsData::dataID(nhcUsage))
761                             .value();
762
763     // MTTK data is only needed when connecting to a barostat
764     MttkData* mttkData = nullptr;
765     if (nhcUsage == NhcUsage::Barostat)
766     {
767         mttkData = builderHelper->simulationData<MttkData>(MttkData::dataID()).value();
768     }
769
770     // Element is now owned by the caller of this method, who will handle lifetime (see ModularSimulatorAlgorithm)
771     auto* element = builderHelper->storeElement(std::make_unique<NoseHooverChainsElement>(
772             legacySimulatorData->inputrec->nsttcouple,
773             offset,
774             nhcUsage,
775             useFullStepKE,
776             legacySimulatorData->inputrec->delta_t * legacySimulatorData->inputrec->nsttcouple / 2,
777             scheduleOnInitStep,
778             legacySimulatorData->inputrec->init_step,
779             energyData,
780             nhcData,
781             mttkData));
782     if (nhcUsage == NhcUsage::System)
783     {
784         auto* thermostat = static_cast<NoseHooverChainsElement*>(element);
785         // Capturing pointer is safe because caller handles lifetime
786         builderHelper->registerTemperaturePressureControl(
787                 [thermostat, propagatorTag](const PropagatorConnection& connection) {
788                     thermostat->connectWithPropagator(connection, propagatorTag);
789                 });
790     }
791     else
792     {
793         GMX_RELEASE_ASSERT(propagatorTag == PropagatorTag(""),
794                            "Propagator tag is unused for Barostat NHC element.");
795     }
796     return element;
797 }
798 //! \endcond
799
800 } // namespace gmx