Use ObservablesReducer for check of DD bonded interaction count.
[alexxy/gromacs.git] / src / gromacs / modularsimulator / statepropagatordata.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  * \brief Defines the state for the modular simulator
37  *
38  * \author Pascal Merz <pascal.merz@me.com>
39  * \ingroup module_modularsimulator
40  */
41
42 #include "gmxpre.h"
43
44 #include "gromacs/utility/enumerationhelpers.h"
45 #include "statepropagatordata.h"
46
47 #include "gromacs/commandline/filenm.h"
48 #include "gromacs/domdec/collect.h"
49 #include "gromacs/domdec/domdec.h"
50 #include "gromacs/fileio/confio.h"
51 #include "gromacs/math/vec.h"
52 #include "gromacs/mdlib/gmx_omp_nthreads.h"
53 #include "gromacs/mdlib/mdatoms.h"
54 #include "gromacs/mdlib/mdoutf.h"
55 #include "gromacs/mdlib/stat.h"
56 #include "gromacs/mdlib/update.h"
57 #include "gromacs/mdtypes/checkpointdata.h"
58 #include "gromacs/mdtypes/commrec.h"
59 #include "gromacs/mdtypes/forcebuffers.h"
60 #include "gromacs/mdtypes/forcerec.h"
61 #include "gromacs/mdtypes/inputrec.h"
62 #include "gromacs/mdtypes/mdatom.h"
63 #include "gromacs/mdtypes/mdrunoptions.h"
64 #include "gromacs/mdtypes/state.h"
65 #include "gromacs/pbcutil/pbc.h"
66 #include "gromacs/topology/atoms.h"
67 #include "gromacs/topology/topology.h"
68 #include "gromacs/trajectory/trajectoryframe.h"
69
70 #include "freeenergyperturbationdata.h"
71 #include "modularsimulator.h"
72 #include "simulatoralgorithm.h"
73
74 namespace gmx
75 {
76 /*! \internal
77  * \brief Helper object to scale velocities according to reference temperature change
78  */
79 class StatePropagatorData::ReferenceTemperatureHelper
80 {
81 public:
82     //! Constructor
83     ReferenceTemperatureHelper(const t_inputrec*    inputrec,
84                                StatePropagatorData* statePropagatorData,
85                                const t_mdatoms*     mdatoms) :
86         numTemperatureGroups_(inputrec->opts.ngtc),
87         referenceTemperature_(inputrec->opts.ref_t, inputrec->opts.ref_t + inputrec->opts.ngtc),
88         velocityScalingFactors_(numTemperatureGroups_),
89         statePropagatorData_(statePropagatorData),
90         mdatoms_(mdatoms)
91     {
92     }
93
94     /*! \brief Update the reference temperature
95      *
96      * Changing the reference temperature requires scaling the velocities, which
97      * is done here.
98      *
99      * \param temperatures  New reference temperatures
100      * \param algorithm     The algorithm which initiated the temperature update
101      */
102     void updateReferenceTemperature(ArrayRef<const real>                           temperatures,
103                                     ReferenceTemperatureChangeAlgorithm gmx_unused algorithm)
104     {
105         // Currently, we don't know about any temperature change algorithms, so we assert this never gets called
106         GMX_ASSERT(false, "StatePropagatorData: Unknown ReferenceTemperatureChangeAlgorithm.");
107         for (int temperatureGroup = 0; temperatureGroup < numTemperatureGroups_; ++temperatureGroup)
108         {
109             velocityScalingFactors_[temperatureGroup] =
110                     std::sqrt(temperatures[temperatureGroup] / referenceTemperature_[temperatureGroup]);
111         }
112
113         auto velocities = statePropagatorData_->velocitiesView().unpaddedArrayRef();
114         int  nth        = gmx_omp_nthreads_get(ModuleMultiThread::Update);
115 #pragma omp parallel for num_threads(nth) schedule(static) default(none) \
116         shared(nth, velocityScalingFactors_, velocities)
117         for (int threadIndex = 0; threadIndex < nth; threadIndex++)
118         {
119             int startAtom = 0;
120             int endAtom   = 0;
121             getThreadAtomRange(nth, threadIndex, statePropagatorData_->localNAtoms_, &startAtom, &endAtom);
122             for (int atomIdx = startAtom; atomIdx < endAtom; ++atomIdx)
123             {
124                 const int temperatureGroup = (mdatoms_->cTC != nullptr) ? mdatoms_->cTC[atomIdx] : 0;
125                 velocities[atomIdx] *= velocityScalingFactors_[temperatureGroup];
126             }
127         }
128         std::copy(temperatures.begin(), temperatures.end(), referenceTemperature_.begin());
129     }
130
131 private:
132     //! The number of temperature groups
133     const int numTemperatureGroups_;
134     //! Coupling temperature per group
135     std::vector<real> referenceTemperature_;
136     //! Working array used at every temperature update
137     std::vector<real> velocityScalingFactors_;
138
139     //! Pointer to StatePropagatorData to scale velocities
140     StatePropagatorData* statePropagatorData_;
141     //! Atom parameters for this domain (temperature group information)
142     const t_mdatoms* mdatoms_;
143 };
144
145 StatePropagatorData::StatePropagatorData(int                numAtoms,
146                                          FILE*              fplog,
147                                          const t_commrec*   cr,
148                                          t_state*           globalState,
149                                          t_state*           localState,
150                                          bool               useGPU,
151                                          bool               canMoleculesBeDistributedOverPBC,
152                                          bool               writeFinalConfiguration,
153                                          const std::string& finalConfigurationFilename,
154                                          const t_inputrec*  inputrec,
155                                          const t_mdatoms*   mdatoms,
156                                          const gmx_mtop_t&  globalTop) :
157     totalNumAtoms_(numAtoms),
158     localNAtoms_(0),
159     box_{ { 0 } },
160     previousBox_{ { 0 } },
161     ddpCount_(0),
162     element_(std::make_unique<Element>(this,
163                                        fplog,
164                                        cr,
165                                        inputrec->nstxout,
166                                        inputrec->nstvout,
167                                        inputrec->nstfout,
168                                        inputrec->nstxout_compressed,
169                                        canMoleculesBeDistributedOverPBC,
170                                        writeFinalConfiguration,
171                                        finalConfigurationFilename,
172                                        inputrec,
173                                        globalTop)),
174     referenceTemperatureHelper_(std::make_unique<ReferenceTemperatureHelper>(inputrec, this, mdatoms)),
175     vvResetVelocities_(false),
176     isRegularSimulationEnd_(false),
177     lastStep_(-1),
178     globalState_(globalState)
179 {
180     bool stateHasVelocities;
181     // Local state only becomes valid now.
182     if (DOMAINDECOMP(cr))
183     {
184         dd_init_local_state(*cr->dd, globalState, localState);
185         stateHasVelocities = ((localState->flags & enumValueToBitMask(StateEntry::V)) != 0);
186         setLocalState(localState);
187     }
188     else
189     {
190         state_change_natoms(globalState, globalState->natoms);
191         f_.resize(globalState->natoms);
192         localNAtoms_ = globalState->natoms;
193         x_           = globalState->x;
194         v_           = globalState->v;
195         copy_mat(globalState->box, box_);
196         stateHasVelocities = ((globalState->flags & enumValueToBitMask(StateEntry::V)) != 0);
197         previousX_.resizeWithPadding(localNAtoms_);
198         ddpCount_ = globalState->ddp_count;
199         copyPosition();
200     }
201     if (useGPU)
202     {
203         changePinningPolicy(&x_, gmx::PinningPolicy::PinnedIfSupported);
204     }
205
206     if (DOMAINDECOMP(cr) && MASTER(cr))
207     {
208         xGlobal_.resizeWithPadding(totalNumAtoms_);
209         previousXGlobal_.resizeWithPadding(totalNumAtoms_);
210         vGlobal_.resizeWithPadding(totalNumAtoms_);
211         fGlobal_.resizeWithPadding(totalNumAtoms_);
212     }
213
214     if (!inputrec->bContinuation)
215     {
216         if (stateHasVelocities)
217         {
218             auto v = velocitiesView().paddedArrayRef();
219             // Set the velocities of vsites, shells and frozen atoms to zero
220             for (int i = 0; i < mdatoms->homenr; i++)
221             {
222                 if (mdatoms->ptype[i] == ParticleType::Shell)
223                 {
224                     clear_rvec(v[i]);
225                 }
226                 else if (mdatoms->cFREEZE)
227                 {
228                     for (int m = 0; m < DIM; m++)
229                     {
230                         if (inputrec->opts.nFreeze[mdatoms->cFREEZE[i]][m])
231                         {
232                             v[i][m] = 0;
233                         }
234                     }
235                 }
236             }
237         }
238         if (inputrec->eI == IntegrationAlgorithm::VV)
239         {
240             vvResetVelocities_ = true;
241         }
242     }
243 }
244
245 StatePropagatorData::~StatePropagatorData() = default;
246
247 StatePropagatorData::Element* StatePropagatorData::element()
248 {
249     return element_.get();
250 }
251
252 void StatePropagatorData::setup()
253 {
254     if (element_)
255     {
256         element_->elementSetup();
257     }
258 }
259
260 ArrayRefWithPadding<RVec> StatePropagatorData::positionsView()
261 {
262     return x_.arrayRefWithPadding();
263 }
264
265 ArrayRefWithPadding<const RVec> StatePropagatorData::constPositionsView() const
266 {
267     return x_.constArrayRefWithPadding();
268 }
269
270 ArrayRefWithPadding<RVec> StatePropagatorData::previousPositionsView()
271 {
272     return previousX_.arrayRefWithPadding();
273 }
274
275 ArrayRefWithPadding<const RVec> StatePropagatorData::constPreviousPositionsView() const
276 {
277     return previousX_.constArrayRefWithPadding();
278 }
279
280 ArrayRefWithPadding<RVec> StatePropagatorData::velocitiesView()
281 {
282     return v_.arrayRefWithPadding();
283 }
284
285 ArrayRefWithPadding<const RVec> StatePropagatorData::constVelocitiesView() const
286 {
287     return v_.constArrayRefWithPadding();
288 }
289
290 ForceBuffersView& StatePropagatorData::forcesView()
291 {
292     return f_.view();
293 }
294
295 const ForceBuffersView& StatePropagatorData::constForcesView() const
296 {
297     return f_.view();
298 }
299
300 rvec* StatePropagatorData::box()
301 {
302     return box_;
303 }
304
305 const rvec* StatePropagatorData::constBox() const
306 {
307     return box_;
308 }
309
310 rvec* StatePropagatorData::previousBox()
311 {
312     return previousBox_;
313 }
314
315 const rvec* StatePropagatorData::constPreviousBox() const
316 {
317     return previousBox_;
318 }
319
320 int StatePropagatorData::localNumAtoms() const
321 {
322     return localNAtoms_;
323 }
324
325 int StatePropagatorData::totalNumAtoms() const
326 {
327     return totalNumAtoms_;
328 }
329
330 t_state* StatePropagatorData::localState()
331 {
332     localState_->flags = enumValueToBitMask(StateEntry::X) | enumValueToBitMask(StateEntry::V)
333                          | enumValueToBitMask(StateEntry::Box);
334     state_change_natoms(localState_, localNAtoms_);
335     std::swap(localState_->x, x_);
336     std::swap(localState_->v, v_);
337     copy_mat(box_, localState_->box);
338     localState_->ddp_count       = ddpCount_;
339     localState_->ddp_count_cg_gl = ddpCountCgGl_;
340     localState_->cg_gl           = cgGl_;
341     return localState_;
342 }
343
344 std::unique_ptr<t_state> StatePropagatorData::copyLocalState(std::unique_ptr<t_state> copy)
345 {
346     copy->flags = enumValueToBitMask(StateEntry::X) | enumValueToBitMask(StateEntry::V)
347                   | enumValueToBitMask(StateEntry::Box);
348     state_change_natoms(copy.get(), localNAtoms_);
349     copy->x = x_;
350     copy->v = v_;
351     copy_mat(box_, copy->box);
352     copy->ddp_count       = ddpCount_;
353     copy->ddp_count_cg_gl = ddpCountCgGl_;
354     copy->cg_gl           = cgGl_;
355     return copy;
356 }
357
358 void StatePropagatorData::setLocalState(t_state* state)
359 {
360     localState_  = state;
361     localNAtoms_ = state->natoms;
362     previousX_.resizeWithPadding(localNAtoms_);
363     std::swap(x_, state->x);
364     std::swap(v_, state->v);
365     copy_mat(state->box, box_);
366     copyPosition();
367     ddpCount_     = state->ddp_count;
368     ddpCountCgGl_ = state->ddp_count_cg_gl;
369     cgGl_         = state->cg_gl;
370
371     if (vvResetVelocities_)
372     {
373         /* DomDec runs twice early in the simulation, once at setup time, and once before the first
374          * step. Every time DD runs, it sets a new local state here. We are saving a backup during
375          * setup time (ok for non-DD cases), so we need to update our backup to the DD state before
376          * the first step here to avoid resetting to an earlier DD state. This is done before any
377          * propagation that needs to be reset, so it's not very safe but correct for now.
378          * TODO: Get rid of this once input is assumed to be at half steps
379          */
380         velocityBackup_ = v_;
381     }
382 }
383
384 t_state* StatePropagatorData::globalState()
385 {
386     return globalState_;
387 }
388
389 ForceBuffers* StatePropagatorData::forcePointer()
390 {
391     return &f_;
392 }
393
394 void StatePropagatorData::copyPosition()
395 {
396     int nth = gmx_omp_nthreads_get(ModuleMultiThread::Update);
397
398 #pragma omp parallel for num_threads(nth) schedule(static) default(none) shared(nth)
399     for (int th = 0; th < nth; th++)
400     {
401         int start_th, end_th;
402         getThreadAtomRange(nth, th, localNAtoms_, &start_th, &end_th);
403         copyPosition(start_th, end_th);
404     }
405
406     /* Box is changed in update() when we do pressure coupling,
407      * but we should still use the old box for energy corrections and when
408      * writing it to the energy file, so it matches the trajectory files for
409      * the same timestep above. Make a copy in a separate array.
410      */
411     copy_mat(box_, previousBox_);
412 }
413
414 void StatePropagatorData::copyPosition(int start, int end)
415 {
416     for (int i = start; i < end; ++i)
417     {
418         previousX_[i] = x_[i];
419     }
420 }
421
422 void StatePropagatorData::updateReferenceTemperature(ArrayRef<const real> temperatures,
423                                                      ReferenceTemperatureChangeAlgorithm algorithm)
424 {
425     referenceTemperatureHelper_->updateReferenceTemperature(temperatures, algorithm);
426 }
427
428 void StatePropagatorData::Element::scheduleTask(Step                       step,
429                                                 Time gmx_unused            time,
430                                                 const RegisterRunFunction& registerRunFunction)
431 {
432     if (statePropagatorData_->vvResetVelocities_)
433     {
434         statePropagatorData_->vvResetVelocities_ = false;
435         registerRunFunction([this]() { statePropagatorData_->resetVelocities(); });
436     }
437     // copy x -> previousX
438     registerRunFunction([this]() { statePropagatorData_->copyPosition(); });
439     // if it's a write out step, keep a copy for writeout
440     if (step == writeOutStep_ || (step == lastStep_ && writeFinalConfiguration_))
441     {
442         registerRunFunction([this]() { saveState(); });
443     }
444 }
445
446 void StatePropagatorData::Element::saveState()
447 {
448     GMX_ASSERT(!localStateBackupValid_,
449                "Save state called again before previous state was written.");
450     localStateBackup_ = statePropagatorData_->copyLocalState(std::move(localStateBackup_));
451     if (freeEnergyPerturbationData_)
452     {
453         localStateBackup_->fep_state    = freeEnergyPerturbationData_->currentFEPState();
454         ArrayRef<const real> lambdaView = freeEnergyPerturbationData_->constLambdaView();
455         std::copy(lambdaView.begin(), lambdaView.end(), localStateBackup_->lambda.begin());
456         localStateBackup_->flags |=
457                 enumValueToBitMask(StateEntry::Lambda) | enumValueToBitMask(StateEntry::FepState);
458     }
459     localStateBackupValid_ = true;
460 }
461
462 std::optional<SignallerCallback> StatePropagatorData::Element::registerTrajectorySignallerCallback(TrajectoryEvent event)
463 {
464     if (event == TrajectoryEvent::StateWritingStep)
465     {
466         return [this](Step step, Time /*unused*/) { this->writeOutStep_ = step; };
467     }
468     return std::nullopt;
469 }
470
471 std::optional<ITrajectoryWriterCallback>
472 StatePropagatorData::Element::registerTrajectoryWriterCallback(TrajectoryEvent event)
473 {
474     if (event == TrajectoryEvent::StateWritingStep)
475     {
476         return [this](gmx_mdoutf* outf, Step step, Time time, bool writeTrajectory, bool gmx_unused writeLog) {
477             if (writeTrajectory)
478             {
479                 write(outf, step, time);
480             }
481         };
482     }
483     return std::nullopt;
484 }
485
486 void StatePropagatorData::Element::write(gmx_mdoutf_t outf, Step currentStep, Time currentTime)
487 {
488     wallcycle_start(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
489     unsigned int mdof_flags = 0;
490     if (do_per_step(currentStep, nstxout_))
491     {
492         mdof_flags |= MDOF_X;
493     }
494     if (do_per_step(currentStep, nstvout_))
495     {
496         mdof_flags |= MDOF_V;
497     }
498     if (do_per_step(currentStep, nstfout_))
499     {
500         mdof_flags |= MDOF_F;
501     }
502     if (do_per_step(currentStep, nstxout_compressed_))
503     {
504         mdof_flags |= MDOF_X_COMPRESSED;
505     }
506     if (do_per_step(currentStep, mdoutf_get_tng_box_output_interval(outf)))
507     {
508         mdof_flags |= MDOF_BOX;
509     }
510     if (do_per_step(currentStep, mdoutf_get_tng_lambda_output_interval(outf)))
511     {
512         mdof_flags |= MDOF_LAMBDA;
513     }
514     if (do_per_step(currentStep, mdoutf_get_tng_compressed_box_output_interval(outf)))
515     {
516         mdof_flags |= MDOF_BOX_COMPRESSED;
517     }
518     if (do_per_step(currentStep, mdoutf_get_tng_compressed_lambda_output_interval(outf)))
519     {
520         mdof_flags |= MDOF_LAMBDA_COMPRESSED;
521     }
522
523     if (mdof_flags == 0)
524     {
525         wallcycle_stop(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
526         return;
527     }
528     GMX_ASSERT(localStateBackupValid_, "Trajectory writing called, but no state saved.");
529
530     // TODO: This is only used for CPT - needs to be filled when we turn CPT back on
531     ObservablesHistory* observablesHistory = nullptr;
532
533     mdoutf_write_to_trajectory_files(fplog_,
534                                      cr_,
535                                      outf,
536                                      static_cast<int>(mdof_flags),
537                                      statePropagatorData_->totalNumAtoms_,
538                                      currentStep,
539                                      currentTime,
540                                      localStateBackup_.get(),
541                                      statePropagatorData_->globalState_,
542                                      observablesHistory,
543                                      statePropagatorData_->f_.view().force(),
544                                      &dummyCheckpointDataHolder_);
545
546     if (currentStep != lastStep_ || !isRegularSimulationEnd_)
547     {
548         localStateBackupValid_ = false;
549     }
550     wallcycle_stop(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
551 }
552
553 void StatePropagatorData::Element::elementSetup()
554 {
555     if (statePropagatorData_->vvResetVelocities_)
556     {
557         // MD-VV does the first velocity half-step only to calculate the constraint virial,
558         // then resets the velocities since the input is assumed to be positions and velocities
559         // at full time step. TODO: Change this to have input at half time steps.
560         statePropagatorData_->velocityBackup_ = statePropagatorData_->v_;
561     }
562 }
563
564 void StatePropagatorData::resetVelocities()
565 {
566     v_ = velocityBackup_;
567 }
568
569 namespace
570 {
571 /*!
572  * \brief Enum describing the contents StatePropagatorData::Element writes to modular checkpoint
573  *
574  * When changing the checkpoint content, add a new element just above Count, and adjust the
575  * checkpoint functionality.
576  */
577 enum class CheckpointVersion
578 {
579     Base, //!< First version of modular checkpointing
580     Count //!< Number of entries. Add new versions right above this!
581 };
582 constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
583 } // namespace
584
585 template<CheckpointDataOperation operation>
586 void StatePropagatorData::doCheckpointData(CheckpointData<operation>* checkpointData)
587 {
588     checkpointVersion(checkpointData, "StatePropagatorData version", c_currentVersion);
589     checkpointData->scalar("numAtoms", &totalNumAtoms_);
590
591     if (operation == CheckpointDataOperation::Read)
592     {
593         xGlobal_.resizeWithPadding(totalNumAtoms_);
594         vGlobal_.resizeWithPadding(totalNumAtoms_);
595     }
596
597     checkpointData->arrayRef("positions", makeCheckpointArrayRef<operation>(xGlobal_));
598     checkpointData->arrayRef("velocities", makeCheckpointArrayRef<operation>(vGlobal_));
599     checkpointData->tensor("box", box_);
600     checkpointData->scalar("ddpCount", &ddpCount_);
601     checkpointData->scalar("ddpCountCgGl", &ddpCountCgGl_);
602     checkpointData->arrayRef("cgGl", makeCheckpointArrayRef<operation>(cgGl_));
603 }
604
605 void StatePropagatorData::Element::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
606                                                        const t_commrec*                   cr)
607 {
608     if (DOMAINDECOMP(cr))
609     {
610         // Collect state from all ranks into global vectors
611         dd_collect_vec(cr->dd,
612                        statePropagatorData_->ddpCount_,
613                        statePropagatorData_->ddpCountCgGl_,
614                        statePropagatorData_->cgGl_,
615                        statePropagatorData_->x_,
616                        statePropagatorData_->xGlobal_);
617         dd_collect_vec(cr->dd,
618                        statePropagatorData_->ddpCount_,
619                        statePropagatorData_->ddpCountCgGl_,
620                        statePropagatorData_->cgGl_,
621                        statePropagatorData_->v_,
622                        statePropagatorData_->vGlobal_);
623     }
624     else
625     {
626         // Everything is local - copy local vectors into global ones
627         statePropagatorData_->xGlobal_.resizeWithPadding(statePropagatorData_->totalNumAtoms());
628         statePropagatorData_->vGlobal_.resizeWithPadding(statePropagatorData_->totalNumAtoms());
629         std::copy(statePropagatorData_->x_.begin(),
630                   statePropagatorData_->x_.end(),
631                   statePropagatorData_->xGlobal_.begin());
632         std::copy(statePropagatorData_->v_.begin(),
633                   statePropagatorData_->v_.end(),
634                   statePropagatorData_->vGlobal_.begin());
635     }
636     if (MASTER(cr))
637     {
638         statePropagatorData_->doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
639     }
640 }
641
642 /*!
643  * \brief Update the legacy global state
644  *
645  * When restoring from checkpoint, data will be distributed during domain decomposition at setup stage.
646  * Domain decomposition still uses the legacy global t_state object so make sure it's up-to-date.
647  */
648 static void updateGlobalState(t_state*                      globalState,
649                               const PaddedHostVector<RVec>& x,
650                               const PaddedHostVector<RVec>& v,
651                               const tensor                  box,
652                               int                           ddpCount,
653                               int                           ddpCountCgGl,
654                               const std::vector<int>&       cgGl)
655 {
656     globalState->x = x;
657     globalState->v = v;
658     copy_mat(box, globalState->box);
659     globalState->ddp_count       = ddpCount;
660     globalState->ddp_count_cg_gl = ddpCountCgGl;
661     globalState->cg_gl           = cgGl;
662 }
663
664 void StatePropagatorData::Element::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
665                                                           const t_commrec*                  cr)
666 {
667     if (MASTER(cr))
668     {
669         statePropagatorData_->doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
670     }
671
672     // Copy data to global state to be distributed by DD at setup stage
673     if (DOMAINDECOMP(cr) && MASTER(cr))
674     {
675         updateGlobalState(statePropagatorData_->globalState_,
676                           statePropagatorData_->xGlobal_,
677                           statePropagatorData_->vGlobal_,
678                           statePropagatorData_->box_,
679                           statePropagatorData_->ddpCount_,
680                           statePropagatorData_->ddpCountCgGl_,
681                           statePropagatorData_->cgGl_);
682     }
683     // Everything is local - copy global vectors to local ones
684     if (!DOMAINDECOMP(cr))
685     {
686         statePropagatorData_->x_.resizeWithPadding(statePropagatorData_->totalNumAtoms_);
687         statePropagatorData_->v_.resizeWithPadding(statePropagatorData_->totalNumAtoms_);
688         std::copy(statePropagatorData_->xGlobal_.begin(),
689                   statePropagatorData_->xGlobal_.end(),
690                   statePropagatorData_->x_.begin());
691         std::copy(statePropagatorData_->vGlobal_.begin(),
692                   statePropagatorData_->vGlobal_.end(),
693                   statePropagatorData_->v_.begin());
694     }
695 }
696
697 const std::string& StatePropagatorData::Element::clientID()
698 {
699     return StatePropagatorData::checkpointID();
700 }
701
702 void StatePropagatorData::Element::trajectoryWriterTeardown(gmx_mdoutf* gmx_unused outf)
703 {
704     // Note that part of this code is duplicated in do_md_trajectory_writing.
705     // This duplication is needed while both legacy and modular code paths are in use.
706     // TODO: Remove duplication asap, make sure to keep in sync in the meantime.
707     if (!writeFinalConfiguration_ || !isRegularSimulationEnd_)
708     {
709         return;
710     }
711
712     GMX_ASSERT(localStateBackupValid_, "Final trajectory writing called, but no state saved.");
713
714     wallcycle_start(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
715     if (DOMAINDECOMP(cr_))
716     {
717         auto globalXRef =
718                 MASTER(cr_) ? statePropagatorData_->globalState_->x : gmx::ArrayRef<gmx::RVec>();
719         dd_collect_vec(cr_->dd,
720                        localStateBackup_->ddp_count,
721                        localStateBackup_->ddp_count_cg_gl,
722                        localStateBackup_->cg_gl,
723                        localStateBackup_->x,
724                        globalXRef);
725         auto globalVRef =
726                 MASTER(cr_) ? statePropagatorData_->globalState_->v : gmx::ArrayRef<gmx::RVec>();
727         dd_collect_vec(cr_->dd,
728                        localStateBackup_->ddp_count,
729                        localStateBackup_->ddp_count_cg_gl,
730                        localStateBackup_->cg_gl,
731                        localStateBackup_->v,
732                        globalVRef);
733     }
734     else
735     {
736         // We have the whole state locally: copy the local state pointer
737         statePropagatorData_->globalState_ = localStateBackup_.get();
738     }
739
740     if (MASTER(cr_))
741     {
742         fprintf(stderr, "\nWriting final coordinates.\n");
743         if (canMoleculesBeDistributedOverPBC_ && !systemHasPeriodicMolecules_)
744         {
745             // Make molecules whole only for confout writing
746             do_pbc_mtop(pbcType_,
747                         localStateBackup_->box,
748                         &top_global_,
749                         statePropagatorData_->globalState_->x.rvec_array());
750         }
751         write_sto_conf_mtop(finalConfigurationFilename_.c_str(),
752                             *top_global_.name,
753                             top_global_,
754                             statePropagatorData_->globalState_->x.rvec_array(),
755                             statePropagatorData_->globalState_->v.rvec_array(),
756                             pbcType_,
757                             localStateBackup_->box);
758     }
759     wallcycle_stop(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
760 }
761
762 std::optional<SignallerCallback> StatePropagatorData::Element::registerLastStepCallback()
763 {
764     return [this](Step step, Time /*time*/) {
765         lastStep_               = step;
766         isRegularSimulationEnd_ = (step == lastPlannedStep_);
767     };
768 }
769
770 StatePropagatorData::Element::Element(StatePropagatorData* statePropagatorData,
771                                       FILE*                fplog,
772                                       const t_commrec*     cr,
773                                       int                  nstxout,
774                                       int                  nstvout,
775                                       int                  nstfout,
776                                       int                  nstxout_compressed,
777                                       bool                 canMoleculesBeDistributedOverPBC,
778                                       bool                 writeFinalConfiguration,
779                                       std::string          finalConfigurationFilename,
780                                       const t_inputrec*    inputrec,
781                                       const gmx_mtop_t&    globalTop) :
782     statePropagatorData_(statePropagatorData),
783     nstxout_(nstxout),
784     nstvout_(nstvout),
785     nstfout_(nstfout),
786     nstxout_compressed_(nstxout_compressed),
787     localStateBackup_(std::make_unique<t_state>()),
788     writeOutStep_(-1),
789     freeEnergyPerturbationData_(nullptr),
790     isRegularSimulationEnd_(false),
791     lastStep_(-1),
792     canMoleculesBeDistributedOverPBC_(canMoleculesBeDistributedOverPBC),
793     systemHasPeriodicMolecules_(inputrec->bPeriodicMols),
794     pbcType_(inputrec->pbcType),
795     lastPlannedStep_(inputrec->nsteps + inputrec->init_step),
796     writeFinalConfiguration_(writeFinalConfiguration),
797     finalConfigurationFilename_(std::move(finalConfigurationFilename)),
798     fplog_(fplog),
799     cr_(cr),
800     top_global_(globalTop)
801 {
802 }
803 void StatePropagatorData::Element::setFreeEnergyPerturbationData(FreeEnergyPerturbationData* freeEnergyPerturbationData)
804 {
805     freeEnergyPerturbationData_ = freeEnergyPerturbationData;
806 }
807
808 ISimulatorElement* StatePropagatorData::Element::getElementPointerImpl(
809         LegacySimulatorData gmx_unused*        legacySimulatorData,
810         ModularSimulatorAlgorithmBuilderHelper gmx_unused* builderHelper,
811         StatePropagatorData*                               statePropagatorData,
812         EnergyData gmx_unused*      energyData,
813         FreeEnergyPerturbationData* freeEnergyPerturbationData,
814         GlobalCommunicationHelper gmx_unused* globalCommunicationHelper,
815         ObservablesReducer* /*observablesReducer*/)
816 {
817     statePropagatorData->element()->setFreeEnergyPerturbationData(freeEnergyPerturbationData);
818     return statePropagatorData->element();
819 }
820
821 void StatePropagatorData::readCheckpointToTrxFrame(t_trxframe* trxFrame, ReadCheckpointData readCheckpointData)
822 {
823     StatePropagatorData statePropagatorData;
824     statePropagatorData.doCheckpointData(&readCheckpointData);
825
826     trxFrame->natoms = statePropagatorData.totalNumAtoms_;
827     trxFrame->bX     = true;
828     trxFrame->x  = makeRvecArray(statePropagatorData.xGlobal_, statePropagatorData.totalNumAtoms_);
829     trxFrame->bV = true;
830     trxFrame->v  = makeRvecArray(statePropagatorData.vGlobal_, statePropagatorData.totalNumAtoms_);
831     trxFrame->bF = false;
832     trxFrame->bBox = true;
833     copy_mat(statePropagatorData.box_, trxFrame->box);
834 }
835
836 const std::string& StatePropagatorData::checkpointID()
837 {
838     static const std::string identifier = "StatePropagatorData";
839     return identifier;
840 }
841
842 } // namespace gmx