de554c1884997cd684f9657f4ee2224150426c9b
[alexxy/gromacs.git] / src / gromacs / modularsimulator / statepropagatordata.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2019,2020, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  * \brief Defines the state for the modular simulator
37  *
38  * \author Pascal Merz <pascal.merz@me.com>
39  * \ingroup module_modularsimulator
40  */
41
42 #include "gmxpre.h"
43
44 #include "statepropagatordata.h"
45
46 #include "gromacs/commandline/filenm.h"
47 #include "gromacs/domdec/collect.h"
48 #include "gromacs/domdec/domdec.h"
49 #include "gromacs/fileio/confio.h"
50 #include "gromacs/math/vec.h"
51 #include "gromacs/mdlib/gmx_omp_nthreads.h"
52 #include "gromacs/mdlib/mdatoms.h"
53 #include "gromacs/mdlib/mdoutf.h"
54 #include "gromacs/mdlib/stat.h"
55 #include "gromacs/mdlib/update.h"
56 #include "gromacs/mdtypes/commrec.h"
57 #include "gromacs/mdtypes/forcerec.h"
58 #include "gromacs/mdtypes/inputrec.h"
59 #include "gromacs/mdtypes/mdatom.h"
60 #include "gromacs/mdtypes/mdrunoptions.h"
61 #include "gromacs/mdtypes/state.h"
62 #include "gromacs/nbnxm/nbnxm.h"
63 #include "gromacs/pbcutil/pbc.h"
64 #include "gromacs/topology/atoms.h"
65 #include "gromacs/topology/topology.h"
66
67 #include "freeenergyperturbationdata.h"
68 #include "modularsimulator.h"
69 #include "simulatoralgorithm.h"
70
71 namespace gmx
72 {
73 StatePropagatorData::StatePropagatorData(int                numAtoms,
74                                          FILE*              fplog,
75                                          const t_commrec*   cr,
76                                          t_state*           globalState,
77                                          bool               useGPU,
78                                          bool               canMoleculesBeDistributedOverPBC,
79                                          bool               writeFinalConfiguration,
80                                          const std::string& finalConfigurationFilename,
81                                          const t_inputrec*  inputrec,
82                                          const t_mdatoms*   mdatoms,
83                                          const gmx_mtop_t*  globalTop) :
84     totalNumAtoms_(numAtoms),
85     localNAtoms_(0),
86     box_{ { 0 } },
87     previousBox_{ { 0 } },
88     ddpCount_(0),
89     element_(std::make_unique<Element>(this,
90                                        fplog,
91                                        cr,
92                                        inputrec->nstxout,
93                                        inputrec->nstvout,
94                                        inputrec->nstfout,
95                                        inputrec->nstxout_compressed,
96                                        canMoleculesBeDistributedOverPBC,
97                                        writeFinalConfiguration,
98                                        finalConfigurationFilename,
99                                        inputrec,
100                                        globalTop)),
101     vvResetVelocities_(false),
102     isRegularSimulationEnd_(false),
103     lastStep_(-1),
104     globalState_(globalState)
105 {
106     bool stateHasVelocities;
107     // Local state only becomes valid now.
108     if (DOMAINDECOMP(cr))
109     {
110         auto localState = std::make_unique<t_state>();
111         dd_init_local_state(cr->dd, globalState, localState.get());
112         stateHasVelocities = ((static_cast<unsigned int>(localState->flags) & (1U << estV)) != 0U);
113         setLocalState(std::move(localState));
114     }
115     else
116     {
117         state_change_natoms(globalState, globalState->natoms);
118         f_.resizeWithPadding(globalState->natoms);
119         localNAtoms_ = globalState->natoms;
120         x_           = globalState->x;
121         v_           = globalState->v;
122         copy_mat(globalState->box, box_);
123         stateHasVelocities = ((static_cast<unsigned int>(globalState->flags) & (1U << estV)) != 0U);
124         previousX_.resizeWithPadding(localNAtoms_);
125         ddpCount_ = globalState->ddp_count;
126         copyPosition();
127     }
128     if (useGPU)
129     {
130         changePinningPolicy(&x_, gmx::PinningPolicy::PinnedIfSupported);
131     }
132
133     if (!inputrec->bContinuation)
134     {
135         if (stateHasVelocities)
136         {
137             auto v = velocitiesView().paddedArrayRef();
138             // Set the velocities of vsites, shells and frozen atoms to zero
139             for (int i = 0; i < mdatoms->homenr; i++)
140             {
141                 if (mdatoms->ptype[i] == eptVSite || mdatoms->ptype[i] == eptShell)
142                 {
143                     clear_rvec(v[i]);
144                 }
145                 else if (mdatoms->cFREEZE)
146                 {
147                     for (int m = 0; m < DIM; m++)
148                     {
149                         if (inputrec->opts.nFreeze[mdatoms->cFREEZE[i]][m])
150                         {
151                             v[i][m] = 0;
152                         }
153                     }
154                 }
155             }
156         }
157         if (inputrec->eI == eiVV)
158         {
159             vvResetVelocities_ = true;
160         }
161     }
162 }
163
164 StatePropagatorData::Element* StatePropagatorData::element()
165 {
166     return element_.get();
167 }
168
169 void StatePropagatorData::setup()
170 {
171     if (element_)
172     {
173         element_->elementSetup();
174     }
175 }
176
177 ArrayRefWithPadding<RVec> StatePropagatorData::positionsView()
178 {
179     return x_.arrayRefWithPadding();
180 }
181
182 ArrayRefWithPadding<const RVec> StatePropagatorData::constPositionsView() const
183 {
184     return x_.constArrayRefWithPadding();
185 }
186
187 ArrayRefWithPadding<RVec> StatePropagatorData::previousPositionsView()
188 {
189     return previousX_.arrayRefWithPadding();
190 }
191
192 ArrayRefWithPadding<const RVec> StatePropagatorData::constPreviousPositionsView() const
193 {
194     return previousX_.constArrayRefWithPadding();
195 }
196
197 ArrayRefWithPadding<RVec> StatePropagatorData::velocitiesView()
198 {
199     return v_.arrayRefWithPadding();
200 }
201
202 ArrayRefWithPadding<const RVec> StatePropagatorData::constVelocitiesView() const
203 {
204     return v_.constArrayRefWithPadding();
205 }
206
207 ArrayRefWithPadding<RVec> StatePropagatorData::forcesView()
208 {
209     return f_.arrayRefWithPadding();
210 }
211
212 ArrayRefWithPadding<const RVec> StatePropagatorData::constForcesView() const
213 {
214     return f_.constArrayRefWithPadding();
215 }
216
217 rvec* StatePropagatorData::box()
218 {
219     return box_;
220 }
221
222 const rvec* StatePropagatorData::constBox() const
223 {
224     return box_;
225 }
226
227 rvec* StatePropagatorData::previousBox()
228 {
229     return previousBox_;
230 }
231
232 const rvec* StatePropagatorData::constPreviousBox() const
233 {
234     return previousBox_;
235 }
236
237 int StatePropagatorData::localNumAtoms() const
238 {
239     return localNAtoms_;
240 }
241
242 int StatePropagatorData::totalNumAtoms() const
243 {
244     return totalNumAtoms_;
245 }
246
247 std::unique_ptr<t_state> StatePropagatorData::localState()
248 {
249     auto state   = std::make_unique<t_state>();
250     state->flags = (1U << estX) | (1U << estV) | (1U << estBOX);
251     state_change_natoms(state.get(), localNAtoms_);
252     state->x = x_;
253     state->v = v_;
254     copy_mat(box_, state->box);
255     state->ddp_count = ddpCount_;
256     return state;
257 }
258
259 void StatePropagatorData::setLocalState(std::unique_ptr<t_state> state)
260 {
261     localNAtoms_ = state->natoms;
262     x_.resizeWithPadding(localNAtoms_);
263     previousX_.resizeWithPadding(localNAtoms_);
264     v_.resizeWithPadding(localNAtoms_);
265     x_ = state->x;
266     v_ = state->v;
267     copy_mat(state->box, box_);
268     copyPosition();
269     ddpCount_ = state->ddp_count;
270
271     if (vvResetVelocities_)
272     {
273         /* DomDec runs twice early in the simulation, once at setup time, and once before the first
274          * step. Every time DD runs, it sets a new local state here. We are saving a backup during
275          * setup time (ok for non-DD cases), so we need to update our backup to the DD state before
276          * the first step here to avoid resetting to an earlier DD state. This is done before any
277          * propagation that needs to be reset, so it's not very safe but correct for now.
278          * TODO: Get rid of this once input is assumed to be at half steps
279          */
280         velocityBackup_ = v_;
281     }
282 }
283
284 t_state* StatePropagatorData::globalState()
285 {
286     return globalState_;
287 }
288
289 PaddedHostVector<RVec>* StatePropagatorData::forcePointer()
290 {
291     return &f_;
292 }
293
294 void StatePropagatorData::copyPosition()
295 {
296     int nth = gmx_omp_nthreads_get(emntUpdate);
297
298 #pragma omp parallel for num_threads(nth) schedule(static) default(none) shared(nth)
299     for (int th = 0; th < nth; th++)
300     {
301         int start_th, end_th;
302         getThreadAtomRange(nth, th, localNAtoms_, &start_th, &end_th);
303         copyPosition(start_th, end_th);
304     }
305
306     /* Box is changed in update() when we do pressure coupling,
307      * but we should still use the old box for energy corrections and when
308      * writing it to the energy file, so it matches the trajectory files for
309      * the same timestep above. Make a copy in a separate array.
310      */
311     copy_mat(box_, previousBox_);
312 }
313
314 void StatePropagatorData::copyPosition(int start, int end)
315 {
316     for (int i = start; i < end; ++i)
317     {
318         previousX_[i] = x_[i];
319     }
320 }
321
322 void StatePropagatorData::Element::scheduleTask(Step step,
323                                                 Time gmx_unused            time,
324                                                 const RegisterRunFunction& registerRunFunction)
325 {
326     if (statePropagatorData_->vvResetVelocities_)
327     {
328         statePropagatorData_->vvResetVelocities_ = false;
329         registerRunFunction([this]() { statePropagatorData_->resetVelocities(); });
330     }
331     // copy x -> previousX
332     registerRunFunction([this]() { statePropagatorData_->copyPosition(); });
333     // if it's a write out step, keep a copy for writeout
334     if (step == writeOutStep_ || (step == lastStep_ && writeFinalConfiguration_))
335     {
336         registerRunFunction([this]() { saveState(); });
337     }
338 }
339
340 void StatePropagatorData::Element::saveState()
341 {
342     GMX_ASSERT(!localStateBackup_, "Save state called again before previous state was written.");
343     localStateBackup_ = statePropagatorData_->localState();
344     if (freeEnergyPerturbationData_)
345     {
346         localStateBackup_->fep_state = freeEnergyPerturbationData_->currentFEPState();
347         for (unsigned long i = 0; i < localStateBackup_->lambda.size(); ++i)
348         {
349             localStateBackup_->lambda[i] = freeEnergyPerturbationData_->constLambdaView()[i];
350         }
351         localStateBackup_->flags |= (1U << estLAMBDA) | (1U << estFEPSTATE);
352     }
353 }
354
355 std::optional<SignallerCallback> StatePropagatorData::Element::registerTrajectorySignallerCallback(TrajectoryEvent event)
356 {
357     if (event == TrajectoryEvent::StateWritingStep)
358     {
359         return [this](Step step, Time /*unused*/) { this->writeOutStep_ = step; };
360     }
361     return std::nullopt;
362 }
363
364 std::optional<ITrajectoryWriterCallback>
365 StatePropagatorData::Element::registerTrajectoryWriterCallback(TrajectoryEvent event)
366 {
367     if (event == TrajectoryEvent::StateWritingStep)
368     {
369         return [this](gmx_mdoutf* outf, Step step, Time time, bool writeTrajectory, bool gmx_unused writeLog) {
370             if (writeTrajectory)
371             {
372                 write(outf, step, time);
373             }
374         };
375     }
376     return std::nullopt;
377 }
378
379 void StatePropagatorData::Element::write(gmx_mdoutf_t outf, Step currentStep, Time currentTime)
380 {
381     wallcycle_start(mdoutf_get_wcycle(outf), ewcTRAJ);
382     unsigned int mdof_flags = 0;
383     if (do_per_step(currentStep, nstxout_))
384     {
385         mdof_flags |= MDOF_X;
386     }
387     if (do_per_step(currentStep, nstvout_))
388     {
389         mdof_flags |= MDOF_V;
390     }
391     if (do_per_step(currentStep, nstfout_))
392     {
393         mdof_flags |= MDOF_F;
394     }
395     if (do_per_step(currentStep, nstxout_compressed_))
396     {
397         mdof_flags |= MDOF_X_COMPRESSED;
398     }
399     if (do_per_step(currentStep, mdoutf_get_tng_box_output_interval(outf)))
400     {
401         mdof_flags |= MDOF_BOX;
402     }
403     if (do_per_step(currentStep, mdoutf_get_tng_lambda_output_interval(outf)))
404     {
405         mdof_flags |= MDOF_LAMBDA;
406     }
407     if (do_per_step(currentStep, mdoutf_get_tng_compressed_box_output_interval(outf)))
408     {
409         mdof_flags |= MDOF_BOX_COMPRESSED;
410     }
411     if (do_per_step(currentStep, mdoutf_get_tng_compressed_lambda_output_interval(outf)))
412     {
413         mdof_flags |= MDOF_LAMBDA_COMPRESSED;
414     }
415
416     if (mdof_flags == 0)
417     {
418         wallcycle_stop(mdoutf_get_wcycle(outf), ewcTRAJ);
419         return;
420     }
421     GMX_ASSERT(localStateBackup_, "Trajectory writing called, but no state saved.");
422
423     // TODO: This is only used for CPT - needs to be filled when we turn CPT back on
424     ObservablesHistory* observablesHistory = nullptr;
425
426     mdoutf_write_to_trajectory_files(fplog_, cr_, outf, static_cast<int>(mdof_flags),
427                                      statePropagatorData_->totalNumAtoms_, currentStep, currentTime,
428                                      localStateBackup_.get(), statePropagatorData_->globalState_,
429                                      observablesHistory, statePropagatorData_->f_);
430
431     if (currentStep != lastStep_ || !isRegularSimulationEnd_)
432     {
433         localStateBackup_.reset();
434     }
435     wallcycle_stop(mdoutf_get_wcycle(outf), ewcTRAJ);
436 }
437
438 void StatePropagatorData::Element::elementSetup()
439 {
440     if (statePropagatorData_->vvResetVelocities_)
441     {
442         // MD-VV does the first velocity half-step only to calculate the constraint virial,
443         // then resets the velocities since the input is assumed to be positions and velocities
444         // at full time step. TODO: Change this to have input at half time steps.
445         statePropagatorData_->velocityBackup_ = statePropagatorData_->v_;
446     }
447 }
448
449 void StatePropagatorData::resetVelocities()
450 {
451     v_ = velocityBackup_;
452 }
453
454 void StatePropagatorData::Element::writeCheckpoint(t_state* localState, t_state gmx_unused* globalState)
455 {
456     state_change_natoms(localState, statePropagatorData_->localNAtoms_);
457     localState->x = statePropagatorData_->x_;
458     localState->v = statePropagatorData_->v_;
459     copy_mat(statePropagatorData_->box_, localState->box);
460     localState->ddp_count = statePropagatorData_->ddpCount_;
461     localState->flags |= (1U << estX) | (1U << estV) | (1U << estBOX);
462 }
463
464 void StatePropagatorData::Element::trajectoryWriterTeardown(gmx_mdoutf* gmx_unused outf)
465 {
466     // Note that part of this code is duplicated in do_md_trajectory_writing.
467     // This duplication is needed while both legacy and modular code paths are in use.
468     // TODO: Remove duplication asap, make sure to keep in sync in the meantime.
469     if (!writeFinalConfiguration_ || !isRegularSimulationEnd_)
470     {
471         return;
472     }
473
474     GMX_ASSERT(localStateBackup_, "Final trajectory writing called, but no state saved.");
475
476     wallcycle_start(mdoutf_get_wcycle(outf), ewcTRAJ);
477     if (DOMAINDECOMP(cr_))
478     {
479         auto globalXRef =
480                 MASTER(cr_) ? statePropagatorData_->globalState_->x : gmx::ArrayRef<gmx::RVec>();
481         dd_collect_vec(cr_->dd, localStateBackup_.get(), localStateBackup_->x, globalXRef);
482         auto globalVRef =
483                 MASTER(cr_) ? statePropagatorData_->globalState_->v : gmx::ArrayRef<gmx::RVec>();
484         dd_collect_vec(cr_->dd, localStateBackup_.get(), localStateBackup_->v, globalVRef);
485     }
486     else
487     {
488         // We have the whole state locally: copy the local state pointer
489         statePropagatorData_->globalState_ = localStateBackup_.get();
490     }
491
492     if (MASTER(cr_))
493     {
494         fprintf(stderr, "\nWriting final coordinates.\n");
495         if (canMoleculesBeDistributedOverPBC_ && !systemHasPeriodicMolecules_)
496         {
497             // Make molecules whole only for confout writing
498             do_pbc_mtop(pbcType_, localStateBackup_->box, top_global_,
499                         statePropagatorData_->globalState_->x.rvec_array());
500         }
501         write_sto_conf_mtop(finalConfigurationFilename_.c_str(), *top_global_->name, top_global_,
502                             statePropagatorData_->globalState_->x.rvec_array(),
503                             statePropagatorData_->globalState_->v.rvec_array(), pbcType_,
504                             localStateBackup_->box);
505     }
506     wallcycle_stop(mdoutf_get_wcycle(outf), ewcTRAJ);
507 }
508
509 std::optional<SignallerCallback> StatePropagatorData::Element::registerLastStepCallback()
510 {
511     return [this](Step step, Time /*time*/) {
512         lastStep_               = step;
513         isRegularSimulationEnd_ = (step == lastPlannedStep_);
514     };
515 }
516
517 StatePropagatorData::Element::Element(StatePropagatorData* statePropagatorData,
518                                       FILE*                fplog,
519                                       const t_commrec*     cr,
520                                       int                  nstxout,
521                                       int                  nstvout,
522                                       int                  nstfout,
523                                       int                  nstxout_compressed,
524                                       bool                 canMoleculesBeDistributedOverPBC,
525                                       bool                 writeFinalConfiguration,
526                                       std::string          finalConfigurationFilename,
527                                       const t_inputrec*    inputrec,
528                                       const gmx_mtop_t*    globalTop) :
529     statePropagatorData_(statePropagatorData),
530     nstxout_(nstxout),
531     nstvout_(nstvout),
532     nstfout_(nstfout),
533     nstxout_compressed_(nstxout_compressed),
534     writeOutStep_(-1),
535     freeEnergyPerturbationData_(nullptr),
536     isRegularSimulationEnd_(false),
537     lastStep_(-1),
538     canMoleculesBeDistributedOverPBC_(canMoleculesBeDistributedOverPBC),
539     systemHasPeriodicMolecules_(inputrec->bPeriodicMols),
540     pbcType_(inputrec->pbcType),
541     lastPlannedStep_(inputrec->nsteps + inputrec->init_step),
542     writeFinalConfiguration_(writeFinalConfiguration),
543     finalConfigurationFilename_(std::move(finalConfigurationFilename)),
544     fplog_(fplog),
545     cr_(cr),
546     top_global_(globalTop)
547 {
548 }
549 void StatePropagatorData::Element::setFreeEnergyPerturbationData(FreeEnergyPerturbationData* freeEnergyPerturbationData)
550 {
551     freeEnergyPerturbationData_ = freeEnergyPerturbationData;
552 }
553
554 ISimulatorElement* StatePropagatorData::Element::getElementPointerImpl(
555         LegacySimulatorData gmx_unused*        legacySimulatorData,
556         ModularSimulatorAlgorithmBuilderHelper gmx_unused* builderHelper,
557         StatePropagatorData*                               statePropagatorData,
558         EnergyData gmx_unused*      energyData,
559         FreeEnergyPerturbationData* freeEnergyPerturbationData,
560         GlobalCommunicationHelper gmx_unused* globalCommunicationHelper)
561 {
562     statePropagatorData->element()->setFreeEnergyPerturbationData(freeEnergyPerturbationData);
563     return statePropagatorData->element();
564 }
565
566 } // namespace gmx