2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2019,2020, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
36 * \brief Defines the state for the modular simulator
38 * \author Pascal Merz <pascal.merz@me.com>
39 * \ingroup module_modularsimulator
44 #include "statepropagatordata.h"
46 #include "gromacs/commandline/filenm.h"
47 #include "gromacs/domdec/collect.h"
48 #include "gromacs/domdec/domdec.h"
49 #include "gromacs/fileio/confio.h"
50 #include "gromacs/math/vec.h"
51 #include "gromacs/mdlib/gmx_omp_nthreads.h"
52 #include "gromacs/mdlib/mdatoms.h"
53 #include "gromacs/mdlib/mdoutf.h"
54 #include "gromacs/mdlib/stat.h"
55 #include "gromacs/mdlib/update.h"
56 #include "gromacs/mdtypes/commrec.h"
57 #include "gromacs/mdtypes/forcerec.h"
58 #include "gromacs/mdtypes/inputrec.h"
59 #include "gromacs/mdtypes/mdatom.h"
60 #include "gromacs/mdtypes/mdrunoptions.h"
61 #include "gromacs/mdtypes/state.h"
62 #include "gromacs/nbnxm/nbnxm.h"
63 #include "gromacs/pbcutil/pbc.h"
64 #include "gromacs/topology/atoms.h"
65 #include "gromacs/topology/topology.h"
67 #include "freeenergyperturbationdata.h"
68 #include "modularsimulator.h"
69 #include "simulatoralgorithm.h"
73 StatePropagatorData::StatePropagatorData(int numAtoms,
78 bool canMoleculesBeDistributedOverPBC,
79 bool writeFinalConfiguration,
80 const std::string& finalConfigurationFilename,
81 const t_inputrec* inputrec,
82 const t_mdatoms* mdatoms,
83 const gmx_mtop_t* globalTop) :
84 totalNumAtoms_(numAtoms),
87 previousBox_{ { 0 } },
89 element_(std::make_unique<Element>(this,
95 inputrec->nstxout_compressed,
96 canMoleculesBeDistributedOverPBC,
97 writeFinalConfiguration,
98 finalConfigurationFilename,
101 vvResetVelocities_(false),
102 isRegularSimulationEnd_(false),
104 globalState_(globalState)
106 bool stateHasVelocities;
107 // Local state only becomes valid now.
108 if (DOMAINDECOMP(cr))
110 auto localState = std::make_unique<t_state>();
111 dd_init_local_state(cr->dd, globalState, localState.get());
112 stateHasVelocities = ((static_cast<unsigned int>(localState->flags) & (1U << estV)) != 0U);
113 setLocalState(std::move(localState));
117 state_change_natoms(globalState, globalState->natoms);
118 f_.resizeWithPadding(globalState->natoms);
119 localNAtoms_ = globalState->natoms;
122 copy_mat(globalState->box, box_);
123 stateHasVelocities = ((static_cast<unsigned int>(globalState->flags) & (1U << estV)) != 0U);
124 previousX_.resizeWithPadding(localNAtoms_);
125 ddpCount_ = globalState->ddp_count;
130 changePinningPolicy(&x_, gmx::PinningPolicy::PinnedIfSupported);
133 if (!inputrec->bContinuation)
135 if (stateHasVelocities)
137 auto v = velocitiesView().paddedArrayRef();
138 // Set the velocities of vsites, shells and frozen atoms to zero
139 for (int i = 0; i < mdatoms->homenr; i++)
141 if (mdatoms->ptype[i] == eptVSite || mdatoms->ptype[i] == eptShell)
145 else if (mdatoms->cFREEZE)
147 for (int m = 0; m < DIM; m++)
149 if (inputrec->opts.nFreeze[mdatoms->cFREEZE[i]][m])
157 if (inputrec->eI == eiVV)
159 vvResetVelocities_ = true;
164 StatePropagatorData::Element* StatePropagatorData::element()
166 return element_.get();
169 void StatePropagatorData::setup()
173 element_->elementSetup();
177 ArrayRefWithPadding<RVec> StatePropagatorData::positionsView()
179 return x_.arrayRefWithPadding();
182 ArrayRefWithPadding<const RVec> StatePropagatorData::constPositionsView() const
184 return x_.constArrayRefWithPadding();
187 ArrayRefWithPadding<RVec> StatePropagatorData::previousPositionsView()
189 return previousX_.arrayRefWithPadding();
192 ArrayRefWithPadding<const RVec> StatePropagatorData::constPreviousPositionsView() const
194 return previousX_.constArrayRefWithPadding();
197 ArrayRefWithPadding<RVec> StatePropagatorData::velocitiesView()
199 return v_.arrayRefWithPadding();
202 ArrayRefWithPadding<const RVec> StatePropagatorData::constVelocitiesView() const
204 return v_.constArrayRefWithPadding();
207 ArrayRefWithPadding<RVec> StatePropagatorData::forcesView()
209 return f_.arrayRefWithPadding();
212 ArrayRefWithPadding<const RVec> StatePropagatorData::constForcesView() const
214 return f_.constArrayRefWithPadding();
217 rvec* StatePropagatorData::box()
222 const rvec* StatePropagatorData::constBox() const
227 rvec* StatePropagatorData::previousBox()
232 const rvec* StatePropagatorData::constPreviousBox() const
237 int StatePropagatorData::localNumAtoms() const
242 int StatePropagatorData::totalNumAtoms() const
244 return totalNumAtoms_;
247 std::unique_ptr<t_state> StatePropagatorData::localState()
249 auto state = std::make_unique<t_state>();
250 state->flags = (1U << estX) | (1U << estV) | (1U << estBOX);
251 state_change_natoms(state.get(), localNAtoms_);
254 copy_mat(box_, state->box);
255 state->ddp_count = ddpCount_;
259 void StatePropagatorData::setLocalState(std::unique_ptr<t_state> state)
261 localNAtoms_ = state->natoms;
262 x_.resizeWithPadding(localNAtoms_);
263 previousX_.resizeWithPadding(localNAtoms_);
264 v_.resizeWithPadding(localNAtoms_);
267 copy_mat(state->box, box_);
269 ddpCount_ = state->ddp_count;
271 if (vvResetVelocities_)
273 /* DomDec runs twice early in the simulation, once at setup time, and once before the first
274 * step. Every time DD runs, it sets a new local state here. We are saving a backup during
275 * setup time (ok for non-DD cases), so we need to update our backup to the DD state before
276 * the first step here to avoid resetting to an earlier DD state. This is done before any
277 * propagation that needs to be reset, so it's not very safe but correct for now.
278 * TODO: Get rid of this once input is assumed to be at half steps
280 velocityBackup_ = v_;
284 t_state* StatePropagatorData::globalState()
289 PaddedHostVector<RVec>* StatePropagatorData::forcePointer()
294 void StatePropagatorData::copyPosition()
296 int nth = gmx_omp_nthreads_get(emntUpdate);
298 #pragma omp parallel for num_threads(nth) schedule(static) default(none) shared(nth)
299 for (int th = 0; th < nth; th++)
301 int start_th, end_th;
302 getThreadAtomRange(nth, th, localNAtoms_, &start_th, &end_th);
303 copyPosition(start_th, end_th);
306 /* Box is changed in update() when we do pressure coupling,
307 * but we should still use the old box for energy corrections and when
308 * writing it to the energy file, so it matches the trajectory files for
309 * the same timestep above. Make a copy in a separate array.
311 copy_mat(box_, previousBox_);
314 void StatePropagatorData::copyPosition(int start, int end)
316 for (int i = start; i < end; ++i)
318 previousX_[i] = x_[i];
322 void StatePropagatorData::Element::scheduleTask(Step step,
323 Time gmx_unused time,
324 const RegisterRunFunction& registerRunFunction)
326 if (statePropagatorData_->vvResetVelocities_)
328 statePropagatorData_->vvResetVelocities_ = false;
329 registerRunFunction([this]() { statePropagatorData_->resetVelocities(); });
331 // copy x -> previousX
332 registerRunFunction([this]() { statePropagatorData_->copyPosition(); });
333 // if it's a write out step, keep a copy for writeout
334 if (step == writeOutStep_ || (step == lastStep_ && writeFinalConfiguration_))
336 registerRunFunction([this]() { saveState(); });
340 void StatePropagatorData::Element::saveState()
342 GMX_ASSERT(!localStateBackup_, "Save state called again before previous state was written.");
343 localStateBackup_ = statePropagatorData_->localState();
344 if (freeEnergyPerturbationData_)
346 localStateBackup_->fep_state = freeEnergyPerturbationData_->currentFEPState();
347 for (unsigned long i = 0; i < localStateBackup_->lambda.size(); ++i)
349 localStateBackup_->lambda[i] = freeEnergyPerturbationData_->constLambdaView()[i];
351 localStateBackup_->flags |= (1U << estLAMBDA) | (1U << estFEPSTATE);
355 std::optional<SignallerCallback> StatePropagatorData::Element::registerTrajectorySignallerCallback(TrajectoryEvent event)
357 if (event == TrajectoryEvent::StateWritingStep)
359 return [this](Step step, Time /*unused*/) { this->writeOutStep_ = step; };
364 std::optional<ITrajectoryWriterCallback>
365 StatePropagatorData::Element::registerTrajectoryWriterCallback(TrajectoryEvent event)
367 if (event == TrajectoryEvent::StateWritingStep)
369 return [this](gmx_mdoutf* outf, Step step, Time time, bool writeTrajectory, bool gmx_unused writeLog) {
372 write(outf, step, time);
379 void StatePropagatorData::Element::write(gmx_mdoutf_t outf, Step currentStep, Time currentTime)
381 wallcycle_start(mdoutf_get_wcycle(outf), ewcTRAJ);
382 unsigned int mdof_flags = 0;
383 if (do_per_step(currentStep, nstxout_))
385 mdof_flags |= MDOF_X;
387 if (do_per_step(currentStep, nstvout_))
389 mdof_flags |= MDOF_V;
391 if (do_per_step(currentStep, nstfout_))
393 mdof_flags |= MDOF_F;
395 if (do_per_step(currentStep, nstxout_compressed_))
397 mdof_flags |= MDOF_X_COMPRESSED;
399 if (do_per_step(currentStep, mdoutf_get_tng_box_output_interval(outf)))
401 mdof_flags |= MDOF_BOX;
403 if (do_per_step(currentStep, mdoutf_get_tng_lambda_output_interval(outf)))
405 mdof_flags |= MDOF_LAMBDA;
407 if (do_per_step(currentStep, mdoutf_get_tng_compressed_box_output_interval(outf)))
409 mdof_flags |= MDOF_BOX_COMPRESSED;
411 if (do_per_step(currentStep, mdoutf_get_tng_compressed_lambda_output_interval(outf)))
413 mdof_flags |= MDOF_LAMBDA_COMPRESSED;
418 wallcycle_stop(mdoutf_get_wcycle(outf), ewcTRAJ);
421 GMX_ASSERT(localStateBackup_, "Trajectory writing called, but no state saved.");
423 // TODO: This is only used for CPT - needs to be filled when we turn CPT back on
424 ObservablesHistory* observablesHistory = nullptr;
426 mdoutf_write_to_trajectory_files(fplog_, cr_, outf, static_cast<int>(mdof_flags),
427 statePropagatorData_->totalNumAtoms_, currentStep, currentTime,
428 localStateBackup_.get(), statePropagatorData_->globalState_,
429 observablesHistory, statePropagatorData_->f_);
431 if (currentStep != lastStep_ || !isRegularSimulationEnd_)
433 localStateBackup_.reset();
435 wallcycle_stop(mdoutf_get_wcycle(outf), ewcTRAJ);
438 void StatePropagatorData::Element::elementSetup()
440 if (statePropagatorData_->vvResetVelocities_)
442 // MD-VV does the first velocity half-step only to calculate the constraint virial,
443 // then resets the velocities since the input is assumed to be positions and velocities
444 // at full time step. TODO: Change this to have input at half time steps.
445 statePropagatorData_->velocityBackup_ = statePropagatorData_->v_;
449 void StatePropagatorData::resetVelocities()
451 v_ = velocityBackup_;
454 void StatePropagatorData::Element::writeCheckpoint(t_state* localState, t_state gmx_unused* globalState)
456 state_change_natoms(localState, statePropagatorData_->localNAtoms_);
457 localState->x = statePropagatorData_->x_;
458 localState->v = statePropagatorData_->v_;
459 copy_mat(statePropagatorData_->box_, localState->box);
460 localState->ddp_count = statePropagatorData_->ddpCount_;
461 localState->flags |= (1U << estX) | (1U << estV) | (1U << estBOX);
464 void StatePropagatorData::Element::trajectoryWriterTeardown(gmx_mdoutf* gmx_unused outf)
466 // Note that part of this code is duplicated in do_md_trajectory_writing.
467 // This duplication is needed while both legacy and modular code paths are in use.
468 // TODO: Remove duplication asap, make sure to keep in sync in the meantime.
469 if (!writeFinalConfiguration_ || !isRegularSimulationEnd_)
474 GMX_ASSERT(localStateBackup_, "Final trajectory writing called, but no state saved.");
476 wallcycle_start(mdoutf_get_wcycle(outf), ewcTRAJ);
477 if (DOMAINDECOMP(cr_))
480 MASTER(cr_) ? statePropagatorData_->globalState_->x : gmx::ArrayRef<gmx::RVec>();
481 dd_collect_vec(cr_->dd, localStateBackup_.get(), localStateBackup_->x, globalXRef);
483 MASTER(cr_) ? statePropagatorData_->globalState_->v : gmx::ArrayRef<gmx::RVec>();
484 dd_collect_vec(cr_->dd, localStateBackup_.get(), localStateBackup_->v, globalVRef);
488 // We have the whole state locally: copy the local state pointer
489 statePropagatorData_->globalState_ = localStateBackup_.get();
494 fprintf(stderr, "\nWriting final coordinates.\n");
495 if (canMoleculesBeDistributedOverPBC_ && !systemHasPeriodicMolecules_)
497 // Make molecules whole only for confout writing
498 do_pbc_mtop(pbcType_, localStateBackup_->box, top_global_,
499 statePropagatorData_->globalState_->x.rvec_array());
501 write_sto_conf_mtop(finalConfigurationFilename_.c_str(), *top_global_->name, top_global_,
502 statePropagatorData_->globalState_->x.rvec_array(),
503 statePropagatorData_->globalState_->v.rvec_array(), pbcType_,
504 localStateBackup_->box);
506 wallcycle_stop(mdoutf_get_wcycle(outf), ewcTRAJ);
509 std::optional<SignallerCallback> StatePropagatorData::Element::registerLastStepCallback()
511 return [this](Step step, Time /*time*/) {
513 isRegularSimulationEnd_ = (step == lastPlannedStep_);
517 StatePropagatorData::Element::Element(StatePropagatorData* statePropagatorData,
523 int nstxout_compressed,
524 bool canMoleculesBeDistributedOverPBC,
525 bool writeFinalConfiguration,
526 std::string finalConfigurationFilename,
527 const t_inputrec* inputrec,
528 const gmx_mtop_t* globalTop) :
529 statePropagatorData_(statePropagatorData),
533 nstxout_compressed_(nstxout_compressed),
535 freeEnergyPerturbationData_(nullptr),
536 isRegularSimulationEnd_(false),
538 canMoleculesBeDistributedOverPBC_(canMoleculesBeDistributedOverPBC),
539 systemHasPeriodicMolecules_(inputrec->bPeriodicMols),
540 pbcType_(inputrec->pbcType),
541 lastPlannedStep_(inputrec->nsteps + inputrec->init_step),
542 writeFinalConfiguration_(writeFinalConfiguration),
543 finalConfigurationFilename_(std::move(finalConfigurationFilename)),
546 top_global_(globalTop)
549 void StatePropagatorData::Element::setFreeEnergyPerturbationData(FreeEnergyPerturbationData* freeEnergyPerturbationData)
551 freeEnergyPerturbationData_ = freeEnergyPerturbationData;
554 ISimulatorElement* StatePropagatorData::Element::getElementPointerImpl(
555 LegacySimulatorData gmx_unused* legacySimulatorData,
556 ModularSimulatorAlgorithmBuilderHelper gmx_unused* builderHelper,
557 StatePropagatorData* statePropagatorData,
558 EnergyData gmx_unused* energyData,
559 FreeEnergyPerturbationData* freeEnergyPerturbationData,
560 GlobalCommunicationHelper gmx_unused* globalCommunicationHelper)
562 statePropagatorData->element()->setFreeEnergyPerturbationData(freeEnergyPerturbationData);
563 return statePropagatorData->element();