Change PaddedVector to PaddedHostVector for force CPU buffer
[alexxy/gromacs.git] / src / gromacs / modularsimulator / statepropagatordata.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2019, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \libinternal
36  * \brief Defines the state for the modular simulator
37  *
38  * \author Pascal Merz <pascal.merz@me.com>
39  * \ingroup module_modularsimulator
40  */
41
42 #include "gmxpre.h"
43
44 #include "statepropagatordata.h"
45
46 #include "gromacs/domdec/domdec.h"
47 #include "gromacs/math/vec.h"
48 #include "gromacs/mdlib/gmx_omp_nthreads.h"
49 #include "gromacs/mdlib/mdoutf.h"
50 #include "gromacs/mdlib/stat.h"
51 #include "gromacs/mdlib/update.h"
52 #include "gromacs/mdtypes/commrec.h"
53 #include "gromacs/mdtypes/inputrec.h"
54 #include "gromacs/mdtypes/mdatom.h"
55 #include "gromacs/mdtypes/state.h"
56 #include "gromacs/topology/atoms.h"
57
58 namespace gmx
59 {
60 StatePropagatorData::StatePropagatorData(
61         int               numAtoms,
62         FILE             *fplog,
63         const t_commrec  *cr,
64         t_state          *globalState,
65         int               nstxout,
66         int               nstvout,
67         int               nstfout,
68         int               nstxout_compressed,
69         bool              useGPU,
70         const t_inputrec *inputrec,
71         const t_mdatoms  *mdatoms) :
72     totalNumAtoms_(numAtoms),
73     nstxout_(nstxout),
74     nstvout_(nstvout),
75     nstfout_(nstfout),
76     nstxout_compressed_(nstxout_compressed),
77     localNAtoms_(0),
78     ddpCount_(0),
79     writeOutStep_(-1),
80     vvResetVelocities_(false),
81     fplog_(fplog),
82     cr_(cr),
83     globalState_(globalState)
84 {
85     // Initialize these here, as box_{{0}} in the initialization list
86     // is confusing uncrustify and doxygen
87     clear_mat(box_);
88     clear_mat(previousBox_);
89
90     bool stateHasVelocities;
91     // Local state only becomes valid now.
92     if (DOMAINDECOMP(cr))
93     {
94         auto localState = std::make_unique<t_state>();
95         if (useGPU)
96         {
97             changePinningPolicy(&x_, gmx::PinningPolicy::PinnedIfSupported);
98         }
99         dd_init_local_state(cr->dd, globalState, localState.get());
100         stateHasVelocities = static_cast<unsigned int>(localState->flags) & (1u << estV);
101         setLocalState(std::move(localState));
102     }
103     else
104     {
105         state_change_natoms(globalState, globalState->natoms);
106         f_.resizeWithPadding(globalState->natoms);
107         localNAtoms_ = globalState->natoms;
108         x_           = globalState->x;
109         v_           = globalState->v;
110         copy_mat(globalState->box, box_);
111         stateHasVelocities = static_cast<unsigned int>(globalState->flags) & (1u << estV);
112         previousX_.resizeWithPadding(localNAtoms_);
113         copyPosition();
114     }
115
116     if (!inputrec->bContinuation)
117     {
118         if (stateHasVelocities)
119         {
120             auto v = velocitiesView().paddedArrayRef();
121             // Set the velocities of vsites, shells and frozen atoms to zero
122             for (int i = 0; i < mdatoms->homenr; i++)
123             {
124                 if (mdatoms->ptype[i] == eptVSite ||
125                     mdatoms->ptype[i] == eptShell)
126                 {
127                     clear_rvec(v[i]);
128                 }
129                 else if (mdatoms->cFREEZE)
130                 {
131                     for (int m = 0; m < DIM; m++)
132                     {
133                         if (inputrec->opts.nFreeze[mdatoms->cFREEZE[i]][m])
134                         {
135                             v[i][m] = 0;
136                         }
137                     }
138                 }
139             }
140         }
141     }
142
143     if (inputrec->eI == eiVV)
144     {
145         vvResetVelocities_ = true;
146     }
147 }
148
149 ArrayRefWithPadding<RVec> StatePropagatorData::positionsView()
150 {
151     return x_.arrayRefWithPadding();
152 }
153
154 ArrayRefWithPadding<const RVec> StatePropagatorData::constPositionsView() const
155 {
156     return x_.constArrayRefWithPadding();
157 }
158
159 ArrayRefWithPadding<RVec> StatePropagatorData::previousPositionsView()
160 {
161     return previousX_.arrayRefWithPadding();
162 }
163
164 ArrayRefWithPadding<const RVec> StatePropagatorData::constPreviousPositionsView() const
165 {
166     return previousX_.constArrayRefWithPadding();
167 }
168
169 ArrayRefWithPadding<RVec> StatePropagatorData::velocitiesView()
170 {
171     return v_.arrayRefWithPadding();
172 }
173
174 ArrayRefWithPadding<const RVec> StatePropagatorData::constVelocitiesView() const
175 {
176     return v_.constArrayRefWithPadding();
177 }
178
179 ArrayRefWithPadding<RVec> StatePropagatorData::forcesView()
180 {
181     return f_.arrayRefWithPadding();
182 }
183
184 ArrayRefWithPadding<const RVec> StatePropagatorData::constForcesView() const
185 {
186     return f_.constArrayRefWithPadding();
187 }
188
189 rvec* StatePropagatorData::box()
190 {
191     return box_;
192 }
193
194 const rvec* StatePropagatorData::constBox()
195 {
196     return box_;
197 }
198
199 rvec* StatePropagatorData::previousBox()
200 {
201     return previousBox_;
202 }
203
204 const rvec* StatePropagatorData::constPreviousBox()
205 {
206     return previousBox_;
207 }
208
209 int StatePropagatorData::localNumAtoms()
210 {
211     return localNAtoms_;
212 }
213
214 std::unique_ptr<t_state> StatePropagatorData::localState()
215 {
216     auto state = std::make_unique<t_state>();
217     state->flags = estX | estV | estBOX;
218     state_change_natoms(state.get(), localNAtoms_);
219     state->x = x_;
220     state->v = v_;
221     copy_mat(box_, state->box);
222     state->ddp_count = ddpCount_;
223     return state;
224 }
225
226 void StatePropagatorData::setLocalState(std::unique_ptr<t_state> state)
227 {
228     localNAtoms_ = state->natoms;
229     x_.resizeWithPadding(localNAtoms_);
230     previousX_.resizeWithPadding(localNAtoms_);
231     v_.resizeWithPadding(localNAtoms_);
232     x_ = state->x;
233     v_ = state->v;
234     copy_mat(state->box, box_);
235     copyPosition();
236     ddpCount_ = state->ddp_count;
237 }
238
239 t_state* StatePropagatorData::globalState()
240 {
241     return globalState_;
242 }
243
244 PaddedHostVector<RVec>* StatePropagatorData::forcePointer()
245 {
246     return &f_;
247 }
248
249 void StatePropagatorData::copyPosition()
250 {
251     int nth = gmx_omp_nthreads_get(emntUpdate);
252
253     #pragma omp parallel for num_threads(nth) schedule(static) default(none) shared(nth)
254     for (int th = 0; th < nth; th++)
255     {
256         int start_th, end_th;
257         getThreadAtomRange(nth, th, localNAtoms_, &start_th, &end_th);
258         copyPosition(start_th, end_th);
259     }
260
261     /* Box is changed in update() when we do pressure coupling,
262      * but we should still use the old box for energy corrections and when
263      * writing it to the energy file, so it matches the trajectory files for
264      * the same timestep above. Make a copy in a separate array.
265      */
266     copy_mat(box_, previousBox_);
267 }
268
269 void StatePropagatorData::copyPosition(int start, int end)
270 {
271     for (int i = start; i < end; ++i)
272     {
273         previousX_[i] = x_[i];
274     }
275 }
276
277 void StatePropagatorData::scheduleTask(
278         Step step, Time gmx_unused time,
279         const RegisterRunFunctionPtr &registerRunFunction)
280 {
281     if (vvResetVelocities_)
282     {
283         vvResetVelocities_ = false;
284         (*registerRunFunction)(
285                 std::make_unique<SimulatorRunFunction>(
286                         [this](){resetVelocities(); }));
287     }
288     // copy x -> previousX
289     (*registerRunFunction)(
290             std::make_unique<SimulatorRunFunction>(
291                     [this](){copyPosition(); }));
292     // if it's a write out step, keep a copy for writeout
293     if (step == writeOutStep_)
294     {
295         (*registerRunFunction)(
296                 std::make_unique<SimulatorRunFunction>(
297                         [this](){saveState(); }));
298     }
299 }
300
301 void StatePropagatorData::saveState()
302 {
303     GMX_ASSERT(
304             !localStateBackup_,
305             "Save state called again before previous state was written.");
306     localStateBackup_ = localState();
307 }
308
309 SignallerCallbackPtr
310 StatePropagatorData::registerTrajectorySignallerCallback(TrajectoryEvent event)
311 {
312     if (event == TrajectoryEvent::stateWritingStep)
313     {
314         return std::make_unique<SignallerCallback>(
315                 [this](Step step, Time){this->writeOutStep_ = step; });
316     }
317     return nullptr;
318 }
319
320 ITrajectoryWriterCallbackPtr
321 StatePropagatorData::registerTrajectoryWriterCallback(TrajectoryEvent event)
322 {
323     if (event == TrajectoryEvent::stateWritingStep)
324     {
325         return std::make_unique<ITrajectoryWriterCallback>(
326                 [this](gmx_mdoutf *outf, Step step, Time time)
327                 {write(outf, step, time); });
328     }
329     return nullptr;
330 }
331
332 void StatePropagatorData::write(gmx_mdoutf_t outf, Step currentStep, Time currentTime)
333 {
334     unsigned int mdof_flags = 0;
335     if (do_per_step(currentStep, nstxout_))
336     {
337         mdof_flags |= MDOF_X;
338     }
339     if (do_per_step(currentStep, nstvout_))
340     {
341         mdof_flags |= MDOF_V;
342     }
343     if (do_per_step(currentStep, nstfout_))
344     {
345         mdof_flags |= MDOF_F;
346     }
347     if (do_per_step(currentStep, nstxout_compressed_))
348     {
349         mdof_flags |= MDOF_X_COMPRESSED;
350     }
351     if (do_per_step(currentStep, mdoutf_get_tng_box_output_interval(outf)))
352     {
353         mdof_flags |= MDOF_BOX;
354     }
355     if (do_per_step(currentStep, mdoutf_get_tng_lambda_output_interval(outf)))
356     {
357         mdof_flags |= MDOF_LAMBDA;
358     }
359     if (do_per_step(currentStep, mdoutf_get_tng_compressed_box_output_interval(outf)))
360     {
361         mdof_flags |= MDOF_BOX_COMPRESSED;
362     }
363     if (do_per_step(currentStep, mdoutf_get_tng_compressed_lambda_output_interval(outf)))
364     {
365         mdof_flags |= MDOF_LAMBDA_COMPRESSED;
366     }
367
368     if (mdof_flags == 0)
369     {
370         return;
371     }
372     GMX_ASSERT(localStateBackup_, "Trajectory writing called, but no state saved.");
373
374     // TODO: This is only used for CPT - needs to be filled when we turn CPT back on
375     ObservablesHistory *observablesHistory = nullptr;
376
377     mdoutf_write_to_trajectory_files(
378             fplog_, cr_, outf, static_cast<int>(mdof_flags), totalNumAtoms_,
379             currentStep, currentTime, localStateBackup_.get(), globalState_, observablesHistory, f_);
380
381     localStateBackup_.reset();
382 }
383
384 void StatePropagatorData::elementSetup()
385 {
386     if (vvResetVelocities_)
387     {
388         velocityBackup_ = v_;
389     }
390 }
391
392 void StatePropagatorData::resetVelocities()
393 {
394     v_ = velocityBackup_;
395 }
396
397 }  // namespace gmx