22517e52380c4c22f3b43646c9f5255a7037684a
[alexxy/gromacs.git] / src / gromacs / modularsimulator / mttk.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  * \brief Defines classes related to MTTK pressure coupling
37  *
38  * \author Pascal Merz <pascal.merz@me.com>
39  * \ingroup module_modularsimulator
40  */
41
42 #include "gmxpre.h"
43
44 #include "mttk.h"
45
46 #include "gromacs/mdtypes/commrec.h"
47 #include "gromacs/domdec/domdec_network.h"
48 #include "gromacs/math/functions.h"
49 #include "gromacs/math/units.h"
50 #include "gromacs/math/vec.h"
51 #include "gromacs/mdlib/coupling.h"
52 #include "gromacs/mdlib/stat.h"
53 #include "gromacs/topology/ifunc.h"
54 #include "gromacs/mdtypes/inputrec.h"
55 #include "gromacs/mdtypes/group.h"
56 #include "gromacs/mdtypes/enerdata.h"
57
58 #include "energydata.h"
59 #include "velocityscalingtemperaturecoupling.h"
60 #include "nosehooverchains.h"
61 #include "simulatoralgorithm.h"
62 #include "trotterhelperfunctions.h"
63
64 namespace gmx
65 {
66
67 void MttkData::build(LegacySimulatorData*                    legacySimulatorData,
68                      ModularSimulatorAlgorithmBuilderHelper* builderHelper,
69                      StatePropagatorData*                    statePropagatorData,
70                      EnergyData*                             energyData,
71                      const MttkPropagatorConnectionDetails&  mttkPropagatorConnectionDetails)
72 {
73     // Uses reference temperature of first T-group
74     const real referenceTemperature = legacySimulatorData->inputrec->opts.ref_t[0];
75     const real referencePressure    = trace(legacySimulatorData->inputrec->ref_p) / DIM;
76     // Weights are set based on initial volume
77     real initialVolume = det(statePropagatorData->constBox());
78
79     // When using domain decomposition, statePropagatorData might not have the initial
80     // box yet, so we get it from the legacy state_global instead.
81     // TODO: Make sure we have a valid state in statePropagatorData at all times (#3421)
82     if (DOMAINDECOMP(legacySimulatorData->cr))
83     {
84         if (MASTER(legacySimulatorData->cr))
85         {
86             initialVolume = det(legacySimulatorData->state_global->box);
87         }
88         dd_bcast(legacySimulatorData->cr->dd, int(sizeof(real)), &initialVolume);
89     }
90
91     GMX_RELEASE_ASSERT(
92             !builderHelper->simulationData<MttkPropagatorConnection>(MttkPropagatorConnection::dataID()),
93             "Attempted to build MttkPropagatorConnection more than once.");
94     MttkPropagatorConnection::build(builderHelper,
95                                     mttkPropagatorConnectionDetails.propagatorTagPrePosition,
96                                     mttkPropagatorConnectionDetails.propagatorTagPostPosition,
97                                     mttkPropagatorConnectionDetails.positionOffset,
98                                     mttkPropagatorConnectionDetails.propagatorTagPreVelocity1,
99                                     mttkPropagatorConnectionDetails.propagatorTagPostVelocity1,
100                                     mttkPropagatorConnectionDetails.velocityOffset1,
101                                     mttkPropagatorConnectionDetails.propagatorTagPreVelocity2,
102                                     mttkPropagatorConnectionDetails.propagatorTagPostVelocity2,
103                                     mttkPropagatorConnectionDetails.velocityOffset2);
104     auto* mttkPropagatorConnection =
105             builderHelper
106                     ->simulationData<MttkPropagatorConnection>(MttkPropagatorConnection::dataID())
107                     .value();
108
109     builderHelper->storeSimulationData(
110             MttkData::dataID(),
111             MttkData(referenceTemperature,
112                      referencePressure,
113                      legacySimulatorData->inputrec->nstpcouple * legacySimulatorData->inputrec->delta_t,
114                      legacySimulatorData->inputrec->tau_p,
115                      initialVolume,
116                      legacySimulatorData->inputrec->opts.nrdf[0],
117                      legacySimulatorData->inputrec->delta_t,
118                      legacySimulatorData->inputrec->compress,
119                      statePropagatorData,
120                      mttkPropagatorConnection));
121     auto* ptrToDataObject = builderHelper->simulationData<MttkData>(MttkData::dataID()).value();
122
123     energyData->addConservedEnergyContribution([ptrToDataObject](Step /*unused*/, Time time) {
124         return ptrToDataObject->temperatureCouplingIntegral(time);
125     });
126     energyData->setParrinelloRahmanBoxVelocities(
127             [ptrToDataObject]() { return ptrToDataObject->boxVelocity_; });
128     builderHelper->registerReferenceTemperatureUpdate(
129             [ptrToDataObject](ArrayRef<const real> temperatures, ReferenceTemperatureChangeAlgorithm algorithm) {
130                 ptrToDataObject->updateReferenceTemperature(temperatures[0], algorithm);
131             });
132 }
133
134 std::string MttkData::dataID()
135 {
136     return "MttkData";
137 }
138
139 MttkData::MttkData(real                       referenceTemperature,
140                    real                       referencePressure,
141                    real                       couplingTimeStep,
142                    real                       couplingTime,
143                    real                       initialVolume,
144                    real                       numDegreesOfFreedom,
145                    real                       simulationTimeStep,
146                    const tensor               compressibility,
147                    const StatePropagatorData* statePropagatorData,
148                    MttkPropagatorConnection*  mttkPropagatorConnection) :
149     couplingTimeStep_(couplingTimeStep),
150     etaVelocity_(0.0),
151     invMass_((c_presfac * trace(compressibility) * c_boltz * referenceTemperature)
152              / (DIM * initialVolume * gmx::square(couplingTime / M_2PI))),
153     etaVelocityTime_(0.0),
154     temperatureCouplingIntegral_(0.0),
155     integralTime_(0.0),
156     referencePressure_(referencePressure),
157     boxVelocity_{ { 0 } },
158     numDegreesOfFreedom_(numDegreesOfFreedom),
159     simulationTimeStep_(simulationTimeStep),
160     referenceTemperature_(referenceTemperature),
161     statePropagatorData_(statePropagatorData),
162     mttkPropagatorConnection_(mttkPropagatorConnection)
163 {
164     // Set integral based on initial volume
165     calculateIntegral(initialVolume);
166 }
167
168 MttkData::MttkData(const MttkData& other) :
169     couplingTimeStep_(other.couplingTimeStep_),
170     etaVelocity_(other.etaVelocity_),
171     invMass_(other.invMass_),
172     etaVelocityTime_(other.etaVelocityTime_),
173     temperatureCouplingIntegral_(other.temperatureCouplingIntegral_),
174     integralTime_(other.integralTime_),
175     referencePressure_(other.referencePressure_),
176     numDegreesOfFreedom_(other.numDegreesOfFreedom_),
177     simulationTimeStep_(other.simulationTimeStep_),
178     statePropagatorData_(other.statePropagatorData_),
179     mttkPropagatorConnection_(other.mttkPropagatorConnection_)
180 {
181     copy_mat(other.boxVelocity_, boxVelocity_);
182 }
183
184 void MttkData::calculateIntegralIfNeeded()
185 {
186     // Check whether coordinate time divided by the time step is close to integer
187     const bool calculationNeeded = timesClose(
188             lround(etaVelocityTime_ / couplingTimeStep_) * couplingTimeStep_, etaVelocityTime_);
189
190     if (calculationNeeded)
191     {
192         const real volume = det(statePropagatorData_->constBox());
193         // Calculate current value of barostat integral
194         calculateIntegral(volume);
195     }
196 }
197
198 void MttkData::calculateIntegral(real volume)
199 {
200     temperatureCouplingIntegral_ = kineticEnergy() + volume * referencePressure_ / c_presfac;
201     integralTime_                = etaVelocityTime_;
202 }
203
204 real MttkData::kineticEnergy() const
205 {
206     return 0.5 * etaVelocity_ * etaVelocity_ / invMass_;
207 }
208
209 void MttkData::scale(real scalingFactor, bool scalingAtFullCouplingTimeStep)
210 {
211     etaVelocity_ *= scalingFactor;
212     if (scalingAtFullCouplingTimeStep)
213     {
214         calculateIntegralIfNeeded();
215     }
216     updateScalingFactors();
217 }
218
219 real MttkData::etaVelocity() const
220 {
221     return etaVelocity_;
222 }
223
224 real MttkData::invEtaMass() const
225 {
226     return invMass_;
227 }
228
229 void MttkData::setEtaVelocity(real etaVelocity, real etaVelocityTimeIncrement)
230 {
231     etaVelocity_ = etaVelocity;
232     etaVelocityTime_ += etaVelocityTimeIncrement;
233     calculateIntegralIfNeeded();
234     updateScalingFactors();
235 }
236
237 double MttkData::temperatureCouplingIntegral(Time gmx_used_in_debug time) const
238 {
239     /* When using nstpcouple >= nstcalcenergy, we accept that the coupling
240      * integral might be ahead of the current energy calculation step. The
241      * extended system degrees of freedom are either in sync or ahead of the
242      * rest of the system.
243      */
244     GMX_ASSERT(time <= integralTime_ || timesClose(integralTime_, time),
245                "MttkData conserved energy time mismatch.");
246     return temperatureCouplingIntegral_;
247 }
248
249 real MttkData::referencePressure() const
250 {
251     return referencePressure_;
252 }
253
254 rvec* MttkData::boxVelocities()
255 {
256     return boxVelocity_;
257 }
258
259 void MttkData::updateReferenceTemperature(real temperature,
260                                           ReferenceTemperatureChangeAlgorithm gmx_unused algorithm)
261 {
262     // Currently, we don't know about any temperature change algorithms, so we assert this never gets called
263     GMX_ASSERT(false, "MttkData: Unknown ReferenceTemperatureChangeAlgorithm.");
264     invMass_ *= temperature / referenceTemperature_;
265     referenceTemperature_ = temperature;
266 }
267
268 namespace
269 {
270 /*!
271  * \brief Enum describing the contents MttkData writes to modular checkpoint
272  *
273  * When changing the checkpoint content, add a new element just above Count, and adjust the
274  * checkpoint functionality.
275  */
276 enum class CheckpointVersion
277 {
278     Base, //!< First version of modular checkpointing
279     Count //!< Number of entries. Add new versions right above this!
280 };
281 constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
282 } // namespace
283
284 template<CheckpointDataOperation operation>
285 void MttkData::doCheckpointData(CheckpointData<operation>* checkpointData)
286 {
287     checkpointVersion(checkpointData, "MttkData version", c_currentVersion);
288     checkpointData->scalar("veta", &etaVelocity_);
289     // Mass is calculated from initial volume, so need to save it for exact continuation
290     checkpointData->scalar("mass", &invMass_);
291     checkpointData->scalar("time", &etaVelocityTime_);
292     checkpointData->scalar("integral", &temperatureCouplingIntegral_);
293     checkpointData->scalar("integralTime", &integralTime_);
294 }
295
296 void MttkData::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr)
297 {
298     if (MASTER(cr))
299     {
300         doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
301     }
302 }
303
304 void MttkData::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr)
305 {
306     if (MASTER(cr))
307     {
308         doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
309     }
310     if (DOMAINDECOMP(cr))
311     {
312         dd_bcast(cr->dd, int(sizeof(real)), &etaVelocity_);
313         dd_bcast(cr->dd, int(sizeof(real)), &invMass_);
314         dd_bcast(cr->dd, int(sizeof(Time)), &etaVelocityTime_);
315         dd_bcast(cr->dd, int(sizeof(double)), &temperatureCouplingIntegral_);
316         dd_bcast(cr->dd, int(sizeof(Time)), &integralTime_);
317     }
318 }
319
320 const std::string& MttkData::clientID()
321 {
322     return identifier_;
323 }
324
325 void MttkData::propagatorCallback(Step step) const
326 {
327     mttkPropagatorConnection_->propagatorCallback(step);
328 }
329
330 void MttkPropagatorConnection::build(ModularSimulatorAlgorithmBuilderHelper* builderHelper,
331                                      const PropagatorTag& propagatorTagPrePosition,
332                                      const PropagatorTag& propagatorTagPostPosition,
333                                      int                  positionOffset,
334                                      const PropagatorTag& propagatorTagPreVelocity1,
335                                      const PropagatorTag& propagatorTagPostVelocity1,
336                                      int                  velocityOffset1,
337                                      const PropagatorTag& propagatorTagPreVelocity2,
338                                      const PropagatorTag& propagatorTagPostVelocity2,
339                                      int                  velocityOffset2)
340 {
341     GMX_RELEASE_ASSERT(!(propagatorTagPrePosition == propagatorTagPostPosition
342                          && propagatorTagPrePosition != PropagatorTag("")),
343                        "Pre- and post-step position scaling in same element is not supported.");
344     GMX_RELEASE_ASSERT(!((propagatorTagPreVelocity1 == propagatorTagPostVelocity1
345                           && propagatorTagPreVelocity1 != PropagatorTag(""))
346                          || (propagatorTagPreVelocity2 == propagatorTagPostVelocity2
347                              && propagatorTagPreVelocity2 != PropagatorTag(""))),
348                        "Pre- and post-step velocity scaling in same element is not implemented.");
349
350     // Store object with simulation algorithm for safe pointer capturing
351     builderHelper->storeSimulationData(MttkPropagatorConnection::dataID(), MttkPropagatorConnection());
352     auto* object = builderHelper
353                            ->simulationData<MttkPropagatorConnection>(MttkPropagatorConnection::dataID())
354                            .value();
355
356     builderHelper->registerTemperaturePressureControl(
357             [object, propagatorTagPrePosition, positionOffset](const PropagatorConnection& connection) {
358                 object->connectWithPropagatorPositionPreStepScaling(
359                         connection, propagatorTagPrePosition, positionOffset);
360             });
361     builderHelper->registerTemperaturePressureControl(
362             [object, propagatorTagPostPosition, positionOffset](const PropagatorConnection& connection) {
363                 object->connectWithPropagatorPositionPostStepScaling(
364                         connection, propagatorTagPostPosition, positionOffset);
365             });
366     builderHelper->registerTemperaturePressureControl(
367             [object, propagatorTagPreVelocity1, velocityOffset1](const PropagatorConnection& connection) {
368                 object->connectWithPropagatorVelocityPreStepScaling(
369                         connection, propagatorTagPreVelocity1, velocityOffset1);
370             });
371     builderHelper->registerTemperaturePressureControl(
372             [object, propagatorTagPostVelocity1, velocityOffset1](const PropagatorConnection& connection) {
373                 object->connectWithPropagatorVelocityPostStepScaling(
374                         connection, propagatorTagPostVelocity1, velocityOffset1);
375             });
376     builderHelper->registerTemperaturePressureControl(
377             [object, propagatorTagPreVelocity2, velocityOffset2](const PropagatorConnection& connection) {
378                 object->connectWithPropagatorVelocityPreStepScaling(
379                         connection, propagatorTagPreVelocity2, velocityOffset2);
380             });
381     builderHelper->registerTemperaturePressureControl(
382             [object, propagatorTagPostVelocity2, velocityOffset2](const PropagatorConnection& connection) {
383                 object->connectWithPropagatorVelocityPostStepScaling(
384                         connection, propagatorTagPostVelocity2, velocityOffset2);
385             });
386 }
387
388 void MttkPropagatorConnection::propagatorCallback(Step step) const
389 {
390     for (const auto& callback : propagatorCallbacks_)
391     {
392         std::get<0>(callback)(step + std::get<1>(callback));
393     }
394 }
395
396 void MttkPropagatorConnection::setPositionScaling(real preStepScaling, real postStepScaling)
397 {
398     for (const auto& scalingFactor : startPositionScalingFactors_)
399     {
400         std::fill(scalingFactor.begin(), scalingFactor.end(), preStepScaling);
401     }
402     for (const auto& scalingFactor : endPositionScalingFactors_)
403     {
404         std::fill(scalingFactor.begin(), scalingFactor.end(), postStepScaling);
405     }
406 }
407
408 void MttkPropagatorConnection::setVelocityScaling(real preStepScaling, real postStepScaling)
409 {
410     for (const auto& scalingFactor : startVelocityScalingFactors_)
411     {
412         std::fill(scalingFactor.begin(), scalingFactor.end(), preStepScaling);
413     }
414     for (const auto& scalingFactor : endVelocityScalingFactors_)
415     {
416         std::fill(scalingFactor.begin(), scalingFactor.end(), postStepScaling);
417     }
418 }
419
420 std::string MttkPropagatorConnection::dataID()
421 {
422     return "MttkPropagatorConnection";
423 }
424
425 void MttkPropagatorConnection::connectWithPropagatorVelocityPreStepScaling(const PropagatorConnection& connectionData,
426                                                                            const PropagatorTag& propagatorTag,
427                                                                            int offset)
428 {
429     if (connectionData.tag == propagatorTag && connectionData.hasStartVelocityScaling())
430     {
431         connectionData.setNumVelocityScalingVariables(1, ScaleVelocities::PreStepOnly);
432         startVelocityScalingFactors_.emplace_back(connectionData.getViewOnStartVelocityScaling());
433         propagatorCallbacks_.emplace_back(
434                 std::make_tuple(connectionData.getVelocityScalingCallback(), offset));
435     }
436 }
437
438 void MttkPropagatorConnection::connectWithPropagatorVelocityPostStepScaling(const PropagatorConnection& connectionData,
439                                                                             const PropagatorTag& propagatorTag,
440                                                                             int offset)
441 {
442     if (connectionData.tag == propagatorTag && connectionData.hasStartVelocityScaling())
443     {
444         // Although we're using this propagator for scaling after the update, we're using
445         // getViewOnStartVelocityScaling() - getViewOnEndVelocityScaling() is only
446         // used for propagators doing BOTH start and end scaling
447         connectionData.setNumVelocityScalingVariables(1, ScaleVelocities::PreStepOnly);
448         endVelocityScalingFactors_.emplace_back(connectionData.getViewOnStartVelocityScaling());
449         propagatorCallbacks_.emplace_back(
450                 std::make_tuple(connectionData.getVelocityScalingCallback(), offset));
451     }
452 }
453
454 void MttkPropagatorConnection::connectWithPropagatorPositionPreStepScaling(const PropagatorConnection& connectionData,
455                                                                            const PropagatorTag& propagatorTag,
456                                                                            int offset)
457 {
458     if (connectionData.tag == propagatorTag && connectionData.hasPositionScaling())
459     {
460         connectionData.setNumPositionScalingVariables(1);
461         startPositionScalingFactors_.emplace_back(connectionData.getViewOnPositionScaling());
462         propagatorCallbacks_.emplace_back(
463                 std::make_tuple(connectionData.getPositionScalingCallback(), offset));
464     }
465 }
466
467 void MttkPropagatorConnection::connectWithPropagatorPositionPostStepScaling(const PropagatorConnection& connectionData,
468                                                                             const PropagatorTag& propagatorTag,
469                                                                             int offset)
470 {
471     if (connectionData.tag == propagatorTag && connectionData.hasPositionScaling())
472     {
473         connectionData.setNumPositionScalingVariables(1);
474         endPositionScalingFactors_.emplace_back(connectionData.getViewOnPositionScaling());
475         propagatorCallbacks_.emplace_back(
476                 std::make_tuple(connectionData.getPositionScalingCallback(), offset));
477     }
478 }
479
480 void MttkData::updateScalingFactors()
481 {
482     // Tuckerman et al. 2006, Eq 5.8
483     // Note that we're using the dof of the first temperature group only
484     const real alpha = 1.0 + DIM / (numDegreesOfFreedom_);
485     /* Tuckerman et al. 2006, eqs 5.11 and 5.13:
486      *
487      * r(t+dt)   = r(t)*exp(v_eta*dt) + dt*v*exp(v_eta*dt/2) * [sinh(v_eta*dt/2) / (v_eta*dt/2)]
488      * v(t+dt/2) = v(t)*exp(-a*v_eta*dt/2) +
489      *             dt/2*f/m*exp(-a*v_eta*dt/4) * [sinh(a*v_eta*dt/4) / (a*v_eta*dt/4)]
490      * with a = 1 + 1/Natoms
491      *
492      * For r, let
493      *   s1 = exp(v_eta*dt/2)
494      *   s2 = [sinh(v_eta*dt/2) / (v_eta*dt/2)]
495      * so we can use
496      *   r(t) *= s1/s2
497      *   r(t+dt) = r(t) + dt*v
498      *   r(t+dt) *= s1*s2  <=>  r(t+dt) = s1*s2 * (r(t)*s1/s2 + dt*v) = s1^2*r(t) + dt*v*s1*s2
499      *
500      * For v, let
501      *   s1 = exp(-a*v_eta*dt/4)
502      *   s2 = [sinh(a*v_eta*dt/4) / (a*v_eta*dt/4)]
503      * so we can use
504      *   v(t) *= s1/s2
505      *   v(t+dt/2) = v(t) + dt/2*f/m
506      *   v(t+dt/2) *= s1*s2  <=>  v(t+dt/2) = s1^2*v(t) + dt/2*f/m*s1*s2
507      *
508      * In legacy simulator, this scaling is applied every step, even if the barostat is updated
509      * less frequently, so we are mirroring this by using the simulation time step for dt and
510      * requesting scaling every step. This could likely be applied impulse-style by using the
511      * coupling time step for dt and only applying it when the barostat gets updated.
512      */
513     const real scalingPos1 = std::exp(0.5 * simulationTimeStep_ * etaVelocity_);
514     const real scalingPos2 = gmx::series_sinhx(0.5 * simulationTimeStep_ * etaVelocity_);
515     const real scalingVel1 = std::exp(-alpha * 0.25 * simulationTimeStep_ * etaVelocity_);
516     const real scalingVel2 = gmx::series_sinhx(alpha * 0.25 * simulationTimeStep_ * etaVelocity_);
517
518     mttkPropagatorConnection_->setPositionScaling(scalingPos1 / scalingPos2, scalingPos1 * scalingPos2);
519     mttkPropagatorConnection_->setVelocityScaling(scalingVel1 / scalingVel2, scalingVel1 * scalingVel2);
520 }
521
522 void MttkElement::propagateEtaVelocity(Step step)
523 {
524     const auto* ekind         = energyData_->ekindata();
525     const auto* virial        = energyData_->totalVirial(step);
526     const real  currentVolume = det(statePropagatorData_->constBox());
527     // Tuckerman et al. 2006, Eq 5.8
528     // Note that we're using the dof of the first temperature group only
529     const real alpha = 1.0 + DIM / (numDegreesOfFreedom_);
530     // Also here, using first group only
531     const real kineticEnergyFactor = alpha * ekind->tcstat[0].ekinscalef_nhc;
532     // Now, we're using full system kinetic energy!
533     tensor modifiedKineticEnergy;
534     msmul(ekind->ekin, kineticEnergyFactor, modifiedKineticEnergy);
535
536     tensor currentPressureTensor;
537
538     const real currentPressure =
539             calc_pres(pbcType_, numWalls_, statePropagatorData_->constBox(), modifiedKineticEnergy, virial, currentPressureTensor)
540             + energyData_->enerdata()->term[F_PDISPCORR];
541
542     const real etaAcceleration = DIM * currentVolume * (mttkData_->invEtaMass() / c_presfac)
543                                  * (currentPressure - mttkData_->referencePressure());
544
545     mttkData_->setEtaVelocity(mttkData_->etaVelocity() + propagationTimeStep_ * etaAcceleration,
546                               propagationTimeStep_);
547 }
548
549 MttkElement::MttkElement(int                        nstcouple,
550                          int                        offset,
551                          real                       propagationTimeStep,
552                          ScheduleOnInitStep         scheduleOnInitStep,
553                          Step                       initStep,
554                          const StatePropagatorData* statePropagatorData,
555                          EnergyData*                energyData,
556                          MttkData*                  mttkData,
557                          PbcType                    pbcType,
558                          int                        numWalls,
559                          real                       numDegreesOfFreedom) :
560     pbcType_(pbcType),
561     numWalls_(numWalls),
562     numDegreesOfFreedom_(numDegreesOfFreedom),
563     nstcouple_(nstcouple),
564     offset_(offset),
565     propagationTimeStep_(propagationTimeStep),
566     scheduleOnInitStep_(scheduleOnInitStep),
567     initialStep_(initStep),
568     statePropagatorData_(statePropagatorData),
569     energyData_(energyData),
570     mttkData_(mttkData)
571 {
572 }
573
574 void MttkElement::scheduleTask(Step step, Time /*unused*/, const RegisterRunFunction& registerRunFunction)
575 {
576     if (step == initialStep_ && scheduleOnInitStep_ == ScheduleOnInitStep::No)
577     {
578         return;
579     }
580     if (do_per_step(step + nstcouple_ + offset_, nstcouple_))
581     {
582         // do T-coupling this step
583         registerRunFunction([this, step]() { propagateEtaVelocity(step); });
584     }
585
586     // Let propagators know that we want to scale
587     // (we're scaling every step - see comment in MttkData::updateScalingFactors())
588     mttkData_->propagatorCallback(step);
589 }
590
591 ISimulatorElement* MttkElement::getElementPointerImpl(
592         LegacySimulatorData*                    legacySimulatorData,
593         ModularSimulatorAlgorithmBuilderHelper* builderHelper,
594         StatePropagatorData gmx_unused* statePropagatorData,
595         EnergyData*                     energyData,
596         FreeEnergyPerturbationData gmx_unused* freeEnergyPerturbationData,
597         GlobalCommunicationHelper gmx_unused* globalCommunicationHelper,
598         ObservablesReducer gmx_unused*         observablesReducer,
599         Offset                                 offset,
600         ScheduleOnInitStep                     scheduleOnInitStep,
601         const MttkPropagatorConnectionDetails& mttkPropagatorConnectionDetails)
602 {
603     // Data is now owned by the caller of this method, who will handle lifetime (see ModularSimulatorAlgorithm)
604     if (!builderHelper->simulationData<MttkData>(MttkData::dataID()))
605     {
606         MttkData::build(legacySimulatorData, builderHelper, statePropagatorData, energyData, mttkPropagatorConnectionDetails);
607     }
608     auto* mttkData = builderHelper->simulationData<MttkData>(MttkData::dataID()).value();
609
610     // Element is now owned by the caller of this method, who will handle lifetime (see ModularSimulatorAlgorithm)
611     auto* element = static_cast<MttkElement*>(builderHelper->storeElement(std::make_unique<MttkElement>(
612             legacySimulatorData->inputrec->nsttcouple,
613             offset,
614             legacySimulatorData->inputrec->delta_t * legacySimulatorData->inputrec->nstpcouple / 2,
615             scheduleOnInitStep,
616             legacySimulatorData->inputrec->init_step,
617             statePropagatorData,
618             energyData,
619             mttkData,
620             legacySimulatorData->inputrec->pbcType,
621             legacySimulatorData->inputrec->nwall,
622             legacySimulatorData->inputrec->opts.nrdf[0])));
623
624     return element;
625 }
626
627 MttkBoxScaling::MttkBoxScaling(real                 simulationTimeStep,
628                                StatePropagatorData* statePropagatorData,
629                                MttkData*            mttkData) :
630     simulationTimeStep_(simulationTimeStep), statePropagatorData_(statePropagatorData), mttkData_(mttkData)
631 {
632 }
633
634 void MttkBoxScaling::scheduleTask(Step gmx_unused            step,
635                                   gmx_unused Time            time,
636                                   const RegisterRunFunction& registerRunFunction)
637 {
638     registerRunFunction([this]() { scaleBox(); });
639 }
640
641 void MttkBoxScaling::scaleBox()
642 {
643     auto* box = statePropagatorData_->box();
644
645     /* DIM * eta = ln V.  so DIM*eta_new = DIM*eta_old + DIM*dt*veta =>
646        ln V_new = ln V_old + 3*dt*veta => V_new = V_old*exp(3*dt*veta) =>
647        Side length scales as exp(veta*dt) */
648     msmul(box, std::exp(mttkData_->etaVelocity() * simulationTimeStep_), box);
649
650     /* Relate veta to boxv.  veta = d(eta)/dT = (1/DIM)*1/V dV/dT.
651        o               If we assume isotropic scaling, and box length scaling
652        factor L, then V = L^DIM (det(M)).  So dV/dt = DIM
653        L^(DIM-1) dL/dt det(M), and veta = (1/L) dL/dt.  The
654        determinant of B is L^DIM det(M), and the determinant
655        of dB/dt is (dL/dT)^DIM det (M).  veta will be
656        (det(dB/dT)/det(B))^(1/3).  Then since M =
657        B_new*(vol_new)^(1/3), dB/dT_new = (veta_new)*B(new). */
658     msmul(box, mttkData_->etaVelocity(), mttkData_->boxVelocities());
659
660     mttkData_->calculateIntegralIfNeeded();
661 }
662
663 ISimulatorElement* MttkBoxScaling::getElementPointerImpl(
664         LegacySimulatorData*                    legacySimulatorData,
665         ModularSimulatorAlgorithmBuilderHelper* builderHelper,
666         StatePropagatorData*                    statePropagatorData,
667         EnergyData*                             energyData,
668         FreeEnergyPerturbationData gmx_unused* freeEnergyPerturbationData,
669         GlobalCommunicationHelper gmx_unused* globalCommunicationHelper,
670         ObservablesReducer gmx_unused*         observablesReducer,
671         const MttkPropagatorConnectionDetails& mttkPropagatorConnectionDetails)
672 {
673     // Data is now owned by the caller of this method, who will handle lifetime (see ModularSimulatorAlgorithm)
674     if (!builderHelper->simulationData<MttkData>(MttkData::dataID()))
675     {
676         MttkData::build(legacySimulatorData, builderHelper, statePropagatorData, energyData, mttkPropagatorConnectionDetails);
677     }
678
679     return builderHelper->storeElement(std::make_unique<MttkBoxScaling>(
680             legacySimulatorData->inputrec->delta_t,
681             statePropagatorData,
682             builderHelper->simulationData<MttkData>(MttkData::dataID()).value()));
683 }
684
685 } // namespace gmx