4dd6a899f27f8bd358a543a8e2a4ecec5f22eaea
[alexxy/gromacs.git] / src / gromacs / modularsimulator / statepropagatordata.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  * \brief Defines the state for the modular simulator
37  *
38  * \author Pascal Merz <pascal.merz@me.com>
39  * \ingroup module_modularsimulator
40  */
41
42 #include "gmxpre.h"
43
44 #include "gromacs/utility/enumerationhelpers.h"
45 #include "statepropagatordata.h"
46
47 #include "gromacs/commandline/filenm.h"
48 #include "gromacs/domdec/collect.h"
49 #include "gromacs/domdec/domdec.h"
50 #include "gromacs/fileio/confio.h"
51 #include "gromacs/math/vec.h"
52 #include "gromacs/mdlib/gmx_omp_nthreads.h"
53 #include "gromacs/mdlib/mdatoms.h"
54 #include "gromacs/mdlib/mdoutf.h"
55 #include "gromacs/mdlib/stat.h"
56 #include "gromacs/mdlib/update.h"
57 #include "gromacs/mdtypes/checkpointdata.h"
58 #include "gromacs/mdtypes/commrec.h"
59 #include "gromacs/mdtypes/forcebuffers.h"
60 #include "gromacs/mdtypes/forcerec.h"
61 #include "gromacs/mdtypes/inputrec.h"
62 #include "gromacs/mdtypes/mdatom.h"
63 #include "gromacs/mdtypes/mdrunoptions.h"
64 #include "gromacs/mdtypes/state.h"
65 #include "gromacs/pbcutil/pbc.h"
66 #include "gromacs/topology/atoms.h"
67 #include "gromacs/topology/topology.h"
68 #include "gromacs/trajectory/trajectoryframe.h"
69
70 #include "freeenergyperturbationdata.h"
71 #include "modularsimulator.h"
72 #include "simulatoralgorithm.h"
73
74 namespace gmx
75 {
76 StatePropagatorData::StatePropagatorData(int                numAtoms,
77                                          FILE*              fplog,
78                                          const t_commrec*   cr,
79                                          t_state*           globalState,
80                                          bool               useGPU,
81                                          bool               canMoleculesBeDistributedOverPBC,
82                                          bool               writeFinalConfiguration,
83                                          const std::string& finalConfigurationFilename,
84                                          const t_inputrec*  inputrec,
85                                          const t_mdatoms*   mdatoms,
86                                          const gmx_mtop_t&  globalTop) :
87     totalNumAtoms_(numAtoms),
88     localNAtoms_(0),
89     box_{ { 0 } },
90     previousBox_{ { 0 } },
91     ddpCount_(0),
92     element_(std::make_unique<Element>(this,
93                                        fplog,
94                                        cr,
95                                        inputrec->nstxout,
96                                        inputrec->nstvout,
97                                        inputrec->nstfout,
98                                        inputrec->nstxout_compressed,
99                                        canMoleculesBeDistributedOverPBC,
100                                        writeFinalConfiguration,
101                                        finalConfigurationFilename,
102                                        inputrec,
103                                        globalTop)),
104     vvResetVelocities_(false),
105     isRegularSimulationEnd_(false),
106     lastStep_(-1),
107     globalState_(globalState)
108 {
109     bool stateHasVelocities;
110     // Local state only becomes valid now.
111     if (DOMAINDECOMP(cr))
112     {
113         auto localState = std::make_unique<t_state>();
114         dd_init_local_state(*cr->dd, globalState, localState.get());
115         stateHasVelocities = ((localState->flags & enumValueToBitMask(StateEntry::V)) != 0);
116         setLocalState(std::move(localState));
117     }
118     else
119     {
120         state_change_natoms(globalState, globalState->natoms);
121         f_.resize(globalState->natoms);
122         localNAtoms_ = globalState->natoms;
123         x_           = globalState->x;
124         v_           = globalState->v;
125         copy_mat(globalState->box, box_);
126         stateHasVelocities = ((globalState->flags & enumValueToBitMask(StateEntry::V)) != 0);
127         previousX_.resizeWithPadding(localNAtoms_);
128         ddpCount_ = globalState->ddp_count;
129         copyPosition();
130     }
131     if (useGPU)
132     {
133         changePinningPolicy(&x_, gmx::PinningPolicy::PinnedIfSupported);
134     }
135
136     if (DOMAINDECOMP(cr) && MASTER(cr))
137     {
138         xGlobal_.resizeWithPadding(totalNumAtoms_);
139         previousXGlobal_.resizeWithPadding(totalNumAtoms_);
140         vGlobal_.resizeWithPadding(totalNumAtoms_);
141         fGlobal_.resizeWithPadding(totalNumAtoms_);
142     }
143
144     if (!inputrec->bContinuation)
145     {
146         if (stateHasVelocities)
147         {
148             auto v = velocitiesView().paddedArrayRef();
149             // Set the velocities of vsites, shells and frozen atoms to zero
150             for (int i = 0; i < mdatoms->homenr; i++)
151             {
152                 if (mdatoms->ptype[i] == ParticleType::Shell)
153                 {
154                     clear_rvec(v[i]);
155                 }
156                 else if (mdatoms->cFREEZE)
157                 {
158                     for (int m = 0; m < DIM; m++)
159                     {
160                         if (inputrec->opts.nFreeze[mdatoms->cFREEZE[i]][m])
161                         {
162                             v[i][m] = 0;
163                         }
164                     }
165                 }
166             }
167         }
168         if (inputrec->eI == IntegrationAlgorithm::VV)
169         {
170             vvResetVelocities_ = true;
171         }
172     }
173 }
174
175 StatePropagatorData::Element* StatePropagatorData::element()
176 {
177     return element_.get();
178 }
179
180 void StatePropagatorData::setup()
181 {
182     if (element_)
183     {
184         element_->elementSetup();
185     }
186 }
187
188 ArrayRefWithPadding<RVec> StatePropagatorData::positionsView()
189 {
190     return x_.arrayRefWithPadding();
191 }
192
193 ArrayRefWithPadding<const RVec> StatePropagatorData::constPositionsView() const
194 {
195     return x_.constArrayRefWithPadding();
196 }
197
198 ArrayRefWithPadding<RVec> StatePropagatorData::previousPositionsView()
199 {
200     return previousX_.arrayRefWithPadding();
201 }
202
203 ArrayRefWithPadding<const RVec> StatePropagatorData::constPreviousPositionsView() const
204 {
205     return previousX_.constArrayRefWithPadding();
206 }
207
208 ArrayRefWithPadding<RVec> StatePropagatorData::velocitiesView()
209 {
210     return v_.arrayRefWithPadding();
211 }
212
213 ArrayRefWithPadding<const RVec> StatePropagatorData::constVelocitiesView() const
214 {
215     return v_.constArrayRefWithPadding();
216 }
217
218 ForceBuffersView& StatePropagatorData::forcesView()
219 {
220     return f_.view();
221 }
222
223 const ForceBuffersView& StatePropagatorData::constForcesView() const
224 {
225     return f_.view();
226 }
227
228 rvec* StatePropagatorData::box()
229 {
230     return box_;
231 }
232
233 const rvec* StatePropagatorData::constBox() const
234 {
235     return box_;
236 }
237
238 rvec* StatePropagatorData::previousBox()
239 {
240     return previousBox_;
241 }
242
243 const rvec* StatePropagatorData::constPreviousBox() const
244 {
245     return previousBox_;
246 }
247
248 int StatePropagatorData::localNumAtoms() const
249 {
250     return localNAtoms_;
251 }
252
253 int StatePropagatorData::totalNumAtoms() const
254 {
255     return totalNumAtoms_;
256 }
257
258 std::unique_ptr<t_state> StatePropagatorData::localState()
259 {
260     auto state   = std::make_unique<t_state>();
261     state->flags = enumValueToBitMask(StateEntry::X) | enumValueToBitMask(StateEntry::V)
262                    | enumValueToBitMask(StateEntry::Box);
263     state_change_natoms(state.get(), localNAtoms_);
264     state->x = x_;
265     state->v = v_;
266     copy_mat(box_, state->box);
267     state->ddp_count       = ddpCount_;
268     state->ddp_count_cg_gl = ddpCountCgGl_;
269     state->cg_gl           = cgGl_;
270     return state;
271 }
272
273 void StatePropagatorData::setLocalState(std::unique_ptr<t_state> state)
274 {
275     localNAtoms_ = state->natoms;
276     x_.resizeWithPadding(localNAtoms_);
277     previousX_.resizeWithPadding(localNAtoms_);
278     v_.resizeWithPadding(localNAtoms_);
279     x_ = state->x;
280     v_ = state->v;
281     copy_mat(state->box, box_);
282     copyPosition();
283     ddpCount_     = state->ddp_count;
284     ddpCountCgGl_ = state->ddp_count_cg_gl;
285     cgGl_         = state->cg_gl;
286
287     if (vvResetVelocities_)
288     {
289         /* DomDec runs twice early in the simulation, once at setup time, and once before the first
290          * step. Every time DD runs, it sets a new local state here. We are saving a backup during
291          * setup time (ok for non-DD cases), so we need to update our backup to the DD state before
292          * the first step here to avoid resetting to an earlier DD state. This is done before any
293          * propagation that needs to be reset, so it's not very safe but correct for now.
294          * TODO: Get rid of this once input is assumed to be at half steps
295          */
296         velocityBackup_ = v_;
297     }
298 }
299
300 t_state* StatePropagatorData::globalState()
301 {
302     return globalState_;
303 }
304
305 ForceBuffers* StatePropagatorData::forcePointer()
306 {
307     return &f_;
308 }
309
310 void StatePropagatorData::copyPosition()
311 {
312     int nth = gmx_omp_nthreads_get(ModuleMultiThread::Update);
313
314 #pragma omp parallel for num_threads(nth) schedule(static) default(none) shared(nth)
315     for (int th = 0; th < nth; th++)
316     {
317         int start_th, end_th;
318         getThreadAtomRange(nth, th, localNAtoms_, &start_th, &end_th);
319         copyPosition(start_th, end_th);
320     }
321
322     /* Box is changed in update() when we do pressure coupling,
323      * but we should still use the old box for energy corrections and when
324      * writing it to the energy file, so it matches the trajectory files for
325      * the same timestep above. Make a copy in a separate array.
326      */
327     copy_mat(box_, previousBox_);
328 }
329
330 void StatePropagatorData::copyPosition(int start, int end)
331 {
332     for (int i = start; i < end; ++i)
333     {
334         previousX_[i] = x_[i];
335     }
336 }
337
338 void StatePropagatorData::Element::scheduleTask(Step step,
339                                                 Time gmx_unused            time,
340                                                 const RegisterRunFunction& registerRunFunction)
341 {
342     if (statePropagatorData_->vvResetVelocities_)
343     {
344         statePropagatorData_->vvResetVelocities_ = false;
345         registerRunFunction([this]() { statePropagatorData_->resetVelocities(); });
346     }
347     // copy x -> previousX
348     registerRunFunction([this]() { statePropagatorData_->copyPosition(); });
349     // if it's a write out step, keep a copy for writeout
350     if (step == writeOutStep_ || (step == lastStep_ && writeFinalConfiguration_))
351     {
352         registerRunFunction([this]() { saveState(); });
353     }
354 }
355
356 void StatePropagatorData::Element::saveState()
357 {
358     GMX_ASSERT(!localStateBackup_, "Save state called again before previous state was written.");
359     localStateBackup_ = statePropagatorData_->localState();
360     if (freeEnergyPerturbationData_)
361     {
362         localStateBackup_->fep_state    = freeEnergyPerturbationData_->currentFEPState();
363         ArrayRef<const real> lambdaView = freeEnergyPerturbationData_->constLambdaView();
364         std::copy(lambdaView.begin(), lambdaView.end(), localStateBackup_->lambda.begin());
365         localStateBackup_->flags |=
366                 enumValueToBitMask(StateEntry::Lambda) | enumValueToBitMask(StateEntry::FepState);
367     }
368 }
369
370 std::optional<SignallerCallback> StatePropagatorData::Element::registerTrajectorySignallerCallback(TrajectoryEvent event)
371 {
372     if (event == TrajectoryEvent::StateWritingStep)
373     {
374         return [this](Step step, Time /*unused*/) { this->writeOutStep_ = step; };
375     }
376     return std::nullopt;
377 }
378
379 std::optional<ITrajectoryWriterCallback>
380 StatePropagatorData::Element::registerTrajectoryWriterCallback(TrajectoryEvent event)
381 {
382     if (event == TrajectoryEvent::StateWritingStep)
383     {
384         return [this](gmx_mdoutf* outf, Step step, Time time, bool writeTrajectory, bool gmx_unused writeLog) {
385             if (writeTrajectory)
386             {
387                 write(outf, step, time);
388             }
389         };
390     }
391     return std::nullopt;
392 }
393
394 void StatePropagatorData::Element::write(gmx_mdoutf_t outf, Step currentStep, Time currentTime)
395 {
396     wallcycle_start(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
397     unsigned int mdof_flags = 0;
398     if (do_per_step(currentStep, nstxout_))
399     {
400         mdof_flags |= MDOF_X;
401     }
402     if (do_per_step(currentStep, nstvout_))
403     {
404         mdof_flags |= MDOF_V;
405     }
406     if (do_per_step(currentStep, nstfout_))
407     {
408         mdof_flags |= MDOF_F;
409     }
410     if (do_per_step(currentStep, nstxout_compressed_))
411     {
412         mdof_flags |= MDOF_X_COMPRESSED;
413     }
414     if (do_per_step(currentStep, mdoutf_get_tng_box_output_interval(outf)))
415     {
416         mdof_flags |= MDOF_BOX;
417     }
418     if (do_per_step(currentStep, mdoutf_get_tng_lambda_output_interval(outf)))
419     {
420         mdof_flags |= MDOF_LAMBDA;
421     }
422     if (do_per_step(currentStep, mdoutf_get_tng_compressed_box_output_interval(outf)))
423     {
424         mdof_flags |= MDOF_BOX_COMPRESSED;
425     }
426     if (do_per_step(currentStep, mdoutf_get_tng_compressed_lambda_output_interval(outf)))
427     {
428         mdof_flags |= MDOF_LAMBDA_COMPRESSED;
429     }
430
431     if (mdof_flags == 0)
432     {
433         wallcycle_stop(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
434         return;
435     }
436     GMX_ASSERT(localStateBackup_, "Trajectory writing called, but no state saved.");
437
438     // TODO: This is only used for CPT - needs to be filled when we turn CPT back on
439     ObservablesHistory* observablesHistory = nullptr;
440
441     mdoutf_write_to_trajectory_files(fplog_,
442                                      cr_,
443                                      outf,
444                                      static_cast<int>(mdof_flags),
445                                      statePropagatorData_->totalNumAtoms_,
446                                      currentStep,
447                                      currentTime,
448                                      localStateBackup_.get(),
449                                      statePropagatorData_->globalState_,
450                                      observablesHistory,
451                                      statePropagatorData_->f_.view().force(),
452                                      &dummyCheckpointDataHolder_);
453
454     if (currentStep != lastStep_ || !isRegularSimulationEnd_)
455     {
456         localStateBackup_.reset();
457     }
458     wallcycle_stop(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
459 }
460
461 void StatePropagatorData::Element::elementSetup()
462 {
463     if (statePropagatorData_->vvResetVelocities_)
464     {
465         // MD-VV does the first velocity half-step only to calculate the constraint virial,
466         // then resets the velocities since the input is assumed to be positions and velocities
467         // at full time step. TODO: Change this to have input at half time steps.
468         statePropagatorData_->velocityBackup_ = statePropagatorData_->v_;
469     }
470 }
471
472 void StatePropagatorData::resetVelocities()
473 {
474     v_ = velocityBackup_;
475 }
476
477 namespace
478 {
479 /*!
480  * \brief Enum describing the contents StatePropagatorData::Element writes to modular checkpoint
481  *
482  * When changing the checkpoint content, add a new element just above Count, and adjust the
483  * checkpoint functionality.
484  */
485 enum class CheckpointVersion
486 {
487     Base, //!< First version of modular checkpointing
488     Count //!< Number of entries. Add new versions right above this!
489 };
490 constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
491 } // namespace
492
493 template<CheckpointDataOperation operation>
494 void StatePropagatorData::doCheckpointData(CheckpointData<operation>* checkpointData)
495 {
496     checkpointVersion(checkpointData, "StatePropagatorData version", c_currentVersion);
497     checkpointData->scalar("numAtoms", &totalNumAtoms_);
498
499     if (operation == CheckpointDataOperation::Read)
500     {
501         xGlobal_.resizeWithPadding(totalNumAtoms_);
502         vGlobal_.resizeWithPadding(totalNumAtoms_);
503     }
504
505     checkpointData->arrayRef("positions", makeCheckpointArrayRef<operation>(xGlobal_));
506     checkpointData->arrayRef("velocities", makeCheckpointArrayRef<operation>(vGlobal_));
507     checkpointData->tensor("box", box_);
508     checkpointData->scalar("ddpCount", &ddpCount_);
509     checkpointData->scalar("ddpCountCgGl", &ddpCountCgGl_);
510     checkpointData->arrayRef("cgGl", makeCheckpointArrayRef<operation>(cgGl_));
511 }
512
513 void StatePropagatorData::Element::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
514                                                        const t_commrec*                   cr)
515 {
516     if (DOMAINDECOMP(cr))
517     {
518         // Collect state from all ranks into global vectors
519         dd_collect_vec(cr->dd,
520                        statePropagatorData_->ddpCount_,
521                        statePropagatorData_->ddpCountCgGl_,
522                        statePropagatorData_->cgGl_,
523                        statePropagatorData_->x_,
524                        statePropagatorData_->xGlobal_);
525         dd_collect_vec(cr->dd,
526                        statePropagatorData_->ddpCount_,
527                        statePropagatorData_->ddpCountCgGl_,
528                        statePropagatorData_->cgGl_,
529                        statePropagatorData_->v_,
530                        statePropagatorData_->vGlobal_);
531     }
532     else
533     {
534         // Everything is local - copy local vectors into global ones
535         statePropagatorData_->xGlobal_.resizeWithPadding(statePropagatorData_->totalNumAtoms());
536         statePropagatorData_->vGlobal_.resizeWithPadding(statePropagatorData_->totalNumAtoms());
537         std::copy(statePropagatorData_->x_.begin(),
538                   statePropagatorData_->x_.end(),
539                   statePropagatorData_->xGlobal_.begin());
540         std::copy(statePropagatorData_->v_.begin(),
541                   statePropagatorData_->v_.end(),
542                   statePropagatorData_->vGlobal_.begin());
543     }
544     if (MASTER(cr))
545     {
546         statePropagatorData_->doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
547     }
548 }
549
550 /*!
551  * \brief Update the legacy global state
552  *
553  * When restoring from checkpoint, data will be distributed during domain decomposition at setup stage.
554  * Domain decomposition still uses the legacy global t_state object so make sure it's up-to-date.
555  */
556 static void updateGlobalState(t_state*                      globalState,
557                               const PaddedHostVector<RVec>& x,
558                               const PaddedHostVector<RVec>& v,
559                               const tensor                  box,
560                               int                           ddpCount,
561                               int                           ddpCountCgGl,
562                               const std::vector<int>&       cgGl)
563 {
564     globalState->x = x;
565     globalState->v = v;
566     copy_mat(box, globalState->box);
567     globalState->ddp_count       = ddpCount;
568     globalState->ddp_count_cg_gl = ddpCountCgGl;
569     globalState->cg_gl           = cgGl;
570 }
571
572 void StatePropagatorData::Element::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
573                                                           const t_commrec*                  cr)
574 {
575     if (MASTER(cr))
576     {
577         statePropagatorData_->doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
578     }
579
580     // Copy data to global state to be distributed by DD at setup stage
581     if (DOMAINDECOMP(cr) && MASTER(cr))
582     {
583         updateGlobalState(statePropagatorData_->globalState_,
584                           statePropagatorData_->xGlobal_,
585                           statePropagatorData_->vGlobal_,
586                           statePropagatorData_->box_,
587                           statePropagatorData_->ddpCount_,
588                           statePropagatorData_->ddpCountCgGl_,
589                           statePropagatorData_->cgGl_);
590     }
591     // Everything is local - copy global vectors to local ones
592     if (!DOMAINDECOMP(cr))
593     {
594         statePropagatorData_->x_.resizeWithPadding(statePropagatorData_->totalNumAtoms_);
595         statePropagatorData_->v_.resizeWithPadding(statePropagatorData_->totalNumAtoms_);
596         std::copy(statePropagatorData_->xGlobal_.begin(),
597                   statePropagatorData_->xGlobal_.end(),
598                   statePropagatorData_->x_.begin());
599         std::copy(statePropagatorData_->vGlobal_.begin(),
600                   statePropagatorData_->vGlobal_.end(),
601                   statePropagatorData_->v_.begin());
602     }
603 }
604
605 const std::string& StatePropagatorData::Element::clientID()
606 {
607     return StatePropagatorData::checkpointID();
608 }
609
610 void StatePropagatorData::Element::trajectoryWriterTeardown(gmx_mdoutf* gmx_unused outf)
611 {
612     // Note that part of this code is duplicated in do_md_trajectory_writing.
613     // This duplication is needed while both legacy and modular code paths are in use.
614     // TODO: Remove duplication asap, make sure to keep in sync in the meantime.
615     if (!writeFinalConfiguration_ || !isRegularSimulationEnd_)
616     {
617         return;
618     }
619
620     GMX_ASSERT(localStateBackup_, "Final trajectory writing called, but no state saved.");
621
622     wallcycle_start(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
623     if (DOMAINDECOMP(cr_))
624     {
625         auto globalXRef =
626                 MASTER(cr_) ? statePropagatorData_->globalState_->x : gmx::ArrayRef<gmx::RVec>();
627         dd_collect_vec(cr_->dd,
628                        localStateBackup_->ddp_count,
629                        localStateBackup_->ddp_count_cg_gl,
630                        localStateBackup_->cg_gl,
631                        localStateBackup_->x,
632                        globalXRef);
633         auto globalVRef =
634                 MASTER(cr_) ? statePropagatorData_->globalState_->v : gmx::ArrayRef<gmx::RVec>();
635         dd_collect_vec(cr_->dd,
636                        localStateBackup_->ddp_count,
637                        localStateBackup_->ddp_count_cg_gl,
638                        localStateBackup_->cg_gl,
639                        localStateBackup_->v,
640                        globalVRef);
641     }
642     else
643     {
644         // We have the whole state locally: copy the local state pointer
645         statePropagatorData_->globalState_ = localStateBackup_.get();
646     }
647
648     if (MASTER(cr_))
649     {
650         fprintf(stderr, "\nWriting final coordinates.\n");
651         if (canMoleculesBeDistributedOverPBC_ && !systemHasPeriodicMolecules_)
652         {
653             // Make molecules whole only for confout writing
654             do_pbc_mtop(pbcType_,
655                         localStateBackup_->box,
656                         &top_global_,
657                         statePropagatorData_->globalState_->x.rvec_array());
658         }
659         write_sto_conf_mtop(finalConfigurationFilename_.c_str(),
660                             *top_global_.name,
661                             top_global_,
662                             statePropagatorData_->globalState_->x.rvec_array(),
663                             statePropagatorData_->globalState_->v.rvec_array(),
664                             pbcType_,
665                             localStateBackup_->box);
666     }
667     wallcycle_stop(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
668 }
669
670 std::optional<SignallerCallback> StatePropagatorData::Element::registerLastStepCallback()
671 {
672     return [this](Step step, Time /*time*/) {
673         lastStep_               = step;
674         isRegularSimulationEnd_ = (step == lastPlannedStep_);
675     };
676 }
677
678 StatePropagatorData::Element::Element(StatePropagatorData* statePropagatorData,
679                                       FILE*                fplog,
680                                       const t_commrec*     cr,
681                                       int                  nstxout,
682                                       int                  nstvout,
683                                       int                  nstfout,
684                                       int                  nstxout_compressed,
685                                       bool                 canMoleculesBeDistributedOverPBC,
686                                       bool                 writeFinalConfiguration,
687                                       std::string          finalConfigurationFilename,
688                                       const t_inputrec*    inputrec,
689                                       const gmx_mtop_t&    globalTop) :
690     statePropagatorData_(statePropagatorData),
691     nstxout_(nstxout),
692     nstvout_(nstvout),
693     nstfout_(nstfout),
694     nstxout_compressed_(nstxout_compressed),
695     writeOutStep_(-1),
696     freeEnergyPerturbationData_(nullptr),
697     isRegularSimulationEnd_(false),
698     lastStep_(-1),
699     canMoleculesBeDistributedOverPBC_(canMoleculesBeDistributedOverPBC),
700     systemHasPeriodicMolecules_(inputrec->bPeriodicMols),
701     pbcType_(inputrec->pbcType),
702     lastPlannedStep_(inputrec->nsteps + inputrec->init_step),
703     writeFinalConfiguration_(writeFinalConfiguration),
704     finalConfigurationFilename_(std::move(finalConfigurationFilename)),
705     fplog_(fplog),
706     cr_(cr),
707     top_global_(globalTop)
708 {
709 }
710 void StatePropagatorData::Element::setFreeEnergyPerturbationData(FreeEnergyPerturbationData* freeEnergyPerturbationData)
711 {
712     freeEnergyPerturbationData_ = freeEnergyPerturbationData;
713 }
714
715 ISimulatorElement* StatePropagatorData::Element::getElementPointerImpl(
716         LegacySimulatorData gmx_unused*        legacySimulatorData,
717         ModularSimulatorAlgorithmBuilderHelper gmx_unused* builderHelper,
718         StatePropagatorData*                               statePropagatorData,
719         EnergyData gmx_unused*      energyData,
720         FreeEnergyPerturbationData* freeEnergyPerturbationData,
721         GlobalCommunicationHelper gmx_unused* globalCommunicationHelper)
722 {
723     statePropagatorData->element()->setFreeEnergyPerturbationData(freeEnergyPerturbationData);
724     return statePropagatorData->element();
725 }
726
727 void StatePropagatorData::readCheckpointToTrxFrame(t_trxframe* trxFrame, ReadCheckpointData readCheckpointData)
728 {
729     StatePropagatorData statePropagatorData;
730     statePropagatorData.doCheckpointData(&readCheckpointData);
731
732     trxFrame->natoms = statePropagatorData.totalNumAtoms_;
733     trxFrame->bX     = true;
734     trxFrame->x  = makeRvecArray(statePropagatorData.xGlobal_, statePropagatorData.totalNumAtoms_);
735     trxFrame->bV = true;
736     trxFrame->v  = makeRvecArray(statePropagatorData.vGlobal_, statePropagatorData.totalNumAtoms_);
737     trxFrame->bF = false;
738     trxFrame->bBox = true;
739     copy_mat(statePropagatorData.box_, trxFrame->box);
740 }
741
742 const std::string& StatePropagatorData::checkpointID()
743 {
744     static const std::string identifier = "StatePropagatorData";
745     return identifier;
746 }
747
748 } // namespace gmx