Apply clang-format-11
[alexxy/gromacs.git] / src / gromacs / modularsimulator / velocityscalingtemperaturecoupling.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  * \brief Defines a velocity-scaling temperature coupling element for
37  * the modular simulator
38  *
39  * \author Pascal Merz <pascal.merz@me.com>
40  * \ingroup module_modularsimulator
41  */
42
43 #include "gmxpre.h"
44
45 #include "velocityscalingtemperaturecoupling.h"
46
47 #include <numeric>
48
49 #include "gromacs/domdec/domdec_network.h"
50 #include "gromacs/math/units.h"
51 #include "gromacs/math/vec.h"
52 #include "gromacs/mdlib/coupling.h"
53 #include "gromacs/mdlib/stat.h"
54 #include "gromacs/mdtypes/checkpointdata.h"
55 #include "gromacs/mdtypes/commrec.h"
56 #include "gromacs/mdtypes/group.h"
57 #include "gromacs/mdtypes/inputrec.h"
58 #include "gromacs/utility/fatalerror.h"
59 #include "gromacs/utility/strconvert.h"
60
61 #include "modularsimulator.h"
62 #include "simulatoralgorithm.h"
63
64 namespace gmx
65 {
66
67 /*! \internal
68  * \brief Data used by the concrete temperature coupling implementations
69  */
70 struct TemperatureCouplingData
71 {
72     //! The coupling time step - simulation time step x nstcouple_
73     const double couplingTimeStep;
74     //! Coupling temperature per group
75     ArrayRef<const real> referenceTemperature;
76     //! Coupling time per group
77     ArrayRef<const real> couplingTime;
78     //! Number of degrees of freedom per group
79     ArrayRef<const real> numDegreesOfFreedom;
80     //! Work exerted by thermostat per group
81     ArrayRef<const double> temperatureCouplingIntegral;
82 };
83
84 /*! \internal
85  * \brief Interface for temperature coupling implementations
86  */
87 class ITemperatureCouplingImpl
88 {
89 public:
90     //! Allow access to the scaling vectors
91     virtual void connectWithPropagator(const PropagatorConnection& connectionData,
92                                        int                         numTemperatureGroups) = 0;
93
94     /*! \brief Make a temperature control step
95      *
96      * \param step                     The current step
97      * \param temperatureGroup         The current temperature group
98      * \param currentKineticEnergy     The kinetic energy of the temperature group
99      * \param currentTemperature       The temperature of the temperature group
100      * \param temperatureCouplingData  Access to general temperature coupling data
101      *
102      * \return  The temperature coupling integral for the current temperature group
103      */
104     [[nodiscard]] virtual real apply(Step                           step,
105                                      int                            temperatureGroup,
106                                      real                           currentKineticEnergy,
107                                      real                           currentTemperature,
108                                      const TemperatureCouplingData& temperatureCouplingData) = 0;
109
110     //! Write private data to checkpoint
111     virtual void writeCheckpoint(std::optional<WriteCheckpointData> checkpointData,
112                                  const t_commrec*                   cr) = 0;
113     //! Read private data from checkpoint
114     virtual void readCheckpoint(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) = 0;
115
116     //! Standard virtual destructor
117     virtual ~ITemperatureCouplingImpl() = default;
118 };
119
120 /*! \internal
121  * \brief Implements v-rescale temperature coupling
122  */
123 class VRescaleTemperatureCoupling final : public ITemperatureCouplingImpl
124 {
125 public:
126     //! Apply the v-rescale temperature control
127     real apply(Step                           step,
128                int                            temperatureGroup,
129                real                           currentKineticEnergy,
130                real gmx_unused                currentTemperature,
131                const TemperatureCouplingData& temperatureCouplingData) override
132     {
133         if (!(temperatureCouplingData.couplingTime[temperatureGroup] >= 0
134               && temperatureCouplingData.numDegreesOfFreedom[temperatureGroup] > 0
135               && currentKineticEnergy > 0))
136         {
137             lambdaStartVelocities_[temperatureGroup] = 1.0;
138             return temperatureCouplingData.temperatureCouplingIntegral[temperatureGroup];
139         }
140
141         const real referenceKineticEnergy =
142                 0.5 * temperatureCouplingData.referenceTemperature[temperatureGroup] * gmx::c_boltz
143                 * temperatureCouplingData.numDegreesOfFreedom[temperatureGroup];
144
145         const real newKineticEnergy =
146                 vrescale_resamplekin(currentKineticEnergy,
147                                      referenceKineticEnergy,
148                                      temperatureCouplingData.numDegreesOfFreedom[temperatureGroup],
149                                      temperatureCouplingData.couplingTime[temperatureGroup]
150                                              / temperatureCouplingData.couplingTimeStep,
151                                      step,
152                                      seed_);
153
154         // Analytically newKineticEnergy >= 0, but we check for rounding errors
155         if (newKineticEnergy <= 0)
156         {
157             lambdaStartVelocities_[temperatureGroup] = 0.0;
158         }
159         else
160         {
161             lambdaStartVelocities_[temperatureGroup] = std::sqrt(newKineticEnergy / currentKineticEnergy);
162         }
163
164         if (debug)
165         {
166             fprintf(debug,
167                     "TC: group %d: Ekr %g, Ek %g, Ek_new %g, Lambda: %g\n",
168                     temperatureGroup,
169                     referenceKineticEnergy,
170                     currentKineticEnergy,
171                     newKineticEnergy,
172                     lambdaStartVelocities_[temperatureGroup]);
173         }
174
175         return temperatureCouplingData.temperatureCouplingIntegral[temperatureGroup]
176                - (newKineticEnergy - currentKineticEnergy);
177     }
178
179     //! Connect with propagator - v-rescale only scales start step velocities
180     void connectWithPropagator(const PropagatorConnection& connectionData, int numTemperatureGroups) override
181     {
182         GMX_RELEASE_ASSERT(connectionData.hasStartVelocityScaling(),
183                            "V-Rescale requires start velocity scaling.");
184         connectionData.setNumVelocityScalingVariables(numTemperatureGroups, ScaleVelocities::PreStepOnly);
185         lambdaStartVelocities_ = connectionData.getViewOnStartVelocityScaling();
186     }
187
188     //! No data to write to checkpoint
189     void writeCheckpoint(std::optional<WriteCheckpointData> gmx_unused checkpointData,
190                          const t_commrec gmx_unused* cr) override
191     {
192     }
193     //! No data to read from checkpoints
194     void readCheckpoint(std::optional<ReadCheckpointData> gmx_unused checkpointData,
195                         const t_commrec gmx_unused* cr) override
196     {
197     }
198
199     //! Constructor
200     VRescaleTemperatureCoupling(int64_t seed) : seed_(seed) {}
201
202 private:
203     //! The random seed
204     const int64_t seed_;
205
206     //! View on the scaling factor of the propagator (pre-step velocities)
207     ArrayRef<real> lambdaStartVelocities_;
208 };
209
210 /*! \internal
211  * \brief Implements Berendsen temperature coupling
212  */
213 class BerendsenTemperatureCoupling final : public ITemperatureCouplingImpl
214 {
215 public:
216     //! Apply the v-rescale temperature control
217     real apply(Step gmx_unused                step,
218                int                            temperatureGroup,
219                real                           currentKineticEnergy,
220                real                           currentTemperature,
221                const TemperatureCouplingData& temperatureCouplingData) override
222     {
223         if (!(temperatureCouplingData.couplingTime[temperatureGroup] >= 0
224               && temperatureCouplingData.numDegreesOfFreedom[temperatureGroup] > 0
225               && currentKineticEnergy > 0))
226         {
227             lambdaStartVelocities_[temperatureGroup] = 1.0;
228             return temperatureCouplingData.temperatureCouplingIntegral[temperatureGroup];
229         }
230
231         real lambda =
232                 std::sqrt(1.0
233                           + (temperatureCouplingData.couplingTimeStep
234                              / temperatureCouplingData.couplingTime[temperatureGroup])
235                                     * (temperatureCouplingData.referenceTemperature[temperatureGroup] / currentTemperature
236                                        - 1.0));
237         lambdaStartVelocities_[temperatureGroup] =
238                 std::max<real>(std::min<real>(lambda, 1.25_real), 0.8_real);
239         if (debug)
240         {
241             fprintf(debug,
242                     "TC: group %d: T: %g, Lambda: %g\n",
243                     temperatureGroup,
244                     currentTemperature,
245                     lambdaStartVelocities_[temperatureGroup]);
246         }
247         return temperatureCouplingData.temperatureCouplingIntegral[temperatureGroup]
248                - (lambdaStartVelocities_[temperatureGroup] * lambdaStartVelocities_[temperatureGroup]
249                   - 1) * currentKineticEnergy;
250     }
251
252     //! Connect with propagator - Berendsen only scales start step velocities
253     void connectWithPropagator(const PropagatorConnection& connectionData, int numTemperatureGroups) override
254     {
255         GMX_RELEASE_ASSERT(connectionData.hasStartVelocityScaling(),
256                            "Berendsen T-coupling requires start velocity scaling.");
257         connectionData.setNumVelocityScalingVariables(numTemperatureGroups, ScaleVelocities::PreStepOnly);
258         lambdaStartVelocities_ = connectionData.getViewOnStartVelocityScaling();
259     }
260
261     //! No data to write to checkpoint
262     void writeCheckpoint(std::optional<WriteCheckpointData> gmx_unused checkpointData,
263                          const t_commrec gmx_unused* cr) override
264     {
265     }
266     //! No data to read from checkpoints
267     void readCheckpoint(std::optional<ReadCheckpointData> gmx_unused checkpointData,
268                         const t_commrec gmx_unused* cr) override
269     {
270     }
271
272 private:
273     //! View on the scaling factor of the propagator (pre-step velocities)
274     ArrayRef<real> lambdaStartVelocities_;
275 };
276
277 // Prepare NoseHooverTemperatureCoupling checkpoint data
278 namespace
279 {
280 /*!
281  * \brief Enum describing the contents NoseHoover writes to modular checkpoint
282  *
283  * When changing the checkpoint content, add a new element just above Count, and adjust the
284  * checkpoint functionality.
285  */
286 enum class NHCheckpointVersion
287 {
288     Base, //!< First version of modular checkpointing
289     Count //!< Number of entries. Add new versions right above this!
290 };
291 constexpr auto c_nhCurrentVersion = NHCheckpointVersion(int(NHCheckpointVersion::Count) - 1);
292 } // namespace
293
294 /*! \internal
295  * \brief Implements the Nose-Hoover temperature coupling
296  */
297 class NoseHooverTemperatureCoupling final : public ITemperatureCouplingImpl
298 {
299 public:
300     //! Apply the Nose-Hoover temperature control
301     real apply(Step gmx_unused                step,
302                int                            temperatureGroup,
303                real                           currentKineticEnergy,
304                real                           currentTemperature,
305                const TemperatureCouplingData& thermostatData) override
306     {
307         return applyLeapFrog(
308                 step, temperatureGroup, currentKineticEnergy, currentTemperature, thermostatData);
309     }
310
311     /*! \brief Apply for leap-frog
312      *
313      * This is called after the force calculation, before coordinate update
314      *
315      * We expect system to be at x(t), v(t-dt/2), f(t), T(t-dt/2)
316      * Internal variables are at xi(t-dt), v_xi(t-dt)
317      * Force on xi is calculated at time of system temperature
318      * After calling this, we will have xi(t), v_xi(t)
319      * The thermostat integral returned is a function of xi and v_xi,
320      * and hence at time t.
321      *
322      * This performs an update of the thermostat variables calculated as
323      *     a_xi(t-dt/2) = (T_sys(t-dt/2) - T_ref) / mass_xi;
324      *     v_xi(t) = v_xi(t-dt) + dt_xi * a_xi(t-dt/2);
325      *     xi(t) = xi(t-dt) + dt_xi * (v_xi(t-dt) + v_xi(t))/2;
326      *
327      * This will be followed by leap-frog integration of coordinates, calculated as
328      *     v(t-dt/2) *= - 0.5 * dt * v_xi(t);  // scale previous velocities
329      *     v(t+dt/2) = update_leapfrog_v(v(t-dt/2), f(t));  // do whatever LF does
330      *     v(t+dt/2) *= 1 / (1 + 0.5 * dt * v_xi(t))  // scale new velocities
331      *     x(t+dt) = update_leapfrog_x(x(t), v(t+dt/2));  // do whatever LF does
332      */
333     real applyLeapFrog(Step gmx_unused                step,
334                        int                            temperatureGroup,
335                        real                           currentKineticEnergy,
336                        real                           currentTemperature,
337                        const TemperatureCouplingData& thermostatData)
338     {
339         if (!(thermostatData.couplingTime[temperatureGroup] >= 0
340               && thermostatData.numDegreesOfFreedom[temperatureGroup] > 0 && currentKineticEnergy > 0))
341         {
342             lambdaStartVelocities_[temperatureGroup] = 1.0;
343             lambdaEndVelocities_[temperatureGroup]   = 1.0;
344             return thermostatData.temperatureCouplingIntegral[temperatureGroup];
345         }
346
347         const auto oldXiVelocity = xiVelocities_[temperatureGroup];
348         const auto xiAcceleration =
349                 invXiMass_[temperatureGroup]
350                 * (currentTemperature - thermostatData.referenceTemperature[temperatureGroup]);
351         xiVelocities_[temperatureGroup] += thermostatData.couplingTimeStep * xiAcceleration;
352         xi_[temperatureGroup] += thermostatData.couplingTimeStep
353                                  * (oldXiVelocity + xiVelocities_[temperatureGroup]) * 0.5;
354         lambdaStartVelocities_[temperatureGroup] =
355                 (1 - 0.5 * thermostatData.couplingTimeStep * xiVelocities_[temperatureGroup]);
356         lambdaEndVelocities_[temperatureGroup] =
357                 1. / (1 + 0.5 * thermostatData.couplingTimeStep * xiVelocities_[temperatureGroup]);
358
359         // Current value of the thermostat integral
360         return 0.5 * c_boltz * thermostatData.numDegreesOfFreedom[temperatureGroup]
361                        * (xiVelocities_[temperatureGroup] * xiVelocities_[temperatureGroup])
362                        / invXiMass_[temperatureGroup]
363                + thermostatData.numDegreesOfFreedom[temperatureGroup] * xi_[temperatureGroup]
364                          * c_boltz * thermostatData.referenceTemperature[temperatureGroup];
365     }
366
367     //! Connect with propagator - Nose-Hoover scales start and end step velocities
368     void connectWithPropagator(const PropagatorConnection& connectionData, int numTemperatureGroups) override
369     {
370         GMX_RELEASE_ASSERT(
371                 connectionData.hasStartVelocityScaling() && connectionData.hasEndVelocityScaling(),
372                 "Nose-Hoover T-coupling requires both start and end velocity scaling.");
373         connectionData.setNumVelocityScalingVariables(numTemperatureGroups,
374                                                       ScaleVelocities::PreStepAndPostStep);
375         lambdaStartVelocities_ = connectionData.getViewOnStartVelocityScaling();
376         lambdaEndVelocities_   = connectionData.getViewOnEndVelocityScaling();
377     }
378
379     //! Constructor
380     NoseHooverTemperatureCoupling(int                  numTemperatureGroups,
381                                   ArrayRef<const real> referenceTemperature,
382                                   ArrayRef<const real> couplingTime)
383     {
384         xi_.resize(numTemperatureGroups, 0.0);
385         xiVelocities_.resize(numTemperatureGroups, 0.0);
386         invXiMass_.resize(numTemperatureGroups, 0.0);
387         for (auto temperatureGroup = 0; temperatureGroup < numTemperatureGroups; ++temperatureGroup)
388         {
389             if (referenceTemperature[temperatureGroup] > 0 && couplingTime[temperatureGroup] > 0)
390             {
391                 // Note: This mass definition is equal to legacy md
392                 //       legacy md-vv divides the mass by ndof * kB
393                 invXiMass_[temperatureGroup] = 1.0
394                                                / (gmx::square(couplingTime[temperatureGroup] / M_2PI)
395                                                   * referenceTemperature[temperatureGroup]);
396             }
397         }
398     }
399
400     //! Helper function to read from / write to CheckpointData
401     template<CheckpointDataOperation operation>
402     void doCheckpointData(CheckpointData<operation>* checkpointData)
403     {
404         checkpointVersion(checkpointData, "Nose-Hoover version", c_nhCurrentVersion);
405         checkpointData->arrayRef("xi", makeCheckpointArrayRef<operation>(xi_));
406         checkpointData->arrayRef("xi velocities", makeCheckpointArrayRef<operation>(xiVelocities_));
407     }
408
409     //! Write thermostat dof to checkpoint
410     void writeCheckpoint(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override
411     {
412         if (MASTER(cr))
413         {
414             doCheckpointData(&checkpointData.value());
415         }
416     }
417     //! Read thermostat dof from checkpoint
418     void readCheckpoint(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) override
419     {
420         if (MASTER(cr))
421         {
422             doCheckpointData(&checkpointData.value());
423         }
424         if (DOMAINDECOMP(cr))
425         {
426             dd_bcast(cr->dd, xi_.size() * sizeof(real), xi_.data());
427             dd_bcast(cr->dd, xiVelocities_.size() * sizeof(real), xiVelocities_.data());
428         }
429     }
430
431 private:
432     //! The thermostat degree of freedom
433     std::vector<real> xi_;
434     //! Velocity of the thermostat dof
435     std::vector<real> xiVelocities_;
436     //! Inverse mass of the thermostat dof
437     std::vector<real> invXiMass_;
438
439     //! View on the scaling factor of the propagator (pre-step velocities)
440     ArrayRef<real> lambdaStartVelocities_;
441     //! View on the scaling factor of the propagator (post-step velocities)
442     ArrayRef<real> lambdaEndVelocities_;
443 };
444
445 VelocityScalingTemperatureCoupling::VelocityScalingTemperatureCoupling(
446         int                               nstcouple,
447         int                               offset,
448         UseFullStepKE                     useFullStepKE,
449         ReportPreviousStepConservedEnergy reportPreviousConservedEnergy,
450         int64_t                           seed,
451         int                               numTemperatureGroups,
452         double                            couplingTimeStep,
453         const real*                       referenceTemperature,
454         const real*                       couplingTime,
455         const real*                       numDegreesOfFreedom,
456         EnergyData*                       energyData,
457         TemperatureCoupling               couplingType) :
458     nstcouple_(nstcouple),
459     offset_(offset),
460     useFullStepKE_(useFullStepKE),
461     reportPreviousConservedEnergy_(reportPreviousConservedEnergy),
462     numTemperatureGroups_(numTemperatureGroups),
463     couplingTimeStep_(couplingTimeStep),
464     referenceTemperature_(referenceTemperature, referenceTemperature + numTemperatureGroups),
465     couplingTime_(couplingTime, couplingTime + numTemperatureGroups),
466     numDegreesOfFreedom_(numDegreesOfFreedom, numDegreesOfFreedom + numTemperatureGroups),
467     temperatureCouplingIntegral_(numTemperatureGroups, 0.0),
468     energyData_(energyData),
469     nextEnergyCalculationStep_(-1)
470 {
471     if (couplingType == TemperatureCoupling::VRescale)
472     {
473         temperatureCouplingImpl_ = std::make_unique<VRescaleTemperatureCoupling>(seed);
474     }
475     else if (couplingType == TemperatureCoupling::Berendsen)
476     {
477         temperatureCouplingImpl_ = std::make_unique<BerendsenTemperatureCoupling>();
478     }
479     else if (couplingType == TemperatureCoupling::NoseHoover)
480     {
481         temperatureCouplingImpl_ = std::make_unique<NoseHooverTemperatureCoupling>(
482                 numTemperatureGroups_, referenceTemperature_, couplingTime_);
483     }
484     else
485     {
486         throw NotImplementedError("Temperature coupling " + std::string(enumValueToString(couplingType))
487                                   + " is not implemented for modular simulator.");
488     }
489     energyData->addConservedEnergyContribution([this](Step gmx_used_in_debug step, Time /*unused*/) {
490         GMX_ASSERT(conservedEnergyContributionStep_ == step,
491                    "VelocityScalingTemperatureCoupling conserved energy step mismatch.");
492         return conservedEnergyContribution_;
493     });
494 }
495
496 void VelocityScalingTemperatureCoupling::connectWithMatchingPropagator(const PropagatorConnection& connectionData,
497                                                                        const PropagatorTag& propagatorTag)
498 {
499     if (connectionData.tag == propagatorTag)
500     {
501         temperatureCouplingImpl_->connectWithPropagator(connectionData, numTemperatureGroups_);
502         propagatorCallback_ = connectionData.getVelocityScalingCallback();
503     }
504 }
505
506 void VelocityScalingTemperatureCoupling::elementSetup()
507 {
508     if (!propagatorCallback_)
509     {
510         throw MissingElementConnectionError(
511                 "Velocity scaling temperature coupling was not connected to a propagator.\n"
512                 "Connection to a propagator element is needed to scale the velocities.\n"
513                 "Use connectWithMatchingPropagator(...) before building the "
514                 "ModularSimulatorAlgorithm "
515                 "object.");
516     }
517 }
518
519 void VelocityScalingTemperatureCoupling::scheduleTask(Step                       step,
520                                                       Time gmx_unused            time,
521                                                       const RegisterRunFunction& registerRunFunction)
522 {
523     /* The thermostat will need a valid kinetic energy when it is running.
524      * Currently, computeGlobalCommunicationPeriod() is making sure this
525      * happens on time.
526      * TODO: Once we're switching to a new global communication scheme, we
527      *       will want the thermostat to signal that global reduction
528      *       of the kinetic energy is needed.
529      *
530      */
531     if (step == nextEnergyCalculationStep_
532         && reportPreviousConservedEnergy_ == ReportPreviousStepConservedEnergy::Yes)
533     {
534         // add conserved energy before we do T-coupling
535         registerRunFunction([this, step]() {
536             conservedEnergyContribution_     = conservedEnergyContribution();
537             conservedEnergyContributionStep_ = step;
538         });
539     }
540     if (do_per_step(step + nstcouple_ + offset_, nstcouple_))
541     {
542         // do T-coupling this step
543         registerRunFunction([this, step]() { setLambda(step); });
544
545         // Let propagator know that we want to do T-coupling
546         propagatorCallback_(step);
547     }
548     if (step == nextEnergyCalculationStep_
549         && reportPreviousConservedEnergy_ == ReportPreviousStepConservedEnergy::No)
550     {
551         // add conserved energy after we did T-coupling
552         registerRunFunction([this, step]() {
553             conservedEnergyContribution_     = conservedEnergyContribution();
554             conservedEnergyContributionStep_ = step;
555         });
556     }
557 }
558
559 void VelocityScalingTemperatureCoupling::setLambda(Step step)
560 {
561     const auto*             ekind          = energyData_->ekindata();
562     TemperatureCouplingData thermostatData = {
563         couplingTimeStep_, referenceTemperature_, couplingTime_, numDegreesOfFreedom_, temperatureCouplingIntegral_
564     };
565
566     for (int temperatureGroup = 0; (temperatureGroup < numTemperatureGroups_); temperatureGroup++)
567     {
568         const real currentKineticEnergy = useFullStepKE_ == UseFullStepKE::Yes
569                                                   ? trace(ekind->tcstat[temperatureGroup].ekinf)
570                                                   : trace(ekind->tcstat[temperatureGroup].ekinh);
571         const real currentTemperature   = useFullStepKE_ == UseFullStepKE::Yes
572                                                   ? ekind->tcstat[temperatureGroup].T
573                                                   : ekind->tcstat[temperatureGroup].Th;
574
575         temperatureCouplingIntegral_[temperatureGroup] = temperatureCouplingImpl_->apply(
576                 step, temperatureGroup, currentKineticEnergy, currentTemperature, thermostatData);
577     }
578 }
579
580 namespace
581 {
582 /*!
583  * \brief Enum describing the contents VelocityScalingTemperatureCoupling writes to modular checkpoint
584  *
585  * When changing the checkpoint content, add a new element just above Count, and adjust the
586  * checkpoint functionality.
587  */
588 enum class CheckpointVersion
589 {
590     Base, //!< First version of modular checkpointing
591     Count //!< Number of entries. Add new versions right above this!
592 };
593 constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
594 } // namespace
595
596 template<CheckpointDataOperation operation>
597 void VelocityScalingTemperatureCoupling::doCheckpointData(CheckpointData<operation>* checkpointData)
598 {
599     checkpointVersion(checkpointData, "VRescaleThermostat version", c_currentVersion);
600
601     checkpointData->arrayRef("thermostat integral",
602                              makeCheckpointArrayRef<operation>(temperatureCouplingIntegral_));
603 }
604
605 void VelocityScalingTemperatureCoupling::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
606                                                              const t_commrec*                   cr)
607 {
608     if (MASTER(cr))
609     {
610         doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
611     }
612     temperatureCouplingImpl_->writeCheckpoint(
613             checkpointData
614                     ? std::make_optional(checkpointData->subCheckpointData("thermostat impl"))
615                     : std::nullopt,
616             cr);
617 }
618
619 void VelocityScalingTemperatureCoupling::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
620                                                                 const t_commrec* cr)
621 {
622     if (MASTER(cr))
623     {
624         doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
625     }
626     if (DOMAINDECOMP(cr))
627     {
628         dd_bcast(cr->dd,
629                  ssize(temperatureCouplingIntegral_) * int(sizeof(double)),
630                  temperatureCouplingIntegral_.data());
631     }
632     temperatureCouplingImpl_->readCheckpoint(
633             checkpointData
634                     ? std::make_optional(checkpointData->subCheckpointData("thermostat impl"))
635                     : std::nullopt,
636             cr);
637 }
638
639 const std::string& VelocityScalingTemperatureCoupling::clientID()
640 {
641     return identifier_;
642 }
643
644 real VelocityScalingTemperatureCoupling::conservedEnergyContribution() const
645 {
646     return std::accumulate(temperatureCouplingIntegral_.begin(), temperatureCouplingIntegral_.end(), 0.0);
647 }
648
649 std::optional<SignallerCallback> VelocityScalingTemperatureCoupling::registerEnergyCallback(EnergySignallerEvent event)
650 {
651     if (event == EnergySignallerEvent::EnergyCalculationStep)
652     {
653         return [this](Step step, Time /*unused*/) { nextEnergyCalculationStep_ = step; };
654     }
655     return std::nullopt;
656 }
657
658 ISimulatorElement* VelocityScalingTemperatureCoupling::getElementPointerImpl(
659         LegacySimulatorData*                    legacySimulatorData,
660         ModularSimulatorAlgorithmBuilderHelper* builderHelper,
661         StatePropagatorData gmx_unused* statePropagatorData,
662         EnergyData*                     energyData,
663         FreeEnergyPerturbationData gmx_unused* freeEnergyPerturbationData,
664         GlobalCommunicationHelper gmx_unused* globalCommunicationHelper,
665         Offset                                offset,
666         UseFullStepKE                         useFullStepKE,
667         ReportPreviousStepConservedEnergy     reportPreviousStepConservedEnergy,
668         const PropagatorTag&                  propagatorTag)
669 {
670     // Element is now owned by the caller of this method, who will handle lifetime (see ModularSimulatorAlgorithm)
671     auto* element = builderHelper->storeElement(std::make_unique<VelocityScalingTemperatureCoupling>(
672             legacySimulatorData->inputrec->nsttcouple,
673             offset,
674             useFullStepKE,
675             reportPreviousStepConservedEnergy,
676             legacySimulatorData->inputrec->ld_seed,
677             legacySimulatorData->inputrec->opts.ngtc,
678             legacySimulatorData->inputrec->delta_t * legacySimulatorData->inputrec->nsttcouple,
679             legacySimulatorData->inputrec->opts.ref_t,
680             legacySimulatorData->inputrec->opts.tau_t,
681             legacySimulatorData->inputrec->opts.nrdf,
682             energyData,
683             legacySimulatorData->inputrec->etc));
684     auto* thermostat = static_cast<VelocityScalingTemperatureCoupling*>(element);
685     // Capturing pointer is safe because lifetime is handled by caller
686     builderHelper->registerTemperaturePressureControl(
687             [thermostat, propagatorTag](const PropagatorConnection& connection) {
688                 thermostat->connectWithMatchingPropagator(connection, propagatorTag);
689             });
690     return element;
691 }
692
693 } // namespace gmx