2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2019,2020, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
36 * \brief Defines the state for the modular simulator
38 * \author Pascal Merz <pascal.merz@me.com>
39 * \ingroup module_modularsimulator
44 #include "statepropagatordata.h"
46 #include "gromacs/commandline/filenm.h"
47 #include "gromacs/domdec/collect.h"
48 #include "gromacs/domdec/domdec.h"
49 #include "gromacs/fileio/confio.h"
50 #include "gromacs/math/vec.h"
51 #include "gromacs/mdlib/gmx_omp_nthreads.h"
52 #include "gromacs/mdlib/mdatoms.h"
53 #include "gromacs/mdlib/mdoutf.h"
54 #include "gromacs/mdlib/stat.h"
55 #include "gromacs/mdlib/update.h"
56 #include "gromacs/mdtypes/checkpointdata.h"
57 #include "gromacs/mdtypes/commrec.h"
58 #include "gromacs/mdtypes/forcebuffers.h"
59 #include "gromacs/mdtypes/forcerec.h"
60 #include "gromacs/mdtypes/inputrec.h"
61 #include "gromacs/mdtypes/mdatom.h"
62 #include "gromacs/mdtypes/mdrunoptions.h"
63 #include "gromacs/mdtypes/state.h"
64 #include "gromacs/pbcutil/pbc.h"
65 #include "gromacs/topology/atoms.h"
66 #include "gromacs/topology/topology.h"
68 #include "freeenergyperturbationdata.h"
69 #include "modularsimulator.h"
70 #include "simulatoralgorithm.h"
74 StatePropagatorData::StatePropagatorData(int numAtoms,
79 bool canMoleculesBeDistributedOverPBC,
80 bool writeFinalConfiguration,
81 const std::string& finalConfigurationFilename,
82 const t_inputrec* inputrec,
83 const t_mdatoms* mdatoms,
84 const gmx_mtop_t* globalTop) :
85 totalNumAtoms_(numAtoms),
88 previousBox_{ { 0 } },
90 element_(std::make_unique<Element>(this,
96 inputrec->nstxout_compressed,
97 canMoleculesBeDistributedOverPBC,
98 writeFinalConfiguration,
99 finalConfigurationFilename,
102 vvResetVelocities_(false),
103 isRegularSimulationEnd_(false),
105 globalState_(globalState)
107 bool stateHasVelocities;
108 // Local state only becomes valid now.
109 if (DOMAINDECOMP(cr))
111 auto localState = std::make_unique<t_state>();
112 dd_init_local_state(cr->dd, globalState, localState.get());
113 stateHasVelocities = ((static_cast<unsigned int>(localState->flags) & (1U << estV)) != 0U);
114 setLocalState(std::move(localState));
118 state_change_natoms(globalState, globalState->natoms);
119 f_.resize(globalState->natoms);
120 localNAtoms_ = globalState->natoms;
123 copy_mat(globalState->box, box_);
124 stateHasVelocities = ((static_cast<unsigned int>(globalState->flags) & (1U << estV)) != 0U);
125 previousX_.resizeWithPadding(localNAtoms_);
126 ddpCount_ = globalState->ddp_count;
131 changePinningPolicy(&x_, gmx::PinningPolicy::PinnedIfSupported);
134 if (DOMAINDECOMP(cr) && MASTER(cr))
136 xGlobal_.resizeWithPadding(totalNumAtoms_);
137 previousXGlobal_.resizeWithPadding(totalNumAtoms_);
138 vGlobal_.resizeWithPadding(totalNumAtoms_);
139 fGlobal_.resizeWithPadding(totalNumAtoms_);
142 if (!inputrec->bContinuation)
144 if (stateHasVelocities)
146 auto v = velocitiesView().paddedArrayRef();
147 // Set the velocities of vsites, shells and frozen atoms to zero
148 for (int i = 0; i < mdatoms->homenr; i++)
150 if (mdatoms->ptype[i] == eptVSite || mdatoms->ptype[i] == eptShell)
154 else if (mdatoms->cFREEZE)
156 for (int m = 0; m < DIM; m++)
158 if (inputrec->opts.nFreeze[mdatoms->cFREEZE[i]][m])
166 if (inputrec->eI == eiVV)
168 vvResetVelocities_ = true;
173 StatePropagatorData::Element* StatePropagatorData::element()
175 return element_.get();
178 void StatePropagatorData::setup()
182 element_->elementSetup();
186 ArrayRefWithPadding<RVec> StatePropagatorData::positionsView()
188 return x_.arrayRefWithPadding();
191 ArrayRefWithPadding<const RVec> StatePropagatorData::constPositionsView() const
193 return x_.constArrayRefWithPadding();
196 ArrayRefWithPadding<RVec> StatePropagatorData::previousPositionsView()
198 return previousX_.arrayRefWithPadding();
201 ArrayRefWithPadding<const RVec> StatePropagatorData::constPreviousPositionsView() const
203 return previousX_.constArrayRefWithPadding();
206 ArrayRefWithPadding<RVec> StatePropagatorData::velocitiesView()
208 return v_.arrayRefWithPadding();
211 ArrayRefWithPadding<const RVec> StatePropagatorData::constVelocitiesView() const
213 return v_.constArrayRefWithPadding();
216 ForceBuffersView& StatePropagatorData::forcesView()
221 const ForceBuffersView& StatePropagatorData::constForcesView() const
226 rvec* StatePropagatorData::box()
231 const rvec* StatePropagatorData::constBox() const
236 rvec* StatePropagatorData::previousBox()
241 const rvec* StatePropagatorData::constPreviousBox() const
246 int StatePropagatorData::localNumAtoms() const
251 int StatePropagatorData::totalNumAtoms() const
253 return totalNumAtoms_;
256 std::unique_ptr<t_state> StatePropagatorData::localState()
258 auto state = std::make_unique<t_state>();
259 state->flags = (1U << estX) | (1U << estV) | (1U << estBOX);
260 state_change_natoms(state.get(), localNAtoms_);
263 copy_mat(box_, state->box);
264 state->ddp_count = ddpCount_;
265 state->ddp_count_cg_gl = ddpCountCgGl_;
266 state->cg_gl = cgGl_;
270 void StatePropagatorData::setLocalState(std::unique_ptr<t_state> state)
272 localNAtoms_ = state->natoms;
273 x_.resizeWithPadding(localNAtoms_);
274 previousX_.resizeWithPadding(localNAtoms_);
275 v_.resizeWithPadding(localNAtoms_);
278 copy_mat(state->box, box_);
280 ddpCount_ = state->ddp_count;
281 ddpCountCgGl_ = state->ddp_count_cg_gl;
282 cgGl_ = state->cg_gl;
284 if (vvResetVelocities_)
286 /* DomDec runs twice early in the simulation, once at setup time, and once before the first
287 * step. Every time DD runs, it sets a new local state here. We are saving a backup during
288 * setup time (ok for non-DD cases), so we need to update our backup to the DD state before
289 * the first step here to avoid resetting to an earlier DD state. This is done before any
290 * propagation that needs to be reset, so it's not very safe but correct for now.
291 * TODO: Get rid of this once input is assumed to be at half steps
293 velocityBackup_ = v_;
297 t_state* StatePropagatorData::globalState()
302 ForceBuffers* StatePropagatorData::forcePointer()
307 void StatePropagatorData::copyPosition()
309 int nth = gmx_omp_nthreads_get(emntUpdate);
311 #pragma omp parallel for num_threads(nth) schedule(static) default(none) shared(nth)
312 for (int th = 0; th < nth; th++)
314 int start_th, end_th;
315 getThreadAtomRange(nth, th, localNAtoms_, &start_th, &end_th);
316 copyPosition(start_th, end_th);
319 /* Box is changed in update() when we do pressure coupling,
320 * but we should still use the old box for energy corrections and when
321 * writing it to the energy file, so it matches the trajectory files for
322 * the same timestep above. Make a copy in a separate array.
324 copy_mat(box_, previousBox_);
327 void StatePropagatorData::copyPosition(int start, int end)
329 for (int i = start; i < end; ++i)
331 previousX_[i] = x_[i];
335 void StatePropagatorData::Element::scheduleTask(Step step,
336 Time gmx_unused time,
337 const RegisterRunFunction& registerRunFunction)
339 if (statePropagatorData_->vvResetVelocities_)
341 statePropagatorData_->vvResetVelocities_ = false;
342 registerRunFunction([this]() { statePropagatorData_->resetVelocities(); });
344 // copy x -> previousX
345 registerRunFunction([this]() { statePropagatorData_->copyPosition(); });
346 // if it's a write out step, keep a copy for writeout
347 if (step == writeOutStep_ || (step == lastStep_ && writeFinalConfiguration_))
349 registerRunFunction([this]() { saveState(); });
353 void StatePropagatorData::Element::saveState()
355 GMX_ASSERT(!localStateBackup_, "Save state called again before previous state was written.");
356 localStateBackup_ = statePropagatorData_->localState();
357 if (freeEnergyPerturbationData_)
359 localStateBackup_->fep_state = freeEnergyPerturbationData_->currentFEPState();
360 for (unsigned long i = 0; i < localStateBackup_->lambda.size(); ++i)
362 localStateBackup_->lambda[i] = freeEnergyPerturbationData_->constLambdaView()[i];
364 localStateBackup_->flags |= (1U << estLAMBDA) | (1U << estFEPSTATE);
368 std::optional<SignallerCallback> StatePropagatorData::Element::registerTrajectorySignallerCallback(TrajectoryEvent event)
370 if (event == TrajectoryEvent::StateWritingStep)
372 return [this](Step step, Time /*unused*/) { this->writeOutStep_ = step; };
377 std::optional<ITrajectoryWriterCallback>
378 StatePropagatorData::Element::registerTrajectoryWriterCallback(TrajectoryEvent event)
380 if (event == TrajectoryEvent::StateWritingStep)
382 return [this](gmx_mdoutf* outf, Step step, Time time, bool writeTrajectory, bool gmx_unused writeLog) {
385 write(outf, step, time);
392 void StatePropagatorData::Element::write(gmx_mdoutf_t outf, Step currentStep, Time currentTime)
394 wallcycle_start(mdoutf_get_wcycle(outf), ewcTRAJ);
395 unsigned int mdof_flags = 0;
396 if (do_per_step(currentStep, nstxout_))
398 mdof_flags |= MDOF_X;
400 if (do_per_step(currentStep, nstvout_))
402 mdof_flags |= MDOF_V;
404 if (do_per_step(currentStep, nstfout_))
406 mdof_flags |= MDOF_F;
408 if (do_per_step(currentStep, nstxout_compressed_))
410 mdof_flags |= MDOF_X_COMPRESSED;
412 if (do_per_step(currentStep, mdoutf_get_tng_box_output_interval(outf)))
414 mdof_flags |= MDOF_BOX;
416 if (do_per_step(currentStep, mdoutf_get_tng_lambda_output_interval(outf)))
418 mdof_flags |= MDOF_LAMBDA;
420 if (do_per_step(currentStep, mdoutf_get_tng_compressed_box_output_interval(outf)))
422 mdof_flags |= MDOF_BOX_COMPRESSED;
424 if (do_per_step(currentStep, mdoutf_get_tng_compressed_lambda_output_interval(outf)))
426 mdof_flags |= MDOF_LAMBDA_COMPRESSED;
431 wallcycle_stop(mdoutf_get_wcycle(outf), ewcTRAJ);
434 GMX_ASSERT(localStateBackup_, "Trajectory writing called, but no state saved.");
436 // TODO: This is only used for CPT - needs to be filled when we turn CPT back on
437 ObservablesHistory* observablesHistory = nullptr;
439 mdoutf_write_to_trajectory_files(
440 fplog_, cr_, outf, static_cast<int>(mdof_flags), statePropagatorData_->totalNumAtoms_,
441 currentStep, currentTime, localStateBackup_.get(), statePropagatorData_->globalState_,
442 observablesHistory, statePropagatorData_->f_.view().force(), &dummyCheckpointDataHolder_);
444 if (currentStep != lastStep_ || !isRegularSimulationEnd_)
446 localStateBackup_.reset();
448 wallcycle_stop(mdoutf_get_wcycle(outf), ewcTRAJ);
451 void StatePropagatorData::Element::elementSetup()
453 if (statePropagatorData_->vvResetVelocities_)
455 // MD-VV does the first velocity half-step only to calculate the constraint virial,
456 // then resets the velocities since the input is assumed to be positions and velocities
457 // at full time step. TODO: Change this to have input at half time steps.
458 statePropagatorData_->velocityBackup_ = statePropagatorData_->v_;
462 void StatePropagatorData::resetVelocities()
464 v_ = velocityBackup_;
470 * \brief Enum describing the contents StatePropagatorData::Element writes to modular checkpoint
472 * When changing the checkpoint content, add a new element just above Count, and adjust the
473 * checkpoint functionality.
475 enum class CheckpointVersion
477 Base, //!< First version of modular checkpointing
478 Count //!< Number of entries. Add new versions right above this!
480 constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
483 template<CheckpointDataOperation operation>
484 void StatePropagatorData::Element::doCheckpointData(CheckpointData<operation>* checkpointData,
487 ArrayRef<RVec> xGlobalRef;
488 ArrayRef<RVec> vGlobalRef;
489 if (DOMAINDECOMP(cr))
493 xGlobalRef = statePropagatorData_->xGlobal_;
494 vGlobalRef = statePropagatorData_->vGlobal_;
496 if (operation == CheckpointDataOperation::Write)
498 dd_collect_vec(cr->dd, statePropagatorData_->ddpCount_, statePropagatorData_->ddpCountCgGl_,
499 statePropagatorData_->cgGl_, statePropagatorData_->x_, xGlobalRef);
500 dd_collect_vec(cr->dd, statePropagatorData_->ddpCount_, statePropagatorData_->ddpCountCgGl_,
501 statePropagatorData_->cgGl_, statePropagatorData_->v_, vGlobalRef);
506 xGlobalRef = statePropagatorData_->x_;
507 vGlobalRef = statePropagatorData_->v_;
511 GMX_ASSERT(checkpointData, "Master needs a valid pointer to a CheckpointData object");
512 checkpointVersion(checkpointData, "StatePropagatorData version", c_currentVersion);
514 checkpointData->arrayRef("positions", makeCheckpointArrayRef<operation>(xGlobalRef));
515 checkpointData->arrayRef("velocities", makeCheckpointArrayRef<operation>(vGlobalRef));
516 checkpointData->tensor("box", statePropagatorData_->box_);
517 checkpointData->scalar("ddpCount", &statePropagatorData_->ddpCount_);
518 checkpointData->scalar("ddpCountCgGl", &statePropagatorData_->ddpCountCgGl_);
519 checkpointData->arrayRef("cgGl", makeCheckpointArrayRef<operation>(statePropagatorData_->cgGl_));
523 void StatePropagatorData::Element::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
526 doCheckpointData<CheckpointDataOperation::Write>(
527 checkpointData ? &checkpointData.value() : nullptr, cr);
531 * \brief Update the legacy global state
533 * When restoring from checkpoint, data will be distributed during domain decomposition at setup stage.
534 * Domain decomposition still uses the legacy global t_state object so make sure it's up-to-date.
536 static void updateGlobalState(t_state* globalState,
537 const PaddedHostVector<RVec>& x,
538 const PaddedHostVector<RVec>& v,
542 const std::vector<int>& cgGl)
546 copy_mat(box, globalState->box);
547 globalState->ddp_count = ddpCount;
548 globalState->ddp_count_cg_gl = ddpCountCgGl;
549 globalState->cg_gl = cgGl;
552 void StatePropagatorData::Element::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
555 doCheckpointData<CheckpointDataOperation::Read>(checkpointData ? &checkpointData.value() : nullptr, cr);
557 // Copy data to global state to be distributed by DD at setup stage
558 if (DOMAINDECOMP(cr) && MASTER(cr))
560 updateGlobalState(statePropagatorData_->globalState_, statePropagatorData_->xGlobal_,
561 statePropagatorData_->vGlobal_, statePropagatorData_->box_,
562 statePropagatorData_->ddpCount_, statePropagatorData_->ddpCountCgGl_,
563 statePropagatorData_->cgGl_);
567 const std::string& StatePropagatorData::Element::clientID()
572 void StatePropagatorData::Element::trajectoryWriterTeardown(gmx_mdoutf* gmx_unused outf)
574 // Note that part of this code is duplicated in do_md_trajectory_writing.
575 // This duplication is needed while both legacy and modular code paths are in use.
576 // TODO: Remove duplication asap, make sure to keep in sync in the meantime.
577 if (!writeFinalConfiguration_ || !isRegularSimulationEnd_)
582 GMX_ASSERT(localStateBackup_, "Final trajectory writing called, but no state saved.");
584 wallcycle_start(mdoutf_get_wcycle(outf), ewcTRAJ);
585 if (DOMAINDECOMP(cr_))
588 MASTER(cr_) ? statePropagatorData_->globalState_->x : gmx::ArrayRef<gmx::RVec>();
589 dd_collect_vec(cr_->dd, localStateBackup_->ddp_count, localStateBackup_->ddp_count_cg_gl,
590 localStateBackup_->cg_gl, localStateBackup_->x, globalXRef);
592 MASTER(cr_) ? statePropagatorData_->globalState_->v : gmx::ArrayRef<gmx::RVec>();
593 dd_collect_vec(cr_->dd, localStateBackup_->ddp_count, localStateBackup_->ddp_count_cg_gl,
594 localStateBackup_->cg_gl, localStateBackup_->v, globalVRef);
598 // We have the whole state locally: copy the local state pointer
599 statePropagatorData_->globalState_ = localStateBackup_.get();
604 fprintf(stderr, "\nWriting final coordinates.\n");
605 if (canMoleculesBeDistributedOverPBC_ && !systemHasPeriodicMolecules_)
607 // Make molecules whole only for confout writing
608 do_pbc_mtop(pbcType_, localStateBackup_->box, top_global_,
609 statePropagatorData_->globalState_->x.rvec_array());
611 write_sto_conf_mtop(finalConfigurationFilename_.c_str(), *top_global_->name, top_global_,
612 statePropagatorData_->globalState_->x.rvec_array(),
613 statePropagatorData_->globalState_->v.rvec_array(), pbcType_,
614 localStateBackup_->box);
616 wallcycle_stop(mdoutf_get_wcycle(outf), ewcTRAJ);
619 std::optional<SignallerCallback> StatePropagatorData::Element::registerLastStepCallback()
621 return [this](Step step, Time /*time*/) {
623 isRegularSimulationEnd_ = (step == lastPlannedStep_);
627 StatePropagatorData::Element::Element(StatePropagatorData* statePropagatorData,
633 int nstxout_compressed,
634 bool canMoleculesBeDistributedOverPBC,
635 bool writeFinalConfiguration,
636 std::string finalConfigurationFilename,
637 const t_inputrec* inputrec,
638 const gmx_mtop_t* globalTop) :
639 statePropagatorData_(statePropagatorData),
643 nstxout_compressed_(nstxout_compressed),
645 freeEnergyPerturbationData_(nullptr),
646 isRegularSimulationEnd_(false),
648 canMoleculesBeDistributedOverPBC_(canMoleculesBeDistributedOverPBC),
649 systemHasPeriodicMolecules_(inputrec->bPeriodicMols),
650 pbcType_(inputrec->pbcType),
651 lastPlannedStep_(inputrec->nsteps + inputrec->init_step),
652 writeFinalConfiguration_(writeFinalConfiguration),
653 finalConfigurationFilename_(std::move(finalConfigurationFilename)),
656 top_global_(globalTop)
659 void StatePropagatorData::Element::setFreeEnergyPerturbationData(FreeEnergyPerturbationData* freeEnergyPerturbationData)
661 freeEnergyPerturbationData_ = freeEnergyPerturbationData;
664 ISimulatorElement* StatePropagatorData::Element::getElementPointerImpl(
665 LegacySimulatorData gmx_unused* legacySimulatorData,
666 ModularSimulatorAlgorithmBuilderHelper gmx_unused* builderHelper,
667 StatePropagatorData* statePropagatorData,
668 EnergyData gmx_unused* energyData,
669 FreeEnergyPerturbationData* freeEnergyPerturbationData,
670 GlobalCommunicationHelper gmx_unused* globalCommunicationHelper)
672 statePropagatorData->element()->setFreeEnergyPerturbationData(freeEnergyPerturbationData);
673 return statePropagatorData->element();