5b9c6d5c07dd55b6770b79c548416fbc9c4333a7
[alexxy/gromacs.git] / src / gromacs / domdec / collect.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2018, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /* \internal \file
36  *
37  * \brief Implements functions to collect state data to the master rank.
38  *
39  * \author Berk Hess <hess@kth.se>
40  * \ingroup module_domdec
41  */
42
43 #include "gmxpre.h"
44
45 #include "collect.h"
46
47 #include "config.h"
48
49 #include "gromacs/domdec/domdec_network.h"
50 #include "gromacs/math/vec.h"
51 #include "gromacs/mdtypes/state.h"
52 #include "gromacs/utility/fatalerror.h"
53
54 #include "atomdistribution.h"
55 #include "distribute.h"
56 #include "domdec_internal.h"
57
58 static void dd_collect_cg(gmx_domdec_t  *dd,
59                           const t_state *state_local)
60 {
61     if (state_local->ddp_count == dd->comm->master_cg_ddp_count)
62     {
63         /* The master has the correct distribution */
64         return;
65     }
66
67     gmx::ArrayRef<const int> atomGroups;
68     int                      nat_home = 0;
69
70     if (state_local->ddp_count == dd->ddp_count)
71     {
72         /* The local state and DD are in sync, use the DD indices */
73         atomGroups = gmx::constArrayRefFromArray(dd->globalAtomGroupIndices.data(), dd->ncg_home);
74         nat_home   = dd->comm->atomRanges.numHomeAtoms();
75     }
76     else if (state_local->ddp_count_cg_gl == state_local->ddp_count)
77     {
78         /* The DD is out of sync with the local state, but we have stored
79          * the cg indices with the local state, so we can use those.
80          */
81         const t_block &cgs_gl = dd->comm->cgs_gl;
82
83         atomGroups = state_local->cg_gl;
84         nat_home   = 0;
85         for (const int &i : atomGroups)
86         {
87             nat_home += cgs_gl.index[i + 1] - cgs_gl.index[i];
88         }
89     }
90     else
91     {
92         gmx_incons("Attempted to collect a vector for a state for which the charge group distribution is unknown");
93     }
94
95     AtomDistribution *ma = dd->ma.get();
96
97     /* Collect the charge group and atom counts on the master */
98     int localBuffer[2] = { static_cast<int>(atomGroups.size()), nat_home };
99     dd_gather(dd, 2*sizeof(int), localBuffer,
100               DDMASTER(dd) ? ma->intBuffer.data() : nullptr);
101
102     if (DDMASTER(dd))
103     {
104         int groupOffset = 0;
105         for (int rank = 0; rank < dd->nnodes; rank++)
106         {
107             auto &domainGroups       = ma->domainGroups[rank];
108             int   numGroups          = ma->intBuffer[2*rank];
109
110             domainGroups.atomGroups  = gmx::constArrayRefFromArray(ma->atomGroups.data() + groupOffset, numGroups);
111
112             domainGroups.numAtoms    = ma->intBuffer[2*rank + 1];
113
114             groupOffset             += numGroups;
115         }
116
117         if (debug)
118         {
119             fprintf(debug, "Initial charge group distribution: ");
120             for (int rank = 0; rank < dd->nnodes; rank++)
121             {
122                 fprintf(debug, " %td", ma->domainGroups[rank].atomGroups.size());
123             }
124             fprintf(debug, "\n");
125         }
126
127         /* Make byte counts and indices */
128         int offset = 0;
129         for (int rank = 0; rank < dd->nnodes; rank++)
130         {
131             int numGroups                     = ma->domainGroups[rank].atomGroups.size();
132             ma->intBuffer[rank]               = numGroups*sizeof(int);
133             ma->intBuffer[dd->nnodes + rank]  = offset*sizeof(int);
134             offset                           += numGroups;
135         }
136     }
137
138     /* Collect the charge group indices on the master */
139     dd_gatherv(dd,
140                atomGroups.size()*sizeof(int), atomGroups.data(),
141                DDMASTER(dd) ? ma->intBuffer.data() : nullptr,
142                DDMASTER(dd) ? ma->intBuffer.data() + dd->nnodes : nullptr,
143                DDMASTER(dd) ? ma->atomGroups.data() : nullptr);
144
145     dd->comm->master_cg_ddp_count = state_local->ddp_count;
146 }
147
148 static void dd_collect_vec_sendrecv(gmx_domdec_t                  *dd,
149                                     gmx::ArrayRef<const gmx::RVec> lv,
150                                     gmx::ArrayRef<gmx::RVec>       v)
151 {
152     if (!DDMASTER(dd))
153     {
154 #if GMX_MPI
155         const int numHomeAtoms = dd->comm->atomRanges.numHomeAtoms();
156         MPI_Send(const_cast<void *>(static_cast<const void *>(lv.data())), numHomeAtoms*sizeof(rvec), MPI_BYTE,
157                  dd->masterrank, dd->rank, dd->mpi_comm_all);
158 #endif
159     }
160     else
161     {
162         AtomDistribution &ma = *dd->ma;
163
164         /* Copy the master coordinates to the global array */
165         const t_block &cgs_gl    = dd->comm->cgs_gl;
166
167         int            rank      = dd->masterrank;
168         int            localAtom = 0;
169         for (const int &i : ma.domainGroups[rank].atomGroups)
170         {
171             for (int globalAtom = cgs_gl.index[i]; globalAtom < cgs_gl.index[i + 1]; globalAtom++)
172             {
173                 copy_rvec(lv[localAtom++], v[globalAtom]);
174             }
175         }
176
177         for (int rank = 0; rank < dd->nnodes; rank++)
178         {
179             if (rank != dd->rank)
180             {
181                 const auto &domainGroups = ma.domainGroups[rank];
182
183                 GMX_RELEASE_ASSERT(v.data() != ma.rvecBuffer.data(), "We need different communication and return buffers");
184
185                 /* When we send/recv instead of scatter/gather, we might need
186                  * to increase the communication buffer size here.
187                  */
188                 if (static_cast<size_t>(domainGroups.numAtoms) > ma.rvecBuffer.size())
189                 {
190                     ma.rvecBuffer.resize(domainGroups.numAtoms);
191                 }
192
193 #if GMX_MPI
194                 MPI_Recv(ma.rvecBuffer.data(), domainGroups.numAtoms*sizeof(rvec), MPI_BYTE, rank,
195                          rank, dd->mpi_comm_all, MPI_STATUS_IGNORE);
196 #endif
197                 int localAtom = 0;
198                 for (const int &cg : domainGroups.atomGroups)
199                 {
200                     for (int globalAtom = cgs_gl.index[cg]; globalAtom < cgs_gl.index[cg + 1]; globalAtom++)
201                     {
202                         copy_rvec(ma.rvecBuffer[localAtom++], v[globalAtom]);
203                     }
204                 }
205             }
206         }
207     }
208 }
209
210 static void dd_collect_vec_gatherv(gmx_domdec_t                  *dd,
211                                    gmx::ArrayRef<const gmx::RVec> lv,
212                                    gmx::ArrayRef<gmx::RVec>       v)
213 {
214     int *recvCounts    = nullptr;
215     int *displacements = nullptr;
216
217     if (DDMASTER(dd))
218     {
219         get_commbuffer_counts(dd->ma.get(), &recvCounts, &displacements);
220     }
221
222     const int numHomeAtoms = dd->comm->atomRanges.numHomeAtoms();
223     dd_gatherv(dd, numHomeAtoms*sizeof(rvec), lv.data(), recvCounts, displacements,
224                DDMASTER(dd) ? dd->ma->rvecBuffer.data() : nullptr);
225
226     if (DDMASTER(dd))
227     {
228         const AtomDistribution &ma     = *dd->ma;
229
230         const t_block          &cgs_gl = dd->comm->cgs_gl;
231
232         int                     bufferAtom = 0;
233         for (int rank = 0; rank < dd->nnodes; rank++)
234         {
235             const auto &domainGroups = ma.domainGroups[rank];
236             for (const int &cg : domainGroups.atomGroups)
237             {
238                 for (int globalAtom = cgs_gl.index[cg]; globalAtom < cgs_gl.index[cg + 1]; globalAtom++)
239                 {
240                     copy_rvec(ma.rvecBuffer[bufferAtom++], v[globalAtom]);
241                 }
242             }
243         }
244     }
245 }
246
247 void dd_collect_vec(gmx_domdec_t                  *dd,
248                     const t_state                 *state_local,
249                     gmx::ArrayRef<const gmx::RVec> lv,
250                     gmx::ArrayRef<gmx::RVec>       v)
251 {
252     dd_collect_cg(dd, state_local);
253
254     if (dd->nnodes <= c_maxNumRanksUseSendRecvForScatterAndGather)
255     {
256         dd_collect_vec_sendrecv(dd, lv, v);
257     }
258     else
259     {
260         dd_collect_vec_gatherv(dd, lv, v);
261     }
262 }
263
264
265 void dd_collect_state(gmx_domdec_t *dd,
266                       const t_state *state_local, t_state *state)
267 {
268     int nh = state_local->nhchainlength;
269
270     if (DDMASTER(dd))
271     {
272         GMX_RELEASE_ASSERT(state->nhchainlength == nh, "The global and local Nose-Hoover chain lengths should match");
273
274         for (int i = 0; i < efptNR; i++)
275         {
276             state->lambda[i] = state_local->lambda[i];
277         }
278         state->fep_state = state_local->fep_state;
279         state->veta      = state_local->veta;
280         state->vol0      = state_local->vol0;
281         copy_mat(state_local->box, state->box);
282         copy_mat(state_local->boxv, state->boxv);
283         copy_mat(state_local->svir_prev, state->svir_prev);
284         copy_mat(state_local->fvir_prev, state->fvir_prev);
285         copy_mat(state_local->pres_prev, state->pres_prev);
286
287         for (int i = 0; i < state_local->ngtc; i++)
288         {
289             for (int j = 0; j < nh; j++)
290             {
291                 state->nosehoover_xi[i*nh+j]        = state_local->nosehoover_xi[i*nh+j];
292                 state->nosehoover_vxi[i*nh+j]       = state_local->nosehoover_vxi[i*nh+j];
293             }
294             state->therm_integral[i] = state_local->therm_integral[i];
295         }
296         for (int i = 0; i < state_local->nnhpres; i++)
297         {
298             for (int j = 0; j < nh; j++)
299             {
300                 state->nhpres_xi[i*nh+j]        = state_local->nhpres_xi[i*nh+j];
301                 state->nhpres_vxi[i*nh+j]       = state_local->nhpres_vxi[i*nh+j];
302             }
303         }
304         state->baros_integral      = state_local->baros_integral;
305         state->pull_com_prev_step  = state_local->pull_com_prev_step;
306     }
307     if (state_local->flags & (1 << estX))
308     {
309         gmx::ArrayRef<gmx::RVec> globalXRef = state ? makeArrayRef(state->x) : gmx::EmptyArrayRef();
310         dd_collect_vec(dd, state_local, makeConstArrayRef(state_local->x), globalXRef);
311     }
312     if (state_local->flags & (1 << estV))
313     {
314         gmx::ArrayRef<gmx::RVec> globalVRef = state ? makeArrayRef(state->v) : gmx::EmptyArrayRef();
315         dd_collect_vec(dd, state_local, makeConstArrayRef(state_local->v), globalVRef);
316     }
317     if (state_local->flags & (1 << estCGP))
318     {
319         gmx::ArrayRef<gmx::RVec> globalCgpRef = state ? makeArrayRef(state->cg_p) : gmx::EmptyArrayRef();
320         dd_collect_vec(dd, state_local, makeConstArrayRef(state_local->cg_p), globalCgpRef);
321     }
322 }