Fix random typos
[alexxy/gromacs.git] / src / gromacs / domdec / collect.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2018,2019,2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /* \internal \file
36  *
37  * \brief Implements functions to collect state data to the master rank.
38  *
39  * \author Berk Hess <hess@kth.se>
40  * \ingroup module_domdec
41  */
42
43 #include "gmxpre.h"
44
45 #include "collect.h"
46
47 #include "config.h"
48
49 #include "gromacs/domdec/domdec_network.h"
50 #include "gromacs/math/vec.h"
51 #include "gromacs/mdtypes/state.h"
52 #include "gromacs/utility/enumerationhelpers.h"
53 #include "gromacs/utility/fatalerror.h"
54
55 #include "atomdistribution.h"
56 #include "distribute.h"
57 #include "domdec_internal.h"
58
59 static void dd_collect_cg(gmx_domdec_t*            dd,
60                           const int                ddpCount,
61                           const int                ddpCountCgGl,
62                           gmx::ArrayRef<const int> localCGNumbers)
63 {
64     if (ddpCount == dd->comm->master_cg_ddp_count)
65     {
66         /* The master has the correct distribution */
67         return;
68     }
69
70     gmx::ArrayRef<const int> atomGroups;
71     int                      nat_home = 0;
72
73     if (ddpCount == dd->ddp_count)
74     {
75         /* The local state and DD are in sync, use the DD indices */
76         atomGroups = gmx::constArrayRefFromArray(dd->globalAtomGroupIndices.data(), dd->numHomeAtoms);
77         nat_home   = dd->comm->atomRanges.numHomeAtoms();
78     }
79     else if (ddpCountCgGl == ddpCount)
80     {
81         /* The DD is out of sync with the local state, but we have stored
82          * the cg indices with the local state, so we can use those.
83          */
84         atomGroups = localCGNumbers;
85         nat_home   = atomGroups.size();
86     }
87     else
88     {
89         gmx_incons(
90                 "Attempted to collect a vector for a state for which the charge group distribution "
91                 "is unknown");
92     }
93
94     AtomDistribution* ma = dd->ma.get();
95
96     /* Collect the charge group and atom counts on the master */
97     int localBuffer[2] = { static_cast<int>(atomGroups.size()), nat_home };
98     dd_gather(dd, 2 * sizeof(int), localBuffer, DDMASTER(dd) ? ma->intBuffer.data() : nullptr);
99
100     if (DDMASTER(dd))
101     {
102         int groupOffset = 0;
103         for (int rank = 0; rank < dd->nnodes; rank++)
104         {
105             auto& domainGroups = ma->domainGroups[rank];
106             int   numGroups    = ma->intBuffer[2 * rank];
107
108             domainGroups.atomGroups =
109                     gmx::constArrayRefFromArray(ma->atomGroups.data() + groupOffset, numGroups);
110
111             domainGroups.numAtoms = ma->intBuffer[2 * rank + 1];
112
113             groupOffset += numGroups;
114         }
115
116         if (debug)
117         {
118             fprintf(debug, "Initial charge group distribution: ");
119             for (int rank = 0; rank < dd->nnodes; rank++)
120             {
121                 fprintf(debug, " %td", ma->domainGroups[rank].atomGroups.ssize());
122             }
123             fprintf(debug, "\n");
124         }
125
126         /* Make byte counts and indices */
127         int offset = 0;
128         for (int rank = 0; rank < dd->nnodes; rank++)
129         {
130             int numGroups                    = ma->domainGroups[rank].atomGroups.size();
131             ma->intBuffer[rank]              = numGroups * sizeof(int);
132             ma->intBuffer[dd->nnodes + rank] = offset * sizeof(int);
133             offset += numGroups;
134         }
135     }
136
137     /* Collect the charge group indices on the master */
138     dd_gatherv(dd,
139                atomGroups.size() * sizeof(int),
140                atomGroups.data(),
141                DDMASTER(dd) ? ma->intBuffer.data() : nullptr,
142                DDMASTER(dd) ? ma->intBuffer.data() + dd->nnodes : nullptr,
143                DDMASTER(dd) ? ma->atomGroups.data() : nullptr);
144
145     dd->comm->master_cg_ddp_count = ddpCount;
146 }
147
148 static void dd_collect_vec_sendrecv(gmx_domdec_t*                  dd,
149                                     gmx::ArrayRef<const gmx::RVec> lv,
150                                     gmx::ArrayRef<gmx::RVec>       v)
151 {
152     if (!DDMASTER(dd))
153     {
154 #if GMX_MPI
155         const int numHomeAtoms = dd->comm->atomRanges.numHomeAtoms();
156         MPI_Send(const_cast<void*>(static_cast<const void*>(lv.data())),
157                  numHomeAtoms * sizeof(rvec),
158                  MPI_BYTE,
159                  dd->masterrank,
160                  dd->rank,
161                  dd->mpi_comm_all);
162 #endif
163     }
164     else
165     {
166         AtomDistribution& ma = *dd->ma;
167
168         int rank      = dd->masterrank;
169         int localAtom = 0;
170         for (const int& globalAtom : ma.domainGroups[rank].atomGroups)
171         {
172             copy_rvec(lv[localAtom++], v[globalAtom]);
173         }
174
175         for (int rank = 0; rank < dd->nnodes; rank++)
176         {
177             if (rank != dd->rank)
178             {
179                 const auto& domainGroups = ma.domainGroups[rank];
180
181                 GMX_RELEASE_ASSERT(v.data() != ma.rvecBuffer.data(),
182                                    "We need different communication and return buffers");
183
184                 /* When we send/recv instead of scatter/gather, we might need
185                  * to increase the communication buffer size here.
186                  */
187                 if (static_cast<size_t>(domainGroups.numAtoms) > ma.rvecBuffer.size())
188                 {
189                     ma.rvecBuffer.resize(domainGroups.numAtoms);
190                 }
191
192 #if GMX_MPI
193                 MPI_Recv(ma.rvecBuffer.data(),
194                          domainGroups.numAtoms * sizeof(rvec),
195                          MPI_BYTE,
196                          rank,
197                          rank,
198                          dd->mpi_comm_all,
199                          MPI_STATUS_IGNORE);
200 #endif
201                 int localAtom = 0;
202                 for (const int& globalAtom : domainGroups.atomGroups)
203                 {
204                     copy_rvec(ma.rvecBuffer[localAtom++], v[globalAtom]);
205                 }
206             }
207         }
208     }
209 }
210
211 static void dd_collect_vec_gatherv(gmx_domdec_t*                  dd,
212                                    gmx::ArrayRef<const gmx::RVec> lv,
213                                    gmx::ArrayRef<gmx::RVec>       v)
214 {
215     int* recvCounts    = nullptr;
216     int* displacements = nullptr;
217
218     if (DDMASTER(dd))
219     {
220         get_commbuffer_counts(dd->ma.get(), &recvCounts, &displacements);
221     }
222
223     const int numHomeAtoms = dd->comm->atomRanges.numHomeAtoms();
224     dd_gatherv(dd,
225                numHomeAtoms * sizeof(rvec),
226                lv.data(),
227                recvCounts,
228                displacements,
229                DDMASTER(dd) ? dd->ma->rvecBuffer.data() : nullptr);
230
231     if (DDMASTER(dd))
232     {
233         const AtomDistribution& ma = *dd->ma;
234
235         int bufferAtom = 0;
236         for (int rank = 0; rank < dd->nnodes; rank++)
237         {
238             const auto& domainGroups = ma.domainGroups[rank];
239             for (const int& globalAtom : domainGroups.atomGroups)
240             {
241                 copy_rvec(ma.rvecBuffer[bufferAtom++], v[globalAtom]);
242             }
243         }
244     }
245 }
246
247 void dd_collect_vec(gmx_domdec_t*                  dd,
248                     const int                      ddpCount,
249                     const int                      ddpCountCgGl,
250                     gmx::ArrayRef<const int>       localCGNumbers,
251                     gmx::ArrayRef<const gmx::RVec> localVector,
252                     gmx::ArrayRef<gmx::RVec>       globalVector)
253 {
254     dd_collect_cg(dd, ddpCount, ddpCountCgGl, localCGNumbers);
255
256     if (dd->nnodes <= c_maxNumRanksUseSendRecvForScatterAndGather)
257     {
258         dd_collect_vec_sendrecv(dd, localVector, globalVector);
259     }
260     else
261     {
262         dd_collect_vec_gatherv(dd, localVector, globalVector);
263     }
264 }
265
266
267 void dd_collect_state(gmx_domdec_t* dd, const t_state* state_local, t_state* state)
268 {
269     int nh = state_local->nhchainlength;
270
271     if (DDMASTER(dd))
272     {
273         GMX_RELEASE_ASSERT(state->nhchainlength == nh,
274                            "The global and local Nose-Hoover chain lengths should match");
275
276         for (auto i : gmx::EnumerationArray<FreeEnergyPerturbationCouplingType, real>::keys())
277         {
278             state->lambda[i] = state_local->lambda[i];
279         }
280         state->fep_state = state_local->fep_state;
281         state->veta      = state_local->veta;
282         state->vol0      = state_local->vol0;
283         copy_mat(state_local->box, state->box);
284         copy_mat(state_local->boxv, state->boxv);
285         copy_mat(state_local->svir_prev, state->svir_prev);
286         copy_mat(state_local->fvir_prev, state->fvir_prev);
287         copy_mat(state_local->pres_prev, state->pres_prev);
288
289         for (int i = 0; i < state_local->ngtc; i++)
290         {
291             for (int j = 0; j < nh; j++)
292             {
293                 state->nosehoover_xi[i * nh + j]  = state_local->nosehoover_xi[i * nh + j];
294                 state->nosehoover_vxi[i * nh + j] = state_local->nosehoover_vxi[i * nh + j];
295             }
296             state->therm_integral[i] = state_local->therm_integral[i];
297         }
298         for (int i = 0; i < state_local->nnhpres; i++)
299         {
300             for (int j = 0; j < nh; j++)
301             {
302                 state->nhpres_xi[i * nh + j]  = state_local->nhpres_xi[i * nh + j];
303                 state->nhpres_vxi[i * nh + j] = state_local->nhpres_vxi[i * nh + j];
304             }
305         }
306         state->baros_integral     = state_local->baros_integral;
307         state->pull_com_prev_step = state_local->pull_com_prev_step;
308     }
309     if (state_local->flags & enumValueToBitMask(StateEntry::X))
310     {
311         auto globalXRef = state ? state->x : gmx::ArrayRef<gmx::RVec>();
312         dd_collect_vec(dd,
313                        state_local->ddp_count,
314                        state_local->ddp_count_cg_gl,
315                        state_local->cg_gl,
316                        state_local->x,
317                        globalXRef);
318     }
319     if (state_local->flags & enumValueToBitMask(StateEntry::V))
320     {
321         auto globalVRef = state ? state->v : gmx::ArrayRef<gmx::RVec>();
322         dd_collect_vec(dd,
323                        state_local->ddp_count,
324                        state_local->ddp_count_cg_gl,
325                        state_local->cg_gl,
326                        state_local->v,
327                        globalVRef);
328     }
329     if (state_local->flags & enumValueToBitMask(StateEntry::Cgp))
330     {
331         auto globalCgpRef = state ? state->cg_p : gmx::ArrayRef<gmx::RVec>();
332         dd_collect_vec(dd,
333                        state_local->ddp_count,
334                        state_local->ddp_count_cg_gl,
335                        state_local->cg_gl,
336                        state_local->cg_p,
337                        globalCgpRef);
338     }
339 }