2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
36 * \brief Defines the state for the modular simulator
38 * \author Pascal Merz <pascal.merz@me.com>
39 * \ingroup module_modularsimulator
44 #include "gromacs/utility/enumerationhelpers.h"
45 #include "statepropagatordata.h"
47 #include "gromacs/commandline/filenm.h"
48 #include "gromacs/domdec/collect.h"
49 #include "gromacs/domdec/domdec.h"
50 #include "gromacs/fileio/confio.h"
51 #include "gromacs/math/vec.h"
52 #include "gromacs/mdlib/gmx_omp_nthreads.h"
53 #include "gromacs/mdlib/mdatoms.h"
54 #include "gromacs/mdlib/mdoutf.h"
55 #include "gromacs/mdlib/stat.h"
56 #include "gromacs/mdlib/update.h"
57 #include "gromacs/mdtypes/checkpointdata.h"
58 #include "gromacs/mdtypes/commrec.h"
59 #include "gromacs/mdtypes/forcebuffers.h"
60 #include "gromacs/mdtypes/forcerec.h"
61 #include "gromacs/mdtypes/inputrec.h"
62 #include "gromacs/mdtypes/mdatom.h"
63 #include "gromacs/mdtypes/mdrunoptions.h"
64 #include "gromacs/mdtypes/state.h"
65 #include "gromacs/pbcutil/pbc.h"
66 #include "gromacs/topology/atoms.h"
67 #include "gromacs/topology/topology.h"
68 #include "gromacs/trajectory/trajectoryframe.h"
70 #include "freeenergyperturbationdata.h"
71 #include "modularsimulator.h"
72 #include "simulatoralgorithm.h"
76 StatePropagatorData::StatePropagatorData(int numAtoms,
81 bool canMoleculesBeDistributedOverPBC,
82 bool writeFinalConfiguration,
83 const std::string& finalConfigurationFilename,
84 const t_inputrec* inputrec,
85 const t_mdatoms* mdatoms,
86 const gmx_mtop_t& globalTop) :
87 totalNumAtoms_(numAtoms),
90 previousBox_{ { 0 } },
92 element_(std::make_unique<Element>(this,
98 inputrec->nstxout_compressed,
99 canMoleculesBeDistributedOverPBC,
100 writeFinalConfiguration,
101 finalConfigurationFilename,
104 vvResetVelocities_(false),
105 isRegularSimulationEnd_(false),
107 globalState_(globalState)
109 bool stateHasVelocities;
110 // Local state only becomes valid now.
111 if (DOMAINDECOMP(cr))
113 auto localState = std::make_unique<t_state>();
114 dd_init_local_state(*cr->dd, globalState, localState.get());
115 stateHasVelocities = ((localState->flags & enumValueToBitMask(StateEntry::V)) != 0);
116 setLocalState(std::move(localState));
120 state_change_natoms(globalState, globalState->natoms);
121 f_.resize(globalState->natoms);
122 localNAtoms_ = globalState->natoms;
125 copy_mat(globalState->box, box_);
126 stateHasVelocities = ((globalState->flags & enumValueToBitMask(StateEntry::V)) != 0);
127 previousX_.resizeWithPadding(localNAtoms_);
128 ddpCount_ = globalState->ddp_count;
133 changePinningPolicy(&x_, gmx::PinningPolicy::PinnedIfSupported);
136 if (DOMAINDECOMP(cr) && MASTER(cr))
138 xGlobal_.resizeWithPadding(totalNumAtoms_);
139 previousXGlobal_.resizeWithPadding(totalNumAtoms_);
140 vGlobal_.resizeWithPadding(totalNumAtoms_);
141 fGlobal_.resizeWithPadding(totalNumAtoms_);
144 if (!inputrec->bContinuation)
146 if (stateHasVelocities)
148 auto v = velocitiesView().paddedArrayRef();
149 // Set the velocities of vsites, shells and frozen atoms to zero
150 for (int i = 0; i < mdatoms->homenr; i++)
152 if (mdatoms->ptype[i] == ParticleType::Shell)
156 else if (mdatoms->cFREEZE)
158 for (int m = 0; m < DIM; m++)
160 if (inputrec->opts.nFreeze[mdatoms->cFREEZE[i]][m])
168 if (inputrec->eI == IntegrationAlgorithm::VV)
170 vvResetVelocities_ = true;
175 StatePropagatorData::Element* StatePropagatorData::element()
177 return element_.get();
180 void StatePropagatorData::setup()
184 element_->elementSetup();
188 ArrayRefWithPadding<RVec> StatePropagatorData::positionsView()
190 return x_.arrayRefWithPadding();
193 ArrayRefWithPadding<const RVec> StatePropagatorData::constPositionsView() const
195 return x_.constArrayRefWithPadding();
198 ArrayRefWithPadding<RVec> StatePropagatorData::previousPositionsView()
200 return previousX_.arrayRefWithPadding();
203 ArrayRefWithPadding<const RVec> StatePropagatorData::constPreviousPositionsView() const
205 return previousX_.constArrayRefWithPadding();
208 ArrayRefWithPadding<RVec> StatePropagatorData::velocitiesView()
210 return v_.arrayRefWithPadding();
213 ArrayRefWithPadding<const RVec> StatePropagatorData::constVelocitiesView() const
215 return v_.constArrayRefWithPadding();
218 ForceBuffersView& StatePropagatorData::forcesView()
223 const ForceBuffersView& StatePropagatorData::constForcesView() const
228 rvec* StatePropagatorData::box()
233 const rvec* StatePropagatorData::constBox() const
238 rvec* StatePropagatorData::previousBox()
243 const rvec* StatePropagatorData::constPreviousBox() const
248 int StatePropagatorData::localNumAtoms() const
253 int StatePropagatorData::totalNumAtoms() const
255 return totalNumAtoms_;
258 std::unique_ptr<t_state> StatePropagatorData::localState()
260 auto state = std::make_unique<t_state>();
261 state->flags = enumValueToBitMask(StateEntry::X) | enumValueToBitMask(StateEntry::V)
262 | enumValueToBitMask(StateEntry::Box);
263 state_change_natoms(state.get(), localNAtoms_);
266 copy_mat(box_, state->box);
267 state->ddp_count = ddpCount_;
268 state->ddp_count_cg_gl = ddpCountCgGl_;
269 state->cg_gl = cgGl_;
273 void StatePropagatorData::setLocalState(std::unique_ptr<t_state> state)
275 localNAtoms_ = state->natoms;
276 x_.resizeWithPadding(localNAtoms_);
277 previousX_.resizeWithPadding(localNAtoms_);
278 v_.resizeWithPadding(localNAtoms_);
281 copy_mat(state->box, box_);
283 ddpCount_ = state->ddp_count;
284 ddpCountCgGl_ = state->ddp_count_cg_gl;
285 cgGl_ = state->cg_gl;
287 if (vvResetVelocities_)
289 /* DomDec runs twice early in the simulation, once at setup time, and once before the first
290 * step. Every time DD runs, it sets a new local state here. We are saving a backup during
291 * setup time (ok for non-DD cases), so we need to update our backup to the DD state before
292 * the first step here to avoid resetting to an earlier DD state. This is done before any
293 * propagation that needs to be reset, so it's not very safe but correct for now.
294 * TODO: Get rid of this once input is assumed to be at half steps
296 velocityBackup_ = v_;
300 t_state* StatePropagatorData::globalState()
305 ForceBuffers* StatePropagatorData::forcePointer()
310 void StatePropagatorData::copyPosition()
312 int nth = gmx_omp_nthreads_get(ModuleMultiThread::Update);
314 #pragma omp parallel for num_threads(nth) schedule(static) default(none) shared(nth)
315 for (int th = 0; th < nth; th++)
317 int start_th, end_th;
318 getThreadAtomRange(nth, th, localNAtoms_, &start_th, &end_th);
319 copyPosition(start_th, end_th);
322 /* Box is changed in update() when we do pressure coupling,
323 * but we should still use the old box for energy corrections and when
324 * writing it to the energy file, so it matches the trajectory files for
325 * the same timestep above. Make a copy in a separate array.
327 copy_mat(box_, previousBox_);
330 void StatePropagatorData::copyPosition(int start, int end)
332 for (int i = start; i < end; ++i)
334 previousX_[i] = x_[i];
338 void StatePropagatorData::Element::scheduleTask(Step step,
339 Time gmx_unused time,
340 const RegisterRunFunction& registerRunFunction)
342 if (statePropagatorData_->vvResetVelocities_)
344 statePropagatorData_->vvResetVelocities_ = false;
345 registerRunFunction([this]() { statePropagatorData_->resetVelocities(); });
347 // copy x -> previousX
348 registerRunFunction([this]() { statePropagatorData_->copyPosition(); });
349 // if it's a write out step, keep a copy for writeout
350 if (step == writeOutStep_ || (step == lastStep_ && writeFinalConfiguration_))
352 registerRunFunction([this]() { saveState(); });
356 void StatePropagatorData::Element::saveState()
358 GMX_ASSERT(!localStateBackup_, "Save state called again before previous state was written.");
359 localStateBackup_ = statePropagatorData_->localState();
360 if (freeEnergyPerturbationData_)
362 localStateBackup_->fep_state = freeEnergyPerturbationData_->currentFEPState();
363 ArrayRef<const real> lambdaView = freeEnergyPerturbationData_->constLambdaView();
364 std::copy(lambdaView.begin(), lambdaView.end(), localStateBackup_->lambda.begin());
365 localStateBackup_->flags |=
366 enumValueToBitMask(StateEntry::Lambda) | enumValueToBitMask(StateEntry::FepState);
370 std::optional<SignallerCallback> StatePropagatorData::Element::registerTrajectorySignallerCallback(TrajectoryEvent event)
372 if (event == TrajectoryEvent::StateWritingStep)
374 return [this](Step step, Time /*unused*/) { this->writeOutStep_ = step; };
379 std::optional<ITrajectoryWriterCallback>
380 StatePropagatorData::Element::registerTrajectoryWriterCallback(TrajectoryEvent event)
382 if (event == TrajectoryEvent::StateWritingStep)
384 return [this](gmx_mdoutf* outf, Step step, Time time, bool writeTrajectory, bool gmx_unused writeLog) {
387 write(outf, step, time);
394 void StatePropagatorData::Element::write(gmx_mdoutf_t outf, Step currentStep, Time currentTime)
396 wallcycle_start(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
397 unsigned int mdof_flags = 0;
398 if (do_per_step(currentStep, nstxout_))
400 mdof_flags |= MDOF_X;
402 if (do_per_step(currentStep, nstvout_))
404 mdof_flags |= MDOF_V;
406 if (do_per_step(currentStep, nstfout_))
408 mdof_flags |= MDOF_F;
410 if (do_per_step(currentStep, nstxout_compressed_))
412 mdof_flags |= MDOF_X_COMPRESSED;
414 if (do_per_step(currentStep, mdoutf_get_tng_box_output_interval(outf)))
416 mdof_flags |= MDOF_BOX;
418 if (do_per_step(currentStep, mdoutf_get_tng_lambda_output_interval(outf)))
420 mdof_flags |= MDOF_LAMBDA;
422 if (do_per_step(currentStep, mdoutf_get_tng_compressed_box_output_interval(outf)))
424 mdof_flags |= MDOF_BOX_COMPRESSED;
426 if (do_per_step(currentStep, mdoutf_get_tng_compressed_lambda_output_interval(outf)))
428 mdof_flags |= MDOF_LAMBDA_COMPRESSED;
433 wallcycle_stop(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
436 GMX_ASSERT(localStateBackup_, "Trajectory writing called, but no state saved.");
438 // TODO: This is only used for CPT - needs to be filled when we turn CPT back on
439 ObservablesHistory* observablesHistory = nullptr;
441 mdoutf_write_to_trajectory_files(fplog_,
444 static_cast<int>(mdof_flags),
445 statePropagatorData_->totalNumAtoms_,
448 localStateBackup_.get(),
449 statePropagatorData_->globalState_,
451 statePropagatorData_->f_.view().force(),
452 &dummyCheckpointDataHolder_);
454 if (currentStep != lastStep_ || !isRegularSimulationEnd_)
456 localStateBackup_.reset();
458 wallcycle_stop(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
461 void StatePropagatorData::Element::elementSetup()
463 if (statePropagatorData_->vvResetVelocities_)
465 // MD-VV does the first velocity half-step only to calculate the constraint virial,
466 // then resets the velocities since the input is assumed to be positions and velocities
467 // at full time step. TODO: Change this to have input at half time steps.
468 statePropagatorData_->velocityBackup_ = statePropagatorData_->v_;
472 void StatePropagatorData::resetVelocities()
474 v_ = velocityBackup_;
480 * \brief Enum describing the contents StatePropagatorData::Element writes to modular checkpoint
482 * When changing the checkpoint content, add a new element just above Count, and adjust the
483 * checkpoint functionality.
485 enum class CheckpointVersion
487 Base, //!< First version of modular checkpointing
488 Count //!< Number of entries. Add new versions right above this!
490 constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
493 template<CheckpointDataOperation operation>
494 void StatePropagatorData::doCheckpointData(CheckpointData<operation>* checkpointData)
496 checkpointVersion(checkpointData, "StatePropagatorData version", c_currentVersion);
497 checkpointData->scalar("numAtoms", &totalNumAtoms_);
499 if (operation == CheckpointDataOperation::Read)
501 xGlobal_.resizeWithPadding(totalNumAtoms_);
502 vGlobal_.resizeWithPadding(totalNumAtoms_);
505 checkpointData->arrayRef("positions", makeCheckpointArrayRef<operation>(xGlobal_));
506 checkpointData->arrayRef("velocities", makeCheckpointArrayRef<operation>(vGlobal_));
507 checkpointData->tensor("box", box_);
508 checkpointData->scalar("ddpCount", &ddpCount_);
509 checkpointData->scalar("ddpCountCgGl", &ddpCountCgGl_);
510 checkpointData->arrayRef("cgGl", makeCheckpointArrayRef<operation>(cgGl_));
513 void StatePropagatorData::Element::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
516 if (DOMAINDECOMP(cr))
518 // Collect state from all ranks into global vectors
519 dd_collect_vec(cr->dd,
520 statePropagatorData_->ddpCount_,
521 statePropagatorData_->ddpCountCgGl_,
522 statePropagatorData_->cgGl_,
523 statePropagatorData_->x_,
524 statePropagatorData_->xGlobal_);
525 dd_collect_vec(cr->dd,
526 statePropagatorData_->ddpCount_,
527 statePropagatorData_->ddpCountCgGl_,
528 statePropagatorData_->cgGl_,
529 statePropagatorData_->v_,
530 statePropagatorData_->vGlobal_);
534 // Everything is local - copy local vectors into global ones
535 statePropagatorData_->xGlobal_.resizeWithPadding(statePropagatorData_->totalNumAtoms());
536 statePropagatorData_->vGlobal_.resizeWithPadding(statePropagatorData_->totalNumAtoms());
537 std::copy(statePropagatorData_->x_.begin(),
538 statePropagatorData_->x_.end(),
539 statePropagatorData_->xGlobal_.begin());
540 std::copy(statePropagatorData_->v_.begin(),
541 statePropagatorData_->v_.end(),
542 statePropagatorData_->vGlobal_.begin());
546 statePropagatorData_->doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
551 * \brief Update the legacy global state
553 * When restoring from checkpoint, data will be distributed during domain decomposition at setup stage.
554 * Domain decomposition still uses the legacy global t_state object so make sure it's up-to-date.
556 static void updateGlobalState(t_state* globalState,
557 const PaddedHostVector<RVec>& x,
558 const PaddedHostVector<RVec>& v,
562 const std::vector<int>& cgGl)
566 copy_mat(box, globalState->box);
567 globalState->ddp_count = ddpCount;
568 globalState->ddp_count_cg_gl = ddpCountCgGl;
569 globalState->cg_gl = cgGl;
572 void StatePropagatorData::Element::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
577 statePropagatorData_->doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
580 // Copy data to global state to be distributed by DD at setup stage
581 if (DOMAINDECOMP(cr) && MASTER(cr))
583 updateGlobalState(statePropagatorData_->globalState_,
584 statePropagatorData_->xGlobal_,
585 statePropagatorData_->vGlobal_,
586 statePropagatorData_->box_,
587 statePropagatorData_->ddpCount_,
588 statePropagatorData_->ddpCountCgGl_,
589 statePropagatorData_->cgGl_);
591 // Everything is local - copy global vectors to local ones
592 if (!DOMAINDECOMP(cr))
594 statePropagatorData_->x_.resizeWithPadding(statePropagatorData_->totalNumAtoms_);
595 statePropagatorData_->v_.resizeWithPadding(statePropagatorData_->totalNumAtoms_);
596 std::copy(statePropagatorData_->xGlobal_.begin(),
597 statePropagatorData_->xGlobal_.end(),
598 statePropagatorData_->x_.begin());
599 std::copy(statePropagatorData_->vGlobal_.begin(),
600 statePropagatorData_->vGlobal_.end(),
601 statePropagatorData_->v_.begin());
605 const std::string& StatePropagatorData::Element::clientID()
607 return StatePropagatorData::checkpointID();
610 void StatePropagatorData::Element::trajectoryWriterTeardown(gmx_mdoutf* gmx_unused outf)
612 // Note that part of this code is duplicated in do_md_trajectory_writing.
613 // This duplication is needed while both legacy and modular code paths are in use.
614 // TODO: Remove duplication asap, make sure to keep in sync in the meantime.
615 if (!writeFinalConfiguration_ || !isRegularSimulationEnd_)
620 GMX_ASSERT(localStateBackup_, "Final trajectory writing called, but no state saved.");
622 wallcycle_start(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
623 if (DOMAINDECOMP(cr_))
626 MASTER(cr_) ? statePropagatorData_->globalState_->x : gmx::ArrayRef<gmx::RVec>();
627 dd_collect_vec(cr_->dd,
628 localStateBackup_->ddp_count,
629 localStateBackup_->ddp_count_cg_gl,
630 localStateBackup_->cg_gl,
631 localStateBackup_->x,
634 MASTER(cr_) ? statePropagatorData_->globalState_->v : gmx::ArrayRef<gmx::RVec>();
635 dd_collect_vec(cr_->dd,
636 localStateBackup_->ddp_count,
637 localStateBackup_->ddp_count_cg_gl,
638 localStateBackup_->cg_gl,
639 localStateBackup_->v,
644 // We have the whole state locally: copy the local state pointer
645 statePropagatorData_->globalState_ = localStateBackup_.get();
650 fprintf(stderr, "\nWriting final coordinates.\n");
651 if (canMoleculesBeDistributedOverPBC_ && !systemHasPeriodicMolecules_)
653 // Make molecules whole only for confout writing
654 do_pbc_mtop(pbcType_,
655 localStateBackup_->box,
657 statePropagatorData_->globalState_->x.rvec_array());
659 write_sto_conf_mtop(finalConfigurationFilename_.c_str(),
662 statePropagatorData_->globalState_->x.rvec_array(),
663 statePropagatorData_->globalState_->v.rvec_array(),
665 localStateBackup_->box);
667 wallcycle_stop(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
670 std::optional<SignallerCallback> StatePropagatorData::Element::registerLastStepCallback()
672 return [this](Step step, Time /*time*/) {
674 isRegularSimulationEnd_ = (step == lastPlannedStep_);
678 StatePropagatorData::Element::Element(StatePropagatorData* statePropagatorData,
684 int nstxout_compressed,
685 bool canMoleculesBeDistributedOverPBC,
686 bool writeFinalConfiguration,
687 std::string finalConfigurationFilename,
688 const t_inputrec* inputrec,
689 const gmx_mtop_t& globalTop) :
690 statePropagatorData_(statePropagatorData),
694 nstxout_compressed_(nstxout_compressed),
696 freeEnergyPerturbationData_(nullptr),
697 isRegularSimulationEnd_(false),
699 canMoleculesBeDistributedOverPBC_(canMoleculesBeDistributedOverPBC),
700 systemHasPeriodicMolecules_(inputrec->bPeriodicMols),
701 pbcType_(inputrec->pbcType),
702 lastPlannedStep_(inputrec->nsteps + inputrec->init_step),
703 writeFinalConfiguration_(writeFinalConfiguration),
704 finalConfigurationFilename_(std::move(finalConfigurationFilename)),
707 top_global_(globalTop)
710 void StatePropagatorData::Element::setFreeEnergyPerturbationData(FreeEnergyPerturbationData* freeEnergyPerturbationData)
712 freeEnergyPerturbationData_ = freeEnergyPerturbationData;
715 ISimulatorElement* StatePropagatorData::Element::getElementPointerImpl(
716 LegacySimulatorData gmx_unused* legacySimulatorData,
717 ModularSimulatorAlgorithmBuilderHelper gmx_unused* builderHelper,
718 StatePropagatorData* statePropagatorData,
719 EnergyData gmx_unused* energyData,
720 FreeEnergyPerturbationData* freeEnergyPerturbationData,
721 GlobalCommunicationHelper gmx_unused* globalCommunicationHelper)
723 statePropagatorData->element()->setFreeEnergyPerturbationData(freeEnergyPerturbationData);
724 return statePropagatorData->element();
727 void StatePropagatorData::readCheckpointToTrxFrame(t_trxframe* trxFrame, ReadCheckpointData readCheckpointData)
729 StatePropagatorData statePropagatorData;
730 statePropagatorData.doCheckpointData(&readCheckpointData);
732 trxFrame->natoms = statePropagatorData.totalNumAtoms_;
734 trxFrame->x = makeRvecArray(statePropagatorData.xGlobal_, statePropagatorData.totalNumAtoms_);
736 trxFrame->v = makeRvecArray(statePropagatorData.vGlobal_, statePropagatorData.totalNumAtoms_);
737 trxFrame->bF = false;
738 trxFrame->bBox = true;
739 copy_mat(statePropagatorData.box_, trxFrame->box);
742 const std::string& StatePropagatorData::checkpointID()
744 static const std::string identifier = "StatePropagatorData";