c9f72964099d40a2f2f8b19893ff2d0ab26c11e7
[alexxy/gromacs.git] / src / gromacs / modularsimulator / statepropagatordata.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  * \brief Defines the state for the modular simulator
37  *
38  * \author Pascal Merz <pascal.merz@me.com>
39  * \ingroup module_modularsimulator
40  */
41
42 #include "gmxpre.h"
43
44 #include "gromacs/utility/enumerationhelpers.h"
45 #include "statepropagatordata.h"
46
47 #include "gromacs/commandline/filenm.h"
48 #include "gromacs/domdec/collect.h"
49 #include "gromacs/domdec/domdec.h"
50 #include "gromacs/fileio/confio.h"
51 #include "gromacs/math/vec.h"
52 #include "gromacs/mdlib/gmx_omp_nthreads.h"
53 #include "gromacs/mdlib/mdatoms.h"
54 #include "gromacs/mdlib/mdoutf.h"
55 #include "gromacs/mdlib/stat.h"
56 #include "gromacs/mdlib/update.h"
57 #include "gromacs/mdtypes/checkpointdata.h"
58 #include "gromacs/mdtypes/commrec.h"
59 #include "gromacs/mdtypes/forcebuffers.h"
60 #include "gromacs/mdtypes/forcerec.h"
61 #include "gromacs/mdtypes/inputrec.h"
62 #include "gromacs/mdtypes/mdatom.h"
63 #include "gromacs/mdtypes/mdrunoptions.h"
64 #include "gromacs/mdtypes/state.h"
65 #include "gromacs/pbcutil/pbc.h"
66 #include "gromacs/topology/atoms.h"
67 #include "gromacs/topology/topology.h"
68 #include "gromacs/trajectory/trajectoryframe.h"
69
70 #include "freeenergyperturbationdata.h"
71 #include "modularsimulator.h"
72 #include "simulatoralgorithm.h"
73
74 namespace gmx
75 {
76 /*! \internal
77  * \brief Helper object to scale velocities according to reference temperature change
78  */
79 class StatePropagatorData::ReferenceTemperatureHelper
80 {
81 public:
82     //! Constructor
83     ReferenceTemperatureHelper(const t_inputrec*    inputrec,
84                                StatePropagatorData* statePropagatorData,
85                                const t_mdatoms*     mdatoms) :
86         numTemperatureGroups_(inputrec->opts.ngtc),
87         referenceTemperature_(inputrec->opts.ref_t, inputrec->opts.ref_t + inputrec->opts.ngtc),
88         velocityScalingFactors_(numTemperatureGroups_),
89         statePropagatorData_(statePropagatorData),
90         mdatoms_(mdatoms)
91     {
92     }
93
94     /*! \brief Update the reference temperature
95      *
96      * Changing the reference temperature requires scaling the velocities, which
97      * is done here.
98      *
99      * \param temperatures  New reference temperatures
100      * \param algorithm     The algorithm which initiated the temperature update
101      */
102     void updateReferenceTemperature(ArrayRef<const real>                           temperatures,
103                                     ReferenceTemperatureChangeAlgorithm gmx_unused algorithm)
104     {
105         // Currently, we don't know about any temperature change algorithms, so we assert this never gets called
106         GMX_ASSERT(false, "StatePropagatorData: Unknown ReferenceTemperatureChangeAlgorithm.");
107         for (int temperatureGroup = 0; temperatureGroup < numTemperatureGroups_; ++temperatureGroup)
108         {
109             velocityScalingFactors_[temperatureGroup] =
110                     std::sqrt(temperatures[temperatureGroup] / referenceTemperature_[temperatureGroup]);
111         }
112
113         auto velocities = statePropagatorData_->velocitiesView().unpaddedArrayRef();
114         int  nth        = gmx_omp_nthreads_get(ModuleMultiThread::Update);
115 #pragma omp parallel for num_threads(nth) schedule(static) default(none) \
116         shared(nth, velocityScalingFactors_, velocities)
117         for (int threadIndex = 0; threadIndex < nth; threadIndex++)
118         {
119             int startAtom = 0;
120             int endAtom   = 0;
121             getThreadAtomRange(nth, threadIndex, statePropagatorData_->localNAtoms_, &startAtom, &endAtom);
122             for (int atomIdx = startAtom; atomIdx < endAtom; ++atomIdx)
123             {
124                 const int temperatureGroup = (mdatoms_->cTC != nullptr) ? mdatoms_->cTC[atomIdx] : 0;
125                 velocities[atomIdx] *= velocityScalingFactors_[temperatureGroup];
126             }
127         }
128         std::copy(temperatures.begin(), temperatures.end(), referenceTemperature_.begin());
129     }
130
131 private:
132     //! The number of temperature groups
133     const int numTemperatureGroups_;
134     //! Coupling temperature per group
135     std::vector<real> referenceTemperature_;
136     //! Working array used at every temperature update
137     std::vector<real> velocityScalingFactors_;
138
139     //! Pointer to StatePropagatorData to scale velocities
140     StatePropagatorData* statePropagatorData_;
141     //! Atom parameters for this domain (temperature group information)
142     const t_mdatoms* mdatoms_;
143 };
144
145 StatePropagatorData::StatePropagatorData(int                numAtoms,
146                                          FILE*              fplog,
147                                          const t_commrec*   cr,
148                                          t_state*           globalState,
149                                          bool               useGPU,
150                                          bool               canMoleculesBeDistributedOverPBC,
151                                          bool               writeFinalConfiguration,
152                                          const std::string& finalConfigurationFilename,
153                                          const t_inputrec*  inputrec,
154                                          const t_mdatoms*   mdatoms,
155                                          const gmx_mtop_t&  globalTop) :
156     totalNumAtoms_(numAtoms),
157     localNAtoms_(0),
158     box_{ { 0 } },
159     previousBox_{ { 0 } },
160     ddpCount_(0),
161     element_(std::make_unique<Element>(this,
162                                        fplog,
163                                        cr,
164                                        inputrec->nstxout,
165                                        inputrec->nstvout,
166                                        inputrec->nstfout,
167                                        inputrec->nstxout_compressed,
168                                        canMoleculesBeDistributedOverPBC,
169                                        writeFinalConfiguration,
170                                        finalConfigurationFilename,
171                                        inputrec,
172                                        globalTop)),
173     referenceTemperatureHelper_(std::make_unique<ReferenceTemperatureHelper>(inputrec, this, mdatoms)),
174     vvResetVelocities_(false),
175     isRegularSimulationEnd_(false),
176     lastStep_(-1),
177     globalState_(globalState)
178 {
179     bool stateHasVelocities;
180     // Local state only becomes valid now.
181     if (DOMAINDECOMP(cr))
182     {
183         auto localState = std::make_unique<t_state>();
184         dd_init_local_state(*cr->dd, globalState, localState.get());
185         stateHasVelocities = ((localState->flags & enumValueToBitMask(StateEntry::V)) != 0);
186         setLocalState(std::move(localState));
187     }
188     else
189     {
190         state_change_natoms(globalState, globalState->natoms);
191         f_.resize(globalState->natoms);
192         localNAtoms_ = globalState->natoms;
193         x_           = globalState->x;
194         v_           = globalState->v;
195         copy_mat(globalState->box, box_);
196         stateHasVelocities = ((globalState->flags & enumValueToBitMask(StateEntry::V)) != 0);
197         previousX_.resizeWithPadding(localNAtoms_);
198         ddpCount_ = globalState->ddp_count;
199         copyPosition();
200     }
201     if (useGPU)
202     {
203         changePinningPolicy(&x_, gmx::PinningPolicy::PinnedIfSupported);
204     }
205
206     if (DOMAINDECOMP(cr) && MASTER(cr))
207     {
208         xGlobal_.resizeWithPadding(totalNumAtoms_);
209         previousXGlobal_.resizeWithPadding(totalNumAtoms_);
210         vGlobal_.resizeWithPadding(totalNumAtoms_);
211         fGlobal_.resizeWithPadding(totalNumAtoms_);
212     }
213
214     if (!inputrec->bContinuation)
215     {
216         if (stateHasVelocities)
217         {
218             auto v = velocitiesView().paddedArrayRef();
219             // Set the velocities of vsites, shells and frozen atoms to zero
220             for (int i = 0; i < mdatoms->homenr; i++)
221             {
222                 if (mdatoms->ptype[i] == ParticleType::Shell)
223                 {
224                     clear_rvec(v[i]);
225                 }
226                 else if (mdatoms->cFREEZE)
227                 {
228                     for (int m = 0; m < DIM; m++)
229                     {
230                         if (inputrec->opts.nFreeze[mdatoms->cFREEZE[i]][m])
231                         {
232                             v[i][m] = 0;
233                         }
234                     }
235                 }
236             }
237         }
238         if (inputrec->eI == IntegrationAlgorithm::VV)
239         {
240             vvResetVelocities_ = true;
241         }
242     }
243 }
244
245 StatePropagatorData::~StatePropagatorData() = default;
246
247 StatePropagatorData::Element* StatePropagatorData::element()
248 {
249     return element_.get();
250 }
251
252 void StatePropagatorData::setup()
253 {
254     if (element_)
255     {
256         element_->elementSetup();
257     }
258 }
259
260 ArrayRefWithPadding<RVec> StatePropagatorData::positionsView()
261 {
262     return x_.arrayRefWithPadding();
263 }
264
265 ArrayRefWithPadding<const RVec> StatePropagatorData::constPositionsView() const
266 {
267     return x_.constArrayRefWithPadding();
268 }
269
270 ArrayRefWithPadding<RVec> StatePropagatorData::previousPositionsView()
271 {
272     return previousX_.arrayRefWithPadding();
273 }
274
275 ArrayRefWithPadding<const RVec> StatePropagatorData::constPreviousPositionsView() const
276 {
277     return previousX_.constArrayRefWithPadding();
278 }
279
280 ArrayRefWithPadding<RVec> StatePropagatorData::velocitiesView()
281 {
282     return v_.arrayRefWithPadding();
283 }
284
285 ArrayRefWithPadding<const RVec> StatePropagatorData::constVelocitiesView() const
286 {
287     return v_.constArrayRefWithPadding();
288 }
289
290 ForceBuffersView& StatePropagatorData::forcesView()
291 {
292     return f_.view();
293 }
294
295 const ForceBuffersView& StatePropagatorData::constForcesView() const
296 {
297     return f_.view();
298 }
299
300 rvec* StatePropagatorData::box()
301 {
302     return box_;
303 }
304
305 const rvec* StatePropagatorData::constBox() const
306 {
307     return box_;
308 }
309
310 rvec* StatePropagatorData::previousBox()
311 {
312     return previousBox_;
313 }
314
315 const rvec* StatePropagatorData::constPreviousBox() const
316 {
317     return previousBox_;
318 }
319
320 int StatePropagatorData::localNumAtoms() const
321 {
322     return localNAtoms_;
323 }
324
325 int StatePropagatorData::totalNumAtoms() const
326 {
327     return totalNumAtoms_;
328 }
329
330 std::unique_ptr<t_state> StatePropagatorData::localState()
331 {
332     auto state   = std::make_unique<t_state>();
333     state->flags = enumValueToBitMask(StateEntry::X) | enumValueToBitMask(StateEntry::V)
334                    | enumValueToBitMask(StateEntry::Box);
335     state_change_natoms(state.get(), localNAtoms_);
336     state->x = x_;
337     state->v = v_;
338     copy_mat(box_, state->box);
339     state->ddp_count       = ddpCount_;
340     state->ddp_count_cg_gl = ddpCountCgGl_;
341     state->cg_gl           = cgGl_;
342     return state;
343 }
344
345 void StatePropagatorData::setLocalState(std::unique_ptr<t_state> state)
346 {
347     localNAtoms_ = state->natoms;
348     x_.resizeWithPadding(localNAtoms_);
349     previousX_.resizeWithPadding(localNAtoms_);
350     v_.resizeWithPadding(localNAtoms_);
351     x_ = state->x;
352     v_ = state->v;
353     copy_mat(state->box, box_);
354     copyPosition();
355     ddpCount_     = state->ddp_count;
356     ddpCountCgGl_ = state->ddp_count_cg_gl;
357     cgGl_         = state->cg_gl;
358
359     if (vvResetVelocities_)
360     {
361         /* DomDec runs twice early in the simulation, once at setup time, and once before the first
362          * step. Every time DD runs, it sets a new local state here. We are saving a backup during
363          * setup time (ok for non-DD cases), so we need to update our backup to the DD state before
364          * the first step here to avoid resetting to an earlier DD state. This is done before any
365          * propagation that needs to be reset, so it's not very safe but correct for now.
366          * TODO: Get rid of this once input is assumed to be at half steps
367          */
368         velocityBackup_ = v_;
369     }
370 }
371
372 t_state* StatePropagatorData::globalState()
373 {
374     return globalState_;
375 }
376
377 ForceBuffers* StatePropagatorData::forcePointer()
378 {
379     return &f_;
380 }
381
382 void StatePropagatorData::copyPosition()
383 {
384     int nth = gmx_omp_nthreads_get(ModuleMultiThread::Update);
385
386 #pragma omp parallel for num_threads(nth) schedule(static) default(none) shared(nth)
387     for (int th = 0; th < nth; th++)
388     {
389         int start_th, end_th;
390         getThreadAtomRange(nth, th, localNAtoms_, &start_th, &end_th);
391         copyPosition(start_th, end_th);
392     }
393
394     /* Box is changed in update() when we do pressure coupling,
395      * but we should still use the old box for energy corrections and when
396      * writing it to the energy file, so it matches the trajectory files for
397      * the same timestep above. Make a copy in a separate array.
398      */
399     copy_mat(box_, previousBox_);
400 }
401
402 void StatePropagatorData::copyPosition(int start, int end)
403 {
404     for (int i = start; i < end; ++i)
405     {
406         previousX_[i] = x_[i];
407     }
408 }
409
410 void StatePropagatorData::updateReferenceTemperature(ArrayRef<const real> temperatures,
411                                                      ReferenceTemperatureChangeAlgorithm algorithm)
412 {
413     referenceTemperatureHelper_->updateReferenceTemperature(temperatures, algorithm);
414 }
415
416 void StatePropagatorData::Element::scheduleTask(Step                       step,
417                                                 Time gmx_unused            time,
418                                                 const RegisterRunFunction& registerRunFunction)
419 {
420     if (statePropagatorData_->vvResetVelocities_)
421     {
422         statePropagatorData_->vvResetVelocities_ = false;
423         registerRunFunction([this]() { statePropagatorData_->resetVelocities(); });
424     }
425     // copy x -> previousX
426     registerRunFunction([this]() { statePropagatorData_->copyPosition(); });
427     // if it's a write out step, keep a copy for writeout
428     if (step == writeOutStep_ || (step == lastStep_ && writeFinalConfiguration_))
429     {
430         registerRunFunction([this]() { saveState(); });
431     }
432 }
433
434 void StatePropagatorData::Element::saveState()
435 {
436     GMX_ASSERT(!localStateBackup_, "Save state called again before previous state was written.");
437     localStateBackup_ = statePropagatorData_->localState();
438     if (freeEnergyPerturbationData_)
439     {
440         localStateBackup_->fep_state    = freeEnergyPerturbationData_->currentFEPState();
441         ArrayRef<const real> lambdaView = freeEnergyPerturbationData_->constLambdaView();
442         std::copy(lambdaView.begin(), lambdaView.end(), localStateBackup_->lambda.begin());
443         localStateBackup_->flags |=
444                 enumValueToBitMask(StateEntry::Lambda) | enumValueToBitMask(StateEntry::FepState);
445     }
446 }
447
448 std::optional<SignallerCallback> StatePropagatorData::Element::registerTrajectorySignallerCallback(TrajectoryEvent event)
449 {
450     if (event == TrajectoryEvent::StateWritingStep)
451     {
452         return [this](Step step, Time /*unused*/) { this->writeOutStep_ = step; };
453     }
454     return std::nullopt;
455 }
456
457 std::optional<ITrajectoryWriterCallback>
458 StatePropagatorData::Element::registerTrajectoryWriterCallback(TrajectoryEvent event)
459 {
460     if (event == TrajectoryEvent::StateWritingStep)
461     {
462         return [this](gmx_mdoutf* outf, Step step, Time time, bool writeTrajectory, bool gmx_unused writeLog) {
463             if (writeTrajectory)
464             {
465                 write(outf, step, time);
466             }
467         };
468     }
469     return std::nullopt;
470 }
471
472 void StatePropagatorData::Element::write(gmx_mdoutf_t outf, Step currentStep, Time currentTime)
473 {
474     wallcycle_start(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
475     unsigned int mdof_flags = 0;
476     if (do_per_step(currentStep, nstxout_))
477     {
478         mdof_flags |= MDOF_X;
479     }
480     if (do_per_step(currentStep, nstvout_))
481     {
482         mdof_flags |= MDOF_V;
483     }
484     if (do_per_step(currentStep, nstfout_))
485     {
486         mdof_flags |= MDOF_F;
487     }
488     if (do_per_step(currentStep, nstxout_compressed_))
489     {
490         mdof_flags |= MDOF_X_COMPRESSED;
491     }
492     if (do_per_step(currentStep, mdoutf_get_tng_box_output_interval(outf)))
493     {
494         mdof_flags |= MDOF_BOX;
495     }
496     if (do_per_step(currentStep, mdoutf_get_tng_lambda_output_interval(outf)))
497     {
498         mdof_flags |= MDOF_LAMBDA;
499     }
500     if (do_per_step(currentStep, mdoutf_get_tng_compressed_box_output_interval(outf)))
501     {
502         mdof_flags |= MDOF_BOX_COMPRESSED;
503     }
504     if (do_per_step(currentStep, mdoutf_get_tng_compressed_lambda_output_interval(outf)))
505     {
506         mdof_flags |= MDOF_LAMBDA_COMPRESSED;
507     }
508
509     if (mdof_flags == 0)
510     {
511         wallcycle_stop(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
512         return;
513     }
514     GMX_ASSERT(localStateBackup_, "Trajectory writing called, but no state saved.");
515
516     // TODO: This is only used for CPT - needs to be filled when we turn CPT back on
517     ObservablesHistory* observablesHistory = nullptr;
518
519     mdoutf_write_to_trajectory_files(fplog_,
520                                      cr_,
521                                      outf,
522                                      static_cast<int>(mdof_flags),
523                                      statePropagatorData_->totalNumAtoms_,
524                                      currentStep,
525                                      currentTime,
526                                      localStateBackup_.get(),
527                                      statePropagatorData_->globalState_,
528                                      observablesHistory,
529                                      statePropagatorData_->f_.view().force(),
530                                      &dummyCheckpointDataHolder_);
531
532     if (currentStep != lastStep_ || !isRegularSimulationEnd_)
533     {
534         localStateBackup_.reset();
535     }
536     wallcycle_stop(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
537 }
538
539 void StatePropagatorData::Element::elementSetup()
540 {
541     if (statePropagatorData_->vvResetVelocities_)
542     {
543         // MD-VV does the first velocity half-step only to calculate the constraint virial,
544         // then resets the velocities since the input is assumed to be positions and velocities
545         // at full time step. TODO: Change this to have input at half time steps.
546         statePropagatorData_->velocityBackup_ = statePropagatorData_->v_;
547     }
548 }
549
550 void StatePropagatorData::resetVelocities()
551 {
552     v_ = velocityBackup_;
553 }
554
555 namespace
556 {
557 /*!
558  * \brief Enum describing the contents StatePropagatorData::Element writes to modular checkpoint
559  *
560  * When changing the checkpoint content, add a new element just above Count, and adjust the
561  * checkpoint functionality.
562  */
563 enum class CheckpointVersion
564 {
565     Base, //!< First version of modular checkpointing
566     Count //!< Number of entries. Add new versions right above this!
567 };
568 constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
569 } // namespace
570
571 template<CheckpointDataOperation operation>
572 void StatePropagatorData::doCheckpointData(CheckpointData<operation>* checkpointData)
573 {
574     checkpointVersion(checkpointData, "StatePropagatorData version", c_currentVersion);
575     checkpointData->scalar("numAtoms", &totalNumAtoms_);
576
577     if (operation == CheckpointDataOperation::Read)
578     {
579         xGlobal_.resizeWithPadding(totalNumAtoms_);
580         vGlobal_.resizeWithPadding(totalNumAtoms_);
581     }
582
583     checkpointData->arrayRef("positions", makeCheckpointArrayRef<operation>(xGlobal_));
584     checkpointData->arrayRef("velocities", makeCheckpointArrayRef<operation>(vGlobal_));
585     checkpointData->tensor("box", box_);
586     checkpointData->scalar("ddpCount", &ddpCount_);
587     checkpointData->scalar("ddpCountCgGl", &ddpCountCgGl_);
588     checkpointData->arrayRef("cgGl", makeCheckpointArrayRef<operation>(cgGl_));
589 }
590
591 void StatePropagatorData::Element::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
592                                                        const t_commrec*                   cr)
593 {
594     if (DOMAINDECOMP(cr))
595     {
596         // Collect state from all ranks into global vectors
597         dd_collect_vec(cr->dd,
598                        statePropagatorData_->ddpCount_,
599                        statePropagatorData_->ddpCountCgGl_,
600                        statePropagatorData_->cgGl_,
601                        statePropagatorData_->x_,
602                        statePropagatorData_->xGlobal_);
603         dd_collect_vec(cr->dd,
604                        statePropagatorData_->ddpCount_,
605                        statePropagatorData_->ddpCountCgGl_,
606                        statePropagatorData_->cgGl_,
607                        statePropagatorData_->v_,
608                        statePropagatorData_->vGlobal_);
609     }
610     else
611     {
612         // Everything is local - copy local vectors into global ones
613         statePropagatorData_->xGlobal_.resizeWithPadding(statePropagatorData_->totalNumAtoms());
614         statePropagatorData_->vGlobal_.resizeWithPadding(statePropagatorData_->totalNumAtoms());
615         std::copy(statePropagatorData_->x_.begin(),
616                   statePropagatorData_->x_.end(),
617                   statePropagatorData_->xGlobal_.begin());
618         std::copy(statePropagatorData_->v_.begin(),
619                   statePropagatorData_->v_.end(),
620                   statePropagatorData_->vGlobal_.begin());
621     }
622     if (MASTER(cr))
623     {
624         statePropagatorData_->doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
625     }
626 }
627
628 /*!
629  * \brief Update the legacy global state
630  *
631  * When restoring from checkpoint, data will be distributed during domain decomposition at setup stage.
632  * Domain decomposition still uses the legacy global t_state object so make sure it's up-to-date.
633  */
634 static void updateGlobalState(t_state*                      globalState,
635                               const PaddedHostVector<RVec>& x,
636                               const PaddedHostVector<RVec>& v,
637                               const tensor                  box,
638                               int                           ddpCount,
639                               int                           ddpCountCgGl,
640                               const std::vector<int>&       cgGl)
641 {
642     globalState->x = x;
643     globalState->v = v;
644     copy_mat(box, globalState->box);
645     globalState->ddp_count       = ddpCount;
646     globalState->ddp_count_cg_gl = ddpCountCgGl;
647     globalState->cg_gl           = cgGl;
648 }
649
650 void StatePropagatorData::Element::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
651                                                           const t_commrec*                  cr)
652 {
653     if (MASTER(cr))
654     {
655         statePropagatorData_->doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
656     }
657
658     // Copy data to global state to be distributed by DD at setup stage
659     if (DOMAINDECOMP(cr) && MASTER(cr))
660     {
661         updateGlobalState(statePropagatorData_->globalState_,
662                           statePropagatorData_->xGlobal_,
663                           statePropagatorData_->vGlobal_,
664                           statePropagatorData_->box_,
665                           statePropagatorData_->ddpCount_,
666                           statePropagatorData_->ddpCountCgGl_,
667                           statePropagatorData_->cgGl_);
668     }
669     // Everything is local - copy global vectors to local ones
670     if (!DOMAINDECOMP(cr))
671     {
672         statePropagatorData_->x_.resizeWithPadding(statePropagatorData_->totalNumAtoms_);
673         statePropagatorData_->v_.resizeWithPadding(statePropagatorData_->totalNumAtoms_);
674         std::copy(statePropagatorData_->xGlobal_.begin(),
675                   statePropagatorData_->xGlobal_.end(),
676                   statePropagatorData_->x_.begin());
677         std::copy(statePropagatorData_->vGlobal_.begin(),
678                   statePropagatorData_->vGlobal_.end(),
679                   statePropagatorData_->v_.begin());
680     }
681 }
682
683 const std::string& StatePropagatorData::Element::clientID()
684 {
685     return StatePropagatorData::checkpointID();
686 }
687
688 void StatePropagatorData::Element::trajectoryWriterTeardown(gmx_mdoutf* gmx_unused outf)
689 {
690     // Note that part of this code is duplicated in do_md_trajectory_writing.
691     // This duplication is needed while both legacy and modular code paths are in use.
692     // TODO: Remove duplication asap, make sure to keep in sync in the meantime.
693     if (!writeFinalConfiguration_ || !isRegularSimulationEnd_)
694     {
695         return;
696     }
697
698     GMX_ASSERT(localStateBackup_, "Final trajectory writing called, but no state saved.");
699
700     wallcycle_start(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
701     if (DOMAINDECOMP(cr_))
702     {
703         auto globalXRef =
704                 MASTER(cr_) ? statePropagatorData_->globalState_->x : gmx::ArrayRef<gmx::RVec>();
705         dd_collect_vec(cr_->dd,
706                        localStateBackup_->ddp_count,
707                        localStateBackup_->ddp_count_cg_gl,
708                        localStateBackup_->cg_gl,
709                        localStateBackup_->x,
710                        globalXRef);
711         auto globalVRef =
712                 MASTER(cr_) ? statePropagatorData_->globalState_->v : gmx::ArrayRef<gmx::RVec>();
713         dd_collect_vec(cr_->dd,
714                        localStateBackup_->ddp_count,
715                        localStateBackup_->ddp_count_cg_gl,
716                        localStateBackup_->cg_gl,
717                        localStateBackup_->v,
718                        globalVRef);
719     }
720     else
721     {
722         // We have the whole state locally: copy the local state pointer
723         statePropagatorData_->globalState_ = localStateBackup_.get();
724     }
725
726     if (MASTER(cr_))
727     {
728         fprintf(stderr, "\nWriting final coordinates.\n");
729         if (canMoleculesBeDistributedOverPBC_ && !systemHasPeriodicMolecules_)
730         {
731             // Make molecules whole only for confout writing
732             do_pbc_mtop(pbcType_,
733                         localStateBackup_->box,
734                         &top_global_,
735                         statePropagatorData_->globalState_->x.rvec_array());
736         }
737         write_sto_conf_mtop(finalConfigurationFilename_.c_str(),
738                             *top_global_.name,
739                             top_global_,
740                             statePropagatorData_->globalState_->x.rvec_array(),
741                             statePropagatorData_->globalState_->v.rvec_array(),
742                             pbcType_,
743                             localStateBackup_->box);
744     }
745     wallcycle_stop(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
746 }
747
748 std::optional<SignallerCallback> StatePropagatorData::Element::registerLastStepCallback()
749 {
750     return [this](Step step, Time /*time*/) {
751         lastStep_               = step;
752         isRegularSimulationEnd_ = (step == lastPlannedStep_);
753     };
754 }
755
756 StatePropagatorData::Element::Element(StatePropagatorData* statePropagatorData,
757                                       FILE*                fplog,
758                                       const t_commrec*     cr,
759                                       int                  nstxout,
760                                       int                  nstvout,
761                                       int                  nstfout,
762                                       int                  nstxout_compressed,
763                                       bool                 canMoleculesBeDistributedOverPBC,
764                                       bool                 writeFinalConfiguration,
765                                       std::string          finalConfigurationFilename,
766                                       const t_inputrec*    inputrec,
767                                       const gmx_mtop_t&    globalTop) :
768     statePropagatorData_(statePropagatorData),
769     nstxout_(nstxout),
770     nstvout_(nstvout),
771     nstfout_(nstfout),
772     nstxout_compressed_(nstxout_compressed),
773     writeOutStep_(-1),
774     freeEnergyPerturbationData_(nullptr),
775     isRegularSimulationEnd_(false),
776     lastStep_(-1),
777     canMoleculesBeDistributedOverPBC_(canMoleculesBeDistributedOverPBC),
778     systemHasPeriodicMolecules_(inputrec->bPeriodicMols),
779     pbcType_(inputrec->pbcType),
780     lastPlannedStep_(inputrec->nsteps + inputrec->init_step),
781     writeFinalConfiguration_(writeFinalConfiguration),
782     finalConfigurationFilename_(std::move(finalConfigurationFilename)),
783     fplog_(fplog),
784     cr_(cr),
785     top_global_(globalTop)
786 {
787 }
788 void StatePropagatorData::Element::setFreeEnergyPerturbationData(FreeEnergyPerturbationData* freeEnergyPerturbationData)
789 {
790     freeEnergyPerturbationData_ = freeEnergyPerturbationData;
791 }
792
793 ISimulatorElement* StatePropagatorData::Element::getElementPointerImpl(
794         LegacySimulatorData gmx_unused*        legacySimulatorData,
795         ModularSimulatorAlgorithmBuilderHelper gmx_unused* builderHelper,
796         StatePropagatorData*                               statePropagatorData,
797         EnergyData gmx_unused*      energyData,
798         FreeEnergyPerturbationData* freeEnergyPerturbationData,
799         GlobalCommunicationHelper gmx_unused* globalCommunicationHelper)
800 {
801     statePropagatorData->element()->setFreeEnergyPerturbationData(freeEnergyPerturbationData);
802     return statePropagatorData->element();
803 }
804
805 void StatePropagatorData::readCheckpointToTrxFrame(t_trxframe* trxFrame, ReadCheckpointData readCheckpointData)
806 {
807     StatePropagatorData statePropagatorData;
808     statePropagatorData.doCheckpointData(&readCheckpointData);
809
810     trxFrame->natoms = statePropagatorData.totalNumAtoms_;
811     trxFrame->bX     = true;
812     trxFrame->x  = makeRvecArray(statePropagatorData.xGlobal_, statePropagatorData.totalNumAtoms_);
813     trxFrame->bV = true;
814     trxFrame->v  = makeRvecArray(statePropagatorData.vGlobal_, statePropagatorData.totalNumAtoms_);
815     trxFrame->bF = false;
816     trxFrame->bBox = true;
817     copy_mat(statePropagatorData.box_, trxFrame->box);
818 }
819
820 const std::string& StatePropagatorData::checkpointID()
821 {
822     static const std::string identifier = "StatePropagatorData";
823     return identifier;
824 }
825
826 } // namespace gmx