47df3fc82abcca8a7ebef784aa7aa5800650c1dd
[alexxy/gromacs.git] / src / gromacs / modularsimulator / velocityscalingtemperaturecoupling.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  * \brief Defines a velocity-scaling temperature coupling element for
37  * the modular simulator
38  *
39  * \author Pascal Merz <pascal.merz@me.com>
40  * \ingroup module_modularsimulator
41  */
42
43 #include "gmxpre.h"
44
45 #include "velocityscalingtemperaturecoupling.h"
46
47 #include <numeric>
48
49 #include "gromacs/domdec/domdec_network.h"
50 #include "gromacs/math/units.h"
51 #include "gromacs/math/vec.h"
52 #include "gromacs/mdlib/coupling.h"
53 #include "gromacs/mdlib/stat.h"
54 #include "gromacs/mdtypes/checkpointdata.h"
55 #include "gromacs/mdtypes/commrec.h"
56 #include "gromacs/mdtypes/group.h"
57 #include "gromacs/mdtypes/inputrec.h"
58 #include "gromacs/utility/fatalerror.h"
59 #include "gromacs/utility/strconvert.h"
60
61 #include "modularsimulator.h"
62 #include "simulatoralgorithm.h"
63
64 namespace gmx
65 {
66
67 /*! \internal
68  * \brief Data used by the concrete temperature coupling implementations
69  */
70 struct TemperatureCouplingData
71 {
72     //! The coupling time step - simulation time step x nstcouple_
73     const double couplingTimeStep;
74     //! Coupling temperature per group
75     ArrayRef<const real> referenceTemperature;
76     //! Coupling time per group
77     ArrayRef<const real> couplingTime;
78     //! Number of degrees of freedom per group
79     ArrayRef<const real> numDegreesOfFreedom;
80     //! Work exerted by thermostat per group
81     ArrayRef<const double> temperatureCouplingIntegral;
82 };
83
84 /*! \internal
85  * \brief Interface for temperature coupling implementations
86  */
87 class ITemperatureCouplingImpl
88 {
89 public:
90     //! Allow access to the scaling vectors
91     virtual void connectWithPropagator(const PropagatorConnection& connectionData,
92                                        int                         numTemperatureGroups) = 0;
93
94     /*! \brief Make a temperature control step
95      *
96      * \param step                     The current step
97      * \param temperatureGroup         The current temperature group
98      * \param currentKineticEnergy     The kinetic energy of the temperature group
99      * \param currentTemperature       The temperature of the temperature group
100      * \param temperatureCouplingData  Access to general temperature coupling data
101      *
102      * \return  The temperature coupling integral for the current temperature group
103      */
104     [[nodiscard]] virtual real apply(Step                           step,
105                                      int                            temperatureGroup,
106                                      real                           currentKineticEnergy,
107                                      real                           currentTemperature,
108                                      const TemperatureCouplingData& temperatureCouplingData) = 0;
109
110     //! Write private data to checkpoint
111     virtual void writeCheckpoint(std::optional<WriteCheckpointData> checkpointData,
112                                  const t_commrec*                   cr) = 0;
113     //! Read private data from checkpoint
114     virtual void readCheckpoint(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) = 0;
115
116     //! Update the reference temperature and update and return the temperature coupling integral
117     virtual real updateReferenceTemperatureAndIntegral(int  temperatureGroup,
118                                                        real newTemperature,
119                                                        ReferenceTemperatureChangeAlgorithm algorithm,
120                                                        const TemperatureCouplingData& temperatureCouplingData) = 0;
121
122     //! Standard virtual destructor
123     virtual ~ITemperatureCouplingImpl() = default;
124 };
125
126 /*! \internal
127  * \brief Implements v-rescale temperature coupling
128  */
129 class VRescaleTemperatureCoupling final : public ITemperatureCouplingImpl
130 {
131 public:
132     //! Apply the v-rescale temperature control
133     real apply(Step                           step,
134                int                            temperatureGroup,
135                real                           currentKineticEnergy,
136                real gmx_unused                currentTemperature,
137                const TemperatureCouplingData& temperatureCouplingData) override
138     {
139         if (!(temperatureCouplingData.couplingTime[temperatureGroup] >= 0
140               && temperatureCouplingData.numDegreesOfFreedom[temperatureGroup] > 0
141               && currentKineticEnergy > 0))
142         {
143             lambdaStartVelocities_[temperatureGroup] = 1.0;
144             return temperatureCouplingData.temperatureCouplingIntegral[temperatureGroup];
145         }
146
147         const real referenceKineticEnergy =
148                 0.5 * temperatureCouplingData.referenceTemperature[temperatureGroup] * gmx::c_boltz
149                 * temperatureCouplingData.numDegreesOfFreedom[temperatureGroup];
150
151         const real newKineticEnergy =
152                 vrescale_resamplekin(currentKineticEnergy,
153                                      referenceKineticEnergy,
154                                      temperatureCouplingData.numDegreesOfFreedom[temperatureGroup],
155                                      temperatureCouplingData.couplingTime[temperatureGroup]
156                                              / temperatureCouplingData.couplingTimeStep,
157                                      step,
158                                      seed_);
159
160         // Analytically newKineticEnergy >= 0, but we check for rounding errors
161         if (newKineticEnergy <= 0)
162         {
163             lambdaStartVelocities_[temperatureGroup] = 0.0;
164         }
165         else
166         {
167             lambdaStartVelocities_[temperatureGroup] = std::sqrt(newKineticEnergy / currentKineticEnergy);
168         }
169
170         if (debug)
171         {
172             fprintf(debug,
173                     "TC: group %d: Ekr %g, Ek %g, Ek_new %g, Lambda: %g\n",
174                     temperatureGroup,
175                     referenceKineticEnergy,
176                     currentKineticEnergy,
177                     newKineticEnergy,
178                     lambdaStartVelocities_[temperatureGroup]);
179         }
180
181         return temperatureCouplingData.temperatureCouplingIntegral[temperatureGroup]
182                - (newKineticEnergy - currentKineticEnergy);
183     }
184
185     //! Connect with propagator - v-rescale only scales start step velocities
186     void connectWithPropagator(const PropagatorConnection& connectionData, int numTemperatureGroups) override
187     {
188         GMX_RELEASE_ASSERT(connectionData.hasStartVelocityScaling(),
189                            "V-Rescale requires start velocity scaling.");
190         connectionData.setNumVelocityScalingVariables(numTemperatureGroups, ScaleVelocities::PreStepOnly);
191         lambdaStartVelocities_ = connectionData.getViewOnStartVelocityScaling();
192     }
193
194     //! No data to write to checkpoint
195     void writeCheckpoint(std::optional<WriteCheckpointData> gmx_unused checkpointData,
196                          const t_commrec gmx_unused* cr) override
197     {
198     }
199     //! No data to read from checkpoints
200     void readCheckpoint(std::optional<ReadCheckpointData> gmx_unused checkpointData,
201                         const t_commrec gmx_unused* cr) override
202     {
203     }
204
205     //! No changes needed
206     real updateReferenceTemperatureAndIntegral(int             temperatureGroup,
207                                                real gmx_unused newTemperature,
208                                                ReferenceTemperatureChangeAlgorithm gmx_unused algorithm,
209                                                const TemperatureCouplingData& temperatureCouplingData) override
210     {
211         return temperatureCouplingData.temperatureCouplingIntegral[temperatureGroup];
212     }
213
214     //! Constructor
215     VRescaleTemperatureCoupling(int64_t seed) : seed_(seed) {}
216
217 private:
218     //! The random seed
219     const int64_t seed_;
220
221     //! View on the scaling factor of the propagator (pre-step velocities)
222     ArrayRef<real> lambdaStartVelocities_;
223 };
224
225 /*! \internal
226  * \brief Implements Berendsen temperature coupling
227  */
228 class BerendsenTemperatureCoupling final : public ITemperatureCouplingImpl
229 {
230 public:
231     //! Apply the v-rescale temperature control
232     real apply(Step gmx_unused                step,
233                int                            temperatureGroup,
234                real                           currentKineticEnergy,
235                real                           currentTemperature,
236                const TemperatureCouplingData& temperatureCouplingData) override
237     {
238         if (!(temperatureCouplingData.couplingTime[temperatureGroup] >= 0
239               && temperatureCouplingData.numDegreesOfFreedom[temperatureGroup] > 0
240               && currentKineticEnergy > 0))
241         {
242             lambdaStartVelocities_[temperatureGroup] = 1.0;
243             return temperatureCouplingData.temperatureCouplingIntegral[temperatureGroup];
244         }
245
246         real lambda =
247                 std::sqrt(1.0
248                           + (temperatureCouplingData.couplingTimeStep
249                              / temperatureCouplingData.couplingTime[temperatureGroup])
250                                     * (temperatureCouplingData.referenceTemperature[temperatureGroup] / currentTemperature
251                                        - 1.0));
252         lambdaStartVelocities_[temperatureGroup] =
253                 std::max<real>(std::min<real>(lambda, 1.25_real), 0.8_real);
254         if (debug)
255         {
256             fprintf(debug,
257                     "TC: group %d: T: %g, Lambda: %g\n",
258                     temperatureGroup,
259                     currentTemperature,
260                     lambdaStartVelocities_[temperatureGroup]);
261         }
262         return temperatureCouplingData.temperatureCouplingIntegral[temperatureGroup]
263                - (lambdaStartVelocities_[temperatureGroup] * lambdaStartVelocities_[temperatureGroup]
264                   - 1) * currentKineticEnergy;
265     }
266
267     //! Connect with propagator - Berendsen only scales start step velocities
268     void connectWithPropagator(const PropagatorConnection& connectionData, int numTemperatureGroups) override
269     {
270         GMX_RELEASE_ASSERT(connectionData.hasStartVelocityScaling(),
271                            "Berendsen T-coupling requires start velocity scaling.");
272         connectionData.setNumVelocityScalingVariables(numTemperatureGroups, ScaleVelocities::PreStepOnly);
273         lambdaStartVelocities_ = connectionData.getViewOnStartVelocityScaling();
274     }
275
276     //! No data to write to checkpoint
277     void writeCheckpoint(std::optional<WriteCheckpointData> gmx_unused checkpointData,
278                          const t_commrec gmx_unused* cr) override
279     {
280     }
281     //! No data to read from checkpoints
282     void readCheckpoint(std::optional<ReadCheckpointData> gmx_unused checkpointData,
283                         const t_commrec gmx_unused* cr) override
284     {
285     }
286
287     //! No changes needed
288     real updateReferenceTemperatureAndIntegral(int             temperatureGroup,
289                                                real gmx_unused newTemperature,
290                                                ReferenceTemperatureChangeAlgorithm gmx_unused algorithm,
291                                                const TemperatureCouplingData& temperatureCouplingData) override
292     {
293         return temperatureCouplingData.temperatureCouplingIntegral[temperatureGroup];
294     }
295
296 private:
297     //! View on the scaling factor of the propagator (pre-step velocities)
298     ArrayRef<real> lambdaStartVelocities_;
299 };
300
301 // Prepare NoseHooverTemperatureCoupling checkpoint data
302 namespace
303 {
304 /*!
305  * \brief Enum describing the contents NoseHoover writes to modular checkpoint
306  *
307  * When changing the checkpoint content, add a new element just above Count, and adjust the
308  * checkpoint functionality.
309  */
310 enum class NHCheckpointVersion
311 {
312     Base, //!< First version of modular checkpointing
313     Count //!< Number of entries. Add new versions right above this!
314 };
315 constexpr auto c_nhCurrentVersion = NHCheckpointVersion(int(NHCheckpointVersion::Count) - 1);
316 } // namespace
317
318 /*! \internal
319  * \brief Implements the Nose-Hoover temperature coupling
320  */
321 class NoseHooverTemperatureCoupling final : public ITemperatureCouplingImpl
322 {
323 public:
324     //! Calculate the current value of the temperature coupling integral
325     real integral(int temperatureGroup, real numDegreesOfFreedom, real referenceTemperature)
326     {
327         return 0.5 * c_boltz * numDegreesOfFreedom
328                        * (xiVelocities_[temperatureGroup] * xiVelocities_[temperatureGroup])
329                        / invXiMass_[temperatureGroup]
330                + numDegreesOfFreedom * xi_[temperatureGroup] * c_boltz * referenceTemperature;
331     }
332
333     //! Apply the Nose-Hoover temperature control
334     real apply(Step gmx_unused                step,
335                int                            temperatureGroup,
336                real                           currentKineticEnergy,
337                real                           currentTemperature,
338                const TemperatureCouplingData& thermostatData) override
339     {
340         return applyLeapFrog(
341                 step, temperatureGroup, currentKineticEnergy, currentTemperature, thermostatData);
342     }
343
344     /*! \brief Apply for leap-frog
345      *
346      * This is called after the force calculation, before coordinate update
347      *
348      * We expect system to be at x(t), v(t-dt/2), f(t), T(t-dt/2)
349      * Internal variables are at xi(t-dt), v_xi(t-dt)
350      * Force on xi is calculated at time of system temperature
351      * After calling this, we will have xi(t), v_xi(t)
352      * The thermostat integral returned is a function of xi and v_xi,
353      * and hence at time t.
354      *
355      * This performs an update of the thermostat variables calculated as
356      *     a_xi(t-dt/2) = (T_sys(t-dt/2) - T_ref) / mass_xi;
357      *     v_xi(t) = v_xi(t-dt) + dt_xi * a_xi(t-dt/2);
358      *     xi(t) = xi(t-dt) + dt_xi * (v_xi(t-dt) + v_xi(t))/2;
359      *
360      * This will be followed by leap-frog integration of coordinates, calculated as
361      *     v(t-dt/2) *= - 0.5 * dt * v_xi(t);  // scale previous velocities
362      *     v(t+dt/2) = update_leapfrog_v(v(t-dt/2), f(t));  // do whatever LF does
363      *     v(t+dt/2) *= 1 / (1 + 0.5 * dt * v_xi(t))  // scale new velocities
364      *     x(t+dt) = update_leapfrog_x(x(t), v(t+dt/2));  // do whatever LF does
365      */
366     real applyLeapFrog(Step gmx_unused                step,
367                        int                            temperatureGroup,
368                        real                           currentKineticEnergy,
369                        real                           currentTemperature,
370                        const TemperatureCouplingData& thermostatData)
371     {
372         if (!(thermostatData.couplingTime[temperatureGroup] >= 0
373               && thermostatData.numDegreesOfFreedom[temperatureGroup] > 0 && currentKineticEnergy > 0))
374         {
375             lambdaStartVelocities_[temperatureGroup] = 1.0;
376             lambdaEndVelocities_[temperatureGroup]   = 1.0;
377             return thermostatData.temperatureCouplingIntegral[temperatureGroup];
378         }
379
380         const auto oldXiVelocity = xiVelocities_[temperatureGroup];
381         const auto xiAcceleration =
382                 invXiMass_[temperatureGroup]
383                 * (currentTemperature - thermostatData.referenceTemperature[temperatureGroup]);
384         xiVelocities_[temperatureGroup] += thermostatData.couplingTimeStep * xiAcceleration;
385         xi_[temperatureGroup] += thermostatData.couplingTimeStep
386                                  * (oldXiVelocity + xiVelocities_[temperatureGroup]) * 0.5;
387         lambdaStartVelocities_[temperatureGroup] =
388                 (1 - 0.5 * thermostatData.couplingTimeStep * xiVelocities_[temperatureGroup]);
389         lambdaEndVelocities_[temperatureGroup] =
390                 1. / (1 + 0.5 * thermostatData.couplingTimeStep * xiVelocities_[temperatureGroup]);
391
392         // Current value of the thermostat integral
393         return integral(temperatureGroup,
394                         thermostatData.numDegreesOfFreedom[temperatureGroup],
395                         thermostatData.referenceTemperature[temperatureGroup]);
396     }
397
398     //! Connect with propagator - Nose-Hoover scales start and end step velocities
399     void connectWithPropagator(const PropagatorConnection& connectionData, int numTemperatureGroups) override
400     {
401         GMX_RELEASE_ASSERT(
402                 connectionData.hasStartVelocityScaling() && connectionData.hasEndVelocityScaling(),
403                 "Nose-Hoover T-coupling requires both start and end velocity scaling.");
404         connectionData.setNumVelocityScalingVariables(numTemperatureGroups,
405                                                       ScaleVelocities::PreStepAndPostStep);
406         lambdaStartVelocities_ = connectionData.getViewOnStartVelocityScaling();
407         lambdaEndVelocities_   = connectionData.getViewOnEndVelocityScaling();
408     }
409
410     //! Constructor
411     NoseHooverTemperatureCoupling(int                  numTemperatureGroups,
412                                   ArrayRef<const real> referenceTemperature,
413                                   ArrayRef<const real> couplingTime)
414     {
415         xi_.resize(numTemperatureGroups, 0.0);
416         xiVelocities_.resize(numTemperatureGroups, 0.0);
417         invXiMass_.resize(numTemperatureGroups, 0.0);
418         for (auto temperatureGroup = 0; temperatureGroup < numTemperatureGroups; ++temperatureGroup)
419         {
420             if (referenceTemperature[temperatureGroup] > 0 && couplingTime[temperatureGroup] > 0)
421             {
422                 // Note: This mass definition is equal to legacy md
423                 //       legacy md-vv divides the mass by ndof * kB
424                 invXiMass_[temperatureGroup] = 1.0
425                                                / (gmx::square(couplingTime[temperatureGroup] / M_2PI)
426                                                   * referenceTemperature[temperatureGroup]);
427             }
428         }
429     }
430
431     //! Helper function to read from / write to CheckpointData
432     template<CheckpointDataOperation operation>
433     void doCheckpointData(CheckpointData<operation>* checkpointData)
434     {
435         checkpointVersion(checkpointData, "Nose-Hoover version", c_nhCurrentVersion);
436         checkpointData->arrayRef("xi", makeCheckpointArrayRef<operation>(xi_));
437         checkpointData->arrayRef("xi velocities", makeCheckpointArrayRef<operation>(xiVelocities_));
438     }
439
440     //! Write thermostat dof to checkpoint
441     void writeCheckpoint(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override
442     {
443         if (MASTER(cr))
444         {
445             doCheckpointData(&checkpointData.value());
446         }
447     }
448     //! Read thermostat dof from checkpoint
449     void readCheckpoint(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) override
450     {
451         if (MASTER(cr))
452         {
453             doCheckpointData(&checkpointData.value());
454         }
455         if (DOMAINDECOMP(cr))
456         {
457             dd_bcast(cr->dd, xi_.size() * sizeof(real), xi_.data());
458             dd_bcast(cr->dd, xiVelocities_.size() * sizeof(real), xiVelocities_.data());
459         }
460     }
461
462     //! Adapt masses
463     real updateReferenceTemperatureAndIntegral(int             temperatureGroup,
464                                                real gmx_unused newTemperature,
465                                                ReferenceTemperatureChangeAlgorithm gmx_unused algorithm,
466                                                const TemperatureCouplingData& temperatureCouplingData) override
467     {
468         // Currently, we don't know about any temperature change algorithms, so we assert this never gets called
469         GMX_ASSERT(false,
470                    "NoseHooverTemperatureCoupling: Unknown ReferenceTemperatureChangeAlgorithm.");
471         const bool newTemperatureIsValid =
472                 (newTemperature > 0 && temperatureCouplingData.couplingTime[temperatureGroup] > 0
473                  && temperatureCouplingData.numDegreesOfFreedom[temperatureGroup] > 0);
474         const bool oldTemperatureIsValid =
475                 (temperatureCouplingData.referenceTemperature[temperatureGroup] > 0
476                  && temperatureCouplingData.couplingTime[temperatureGroup] > 0
477                  && temperatureCouplingData.numDegreesOfFreedom[temperatureGroup] > 0);
478         GMX_RELEASE_ASSERT(newTemperatureIsValid == oldTemperatureIsValid,
479                            "Cannot turn temperature coupling on / off during simulation run.");
480         if (oldTemperatureIsValid && newTemperatureIsValid)
481         {
482             invXiMass_[temperatureGroup] *=
483                     (temperatureCouplingData.referenceTemperature[temperatureGroup] / newTemperature);
484             xiVelocities_[temperatureGroup] *= std::sqrt(
485                     newTemperature / temperatureCouplingData.referenceTemperature[temperatureGroup]);
486         }
487         return integral(temperatureGroup,
488                         temperatureCouplingData.numDegreesOfFreedom[temperatureGroup],
489                         newTemperature);
490     }
491
492 private:
493     //! The thermostat degree of freedom
494     std::vector<real> xi_;
495     //! Velocity of the thermostat dof
496     std::vector<real> xiVelocities_;
497     //! Inverse mass of the thermostat dof
498     std::vector<real> invXiMass_;
499
500     //! View on the scaling factor of the propagator (pre-step velocities)
501     ArrayRef<real> lambdaStartVelocities_;
502     //! View on the scaling factor of the propagator (post-step velocities)
503     ArrayRef<real> lambdaEndVelocities_;
504 };
505
506 VelocityScalingTemperatureCoupling::VelocityScalingTemperatureCoupling(
507         int                               nstcouple,
508         int                               offset,
509         UseFullStepKE                     useFullStepKE,
510         ReportPreviousStepConservedEnergy reportPreviousConservedEnergy,
511         int64_t                           seed,
512         int                               numTemperatureGroups,
513         double                            couplingTimeStep,
514         const real*                       referenceTemperature,
515         const real*                       couplingTime,
516         const real*                       numDegreesOfFreedom,
517         EnergyData*                       energyData,
518         TemperatureCoupling               couplingType) :
519     nstcouple_(nstcouple),
520     offset_(offset),
521     useFullStepKE_(useFullStepKE),
522     reportPreviousConservedEnergy_(reportPreviousConservedEnergy),
523     numTemperatureGroups_(numTemperatureGroups),
524     couplingTimeStep_(couplingTimeStep),
525     referenceTemperature_(referenceTemperature, referenceTemperature + numTemperatureGroups),
526     couplingTime_(couplingTime, couplingTime + numTemperatureGroups),
527     numDegreesOfFreedom_(numDegreesOfFreedom, numDegreesOfFreedom + numTemperatureGroups),
528     temperatureCouplingIntegral_(numTemperatureGroups, 0.0),
529     energyData_(energyData),
530     nextEnergyCalculationStep_(-1)
531 {
532     if (couplingType == TemperatureCoupling::VRescale)
533     {
534         temperatureCouplingImpl_ = std::make_unique<VRescaleTemperatureCoupling>(seed);
535     }
536     else if (couplingType == TemperatureCoupling::Berendsen)
537     {
538         temperatureCouplingImpl_ = std::make_unique<BerendsenTemperatureCoupling>();
539     }
540     else if (couplingType == TemperatureCoupling::NoseHoover)
541     {
542         temperatureCouplingImpl_ = std::make_unique<NoseHooverTemperatureCoupling>(
543                 numTemperatureGroups_, referenceTemperature_, couplingTime_);
544     }
545     else
546     {
547         throw NotImplementedError("Temperature coupling " + std::string(enumValueToString(couplingType))
548                                   + " is not implemented for modular simulator.");
549     }
550     energyData->addConservedEnergyContribution([this](Step gmx_used_in_debug step, Time /*unused*/) {
551         GMX_ASSERT(conservedEnergyContributionStep_ == step,
552                    "VelocityScalingTemperatureCoupling conserved energy step mismatch.");
553         return conservedEnergyContribution_;
554     });
555 }
556
557 void VelocityScalingTemperatureCoupling::connectWithMatchingPropagator(const PropagatorConnection& connectionData,
558                                                                        const PropagatorTag& propagatorTag)
559 {
560     if (connectionData.tag == propagatorTag)
561     {
562         temperatureCouplingImpl_->connectWithPropagator(connectionData, numTemperatureGroups_);
563         propagatorCallback_ = connectionData.getVelocityScalingCallback();
564     }
565 }
566
567 void VelocityScalingTemperatureCoupling::elementSetup()
568 {
569     if (!propagatorCallback_)
570     {
571         throw MissingElementConnectionError(
572                 "Velocity scaling temperature coupling was not connected to a propagator.\n"
573                 "Connection to a propagator element is needed to scale the velocities.\n"
574                 "Use connectWithMatchingPropagator(...) before building the "
575                 "ModularSimulatorAlgorithm "
576                 "object.");
577     }
578 }
579
580 void VelocityScalingTemperatureCoupling::scheduleTask(Step                       step,
581                                                       Time gmx_unused            time,
582                                                       const RegisterRunFunction& registerRunFunction)
583 {
584     /* The thermostat will need a valid kinetic energy when it is running.
585      * Currently, computeGlobalCommunicationPeriod() is making sure this
586      * happens on time.
587      * TODO: Once we're switching to a new global communication scheme, we
588      *       will want the thermostat to signal that global reduction
589      *       of the kinetic energy is needed.
590      *
591      */
592     if (step == nextEnergyCalculationStep_
593         && reportPreviousConservedEnergy_ == ReportPreviousStepConservedEnergy::Yes)
594     {
595         // add conserved energy before we do T-coupling
596         registerRunFunction([this, step]() {
597             conservedEnergyContribution_     = conservedEnergyContribution();
598             conservedEnergyContributionStep_ = step;
599         });
600     }
601     if (do_per_step(step + nstcouple_ + offset_, nstcouple_))
602     {
603         // do T-coupling this step
604         registerRunFunction([this, step]() { setLambda(step); });
605
606         // Let propagator know that we want to do T-coupling
607         propagatorCallback_(step);
608     }
609     if (step == nextEnergyCalculationStep_
610         && reportPreviousConservedEnergy_ == ReportPreviousStepConservedEnergy::No)
611     {
612         // add conserved energy after we did T-coupling
613         registerRunFunction([this, step]() {
614             conservedEnergyContribution_     = conservedEnergyContribution();
615             conservedEnergyContributionStep_ = step;
616         });
617     }
618 }
619
620 void VelocityScalingTemperatureCoupling::setLambda(Step step)
621 {
622     const auto*             ekind          = energyData_->ekindata();
623     TemperatureCouplingData thermostatData = {
624         couplingTimeStep_, referenceTemperature_, couplingTime_, numDegreesOfFreedom_, temperatureCouplingIntegral_
625     };
626
627     for (int temperatureGroup = 0; (temperatureGroup < numTemperatureGroups_); temperatureGroup++)
628     {
629         const real currentKineticEnergy = useFullStepKE_ == UseFullStepKE::Yes
630                                                   ? trace(ekind->tcstat[temperatureGroup].ekinf)
631                                                   : trace(ekind->tcstat[temperatureGroup].ekinh);
632         const real currentTemperature   = useFullStepKE_ == UseFullStepKE::Yes
633                                                   ? ekind->tcstat[temperatureGroup].T
634                                                   : ekind->tcstat[temperatureGroup].Th;
635
636         temperatureCouplingIntegral_[temperatureGroup] = temperatureCouplingImpl_->apply(
637                 step, temperatureGroup, currentKineticEnergy, currentTemperature, thermostatData);
638     }
639 }
640
641 void VelocityScalingTemperatureCoupling::updateReferenceTemperature(ArrayRef<const real> temperatures,
642                                                                     ReferenceTemperatureChangeAlgorithm algorithm)
643 {
644     TemperatureCouplingData thermostatData = {
645         couplingTimeStep_, referenceTemperature_, couplingTime_, numDegreesOfFreedom_, temperatureCouplingIntegral_
646     };
647     for (int temperatureGroup = 0; (temperatureGroup < numTemperatureGroups_); temperatureGroup++)
648     {
649         temperatureCouplingIntegral_[temperatureGroup] =
650                 temperatureCouplingImpl_->updateReferenceTemperatureAndIntegral(
651                         temperatureGroup, temperatures[temperatureGroup], algorithm, thermostatData);
652     }
653     // Currently, we don't know about any temperature change algorithms, so we assert this never gets called
654     GMX_ASSERT(false,
655                "VelocityScalingTemperatureCoupling: Unknown ReferenceTemperatureChangeAlgorithm.");
656     std::copy(temperatures.begin(), temperatures.end(), referenceTemperature_.begin());
657 }
658
659 namespace
660 {
661 /*!
662  * \brief Enum describing the contents VelocityScalingTemperatureCoupling writes to modular checkpoint
663  *
664  * When changing the checkpoint content, add a new element just above Count, and adjust the
665  * checkpoint functionality.
666  */
667 enum class CheckpointVersion
668 {
669     Base, //!< First version of modular checkpointing
670     Count //!< Number of entries. Add new versions right above this!
671 };
672 constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
673 } // namespace
674
675 template<CheckpointDataOperation operation>
676 void VelocityScalingTemperatureCoupling::doCheckpointData(CheckpointData<operation>* checkpointData)
677 {
678     checkpointVersion(checkpointData, "VRescaleThermostat version", c_currentVersion);
679
680     checkpointData->arrayRef("thermostat integral",
681                              makeCheckpointArrayRef<operation>(temperatureCouplingIntegral_));
682 }
683
684 void VelocityScalingTemperatureCoupling::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
685                                                              const t_commrec*                   cr)
686 {
687     if (MASTER(cr))
688     {
689         doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
690     }
691     temperatureCouplingImpl_->writeCheckpoint(
692             checkpointData
693                     ? std::make_optional(checkpointData->subCheckpointData("thermostat impl"))
694                     : std::nullopt,
695             cr);
696 }
697
698 void VelocityScalingTemperatureCoupling::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
699                                                                 const t_commrec* cr)
700 {
701     if (MASTER(cr))
702     {
703         doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
704     }
705     if (DOMAINDECOMP(cr))
706     {
707         dd_bcast(cr->dd,
708                  ssize(temperatureCouplingIntegral_) * int(sizeof(double)),
709                  temperatureCouplingIntegral_.data());
710     }
711     temperatureCouplingImpl_->readCheckpoint(
712             checkpointData
713                     ? std::make_optional(checkpointData->subCheckpointData("thermostat impl"))
714                     : std::nullopt,
715             cr);
716 }
717
718 const std::string& VelocityScalingTemperatureCoupling::clientID()
719 {
720     return identifier_;
721 }
722
723 real VelocityScalingTemperatureCoupling::conservedEnergyContribution() const
724 {
725     return std::accumulate(temperatureCouplingIntegral_.begin(), temperatureCouplingIntegral_.end(), 0.0);
726 }
727
728 std::optional<SignallerCallback> VelocityScalingTemperatureCoupling::registerEnergyCallback(EnergySignallerEvent event)
729 {
730     if (event == EnergySignallerEvent::EnergyCalculationStep)
731     {
732         return [this](Step step, Time /*unused*/) { nextEnergyCalculationStep_ = step; };
733     }
734     return std::nullopt;
735 }
736
737 ISimulatorElement* VelocityScalingTemperatureCoupling::getElementPointerImpl(
738         LegacySimulatorData*                    legacySimulatorData,
739         ModularSimulatorAlgorithmBuilderHelper* builderHelper,
740         StatePropagatorData gmx_unused* statePropagatorData,
741         EnergyData*                     energyData,
742         FreeEnergyPerturbationData gmx_unused* freeEnergyPerturbationData,
743         GlobalCommunicationHelper gmx_unused* globalCommunicationHelper,
744         ObservablesReducer* /*observablesReducer*/,
745         Offset                            offset,
746         UseFullStepKE                     useFullStepKE,
747         ReportPreviousStepConservedEnergy reportPreviousStepConservedEnergy,
748         const PropagatorTag&              propagatorTag)
749 {
750     // Element is now owned by the caller of this method, who will handle lifetime (see ModularSimulatorAlgorithm)
751     auto* element = builderHelper->storeElement(std::make_unique<VelocityScalingTemperatureCoupling>(
752             legacySimulatorData->inputrec->nsttcouple,
753             offset,
754             useFullStepKE,
755             reportPreviousStepConservedEnergy,
756             legacySimulatorData->inputrec->ld_seed,
757             legacySimulatorData->inputrec->opts.ngtc,
758             legacySimulatorData->inputrec->delta_t * legacySimulatorData->inputrec->nsttcouple,
759             legacySimulatorData->inputrec->opts.ref_t,
760             legacySimulatorData->inputrec->opts.tau_t,
761             legacySimulatorData->inputrec->opts.nrdf,
762             energyData,
763             legacySimulatorData->inputrec->etc));
764     auto* thermostat = static_cast<VelocityScalingTemperatureCoupling*>(element);
765     // Capturing pointer is safe because lifetime is handled by caller
766     builderHelper->registerTemperaturePressureControl(
767             [thermostat, propagatorTag](const PropagatorConnection& connection) {
768                 thermostat->connectWithMatchingPropagator(connection, propagatorTag);
769             });
770     builderHelper->registerReferenceTemperatureUpdate(
771             [thermostat](ArrayRef<const real> temperatures, ReferenceTemperatureChangeAlgorithm algorithm) {
772                 thermostat->updateReferenceTemperature(temperatures, algorithm);
773             });
774     return element;
775 }
776
777 } // namespace gmx