cbafcfb871cd27f6e195b1db57b7b6ae248f5b64
[alexxy/gromacs.git] / src / gromacs / modularsimulator / statepropagatordata.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  * \brief Defines the state for the modular simulator
37  *
38  * \author Pascal Merz <pascal.merz@me.com>
39  * \ingroup module_modularsimulator
40  */
41
42 #include "gmxpre.h"
43
44 #include "gromacs/utility/enumerationhelpers.h"
45 #include "statepropagatordata.h"
46
47 #include "gromacs/commandline/filenm.h"
48 #include "gromacs/domdec/collect.h"
49 #include "gromacs/domdec/domdec.h"
50 #include "gromacs/fileio/confio.h"
51 #include "gromacs/math/vec.h"
52 #include "gromacs/mdlib/gmx_omp_nthreads.h"
53 #include "gromacs/mdlib/mdatoms.h"
54 #include "gromacs/mdlib/mdoutf.h"
55 #include "gromacs/mdlib/stat.h"
56 #include "gromacs/mdlib/update.h"
57 #include "gromacs/mdtypes/checkpointdata.h"
58 #include "gromacs/mdtypes/commrec.h"
59 #include "gromacs/mdtypes/forcebuffers.h"
60 #include "gromacs/mdtypes/forcerec.h"
61 #include "gromacs/mdtypes/inputrec.h"
62 #include "gromacs/mdtypes/mdatom.h"
63 #include "gromacs/mdtypes/mdrunoptions.h"
64 #include "gromacs/mdtypes/state.h"
65 #include "gromacs/pbcutil/pbc.h"
66 #include "gromacs/topology/atoms.h"
67 #include "gromacs/topology/topology.h"
68 #include "gromacs/trajectory/trajectoryframe.h"
69
70 #include "freeenergyperturbationdata.h"
71 #include "modularsimulator.h"
72 #include "simulatoralgorithm.h"
73
74 namespace gmx
75 {
76 StatePropagatorData::StatePropagatorData(int                numAtoms,
77                                          FILE*              fplog,
78                                          const t_commrec*   cr,
79                                          t_state*           globalState,
80                                          bool               useGPU,
81                                          bool               canMoleculesBeDistributedOverPBC,
82                                          bool               writeFinalConfiguration,
83                                          const std::string& finalConfigurationFilename,
84                                          const t_inputrec*  inputrec,
85                                          const t_mdatoms*   mdatoms,
86                                          const gmx_mtop_t&  globalTop) :
87     totalNumAtoms_(numAtoms),
88     localNAtoms_(0),
89     box_{ { 0 } },
90     previousBox_{ { 0 } },
91     ddpCount_(0),
92     element_(std::make_unique<Element>(this,
93                                        fplog,
94                                        cr,
95                                        inputrec->nstxout,
96                                        inputrec->nstvout,
97                                        inputrec->nstfout,
98                                        inputrec->nstxout_compressed,
99                                        canMoleculesBeDistributedOverPBC,
100                                        writeFinalConfiguration,
101                                        finalConfigurationFilename,
102                                        inputrec,
103                                        globalTop)),
104     vvResetVelocities_(false),
105     isRegularSimulationEnd_(false),
106     lastStep_(-1),
107     globalState_(globalState)
108 {
109     bool stateHasVelocities;
110     // Local state only becomes valid now.
111     if (DOMAINDECOMP(cr))
112     {
113         auto localState = std::make_unique<t_state>();
114         dd_init_local_state(*cr->dd, globalState, localState.get());
115         stateHasVelocities = ((localState->flags & enumValueToBitMask(StateEntry::V)) != 0);
116         setLocalState(std::move(localState));
117     }
118     else
119     {
120         state_change_natoms(globalState, globalState->natoms);
121         f_.resize(globalState->natoms);
122         localNAtoms_ = globalState->natoms;
123         x_           = globalState->x;
124         v_           = globalState->v;
125         copy_mat(globalState->box, box_);
126         stateHasVelocities = ((globalState->flags & enumValueToBitMask(StateEntry::V)) != 0);
127         previousX_.resizeWithPadding(localNAtoms_);
128         ddpCount_ = globalState->ddp_count;
129         copyPosition();
130     }
131     if (useGPU)
132     {
133         changePinningPolicy(&x_, gmx::PinningPolicy::PinnedIfSupported);
134     }
135
136     if (DOMAINDECOMP(cr) && MASTER(cr))
137     {
138         xGlobal_.resizeWithPadding(totalNumAtoms_);
139         previousXGlobal_.resizeWithPadding(totalNumAtoms_);
140         vGlobal_.resizeWithPadding(totalNumAtoms_);
141         fGlobal_.resizeWithPadding(totalNumAtoms_);
142     }
143
144     if (!inputrec->bContinuation)
145     {
146         if (stateHasVelocities)
147         {
148             auto v = velocitiesView().paddedArrayRef();
149             // Set the velocities of vsites, shells and frozen atoms to zero
150             for (int i = 0; i < mdatoms->homenr; i++)
151             {
152                 if (mdatoms->ptype[i] == ParticleType::Shell)
153                 {
154                     clear_rvec(v[i]);
155                 }
156                 else if (mdatoms->cFREEZE)
157                 {
158                     for (int m = 0; m < DIM; m++)
159                     {
160                         if (inputrec->opts.nFreeze[mdatoms->cFREEZE[i]][m])
161                         {
162                             v[i][m] = 0;
163                         }
164                     }
165                 }
166             }
167         }
168         if (inputrec->eI == IntegrationAlgorithm::VV)
169         {
170             vvResetVelocities_ = true;
171         }
172     }
173 }
174
175 StatePropagatorData::Element* StatePropagatorData::element()
176 {
177     return element_.get();
178 }
179
180 void StatePropagatorData::setup()
181 {
182     if (element_)
183     {
184         element_->elementSetup();
185     }
186 }
187
188 ArrayRefWithPadding<RVec> StatePropagatorData::positionsView()
189 {
190     return x_.arrayRefWithPadding();
191 }
192
193 ArrayRefWithPadding<const RVec> StatePropagatorData::constPositionsView() const
194 {
195     return x_.constArrayRefWithPadding();
196 }
197
198 ArrayRefWithPadding<RVec> StatePropagatorData::previousPositionsView()
199 {
200     return previousX_.arrayRefWithPadding();
201 }
202
203 ArrayRefWithPadding<const RVec> StatePropagatorData::constPreviousPositionsView() const
204 {
205     return previousX_.constArrayRefWithPadding();
206 }
207
208 ArrayRefWithPadding<RVec> StatePropagatorData::velocitiesView()
209 {
210     return v_.arrayRefWithPadding();
211 }
212
213 ArrayRefWithPadding<const RVec> StatePropagatorData::constVelocitiesView() const
214 {
215     return v_.constArrayRefWithPadding();
216 }
217
218 ForceBuffersView& StatePropagatorData::forcesView()
219 {
220     return f_.view();
221 }
222
223 const ForceBuffersView& StatePropagatorData::constForcesView() const
224 {
225     return f_.view();
226 }
227
228 rvec* StatePropagatorData::box()
229 {
230     return box_;
231 }
232
233 const rvec* StatePropagatorData::constBox() const
234 {
235     return box_;
236 }
237
238 rvec* StatePropagatorData::previousBox()
239 {
240     return previousBox_;
241 }
242
243 const rvec* StatePropagatorData::constPreviousBox() const
244 {
245     return previousBox_;
246 }
247
248 int StatePropagatorData::localNumAtoms() const
249 {
250     return localNAtoms_;
251 }
252
253 int StatePropagatorData::totalNumAtoms() const
254 {
255     return totalNumAtoms_;
256 }
257
258 std::unique_ptr<t_state> StatePropagatorData::localState()
259 {
260     auto state   = std::make_unique<t_state>();
261     state->flags = enumValueToBitMask(StateEntry::X) | enumValueToBitMask(StateEntry::V)
262                    | enumValueToBitMask(StateEntry::Box);
263     state_change_natoms(state.get(), localNAtoms_);
264     state->x = x_;
265     state->v = v_;
266     copy_mat(box_, state->box);
267     state->ddp_count       = ddpCount_;
268     state->ddp_count_cg_gl = ddpCountCgGl_;
269     state->cg_gl           = cgGl_;
270     return state;
271 }
272
273 void StatePropagatorData::setLocalState(std::unique_ptr<t_state> state)
274 {
275     localNAtoms_ = state->natoms;
276     x_.resizeWithPadding(localNAtoms_);
277     previousX_.resizeWithPadding(localNAtoms_);
278     v_.resizeWithPadding(localNAtoms_);
279     x_ = state->x;
280     v_ = state->v;
281     copy_mat(state->box, box_);
282     copyPosition();
283     ddpCount_     = state->ddp_count;
284     ddpCountCgGl_ = state->ddp_count_cg_gl;
285     cgGl_         = state->cg_gl;
286
287     if (vvResetVelocities_)
288     {
289         /* DomDec runs twice early in the simulation, once at setup time, and once before the first
290          * step. Every time DD runs, it sets a new local state here. We are saving a backup during
291          * setup time (ok for non-DD cases), so we need to update our backup to the DD state before
292          * the first step here to avoid resetting to an earlier DD state. This is done before any
293          * propagation that needs to be reset, so it's not very safe but correct for now.
294          * TODO: Get rid of this once input is assumed to be at half steps
295          */
296         velocityBackup_ = v_;
297     }
298 }
299
300 t_state* StatePropagatorData::globalState()
301 {
302     return globalState_;
303 }
304
305 ForceBuffers* StatePropagatorData::forcePointer()
306 {
307     return &f_;
308 }
309
310 void StatePropagatorData::copyPosition()
311 {
312     int nth = gmx_omp_nthreads_get(emntUpdate);
313
314 #pragma omp parallel for num_threads(nth) schedule(static) default(none) shared(nth)
315     for (int th = 0; th < nth; th++)
316     {
317         int start_th, end_th;
318         getThreadAtomRange(nth, th, localNAtoms_, &start_th, &end_th);
319         copyPosition(start_th, end_th);
320     }
321
322     /* Box is changed in update() when we do pressure coupling,
323      * but we should still use the old box for energy corrections and when
324      * writing it to the energy file, so it matches the trajectory files for
325      * the same timestep above. Make a copy in a separate array.
326      */
327     copy_mat(box_, previousBox_);
328 }
329
330 void StatePropagatorData::copyPosition(int start, int end)
331 {
332     for (int i = start; i < end; ++i)
333     {
334         previousX_[i] = x_[i];
335     }
336 }
337
338 void StatePropagatorData::Element::scheduleTask(Step step,
339                                                 Time gmx_unused            time,
340                                                 const RegisterRunFunction& registerRunFunction)
341 {
342     if (statePropagatorData_->vvResetVelocities_)
343     {
344         statePropagatorData_->vvResetVelocities_ = false;
345         registerRunFunction([this]() { statePropagatorData_->resetVelocities(); });
346     }
347     // copy x -> previousX
348     registerRunFunction([this]() { statePropagatorData_->copyPosition(); });
349     // if it's a write out step, keep a copy for writeout
350     if (step == writeOutStep_ || (step == lastStep_ && writeFinalConfiguration_))
351     {
352         registerRunFunction([this]() { saveState(); });
353     }
354 }
355
356 void StatePropagatorData::Element::saveState()
357 {
358     GMX_ASSERT(!localStateBackup_, "Save state called again before previous state was written.");
359     localStateBackup_ = statePropagatorData_->localState();
360     if (freeEnergyPerturbationData_)
361     {
362         localStateBackup_->fep_state = freeEnergyPerturbationData_->currentFEPState();
363         for (unsigned long i = 0;
364              i < gmx::EnumerationArray<FreeEnergyPerturbationCouplingType, double>::size();
365              ++i)
366         {
367             localStateBackup_->lambda[i] = freeEnergyPerturbationData_->constLambdaView()[i];
368         }
369         localStateBackup_->flags |=
370                 enumValueToBitMask(StateEntry::Lambda) | enumValueToBitMask(StateEntry::FepState);
371     }
372 }
373
374 std::optional<SignallerCallback> StatePropagatorData::Element::registerTrajectorySignallerCallback(TrajectoryEvent event)
375 {
376     if (event == TrajectoryEvent::StateWritingStep)
377     {
378         return [this](Step step, Time /*unused*/) { this->writeOutStep_ = step; };
379     }
380     return std::nullopt;
381 }
382
383 std::optional<ITrajectoryWriterCallback>
384 StatePropagatorData::Element::registerTrajectoryWriterCallback(TrajectoryEvent event)
385 {
386     if (event == TrajectoryEvent::StateWritingStep)
387     {
388         return [this](gmx_mdoutf* outf, Step step, Time time, bool writeTrajectory, bool gmx_unused writeLog) {
389             if (writeTrajectory)
390             {
391                 write(outf, step, time);
392             }
393         };
394     }
395     return std::nullopt;
396 }
397
398 void StatePropagatorData::Element::write(gmx_mdoutf_t outf, Step currentStep, Time currentTime)
399 {
400     wallcycle_start(mdoutf_get_wcycle(outf), ewcTRAJ);
401     unsigned int mdof_flags = 0;
402     if (do_per_step(currentStep, nstxout_))
403     {
404         mdof_flags |= MDOF_X;
405     }
406     if (do_per_step(currentStep, nstvout_))
407     {
408         mdof_flags |= MDOF_V;
409     }
410     if (do_per_step(currentStep, nstfout_))
411     {
412         mdof_flags |= MDOF_F;
413     }
414     if (do_per_step(currentStep, nstxout_compressed_))
415     {
416         mdof_flags |= MDOF_X_COMPRESSED;
417     }
418     if (do_per_step(currentStep, mdoutf_get_tng_box_output_interval(outf)))
419     {
420         mdof_flags |= MDOF_BOX;
421     }
422     if (do_per_step(currentStep, mdoutf_get_tng_lambda_output_interval(outf)))
423     {
424         mdof_flags |= MDOF_LAMBDA;
425     }
426     if (do_per_step(currentStep, mdoutf_get_tng_compressed_box_output_interval(outf)))
427     {
428         mdof_flags |= MDOF_BOX_COMPRESSED;
429     }
430     if (do_per_step(currentStep, mdoutf_get_tng_compressed_lambda_output_interval(outf)))
431     {
432         mdof_flags |= MDOF_LAMBDA_COMPRESSED;
433     }
434
435     if (mdof_flags == 0)
436     {
437         wallcycle_stop(mdoutf_get_wcycle(outf), ewcTRAJ);
438         return;
439     }
440     GMX_ASSERT(localStateBackup_, "Trajectory writing called, but no state saved.");
441
442     // TODO: This is only used for CPT - needs to be filled when we turn CPT back on
443     ObservablesHistory* observablesHistory = nullptr;
444
445     mdoutf_write_to_trajectory_files(fplog_,
446                                      cr_,
447                                      outf,
448                                      static_cast<int>(mdof_flags),
449                                      statePropagatorData_->totalNumAtoms_,
450                                      currentStep,
451                                      currentTime,
452                                      localStateBackup_.get(),
453                                      statePropagatorData_->globalState_,
454                                      observablesHistory,
455                                      statePropagatorData_->f_.view().force(),
456                                      &dummyCheckpointDataHolder_);
457
458     if (currentStep != lastStep_ || !isRegularSimulationEnd_)
459     {
460         localStateBackup_.reset();
461     }
462     wallcycle_stop(mdoutf_get_wcycle(outf), ewcTRAJ);
463 }
464
465 void StatePropagatorData::Element::elementSetup()
466 {
467     if (statePropagatorData_->vvResetVelocities_)
468     {
469         // MD-VV does the first velocity half-step only to calculate the constraint virial,
470         // then resets the velocities since the input is assumed to be positions and velocities
471         // at full time step. TODO: Change this to have input at half time steps.
472         statePropagatorData_->velocityBackup_ = statePropagatorData_->v_;
473     }
474 }
475
476 void StatePropagatorData::resetVelocities()
477 {
478     v_ = velocityBackup_;
479 }
480
481 namespace
482 {
483 /*!
484  * \brief Enum describing the contents StatePropagatorData::Element writes to modular checkpoint
485  *
486  * When changing the checkpoint content, add a new element just above Count, and adjust the
487  * checkpoint functionality.
488  */
489 enum class CheckpointVersion
490 {
491     Base, //!< First version of modular checkpointing
492     Count //!< Number of entries. Add new versions right above this!
493 };
494 constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
495 } // namespace
496
497 template<CheckpointDataOperation operation>
498 void StatePropagatorData::doCheckpointData(CheckpointData<operation>* checkpointData)
499 {
500     checkpointVersion(checkpointData, "StatePropagatorData version", c_currentVersion);
501     checkpointData->scalar("numAtoms", &totalNumAtoms_);
502
503     if (operation == CheckpointDataOperation::Read)
504     {
505         xGlobal_.resizeWithPadding(totalNumAtoms_);
506         vGlobal_.resizeWithPadding(totalNumAtoms_);
507     }
508
509     checkpointData->arrayRef("positions", makeCheckpointArrayRef<operation>(xGlobal_));
510     checkpointData->arrayRef("velocities", makeCheckpointArrayRef<operation>(vGlobal_));
511     checkpointData->tensor("box", box_);
512     checkpointData->scalar("ddpCount", &ddpCount_);
513     checkpointData->scalar("ddpCountCgGl", &ddpCountCgGl_);
514     checkpointData->arrayRef("cgGl", makeCheckpointArrayRef<operation>(cgGl_));
515 }
516
517 void StatePropagatorData::Element::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
518                                                        const t_commrec*                   cr)
519 {
520     if (DOMAINDECOMP(cr))
521     {
522         // Collect state from all ranks into global vectors
523         dd_collect_vec(cr->dd,
524                        statePropagatorData_->ddpCount_,
525                        statePropagatorData_->ddpCountCgGl_,
526                        statePropagatorData_->cgGl_,
527                        statePropagatorData_->x_,
528                        statePropagatorData_->xGlobal_);
529         dd_collect_vec(cr->dd,
530                        statePropagatorData_->ddpCount_,
531                        statePropagatorData_->ddpCountCgGl_,
532                        statePropagatorData_->cgGl_,
533                        statePropagatorData_->v_,
534                        statePropagatorData_->vGlobal_);
535     }
536     else
537     {
538         // Everything is local - copy local vectors into global ones
539         statePropagatorData_->xGlobal_.resizeWithPadding(statePropagatorData_->totalNumAtoms());
540         statePropagatorData_->vGlobal_.resizeWithPadding(statePropagatorData_->totalNumAtoms());
541         std::copy(statePropagatorData_->x_.begin(),
542                   statePropagatorData_->x_.end(),
543                   statePropagatorData_->xGlobal_.begin());
544         std::copy(statePropagatorData_->v_.begin(),
545                   statePropagatorData_->v_.end(),
546                   statePropagatorData_->vGlobal_.begin());
547     }
548     if (MASTER(cr))
549     {
550         statePropagatorData_->doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
551     }
552 }
553
554 /*!
555  * \brief Update the legacy global state
556  *
557  * When restoring from checkpoint, data will be distributed during domain decomposition at setup stage.
558  * Domain decomposition still uses the legacy global t_state object so make sure it's up-to-date.
559  */
560 static void updateGlobalState(t_state*                      globalState,
561                               const PaddedHostVector<RVec>& x,
562                               const PaddedHostVector<RVec>& v,
563                               const tensor                  box,
564                               int                           ddpCount,
565                               int                           ddpCountCgGl,
566                               const std::vector<int>&       cgGl)
567 {
568     globalState->x = x;
569     globalState->v = v;
570     copy_mat(box, globalState->box);
571     globalState->ddp_count       = ddpCount;
572     globalState->ddp_count_cg_gl = ddpCountCgGl;
573     globalState->cg_gl           = cgGl;
574 }
575
576 void StatePropagatorData::Element::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
577                                                           const t_commrec*                  cr)
578 {
579     if (MASTER(cr))
580     {
581         statePropagatorData_->doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
582     }
583
584     // Copy data to global state to be distributed by DD at setup stage
585     if (DOMAINDECOMP(cr) && MASTER(cr))
586     {
587         updateGlobalState(statePropagatorData_->globalState_,
588                           statePropagatorData_->xGlobal_,
589                           statePropagatorData_->vGlobal_,
590                           statePropagatorData_->box_,
591                           statePropagatorData_->ddpCount_,
592                           statePropagatorData_->ddpCountCgGl_,
593                           statePropagatorData_->cgGl_);
594     }
595     // Everything is local - copy global vectors to local ones
596     if (!DOMAINDECOMP(cr))
597     {
598         statePropagatorData_->x_.resizeWithPadding(statePropagatorData_->totalNumAtoms_);
599         statePropagatorData_->v_.resizeWithPadding(statePropagatorData_->totalNumAtoms_);
600         std::copy(statePropagatorData_->xGlobal_.begin(),
601                   statePropagatorData_->xGlobal_.end(),
602                   statePropagatorData_->x_.begin());
603         std::copy(statePropagatorData_->vGlobal_.begin(),
604                   statePropagatorData_->vGlobal_.end(),
605                   statePropagatorData_->v_.begin());
606     }
607 }
608
609 const std::string& StatePropagatorData::Element::clientID()
610 {
611     return StatePropagatorData::checkpointID();
612 }
613
614 void StatePropagatorData::Element::trajectoryWriterTeardown(gmx_mdoutf* gmx_unused outf)
615 {
616     // Note that part of this code is duplicated in do_md_trajectory_writing.
617     // This duplication is needed while both legacy and modular code paths are in use.
618     // TODO: Remove duplication asap, make sure to keep in sync in the meantime.
619     if (!writeFinalConfiguration_ || !isRegularSimulationEnd_)
620     {
621         return;
622     }
623
624     GMX_ASSERT(localStateBackup_, "Final trajectory writing called, but no state saved.");
625
626     wallcycle_start(mdoutf_get_wcycle(outf), ewcTRAJ);
627     if (DOMAINDECOMP(cr_))
628     {
629         auto globalXRef =
630                 MASTER(cr_) ? statePropagatorData_->globalState_->x : gmx::ArrayRef<gmx::RVec>();
631         dd_collect_vec(cr_->dd,
632                        localStateBackup_->ddp_count,
633                        localStateBackup_->ddp_count_cg_gl,
634                        localStateBackup_->cg_gl,
635                        localStateBackup_->x,
636                        globalXRef);
637         auto globalVRef =
638                 MASTER(cr_) ? statePropagatorData_->globalState_->v : gmx::ArrayRef<gmx::RVec>();
639         dd_collect_vec(cr_->dd,
640                        localStateBackup_->ddp_count,
641                        localStateBackup_->ddp_count_cg_gl,
642                        localStateBackup_->cg_gl,
643                        localStateBackup_->v,
644                        globalVRef);
645     }
646     else
647     {
648         // We have the whole state locally: copy the local state pointer
649         statePropagatorData_->globalState_ = localStateBackup_.get();
650     }
651
652     if (MASTER(cr_))
653     {
654         fprintf(stderr, "\nWriting final coordinates.\n");
655         if (canMoleculesBeDistributedOverPBC_ && !systemHasPeriodicMolecules_)
656         {
657             // Make molecules whole only for confout writing
658             do_pbc_mtop(pbcType_,
659                         localStateBackup_->box,
660                         &top_global_,
661                         statePropagatorData_->globalState_->x.rvec_array());
662         }
663         write_sto_conf_mtop(finalConfigurationFilename_.c_str(),
664                             *top_global_.name,
665                             top_global_,
666                             statePropagatorData_->globalState_->x.rvec_array(),
667                             statePropagatorData_->globalState_->v.rvec_array(),
668                             pbcType_,
669                             localStateBackup_->box);
670     }
671     wallcycle_stop(mdoutf_get_wcycle(outf), ewcTRAJ);
672 }
673
674 std::optional<SignallerCallback> StatePropagatorData::Element::registerLastStepCallback()
675 {
676     return [this](Step step, Time /*time*/) {
677         lastStep_               = step;
678         isRegularSimulationEnd_ = (step == lastPlannedStep_);
679     };
680 }
681
682 StatePropagatorData::Element::Element(StatePropagatorData* statePropagatorData,
683                                       FILE*                fplog,
684                                       const t_commrec*     cr,
685                                       int                  nstxout,
686                                       int                  nstvout,
687                                       int                  nstfout,
688                                       int                  nstxout_compressed,
689                                       bool                 canMoleculesBeDistributedOverPBC,
690                                       bool                 writeFinalConfiguration,
691                                       std::string          finalConfigurationFilename,
692                                       const t_inputrec*    inputrec,
693                                       const gmx_mtop_t&    globalTop) :
694     statePropagatorData_(statePropagatorData),
695     nstxout_(nstxout),
696     nstvout_(nstvout),
697     nstfout_(nstfout),
698     nstxout_compressed_(nstxout_compressed),
699     writeOutStep_(-1),
700     freeEnergyPerturbationData_(nullptr),
701     isRegularSimulationEnd_(false),
702     lastStep_(-1),
703     canMoleculesBeDistributedOverPBC_(canMoleculesBeDistributedOverPBC),
704     systemHasPeriodicMolecules_(inputrec->bPeriodicMols),
705     pbcType_(inputrec->pbcType),
706     lastPlannedStep_(inputrec->nsteps + inputrec->init_step),
707     writeFinalConfiguration_(writeFinalConfiguration),
708     finalConfigurationFilename_(std::move(finalConfigurationFilename)),
709     fplog_(fplog),
710     cr_(cr),
711     top_global_(globalTop)
712 {
713 }
714 void StatePropagatorData::Element::setFreeEnergyPerturbationData(FreeEnergyPerturbationData* freeEnergyPerturbationData)
715 {
716     freeEnergyPerturbationData_ = freeEnergyPerturbationData;
717 }
718
719 ISimulatorElement* StatePropagatorData::Element::getElementPointerImpl(
720         LegacySimulatorData gmx_unused*        legacySimulatorData,
721         ModularSimulatorAlgorithmBuilderHelper gmx_unused* builderHelper,
722         StatePropagatorData*                               statePropagatorData,
723         EnergyData gmx_unused*      energyData,
724         FreeEnergyPerturbationData* freeEnergyPerturbationData,
725         GlobalCommunicationHelper gmx_unused* globalCommunicationHelper)
726 {
727     statePropagatorData->element()->setFreeEnergyPerturbationData(freeEnergyPerturbationData);
728     return statePropagatorData->element();
729 }
730
731 void StatePropagatorData::readCheckpointToTrxFrame(t_trxframe* trxFrame, ReadCheckpointData readCheckpointData)
732 {
733     StatePropagatorData statePropagatorData;
734     statePropagatorData.doCheckpointData(&readCheckpointData);
735
736     trxFrame->natoms = statePropagatorData.totalNumAtoms_;
737     trxFrame->bX     = true;
738     trxFrame->x  = makeRvecArray(statePropagatorData.xGlobal_, statePropagatorData.totalNumAtoms_);
739     trxFrame->bV = true;
740     trxFrame->v  = makeRvecArray(statePropagatorData.vGlobal_, statePropagatorData.totalNumAtoms_);
741     trxFrame->bF = false;
742     trxFrame->bBox = true;
743     copy_mat(statePropagatorData.box_, trxFrame->box);
744 }
745
746 const std::string& StatePropagatorData::checkpointID()
747 {
748     static const std::string identifier = "StatePropagatorData";
749     return identifier;
750 }
751
752 } // namespace gmx