ce7a3c861888814a76ed3c4b693973272c199fad
[alexxy/gromacs.git] / src / gromacs / modularsimulator / statepropagatordata.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  * \brief Defines the state for the modular simulator
37  *
38  * \author Pascal Merz <pascal.merz@me.com>
39  * \ingroup module_modularsimulator
40  */
41
42 #include "gmxpre.h"
43
44 #include "statepropagatordata.h"
45
46 #include "gromacs/commandline/filenm.h"
47 #include "gromacs/domdec/collect.h"
48 #include "gromacs/domdec/domdec.h"
49 #include "gromacs/fileio/confio.h"
50 #include "gromacs/math/vec.h"
51 #include "gromacs/mdlib/gmx_omp_nthreads.h"
52 #include "gromacs/mdlib/mdatoms.h"
53 #include "gromacs/mdlib/mdoutf.h"
54 #include "gromacs/mdlib/stat.h"
55 #include "gromacs/mdlib/update.h"
56 #include "gromacs/mdtypes/checkpointdata.h"
57 #include "gromacs/mdtypes/commrec.h"
58 #include "gromacs/mdtypes/forcebuffers.h"
59 #include "gromacs/mdtypes/forcerec.h"
60 #include "gromacs/mdtypes/inputrec.h"
61 #include "gromacs/mdtypes/mdatom.h"
62 #include "gromacs/mdtypes/mdrunoptions.h"
63 #include "gromacs/mdtypes/state.h"
64 #include "gromacs/pbcutil/pbc.h"
65 #include "gromacs/topology/atoms.h"
66 #include "gromacs/topology/topology.h"
67 #include "gromacs/trajectory/trajectoryframe.h"
68
69 #include "freeenergyperturbationdata.h"
70 #include "modularsimulator.h"
71 #include "simulatoralgorithm.h"
72
73 namespace gmx
74 {
75 StatePropagatorData::StatePropagatorData(int                numAtoms,
76                                          FILE*              fplog,
77                                          const t_commrec*   cr,
78                                          t_state*           globalState,
79                                          bool               useGPU,
80                                          bool               canMoleculesBeDistributedOverPBC,
81                                          bool               writeFinalConfiguration,
82                                          const std::string& finalConfigurationFilename,
83                                          const t_inputrec*  inputrec,
84                                          const t_mdatoms*   mdatoms,
85                                          const gmx_mtop_t*  globalTop) :
86     totalNumAtoms_(numAtoms),
87     localNAtoms_(0),
88     box_{ { 0 } },
89     previousBox_{ { 0 } },
90     ddpCount_(0),
91     element_(std::make_unique<Element>(this,
92                                        fplog,
93                                        cr,
94                                        inputrec->nstxout,
95                                        inputrec->nstvout,
96                                        inputrec->nstfout,
97                                        inputrec->nstxout_compressed,
98                                        canMoleculesBeDistributedOverPBC,
99                                        writeFinalConfiguration,
100                                        finalConfigurationFilename,
101                                        inputrec,
102                                        globalTop)),
103     vvResetVelocities_(false),
104     isRegularSimulationEnd_(false),
105     lastStep_(-1),
106     globalState_(globalState)
107 {
108     bool stateHasVelocities;
109     // Local state only becomes valid now.
110     if (DOMAINDECOMP(cr))
111     {
112         auto localState = std::make_unique<t_state>();
113         dd_init_local_state(cr->dd, globalState, localState.get());
114         stateHasVelocities = ((static_cast<unsigned int>(localState->flags) & (1U << estV)) != 0U);
115         setLocalState(std::move(localState));
116     }
117     else
118     {
119         state_change_natoms(globalState, globalState->natoms);
120         f_.resize(globalState->natoms);
121         localNAtoms_ = globalState->natoms;
122         x_           = globalState->x;
123         v_           = globalState->v;
124         copy_mat(globalState->box, box_);
125         stateHasVelocities = ((static_cast<unsigned int>(globalState->flags) & (1U << estV)) != 0U);
126         previousX_.resizeWithPadding(localNAtoms_);
127         ddpCount_ = globalState->ddp_count;
128         copyPosition();
129     }
130     if (useGPU)
131     {
132         changePinningPolicy(&x_, gmx::PinningPolicy::PinnedIfSupported);
133     }
134
135     if (DOMAINDECOMP(cr) && MASTER(cr))
136     {
137         xGlobal_.resizeWithPadding(totalNumAtoms_);
138         previousXGlobal_.resizeWithPadding(totalNumAtoms_);
139         vGlobal_.resizeWithPadding(totalNumAtoms_);
140         fGlobal_.resizeWithPadding(totalNumAtoms_);
141     }
142
143     if (!inputrec->bContinuation)
144     {
145         if (stateHasVelocities)
146         {
147             auto v = velocitiesView().paddedArrayRef();
148             // Set the velocities of vsites, shells and frozen atoms to zero
149             for (int i = 0; i < mdatoms->homenr; i++)
150             {
151                 if (mdatoms->ptype[i] == eptShell)
152                 {
153                     clear_rvec(v[i]);
154                 }
155                 else if (mdatoms->cFREEZE)
156                 {
157                     for (int m = 0; m < DIM; m++)
158                     {
159                         if (inputrec->opts.nFreeze[mdatoms->cFREEZE[i]][m])
160                         {
161                             v[i][m] = 0;
162                         }
163                     }
164                 }
165             }
166         }
167         if (inputrec->eI == eiVV)
168         {
169             vvResetVelocities_ = true;
170         }
171     }
172 }
173
174 StatePropagatorData::Element* StatePropagatorData::element()
175 {
176     return element_.get();
177 }
178
179 void StatePropagatorData::setup()
180 {
181     if (element_)
182     {
183         element_->elementSetup();
184     }
185 }
186
187 ArrayRefWithPadding<RVec> StatePropagatorData::positionsView()
188 {
189     return x_.arrayRefWithPadding();
190 }
191
192 ArrayRefWithPadding<const RVec> StatePropagatorData::constPositionsView() const
193 {
194     return x_.constArrayRefWithPadding();
195 }
196
197 ArrayRefWithPadding<RVec> StatePropagatorData::previousPositionsView()
198 {
199     return previousX_.arrayRefWithPadding();
200 }
201
202 ArrayRefWithPadding<const RVec> StatePropagatorData::constPreviousPositionsView() const
203 {
204     return previousX_.constArrayRefWithPadding();
205 }
206
207 ArrayRefWithPadding<RVec> StatePropagatorData::velocitiesView()
208 {
209     return v_.arrayRefWithPadding();
210 }
211
212 ArrayRefWithPadding<const RVec> StatePropagatorData::constVelocitiesView() const
213 {
214     return v_.constArrayRefWithPadding();
215 }
216
217 ForceBuffersView& StatePropagatorData::forcesView()
218 {
219     return f_.view();
220 }
221
222 const ForceBuffersView& StatePropagatorData::constForcesView() const
223 {
224     return f_.view();
225 }
226
227 rvec* StatePropagatorData::box()
228 {
229     return box_;
230 }
231
232 const rvec* StatePropagatorData::constBox() const
233 {
234     return box_;
235 }
236
237 rvec* StatePropagatorData::previousBox()
238 {
239     return previousBox_;
240 }
241
242 const rvec* StatePropagatorData::constPreviousBox() const
243 {
244     return previousBox_;
245 }
246
247 int StatePropagatorData::localNumAtoms() const
248 {
249     return localNAtoms_;
250 }
251
252 int StatePropagatorData::totalNumAtoms() const
253 {
254     return totalNumAtoms_;
255 }
256
257 std::unique_ptr<t_state> StatePropagatorData::localState()
258 {
259     auto state   = std::make_unique<t_state>();
260     state->flags = (1U << estX) | (1U << estV) | (1U << estBOX);
261     state_change_natoms(state.get(), localNAtoms_);
262     state->x = x_;
263     state->v = v_;
264     copy_mat(box_, state->box);
265     state->ddp_count       = ddpCount_;
266     state->ddp_count_cg_gl = ddpCountCgGl_;
267     state->cg_gl           = cgGl_;
268     return state;
269 }
270
271 void StatePropagatorData::setLocalState(std::unique_ptr<t_state> state)
272 {
273     localNAtoms_ = state->natoms;
274     x_.resizeWithPadding(localNAtoms_);
275     previousX_.resizeWithPadding(localNAtoms_);
276     v_.resizeWithPadding(localNAtoms_);
277     x_ = state->x;
278     v_ = state->v;
279     copy_mat(state->box, box_);
280     copyPosition();
281     ddpCount_     = state->ddp_count;
282     ddpCountCgGl_ = state->ddp_count_cg_gl;
283     cgGl_         = state->cg_gl;
284
285     if (vvResetVelocities_)
286     {
287         /* DomDec runs twice early in the simulation, once at setup time, and once before the first
288          * step. Every time DD runs, it sets a new local state here. We are saving a backup during
289          * setup time (ok for non-DD cases), so we need to update our backup to the DD state before
290          * the first step here to avoid resetting to an earlier DD state. This is done before any
291          * propagation that needs to be reset, so it's not very safe but correct for now.
292          * TODO: Get rid of this once input is assumed to be at half steps
293          */
294         velocityBackup_ = v_;
295     }
296 }
297
298 t_state* StatePropagatorData::globalState()
299 {
300     return globalState_;
301 }
302
303 ForceBuffers* StatePropagatorData::forcePointer()
304 {
305     return &f_;
306 }
307
308 void StatePropagatorData::copyPosition()
309 {
310     int nth = gmx_omp_nthreads_get(emntUpdate);
311
312 #pragma omp parallel for num_threads(nth) schedule(static) default(none) shared(nth)
313     for (int th = 0; th < nth; th++)
314     {
315         int start_th, end_th;
316         getThreadAtomRange(nth, th, localNAtoms_, &start_th, &end_th);
317         copyPosition(start_th, end_th);
318     }
319
320     /* Box is changed in update() when we do pressure coupling,
321      * but we should still use the old box for energy corrections and when
322      * writing it to the energy file, so it matches the trajectory files for
323      * the same timestep above. Make a copy in a separate array.
324      */
325     copy_mat(box_, previousBox_);
326 }
327
328 void StatePropagatorData::copyPosition(int start, int end)
329 {
330     for (int i = start; i < end; ++i)
331     {
332         previousX_[i] = x_[i];
333     }
334 }
335
336 void StatePropagatorData::Element::scheduleTask(Step step,
337                                                 Time gmx_unused            time,
338                                                 const RegisterRunFunction& registerRunFunction)
339 {
340     if (statePropagatorData_->vvResetVelocities_)
341     {
342         statePropagatorData_->vvResetVelocities_ = false;
343         registerRunFunction([this]() { statePropagatorData_->resetVelocities(); });
344     }
345     // copy x -> previousX
346     registerRunFunction([this]() { statePropagatorData_->copyPosition(); });
347     // if it's a write out step, keep a copy for writeout
348     if (step == writeOutStep_ || (step == lastStep_ && writeFinalConfiguration_))
349     {
350         registerRunFunction([this]() { saveState(); });
351     }
352 }
353
354 void StatePropagatorData::Element::saveState()
355 {
356     GMX_ASSERT(!localStateBackup_, "Save state called again before previous state was written.");
357     localStateBackup_ = statePropagatorData_->localState();
358     if (freeEnergyPerturbationData_)
359     {
360         localStateBackup_->fep_state = freeEnergyPerturbationData_->currentFEPState();
361         for (unsigned long i = 0; i < localStateBackup_->lambda.size(); ++i)
362         {
363             localStateBackup_->lambda[i] = freeEnergyPerturbationData_->constLambdaView()[i];
364         }
365         localStateBackup_->flags |= (1U << estLAMBDA) | (1U << estFEPSTATE);
366     }
367 }
368
369 std::optional<SignallerCallback> StatePropagatorData::Element::registerTrajectorySignallerCallback(TrajectoryEvent event)
370 {
371     if (event == TrajectoryEvent::StateWritingStep)
372     {
373         return [this](Step step, Time /*unused*/) { this->writeOutStep_ = step; };
374     }
375     return std::nullopt;
376 }
377
378 std::optional<ITrajectoryWriterCallback>
379 StatePropagatorData::Element::registerTrajectoryWriterCallback(TrajectoryEvent event)
380 {
381     if (event == TrajectoryEvent::StateWritingStep)
382     {
383         return [this](gmx_mdoutf* outf, Step step, Time time, bool writeTrajectory, bool gmx_unused writeLog) {
384             if (writeTrajectory)
385             {
386                 write(outf, step, time);
387             }
388         };
389     }
390     return std::nullopt;
391 }
392
393 void StatePropagatorData::Element::write(gmx_mdoutf_t outf, Step currentStep, Time currentTime)
394 {
395     wallcycle_start(mdoutf_get_wcycle(outf), ewcTRAJ);
396     unsigned int mdof_flags = 0;
397     if (do_per_step(currentStep, nstxout_))
398     {
399         mdof_flags |= MDOF_X;
400     }
401     if (do_per_step(currentStep, nstvout_))
402     {
403         mdof_flags |= MDOF_V;
404     }
405     if (do_per_step(currentStep, nstfout_))
406     {
407         mdof_flags |= MDOF_F;
408     }
409     if (do_per_step(currentStep, nstxout_compressed_))
410     {
411         mdof_flags |= MDOF_X_COMPRESSED;
412     }
413     if (do_per_step(currentStep, mdoutf_get_tng_box_output_interval(outf)))
414     {
415         mdof_flags |= MDOF_BOX;
416     }
417     if (do_per_step(currentStep, mdoutf_get_tng_lambda_output_interval(outf)))
418     {
419         mdof_flags |= MDOF_LAMBDA;
420     }
421     if (do_per_step(currentStep, mdoutf_get_tng_compressed_box_output_interval(outf)))
422     {
423         mdof_flags |= MDOF_BOX_COMPRESSED;
424     }
425     if (do_per_step(currentStep, mdoutf_get_tng_compressed_lambda_output_interval(outf)))
426     {
427         mdof_flags |= MDOF_LAMBDA_COMPRESSED;
428     }
429
430     if (mdof_flags == 0)
431     {
432         wallcycle_stop(mdoutf_get_wcycle(outf), ewcTRAJ);
433         return;
434     }
435     GMX_ASSERT(localStateBackup_, "Trajectory writing called, but no state saved.");
436
437     // TODO: This is only used for CPT - needs to be filled when we turn CPT back on
438     ObservablesHistory* observablesHistory = nullptr;
439
440     mdoutf_write_to_trajectory_files(fplog_,
441                                      cr_,
442                                      outf,
443                                      static_cast<int>(mdof_flags),
444                                      statePropagatorData_->totalNumAtoms_,
445                                      currentStep,
446                                      currentTime,
447                                      localStateBackup_.get(),
448                                      statePropagatorData_->globalState_,
449                                      observablesHistory,
450                                      statePropagatorData_->f_.view().force(),
451                                      &dummyCheckpointDataHolder_);
452
453     if (currentStep != lastStep_ || !isRegularSimulationEnd_)
454     {
455         localStateBackup_.reset();
456     }
457     wallcycle_stop(mdoutf_get_wcycle(outf), ewcTRAJ);
458 }
459
460 void StatePropagatorData::Element::elementSetup()
461 {
462     if (statePropagatorData_->vvResetVelocities_)
463     {
464         // MD-VV does the first velocity half-step only to calculate the constraint virial,
465         // then resets the velocities since the input is assumed to be positions and velocities
466         // at full time step. TODO: Change this to have input at half time steps.
467         statePropagatorData_->velocityBackup_ = statePropagatorData_->v_;
468     }
469 }
470
471 void StatePropagatorData::resetVelocities()
472 {
473     v_ = velocityBackup_;
474 }
475
476 namespace
477 {
478 /*!
479  * \brief Enum describing the contents StatePropagatorData::Element writes to modular checkpoint
480  *
481  * When changing the checkpoint content, add a new element just above Count, and adjust the
482  * checkpoint functionality.
483  */
484 enum class CheckpointVersion
485 {
486     Base, //!< First version of modular checkpointing
487     Count //!< Number of entries. Add new versions right above this!
488 };
489 constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
490 } // namespace
491
492 template<CheckpointDataOperation operation>
493 void StatePropagatorData::doCheckpointData(CheckpointData<operation>* checkpointData)
494 {
495     checkpointVersion(checkpointData, "StatePropagatorData version", c_currentVersion);
496     checkpointData->scalar("numAtoms", &totalNumAtoms_);
497
498     if (operation == CheckpointDataOperation::Read)
499     {
500         xGlobal_.resizeWithPadding(totalNumAtoms_);
501         vGlobal_.resizeWithPadding(totalNumAtoms_);
502     }
503
504     checkpointData->arrayRef("positions", makeCheckpointArrayRef<operation>(xGlobal_));
505     checkpointData->arrayRef("velocities", makeCheckpointArrayRef<operation>(vGlobal_));
506     checkpointData->tensor("box", box_);
507     checkpointData->scalar("ddpCount", &ddpCount_);
508     checkpointData->scalar("ddpCountCgGl", &ddpCountCgGl_);
509     checkpointData->arrayRef("cgGl", makeCheckpointArrayRef<operation>(cgGl_));
510 }
511
512 void StatePropagatorData::Element::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
513                                                        const t_commrec*                   cr)
514 {
515     if (DOMAINDECOMP(cr))
516     {
517         // Collect state from all ranks into global vectors
518         dd_collect_vec(cr->dd,
519                        statePropagatorData_->ddpCount_,
520                        statePropagatorData_->ddpCountCgGl_,
521                        statePropagatorData_->cgGl_,
522                        statePropagatorData_->x_,
523                        statePropagatorData_->xGlobal_);
524         dd_collect_vec(cr->dd,
525                        statePropagatorData_->ddpCount_,
526                        statePropagatorData_->ddpCountCgGl_,
527                        statePropagatorData_->cgGl_,
528                        statePropagatorData_->v_,
529                        statePropagatorData_->vGlobal_);
530     }
531     else
532     {
533         // Everything is local - copy local vectors into global ones
534         statePropagatorData_->xGlobal_.resizeWithPadding(statePropagatorData_->totalNumAtoms());
535         statePropagatorData_->vGlobal_.resizeWithPadding(statePropagatorData_->totalNumAtoms());
536         std::copy(statePropagatorData_->x_.begin(),
537                   statePropagatorData_->x_.end(),
538                   statePropagatorData_->xGlobal_.begin());
539         std::copy(statePropagatorData_->v_.begin(),
540                   statePropagatorData_->v_.end(),
541                   statePropagatorData_->vGlobal_.begin());
542     }
543     if (MASTER(cr))
544     {
545         statePropagatorData_->doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
546     }
547 }
548
549 /*!
550  * \brief Update the legacy global state
551  *
552  * When restoring from checkpoint, data will be distributed during domain decomposition at setup stage.
553  * Domain decomposition still uses the legacy global t_state object so make sure it's up-to-date.
554  */
555 static void updateGlobalState(t_state*                      globalState,
556                               const PaddedHostVector<RVec>& x,
557                               const PaddedHostVector<RVec>& v,
558                               const tensor                  box,
559                               int                           ddpCount,
560                               int                           ddpCountCgGl,
561                               const std::vector<int>&       cgGl)
562 {
563     globalState->x = x;
564     globalState->v = v;
565     copy_mat(box, globalState->box);
566     globalState->ddp_count       = ddpCount;
567     globalState->ddp_count_cg_gl = ddpCountCgGl;
568     globalState->cg_gl           = cgGl;
569 }
570
571 void StatePropagatorData::Element::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
572                                                           const t_commrec*                  cr)
573 {
574     if (MASTER(cr))
575     {
576         statePropagatorData_->doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
577     }
578
579     // Copy data to global state to be distributed by DD at setup stage
580     if (DOMAINDECOMP(cr) && MASTER(cr))
581     {
582         updateGlobalState(statePropagatorData_->globalState_,
583                           statePropagatorData_->xGlobal_,
584                           statePropagatorData_->vGlobal_,
585                           statePropagatorData_->box_,
586                           statePropagatorData_->ddpCount_,
587                           statePropagatorData_->ddpCountCgGl_,
588                           statePropagatorData_->cgGl_);
589     }
590     // Everything is local - copy global vectors to local ones
591     if (!DOMAINDECOMP(cr))
592     {
593         statePropagatorData_->x_.resizeWithPadding(statePropagatorData_->totalNumAtoms_);
594         statePropagatorData_->v_.resizeWithPadding(statePropagatorData_->totalNumAtoms_);
595         std::copy(statePropagatorData_->xGlobal_.begin(),
596                   statePropagatorData_->xGlobal_.end(),
597                   statePropagatorData_->x_.begin());
598         std::copy(statePropagatorData_->vGlobal_.begin(),
599                   statePropagatorData_->vGlobal_.end(),
600                   statePropagatorData_->v_.begin());
601     }
602 }
603
604 const std::string& StatePropagatorData::Element::clientID()
605 {
606     return StatePropagatorData::checkpointID();
607 }
608
609 void StatePropagatorData::Element::trajectoryWriterTeardown(gmx_mdoutf* gmx_unused outf)
610 {
611     // Note that part of this code is duplicated in do_md_trajectory_writing.
612     // This duplication is needed while both legacy and modular code paths are in use.
613     // TODO: Remove duplication asap, make sure to keep in sync in the meantime.
614     if (!writeFinalConfiguration_ || !isRegularSimulationEnd_)
615     {
616         return;
617     }
618
619     GMX_ASSERT(localStateBackup_, "Final trajectory writing called, but no state saved.");
620
621     wallcycle_start(mdoutf_get_wcycle(outf), ewcTRAJ);
622     if (DOMAINDECOMP(cr_))
623     {
624         auto globalXRef =
625                 MASTER(cr_) ? statePropagatorData_->globalState_->x : gmx::ArrayRef<gmx::RVec>();
626         dd_collect_vec(cr_->dd,
627                        localStateBackup_->ddp_count,
628                        localStateBackup_->ddp_count_cg_gl,
629                        localStateBackup_->cg_gl,
630                        localStateBackup_->x,
631                        globalXRef);
632         auto globalVRef =
633                 MASTER(cr_) ? statePropagatorData_->globalState_->v : gmx::ArrayRef<gmx::RVec>();
634         dd_collect_vec(cr_->dd,
635                        localStateBackup_->ddp_count,
636                        localStateBackup_->ddp_count_cg_gl,
637                        localStateBackup_->cg_gl,
638                        localStateBackup_->v,
639                        globalVRef);
640     }
641     else
642     {
643         // We have the whole state locally: copy the local state pointer
644         statePropagatorData_->globalState_ = localStateBackup_.get();
645     }
646
647     if (MASTER(cr_))
648     {
649         fprintf(stderr, "\nWriting final coordinates.\n");
650         if (canMoleculesBeDistributedOverPBC_ && !systemHasPeriodicMolecules_)
651         {
652             // Make molecules whole only for confout writing
653             do_pbc_mtop(pbcType_,
654                         localStateBackup_->box,
655                         top_global_,
656                         statePropagatorData_->globalState_->x.rvec_array());
657         }
658         write_sto_conf_mtop(finalConfigurationFilename_.c_str(),
659                             *top_global_->name,
660                             top_global_,
661                             statePropagatorData_->globalState_->x.rvec_array(),
662                             statePropagatorData_->globalState_->v.rvec_array(),
663                             pbcType_,
664                             localStateBackup_->box);
665     }
666     wallcycle_stop(mdoutf_get_wcycle(outf), ewcTRAJ);
667 }
668
669 std::optional<SignallerCallback> StatePropagatorData::Element::registerLastStepCallback()
670 {
671     return [this](Step step, Time /*time*/) {
672         lastStep_               = step;
673         isRegularSimulationEnd_ = (step == lastPlannedStep_);
674     };
675 }
676
677 StatePropagatorData::Element::Element(StatePropagatorData* statePropagatorData,
678                                       FILE*                fplog,
679                                       const t_commrec*     cr,
680                                       int                  nstxout,
681                                       int                  nstvout,
682                                       int                  nstfout,
683                                       int                  nstxout_compressed,
684                                       bool                 canMoleculesBeDistributedOverPBC,
685                                       bool                 writeFinalConfiguration,
686                                       std::string          finalConfigurationFilename,
687                                       const t_inputrec*    inputrec,
688                                       const gmx_mtop_t*    globalTop) :
689     statePropagatorData_(statePropagatorData),
690     nstxout_(nstxout),
691     nstvout_(nstvout),
692     nstfout_(nstfout),
693     nstxout_compressed_(nstxout_compressed),
694     writeOutStep_(-1),
695     freeEnergyPerturbationData_(nullptr),
696     isRegularSimulationEnd_(false),
697     lastStep_(-1),
698     canMoleculesBeDistributedOverPBC_(canMoleculesBeDistributedOverPBC),
699     systemHasPeriodicMolecules_(inputrec->bPeriodicMols),
700     pbcType_(inputrec->pbcType),
701     lastPlannedStep_(inputrec->nsteps + inputrec->init_step),
702     writeFinalConfiguration_(writeFinalConfiguration),
703     finalConfigurationFilename_(std::move(finalConfigurationFilename)),
704     fplog_(fplog),
705     cr_(cr),
706     top_global_(globalTop)
707 {
708 }
709 void StatePropagatorData::Element::setFreeEnergyPerturbationData(FreeEnergyPerturbationData* freeEnergyPerturbationData)
710 {
711     freeEnergyPerturbationData_ = freeEnergyPerturbationData;
712 }
713
714 ISimulatorElement* StatePropagatorData::Element::getElementPointerImpl(
715         LegacySimulatorData gmx_unused*        legacySimulatorData,
716         ModularSimulatorAlgorithmBuilderHelper gmx_unused* builderHelper,
717         StatePropagatorData*                               statePropagatorData,
718         EnergyData gmx_unused*      energyData,
719         FreeEnergyPerturbationData* freeEnergyPerturbationData,
720         GlobalCommunicationHelper gmx_unused* globalCommunicationHelper)
721 {
722     statePropagatorData->element()->setFreeEnergyPerturbationData(freeEnergyPerturbationData);
723     return statePropagatorData->element();
724 }
725
726 void StatePropagatorData::readCheckpointToTrxFrame(t_trxframe* trxFrame, ReadCheckpointData readCheckpointData)
727 {
728     StatePropagatorData statePropagatorData;
729     statePropagatorData.doCheckpointData(&readCheckpointData);
730
731     trxFrame->natoms = statePropagatorData.totalNumAtoms_;
732     trxFrame->bX     = true;
733     trxFrame->x  = makeRvecArray(statePropagatorData.xGlobal_, statePropagatorData.totalNumAtoms_);
734     trxFrame->bV = true;
735     trxFrame->v  = makeRvecArray(statePropagatorData.vGlobal_, statePropagatorData.totalNumAtoms_);
736     trxFrame->bF = false;
737     trxFrame->bBox = true;
738     copy_mat(statePropagatorData.box_, trxFrame->box);
739 }
740
741 const std::string& StatePropagatorData::checkpointID()
742 {
743     static const std::string identifier = "StatePropagatorData";
744     return identifier;
745 }
746
747 } // namespace gmx