9cff8a7376d189814a2af0066dd77e6854be14c4
[alexxy/gromacs.git] / src / gromacs / domdec / distribute.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2018,2019,2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /* \internal \file
36  *
37  * \brief Implements atom distribution functions.
38  *
39  * \author Berk Hess <hess@kth.se>
40  * \ingroup module_domdec
41  */
42
43 #include "gmxpre.h"
44
45 #include "distribute.h"
46
47 #include "config.h"
48
49 #include <vector>
50
51 #include "gromacs/domdec/domdec_network.h"
52 #include "gromacs/math/vec.h"
53 #include "gromacs/mdtypes/commrec.h"
54 #include "gromacs/mdtypes/df_history.h"
55 #include "gromacs/mdtypes/state.h"
56 #include "gromacs/topology/topology.h"
57 #include "gromacs/utility/enumerationhelpers.h"
58 #include "gromacs/utility/fatalerror.h"
59 #include "gromacs/utility/logger.h"
60
61 #include "atomdistribution.h"
62 #include "cellsizes.h"
63 #include "domdec_internal.h"
64 #include "utility.h"
65
66 static void distributeVecSendrecv(gmx_domdec_t*                  dd,
67                                   gmx::ArrayRef<const gmx::RVec> globalVec,
68                                   gmx::ArrayRef<gmx::RVec>       localVec)
69 {
70     if (DDMASTER(dd))
71     {
72         std::vector<gmx::RVec> buffer;
73
74         for (int rank = 0; rank < dd->nnodes; rank++)
75         {
76             if (rank != dd->rank)
77             {
78                 const auto& domainGroups = dd->ma->domainGroups[rank];
79
80                 buffer.resize(domainGroups.numAtoms);
81
82                 int localAtom = 0;
83                 for (const int& globalAtom : domainGroups.atomGroups)
84                 {
85                     buffer[localAtom++] = globalVec[globalAtom];
86                 }
87                 GMX_RELEASE_ASSERT(localAtom == domainGroups.numAtoms,
88                                    "The index count and number of indices should match");
89
90 #if GMX_MPI
91                 MPI_Send(buffer.data(), domainGroups.numAtoms * sizeof(gmx::RVec), MPI_BYTE, rank, rank, dd->mpi_comm_all);
92 #endif
93             }
94         }
95
96         const auto& domainGroups = dd->ma->domainGroups[dd->masterrank];
97         int         localAtom    = 0;
98         for (const int& globalAtom : domainGroups.atomGroups)
99         {
100             localVec[localAtom++] = globalVec[globalAtom];
101         }
102     }
103     else
104     {
105 #if GMX_MPI
106         int numHomeAtoms = dd->comm->atomRanges.numHomeAtoms();
107         MPI_Recv(localVec.data(),
108                  numHomeAtoms * sizeof(gmx::RVec),
109                  MPI_BYTE,
110                  dd->masterrank,
111                  MPI_ANY_TAG,
112                  dd->mpi_comm_all,
113                  MPI_STATUS_IGNORE);
114 #endif
115     }
116 }
117
118 static void distributeVecScatterv(gmx_domdec_t*                  dd,
119                                   gmx::ArrayRef<const gmx::RVec> globalVec,
120                                   gmx::ArrayRef<gmx::RVec>       localVec)
121 {
122     int* sendCounts    = nullptr;
123     int* displacements = nullptr;
124
125     if (DDMASTER(dd))
126     {
127         AtomDistribution& ma = *dd->ma;
128
129         get_commbuffer_counts(&ma, &sendCounts, &displacements);
130
131         gmx::ArrayRef<gmx::RVec> buffer    = ma.rvecBuffer;
132         int                      localAtom = 0;
133         for (int rank = 0; rank < dd->nnodes; rank++)
134         {
135             const auto& domainGroups = ma.domainGroups[rank];
136             for (const int& globalAtom : domainGroups.atomGroups)
137             {
138                 buffer[localAtom++] = globalVec[globalAtom];
139             }
140         }
141     }
142
143     int numHomeAtoms = dd->comm->atomRanges.numHomeAtoms();
144     dd_scatterv(dd,
145                 sendCounts,
146                 displacements,
147                 DDMASTER(dd) ? dd->ma->rvecBuffer.data() : nullptr,
148                 numHomeAtoms * sizeof(gmx::RVec),
149                 localVec.data());
150 }
151
152 static void distributeVec(gmx_domdec_t*                  dd,
153                           gmx::ArrayRef<const gmx::RVec> globalVec,
154                           gmx::ArrayRef<gmx::RVec>       localVec)
155 {
156     if (dd->nnodes <= c_maxNumRanksUseSendRecvForScatterAndGather)
157     {
158         distributeVecSendrecv(dd, globalVec, localVec);
159     }
160     else
161     {
162         distributeVecScatterv(dd, globalVec, localVec);
163     }
164 }
165
166 static void dd_distribute_dfhist(gmx_domdec_t* dd, df_history_t* dfhist)
167 {
168     if (dfhist == nullptr)
169     {
170         return;
171     }
172
173     dd_bcast(dd, sizeof(int), &dfhist->bEquil);
174     dd_bcast(dd, sizeof(int), &dfhist->nlambda);
175     dd_bcast(dd, sizeof(real), &dfhist->wl_delta);
176
177     if (dfhist->nlambda > 0)
178     {
179         int nlam = dfhist->nlambda;
180         dd_bcast(dd, sizeof(int) * nlam, dfhist->n_at_lam);
181         dd_bcast(dd, sizeof(real) * nlam, dfhist->wl_histo);
182         dd_bcast(dd, sizeof(real) * nlam, dfhist->sum_weights);
183         dd_bcast(dd, sizeof(real) * nlam, dfhist->sum_dg);
184         dd_bcast(dd, sizeof(real) * nlam, dfhist->sum_minvar);
185         dd_bcast(dd, sizeof(real) * nlam, dfhist->sum_variance);
186
187         for (int i = 0; i < nlam; i++)
188         {
189             dd_bcast(dd, sizeof(real) * nlam, dfhist->accum_p[i]);
190             dd_bcast(dd, sizeof(real) * nlam, dfhist->accum_m[i]);
191             dd_bcast(dd, sizeof(real) * nlam, dfhist->accum_p2[i]);
192             dd_bcast(dd, sizeof(real) * nlam, dfhist->accum_m2[i]);
193             dd_bcast(dd, sizeof(real) * nlam, dfhist->Tij[i]);
194             dd_bcast(dd, sizeof(real) * nlam, dfhist->Tij_empirical[i]);
195         }
196     }
197 }
198
199 static void dd_distribute_state(gmx_domdec_t* dd, const t_state* state, t_state* state_local)
200 {
201     int nh = state_local->nhchainlength;
202
203     if (DDMASTER(dd))
204     {
205         GMX_RELEASE_ASSERT(state->nhchainlength == nh,
206                            "The global and local Nose-Hoover chain lengths should match");
207
208         for (auto i : gmx::EnumerationArray<FreeEnergyPerturbationCouplingType, real>::keys())
209         {
210             state_local->lambda[i] = state->lambda[i];
211         }
212         state_local->fep_state = state->fep_state;
213         state_local->veta      = state->veta;
214         state_local->vol0      = state->vol0;
215         copy_mat(state->box, state_local->box);
216         copy_mat(state->box_rel, state_local->box_rel);
217         copy_mat(state->boxv, state_local->boxv);
218         copy_mat(state->svir_prev, state_local->svir_prev);
219         copy_mat(state->fvir_prev, state_local->fvir_prev);
220         if (state->dfhist != nullptr)
221         {
222             copy_df_history(state_local->dfhist, state->dfhist);
223         }
224         for (int i = 0; i < state_local->ngtc; i++)
225         {
226             for (int j = 0; j < nh; j++)
227             {
228                 state_local->nosehoover_xi[i * nh + j]  = state->nosehoover_xi[i * nh + j];
229                 state_local->nosehoover_vxi[i * nh + j] = state->nosehoover_vxi[i * nh + j];
230             }
231             state_local->therm_integral[i] = state->therm_integral[i];
232         }
233         for (int i = 0; i < state_local->nnhpres; i++)
234         {
235             for (int j = 0; j < nh; j++)
236             {
237                 state_local->nhpres_xi[i * nh + j]  = state->nhpres_xi[i * nh + j];
238                 state_local->nhpres_vxi[i * nh + j] = state->nhpres_vxi[i * nh + j];
239             }
240         }
241         state_local->baros_integral = state->baros_integral;
242     }
243     dd_bcast(dd,
244              (static_cast<int>(FreeEnergyPerturbationCouplingType::Count) * sizeof(real)),
245              state_local->lambda.data());
246     dd_bcast(dd, sizeof(int), &state_local->fep_state);
247     dd_bcast(dd, sizeof(real), &state_local->veta);
248     dd_bcast(dd, sizeof(real), &state_local->vol0);
249     dd_bcast(dd, sizeof(state_local->box), state_local->box);
250     dd_bcast(dd, sizeof(state_local->box_rel), state_local->box_rel);
251     dd_bcast(dd, sizeof(state_local->boxv), state_local->boxv);
252     dd_bcast(dd, sizeof(state_local->svir_prev), state_local->svir_prev);
253     dd_bcast(dd, sizeof(state_local->fvir_prev), state_local->fvir_prev);
254     dd_bcast(dd, ((state_local->ngtc * nh) * sizeof(double)), state_local->nosehoover_xi.data());
255     dd_bcast(dd, ((state_local->ngtc * nh) * sizeof(double)), state_local->nosehoover_vxi.data());
256     dd_bcast(dd, state_local->ngtc * sizeof(double), state_local->therm_integral.data());
257     dd_bcast(dd, ((state_local->nnhpres * nh) * sizeof(double)), state_local->nhpres_xi.data());
258     dd_bcast(dd, ((state_local->nnhpres * nh) * sizeof(double)), state_local->nhpres_vxi.data());
259
260     /* communicate df_history -- required for restarting from checkpoint */
261     dd_distribute_dfhist(dd, state_local->dfhist);
262
263     state_change_natoms(state_local, dd->comm->atomRanges.numHomeAtoms());
264
265     if (state_local->flags & enumValueToBitMask(StateEntry::X))
266     {
267         distributeVec(dd, DDMASTER(dd) ? state->x : gmx::ArrayRef<const gmx::RVec>(), state_local->x);
268     }
269     if (state_local->flags & enumValueToBitMask(StateEntry::V))
270     {
271         distributeVec(dd, DDMASTER(dd) ? state->v : gmx::ArrayRef<const gmx::RVec>(), state_local->v);
272     }
273     if (state_local->flags & enumValueToBitMask(StateEntry::Cgp))
274     {
275         distributeVec(dd, DDMASTER(dd) ? state->cg_p : gmx::ArrayRef<const gmx::RVec>(), state_local->cg_p);
276     }
277 }
278
279 /* Computes and returns the domain index for the given atom group.
280  *
281  * Also updates the coordinates in pos for PBC, when necessary.
282  */
283 static inline int computeAtomGroupDomainIndex(const gmx_domdec_t& dd,
284                                               const gmx_ddbox_t&  ddbox,
285                                               const matrix&       triclinicCorrectionMatrix,
286                                               gmx::ArrayRef<const std::vector<real>> cellBoundaries,
287                                               int                                    atomBegin,
288                                               int                                    atomEnd,
289                                               const matrix                           box,
290                                               rvec*                                  pos)
291 {
292     /* Set the reference location cg_cm for assigning the group */
293     rvec cog;
294     int  numAtoms = atomEnd - atomBegin;
295     if (numAtoms == 1)
296     {
297         copy_rvec(pos[atomBegin], cog);
298     }
299     else
300     {
301         real invNumAtoms = 1 / static_cast<real>(numAtoms);
302
303         clear_rvec(cog);
304         for (int a = atomBegin; a < atomEnd; a++)
305         {
306             rvec_inc(cog, pos[a]);
307         }
308         for (int d = 0; d < DIM; d++)
309         {
310             cog[d] *= invNumAtoms;
311         }
312     }
313     /* Put the charge group in the box and determine the cell index ind */
314     ivec ind;
315     for (int d = DIM - 1; d >= 0; d--)
316     {
317         real pos_d = cog[d];
318         if (d < dd.unitCellInfo.npbcdim)
319         {
320             bool bScrew = (dd.unitCellInfo.haveScrewPBC && d == XX);
321             if (ddbox.tric_dir[d] && dd.numCells[d] > 1)
322             {
323                 /* Use triclinic coordinates for this dimension */
324                 for (int j = d + 1; j < DIM; j++)
325                 {
326                     pos_d += cog[j] * triclinicCorrectionMatrix[j][d];
327                 }
328             }
329             while (pos_d >= box[d][d])
330             {
331                 pos_d -= box[d][d];
332                 rvec_dec(cog, box[d]);
333                 if (bScrew)
334                 {
335                     cog[YY] = box[YY][YY] - cog[YY];
336                     cog[ZZ] = box[ZZ][ZZ] - cog[ZZ];
337                 }
338                 for (int a = atomBegin; a < atomEnd; a++)
339                 {
340                     rvec_dec(pos[a], box[d]);
341                     if (bScrew)
342                     {
343                         pos[a][YY] = box[YY][YY] - pos[a][YY];
344                         pos[a][ZZ] = box[ZZ][ZZ] - pos[a][ZZ];
345                     }
346                 }
347             }
348             while (pos_d < 0)
349             {
350                 pos_d += box[d][d];
351                 rvec_inc(cog, box[d]);
352                 if (bScrew)
353                 {
354                     cog[YY] = box[YY][YY] - cog[YY];
355                     cog[ZZ] = box[ZZ][ZZ] - cog[ZZ];
356                 }
357                 for (int a = atomBegin; a < atomEnd; a++)
358                 {
359                     rvec_inc(pos[a], box[d]);
360                     if (bScrew)
361                     {
362                         pos[a][YY] = box[YY][YY] - pos[a][YY];
363                         pos[a][ZZ] = box[ZZ][ZZ] - pos[a][ZZ];
364                     }
365                 }
366             }
367         }
368         /* This could be done more efficiently */
369         ind[d] = 0;
370         while (ind[d] + 1 < dd.numCells[d] && pos_d >= cellBoundaries[d][ind[d] + 1])
371         {
372             ind[d]++;
373         }
374     }
375
376     return dd_index(dd.numCells, ind);
377 }
378
379
380 static std::vector<std::vector<int>> getAtomGroupDistribution(const gmx::MDLogger& mdlog,
381                                                               const gmx_mtop_t&    mtop,
382                                                               const matrix         box,
383                                                               const gmx_ddbox_t&   ddbox,
384                                                               rvec                 pos[],
385                                                               gmx_domdec_t*        dd)
386 {
387     AtomDistribution& ma = *dd->ma;
388
389     /* Clear the count */
390     for (int rank = 0; rank < dd->nnodes; rank++)
391     {
392         ma.domainGroups[rank].numAtoms = 0;
393     }
394
395     matrix triclinicCorrectionMatrix;
396     make_tric_corr_matrix(dd->unitCellInfo.npbcdim, box, triclinicCorrectionMatrix);
397
398     ivec       npulse;
399     const auto cellBoundaries = set_dd_cell_sizes_slb(dd, &ddbox, setcellsizeslbMASTER, npulse);
400
401     std::vector<std::vector<int>> indices(dd->nnodes);
402
403     if (dd->comm->systemInfo.useUpdateGroups)
404     {
405         int atomOffset = 0;
406         for (const gmx_molblock_t& molblock : mtop.molblock)
407         {
408             const auto& updateGrouping =
409                     dd->comm->systemInfo.updateGroupingsPerMoleculeType[molblock.type];
410
411             for (int mol = 0; mol < molblock.nmol; mol++)
412             {
413                 for (int g = 0; g < updateGrouping.numBlocks(); g++)
414                 {
415                     const auto& block       = updateGrouping.block(g);
416                     const int   atomBegin   = atomOffset + block.begin();
417                     const int   atomEnd     = atomOffset + block.end();
418                     const int   domainIndex = computeAtomGroupDomainIndex(
419                             *dd, ddbox, triclinicCorrectionMatrix, cellBoundaries, atomBegin, atomEnd, box, pos);
420
421                     for (int atomIndex : block)
422                     {
423                         indices[domainIndex].push_back(atomOffset + atomIndex);
424                     }
425                     ma.domainGroups[domainIndex].numAtoms += block.size();
426                 }
427
428                 atomOffset += updateGrouping.fullRange().end();
429             }
430         }
431
432         GMX_RELEASE_ASSERT(atomOffset == mtop.natoms, "Should distribute all atoms");
433     }
434     else
435     {
436         /* Compute the center of geometry for all atoms */
437         for (int atom = 0; atom < mtop.natoms; atom++)
438         {
439             int domainIndex = computeAtomGroupDomainIndex(
440                     *dd, ddbox, triclinicCorrectionMatrix, cellBoundaries, atom, atom + 1, box, pos);
441
442             indices[domainIndex].push_back(atom);
443             ma.domainGroups[domainIndex].numAtoms += 1;
444         }
445     }
446
447     {
448         // Use double for the sums to avoid natoms^2 overflowing
449         // (65537^2 > 2^32)
450         int    nat_sum  = 0;
451         double nat2_sum = 0;
452         int    nat_min  = ma.domainGroups[0].numAtoms;
453         int    nat_max  = ma.domainGroups[0].numAtoms;
454         for (int rank = 0; rank < dd->nnodes; rank++)
455         {
456             int numAtoms = ma.domainGroups[rank].numAtoms;
457             nat_sum += numAtoms;
458             // convert to double to avoid integer overflows when squaring
459             nat2_sum += gmx::square(double(numAtoms));
460             nat_min = std::min(nat_min, numAtoms);
461             nat_max = std::max(nat_max, numAtoms);
462         }
463         nat_sum /= dd->nnodes;
464         nat2_sum /= dd->nnodes;
465
466         GMX_LOG(mdlog.info)
467                 .appendTextFormatted(
468                         "Atom distribution over %d domains: av %d stddev %d min %d max %d",
469                         dd->nnodes,
470                         nat_sum,
471                         gmx::roundToInt(std::sqrt(nat2_sum - gmx::square(static_cast<double>(nat_sum)))),
472                         nat_min,
473                         nat_max);
474     }
475
476     return indices;
477 }
478
479 static void distributeAtomGroups(const gmx::MDLogger& mdlog,
480                                  gmx_domdec_t*        dd,
481                                  const gmx_mtop_t&    mtop,
482                                  const matrix         box,
483                                  const gmx_ddbox_t*   ddbox,
484                                  rvec                 pos[])
485 {
486     AtomDistribution* ma   = dd->ma.get();
487     int *             ibuf = nullptr, buf2[2] = { 0, 0 };
488     gmx_bool          bMaster = DDMASTER(dd);
489
490     std::vector<std::vector<int>> groupIndices;
491
492     if (bMaster)
493     {
494         GMX_ASSERT(box && pos, "box or pos not set on master");
495
496         if (dd->unitCellInfo.haveScrewPBC)
497         {
498             check_screw_box(box);
499         }
500
501         groupIndices = getAtomGroupDistribution(mdlog, mtop, box, *ddbox, pos, dd);
502
503         for (int rank = 0; rank < dd->nnodes; rank++)
504         {
505             ma->intBuffer[rank * 2]     = groupIndices[rank].size();
506             ma->intBuffer[rank * 2 + 1] = ma->domainGroups[rank].numAtoms;
507         }
508         ibuf = ma->intBuffer.data();
509     }
510     else
511     {
512         ibuf = nullptr;
513     }
514     dd_scatter(dd, 2 * sizeof(int), ibuf, buf2);
515
516     dd->ncg_home = buf2[0];
517     dd->comm->atomRanges.setEnd(DDAtomRanges::Type::Home, buf2[1]);
518     dd->globalAtomGroupIndices.resize(dd->ncg_home);
519     dd->globalAtomIndices.resize(dd->comm->atomRanges.numHomeAtoms());
520
521     if (bMaster)
522     {
523         ma->atomGroups.clear();
524
525         int groupOffset = 0;
526         for (int rank = 0; rank < dd->nnodes; rank++)
527         {
528             ma->intBuffer[rank]              = groupIndices[rank].size() * sizeof(int);
529             ma->intBuffer[dd->nnodes + rank] = groupOffset * sizeof(int);
530
531             ma->atomGroups.insert(
532                     ma->atomGroups.end(), groupIndices[rank].begin(), groupIndices[rank].end());
533
534             ma->domainGroups[rank].atomGroups = gmx::constArrayRefFromArray(
535                     ma->atomGroups.data() + groupOffset, groupIndices[rank].size());
536
537             groupOffset += groupIndices[rank].size();
538         }
539     }
540
541     dd_scatterv(dd,
542                 bMaster ? ma->intBuffer.data() : nullptr,
543                 bMaster ? ma->intBuffer.data() + dd->nnodes : nullptr,
544                 bMaster ? ma->atomGroups.data() : nullptr,
545                 dd->ncg_home * sizeof(int),
546                 dd->globalAtomGroupIndices.data());
547
548     if (debug)
549     {
550         fprintf(debug, "Home charge groups:\n");
551         for (int i = 0; i < dd->ncg_home; i++)
552         {
553             fprintf(debug, " %d", dd->globalAtomGroupIndices[i]);
554             if (i % 10 == 9)
555             {
556                 fprintf(debug, "\n");
557             }
558         }
559         fprintf(debug, "\n");
560     }
561 }
562
563 void distributeState(const gmx::MDLogger& mdlog,
564                      gmx_domdec_t*        dd,
565                      const gmx_mtop_t&    mtop,
566                      t_state*             state_global,
567                      const gmx_ddbox_t&   ddbox,
568                      t_state*             state_local)
569 {
570     rvec* xGlobal = (DDMASTER(dd) ? state_global->x.rvec_array() : nullptr);
571
572     distributeAtomGroups(mdlog, dd, mtop, DDMASTER(dd) ? state_global->box : nullptr, &ddbox, xGlobal);
573
574     dd_distribute_state(dd, state_global, state_local);
575 }