Tests of restrained listed potentials.
[alexxy/gromacs.git] / src / gromacs / mdlib / nbnxn_search.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2012,2013,2014,2015,2016,2017,2018,2019, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 #include "gmxpre.h"
37
38 #include "nbnxn_search.h"
39
40 #include "config.h"
41
42 #include <cassert>
43 #include <cmath>
44 #include <cstring>
45
46 #include <algorithm>
47
48 #include "gromacs/domdec/domdec_struct.h"
49 #include "gromacs/gmxlib/nrnb.h"
50 #include "gromacs/math/functions.h"
51 #include "gromacs/math/utilities.h"
52 #include "gromacs/math/vec.h"
53 #include "gromacs/mdlib/gmx_omp_nthreads.h"
54 #include "gromacs/mdlib/nb_verlet.h"
55 #include "gromacs/mdlib/nbnxn_atomdata.h"
56 #include "gromacs/mdlib/nbnxn_consts.h"
57 #include "gromacs/mdlib/nbnxn_grid.h"
58 #include "gromacs/mdlib/nbnxn_internal.h"
59 #include "gromacs/mdlib/nbnxn_simd.h"
60 #include "gromacs/mdlib/nbnxn_util.h"
61 #include "gromacs/mdlib/ns.h"
62 #include "gromacs/mdtypes/group.h"
63 #include "gromacs/mdtypes/md_enums.h"
64 #include "gromacs/pbcutil/ishift.h"
65 #include "gromacs/pbcutil/pbc.h"
66 #include "gromacs/simd/simd.h"
67 #include "gromacs/simd/vector_operations.h"
68 #include "gromacs/topology/block.h"
69 #include "gromacs/utility/exceptions.h"
70 #include "gromacs/utility/fatalerror.h"
71 #include "gromacs/utility/gmxomp.h"
72 #include "gromacs/utility/smalloc.h"
73
74 using namespace gmx; // TODO: Remove when this file is moved into gmx namespace
75
76
77 /* We shift the i-particles backward for PBC.
78  * This leads to more conditionals than shifting forward.
79  * We do this to get more balanced pair lists.
80  */
81 constexpr bool c_pbcShiftBackward = true;
82
83
84 static void nbs_cycle_clear(nbnxn_cycle_t *cc)
85 {
86     for (int i = 0; i < enbsCCnr; i++)
87     {
88         cc[i].count = 0;
89         cc[i].c     = 0;
90     }
91 }
92
93 static double Mcyc_av(const nbnxn_cycle_t *cc)
94 {
95     return static_cast<double>(cc->c)*1e-6/cc->count;
96 }
97
98 static void nbs_cycle_print(FILE *fp, const nbnxn_search *nbs)
99 {
100     fprintf(fp, "\n");
101     fprintf(fp, "ns %4d grid %4.1f search %4.1f red.f %5.3f",
102             nbs->cc[enbsCCgrid].count,
103             Mcyc_av(&nbs->cc[enbsCCgrid]),
104             Mcyc_av(&nbs->cc[enbsCCsearch]),
105             Mcyc_av(&nbs->cc[enbsCCreducef]));
106
107     if (nbs->work.size() > 1)
108     {
109         if (nbs->cc[enbsCCcombine].count > 0)
110         {
111             fprintf(fp, " comb %5.2f",
112                     Mcyc_av(&nbs->cc[enbsCCcombine]));
113         }
114         fprintf(fp, " s. th");
115         for (const nbnxn_search_work_t &work : nbs->work)
116         {
117             fprintf(fp, " %4.1f",
118                     Mcyc_av(&work.cc[enbsCCsearch]));
119         }
120     }
121     fprintf(fp, "\n");
122 }
123
124 /* Layout for the nonbonded NxN pair lists */
125 enum class NbnxnLayout
126 {
127     NoSimd4x4, // i-cluster size 4, j-cluster size 4
128     Simd4xN,   // i-cluster size 4, j-cluster size SIMD width
129     Simd2xNN,  // i-cluster size 4, j-cluster size half SIMD width
130     Gpu8x8x8   // i-cluster size 8, j-cluster size 8 + super-clustering
131 };
132
133 #if GMX_SIMD
134 /* Returns the j-cluster size */
135 template <NbnxnLayout layout>
136 static constexpr int jClusterSize()
137 {
138     static_assert(layout == NbnxnLayout::NoSimd4x4 || layout == NbnxnLayout::Simd4xN || layout == NbnxnLayout::Simd2xNN, "Currently jClusterSize only supports CPU layouts");
139
140     return layout == NbnxnLayout::Simd4xN ? GMX_SIMD_REAL_WIDTH : (layout == NbnxnLayout::Simd2xNN ? GMX_SIMD_REAL_WIDTH/2 : c_nbnxnCpuIClusterSize);
141 }
142
143 /*! \brief Returns the j-cluster index given the i-cluster index.
144  *
145  * \tparam    jClusterSize      The number of atoms in a j-cluster
146  * \tparam    jSubClusterIndex  The j-sub-cluster index (0/1), used when size(j-cluster) < size(i-cluster)
147  * \param[in] ci                The i-cluster index
148  */
149 template <int jClusterSize, int jSubClusterIndex>
150 static inline int cjFromCi(int ci)
151 {
152     static_assert(jClusterSize == c_nbnxnCpuIClusterSize/2 || jClusterSize == c_nbnxnCpuIClusterSize || jClusterSize == c_nbnxnCpuIClusterSize*2, "Only j-cluster sizes 2, 4 and 8 are currently implemented");
153
154     static_assert(jSubClusterIndex == 0 || jSubClusterIndex == 1,
155                   "Only sub-cluster indices 0 and 1 are supported");
156
157     if (jClusterSize == c_nbnxnCpuIClusterSize/2)
158     {
159         if (jSubClusterIndex == 0)
160         {
161             return ci << 1;
162         }
163         else
164         {
165             return ((ci + 1) << 1) - 1;
166         }
167     }
168     else if (jClusterSize == c_nbnxnCpuIClusterSize)
169     {
170         return ci;
171     }
172     else
173     {
174         return ci >> 1;
175     }
176 }
177
178 /*! \brief Returns the j-cluster index given the i-cluster index.
179  *
180  * \tparam    layout            The pair-list layout
181  * \tparam    jSubClusterIndex  The j-sub-cluster index (0/1), used when size(j-cluster) < size(i-cluster)
182  * \param[in] ci                The i-cluster index
183  */
184 template <NbnxnLayout layout, int jSubClusterIndex>
185 static inline int cjFromCi(int ci)
186 {
187     constexpr int clusterSize = jClusterSize<layout>();
188
189     return cjFromCi<clusterSize, jSubClusterIndex>(ci);
190 }
191
192 /* Returns the nbnxn coordinate data index given the i-cluster index */
193 template <NbnxnLayout layout>
194 static inline int xIndexFromCi(int ci)
195 {
196     constexpr int clusterSize = jClusterSize<layout>();
197
198     static_assert(clusterSize == c_nbnxnCpuIClusterSize/2 || clusterSize == c_nbnxnCpuIClusterSize || clusterSize == c_nbnxnCpuIClusterSize*2, "Only j-cluster sizes 2, 4 and 8 are currently implemented");
199
200     if (clusterSize <= c_nbnxnCpuIClusterSize)
201     {
202         /* Coordinates are stored packed in groups of 4 */
203         return ci*STRIDE_P4;
204     }
205     else
206     {
207         /* Coordinates packed in 8, i-cluster size is half the packing width */
208         return (ci >> 1)*STRIDE_P8 + (ci & 1)*(c_packX8 >> 1);
209     }
210 }
211
212 /* Returns the nbnxn coordinate data index given the j-cluster index */
213 template <NbnxnLayout layout>
214 static inline int xIndexFromCj(int cj)
215 {
216     constexpr int clusterSize = jClusterSize<layout>();
217
218     static_assert(clusterSize == c_nbnxnCpuIClusterSize/2 || clusterSize == c_nbnxnCpuIClusterSize || clusterSize == c_nbnxnCpuIClusterSize*2, "Only j-cluster sizes 2, 4 and 8 are currently implemented");
219
220     if (clusterSize == c_nbnxnCpuIClusterSize/2)
221     {
222         /* Coordinates are stored packed in groups of 4 */
223         return (cj >> 1)*STRIDE_P4 + (cj & 1)*(c_packX4 >> 1);
224     }
225     else if (clusterSize == c_nbnxnCpuIClusterSize)
226     {
227         /* Coordinates are stored packed in groups of 4 */
228         return cj*STRIDE_P4;
229     }
230     else
231     {
232         /* Coordinates are stored packed in groups of 8 */
233         return cj*STRIDE_P8;
234     }
235 }
236 #endif //GMX_SIMD
237
238 gmx_bool nbnxn_kernel_pairlist_simple(int nb_kernel_type)
239 {
240     if (nb_kernel_type == nbnxnkNotSet)
241     {
242         gmx_fatal(FARGS, "Non-bonded kernel type not set for Verlet-style pair-list.");
243     }
244
245     switch (nb_kernel_type)
246     {
247         case nbnxnk8x8x8_GPU:
248         case nbnxnk8x8x8_PlainC:
249             return FALSE;
250
251         case nbnxnk4x4_PlainC:
252         case nbnxnk4xN_SIMD_4xN:
253         case nbnxnk4xN_SIMD_2xNN:
254             return TRUE;
255
256         default:
257             gmx_incons("Invalid nonbonded kernel type passed!");
258             return FALSE;
259     }
260 }
261
262 /* Initializes a single nbnxn_pairlist_t data structure */
263 static void nbnxn_init_pairlist_fep(t_nblist *nl)
264 {
265     nl->type        = GMX_NBLIST_INTERACTION_FREE_ENERGY;
266     nl->igeometry   = GMX_NBLIST_GEOMETRY_PARTICLE_PARTICLE;
267     /* The interaction functions are set in the free energy kernel fuction */
268     nl->ivdw        = -1;
269     nl->ivdwmod     = -1;
270     nl->ielec       = -1;
271     nl->ielecmod    = -1;
272
273     nl->maxnri      = 0;
274     nl->maxnrj      = 0;
275     nl->nri         = 0;
276     nl->nrj         = 0;
277     nl->iinr        = nullptr;
278     nl->gid         = nullptr;
279     nl->shift       = nullptr;
280     nl->jindex      = nullptr;
281     nl->jjnr        = nullptr;
282     nl->excl_fep    = nullptr;
283
284 }
285
286 static void free_nblist(t_nblist *nl)
287 {
288     sfree(nl->iinr);
289     sfree(nl->gid);
290     sfree(nl->shift);
291     sfree(nl->jindex);
292     sfree(nl->jjnr);
293     sfree(nl->excl_fep);
294 }
295
296 nbnxn_search_work_t::nbnxn_search_work_t() :
297     cp0({{0}}
298         ),
299     buffer_flags({0, nullptr, 0}),
300     ndistc(0),
301     nbl_fep(new t_nblist),
302     cp1({{0}})
303 {
304     nbnxn_init_pairlist_fep(nbl_fep.get());
305
306     nbs_cycle_clear(cc);
307 }
308
309 nbnxn_search_work_t::~nbnxn_search_work_t()
310 {
311     sfree(buffer_flags.flag);
312
313     free_nblist(nbl_fep.get());
314 }
315
316 nbnxn_search::nbnxn_search(const ivec               *n_dd_cells,
317                            const gmx_domdec_zones_t *zones,
318                            gmx_bool                  bFEP,
319                            int                       nthread_max) :
320     bFEP(bFEP),
321     ePBC(epbcNONE), // The correct value will be set during the gridding
322     zones(zones),
323     natoms_local(0),
324     natoms_nonlocal(0),
325     search_count(0),
326     work(nthread_max)
327 {
328     // The correct value will be set during the gridding
329     clear_mat(box);
330     clear_ivec(dd_dim);
331     int numGrids = 1;
332     DomDec = n_dd_cells != nullptr;
333     if (DomDec)
334     {
335         for (int d = 0; d < DIM; d++)
336         {
337             if ((*n_dd_cells)[d] > 1)
338             {
339                 dd_dim[d] = 1;
340                 /* Each grid matches a DD zone */
341                 numGrids *= 2;
342             }
343         }
344     }
345
346     grid.resize(numGrids);
347
348     /* Initialize detailed nbsearch cycle counting */
349     print_cycles = (getenv("GMX_NBNXN_CYCLE") != nullptr);
350     nbs_cycle_clear(cc);
351 }
352
353 nbnxn_search *nbnxn_init_search(const ivec                *n_dd_cells,
354                                 const gmx_domdec_zones_t  *zones,
355                                 gmx_bool                   bFEP,
356                                 int                        nthread_max)
357 {
358     return new nbnxn_search(n_dd_cells, zones, bFEP, nthread_max);
359 }
360
361 static void init_buffer_flags(nbnxn_buffer_flags_t *flags,
362                               int                   natoms)
363 {
364     flags->nflag = (natoms + NBNXN_BUFFERFLAG_SIZE - 1)/NBNXN_BUFFERFLAG_SIZE;
365     if (flags->nflag > flags->flag_nalloc)
366     {
367         flags->flag_nalloc = over_alloc_large(flags->nflag);
368         srenew(flags->flag, flags->flag_nalloc);
369     }
370     for (int b = 0; b < flags->nflag; b++)
371     {
372         bitmask_clear(&(flags->flag[b]));
373     }
374 }
375
376 /* Returns the pair-list cutoff between a bounding box and a grid cell given an atom-to-atom pair-list cutoff
377  *
378  * Given a cutoff distance between atoms, this functions returns the cutoff
379  * distance2 between a bounding box of a group of atoms and a grid cell.
380  * Since atoms can be geometrically outside of the cell they have been
381  * assigned to (when atom groups instead of individual atoms are assigned
382  * to cells), this distance returned can be larger than the input.
383  */
384 static real listRangeForBoundingBoxToGridCell(real                rlist,
385                                               const nbnxn_grid_t &grid)
386 {
387     return rlist + grid.maxAtomGroupRadius;
388
389 }
390 /* Returns the pair-list cutoff between a grid cells given an atom-to-atom pair-list cutoff
391  *
392  * Given a cutoff distance between atoms, this functions returns the cutoff
393  * distance2 between two grid cells.
394  * Since atoms can be geometrically outside of the cell they have been
395  * assigned to (when atom groups instead of individual atoms are assigned
396  * to cells), this distance returned can be larger than the input.
397  */
398 static real listRangeForGridCellToGridCell(real                rlist,
399                                            const nbnxn_grid_t &iGrid,
400                                            const nbnxn_grid_t &jGrid)
401 {
402     return rlist + iGrid.maxAtomGroupRadius + jGrid.maxAtomGroupRadius;
403 }
404
405 /* Determines the cell range along one dimension that
406  * the bounding box b0 - b1 sees.
407  */
408 template<int dim>
409 static void get_cell_range(real b0, real b1,
410                            const nbnxn_grid_t &jGrid,
411                            real d2, real rlist, int *cf, int *cl)
412 {
413     real listRangeBBToCell2 = gmx::square(listRangeForBoundingBoxToGridCell(rlist, jGrid));
414     real distanceInCells    = (b0 - jGrid.c0[dim])*jGrid.invCellSize[dim];
415     *cf                     = std::max(static_cast<int>(distanceInCells), 0);
416
417     while (*cf > 0 &&
418            d2 + gmx::square((b0 - jGrid.c0[dim]) - (*cf - 1 + 1)*jGrid.cellSize[dim]) < listRangeBBToCell2)
419     {
420         (*cf)--;
421     }
422
423     *cl = std::min(static_cast<int>((b1 - jGrid.c0[dim])*jGrid.invCellSize[dim]), jGrid.numCells[dim] - 1);
424     while (*cl < jGrid.numCells[dim] - 1 &&
425            d2 + gmx::square((*cl + 1)*jGrid.cellSize[dim] - (b1 - jGrid.c0[dim])) < listRangeBBToCell2)
426     {
427         (*cl)++;
428     }
429 }
430
431 /* Reference code calculating the distance^2 between two bounding boxes */
432 /*
433    static float box_dist2(float bx0, float bx1, float by0,
434                        float by1, float bz0, float bz1,
435                        const nbnxn_bb_t *bb)
436    {
437     float d2;
438     float dl, dh, dm, dm0;
439
440     d2 = 0;
441
442     dl  = bx0 - bb->upper[BB_X];
443     dh  = bb->lower[BB_X] - bx1;
444     dm  = std::max(dl, dh);
445     dm0 = std::max(dm, 0.0f);
446     d2 += dm0*dm0;
447
448     dl  = by0 - bb->upper[BB_Y];
449     dh  = bb->lower[BB_Y] - by1;
450     dm  = std::max(dl, dh);
451     dm0 = std::max(dm, 0.0f);
452     d2 += dm0*dm0;
453
454     dl  = bz0 - bb->upper[BB_Z];
455     dh  = bb->lower[BB_Z] - bz1;
456     dm  = std::max(dl, dh);
457     dm0 = std::max(dm, 0.0f);
458     d2 += dm0*dm0;
459
460     return d2;
461    }
462  */
463
464 /* Plain C code calculating the distance^2 between two bounding boxes */
465 static float subc_bb_dist2(int                              si,
466                            const nbnxn_bb_t                *bb_i_ci,
467                            int                              csj,
468                            gmx::ArrayRef<const nbnxn_bb_t>  bb_j_all)
469 {
470     const nbnxn_bb_t *bb_i = bb_i_ci         +  si;
471     const nbnxn_bb_t *bb_j = bb_j_all.data() + csj;
472
473     float             d2 = 0;
474     float             dl, dh, dm, dm0;
475
476     dl  = bb_i->lower[BB_X] - bb_j->upper[BB_X];
477     dh  = bb_j->lower[BB_X] - bb_i->upper[BB_X];
478     dm  = std::max(dl, dh);
479     dm0 = std::max(dm, 0.0f);
480     d2 += dm0*dm0;
481
482     dl  = bb_i->lower[BB_Y] - bb_j->upper[BB_Y];
483     dh  = bb_j->lower[BB_Y] - bb_i->upper[BB_Y];
484     dm  = std::max(dl, dh);
485     dm0 = std::max(dm, 0.0f);
486     d2 += dm0*dm0;
487
488     dl  = bb_i->lower[BB_Z] - bb_j->upper[BB_Z];
489     dh  = bb_j->lower[BB_Z] - bb_i->upper[BB_Z];
490     dm  = std::max(dl, dh);
491     dm0 = std::max(dm, 0.0f);
492     d2 += dm0*dm0;
493
494     return d2;
495 }
496
497 #if NBNXN_SEARCH_BB_SIMD4
498
499 /* 4-wide SIMD code for bb distance for bb format xyz0 */
500 static float subc_bb_dist2_simd4(int                              si,
501                                  const nbnxn_bb_t                *bb_i_ci,
502                                  int                              csj,
503                                  gmx::ArrayRef<const nbnxn_bb_t>  bb_j_all)
504 {
505     // TODO: During SIMDv2 transition only some archs use namespace (remove when done)
506     using namespace gmx;
507
508     Simd4Float bb_i_S0, bb_i_S1;
509     Simd4Float bb_j_S0, bb_j_S1;
510     Simd4Float dl_S;
511     Simd4Float dh_S;
512     Simd4Float dm_S;
513     Simd4Float dm0_S;
514
515     bb_i_S0 = load4(&bb_i_ci[si].lower[0]);
516     bb_i_S1 = load4(&bb_i_ci[si].upper[0]);
517     bb_j_S0 = load4(&bb_j_all[csj].lower[0]);
518     bb_j_S1 = load4(&bb_j_all[csj].upper[0]);
519
520     dl_S    = bb_i_S0 - bb_j_S1;
521     dh_S    = bb_j_S0 - bb_i_S1;
522
523     dm_S    = max(dl_S, dh_S);
524     dm0_S   = max(dm_S, simd4SetZeroF());
525
526     return dotProduct(dm0_S, dm0_S);
527 }
528
529 /* Calculate bb bounding distances of bb_i[si,...,si+3] and store them in d2 */
530 #define SUBC_BB_DIST2_SIMD4_XXXX_INNER(si, bb_i, d2) \
531     {                                                \
532         int               shi;                                  \
533                                                  \
534         Simd4Float        dx_0, dy_0, dz_0;                    \
535         Simd4Float        dx_1, dy_1, dz_1;                    \
536                                                  \
537         Simd4Float        mx, my, mz;                          \
538         Simd4Float        m0x, m0y, m0z;                       \
539                                                  \
540         Simd4Float        d2x, d2y, d2z;                       \
541         Simd4Float        d2s, d2t;                            \
542                                                  \
543         shi = (si)*NNBSBB_D*DIM;                       \
544                                                  \
545         xi_l = load4((bb_i)+shi+0*STRIDE_PBB);   \
546         yi_l = load4((bb_i)+shi+1*STRIDE_PBB);   \
547         zi_l = load4((bb_i)+shi+2*STRIDE_PBB);   \
548         xi_h = load4((bb_i)+shi+3*STRIDE_PBB);   \
549         yi_h = load4((bb_i)+shi+4*STRIDE_PBB);   \
550         zi_h = load4((bb_i)+shi+5*STRIDE_PBB);   \
551                                                  \
552         dx_0 = xi_l - xj_h;                 \
553         dy_0 = yi_l - yj_h;                 \
554         dz_0 = zi_l - zj_h;                 \
555                                                  \
556         dx_1 = xj_l - xi_h;                 \
557         dy_1 = yj_l - yi_h;                 \
558         dz_1 = zj_l - zi_h;                 \
559                                                  \
560         mx   = max(dx_0, dx_1);                 \
561         my   = max(dy_0, dy_1);                 \
562         mz   = max(dz_0, dz_1);                 \
563                                                  \
564         m0x  = max(mx, zero);                   \
565         m0y  = max(my, zero);                   \
566         m0z  = max(mz, zero);                   \
567                                                  \
568         d2x  = m0x * m0x;                   \
569         d2y  = m0y * m0y;                   \
570         d2z  = m0z * m0z;                   \
571                                                  \
572         d2s  = d2x + d2y;                   \
573         d2t  = d2s + d2z;                   \
574                                                  \
575         store4((d2)+(si), d2t);                      \
576     }
577
578 /* 4-wide SIMD code for nsi bb distances for bb format xxxxyyyyzzzz */
579 static void subc_bb_dist2_simd4_xxxx(const float *bb_j,
580                                      int nsi, const float *bb_i,
581                                      float *d2)
582 {
583     // TODO: During SIMDv2 transition only some archs use namespace (remove when done)
584     using namespace gmx;
585
586     Simd4Float xj_l, yj_l, zj_l;
587     Simd4Float xj_h, yj_h, zj_h;
588     Simd4Float xi_l, yi_l, zi_l;
589     Simd4Float xi_h, yi_h, zi_h;
590
591     Simd4Float zero;
592
593     zero = setZero();
594
595     xj_l = Simd4Float(bb_j[0*STRIDE_PBB]);
596     yj_l = Simd4Float(bb_j[1*STRIDE_PBB]);
597     zj_l = Simd4Float(bb_j[2*STRIDE_PBB]);
598     xj_h = Simd4Float(bb_j[3*STRIDE_PBB]);
599     yj_h = Simd4Float(bb_j[4*STRIDE_PBB]);
600     zj_h = Simd4Float(bb_j[5*STRIDE_PBB]);
601
602     /* Here we "loop" over si (0,STRIDE_PBB) from 0 to nsi with step STRIDE_PBB.
603      * But as we know the number of iterations is 1 or 2, we unroll manually.
604      */
605     SUBC_BB_DIST2_SIMD4_XXXX_INNER(0, bb_i, d2);
606     if (STRIDE_PBB < nsi)
607     {
608         SUBC_BB_DIST2_SIMD4_XXXX_INNER(STRIDE_PBB, bb_i, d2);
609     }
610 }
611
612 #endif /* NBNXN_SEARCH_BB_SIMD4 */
613
614
615 /* Returns if any atom pair from two clusters is within distance sqrt(rlist2) */
616 static inline gmx_bool
617 clusterpair_in_range(const NbnxnPairlistGpuWork &work,
618                      int si,
619                      int csj, int stride, const real *x_j,
620                      real rlist2)
621 {
622 #if !GMX_SIMD4_HAVE_REAL
623
624     /* Plain C version.
625      * All coordinates are stored as xyzxyz...
626      */
627
628     const real *x_i = work.iSuperClusterData.x.data();
629
630     for (int i = 0; i < c_nbnxnGpuClusterSize; i++)
631     {
632         int i0 = (si*c_nbnxnGpuClusterSize + i)*DIM;
633         for (int j = 0; j < c_nbnxnGpuClusterSize; j++)
634         {
635             int  j0 = (csj*c_nbnxnGpuClusterSize + j)*stride;
636
637             real d2 = gmx::square(x_i[i0  ] - x_j[j0  ]) + gmx::square(x_i[i0+1] - x_j[j0+1]) + gmx::square(x_i[i0+2] - x_j[j0+2]);
638
639             if (d2 < rlist2)
640             {
641                 return TRUE;
642             }
643         }
644     }
645
646     return FALSE;
647
648 #else /* !GMX_SIMD4_HAVE_REAL */
649
650     /* 4-wide SIMD version.
651      * The coordinates x_i are stored as xxxxyyyy..., x_j is stored xyzxyz...
652      * Using 8-wide AVX(2) is not faster on Intel Sandy Bridge and Haswell.
653      */
654     static_assert(c_nbnxnGpuClusterSize == 8 || c_nbnxnGpuClusterSize == 4,
655                   "A cluster is hard-coded to 4/8 atoms.");
656
657     Simd4Real   rc2_S      = Simd4Real(rlist2);
658
659     const real *x_i        = work.iSuperClusterData.xSimd.data();
660
661     int         dim_stride = c_nbnxnGpuClusterSize*DIM;
662     Simd4Real   ix_S0      = load4(x_i + si*dim_stride + 0*GMX_SIMD4_WIDTH);
663     Simd4Real   iy_S0      = load4(x_i + si*dim_stride + 1*GMX_SIMD4_WIDTH);
664     Simd4Real   iz_S0      = load4(x_i + si*dim_stride + 2*GMX_SIMD4_WIDTH);
665
666     Simd4Real   ix_S1, iy_S1, iz_S1;
667     if (c_nbnxnGpuClusterSize == 8)
668     {
669         ix_S1      = load4(x_i + si*dim_stride + 3*GMX_SIMD4_WIDTH);
670         iy_S1      = load4(x_i + si*dim_stride + 4*GMX_SIMD4_WIDTH);
671         iz_S1      = load4(x_i + si*dim_stride + 5*GMX_SIMD4_WIDTH);
672     }
673     /* We loop from the outer to the inner particles to maximize
674      * the chance that we find a pair in range quickly and return.
675      */
676     int j0 = csj*c_nbnxnGpuClusterSize;
677     int j1 = j0 + c_nbnxnGpuClusterSize - 1;
678     while (j0 < j1)
679     {
680         Simd4Real jx0_S, jy0_S, jz0_S;
681         Simd4Real jx1_S, jy1_S, jz1_S;
682
683         Simd4Real dx_S0, dy_S0, dz_S0;
684         Simd4Real dx_S1, dy_S1, dz_S1;
685         Simd4Real dx_S2, dy_S2, dz_S2;
686         Simd4Real dx_S3, dy_S3, dz_S3;
687
688         Simd4Real rsq_S0;
689         Simd4Real rsq_S1;
690         Simd4Real rsq_S2;
691         Simd4Real rsq_S3;
692
693         Simd4Bool wco_S0;
694         Simd4Bool wco_S1;
695         Simd4Bool wco_S2;
696         Simd4Bool wco_S3;
697         Simd4Bool wco_any_S01, wco_any_S23, wco_any_S;
698
699         jx0_S = Simd4Real(x_j[j0*stride+0]);
700         jy0_S = Simd4Real(x_j[j0*stride+1]);
701         jz0_S = Simd4Real(x_j[j0*stride+2]);
702
703         jx1_S = Simd4Real(x_j[j1*stride+0]);
704         jy1_S = Simd4Real(x_j[j1*stride+1]);
705         jz1_S = Simd4Real(x_j[j1*stride+2]);
706
707         /* Calculate distance */
708         dx_S0            = ix_S0 - jx0_S;
709         dy_S0            = iy_S0 - jy0_S;
710         dz_S0            = iz_S0 - jz0_S;
711         dx_S2            = ix_S0 - jx1_S;
712         dy_S2            = iy_S0 - jy1_S;
713         dz_S2            = iz_S0 - jz1_S;
714         if (c_nbnxnGpuClusterSize == 8)
715         {
716             dx_S1            = ix_S1 - jx0_S;
717             dy_S1            = iy_S1 - jy0_S;
718             dz_S1            = iz_S1 - jz0_S;
719             dx_S3            = ix_S1 - jx1_S;
720             dy_S3            = iy_S1 - jy1_S;
721             dz_S3            = iz_S1 - jz1_S;
722         }
723
724         /* rsq = dx*dx+dy*dy+dz*dz */
725         rsq_S0           = norm2(dx_S0, dy_S0, dz_S0);
726         rsq_S2           = norm2(dx_S2, dy_S2, dz_S2);
727         if (c_nbnxnGpuClusterSize == 8)
728         {
729             rsq_S1           = norm2(dx_S1, dy_S1, dz_S1);
730             rsq_S3           = norm2(dx_S3, dy_S3, dz_S3);
731         }
732
733         wco_S0           = (rsq_S0 < rc2_S);
734         wco_S2           = (rsq_S2 < rc2_S);
735         if (c_nbnxnGpuClusterSize == 8)
736         {
737             wco_S1           = (rsq_S1 < rc2_S);
738             wco_S3           = (rsq_S3 < rc2_S);
739         }
740         if (c_nbnxnGpuClusterSize == 8)
741         {
742             wco_any_S01      = wco_S0 || wco_S1;
743             wco_any_S23      = wco_S2 || wco_S3;
744             wco_any_S        = wco_any_S01 || wco_any_S23;
745         }
746         else
747         {
748             wco_any_S = wco_S0 || wco_S2;
749         }
750
751         if (anyTrue(wco_any_S))
752         {
753             return TRUE;
754         }
755
756         j0++;
757         j1--;
758     }
759
760     return FALSE;
761
762 #endif /* !GMX_SIMD4_HAVE_REAL */
763 }
764
765 /* Returns the j-cluster index for index cjIndex in a cj list */
766 static inline int nblCj(gmx::ArrayRef<const nbnxn_cj_t> cjList,
767                         int                             cjIndex)
768 {
769     return cjList[cjIndex].cj;
770 }
771
772 /* Returns the j-cluster index for index cjIndex in a cj4 list */
773 static inline int nblCj(gmx::ArrayRef<const nbnxn_cj4_t> cj4List,
774                         int                              cjIndex)
775 {
776     return cj4List[cjIndex/c_nbnxnGpuJgroupSize].cj[cjIndex & (c_nbnxnGpuJgroupSize - 1)];
777 }
778
779 /* Returns the i-interaction mask of the j sub-cell for index cj_ind */
780 static unsigned int nbl_imask0(const NbnxnPairlistGpu *nbl, int cj_ind)
781 {
782     return nbl->cj4[cj_ind/c_nbnxnGpuJgroupSize].imei[0].imask;
783 }
784
785 /* Initializes a single NbnxnPairlistCpu data structure */
786 static void nbnxn_init_pairlist(NbnxnPairlistCpu *nbl)
787 {
788     nbl->na_ci       = c_nbnxnCpuIClusterSize;
789     nbl->na_cj       = 0;
790     nbl->ci.clear();
791     nbl->ciOuter.clear();
792     nbl->ncjInUse    = 0;
793     nbl->cj.clear();
794     nbl->cjOuter.clear();
795     nbl->nci_tot     = 0;
796
797     nbl->work        = new NbnxnPairlistCpuWork();
798 }
799
800 NbnxnPairlistGpu::NbnxnPairlistGpu(gmx::PinningPolicy pinningPolicy) :
801     na_ci(c_nbnxnGpuClusterSize),
802     na_cj(c_nbnxnGpuClusterSize),
803     na_sc(c_gpuNumClusterPerCell*c_nbnxnGpuClusterSize),
804     rlist(0),
805     sci({}, {pinningPolicy}),
806     cj4({}, {pinningPolicy}),
807     excl({}, {pinningPolicy}),
808     nci_tot(0)
809 {
810     static_assert(c_nbnxnGpuNumClusterPerSupercluster == c_gpuNumClusterPerCell,
811                   "The search code assumes that the a super-cluster matches a search grid cell");
812
813     static_assert(sizeof(cj4[0].imei[0].imask)*8 >= c_nbnxnGpuJgroupSize*c_gpuNumClusterPerCell,
814                   "The i super-cluster cluster interaction mask does not contain a sufficient number of bits");
815
816     static_assert(sizeof(excl[0])*8 >= c_nbnxnGpuJgroupSize*c_gpuNumClusterPerCell, "The GPU exclusion mask does not contain a sufficient number of bits");
817
818     // We always want a first entry without any exclusions
819     excl.resize(1);
820
821     work = new NbnxnPairlistGpuWork();
822 }
823
824 void nbnxn_init_pairlist_set(nbnxn_pairlist_set_t *nbl_list,
825                              gmx_bool bSimple, gmx_bool bCombined)
826 {
827     GMX_RELEASE_ASSERT(!bSimple || !bCombined, "Can only combine non-simple lists");
828
829     nbl_list->bSimple   = bSimple;
830     nbl_list->bCombined = bCombined;
831
832     nbl_list->nnbl = gmx_omp_nthreads_get(emntNonbonded);
833
834     if (!nbl_list->bCombined &&
835         nbl_list->nnbl > NBNXN_BUFFERFLAG_MAX_THREADS)
836     {
837         gmx_fatal(FARGS, "%d OpenMP threads were requested. Since the non-bonded force buffer reduction is prohibitively slow with more than %d threads, we do not allow this. Use %d or less OpenMP threads.",
838                   nbl_list->nnbl, NBNXN_BUFFERFLAG_MAX_THREADS, NBNXN_BUFFERFLAG_MAX_THREADS);
839     }
840
841     if (bSimple)
842     {
843         snew(nbl_list->nbl, nbl_list->nnbl);
844         if (nbl_list->nnbl > 1)
845         {
846             snew(nbl_list->nbl_work, nbl_list->nnbl);
847         }
848     }
849     else
850     {
851         snew(nbl_list->nblGpu, nbl_list->nnbl);
852     }
853     snew(nbl_list->nbl_fep, nbl_list->nnbl);
854     /* Execute in order to avoid memory interleaving between threads */
855 #pragma omp parallel for num_threads(nbl_list->nnbl) schedule(static)
856     for (int i = 0; i < nbl_list->nnbl; i++)
857     {
858         try
859         {
860             /* Allocate the nblist data structure locally on each thread
861              * to optimize memory access for NUMA architectures.
862              */
863             if (bSimple)
864             {
865                 nbl_list->nbl[i] = new NbnxnPairlistCpu();
866
867                 nbnxn_init_pairlist(nbl_list->nbl[i]);
868                 if (nbl_list->nnbl > 1)
869                 {
870                     nbl_list->nbl_work[i] = new NbnxnPairlistCpu();
871                     nbnxn_init_pairlist(nbl_list->nbl_work[i]);
872                 }
873             }
874             else
875             {
876                 /* Only list 0 is used on the GPU, use normal allocation for i>0 */
877                 auto pinningPolicy = (i == 0 ? gmx::PinningPolicy::PinnedIfSupported : gmx::PinningPolicy::CannotBePinned);
878
879                 nbl_list->nblGpu[i] = new NbnxnPairlistGpu(pinningPolicy);
880             }
881
882             snew(nbl_list->nbl_fep[i], 1);
883             nbnxn_init_pairlist_fep(nbl_list->nbl_fep[i]);
884         }
885         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
886     }
887 }
888
889 /* Print statistics of a pair list, used for debug output */
890 static void print_nblist_statistics(FILE *fp, const NbnxnPairlistCpu *nbl,
891                                     const nbnxn_search *nbs, real rl)
892 {
893     const nbnxn_grid_t *grid;
894     int                 cs[SHIFTS];
895     int                 npexcl;
896
897     grid = &nbs->grid[0];
898
899     fprintf(fp, "nbl nci %zu ncj %d\n",
900             nbl->ci.size(), nbl->ncjInUse);
901     fprintf(fp, "nbl na_cj %d rl %g ncp %d per cell %.1f atoms %.1f ratio %.2f\n",
902             nbl->na_cj, rl, nbl->ncjInUse, nbl->ncjInUse/static_cast<double>(grid->nc),
903             nbl->ncjInUse/static_cast<double>(grid->nc)*grid->na_cj,
904             nbl->ncjInUse/static_cast<double>(grid->nc)*grid->na_cj/(0.5*4.0/3.0*M_PI*rl*rl*rl*grid->nc*grid->na_cj/(grid->size[XX]*grid->size[YY]*grid->size[ZZ])));
905
906     fprintf(fp, "nbl average j cell list length %.1f\n",
907             0.25*nbl->ncjInUse/std::max(static_cast<double>(nbl->ci.size()), 1.0));
908
909     for (int s = 0; s < SHIFTS; s++)
910     {
911         cs[s] = 0;
912     }
913     npexcl = 0;
914     for (const nbnxn_ci_t &ciEntry : nbl->ci)
915     {
916         cs[ciEntry.shift & NBNXN_CI_SHIFT] +=
917             ciEntry.cj_ind_end - ciEntry.cj_ind_start;
918
919         int j = ciEntry.cj_ind_start;
920         while (j < ciEntry.cj_ind_end &&
921                nbl->cj[j].excl != NBNXN_INTERACTION_MASK_ALL)
922         {
923             npexcl++;
924             j++;
925         }
926     }
927     fprintf(fp, "nbl cell pairs, total: %zu excl: %d %.1f%%\n",
928             nbl->cj.size(), npexcl, 100*npexcl/std::max(static_cast<double>(nbl->cj.size()), 1.0));
929     for (int s = 0; s < SHIFTS; s++)
930     {
931         if (cs[s] > 0)
932         {
933             fprintf(fp, "nbl shift %2d ncj %3d\n", s, cs[s]);
934         }
935     }
936 }
937
938 /* Print statistics of a pair lists, used for debug output */
939 static void print_nblist_statistics(FILE *fp, const NbnxnPairlistGpu *nbl,
940                                     const nbnxn_search *nbs, real rl)
941 {
942     const nbnxn_grid_t *grid;
943     int                 b;
944     int                 c[c_gpuNumClusterPerCell + 1];
945     double              sum_nsp, sum_nsp2;
946     int                 nsp_max;
947
948     /* This code only produces correct statistics with domain decomposition */
949     grid = &nbs->grid[0];
950
951     fprintf(fp, "nbl nsci %zu ncj4 %zu nsi %d excl4 %zu\n",
952             nbl->sci.size(), nbl->cj4.size(), nbl->nci_tot, nbl->excl.size());
953     fprintf(fp, "nbl na_c %d rl %g ncp %d per cell %.1f atoms %.1f ratio %.2f\n",
954             nbl->na_ci, rl, nbl->nci_tot, nbl->nci_tot/static_cast<double>(grid->nsubc_tot),
955             nbl->nci_tot/static_cast<double>(grid->nsubc_tot)*grid->na_c,
956             nbl->nci_tot/static_cast<double>(grid->nsubc_tot)*grid->na_c/(0.5*4.0/3.0*M_PI*rl*rl*rl*grid->nsubc_tot*grid->na_c/(grid->size[XX]*grid->size[YY]*grid->size[ZZ])));
957
958     sum_nsp  = 0;
959     sum_nsp2 = 0;
960     nsp_max  = 0;
961     for (int si = 0; si <= c_gpuNumClusterPerCell; si++)
962     {
963         c[si] = 0;
964     }
965     for (const nbnxn_sci_t &sci : nbl->sci)
966     {
967         int nsp = 0;
968         for (int j4 = sci.cj4_ind_start; j4 < sci.cj4_ind_end; j4++)
969         {
970             for (int j = 0; j < c_nbnxnGpuJgroupSize; j++)
971             {
972                 b = 0;
973                 for (int si = 0; si < c_gpuNumClusterPerCell; si++)
974                 {
975                     if (nbl->cj4[j4].imei[0].imask & (1U << (j*c_gpuNumClusterPerCell + si)))
976                     {
977                         b++;
978                     }
979                 }
980                 nsp += b;
981                 c[b]++;
982             }
983         }
984         sum_nsp  += nsp;
985         sum_nsp2 += nsp*nsp;
986         nsp_max   = std::max(nsp_max, nsp);
987     }
988     if (!nbl->sci.empty())
989     {
990         sum_nsp  /= nbl->sci.size();
991         sum_nsp2 /= nbl->sci.size();
992     }
993     fprintf(fp, "nbl #cluster-pairs: av %.1f stddev %.1f max %d\n",
994             sum_nsp, std::sqrt(sum_nsp2 - sum_nsp*sum_nsp), nsp_max);
995
996     if (!nbl->cj4.empty())
997     {
998         for (b = 0; b <= c_gpuNumClusterPerCell; b++)
999         {
1000             fprintf(fp, "nbl j-list #i-subcell %d %7d %4.1f\n",
1001                     b, c[b], 100.0*c[b]/size_t {nbl->cj4.size()*c_nbnxnGpuJgroupSize});
1002         }
1003     }
1004 }
1005
1006 /* Returns a pointer to the exclusion mask for j-cluster-group \p cj4 and warp \p warp
1007  * Generates a new exclusion entry when the j-cluster-group uses
1008  * the default all-interaction mask at call time, so the returned mask
1009  * can be modified when needed.
1010  */
1011 static nbnxn_excl_t *get_exclusion_mask(NbnxnPairlistGpu *nbl,
1012                                         int               cj4,
1013                                         int               warp)
1014 {
1015     if (nbl->cj4[cj4].imei[warp].excl_ind == 0)
1016     {
1017         /* No exclusions set, make a new list entry */
1018         const size_t oldSize = nbl->excl.size();
1019         GMX_ASSERT(oldSize >= 1, "We should always have entry [0]");
1020         /* Add entry with default values: no exclusions */
1021         nbl->excl.resize(oldSize + 1);
1022         nbl->cj4[cj4].imei[warp].excl_ind = oldSize;
1023     }
1024
1025     return &nbl->excl[nbl->cj4[cj4].imei[warp].excl_ind];
1026 }
1027
1028 static void set_self_and_newton_excls_supersub(NbnxnPairlistGpu *nbl,
1029                                                int cj4_ind, int sj_offset,
1030                                                int i_cluster_in_cell)
1031 {
1032     nbnxn_excl_t *excl[c_nbnxnGpuClusterpairSplit];
1033
1034     /* Here we only set the set self and double pair exclusions */
1035
1036     /* Reserve extra elements, so the resize() in get_exclusion_mask()
1037      * will not invalidate excl entries in the loop below
1038      */
1039     nbl->excl.reserve(nbl->excl.size() + c_nbnxnGpuClusterpairSplit);
1040     for (int w = 0; w < c_nbnxnGpuClusterpairSplit; w++)
1041     {
1042         excl[w] = get_exclusion_mask(nbl, cj4_ind, w);
1043     }
1044
1045     /* Only minor < major bits set */
1046     for (int ej = 0; ej < nbl->na_ci; ej++)
1047     {
1048         int w = (ej>>2);
1049         for (int ei = ej; ei < nbl->na_ci; ei++)
1050         {
1051             excl[w]->pair[(ej & (c_nbnxnGpuJgroupSize-1))*nbl->na_ci + ei] &=
1052                 ~(1U << (sj_offset*c_gpuNumClusterPerCell + i_cluster_in_cell));
1053         }
1054     }
1055 }
1056
1057 /* Returns a diagonal or off-diagonal interaction mask for plain C lists */
1058 static unsigned int get_imask(gmx_bool rdiag, int ci, int cj)
1059 {
1060     return (rdiag && ci == cj ? NBNXN_INTERACTION_MASK_DIAG : NBNXN_INTERACTION_MASK_ALL);
1061 }
1062
1063 /* Returns a diagonal or off-diagonal interaction mask for cj-size=2 */
1064 gmx_unused static unsigned int get_imask_simd_j2(gmx_bool rdiag, int ci, int cj)
1065 {
1066     return (rdiag && ci*2 == cj ? NBNXN_INTERACTION_MASK_DIAG_J2_0 :
1067             (rdiag && ci*2+1 == cj ? NBNXN_INTERACTION_MASK_DIAG_J2_1 :
1068              NBNXN_INTERACTION_MASK_ALL));
1069 }
1070
1071 /* Returns a diagonal or off-diagonal interaction mask for cj-size=4 */
1072 gmx_unused static unsigned int get_imask_simd_j4(gmx_bool rdiag, int ci, int cj)
1073 {
1074     return (rdiag && ci == cj ? NBNXN_INTERACTION_MASK_DIAG : NBNXN_INTERACTION_MASK_ALL);
1075 }
1076
1077 /* Returns a diagonal or off-diagonal interaction mask for cj-size=8 */
1078 gmx_unused static unsigned int get_imask_simd_j8(gmx_bool rdiag, int ci, int cj)
1079 {
1080     return (rdiag && ci == cj*2 ? NBNXN_INTERACTION_MASK_DIAG_J8_0 :
1081             (rdiag && ci == cj*2+1 ? NBNXN_INTERACTION_MASK_DIAG_J8_1 :
1082              NBNXN_INTERACTION_MASK_ALL));
1083 }
1084
1085 #if GMX_SIMD
1086 #if GMX_SIMD_REAL_WIDTH == 2
1087 #define get_imask_simd_4xn  get_imask_simd_j2
1088 #endif
1089 #if GMX_SIMD_REAL_WIDTH == 4
1090 #define get_imask_simd_4xn  get_imask_simd_j4
1091 #endif
1092 #if GMX_SIMD_REAL_WIDTH == 8
1093 #define get_imask_simd_4xn  get_imask_simd_j8
1094 #define get_imask_simd_2xnn get_imask_simd_j4
1095 #endif
1096 #if GMX_SIMD_REAL_WIDTH == 16
1097 #define get_imask_simd_2xnn get_imask_simd_j8
1098 #endif
1099 #endif
1100
1101 /* Plain C code for checking and adding cluster-pairs to the list.
1102  *
1103  * \param[in]     gridj               The j-grid
1104  * \param[in,out] nbl                 The pair-list to store the cluster pairs in
1105  * \param[in]     icluster            The index of the i-cluster
1106  * \param[in]     jclusterFirst       The first cluster in the j-range
1107  * \param[in]     jclusterLast        The last cluster in the j-range
1108  * \param[in]     excludeSubDiagonal  Exclude atom pairs with i-index > j-index
1109  * \param[in]     x_j                 Coordinates for the j-atom, in xyz format
1110  * \param[in]     rlist2              The squared list cut-off
1111  * \param[in]     rbb2                The squared cut-off for putting cluster-pairs in the list based on bounding box distance only
1112  * \param[in,out] numDistanceChecks   The number of distance checks performed
1113  */
1114 static void
1115 makeClusterListSimple(const nbnxn_grid_t       &jGrid,
1116                       NbnxnPairlistCpu *        nbl,
1117                       int                       icluster,
1118                       int                       jclusterFirst,
1119                       int                       jclusterLast,
1120                       bool                      excludeSubDiagonal,
1121                       const real * gmx_restrict x_j,
1122                       real                      rlist2,
1123                       float                     rbb2,
1124                       int * gmx_restrict        numDistanceChecks)
1125 {
1126     const nbnxn_bb_t * gmx_restrict bb_ci = nbl->work->iClusterData.bb.data();
1127     const real * gmx_restrict       x_ci  = nbl->work->iClusterData.x.data();
1128
1129     gmx_bool                        InRange;
1130
1131     InRange = FALSE;
1132     while (!InRange && jclusterFirst <= jclusterLast)
1133     {
1134         real d2  = subc_bb_dist2(0, bb_ci, jclusterFirst, jGrid.bb);
1135         *numDistanceChecks += 2;
1136
1137         /* Check if the distance is within the distance where
1138          * we use only the bounding box distance rbb,
1139          * or within the cut-off and there is at least one atom pair
1140          * within the cut-off.
1141          */
1142         if (d2 < rbb2)
1143         {
1144             InRange = TRUE;
1145         }
1146         else if (d2 < rlist2)
1147         {
1148             int cjf_gl = jGrid.cell0 + jclusterFirst;
1149             for (int i = 0; i < c_nbnxnCpuIClusterSize && !InRange; i++)
1150             {
1151                 for (int j = 0; j < c_nbnxnCpuIClusterSize; j++)
1152                 {
1153                     InRange = InRange ||
1154                         (gmx::square(x_ci[i*STRIDE_XYZ+XX] - x_j[(cjf_gl*c_nbnxnCpuIClusterSize+j)*STRIDE_XYZ+XX]) +
1155                          gmx::square(x_ci[i*STRIDE_XYZ+YY] - x_j[(cjf_gl*c_nbnxnCpuIClusterSize+j)*STRIDE_XYZ+YY]) +
1156                          gmx::square(x_ci[i*STRIDE_XYZ+ZZ] - x_j[(cjf_gl*c_nbnxnCpuIClusterSize+j)*STRIDE_XYZ+ZZ]) < rlist2);
1157                 }
1158             }
1159             *numDistanceChecks += c_nbnxnCpuIClusterSize*c_nbnxnCpuIClusterSize;
1160         }
1161         if (!InRange)
1162         {
1163             jclusterFirst++;
1164         }
1165     }
1166     if (!InRange)
1167     {
1168         return;
1169     }
1170
1171     InRange = FALSE;
1172     while (!InRange && jclusterLast > jclusterFirst)
1173     {
1174         real d2  = subc_bb_dist2(0, bb_ci, jclusterLast, jGrid.bb);
1175         *numDistanceChecks += 2;
1176
1177         /* Check if the distance is within the distance where
1178          * we use only the bounding box distance rbb,
1179          * or within the cut-off and there is at least one atom pair
1180          * within the cut-off.
1181          */
1182         if (d2 < rbb2)
1183         {
1184             InRange = TRUE;
1185         }
1186         else if (d2 < rlist2)
1187         {
1188             int cjl_gl = jGrid.cell0 + jclusterLast;
1189             for (int i = 0; i < c_nbnxnCpuIClusterSize && !InRange; i++)
1190             {
1191                 for (int j = 0; j < c_nbnxnCpuIClusterSize; j++)
1192                 {
1193                     InRange = InRange ||
1194                         (gmx::square(x_ci[i*STRIDE_XYZ+XX] - x_j[(cjl_gl*c_nbnxnCpuIClusterSize+j)*STRIDE_XYZ+XX]) +
1195                          gmx::square(x_ci[i*STRIDE_XYZ+YY] - x_j[(cjl_gl*c_nbnxnCpuIClusterSize+j)*STRIDE_XYZ+YY]) +
1196                          gmx::square(x_ci[i*STRIDE_XYZ+ZZ] - x_j[(cjl_gl*c_nbnxnCpuIClusterSize+j)*STRIDE_XYZ+ZZ]) < rlist2);
1197                 }
1198             }
1199             *numDistanceChecks += c_nbnxnCpuIClusterSize*c_nbnxnCpuIClusterSize;
1200         }
1201         if (!InRange)
1202         {
1203             jclusterLast--;
1204         }
1205     }
1206
1207     if (jclusterFirst <= jclusterLast)
1208     {
1209         for (int jcluster = jclusterFirst; jcluster <= jclusterLast; jcluster++)
1210         {
1211             /* Store cj and the interaction mask */
1212             nbnxn_cj_t cjEntry;
1213             cjEntry.cj   = jGrid.cell0 + jcluster;
1214             cjEntry.excl = get_imask(excludeSubDiagonal, icluster, jcluster);
1215             nbl->cj.push_back(cjEntry);
1216         }
1217         /* Increase the closing index in the i list */
1218         nbl->ci.back().cj_ind_end = nbl->cj.size();
1219     }
1220 }
1221
1222 #ifdef GMX_NBNXN_SIMD_4XN
1223 #include "gromacs/mdlib/nbnxn_search_simd_4xn.h"
1224 #endif
1225 #ifdef GMX_NBNXN_SIMD_2XNN
1226 #include "gromacs/mdlib/nbnxn_search_simd_2xnn.h"
1227 #endif
1228
1229 /* Plain C or SIMD4 code for making a pair list of super-cell sci vs scj.
1230  * Checks bounding box distances and possibly atom pair distances.
1231  */
1232 static void make_cluster_list_supersub(const nbnxn_grid_t &iGrid,
1233                                        const nbnxn_grid_t &jGrid,
1234                                        NbnxnPairlistGpu   *nbl,
1235                                        const int           sci,
1236                                        const int           scj,
1237                                        const bool          excludeSubDiagonal,
1238                                        const int           stride,
1239                                        const real         *x,
1240                                        const real          rlist2,
1241                                        const float         rbb2,
1242                                        int                *numDistanceChecks)
1243 {
1244     NbnxnPairlistGpuWork &work   = *nbl->work;
1245
1246 #if NBNXN_BBXXXX
1247     const float          *pbb_ci = work.iSuperClusterData.bbPacked.data();
1248 #else
1249     const nbnxn_bb_t     *bb_ci  = work.iSuperClusterData.bb.data();
1250 #endif
1251
1252     assert(c_nbnxnGpuClusterSize == iGrid.na_c);
1253     assert(c_nbnxnGpuClusterSize == jGrid.na_c);
1254
1255     /* We generate the pairlist mainly based on bounding-box distances
1256      * and do atom pair distance based pruning on the GPU.
1257      * Only if a j-group contains a single cluster-pair, we try to prune
1258      * that pair based on atom distances on the CPU to avoid empty j-groups.
1259      */
1260 #define PRUNE_LIST_CPU_ONE 1
1261 #define PRUNE_LIST_CPU_ALL 0
1262
1263 #if PRUNE_LIST_CPU_ONE
1264     int  ci_last = -1;
1265 #endif
1266
1267     float *d2l = work.distanceBuffer.data();
1268
1269     for (int subc = 0; subc < jGrid.nsubc[scj]; subc++)
1270     {
1271         const int    cj4_ind   = work.cj_ind/c_nbnxnGpuJgroupSize;
1272         const int    cj_offset = work.cj_ind - cj4_ind*c_nbnxnGpuJgroupSize;
1273         const int    cj        = scj*c_gpuNumClusterPerCell + subc;
1274
1275         const int    cj_gl     = jGrid.cell0*c_gpuNumClusterPerCell + cj;
1276
1277         int          ci1;
1278         if (excludeSubDiagonal && sci == scj)
1279         {
1280             ci1 = subc + 1;
1281         }
1282         else
1283         {
1284             ci1 = iGrid.nsubc[sci];
1285         }
1286
1287 #if NBNXN_BBXXXX
1288         /* Determine all ci1 bb distances in one call with SIMD4 */
1289         subc_bb_dist2_simd4_xxxx(jGrid.pbb.data() + (cj >> STRIDE_PBB_2LOG)*NNBSBB_XXXX + (cj & (STRIDE_PBB-1)),
1290                                  ci1, pbb_ci, d2l);
1291         *numDistanceChecks += c_nbnxnGpuClusterSize*2;
1292 #endif
1293
1294         int          npair = 0;
1295         unsigned int imask = 0;
1296         /* We use a fixed upper-bound instead of ci1 to help optimization */
1297         for (int ci = 0; ci < c_gpuNumClusterPerCell; ci++)
1298         {
1299             if (ci == ci1)
1300             {
1301                 break;
1302             }
1303
1304 #if !NBNXN_BBXXXX
1305             /* Determine the bb distance between ci and cj */
1306             d2l[ci]             = subc_bb_dist2(ci, bb_ci, cj, jGrid.bb);
1307             *numDistanceChecks += 2;
1308 #endif
1309             float d2 = d2l[ci];
1310
1311 #if PRUNE_LIST_CPU_ALL
1312             /* Check if the distance is within the distance where
1313              * we use only the bounding box distance rbb,
1314              * or within the cut-off and there is at least one atom pair
1315              * within the cut-off. This check is very costly.
1316              */
1317             *numDistanceChecks += c_nbnxnGpuClusterSize*c_nbnxnGpuClusterSize;
1318             if (d2 < rbb2 ||
1319                 (d2 < rlist2 &&
1320                  clusterpair_in_range(work, ci, cj_gl, stride, x, rlist2)))
1321 #else
1322             /* Check if the distance between the two bounding boxes
1323              * in within the pair-list cut-off.
1324              */
1325             if (d2 < rlist2)
1326 #endif
1327             {
1328                 /* Flag this i-subcell to be taken into account */
1329                 imask |= (1U << (cj_offset*c_gpuNumClusterPerCell + ci));
1330
1331 #if PRUNE_LIST_CPU_ONE
1332                 ci_last = ci;
1333 #endif
1334
1335                 npair++;
1336             }
1337         }
1338
1339 #if PRUNE_LIST_CPU_ONE
1340         /* If we only found 1 pair, check if any atoms are actually
1341          * within the cut-off, so we could get rid of it.
1342          */
1343         if (npair == 1 && d2l[ci_last] >= rbb2 &&
1344             !clusterpair_in_range(work, ci_last, cj_gl, stride, x, rlist2))
1345         {
1346             imask &= ~(1U << (cj_offset*c_gpuNumClusterPerCell + ci_last));
1347             npair--;
1348         }
1349 #endif
1350
1351         if (npair > 0)
1352         {
1353             /* We have at least one cluster pair: add a j-entry */
1354             if (static_cast<size_t>(cj4_ind) == nbl->cj4.size())
1355             {
1356                 nbl->cj4.resize(nbl->cj4.size() + 1);
1357             }
1358             nbnxn_cj4_t *cj4   = &nbl->cj4[cj4_ind];
1359
1360             cj4->cj[cj_offset] = cj_gl;
1361
1362             /* Set the exclusions for the ci==sj entry.
1363              * Here we don't bother to check if this entry is actually flagged,
1364              * as it will nearly always be in the list.
1365              */
1366             if (excludeSubDiagonal && sci == scj)
1367             {
1368                 set_self_and_newton_excls_supersub(nbl, cj4_ind, cj_offset, subc);
1369             }
1370
1371             /* Copy the cluster interaction mask to the list */
1372             for (int w = 0; w < c_nbnxnGpuClusterpairSplit; w++)
1373             {
1374                 cj4->imei[w].imask |= imask;
1375             }
1376
1377             nbl->work->cj_ind++;
1378
1379             /* Keep the count */
1380             nbl->nci_tot += npair;
1381
1382             /* Increase the closing index in i super-cell list */
1383             nbl->sci.back().cj4_ind_end =
1384                 (nbl->work->cj_ind + c_nbnxnGpuJgroupSize - 1)/c_nbnxnGpuJgroupSize;
1385         }
1386     }
1387 }
1388
1389 /* Returns how many contiguous j-clusters we have starting in the i-list */
1390 template <typename CjListType>
1391 static int numContiguousJClusters(const int                       cjIndexStart,
1392                                   const int                       cjIndexEnd,
1393                                   gmx::ArrayRef<const CjListType> cjList)
1394 {
1395     const int firstJCluster = nblCj(cjList, cjIndexStart);
1396
1397     int       numContiguous = 0;
1398
1399     while (cjIndexStart + numContiguous < cjIndexEnd &&
1400            nblCj(cjList, cjIndexStart + numContiguous) == firstJCluster + numContiguous)
1401     {
1402         numContiguous++;
1403     }
1404
1405     return numContiguous;
1406 }
1407
1408 /*! \internal
1409  * \brief Helper struct for efficient searching for excluded atoms in a j-list
1410  */
1411 struct JListRanges
1412 {
1413     /*! \brief Constructs a j-list range from \p cjList with the given index range */
1414     template <typename CjListType>
1415     JListRanges(int                             cjIndexStart,
1416                 int                             cjIndexEnd,
1417                 gmx::ArrayRef<const CjListType> cjList);
1418
1419     int cjIndexStart; //!< The start index in the j-list
1420     int cjIndexEnd;   //!< The end index in the j-list
1421     int cjFirst;      //!< The j-cluster with index cjIndexStart
1422     int cjLast;       //!< The j-cluster with index cjIndexEnd-1
1423     int numDirect;    //!< Up to cjIndexStart+numDirect the j-clusters are cjFirst + the index offset
1424 };
1425
1426 #ifndef DOXYGEN
1427 template <typename CjListType>
1428 JListRanges::JListRanges(int                             cjIndexStart,
1429                          int                             cjIndexEnd,
1430                          gmx::ArrayRef<const CjListType> cjList) :
1431     cjIndexStart(cjIndexStart),
1432     cjIndexEnd(cjIndexEnd)
1433 {
1434     GMX_ASSERT(cjIndexEnd > cjIndexStart, "JListRanges should only be called with non-empty lists");
1435
1436     cjFirst   = nblCj(cjList, cjIndexStart);
1437     cjLast    = nblCj(cjList, cjIndexEnd - 1);
1438
1439     /* Determine how many contiguous j-cells we have starting
1440      * from the first i-cell. This number can be used to directly
1441      * calculate j-cell indices for excluded atoms.
1442      */
1443     numDirect = numContiguousJClusters(cjIndexStart, cjIndexEnd, cjList);
1444 }
1445 #endif // !DOXYGEN
1446
1447 /* Return the index of \p jCluster in the given range or -1 when not present
1448  *
1449  * Note: This code is executed very often and therefore performance is
1450  *       important. It should be inlined and fully optimized.
1451  */
1452 template <typename CjListType>
1453 static inline int
1454 findJClusterInJList(int                              jCluster,
1455                     const JListRanges               &ranges,
1456                     gmx::ArrayRef<const CjListType>  cjList)
1457 {
1458     int index;
1459
1460     if (jCluster < ranges.cjFirst + ranges.numDirect)
1461     {
1462         /* We can calculate the index directly using the offset */
1463         index = ranges.cjIndexStart + jCluster - ranges.cjFirst;
1464     }
1465     else
1466     {
1467         /* Search for jCluster using bisection */
1468         index           = -1;
1469         int rangeStart  = ranges.cjIndexStart + ranges.numDirect;
1470         int rangeEnd    = ranges.cjIndexEnd;
1471         int rangeMiddle;
1472         while (index == -1 && rangeStart < rangeEnd)
1473         {
1474             rangeMiddle = (rangeStart + rangeEnd) >> 1;
1475
1476             const int clusterMiddle = nblCj(cjList, rangeMiddle);
1477
1478             if (jCluster == clusterMiddle)
1479             {
1480                 index      = rangeMiddle;
1481             }
1482             else if (jCluster < clusterMiddle)
1483             {
1484                 rangeEnd   = rangeMiddle;
1485             }
1486             else
1487             {
1488                 rangeStart = rangeMiddle + 1;
1489             }
1490         }
1491     }
1492
1493     return index;
1494 }
1495
1496 // TODO: Get rid of the two functions below by renaming sci to ci (or something better)
1497
1498 /* Return the i-entry in the list we are currently operating on */
1499 static nbnxn_ci_t *getOpenIEntry(NbnxnPairlistCpu *nbl)
1500 {
1501     return &nbl->ci.back();
1502 }
1503
1504 /* Return the i-entry in the list we are currently operating on */
1505 static nbnxn_sci_t *getOpenIEntry(NbnxnPairlistGpu *nbl)
1506 {
1507     return &nbl->sci.back();
1508 }
1509
1510 /* Set all atom-pair exclusions for a simple type list i-entry
1511  *
1512  * Set all atom-pair exclusions from the topology stored in exclusions
1513  * as masks in the pair-list for simple list entry iEntry.
1514  */
1515 static void
1516 setExclusionsForIEntry(const nbnxn_search   *nbs,
1517                        NbnxnPairlistCpu     *nbl,
1518                        gmx_bool              diagRemoved,
1519                        int                   na_cj_2log,
1520                        const nbnxn_ci_t     &iEntry,
1521                        const t_blocka       &exclusions)
1522 {
1523     if (iEntry.cj_ind_end == iEntry.cj_ind_start)
1524     {
1525         /* Empty list: no exclusions */
1526         return;
1527     }
1528
1529     const JListRanges        ranges(iEntry.cj_ind_start, iEntry.cj_ind_end, gmx::makeConstArrayRef(nbl->cj));
1530
1531     const int                iCluster = iEntry.ci;
1532
1533     gmx::ArrayRef<const int> cell = nbs->cell;
1534
1535     /* Loop over the atoms in the i-cluster */
1536     for (int i = 0; i < nbl->na_ci; i++)
1537     {
1538         const int iIndex = iCluster*nbl->na_ci + i;
1539         const int iAtom  = nbs->a[iIndex];
1540         if (iAtom >= 0)
1541         {
1542             /* Loop over the topology-based exclusions for this i-atom */
1543             for (int exclIndex = exclusions.index[iAtom]; exclIndex < exclusions.index[iAtom + 1]; exclIndex++)
1544             {
1545                 const int jAtom = exclusions.a[exclIndex];
1546
1547                 if (jAtom == iAtom)
1548                 {
1549                     /* The self exclusion are already set, save some time */
1550                     continue;
1551                 }
1552
1553                 /* Get the index of the j-atom in the nbnxn atom data */
1554                 const int jIndex = cell[jAtom];
1555
1556                 /* Without shifts we only calculate interactions j>i
1557                  * for one-way pair-lists.
1558                  */
1559                 if (diagRemoved && jIndex <= iIndex)
1560                 {
1561                     continue;
1562                 }
1563
1564                 const int jCluster = (jIndex >> na_cj_2log);
1565
1566                 /* Could the cluster se be in our list? */
1567                 if (jCluster >= ranges.cjFirst && jCluster <= ranges.cjLast)
1568                 {
1569                     const int index =
1570                         findJClusterInJList(jCluster, ranges,
1571                                             gmx::makeConstArrayRef(nbl->cj));
1572
1573                     if (index >= 0)
1574                     {
1575                         /* We found an exclusion, clear the corresponding
1576                          * interaction bit.
1577                          */
1578                         const int innerJ     = jIndex - (jCluster << na_cj_2log);
1579
1580                         nbl->cj[index].excl &= ~(1U << ((i << na_cj_2log) + innerJ));
1581                     }
1582                 }
1583             }
1584         }
1585     }
1586 }
1587
1588 /* Add a new i-entry to the FEP list and copy the i-properties */
1589 static inline void fep_list_new_nri_copy(t_nblist *nlist)
1590 {
1591     /* Add a new i-entry */
1592     nlist->nri++;
1593
1594     assert(nlist->nri < nlist->maxnri);
1595
1596     /* Duplicate the last i-entry, except for jindex, which continues */
1597     nlist->iinr[nlist->nri]   = nlist->iinr[nlist->nri-1];
1598     nlist->shift[nlist->nri]  = nlist->shift[nlist->nri-1];
1599     nlist->gid[nlist->nri]    = nlist->gid[nlist->nri-1];
1600     nlist->jindex[nlist->nri] = nlist->nrj;
1601 }
1602
1603 /* For load balancing of the free-energy lists over threads, we set
1604  * the maximum nrj size of an i-entry to 40. This leads to good
1605  * load balancing in the worst case scenario of a single perturbed
1606  * particle on 16 threads, while not introducing significant overhead.
1607  * Note that half of the perturbed pairs will anyhow end up in very small lists,
1608  * since non perturbed i-particles will see few perturbed j-particles).
1609  */
1610 const int max_nrj_fep = 40;
1611
1612 /* Exclude the perturbed pairs from the Verlet list. This is only done to avoid
1613  * singularities for overlapping particles (0/0), since the charges and
1614  * LJ parameters have been zeroed in the nbnxn data structure.
1615  * Simultaneously make a group pair list for the perturbed pairs.
1616  */
1617 static void make_fep_list(const nbnxn_search     *nbs,
1618                           const nbnxn_atomdata_t *nbat,
1619                           NbnxnPairlistCpu       *nbl,
1620                           gmx_bool                bDiagRemoved,
1621                           nbnxn_ci_t             *nbl_ci,
1622                           real gmx_unused         shx,
1623                           real gmx_unused         shy,
1624                           real gmx_unused         shz,
1625                           real gmx_unused         rlist_fep2,
1626                           const nbnxn_grid_t     &iGrid,
1627                           const nbnxn_grid_t     &jGrid,
1628                           t_nblist               *nlist)
1629 {
1630     int      ci, cj_ind_start, cj_ind_end, cja, cjr;
1631     int      nri_max;
1632     int      ngid, gid_i = 0, gid_j, gid;
1633     int      egp_shift, egp_mask;
1634     int      gid_cj = 0;
1635     int      ind_i, ind_j, ai, aj;
1636     int      nri;
1637     gmx_bool bFEP_i, bFEP_i_all;
1638
1639     if (nbl_ci->cj_ind_end == nbl_ci->cj_ind_start)
1640     {
1641         /* Empty list */
1642         return;
1643     }
1644
1645     ci = nbl_ci->ci;
1646
1647     cj_ind_start = nbl_ci->cj_ind_start;
1648     cj_ind_end   = nbl_ci->cj_ind_end;
1649
1650     /* In worst case we have alternating energy groups
1651      * and create #atom-pair lists, which means we need the size
1652      * of a cluster pair (na_ci*na_cj) times the number of cj's.
1653      */
1654     nri_max = nbl->na_ci*nbl->na_cj*(cj_ind_end - cj_ind_start);
1655     if (nlist->nri + nri_max > nlist->maxnri)
1656     {
1657         nlist->maxnri = over_alloc_large(nlist->nri + nri_max);
1658         reallocate_nblist(nlist);
1659     }
1660
1661     const nbnxn_atomdata_t::Params &nbatParams = nbat->params();
1662
1663     ngid = nbatParams.nenergrp;
1664
1665     if (ngid*jGrid.na_cj > gmx::index(sizeof(gid_cj)*8))
1666     {
1667         gmx_fatal(FARGS, "The Verlet scheme with %dx%d kernels and free-energy only supports up to %zu energy groups",
1668                   iGrid.na_c, jGrid.na_cj, (sizeof(gid_cj)*8)/jGrid.na_cj);
1669     }
1670
1671     egp_shift = nbatParams.neg_2log;
1672     egp_mask  = (1 << egp_shift) - 1;
1673
1674     /* Loop over the atoms in the i sub-cell */
1675     bFEP_i_all = TRUE;
1676     for (int i = 0; i < nbl->na_ci; i++)
1677     {
1678         ind_i = ci*nbl->na_ci + i;
1679         ai    = nbs->a[ind_i];
1680         if (ai >= 0)
1681         {
1682             nri                  = nlist->nri;
1683             nlist->jindex[nri+1] = nlist->jindex[nri];
1684             nlist->iinr[nri]     = ai;
1685             /* The actual energy group pair index is set later */
1686             nlist->gid[nri]      = 0;
1687             nlist->shift[nri]    = nbl_ci->shift & NBNXN_CI_SHIFT;
1688
1689             bFEP_i = ((iGrid.fep[ci - iGrid.cell0] & (1 << i)) != 0u);
1690
1691             bFEP_i_all = bFEP_i_all && bFEP_i;
1692
1693             if (nlist->nrj + (cj_ind_end - cj_ind_start)*nbl->na_cj > nlist->maxnrj)
1694             {
1695                 nlist->maxnrj = over_alloc_small(nlist->nrj + (cj_ind_end - cj_ind_start)*nbl->na_cj);
1696                 srenew(nlist->jjnr,     nlist->maxnrj);
1697                 srenew(nlist->excl_fep, nlist->maxnrj);
1698             }
1699
1700             if (ngid > 1)
1701             {
1702                 gid_i = (nbatParams.energrp[ci] >> (egp_shift*i)) & egp_mask;
1703             }
1704
1705             for (int cj_ind = cj_ind_start; cj_ind < cj_ind_end; cj_ind++)
1706             {
1707                 unsigned int fep_cj;
1708
1709                 cja = nbl->cj[cj_ind].cj;
1710
1711                 if (jGrid.na_cj == jGrid.na_c)
1712                 {
1713                     cjr    = cja - jGrid.cell0;
1714                     fep_cj = jGrid.fep[cjr];
1715                     if (ngid > 1)
1716                     {
1717                         gid_cj = nbatParams.energrp[cja];
1718                     }
1719                 }
1720                 else if (2*jGrid.na_cj == jGrid.na_c)
1721                 {
1722                     cjr    = cja - jGrid.cell0*2;
1723                     /* Extract half of the ci fep/energrp mask */
1724                     fep_cj = (jGrid.fep[cjr>>1] >> ((cjr&1)*jGrid.na_cj)) & ((1<<jGrid.na_cj) - 1);
1725                     if (ngid > 1)
1726                     {
1727                         gid_cj = nbatParams.energrp[cja>>1] >> ((cja&1)*jGrid.na_cj*egp_shift) & ((1<<(jGrid.na_cj*egp_shift)) - 1);
1728                     }
1729                 }
1730                 else
1731                 {
1732                     cjr    = cja - (jGrid.cell0>>1);
1733                     /* Combine two ci fep masks/energrp */
1734                     fep_cj = jGrid.fep[cjr*2] + (jGrid.fep[cjr*2+1] << jGrid.na_c);
1735                     if (ngid > 1)
1736                     {
1737                         gid_cj = nbatParams.energrp[cja*2] + (nbatParams.energrp[cja*2+1] << (jGrid.na_c*egp_shift));
1738                     }
1739                 }
1740
1741                 if (bFEP_i || fep_cj != 0)
1742                 {
1743                     for (int j = 0; j < nbl->na_cj; j++)
1744                     {
1745                         /* Is this interaction perturbed and not excluded? */
1746                         ind_j = cja*nbl->na_cj + j;
1747                         aj    = nbs->a[ind_j];
1748                         if (aj >= 0 &&
1749                             (bFEP_i || (fep_cj & (1 << j))) &&
1750                             (!bDiagRemoved || ind_j >= ind_i))
1751                         {
1752                             if (ngid > 1)
1753                             {
1754                                 gid_j = (gid_cj >> (j*egp_shift)) & egp_mask;
1755                                 gid   = GID(gid_i, gid_j, ngid);
1756
1757                                 if (nlist->nrj > nlist->jindex[nri] &&
1758                                     nlist->gid[nri] != gid)
1759                                 {
1760                                     /* Energy group pair changed: new list */
1761                                     fep_list_new_nri_copy(nlist);
1762                                     nri = nlist->nri;
1763                                 }
1764                                 nlist->gid[nri] = gid;
1765                             }
1766
1767                             if (nlist->nrj - nlist->jindex[nri] >= max_nrj_fep)
1768                             {
1769                                 fep_list_new_nri_copy(nlist);
1770                                 nri = nlist->nri;
1771                             }
1772
1773                             /* Add it to the FEP list */
1774                             nlist->jjnr[nlist->nrj]     = aj;
1775                             nlist->excl_fep[nlist->nrj] = (nbl->cj[cj_ind].excl >> (i*nbl->na_cj + j)) & 1;
1776                             nlist->nrj++;
1777
1778                             /* Exclude it from the normal list.
1779                              * Note that the charge has been set to zero,
1780                              * but we need to avoid 0/0, as perturbed atoms
1781                              * can be on top of each other.
1782                              */
1783                             nbl->cj[cj_ind].excl &= ~(1U << (i*nbl->na_cj + j));
1784                         }
1785                     }
1786                 }
1787             }
1788
1789             if (nlist->nrj > nlist->jindex[nri])
1790             {
1791                 /* Actually add this new, non-empty, list */
1792                 nlist->nri++;
1793                 nlist->jindex[nlist->nri] = nlist->nrj;
1794             }
1795         }
1796     }
1797
1798     if (bFEP_i_all)
1799     {
1800         /* All interactions are perturbed, we can skip this entry */
1801         nbl_ci->cj_ind_end = cj_ind_start;
1802         nbl->ncjInUse     -= cj_ind_end - cj_ind_start;
1803     }
1804 }
1805
1806 /* Return the index of atom a within a cluster */
1807 static inline int cj_mod_cj4(int cj)
1808 {
1809     return cj & (c_nbnxnGpuJgroupSize - 1);
1810 }
1811
1812 /* Convert a j-cluster to a cj4 group */
1813 static inline int cj_to_cj4(int cj)
1814 {
1815     return cj/c_nbnxnGpuJgroupSize;
1816 }
1817
1818 /* Return the index of an j-atom within a warp */
1819 static inline int a_mod_wj(int a)
1820 {
1821     return a & (c_nbnxnGpuClusterSize/c_nbnxnGpuClusterpairSplit - 1);
1822 }
1823
1824 /* As make_fep_list above, but for super/sub lists. */
1825 static void make_fep_list(const nbnxn_search     *nbs,
1826                           const nbnxn_atomdata_t *nbat,
1827                           NbnxnPairlistGpu       *nbl,
1828                           gmx_bool                bDiagRemoved,
1829                           const nbnxn_sci_t      *nbl_sci,
1830                           real                    shx,
1831                           real                    shy,
1832                           real                    shz,
1833                           real                    rlist_fep2,
1834                           const nbnxn_grid_t     &iGrid,
1835                           const nbnxn_grid_t     &jGrid,
1836                           t_nblist               *nlist)
1837 {
1838     int                nri_max;
1839     int                c_abs;
1840     int                ind_i, ind_j, ai, aj;
1841     int                nri;
1842     gmx_bool           bFEP_i;
1843     real               xi, yi, zi;
1844     const nbnxn_cj4_t *cj4;
1845
1846     const int          numJClusterGroups = nbl_sci->numJClusterGroups();
1847     if (numJClusterGroups == 0)
1848     {
1849         /* Empty list */
1850         return;
1851     }
1852
1853     const int sci           = nbl_sci->sci;
1854
1855     const int cj4_ind_start = nbl_sci->cj4_ind_start;
1856     const int cj4_ind_end   = nbl_sci->cj4_ind_end;
1857
1858     /* Here we process one super-cell, max #atoms na_sc, versus a list
1859      * cj4 entries, each with max c_nbnxnGpuJgroupSize cj's, each
1860      * of size na_cj atoms.
1861      * On the GPU we don't support energy groups (yet).
1862      * So for each of the na_sc i-atoms, we need max one FEP list
1863      * for each max_nrj_fep j-atoms.
1864      */
1865     nri_max = nbl->na_sc*nbl->na_cj*(1 + (numJClusterGroups*c_nbnxnGpuJgroupSize)/max_nrj_fep);
1866     if (nlist->nri + nri_max > nlist->maxnri)
1867     {
1868         nlist->maxnri = over_alloc_large(nlist->nri + nri_max);
1869         reallocate_nblist(nlist);
1870     }
1871
1872     /* Loop over the atoms in the i super-cluster */
1873     for (int c = 0; c < c_gpuNumClusterPerCell; c++)
1874     {
1875         c_abs = sci*c_gpuNumClusterPerCell + c;
1876
1877         for (int i = 0; i < nbl->na_ci; i++)
1878         {
1879             ind_i = c_abs*nbl->na_ci + i;
1880             ai    = nbs->a[ind_i];
1881             if (ai >= 0)
1882             {
1883                 nri                  = nlist->nri;
1884                 nlist->jindex[nri+1] = nlist->jindex[nri];
1885                 nlist->iinr[nri]     = ai;
1886                 /* With GPUs, energy groups are not supported */
1887                 nlist->gid[nri]      = 0;
1888                 nlist->shift[nri]    = nbl_sci->shift & NBNXN_CI_SHIFT;
1889
1890                 bFEP_i = ((iGrid.fep[c_abs - iGrid.cell0*c_gpuNumClusterPerCell] & (1 << i)) != 0u);
1891
1892                 xi = nbat->x()[ind_i*nbat->xstride+XX] + shx;
1893                 yi = nbat->x()[ind_i*nbat->xstride+YY] + shy;
1894                 zi = nbat->x()[ind_i*nbat->xstride+ZZ] + shz;
1895
1896                 const int nrjMax = nlist->nrj + numJClusterGroups*c_nbnxnGpuJgroupSize*nbl->na_cj;
1897                 if (nrjMax > nlist->maxnrj)
1898                 {
1899                     nlist->maxnrj = over_alloc_small(nrjMax);
1900                     srenew(nlist->jjnr,     nlist->maxnrj);
1901                     srenew(nlist->excl_fep, nlist->maxnrj);
1902                 }
1903
1904                 for (int cj4_ind = cj4_ind_start; cj4_ind < cj4_ind_end; cj4_ind++)
1905                 {
1906                     cj4 = &nbl->cj4[cj4_ind];
1907
1908                     for (int gcj = 0; gcj < c_nbnxnGpuJgroupSize; gcj++)
1909                     {
1910                         unsigned int fep_cj;
1911
1912                         if ((cj4->imei[0].imask & (1U << (gcj*c_gpuNumClusterPerCell + c))) == 0)
1913                         {
1914                             /* Skip this ci for this cj */
1915                             continue;
1916                         }
1917
1918                         const int cjr =
1919                             cj4->cj[gcj] - jGrid.cell0*c_gpuNumClusterPerCell;
1920
1921                         fep_cj = jGrid.fep[cjr];
1922
1923                         if (bFEP_i || fep_cj != 0)
1924                         {
1925                             for (int j = 0; j < nbl->na_cj; j++)
1926                             {
1927                                 /* Is this interaction perturbed and not excluded? */
1928                                 ind_j = (jGrid.cell0*c_gpuNumClusterPerCell + cjr)*nbl->na_cj + j;
1929                                 aj    = nbs->a[ind_j];
1930                                 if (aj >= 0 &&
1931                                     (bFEP_i || (fep_cj & (1 << j))) &&
1932                                     (!bDiagRemoved || ind_j >= ind_i))
1933                                 {
1934                                     int           excl_pair;
1935                                     unsigned int  excl_bit;
1936                                     real          dx, dy, dz;
1937
1938                                     const int     jHalf = j/(c_nbnxnGpuClusterSize/c_nbnxnGpuClusterpairSplit);
1939                                     nbnxn_excl_t *excl  =
1940                                         get_exclusion_mask(nbl, cj4_ind, jHalf);
1941
1942                                     excl_pair = a_mod_wj(j)*nbl->na_ci + i;
1943                                     excl_bit  = (1U << (gcj*c_gpuNumClusterPerCell + c));
1944
1945                                     dx = nbat->x()[ind_j*nbat->xstride+XX] - xi;
1946                                     dy = nbat->x()[ind_j*nbat->xstride+YY] - yi;
1947                                     dz = nbat->x()[ind_j*nbat->xstride+ZZ] - zi;
1948
1949                                     /* The unpruned GPU list has more than 2/3
1950                                      * of the atom pairs beyond rlist. Using
1951                                      * this list will cause a lot of overhead
1952                                      * in the CPU FEP kernels, especially
1953                                      * relative to the fast GPU kernels.
1954                                      * So we prune the FEP list here.
1955                                      */
1956                                     if (dx*dx + dy*dy + dz*dz < rlist_fep2)
1957                                     {
1958                                         if (nlist->nrj - nlist->jindex[nri] >= max_nrj_fep)
1959                                         {
1960                                             fep_list_new_nri_copy(nlist);
1961                                             nri = nlist->nri;
1962                                         }
1963
1964                                         /* Add it to the FEP list */
1965                                         nlist->jjnr[nlist->nrj]     = aj;
1966                                         nlist->excl_fep[nlist->nrj] = (excl->pair[excl_pair] & excl_bit) ? 1 : 0;
1967                                         nlist->nrj++;
1968                                     }
1969
1970                                     /* Exclude it from the normal list.
1971                                      * Note that the charge and LJ parameters have
1972                                      * been set to zero, but we need to avoid 0/0,
1973                                      * as perturbed atoms can be on top of each other.
1974                                      */
1975                                     excl->pair[excl_pair] &= ~excl_bit;
1976                                 }
1977                             }
1978
1979                             /* Note that we could mask out this pair in imask
1980                              * if all i- and/or all j-particles are perturbed.
1981                              * But since the perturbed pairs on the CPU will
1982                              * take an order of magnitude more time, the GPU
1983                              * will finish before the CPU and there is no gain.
1984                              */
1985                         }
1986                     }
1987                 }
1988
1989                 if (nlist->nrj > nlist->jindex[nri])
1990                 {
1991                     /* Actually add this new, non-empty, list */
1992                     nlist->nri++;
1993                     nlist->jindex[nlist->nri] = nlist->nrj;
1994                 }
1995             }
1996         }
1997     }
1998 }
1999
2000 /* Set all atom-pair exclusions for a GPU type list i-entry
2001  *
2002  * Sets all atom-pair exclusions from the topology stored in exclusions
2003  * as masks in the pair-list for i-super-cluster list entry iEntry.
2004  */
2005 static void
2006 setExclusionsForIEntry(const nbnxn_search   *nbs,
2007                        NbnxnPairlistGpu     *nbl,
2008                        gmx_bool              diagRemoved,
2009                        int gmx_unused        na_cj_2log,
2010                        const nbnxn_sci_t    &iEntry,
2011                        const t_blocka       &exclusions)
2012 {
2013     if (iEntry.numJClusterGroups() == 0)
2014     {
2015         /* Empty list */
2016         return;
2017     }
2018
2019     /* Set the search ranges using start and end j-cluster indices.
2020      * Note that here we can not use cj4_ind_end, since the last cj4
2021      * can be only partially filled, so we use cj_ind.
2022      */
2023     const JListRanges ranges(iEntry.cj4_ind_start*c_nbnxnGpuJgroupSize,
2024                              nbl->work->cj_ind,
2025                              gmx::makeConstArrayRef(nbl->cj4));
2026
2027     GMX_ASSERT(nbl->na_ci == c_nbnxnGpuClusterSize, "na_ci should match the GPU cluster size");
2028     constexpr int            c_clusterSize      = c_nbnxnGpuClusterSize;
2029     constexpr int            c_superClusterSize = c_nbnxnGpuNumClusterPerSupercluster*c_nbnxnGpuClusterSize;
2030
2031     const int                iSuperCluster = iEntry.sci;
2032
2033     gmx::ArrayRef<const int> cell = nbs->cell;
2034
2035     /* Loop over the atoms in the i super-cluster */
2036     for (int i = 0; i < c_superClusterSize; i++)
2037     {
2038         const int iIndex = iSuperCluster*c_superClusterSize + i;
2039         const int iAtom  = nbs->a[iIndex];
2040         if (iAtom >= 0)
2041         {
2042             const int iCluster = i/c_clusterSize;
2043
2044             /* Loop over the topology-based exclusions for this i-atom */
2045             for (int exclIndex = exclusions.index[iAtom]; exclIndex < exclusions.index[iAtom + 1]; exclIndex++)
2046             {
2047                 const int jAtom = exclusions.a[exclIndex];
2048
2049                 if (jAtom == iAtom)
2050                 {
2051                     /* The self exclusions are already set, save some time */
2052                     continue;
2053                 }
2054
2055                 /* Get the index of the j-atom in the nbnxn atom data */
2056                 const int jIndex = cell[jAtom];
2057
2058                 /* Without shifts we only calculate interactions j>i
2059                  * for one-way pair-lists.
2060                  */
2061                 /* NOTE: We would like to use iIndex on the right hand side,
2062                  * but that makes this routine 25% slower with gcc6/7.
2063                  * Even using c_superClusterSize makes it slower.
2064                  * Either of these changes triggers peeling of the exclIndex
2065                  * loop, which apparently leads to far less efficient code.
2066                  */
2067                 if (diagRemoved && jIndex <= iSuperCluster*nbl->na_sc + i)
2068                 {
2069                     continue;
2070                 }
2071
2072                 const int jCluster = jIndex/c_clusterSize;
2073
2074                 /* Check whether the cluster is in our list? */
2075                 if (jCluster >= ranges.cjFirst && jCluster <= ranges.cjLast)
2076                 {
2077                     const int index =
2078                         findJClusterInJList(jCluster, ranges,
2079                                             gmx::makeConstArrayRef(nbl->cj4));
2080
2081                     if (index >= 0)
2082                     {
2083                         /* We found an exclusion, clear the corresponding
2084                          * interaction bit.
2085                          */
2086                         const unsigned int pairMask = (1U << (cj_mod_cj4(index)*c_gpuNumClusterPerCell + iCluster));
2087                         /* Check if the i-cluster interacts with the j-cluster */
2088                         if (nbl_imask0(nbl, index) & pairMask)
2089                         {
2090                             const int innerI = (i      & (c_clusterSize - 1));
2091                             const int innerJ = (jIndex & (c_clusterSize - 1));
2092
2093                             /* Determine which j-half (CUDA warp) we are in */
2094                             const int     jHalf = innerJ/(c_clusterSize/c_nbnxnGpuClusterpairSplit);
2095
2096                             nbnxn_excl_t *interactionMask =
2097                                 get_exclusion_mask(nbl, cj_to_cj4(index), jHalf);
2098
2099                             interactionMask->pair[a_mod_wj(innerJ)*c_clusterSize + innerI] &= ~pairMask;
2100                         }
2101                     }
2102                 }
2103             }
2104         }
2105     }
2106 }
2107
2108 /* Make a new ci entry at the back of nbl->ci */
2109 static void addNewIEntry(NbnxnPairlistCpu *nbl, int ci, int shift, int flags)
2110 {
2111     nbnxn_ci_t ciEntry;
2112     ciEntry.ci            = ci;
2113     ciEntry.shift         = shift;
2114     /* Store the interaction flags along with the shift */
2115     ciEntry.shift        |= flags;
2116     ciEntry.cj_ind_start  = nbl->cj.size();
2117     ciEntry.cj_ind_end    = nbl->cj.size();
2118     nbl->ci.push_back(ciEntry);
2119 }
2120
2121 /* Make a new sci entry at index nbl->nsci */
2122 static void addNewIEntry(NbnxnPairlistGpu *nbl, int sci, int shift, int gmx_unused flags)
2123 {
2124     nbnxn_sci_t sciEntry;
2125     sciEntry.sci           = sci;
2126     sciEntry.shift         = shift;
2127     sciEntry.cj4_ind_start = nbl->cj4.size();
2128     sciEntry.cj4_ind_end   = nbl->cj4.size();
2129
2130     nbl->sci.push_back(sciEntry);
2131 }
2132
2133 /* Sort the simple j-list cj on exclusions.
2134  * Entries with exclusions will all be sorted to the beginning of the list.
2135  */
2136 static void sort_cj_excl(nbnxn_cj_t *cj, int ncj,
2137                          NbnxnPairlistCpuWork *work)
2138 {
2139     work->cj.resize(ncj);
2140
2141     /* Make a list of the j-cells involving exclusions */
2142     int jnew = 0;
2143     for (int j = 0; j < ncj; j++)
2144     {
2145         if (cj[j].excl != NBNXN_INTERACTION_MASK_ALL)
2146         {
2147             work->cj[jnew++] = cj[j];
2148         }
2149     }
2150     /* Check if there are exclusions at all or not just the first entry */
2151     if (!((jnew == 0) ||
2152           (jnew == 1 && cj[0].excl != NBNXN_INTERACTION_MASK_ALL)))
2153     {
2154         for (int j = 0; j < ncj; j++)
2155         {
2156             if (cj[j].excl == NBNXN_INTERACTION_MASK_ALL)
2157             {
2158                 work->cj[jnew++] = cj[j];
2159             }
2160         }
2161         for (int j = 0; j < ncj; j++)
2162         {
2163             cj[j] = work->cj[j];
2164         }
2165     }
2166 }
2167
2168 /* Close this simple list i entry */
2169 static void closeIEntry(NbnxnPairlistCpu    *nbl,
2170                         int gmx_unused       sp_max_av,
2171                         gmx_bool gmx_unused  progBal,
2172                         float gmx_unused     nsp_tot_est,
2173                         int gmx_unused       thread,
2174                         int gmx_unused       nthread)
2175 {
2176     nbnxn_ci_t &ciEntry = nbl->ci.back();
2177
2178     /* All content of the new ci entry have already been filled correctly,
2179      * we only need to sort and increase counts or remove the entry when empty.
2180      */
2181     const int jlen = ciEntry.cj_ind_end - ciEntry.cj_ind_start;
2182     if (jlen > 0)
2183     {
2184         sort_cj_excl(nbl->cj.data() + ciEntry.cj_ind_start, jlen, nbl->work);
2185
2186         /* The counts below are used for non-bonded pair/flop counts
2187          * and should therefore match the available kernel setups.
2188          */
2189         if (!(ciEntry.shift & NBNXN_CI_DO_COUL(0)))
2190         {
2191             nbl->work->ncj_noq += jlen;
2192         }
2193         else if ((ciEntry.shift & NBNXN_CI_HALF_LJ(0)) ||
2194                  !(ciEntry.shift & NBNXN_CI_DO_LJ(0)))
2195         {
2196             nbl->work->ncj_hlj += jlen;
2197         }
2198     }
2199     else
2200     {
2201         /* Entry is empty: remove it  */
2202         nbl->ci.pop_back();
2203     }
2204 }
2205
2206 /* Split sci entry for load balancing on the GPU.
2207  * Splitting ensures we have enough lists to fully utilize the whole GPU.
2208  * With progBal we generate progressively smaller lists, which improves
2209  * load balancing. As we only know the current count on our own thread,
2210  * we will need to estimate the current total amount of i-entries.
2211  * As the lists get concatenated later, this estimate depends
2212  * both on nthread and our own thread index.
2213  */
2214 static void split_sci_entry(NbnxnPairlistGpu *nbl,
2215                             int nsp_target_av,
2216                             gmx_bool progBal, float nsp_tot_est,
2217                             int thread, int nthread)
2218 {
2219     int nsp_max;
2220
2221     if (progBal)
2222     {
2223         float nsp_est;
2224
2225         /* Estimate the total numbers of ci's of the nblist combined
2226          * over all threads using the target number of ci's.
2227          */
2228         nsp_est = (nsp_tot_est*thread)/nthread + nbl->nci_tot;
2229
2230         /* The first ci blocks should be larger, to avoid overhead.
2231          * The last ci blocks should be smaller, to improve load balancing.
2232          * The factor 3/2 makes the first block 3/2 times the target average
2233          * and ensures that the total number of blocks end up equal to
2234          * that of equally sized blocks of size nsp_target_av.
2235          */
2236         nsp_max = static_cast<int>(nsp_target_av*(nsp_tot_est*1.5/(nsp_est + nsp_tot_est)));
2237     }
2238     else
2239     {
2240         nsp_max = nsp_target_av;
2241     }
2242
2243     const int cj4_start = nbl->sci.back().cj4_ind_start;
2244     const int cj4_end   = nbl->sci.back().cj4_ind_end;
2245     const int j4len     = cj4_end - cj4_start;
2246
2247     if (j4len > 1 && j4len*c_gpuNumClusterPerCell*c_nbnxnGpuJgroupSize > nsp_max)
2248     {
2249         /* Modify the last ci entry and process the cj4's again */
2250
2251         int nsp        = 0;
2252         int nsp_sci    = 0;
2253         int nsp_cj4_e  = 0;
2254         int nsp_cj4    = 0;
2255         for (int cj4 = cj4_start; cj4 < cj4_end; cj4++)
2256         {
2257             int nsp_cj4_p = nsp_cj4;
2258             /* Count the number of cluster pairs in this cj4 group */
2259             nsp_cj4   = 0;
2260             for (int p = 0; p < c_gpuNumClusterPerCell*c_nbnxnGpuJgroupSize; p++)
2261             {
2262                 nsp_cj4 += (nbl->cj4[cj4].imei[0].imask >> p) & 1;
2263             }
2264
2265             /* If adding the current cj4 with nsp_cj4 pairs get us further
2266              * away from our target nsp_max, split the list before this cj4.
2267              */
2268             if (nsp > 0 && nsp_max - nsp < nsp + nsp_cj4 - nsp_max)
2269             {
2270                 /* Split the list at cj4 */
2271                 nbl->sci.back().cj4_ind_end = cj4;
2272                 /* Create a new sci entry */
2273                 nbnxn_sci_t sciNew;
2274                 sciNew.sci           = nbl->sci.back().sci;
2275                 sciNew.shift         = nbl->sci.back().shift;
2276                 sciNew.cj4_ind_start = cj4;
2277                 nbl->sci.push_back(sciNew);
2278
2279                 nsp_sci              = nsp;
2280                 nsp_cj4_e            = nsp_cj4_p;
2281                 nsp                  = 0;
2282             }
2283             nsp += nsp_cj4;
2284         }
2285
2286         /* Put the remaining cj4's in the last sci entry */
2287         nbl->sci.back().cj4_ind_end = cj4_end;
2288
2289         /* Possibly balance out the last two sci's
2290          * by moving the last cj4 of the second last sci.
2291          */
2292         if (nsp_sci - nsp_cj4_e >= nsp + nsp_cj4_e)
2293         {
2294             GMX_ASSERT(nbl->sci.size() >= 2, "We expect at least two elements");
2295             nbl->sci[nbl->sci.size() - 2].cj4_ind_end--;
2296             nbl->sci[nbl->sci.size() - 1].cj4_ind_start--;
2297         }
2298     }
2299 }
2300
2301 /* Clost this super/sub list i entry */
2302 static void closeIEntry(NbnxnPairlistGpu *nbl,
2303                         int nsp_max_av,
2304                         gmx_bool progBal, float nsp_tot_est,
2305                         int thread, int nthread)
2306 {
2307     nbnxn_sci_t &sciEntry = *getOpenIEntry(nbl);
2308
2309     /* All content of the new ci entry have already been filled correctly,
2310      * we only need to, potentially, split or remove the entry when empty.
2311      */
2312     int j4len = sciEntry.numJClusterGroups();
2313     if (j4len > 0)
2314     {
2315         /* We can only have complete blocks of 4 j-entries in a list,
2316          * so round the count up before closing.
2317          */
2318         int ncj4          = (nbl->work->cj_ind + c_nbnxnGpuJgroupSize - 1)/c_nbnxnGpuJgroupSize;
2319         nbl->work->cj_ind = ncj4*c_nbnxnGpuJgroupSize;
2320
2321         if (nsp_max_av > 0)
2322         {
2323             /* Measure the size of the new entry and potentially split it */
2324             split_sci_entry(nbl, nsp_max_av, progBal, nsp_tot_est,
2325                             thread, nthread);
2326         }
2327     }
2328     else
2329     {
2330         /* Entry is empty: remove it  */
2331         nbl->sci.pop_back();
2332     }
2333 }
2334
2335 /* Syncs the working array before adding another grid pair to the GPU list */
2336 static void sync_work(NbnxnPairlistCpu gmx_unused *nbl)
2337 {
2338 }
2339
2340 /* Syncs the working array before adding another grid pair to the GPU list */
2341 static void sync_work(NbnxnPairlistGpu *nbl)
2342 {
2343     nbl->work->cj_ind   = nbl->cj4.size()*c_nbnxnGpuJgroupSize;
2344 }
2345
2346 /* Clears an NbnxnPairlistCpu data structure */
2347 static void clear_pairlist(NbnxnPairlistCpu *nbl)
2348 {
2349     nbl->ci.clear();
2350     nbl->cj.clear();
2351     nbl->ncjInUse      = 0;
2352     nbl->nci_tot       = 0;
2353     nbl->ciOuter.clear();
2354     nbl->cjOuter.clear();
2355
2356     nbl->work->ncj_noq = 0;
2357     nbl->work->ncj_hlj = 0;
2358 }
2359
2360 /* Clears an NbnxnPairlistGpu data structure */
2361 static void clear_pairlist(NbnxnPairlistGpu *nbl)
2362 {
2363     nbl->sci.clear();
2364     nbl->cj4.clear();
2365     nbl->excl.resize(1);
2366     nbl->nci_tot = 0;
2367 }
2368
2369 /* Clears a group scheme pair list */
2370 static void clear_pairlist_fep(t_nblist *nl)
2371 {
2372     nl->nri = 0;
2373     nl->nrj = 0;
2374     if (nl->jindex == nullptr)
2375     {
2376         snew(nl->jindex, 1);
2377     }
2378     nl->jindex[0] = 0;
2379 }
2380
2381 /* Sets a simple list i-cell bounding box, including PBC shift */
2382 static inline void set_icell_bb_simple(gmx::ArrayRef<const nbnxn_bb_t> bb,
2383                                        int ci,
2384                                        real shx, real shy, real shz,
2385                                        nbnxn_bb_t *bb_ci)
2386 {
2387     bb_ci->lower[BB_X] = bb[ci].lower[BB_X] + shx;
2388     bb_ci->lower[BB_Y] = bb[ci].lower[BB_Y] + shy;
2389     bb_ci->lower[BB_Z] = bb[ci].lower[BB_Z] + shz;
2390     bb_ci->upper[BB_X] = bb[ci].upper[BB_X] + shx;
2391     bb_ci->upper[BB_Y] = bb[ci].upper[BB_Y] + shy;
2392     bb_ci->upper[BB_Z] = bb[ci].upper[BB_Z] + shz;
2393 }
2394
2395 /* Sets a simple list i-cell bounding box, including PBC shift */
2396 static inline void set_icell_bb(const nbnxn_grid_t &iGrid,
2397                                 int ci,
2398                                 real shx, real shy, real shz,
2399                                 NbnxnPairlistCpuWork *work)
2400 {
2401     set_icell_bb_simple(iGrid.bb, ci, shx, shy, shz, &work->iClusterData.bb[0]);
2402 }
2403
2404 #if NBNXN_BBXXXX
2405 /* Sets a super-cell and sub cell bounding boxes, including PBC shift */
2406 static void set_icell_bbxxxx_supersub(gmx::ArrayRef<const float> bb,
2407                                       int ci,
2408                                       real shx, real shy, real shz,
2409                                       float *bb_ci)
2410 {
2411     int ia = ci*(c_gpuNumClusterPerCell >> STRIDE_PBB_2LOG)*NNBSBB_XXXX;
2412     for (int m = 0; m < (c_gpuNumClusterPerCell >> STRIDE_PBB_2LOG)*NNBSBB_XXXX; m += NNBSBB_XXXX)
2413     {
2414         for (int i = 0; i < STRIDE_PBB; i++)
2415         {
2416             bb_ci[m+0*STRIDE_PBB+i] = bb[ia+m+0*STRIDE_PBB+i] + shx;
2417             bb_ci[m+1*STRIDE_PBB+i] = bb[ia+m+1*STRIDE_PBB+i] + shy;
2418             bb_ci[m+2*STRIDE_PBB+i] = bb[ia+m+2*STRIDE_PBB+i] + shz;
2419             bb_ci[m+3*STRIDE_PBB+i] = bb[ia+m+3*STRIDE_PBB+i] + shx;
2420             bb_ci[m+4*STRIDE_PBB+i] = bb[ia+m+4*STRIDE_PBB+i] + shy;
2421             bb_ci[m+5*STRIDE_PBB+i] = bb[ia+m+5*STRIDE_PBB+i] + shz;
2422         }
2423     }
2424 }
2425 #endif
2426
2427 /* Sets a super-cell and sub cell bounding boxes, including PBC shift */
2428 gmx_unused static void set_icell_bb_supersub(gmx::ArrayRef<const nbnxn_bb_t> bb,
2429                                              int ci,
2430                                              real shx, real shy, real shz,
2431                                              nbnxn_bb_t *bb_ci)
2432 {
2433     for (int i = 0; i < c_gpuNumClusterPerCell; i++)
2434     {
2435         set_icell_bb_simple(bb, ci*c_gpuNumClusterPerCell+i,
2436                             shx, shy, shz,
2437                             &bb_ci[i]);
2438     }
2439 }
2440
2441 /* Sets a super-cell and sub cell bounding boxes, including PBC shift */
2442 gmx_unused static void set_icell_bb(const nbnxn_grid_t &iGrid,
2443                                     int ci,
2444                                     real shx, real shy, real shz,
2445                                     NbnxnPairlistGpuWork *work)
2446 {
2447 #if NBNXN_BBXXXX
2448     set_icell_bbxxxx_supersub(iGrid.pbb, ci, shx, shy, shz,
2449                               work->iSuperClusterData.bbPacked.data());
2450 #else
2451     set_icell_bb_supersub(iGrid.bb, ci, shx, shy, shz,
2452                           work->iSuperClusterData.bb.data());
2453 #endif
2454 }
2455
2456 /* Copies PBC shifted i-cell atom coordinates x,y,z to working array */
2457 static void icell_set_x_simple(int ci,
2458                                real shx, real shy, real shz,
2459                                int stride, const real *x,
2460                                NbnxnPairlistCpuWork::IClusterData *iClusterData)
2461 {
2462     const int ia = ci*c_nbnxnCpuIClusterSize;
2463
2464     for (int i = 0; i < c_nbnxnCpuIClusterSize; i++)
2465     {
2466         iClusterData->x[i*STRIDE_XYZ+XX] = x[(ia+i)*stride+XX] + shx;
2467         iClusterData->x[i*STRIDE_XYZ+YY] = x[(ia+i)*stride+YY] + shy;
2468         iClusterData->x[i*STRIDE_XYZ+ZZ] = x[(ia+i)*stride+ZZ] + shz;
2469     }
2470 }
2471
2472 static void icell_set_x(int ci,
2473                         real shx, real shy, real shz,
2474                         int stride, const real *x,
2475                         int nb_kernel_type,
2476                         NbnxnPairlistCpuWork *work)
2477 {
2478     switch (nb_kernel_type)
2479     {
2480 #if GMX_SIMD
2481 #ifdef GMX_NBNXN_SIMD_4XN
2482         case nbnxnk4xN_SIMD_4xN:
2483             icell_set_x_simd_4xn(ci, shx, shy, shz, stride, x, work);
2484             break;
2485 #endif
2486 #ifdef GMX_NBNXN_SIMD_2XNN
2487         case nbnxnk4xN_SIMD_2xNN:
2488             icell_set_x_simd_2xnn(ci, shx, shy, shz, stride, x, work);
2489             break;
2490 #endif
2491 #endif
2492         case nbnxnk4x4_PlainC:
2493             icell_set_x_simple(ci, shx, shy, shz, stride, x, &work->iClusterData);
2494             break;
2495         default:
2496             GMX_ASSERT(false, "Unhandled case");
2497             break;
2498     }
2499 }
2500
2501 /* Copies PBC shifted super-cell atom coordinates x,y,z to working array */
2502 static void icell_set_x(int ci,
2503                         real shx, real shy, real shz,
2504                         int stride, const real *x,
2505                         int gmx_unused nb_kernel_type,
2506                         NbnxnPairlistGpuWork *work)
2507 {
2508 #if !GMX_SIMD4_HAVE_REAL
2509
2510     real * x_ci = work->iSuperClusterData.x.data();
2511
2512     int    ia = ci*c_gpuNumClusterPerCell*c_nbnxnGpuClusterSize;
2513     for (int i = 0; i < c_gpuNumClusterPerCell*c_nbnxnGpuClusterSize; i++)
2514     {
2515         x_ci[i*DIM + XX] = x[(ia+i)*stride + XX] + shx;
2516         x_ci[i*DIM + YY] = x[(ia+i)*stride + YY] + shy;
2517         x_ci[i*DIM + ZZ] = x[(ia+i)*stride + ZZ] + shz;
2518     }
2519
2520 #else /* !GMX_SIMD4_HAVE_REAL */
2521
2522     real * x_ci = work->iSuperClusterData.xSimd.data();
2523
2524     for (int si = 0; si < c_gpuNumClusterPerCell; si++)
2525     {
2526         for (int i = 0; i < c_nbnxnGpuClusterSize; i += GMX_SIMD4_WIDTH)
2527         {
2528             int io = si*c_nbnxnGpuClusterSize + i;
2529             int ia = ci*c_gpuNumClusterPerCell*c_nbnxnGpuClusterSize + io;
2530             for (int j = 0; j < GMX_SIMD4_WIDTH; j++)
2531             {
2532                 x_ci[io*DIM + j + XX*GMX_SIMD4_WIDTH] = x[(ia + j)*stride + XX] + shx;
2533                 x_ci[io*DIM + j + YY*GMX_SIMD4_WIDTH] = x[(ia + j)*stride + YY] + shy;
2534                 x_ci[io*DIM + j + ZZ*GMX_SIMD4_WIDTH] = x[(ia + j)*stride + ZZ] + shz;
2535             }
2536         }
2537     }
2538
2539 #endif /* !GMX_SIMD4_HAVE_REAL */
2540 }
2541
2542 static real minimum_subgrid_size_xy(const nbnxn_grid_t &grid)
2543 {
2544     if (grid.bSimple)
2545     {
2546         return std::min(grid.cellSize[XX], grid.cellSize[YY]);
2547     }
2548     else
2549     {
2550         return std::min(grid.cellSize[XX]/c_gpuNumClusterPerCellX,
2551                         grid.cellSize[YY]/c_gpuNumClusterPerCellY);
2552     }
2553 }
2554
2555 static real effective_buffer_1x1_vs_MxN(const nbnxn_grid_t &iGrid,
2556                                         const nbnxn_grid_t &jGrid)
2557 {
2558     const real eff_1x1_buffer_fac_overest = 0.1;
2559
2560     /* Determine an atom-pair list cut-off buffer size for atom pairs,
2561      * to be added to rlist (including buffer) used for MxN.
2562      * This is for converting an MxN list to a 1x1 list. This means we can't
2563      * use the normal buffer estimate, as we have an MxN list in which
2564      * some atom pairs beyond rlist are missing. We want to capture
2565      * the beneficial effect of buffering by extra pairs just outside rlist,
2566      * while removing the useless pairs that are further away from rlist.
2567      * (Also the buffer could have been set manually not using the estimate.)
2568      * This buffer size is an overestimate.
2569      * We add 10% of the smallest grid sub-cell dimensions.
2570      * Note that the z-size differs per cell and we don't use this,
2571      * so we overestimate.
2572      * With PME, the 10% value gives a buffer that is somewhat larger
2573      * than the effective buffer with a tolerance of 0.005 kJ/mol/ps.
2574      * Smaller tolerances or using RF lead to a smaller effective buffer,
2575      * so 10% gives a safe overestimate.
2576      */
2577     return eff_1x1_buffer_fac_overest*(minimum_subgrid_size_xy(iGrid) +
2578                                        minimum_subgrid_size_xy(jGrid));
2579 }
2580
2581 /* Clusters at the cut-off only increase rlist by 60% of their size */
2582 static real nbnxn_rlist_inc_outside_fac = 0.6;
2583
2584 /* Due to the cluster size the effective pair-list is longer than
2585  * that of a simple atom pair-list. This function gives the extra distance.
2586  */
2587 real nbnxn_get_rlist_effective_inc(int cluster_size_j, real atom_density)
2588 {
2589     int  cluster_size_i;
2590     real vol_inc_i, vol_inc_j;
2591
2592     /* We should get this from the setup, but currently it's the same for
2593      * all setups, including GPUs.
2594      */
2595     cluster_size_i = c_nbnxnCpuIClusterSize;
2596
2597     vol_inc_i = (cluster_size_i - 1)/atom_density;
2598     vol_inc_j = (cluster_size_j - 1)/atom_density;
2599
2600     return nbnxn_rlist_inc_outside_fac*std::cbrt(vol_inc_i + vol_inc_j);
2601 }
2602
2603 /* Estimates the interaction volume^2 for non-local interactions */
2604 static real nonlocal_vol2(const struct gmx_domdec_zones_t *zones, const rvec ls, real r)
2605 {
2606     real cl, ca, za;
2607     real vold_est;
2608     real vol2_est_tot;
2609
2610     vol2_est_tot = 0;
2611
2612     /* Here we simply add up the volumes of 1, 2 or 3 1D decomposition
2613      * not home interaction volume^2. As these volumes are not additive,
2614      * this is an overestimate, but it would only be significant in the limit
2615      * of small cells, where we anyhow need to split the lists into
2616      * as small parts as possible.
2617      */
2618
2619     for (int z = 0; z < zones->n; z++)
2620     {
2621         if (zones->shift[z][XX] + zones->shift[z][YY] + zones->shift[z][ZZ] == 1)
2622         {
2623             cl = 0;
2624             ca = 1;
2625             za = 1;
2626             for (int d = 0; d < DIM; d++)
2627             {
2628                 if (zones->shift[z][d] == 0)
2629                 {
2630                     cl += 0.5*ls[d];
2631                     ca *= ls[d];
2632                     za *= zones->size[z].x1[d] - zones->size[z].x0[d];
2633                 }
2634             }
2635
2636             /* 4 octants of a sphere */
2637             vold_est  = 0.25*M_PI*r*r*r*r;
2638             /* 4 quarter pie slices on the edges */
2639             vold_est += 4*cl*M_PI/6.0*r*r*r;
2640             /* One rectangular volume on a face */
2641             vold_est += ca*0.5*r*r;
2642
2643             vol2_est_tot += vold_est*za;
2644         }
2645     }
2646
2647     return vol2_est_tot;
2648 }
2649
2650 /* Estimates the average size of a full j-list for super/sub setup */
2651 static void get_nsubpair_target(const nbnxn_search   *nbs,
2652                                 int                   iloc,
2653                                 real                  rlist,
2654                                 int                   min_ci_balanced,
2655                                 int                  *nsubpair_target,
2656                                 float                *nsubpair_tot_est)
2657 {
2658     /* The target value of 36 seems to be the optimum for Kepler.
2659      * Maxwell is less sensitive to the exact value.
2660      */
2661     const int           nsubpair_target_min = 36;
2662     rvec                ls;
2663     real                r_eff_sup, vol_est, nsp_est, nsp_est_nl;
2664
2665     const nbnxn_grid_t &grid = nbs->grid[0];
2666
2667     /* We don't need to balance list sizes if:
2668      * - We didn't request balancing.
2669      * - The number of grid cells >= the number of lists requested,
2670      *   since we will always generate at least #cells lists.
2671      * - We don't have any cells, since then there won't be any lists.
2672      */
2673     if (min_ci_balanced <= 0 || grid.nc >= min_ci_balanced || grid.nc == 0)
2674     {
2675         /* nsubpair_target==0 signals no balancing */
2676         *nsubpair_target  = 0;
2677         *nsubpair_tot_est = 0;
2678
2679         return;
2680     }
2681
2682     ls[XX] = (grid.c1[XX] - grid.c0[XX])/(grid.numCells[XX]*c_gpuNumClusterPerCellX);
2683     ls[YY] = (grid.c1[YY] - grid.c0[YY])/(grid.numCells[YY]*c_gpuNumClusterPerCellY);
2684     ls[ZZ] = grid.na_c/(grid.atom_density*ls[XX]*ls[YY]);
2685
2686     /* The average length of the diagonal of a sub cell */
2687     real diagonal = std::sqrt(ls[XX]*ls[XX] + ls[YY]*ls[YY] + ls[ZZ]*ls[ZZ]);
2688
2689     /* The formulas below are a heuristic estimate of the average nsj per si*/
2690     r_eff_sup = rlist + nbnxn_rlist_inc_outside_fac*gmx::square((grid.na_c - 1.0)/grid.na_c)*0.5*diagonal;
2691
2692     if (!nbs->DomDec || nbs->zones->n == 1)
2693     {
2694         nsp_est_nl = 0;
2695     }
2696     else
2697     {
2698         nsp_est_nl =
2699             gmx::square(grid.atom_density/grid.na_c)*
2700             nonlocal_vol2(nbs->zones, ls, r_eff_sup);
2701     }
2702
2703     if (LOCAL_I(iloc))
2704     {
2705         /* Sub-cell interacts with itself */
2706         vol_est  = ls[XX]*ls[YY]*ls[ZZ];
2707         /* 6/2 rectangular volume on the faces */
2708         vol_est += (ls[XX]*ls[YY] + ls[XX]*ls[ZZ] + ls[YY]*ls[ZZ])*r_eff_sup;
2709         /* 12/2 quarter pie slices on the edges */
2710         vol_est += 2*(ls[XX] + ls[YY] + ls[ZZ])*0.25*M_PI*gmx::square(r_eff_sup);
2711         /* 4 octants of a sphere */
2712         vol_est += 0.5*4.0/3.0*M_PI*gmx::power3(r_eff_sup);
2713
2714         /* Estimate the number of cluster pairs as the local number of
2715          * clusters times the volume they interact with times the density.
2716          */
2717         nsp_est = grid.nsubc_tot*vol_est*grid.atom_density/grid.na_c;
2718
2719         /* Subtract the non-local pair count */
2720         nsp_est -= nsp_est_nl;
2721
2722         /* For small cut-offs nsp_est will be an underesimate.
2723          * With DD nsp_est_nl is an overestimate so nsp_est can get negative.
2724          * So to avoid too small or negative nsp_est we set a minimum of
2725          * all cells interacting with all 3^3 direct neighbors (3^3-1)/2+1=14.
2726          * This might be a slight overestimate for small non-periodic groups of
2727          * atoms as will occur for a local domain with DD, but for small
2728          * groups of atoms we'll anyhow be limited by nsubpair_target_min,
2729          * so this overestimation will not matter.
2730          */
2731         nsp_est = std::max(nsp_est, grid.nsubc_tot*14._real);
2732
2733         if (debug)
2734         {
2735             fprintf(debug, "nsp_est local %5.1f non-local %5.1f\n",
2736                     nsp_est, nsp_est_nl);
2737         }
2738     }
2739     else
2740     {
2741         nsp_est = nsp_est_nl;
2742     }
2743
2744     /* Thus the (average) maximum j-list size should be as follows.
2745      * Since there is overhead, we shouldn't make the lists too small
2746      * (and we can't chop up j-groups) so we use a minimum target size of 36.
2747      */
2748     *nsubpair_target  = std::max(nsubpair_target_min,
2749                                  roundToInt(nsp_est/min_ci_balanced));
2750     *nsubpair_tot_est = static_cast<int>(nsp_est);
2751
2752     if (debug)
2753     {
2754         fprintf(debug, "nbl nsp estimate %.1f, nsubpair_target %d\n",
2755                 nsp_est, *nsubpair_target);
2756     }
2757 }
2758
2759 /* Debug list print function */
2760 static void print_nblist_ci_cj(FILE *fp, const NbnxnPairlistCpu *nbl)
2761 {
2762     for (const nbnxn_ci_t &ciEntry : nbl->ci)
2763     {
2764         fprintf(fp, "ci %4d  shift %2d  ncj %3d\n",
2765                 ciEntry.ci, ciEntry.shift,
2766                 ciEntry.cj_ind_end - ciEntry.cj_ind_start);
2767
2768         for (int j = ciEntry.cj_ind_start; j < ciEntry.cj_ind_end; j++)
2769         {
2770             fprintf(fp, "  cj %5d  imask %x\n",
2771                     nbl->cj[j].cj,
2772                     nbl->cj[j].excl);
2773         }
2774     }
2775 }
2776
2777 /* Debug list print function */
2778 static void print_nblist_sci_cj(FILE *fp, const NbnxnPairlistGpu *nbl)
2779 {
2780     for (const nbnxn_sci_t &sci : nbl->sci)
2781     {
2782         fprintf(fp, "ci %4d  shift %2d  ncj4 %2d\n",
2783                 sci.sci, sci.shift,
2784                 sci.numJClusterGroups());
2785
2786         int ncp = 0;
2787         for (int j4 = sci.cj4_ind_start; j4 < sci.cj4_ind_end; j4++)
2788         {
2789             for (int j = 0; j < c_nbnxnGpuJgroupSize; j++)
2790             {
2791                 fprintf(fp, "  sj %5d  imask %x\n",
2792                         nbl->cj4[j4].cj[j],
2793                         nbl->cj4[j4].imei[0].imask);
2794                 for (int si = 0; si < c_gpuNumClusterPerCell; si++)
2795                 {
2796                     if (nbl->cj4[j4].imei[0].imask & (1U << (j*c_gpuNumClusterPerCell + si)))
2797                     {
2798                         ncp++;
2799                     }
2800                 }
2801             }
2802         }
2803         fprintf(fp, "ci %4d  shift %2d  ncj4 %2d ncp %3d\n",
2804                 sci.sci, sci.shift,
2805                 sci.numJClusterGroups(),
2806                 ncp);
2807     }
2808 }
2809
2810 /* Combine pair lists *nbl generated on multiple threads nblc */
2811 static void combine_nblists(int nnbl, NbnxnPairlistGpu **nbl,
2812                             NbnxnPairlistGpu *nblc)
2813 {
2814     int nsci  = nblc->sci.size();
2815     int ncj4  = nblc->cj4.size();
2816     int nexcl = nblc->excl.size();
2817     for (int i = 0; i < nnbl; i++)
2818     {
2819         nsci  += nbl[i]->sci.size();
2820         ncj4  += nbl[i]->cj4.size();
2821         nexcl += nbl[i]->excl.size();
2822     }
2823
2824     /* Resize with the final, combined size, so we can fill in parallel */
2825     /* NOTE: For better performance we should use default initialization */
2826     nblc->sci.resize(nsci);
2827     nblc->cj4.resize(ncj4);
2828     nblc->excl.resize(nexcl);
2829
2830     /* Each thread should copy its own data to the combined arrays,
2831      * as otherwise data will go back and forth between different caches.
2832      */
2833 #if GMX_OPENMP && !(defined __clang_analyzer__)
2834     int nthreads = gmx_omp_nthreads_get(emntPairsearch);
2835 #endif
2836
2837 #pragma omp parallel for num_threads(nthreads) schedule(static)
2838     for (int n = 0; n < nnbl; n++)
2839     {
2840         try
2841         {
2842             /* Determine the offset in the combined data for our thread.
2843              * Note that the original sizes in nblc are lost.
2844              */
2845             int sci_offset  = nsci;
2846             int cj4_offset  = ncj4;
2847             int excl_offset = nexcl;
2848
2849             for (int i = n; i < nnbl; i++)
2850             {
2851                 sci_offset  -= nbl[i]->sci.size();
2852                 cj4_offset  -= nbl[i]->cj4.size();
2853                 excl_offset -= nbl[i]->excl.size();
2854             }
2855
2856             const NbnxnPairlistGpu &nbli = *nbl[n];
2857
2858             for (size_t i = 0; i < nbli.sci.size(); i++)
2859             {
2860                 nblc->sci[sci_offset + i]                = nbli.sci[i];
2861                 nblc->sci[sci_offset + i].cj4_ind_start += cj4_offset;
2862                 nblc->sci[sci_offset + i].cj4_ind_end   += cj4_offset;
2863             }
2864
2865             for (size_t j4 = 0; j4 < nbli.cj4.size(); j4++)
2866             {
2867                 nblc->cj4[cj4_offset + j4]                   = nbli.cj4[j4];
2868                 nblc->cj4[cj4_offset + j4].imei[0].excl_ind += excl_offset;
2869                 nblc->cj4[cj4_offset + j4].imei[1].excl_ind += excl_offset;
2870             }
2871
2872             for (size_t j4 = 0; j4 < nbli.excl.size(); j4++)
2873             {
2874                 nblc->excl[excl_offset + j4] = nbli.excl[j4];
2875             }
2876         }
2877         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
2878     }
2879
2880     for (int n = 0; n < nnbl; n++)
2881     {
2882         nblc->nci_tot += nbl[n]->nci_tot;
2883     }
2884 }
2885
2886 static void balance_fep_lists(const nbnxn_search   *nbs,
2887                               nbnxn_pairlist_set_t *nbl_lists)
2888 {
2889     int       nnbl;
2890     int       nri_tot, nrj_tot, nrj_target;
2891     int       th_dest;
2892     t_nblist *nbld;
2893
2894     nnbl = nbl_lists->nnbl;
2895
2896     if (nnbl == 1)
2897     {
2898         /* Nothing to balance */
2899         return;
2900     }
2901
2902     /* Count the total i-lists and pairs */
2903     nri_tot = 0;
2904     nrj_tot = 0;
2905     for (int th = 0; th < nnbl; th++)
2906     {
2907         nri_tot += nbl_lists->nbl_fep[th]->nri;
2908         nrj_tot += nbl_lists->nbl_fep[th]->nrj;
2909     }
2910
2911     nrj_target = (nrj_tot + nnbl - 1)/nnbl;
2912
2913     assert(gmx_omp_nthreads_get(emntNonbonded) == nnbl);
2914
2915 #pragma omp parallel for schedule(static) num_threads(nnbl)
2916     for (int th = 0; th < nnbl; th++)
2917     {
2918         try
2919         {
2920             t_nblist *nbl = nbs->work[th].nbl_fep.get();
2921
2922             /* Note that here we allocate for the total size, instead of
2923              * a per-thread esimate (which is hard to obtain).
2924              */
2925             if (nri_tot > nbl->maxnri)
2926             {
2927                 nbl->maxnri = over_alloc_large(nri_tot);
2928                 reallocate_nblist(nbl);
2929             }
2930             if (nri_tot > nbl->maxnri || nrj_tot > nbl->maxnrj)
2931             {
2932                 nbl->maxnrj = over_alloc_small(nrj_tot);
2933                 srenew(nbl->jjnr, nbl->maxnrj);
2934                 srenew(nbl->excl_fep, nbl->maxnrj);
2935             }
2936
2937             clear_pairlist_fep(nbl);
2938         }
2939         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
2940     }
2941
2942     /* Loop over the source lists and assign and copy i-entries */
2943     th_dest = 0;
2944     nbld    = nbs->work[th_dest].nbl_fep.get();
2945     for (int th = 0; th < nnbl; th++)
2946     {
2947         t_nblist *nbls;
2948
2949         nbls = nbl_lists->nbl_fep[th];
2950
2951         for (int i = 0; i < nbls->nri; i++)
2952         {
2953             int nrj;
2954
2955             /* The number of pairs in this i-entry */
2956             nrj = nbls->jindex[i+1] - nbls->jindex[i];
2957
2958             /* Decide if list th_dest is too large and we should procede
2959              * to the next destination list.
2960              */
2961             if (th_dest+1 < nnbl && nbld->nrj > 0 &&
2962                 nbld->nrj + nrj - nrj_target > nrj_target - nbld->nrj)
2963             {
2964                 th_dest++;
2965                 nbld = nbs->work[th_dest].nbl_fep.get();
2966             }
2967
2968             nbld->iinr[nbld->nri]  = nbls->iinr[i];
2969             nbld->gid[nbld->nri]   = nbls->gid[i];
2970             nbld->shift[nbld->nri] = nbls->shift[i];
2971
2972             for (int j = nbls->jindex[i]; j < nbls->jindex[i+1]; j++)
2973             {
2974                 nbld->jjnr[nbld->nrj]     = nbls->jjnr[j];
2975                 nbld->excl_fep[nbld->nrj] = nbls->excl_fep[j];
2976                 nbld->nrj++;
2977             }
2978             nbld->nri++;
2979             nbld->jindex[nbld->nri] = nbld->nrj;
2980         }
2981     }
2982
2983     /* Swap the list pointers */
2984     for (int th = 0; th < nnbl; th++)
2985     {
2986         t_nblist *nbl_tmp      = nbs->work[th].nbl_fep.release();
2987         nbs->work[th].nbl_fep.reset(nbl_lists->nbl_fep[th]);
2988         nbl_lists->nbl_fep[th] = nbl_tmp;
2989
2990         if (debug)
2991         {
2992             fprintf(debug, "nbl_fep[%d] nri %4d nrj %4d\n",
2993                     th,
2994                     nbl_lists->nbl_fep[th]->nri,
2995                     nbl_lists->nbl_fep[th]->nrj);
2996         }
2997     }
2998 }
2999
3000 /* Returns the next ci to be processes by our thread */
3001 static gmx_bool next_ci(const nbnxn_grid_t &grid,
3002                         int nth, int ci_block,
3003                         int *ci_x, int *ci_y,
3004                         int *ci_b, int *ci)
3005 {
3006     (*ci_b)++;
3007     (*ci)++;
3008
3009     if (*ci_b == ci_block)
3010     {
3011         /* Jump to the next block assigned to this task */
3012         *ci   += (nth - 1)*ci_block;
3013         *ci_b  = 0;
3014     }
3015
3016     if (*ci >= grid.nc)
3017     {
3018         return FALSE;
3019     }
3020
3021     while (*ci >= grid.cxy_ind[*ci_x*grid.numCells[YY] + *ci_y + 1])
3022     {
3023         *ci_y += 1;
3024         if (*ci_y == grid.numCells[YY])
3025         {
3026             *ci_x += 1;
3027             *ci_y  = 0;
3028         }
3029     }
3030
3031     return TRUE;
3032 }
3033
3034 /* Returns the distance^2 for which we put cell pairs in the list
3035  * without checking atom pair distances. This is usually < rlist^2.
3036  */
3037 static float boundingbox_only_distance2(const nbnxn_grid_t &iGrid,
3038                                         const nbnxn_grid_t &jGrid,
3039                                         real                rlist,
3040                                         gmx_bool            simple)
3041 {
3042     /* If the distance between two sub-cell bounding boxes is less
3043      * than this distance, do not check the distance between
3044      * all particle pairs in the sub-cell, since then it is likely
3045      * that the box pair has atom pairs within the cut-off.
3046      * We use the nblist cut-off minus 0.5 times the average x/y diagonal
3047      * spacing of the sub-cells. Around 40% of the checked pairs are pruned.
3048      * Using more than 0.5 gains at most 0.5%.
3049      * If forces are calculated more than twice, the performance gain
3050      * in the force calculation outweighs the cost of checking.
3051      * Note that with subcell lists, the atom-pair distance check
3052      * is only performed when only 1 out of 8 sub-cells in within range,
3053      * this is because the GPU is much faster than the cpu.
3054      */
3055     real bbx, bby;
3056     real rbb2;
3057
3058     bbx = 0.5*(iGrid.cellSize[XX] + jGrid.cellSize[XX]);
3059     bby = 0.5*(iGrid.cellSize[YY] + jGrid.cellSize[YY]);
3060     if (!simple)
3061     {
3062         bbx /= c_gpuNumClusterPerCellX;
3063         bby /= c_gpuNumClusterPerCellY;
3064     }
3065
3066     rbb2 = std::max(0.0, rlist - 0.5*std::sqrt(bbx*bbx + bby*bby));
3067     rbb2 = rbb2 * rbb2;
3068
3069 #if !GMX_DOUBLE
3070     return rbb2;
3071 #else
3072     return (float)((1+GMX_FLOAT_EPS)*rbb2);
3073 #endif
3074 }
3075
3076 static int get_ci_block_size(const nbnxn_grid_t &iGrid,
3077                              gmx_bool bDomDec, int nth)
3078 {
3079     const int ci_block_enum      = 5;
3080     const int ci_block_denom     = 11;
3081     const int ci_block_min_atoms = 16;
3082     int       ci_block;
3083
3084     /* Here we decide how to distribute the blocks over the threads.
3085      * We use prime numbers to try to avoid that the grid size becomes
3086      * a multiple of the number of threads, which would lead to some
3087      * threads getting "inner" pairs and others getting boundary pairs,
3088      * which in turns will lead to load imbalance between threads.
3089      * Set the block size as 5/11/ntask times the average number of cells
3090      * in a y,z slab. This should ensure a quite uniform distribution
3091      * of the grid parts of the different thread along all three grid
3092      * zone boundaries with 3D domain decomposition. At the same time
3093      * the blocks will not become too small.
3094      */
3095     ci_block = (iGrid.nc*ci_block_enum)/(ci_block_denom*iGrid.numCells[XX]*nth);
3096
3097     /* Ensure the blocks are not too small: avoids cache invalidation */
3098     if (ci_block*iGrid.na_sc < ci_block_min_atoms)
3099     {
3100         ci_block = (ci_block_min_atoms + iGrid.na_sc - 1)/iGrid.na_sc;
3101     }
3102
3103     /* Without domain decomposition
3104      * or with less than 3 blocks per task, divide in nth blocks.
3105      */
3106     if (!bDomDec || nth*3*ci_block > iGrid.nc)
3107     {
3108         ci_block = (iGrid.nc + nth - 1)/nth;
3109     }
3110
3111     if (ci_block > 1 && (nth - 1)*ci_block >= iGrid.nc)
3112     {
3113         /* Some threads have no work. Although reducing the block size
3114          * does not decrease the block count on the first few threads,
3115          * with GPUs better mixing of "upper" cells that have more empty
3116          * clusters results in a somewhat lower max load over all threads.
3117          * Without GPUs the regime of so few atoms per thread is less
3118          * performance relevant, but with 8-wide SIMD the same reasoning
3119          * applies, since the pair list uses 4 i-atom "sub-clusters".
3120          */
3121         ci_block--;
3122     }
3123
3124     return ci_block;
3125 }
3126
3127 /* Returns the number of bits to right-shift a cluster index to obtain
3128  * the corresponding force buffer flag index.
3129  */
3130 static int getBufferFlagShift(int numAtomsPerCluster)
3131 {
3132     int bufferFlagShift = 0;
3133     while ((numAtomsPerCluster << bufferFlagShift) < NBNXN_BUFFERFLAG_SIZE)
3134     {
3135         bufferFlagShift++;
3136     }
3137
3138     return bufferFlagShift;
3139 }
3140
3141 static bool pairlistIsSimple(const NbnxnPairlistCpu gmx_unused &pairlist)
3142 {
3143     return true;
3144 }
3145
3146 static bool pairlistIsSimple(const NbnxnPairlistGpu gmx_unused &pairlist)
3147 {
3148     return false;
3149 }
3150
3151 static void makeClusterListWrapper(NbnxnPairlistCpu              *nbl,
3152                                    const nbnxn_grid_t gmx_unused &iGrid,
3153                                    const int                      ci,
3154                                    const nbnxn_grid_t            &jGrid,
3155                                    const int                      firstCell,
3156                                    const int                      lastCell,
3157                                    const bool                     excludeSubDiagonal,
3158                                    const nbnxn_atomdata_t        *nbat,
3159                                    const real                     rlist2,
3160                                    const real                     rbb2,
3161                                    const int                      nb_kernel_type,
3162                                    int                           *numDistanceChecks)
3163 {
3164     switch (nb_kernel_type)
3165     {
3166         case nbnxnk4x4_PlainC:
3167             makeClusterListSimple(jGrid,
3168                                   nbl, ci, firstCell, lastCell,
3169                                   excludeSubDiagonal,
3170                                   nbat->x().data(),
3171                                   rlist2, rbb2,
3172                                   numDistanceChecks);
3173             break;
3174 #ifdef GMX_NBNXN_SIMD_4XN
3175         case nbnxnk4xN_SIMD_4xN:
3176             makeClusterListSimd4xn(jGrid,
3177                                    nbl, ci, firstCell, lastCell,
3178                                    excludeSubDiagonal,
3179                                    nbat->x().data(),
3180                                    rlist2, rbb2,
3181                                    numDistanceChecks);
3182             break;
3183 #endif
3184 #ifdef GMX_NBNXN_SIMD_2XNN
3185         case nbnxnk4xN_SIMD_2xNN:
3186             makeClusterListSimd2xnn(jGrid,
3187                                     nbl, ci, firstCell, lastCell,
3188                                     excludeSubDiagonal,
3189                                     nbat->x().data(),
3190                                     rlist2, rbb2,
3191                                     numDistanceChecks);
3192             break;
3193 #endif
3194     }
3195 }
3196
3197 static void makeClusterListWrapper(NbnxnPairlistGpu              *nbl,
3198                                    const nbnxn_grid_t &gmx_unused iGrid,
3199                                    const int                      ci,
3200                                    const nbnxn_grid_t            &jGrid,
3201                                    const int                      firstCell,
3202                                    const int                      lastCell,
3203                                    const bool                     excludeSubDiagonal,
3204                                    const nbnxn_atomdata_t        *nbat,
3205                                    const real                     rlist2,
3206                                    const real                     rbb2,
3207                                    const int gmx_unused           nb_kernel_type,
3208                                    int                           *numDistanceChecks)
3209 {
3210     for (int cj = firstCell; cj <= lastCell; cj++)
3211     {
3212         make_cluster_list_supersub(iGrid, jGrid,
3213                                    nbl, ci, cj,
3214                                    excludeSubDiagonal,
3215                                    nbat->xstride, nbat->x().data(),
3216                                    rlist2, rbb2,
3217                                    numDistanceChecks);
3218     }
3219 }
3220
3221 static int getNumSimpleJClustersInList(const NbnxnPairlistCpu &nbl)
3222 {
3223     return nbl.cj.size();
3224 }
3225
3226 static int getNumSimpleJClustersInList(const gmx_unused NbnxnPairlistGpu &nbl)
3227 {
3228     return 0;
3229 }
3230
3231 static void incrementNumSimpleJClustersInList(NbnxnPairlistCpu *nbl,
3232                                               int               ncj_old_j)
3233 {
3234     nbl->ncjInUse += nbl->cj.size() - ncj_old_j;
3235 }
3236
3237 static void incrementNumSimpleJClustersInList(NbnxnPairlistGpu gmx_unused *nbl,
3238                                               int              gmx_unused  ncj_old_j)
3239 {
3240 }
3241
3242 static void checkListSizeConsistency(const NbnxnPairlistCpu &nbl,
3243                                      const bool              haveFreeEnergy)
3244 {
3245     GMX_RELEASE_ASSERT(static_cast<size_t>(nbl.ncjInUse) == nbl.cj.size() || haveFreeEnergy,
3246                        "Without free-energy all cj pair-list entries should be in use. "
3247                        "Note that subsequent code does not make use of the equality, "
3248                        "this check is only here to catch bugs");
3249 }
3250
3251 static void checkListSizeConsistency(const NbnxnPairlistGpu gmx_unused &nbl,
3252                                      bool gmx_unused                    haveFreeEnergy)
3253 {
3254     /* We currently can not check consistency here */
3255 }
3256
3257 /* Set the buffer flags for newly added entries in the list */
3258 static void setBufferFlags(const NbnxnPairlistCpu &nbl,
3259                            const int               ncj_old_j,
3260                            const int               gridj_flag_shift,
3261                            gmx_bitmask_t          *gridj_flag,
3262                            const int               th)
3263 {
3264     if (gmx::ssize(nbl.cj) > ncj_old_j)
3265     {
3266         int cbFirst = nbl.cj[ncj_old_j].cj >> gridj_flag_shift;
3267         int cbLast  = nbl.cj.back().cj >> gridj_flag_shift;
3268         for (int cb = cbFirst; cb <= cbLast; cb++)
3269         {
3270             bitmask_init_bit(&gridj_flag[cb], th);
3271         }
3272     }
3273 }
3274
3275 static void setBufferFlags(const NbnxnPairlistGpu gmx_unused &nbl,
3276                            int gmx_unused                     ncj_old_j,
3277                            int gmx_unused                     gridj_flag_shift,
3278                            gmx_bitmask_t gmx_unused          *gridj_flag,
3279                            int gmx_unused                     th)
3280 {
3281     GMX_ASSERT(false, "This function should never be called");
3282 }
3283
3284 /* Generates the part of pair-list nbl assigned to our thread */
3285 template <typename T>
3286 static void nbnxn_make_pairlist_part(const nbnxn_search *nbs,
3287                                      const nbnxn_grid_t &iGrid,
3288                                      const nbnxn_grid_t &jGrid,
3289                                      nbnxn_search_work_t *work,
3290                                      const nbnxn_atomdata_t *nbat,
3291                                      const t_blocka &exclusions,
3292                                      real rlist,
3293                                      int nb_kernel_type,
3294                                      int ci_block,
3295                                      gmx_bool bFBufferFlag,
3296                                      int nsubpair_max,
3297                                      gmx_bool progBal,
3298                                      float nsubpair_tot_est,
3299                                      int th, int nth,
3300                                      T *nbl,
3301                                      t_nblist *nbl_fep)
3302 {
3303     int               na_cj_2log;
3304     matrix            box;
3305     real              rlist2, rl_fep2 = 0;
3306     float             rbb2;
3307     int               ci_b, ci, ci_x, ci_y, ci_xy;
3308     ivec              shp;
3309     real              bx0, bx1, by0, by1, bz0, bz1;
3310     real              bz1_frac;
3311     real              d2cx, d2z, d2z_cx, d2z_cy, d2zx, d2zxy, d2xy;
3312     int               cxf, cxl, cyf, cyf_x, cyl;
3313     int               numDistanceChecks;
3314     int               gridi_flag_shift = 0, gridj_flag_shift = 0;
3315     gmx_bitmask_t    *gridj_flag       = nullptr;
3316     int               ncj_old_i, ncj_old_j;
3317
3318     nbs_cycle_start(&work->cc[enbsCCsearch]);
3319
3320     if (jGrid.bSimple != pairlistIsSimple(*nbl) ||
3321         iGrid.bSimple != pairlistIsSimple(*nbl))
3322     {
3323         gmx_incons("Grid incompatible with pair-list");
3324     }
3325
3326     sync_work(nbl);
3327     GMX_ASSERT(nbl->na_ci == jGrid.na_c, "The cluster sizes in the list and grid should match");
3328     nbl->na_cj = nbnxn_kernel_to_cluster_j_size(nb_kernel_type);
3329     na_cj_2log = get_2log(nbl->na_cj);
3330
3331     nbl->rlist  = rlist;
3332
3333     if (bFBufferFlag)
3334     {
3335         /* Determine conversion of clusters to flag blocks */
3336         gridi_flag_shift = getBufferFlagShift(nbl->na_ci);
3337         gridj_flag_shift = getBufferFlagShift(nbl->na_cj);
3338
3339         gridj_flag       = work->buffer_flags.flag;
3340     }
3341
3342     copy_mat(nbs->box, box);
3343
3344     rlist2 = nbl->rlist*nbl->rlist;
3345
3346     if (nbs->bFEP && !pairlistIsSimple(*nbl))
3347     {
3348         /* Determine an atom-pair list cut-off distance for FEP atom pairs.
3349          * We should not simply use rlist, since then we would not have
3350          * the small, effective buffering of the NxN lists.
3351          * The buffer is on overestimate, but the resulting cost for pairs
3352          * beyond rlist is neglible compared to the FEP pairs within rlist.
3353          */
3354         rl_fep2 = nbl->rlist + effective_buffer_1x1_vs_MxN(iGrid, jGrid);
3355
3356         if (debug)
3357         {
3358             fprintf(debug, "nbl_fep atom-pair rlist %f\n", rl_fep2);
3359         }
3360         rl_fep2 = rl_fep2*rl_fep2;
3361     }
3362
3363     rbb2 = boundingbox_only_distance2(iGrid, jGrid, nbl->rlist, pairlistIsSimple(*nbl));
3364
3365     if (debug)
3366     {
3367         fprintf(debug, "nbl bounding box only distance %f\n", std::sqrt(rbb2));
3368     }
3369
3370     const bool isIntraGridList = (&iGrid == &jGrid);
3371
3372     /* Set the shift range */
3373     for (int d = 0; d < DIM; d++)
3374     {
3375         /* Check if we need periodicity shifts.
3376          * Without PBC or with domain decomposition we don't need them.
3377          */
3378         if (d >= ePBC2npbcdim(nbs->ePBC) || nbs->dd_dim[d])
3379         {
3380             shp[d] = 0;
3381         }
3382         else
3383         {
3384             const real listRangeCellToCell = listRangeForGridCellToGridCell(rlist, iGrid, jGrid);
3385             if (d == XX &&
3386                 box[XX][XX] - fabs(box[YY][XX]) - fabs(box[ZZ][XX]) < listRangeCellToCell)
3387             {
3388                 shp[d] = 2;
3389             }
3390             else
3391             {
3392                 shp[d] = 1;
3393             }
3394         }
3395     }
3396     const bool bSimple = pairlistIsSimple(*nbl);
3397     gmx::ArrayRef<const nbnxn_bb_t> bb_i;
3398 #if NBNXN_BBXXXX
3399     gmx::ArrayRef<const float>      pbb_i;
3400     if (bSimple)
3401     {
3402         bb_i  = iGrid.bb;
3403     }
3404     else
3405     {
3406         pbb_i = iGrid.pbb;
3407     }
3408 #else
3409     /* We use the normal bounding box format for both grid types */
3410     bb_i  = iGrid.bb;
3411 #endif
3412     gmx::ArrayRef<const float> bbcz_i  = iGrid.bbcz;
3413     gmx::ArrayRef<const int>   flags_i = iGrid.flags;
3414     gmx::ArrayRef<const float> bbcz_j  = jGrid.bbcz;
3415     int                        cell0_i = iGrid.cell0;
3416
3417     if (debug)
3418     {
3419         fprintf(debug, "nbl nc_i %d col.av. %.1f ci_block %d\n",
3420                 iGrid.nc, iGrid.nc/static_cast<double>(iGrid.numCells[XX]*iGrid.numCells[YY]), ci_block);
3421     }
3422
3423     numDistanceChecks = 0;
3424
3425     const real listRangeBBToJCell2 = gmx::square(listRangeForBoundingBoxToGridCell(rlist, jGrid));
3426
3427     /* Initially ci_b and ci to 1 before where we want them to start,
3428      * as they will both be incremented in next_ci.
3429      */
3430     ci_b = -1;
3431     ci   = th*ci_block - 1;
3432     ci_x = 0;
3433     ci_y = 0;
3434     while (next_ci(iGrid, nth, ci_block, &ci_x, &ci_y, &ci_b, &ci))
3435     {
3436         if (bSimple && flags_i[ci] == 0)
3437         {
3438             continue;
3439         }
3440
3441         ncj_old_i = getNumSimpleJClustersInList(*nbl);
3442
3443         d2cx = 0;
3444         if (!isIntraGridList && shp[XX] == 0)
3445         {
3446             if (bSimple)
3447             {
3448                 bx1 = bb_i[ci].upper[BB_X];
3449             }
3450             else
3451             {
3452                 bx1 = iGrid.c0[XX] + (ci_x+1)*iGrid.cellSize[XX];
3453             }
3454             if (bx1 < jGrid.c0[XX])
3455             {
3456                 d2cx = gmx::square(jGrid.c0[XX] - bx1);
3457
3458                 if (d2cx >= listRangeBBToJCell2)
3459                 {
3460                     continue;
3461                 }
3462             }
3463         }
3464
3465         ci_xy = ci_x*iGrid.numCells[YY] + ci_y;
3466
3467         /* Loop over shift vectors in three dimensions */
3468         for (int tz = -shp[ZZ]; tz <= shp[ZZ]; tz++)
3469         {
3470             const real shz = tz*box[ZZ][ZZ];
3471
3472             bz0 = bbcz_i[ci*NNBSBB_D  ] + shz;
3473             bz1 = bbcz_i[ci*NNBSBB_D+1] + shz;
3474
3475             if (tz == 0)
3476             {
3477                 d2z = 0;
3478             }
3479             else if (tz < 0)
3480             {
3481                 d2z = gmx::square(bz1);
3482             }
3483             else
3484             {
3485                 d2z = gmx::square(bz0 - box[ZZ][ZZ]);
3486             }
3487
3488             d2z_cx = d2z + d2cx;
3489
3490             if (d2z_cx >= rlist2)
3491             {
3492                 continue;
3493             }
3494
3495             bz1_frac = bz1/(iGrid.cxy_ind[ci_xy+1] - iGrid.cxy_ind[ci_xy]);
3496             if (bz1_frac < 0)
3497             {
3498                 bz1_frac = 0;
3499             }
3500             /* The check with bz1_frac close to or larger than 1 comes later */
3501
3502             for (int ty = -shp[YY]; ty <= shp[YY]; ty++)
3503             {
3504                 const real shy = ty*box[YY][YY] + tz*box[ZZ][YY];
3505
3506                 if (bSimple)
3507                 {
3508                     by0 = bb_i[ci].lower[BB_Y] + shy;
3509                     by1 = bb_i[ci].upper[BB_Y] + shy;
3510                 }
3511                 else
3512                 {
3513                     by0 = iGrid.c0[YY] + (ci_y  )*iGrid.cellSize[YY] + shy;
3514                     by1 = iGrid.c0[YY] + (ci_y+1)*iGrid.cellSize[YY] + shy;
3515                 }
3516
3517                 get_cell_range<YY>(by0, by1,
3518                                    jGrid,
3519                                    d2z_cx, rlist,
3520                                    &cyf, &cyl);
3521
3522                 if (cyf > cyl)
3523                 {
3524                     continue;
3525                 }
3526
3527                 d2z_cy = d2z;
3528                 if (by1 < jGrid.c0[YY])
3529                 {
3530                     d2z_cy += gmx::square(jGrid.c0[YY] - by1);
3531                 }
3532                 else if (by0 > jGrid.c1[YY])
3533                 {
3534                     d2z_cy += gmx::square(by0 - jGrid.c1[YY]);
3535                 }
3536
3537                 for (int tx = -shp[XX]; tx <= shp[XX]; tx++)
3538                 {
3539                     const int  shift              = XYZ2IS(tx, ty, tz);
3540
3541                     const bool excludeSubDiagonal = (isIntraGridList && shift == CENTRAL);
3542
3543                     if (c_pbcShiftBackward && isIntraGridList && shift > CENTRAL)
3544                     {
3545                         continue;
3546                     }
3547
3548                     const real shx = tx*box[XX][XX] + ty*box[YY][XX] + tz*box[ZZ][XX];
3549
3550                     if (bSimple)
3551                     {
3552                         bx0 = bb_i[ci].lower[BB_X] + shx;
3553                         bx1 = bb_i[ci].upper[BB_X] + shx;
3554                     }
3555                     else
3556                     {
3557                         bx0 = iGrid.c0[XX] + (ci_x  )*iGrid.cellSize[XX] + shx;
3558                         bx1 = iGrid.c0[XX] + (ci_x+1)*iGrid.cellSize[XX] + shx;
3559                     }
3560
3561                     get_cell_range<XX>(bx0, bx1,
3562                                        jGrid,
3563                                        d2z_cy, rlist,
3564                                        &cxf, &cxl);
3565
3566                     if (cxf > cxl)
3567                     {
3568                         continue;
3569                     }
3570
3571                     addNewIEntry(nbl, cell0_i+ci, shift, flags_i[ci]);
3572
3573                     if ((!c_pbcShiftBackward || excludeSubDiagonal) &&
3574                         cxf < ci_x)
3575                     {
3576                         /* Leave the pairs with i > j.
3577                          * x is the major index, so skip half of it.
3578                          */
3579                         cxf = ci_x;
3580                     }
3581
3582                     set_icell_bb(iGrid, ci, shx, shy, shz,
3583                                  nbl->work);
3584
3585                     icell_set_x(cell0_i+ci, shx, shy, shz,
3586                                 nbat->xstride, nbat->x().data(),
3587                                 nb_kernel_type,
3588                                 nbl->work);
3589
3590                     for (int cx = cxf; cx <= cxl; cx++)
3591                     {
3592                         d2zx = d2z;
3593                         if (jGrid.c0[XX] + cx*jGrid.cellSize[XX] > bx1)
3594                         {
3595                             d2zx += gmx::square(jGrid.c0[XX] + cx*jGrid.cellSize[XX] - bx1);
3596                         }
3597                         else if (jGrid.c0[XX] + (cx+1)*jGrid.cellSize[XX] < bx0)
3598                         {
3599                             d2zx += gmx::square(jGrid.c0[XX] + (cx+1)*jGrid.cellSize[XX] - bx0);
3600                         }
3601
3602                         if (isIntraGridList &&
3603                             cx == 0 &&
3604                             (!c_pbcShiftBackward || shift == CENTRAL) &&
3605                             cyf < ci_y)
3606                         {
3607                             /* Leave the pairs with i > j.
3608                              * Skip half of y when i and j have the same x.
3609                              */
3610                             cyf_x = ci_y;
3611                         }
3612                         else
3613                         {
3614                             cyf_x = cyf;
3615                         }
3616
3617                         for (int cy = cyf_x; cy <= cyl; cy++)
3618                         {
3619                             const int columnStart = jGrid.cxy_ind[cx*jGrid.numCells[YY] + cy];
3620                             const int columnEnd   = jGrid.cxy_ind[cx*jGrid.numCells[YY] + cy + 1];
3621
3622                             d2zxy = d2zx;
3623                             if (jGrid.c0[YY] + cy*jGrid.cellSize[YY] > by1)
3624                             {
3625                                 d2zxy += gmx::square(jGrid.c0[YY] + cy*jGrid.cellSize[YY] - by1);
3626                             }
3627                             else if (jGrid.c0[YY] + (cy+1)*jGrid.cellSize[YY] < by0)
3628                             {
3629                                 d2zxy += gmx::square(jGrid.c0[YY] + (cy+1)*jGrid.cellSize[YY] - by0);
3630                             }
3631                             if (columnStart < columnEnd && d2zxy < listRangeBBToJCell2)
3632                             {
3633                                 /* To improve efficiency in the common case
3634                                  * of a homogeneous particle distribution,
3635                                  * we estimate the index of the middle cell
3636                                  * in range (midCell). We search down and up
3637                                  * starting from this index.
3638                                  *
3639                                  * Note that the bbcz_j array contains bounds
3640                                  * for i-clusters, thus for clusters of 4 atoms.
3641                                  * For the common case where the j-cluster size
3642                                  * is 8, we could step with a stride of 2,
3643                                  * but we do not do this because it would
3644                                  * complicate this code even more.
3645                                  */
3646                                 int midCell = columnStart + static_cast<int>(bz1_frac*(columnEnd - columnStart));
3647                                 if (midCell >= columnEnd)
3648                                 {
3649                                     midCell = columnEnd - 1;
3650                                 }
3651
3652                                 d2xy = d2zxy - d2z;
3653
3654                                 /* Find the lowest cell that can possibly
3655                                  * be within range.
3656                                  * Check if we hit the bottom of the grid,
3657                                  * if the j-cell is below the i-cell and if so,
3658                                  * if it is within range.
3659                                  */
3660                                 int downTestCell = midCell;
3661                                 while (downTestCell >= columnStart &&
3662                                        (bbcz_j[downTestCell*NNBSBB_D + 1] >= bz0 ||
3663                                         d2xy + gmx::square(bbcz_j[downTestCell*NNBSBB_D + 1] - bz0) < rlist2))
3664                                 {
3665                                     downTestCell--;
3666                                 }
3667                                 int firstCell = downTestCell + 1;
3668
3669                                 /* Find the highest cell that can possibly
3670                                  * be within range.
3671                                  * Check if we hit the top of the grid,
3672                                  * if the j-cell is above the i-cell and if so,
3673                                  * if it is within range.
3674                                  */
3675                                 int upTestCell = midCell + 1;
3676                                 while (upTestCell < columnEnd &&
3677                                        (bbcz_j[upTestCell*NNBSBB_D] <= bz1 ||
3678                                         d2xy + gmx::square(bbcz_j[upTestCell*NNBSBB_D] - bz1) < rlist2))
3679                                 {
3680                                     upTestCell++;
3681                                 }
3682                                 int lastCell = upTestCell - 1;
3683
3684 #define NBNXN_REFCODE 0
3685 #if NBNXN_REFCODE
3686                                 {
3687                                     /* Simple reference code, for debugging,
3688                                      * overrides the more complex code above.
3689                                      */
3690                                     firstCell = columnEnd;
3691                                     lastCell  = -1;
3692                                     for (int k = columnStart; k < columnEnd; k++)
3693                                     {
3694                                         if (d2xy + gmx::square(bbcz_j[k*NNBSBB_D + 1] - bz0) < rlist2 &&
3695                                             k < firstCell)
3696                                         {
3697                                             firstCell = k;
3698                                         }
3699                                         if (d2xy + gmx::square(bbcz_j[k*NNBSBB_D] - bz1) < rlist2 &&
3700                                             k > lastCell)
3701                                         {
3702                                             lastCell = k;
3703                                         }
3704                                     }
3705                                 }
3706 #endif
3707
3708                                 if (isIntraGridList)
3709                                 {
3710                                     /* We want each atom/cell pair only once,
3711                                      * only use cj >= ci.
3712                                      */
3713                                     if (!c_pbcShiftBackward || shift == CENTRAL)
3714                                     {
3715                                         firstCell = std::max(firstCell, ci);
3716                                     }
3717                                 }
3718
3719                                 if (firstCell <= lastCell)
3720                                 {
3721                                     GMX_ASSERT(firstCell >= columnStart && lastCell < columnEnd, "The range should reside within the current grid column");
3722
3723                                     /* For f buffer flags with simple lists */
3724                                     ncj_old_j = getNumSimpleJClustersInList(*nbl);
3725
3726                                     makeClusterListWrapper(nbl,
3727                                                            iGrid, ci,
3728                                                            jGrid, firstCell, lastCell,
3729                                                            excludeSubDiagonal,
3730                                                            nbat,
3731                                                            rlist2, rbb2,
3732                                                            nb_kernel_type,
3733                                                            &numDistanceChecks);
3734
3735                                     if (bFBufferFlag)
3736                                     {
3737                                         setBufferFlags(*nbl, ncj_old_j, gridj_flag_shift,
3738                                                        gridj_flag, th);
3739                                     }
3740
3741                                     incrementNumSimpleJClustersInList(nbl, ncj_old_j);
3742                                 }
3743                             }
3744                         }
3745                     }
3746
3747                     /* Set the exclusions for this ci list */
3748                     setExclusionsForIEntry(nbs,
3749                                            nbl,
3750                                            excludeSubDiagonal,
3751                                            na_cj_2log,
3752                                            *getOpenIEntry(nbl),
3753                                            exclusions);
3754
3755                     if (nbs->bFEP)
3756                     {
3757                         make_fep_list(nbs, nbat, nbl,
3758                                       excludeSubDiagonal,
3759                                       getOpenIEntry(nbl),
3760                                       shx, shy, shz,
3761                                       rl_fep2,
3762                                       iGrid, jGrid, nbl_fep);
3763                     }
3764
3765                     /* Close this ci list */
3766                     closeIEntry(nbl,
3767                                 nsubpair_max,
3768                                 progBal, nsubpair_tot_est,
3769                                 th, nth);
3770                 }
3771             }
3772         }
3773
3774         if (bFBufferFlag && getNumSimpleJClustersInList(*nbl) > ncj_old_i)
3775         {
3776             bitmask_init_bit(&(work->buffer_flags.flag[(iGrid.cell0+ci) >> gridi_flag_shift]), th);
3777         }
3778     }
3779
3780     work->ndistc = numDistanceChecks;
3781
3782     nbs_cycle_stop(&work->cc[enbsCCsearch]);
3783
3784     checkListSizeConsistency(*nbl, nbs->bFEP);
3785
3786     if (debug)
3787     {
3788         fprintf(debug, "number of distance checks %d\n", numDistanceChecks);
3789
3790         print_nblist_statistics(debug, nbl, nbs, rlist);
3791
3792         if (nbs->bFEP)
3793         {
3794             fprintf(debug, "nbl FEP list pairs: %d\n", nbl_fep->nrj);
3795         }
3796     }
3797 }
3798
3799 static void reduce_buffer_flags(const nbnxn_search         *nbs,
3800                                 int                         nsrc,
3801                                 const nbnxn_buffer_flags_t *dest)
3802 {
3803     for (int s = 0; s < nsrc; s++)
3804     {
3805         gmx_bitmask_t * flag = nbs->work[s].buffer_flags.flag;
3806
3807         for (int b = 0; b < dest->nflag; b++)
3808         {
3809             bitmask_union(&(dest->flag[b]), flag[b]);
3810         }
3811     }
3812 }
3813
3814 static void print_reduction_cost(const nbnxn_buffer_flags_t *flags, int nout)
3815 {
3816     int           nelem, nkeep, ncopy, nred, out;
3817     gmx_bitmask_t mask_0;
3818
3819     nelem = 0;
3820     nkeep = 0;
3821     ncopy = 0;
3822     nred  = 0;
3823     bitmask_init_bit(&mask_0, 0);
3824     for (int b = 0; b < flags->nflag; b++)
3825     {
3826         if (bitmask_is_equal(flags->flag[b], mask_0))
3827         {
3828             /* Only flag 0 is set, no copy of reduction required */
3829             nelem++;
3830             nkeep++;
3831         }
3832         else if (!bitmask_is_zero(flags->flag[b]))
3833         {
3834             int c = 0;
3835             for (out = 0; out < nout; out++)
3836             {
3837                 if (bitmask_is_set(flags->flag[b], out))
3838                 {
3839                     c++;
3840                 }
3841             }
3842             nelem += c;
3843             if (c == 1)
3844             {
3845                 ncopy++;
3846             }
3847             else
3848             {
3849                 nred += c;
3850             }
3851         }
3852     }
3853
3854     fprintf(debug, "nbnxn reduction: #flag %d #list %d elem %4.2f, keep %4.2f copy %4.2f red %4.2f\n",
3855             flags->nflag, nout,
3856             nelem/static_cast<double>(flags->nflag),
3857             nkeep/static_cast<double>(flags->nflag),
3858             ncopy/static_cast<double>(flags->nflag),
3859             nred/static_cast<double>(flags->nflag));
3860 }
3861
3862 /* Copies the list entries from src to dest when cjStart <= *cjGlobal < cjEnd.
3863  * *cjGlobal is updated with the cj count in src.
3864  * When setFlags==true, flag bit t is set in flag for all i and j clusters.
3865  */
3866 template<bool setFlags>
3867 static void copySelectedListRange(const nbnxn_ci_t * gmx_restrict srcCi,
3868                                   const NbnxnPairlistCpu * gmx_restrict src,
3869                                   NbnxnPairlistCpu * gmx_restrict dest,
3870                                   gmx_bitmask_t *flag,
3871                                   int iFlagShift, int jFlagShift, int t)
3872 {
3873     const int ncj = srcCi->cj_ind_end - srcCi->cj_ind_start;
3874
3875     dest->ci.push_back(*srcCi);
3876     dest->ci.back().cj_ind_start = dest->cj.size();
3877     dest->ci.back().cj_ind_end   = dest->cj.size() + ncj;
3878
3879     if (setFlags)
3880     {
3881         bitmask_init_bit(&flag[srcCi->ci >> iFlagShift], t);
3882     }
3883
3884     for (int j = srcCi->cj_ind_start; j < srcCi->cj_ind_end; j++)
3885     {
3886         dest->cj.push_back(src->cj[j]);
3887
3888         if (setFlags)
3889         {
3890             /* NOTE: This is relatively expensive, since this
3891              * operation is done for all elements in the list,
3892              * whereas at list generation this is done only
3893              * once for each flag entry.
3894              */
3895             bitmask_init_bit(&flag[src->cj[j].cj >> jFlagShift], t);
3896         }
3897     }
3898 }
3899
3900 /* This routine re-balances the pairlists such that all are nearly equally
3901  * sized. Only whole i-entries are moved between lists. These are moved
3902  * between the ends of the lists, such that the buffer reduction cost should
3903  * not change significantly.
3904  * Note that all original reduction flags are currently kept. This can lead
3905  * to reduction of parts of the force buffer that could be avoided. But since
3906  * the original lists are quite balanced, this will only give minor overhead.
3907  */
3908 static void rebalanceSimpleLists(int                                  numLists,
3909                                  NbnxnPairlistCpu * const * const     srcSet,
3910                                  NbnxnPairlistCpu                   **destSet,
3911                                  gmx::ArrayRef<nbnxn_search_work_t>   searchWork)
3912 {
3913     int ncjTotal = 0;
3914     for (int s = 0; s < numLists; s++)
3915     {
3916         ncjTotal += srcSet[s]->ncjInUse;
3917     }
3918     int ncjTarget = (ncjTotal + numLists - 1)/numLists;
3919
3920 #pragma omp parallel num_threads(numLists)
3921     {
3922         int t       = gmx_omp_get_thread_num();
3923
3924         int cjStart = ncjTarget* t;
3925         int cjEnd   = ncjTarget*(t + 1);
3926
3927         /* The destination pair-list for task/thread t */
3928         NbnxnPairlistCpu *dest = destSet[t];
3929
3930         clear_pairlist(dest);
3931         dest->na_cj   = srcSet[0]->na_cj;
3932
3933         /* Note that the flags in the work struct (still) contain flags
3934          * for all entries that are present in srcSet->nbl[t].
3935          */
3936         gmx_bitmask_t *flag       = searchWork[t].buffer_flags.flag;
3937
3938         int            iFlagShift = getBufferFlagShift(dest->na_ci);
3939         int            jFlagShift = getBufferFlagShift(dest->na_cj);
3940
3941         int            cjGlobal   = 0;
3942         for (int s = 0; s < numLists && cjGlobal < cjEnd; s++)
3943         {
3944             const NbnxnPairlistCpu *src = srcSet[s];
3945
3946             if (cjGlobal + src->ncjInUse > cjStart)
3947             {
3948                 for (gmx::index i = 0; i < gmx::ssize(src->ci) && cjGlobal < cjEnd; i++)
3949                 {
3950                     const nbnxn_ci_t *srcCi = &src->ci[i];
3951                     int               ncj   = srcCi->cj_ind_end - srcCi->cj_ind_start;
3952                     if (cjGlobal >= cjStart)
3953                     {
3954                         /* If the source list is not our own, we need to set
3955                          * extra flags (the template bool parameter).
3956                          */
3957                         if (s != t)
3958                         {
3959                             copySelectedListRange
3960                             <true>
3961                                 (srcCi, src, dest,
3962                                 flag, iFlagShift, jFlagShift, t);
3963                         }
3964                         else
3965                         {
3966                             copySelectedListRange
3967                             <false>
3968                                 (srcCi, src,
3969                                 dest, flag, iFlagShift, jFlagShift, t);
3970                         }
3971                     }
3972                     cjGlobal += ncj;
3973                 }
3974             }
3975             else
3976             {
3977                 cjGlobal += src->ncjInUse;
3978             }
3979         }
3980
3981         dest->ncjInUse = dest->cj.size();
3982     }
3983
3984 #ifndef NDEBUG
3985     int ncjTotalNew = 0;
3986     for (int s = 0; s < numLists; s++)
3987     {
3988         ncjTotalNew += destSet[s]->ncjInUse;
3989     }
3990     GMX_RELEASE_ASSERT(ncjTotalNew == ncjTotal, "The total size of the lists before and after rebalancing should match");
3991 #endif
3992 }
3993
3994 /* Returns if the pairlists are so imbalanced that it is worth rebalancing. */
3995 static bool checkRebalanceSimpleLists(const nbnxn_pairlist_set_t *listSet)
3996 {
3997     int numLists = listSet->nnbl;
3998     int ncjMax   = 0;
3999     int ncjTotal = 0;
4000     for (int s = 0; s < numLists; s++)
4001     {
4002         ncjMax    = std::max(ncjMax, listSet->nbl[s]->ncjInUse);
4003         ncjTotal += listSet->nbl[s]->ncjInUse;
4004     }
4005     if (debug)
4006     {
4007         fprintf(debug, "Pair-list ncjMax %d ncjTotal %d\n", ncjMax, ncjTotal);
4008     }
4009     /* The rebalancing adds 3% extra time to the search. Heuristically we
4010      * determined that under common conditions the non-bonded kernel balance
4011      * improvement will outweigh this when the imbalance is more than 3%.
4012      * But this will, obviously, depend on search vs kernel time and nstlist.
4013      */
4014     const real rebalanceTolerance = 1.03;
4015
4016     return numLists*ncjMax > ncjTotal*rebalanceTolerance;
4017 }
4018
4019 /* Perform a count (linear) sort to sort the smaller lists to the end.
4020  * This avoids load imbalance on the GPU, as large lists will be
4021  * scheduled and executed first and the smaller lists later.
4022  * Load balancing between multi-processors only happens at the end
4023  * and there smaller lists lead to more effective load balancing.
4024  * The sorting is done on the cj4 count, not on the actual pair counts.
4025  * Not only does this make the sort faster, but it also results in
4026  * better load balancing than using a list sorted on exact load.
4027  * This function swaps the pointer in the pair list to avoid a copy operation.
4028  */
4029 static void sort_sci(NbnxnPairlistGpu *nbl)
4030 {
4031     if (nbl->cj4.size() <= nbl->sci.size())
4032     {
4033         /* nsci = 0 or all sci have size 1, sorting won't change the order */
4034         return;
4035     }
4036
4037     NbnxnPairlistGpuWork &work = *nbl->work;
4038
4039     /* We will distinguish differences up to double the average */
4040     const int m = (2*nbl->cj4.size())/nbl->sci.size();
4041
4042     /* Resize work.sci_sort so we can sort into it */
4043     work.sci_sort.resize(nbl->sci.size());
4044
4045     std::vector<int> &sort = work.sortBuffer;
4046     /* Set up m + 1 entries in sort, initialized at 0 */
4047     sort.clear();
4048     sort.resize(m + 1, 0);
4049     /* Count the entries of each size */
4050     for (const nbnxn_sci_t &sci : nbl->sci)
4051     {
4052         int i = std::min(m, sci.numJClusterGroups());
4053         sort[i]++;
4054     }
4055     /* Calculate the offset for each count */
4056     int s0  = sort[m];
4057     sort[m] = 0;
4058     for (int i = m - 1; i >= 0; i--)
4059     {
4060         int s1  = sort[i];
4061         sort[i] = sort[i + 1] + s0;
4062         s0      = s1;
4063     }
4064
4065     /* Sort entries directly into place */
4066     gmx::ArrayRef<nbnxn_sci_t> sci_sort = work.sci_sort;
4067     for (const nbnxn_sci_t &sci : nbl->sci)
4068     {
4069         int i = std::min(m, sci.numJClusterGroups());
4070         sci_sort[sort[i]++] = sci;
4071     }
4072
4073     /* Swap the sci pointers so we use the new, sorted list */
4074     std::swap(nbl->sci, work.sci_sort);
4075 }
4076
4077 /* Make a local or non-local pair-list, depending on iloc */
4078 void nbnxn_make_pairlist(nbnxn_search         *nbs,
4079                          nbnxn_atomdata_t     *nbat,
4080                          const t_blocka       *excl,
4081                          real                  rlist,
4082                          int                   min_ci_balanced,
4083                          nbnxn_pairlist_set_t *nbl_list,
4084                          int                   iloc,
4085                          int                   nb_kernel_type,
4086                          t_nrnb               *nrnb)
4087 {
4088     int                nsubpair_target;
4089     float              nsubpair_tot_est;
4090     int                nnbl;
4091     int                ci_block;
4092     gmx_bool           CombineNBLists;
4093     gmx_bool           progBal;
4094     int                np_tot, np_noq, np_hlj, nap;
4095
4096     nnbl            = nbl_list->nnbl;
4097     CombineNBLists  = nbl_list->bCombined;
4098
4099     if (debug)
4100     {
4101         fprintf(debug, "ns making %d nblists\n", nnbl);
4102     }
4103
4104     nbat->bUseBufferFlags = (nbat->out.size() > 1);
4105     /* We should re-init the flags before making the first list */
4106     if (nbat->bUseBufferFlags && LOCAL_I(iloc))
4107     {
4108         init_buffer_flags(&nbat->buffer_flags, nbat->numAtoms());
4109     }
4110
4111     int nzi;
4112     if (LOCAL_I(iloc))
4113     {
4114         /* Only zone (grid) 0 vs 0 */
4115         nzi = 1;
4116     }
4117     else
4118     {
4119         nzi = nbs->zones->nizone;
4120     }
4121
4122     if (!nbl_list->bSimple && min_ci_balanced > 0)
4123     {
4124         get_nsubpair_target(nbs, iloc, rlist, min_ci_balanced,
4125                             &nsubpair_target, &nsubpair_tot_est);
4126     }
4127     else
4128     {
4129         nsubpair_target  = 0;
4130         nsubpair_tot_est = 0;
4131     }
4132
4133     /* Clear all pair-lists */
4134     for (int th = 0; th < nnbl; th++)
4135     {
4136         if (nbl_list->bSimple)
4137         {
4138             clear_pairlist(nbl_list->nbl[th]);
4139         }
4140         else
4141         {
4142             clear_pairlist(nbl_list->nblGpu[th]);
4143         }
4144
4145         if (nbs->bFEP)
4146         {
4147             clear_pairlist_fep(nbl_list->nbl_fep[th]);
4148         }
4149     }
4150
4151     for (int zi = 0; zi < nzi; zi++)
4152     {
4153         const nbnxn_grid_t &iGrid = nbs->grid[zi];
4154
4155         int                 zj0;
4156         int                 zj1;
4157         if (LOCAL_I(iloc))
4158         {
4159             zj0 = 0;
4160             zj1 = 1;
4161         }
4162         else
4163         {
4164             zj0 = nbs->zones->izone[zi].j0;
4165             zj1 = nbs->zones->izone[zi].j1;
4166             if (zi == 0)
4167             {
4168                 zj0++;
4169             }
4170         }
4171         for (int zj = zj0; zj < zj1; zj++)
4172         {
4173             const nbnxn_grid_t &jGrid = nbs->grid[zj];
4174
4175             if (debug)
4176             {
4177                 fprintf(debug, "ns search grid %d vs %d\n", zi, zj);
4178             }
4179
4180             nbs_cycle_start(&nbs->cc[enbsCCsearch]);
4181
4182             ci_block = get_ci_block_size(iGrid, nbs->DomDec, nnbl);
4183
4184             /* With GPU: generate progressively smaller lists for
4185              * load balancing for local only or non-local with 2 zones.
4186              */
4187             progBal = (LOCAL_I(iloc) || nbs->zones->n <= 2);
4188
4189 #pragma omp parallel for num_threads(nnbl) schedule(static)
4190             for (int th = 0; th < nnbl; th++)
4191             {
4192                 try
4193                 {
4194                     /* Re-init the thread-local work flag data before making
4195                      * the first list (not an elegant conditional).
4196                      */
4197                     if (nbat->bUseBufferFlags && ((zi == 0 && zj == 0)))
4198                     {
4199                         init_buffer_flags(&nbs->work[th].buffer_flags, nbat->numAtoms());
4200                     }
4201
4202                     if (CombineNBLists && th > 0)
4203                     {
4204                         GMX_ASSERT(!nbl_list->bSimple, "Can only combine GPU lists");
4205
4206                         clear_pairlist(nbl_list->nblGpu[th]);
4207                     }
4208
4209                     /* Divide the i super cell equally over the nblists */
4210                     if (nbl_list->bSimple)
4211                     {
4212                         nbnxn_make_pairlist_part(nbs, iGrid, jGrid,
4213                                                  &nbs->work[th], nbat, *excl,
4214                                                  rlist,
4215                                                  nb_kernel_type,
4216                                                  ci_block,
4217                                                  nbat->bUseBufferFlags,
4218                                                  nsubpair_target,
4219                                                  progBal, nsubpair_tot_est,
4220                                                  th, nnbl,
4221                                                  nbl_list->nbl[th],
4222                                                  nbl_list->nbl_fep[th]);
4223                     }
4224                     else
4225                     {
4226                         nbnxn_make_pairlist_part(nbs, iGrid, jGrid,
4227                                                  &nbs->work[th], nbat, *excl,
4228                                                  rlist,
4229                                                  nb_kernel_type,
4230                                                  ci_block,
4231                                                  nbat->bUseBufferFlags,
4232                                                  nsubpair_target,
4233                                                  progBal, nsubpair_tot_est,
4234                                                  th, nnbl,
4235                                                  nbl_list->nblGpu[th],
4236                                                  nbl_list->nbl_fep[th]);
4237                     }
4238                 }
4239                 GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
4240             }
4241             nbs_cycle_stop(&nbs->cc[enbsCCsearch]);
4242
4243             np_tot = 0;
4244             np_noq = 0;
4245             np_hlj = 0;
4246             for (int th = 0; th < nnbl; th++)
4247             {
4248                 inc_nrnb(nrnb, eNR_NBNXN_DIST2, nbs->work[th].ndistc);
4249
4250                 if (nbl_list->bSimple)
4251                 {
4252                     NbnxnPairlistCpu *nbl = nbl_list->nbl[th];
4253                     np_tot += nbl->cj.size();
4254                     np_noq += nbl->work->ncj_noq;
4255                     np_hlj += nbl->work->ncj_hlj;
4256                 }
4257                 else
4258                 {
4259                     NbnxnPairlistGpu *nbl = nbl_list->nblGpu[th];
4260                     /* This count ignores potential subsequent pair pruning */
4261                     np_tot += nbl->nci_tot;
4262                 }
4263             }
4264             if (nbl_list->bSimple)
4265             {
4266                 nap               = nbl_list->nbl[0]->na_ci*nbl_list->nbl[0]->na_cj;
4267             }
4268             else
4269             {
4270                 nap               = gmx::square(nbl_list->nblGpu[0]->na_ci);
4271             }
4272             nbl_list->natpair_ljq = (np_tot - np_noq)*nap - np_hlj*nap/2;
4273             nbl_list->natpair_lj  = np_noq*nap;
4274             nbl_list->natpair_q   = np_hlj*nap/2;
4275
4276             if (CombineNBLists && nnbl > 1)
4277             {
4278                 GMX_ASSERT(!nbl_list->bSimple, "Can only combine GPU lists");
4279                 NbnxnPairlistGpu **nbl = nbl_list->nblGpu;
4280
4281                 nbs_cycle_start(&nbs->cc[enbsCCcombine]);
4282
4283                 combine_nblists(nnbl-1, nbl+1, nbl[0]);
4284
4285                 nbs_cycle_stop(&nbs->cc[enbsCCcombine]);
4286             }
4287         }
4288     }
4289
4290     if (nbl_list->bSimple)
4291     {
4292         if (nnbl > 1 && checkRebalanceSimpleLists(nbl_list))
4293         {
4294             rebalanceSimpleLists(nbl_list->nnbl, nbl_list->nbl, nbl_list->nbl_work, nbs->work);
4295
4296             /* Swap the pointer of the sets of pair lists */
4297             NbnxnPairlistCpu **tmp = nbl_list->nbl;
4298             nbl_list->nbl          = nbl_list->nbl_work;
4299             nbl_list->nbl_work     = tmp;
4300         }
4301     }
4302     else
4303     {
4304         /* Sort the entries on size, large ones first */
4305         if (CombineNBLists || nnbl == 1)
4306         {
4307             sort_sci(nbl_list->nblGpu[0]);
4308         }
4309         else
4310         {
4311 #pragma omp parallel for num_threads(nnbl) schedule(static)
4312             for (int th = 0; th < nnbl; th++)
4313             {
4314                 try
4315                 {
4316                     sort_sci(nbl_list->nblGpu[th]);
4317                 }
4318                 GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
4319             }
4320         }
4321     }
4322
4323     if (nbat->bUseBufferFlags)
4324     {
4325         reduce_buffer_flags(nbs, nbl_list->nnbl, &nbat->buffer_flags);
4326     }
4327
4328     if (nbs->bFEP)
4329     {
4330         /* Balance the free-energy lists over all the threads */
4331         balance_fep_lists(nbs, nbl_list);
4332     }
4333
4334     if (nbl_list->bSimple)
4335     {
4336         /* This is a fresh list, so not pruned, stored using ci.
4337          * ciOuter is invalid at this point.
4338          */
4339         GMX_ASSERT(nbl_list->nbl[0]->ciOuter.empty(), "ciOuter is invalid so it should be empty");
4340     }
4341
4342     /* Special performance logging stuff (env.var. GMX_NBNXN_CYCLE) */
4343     if (LOCAL_I(iloc))
4344     {
4345         nbs->search_count++;
4346     }
4347     if (nbs->print_cycles &&
4348         (!nbs->DomDec || !LOCAL_I(iloc)) &&
4349         nbs->search_count % 100 == 0)
4350     {
4351         nbs_cycle_print(stderr, nbs);
4352     }
4353
4354     /* If we have more than one list, they either got rebalancing (CPU)
4355      * or combined (GPU), so we should dump the final result to debug.
4356      */
4357     if (debug && nbl_list->nnbl > 1)
4358     {
4359         if (nbl_list->bSimple)
4360         {
4361             for (int t = 0; t < nbl_list->nnbl; t++)
4362             {
4363                 print_nblist_statistics(debug, nbl_list->nbl[t], nbs, rlist);
4364             }
4365         }
4366         else
4367         {
4368             print_nblist_statistics(debug, nbl_list->nblGpu[0], nbs, rlist);
4369         }
4370     }
4371
4372     if (debug)
4373     {
4374         if (gmx_debug_at)
4375         {
4376             if (nbl_list->bSimple)
4377             {
4378                 for (int t = 0; t < nbl_list->nnbl; t++)
4379                 {
4380                     print_nblist_ci_cj(debug, nbl_list->nbl[t]);
4381                 }
4382             }
4383             else
4384             {
4385                 print_nblist_sci_cj(debug, nbl_list->nblGpu[0]);
4386             }
4387         }
4388
4389         if (nbat->bUseBufferFlags)
4390         {
4391             print_reduction_cost(&nbat->buffer_flags, nbl_list->nnbl);
4392         }
4393     }
4394 }
4395
4396 void nbnxnPrepareListForDynamicPruning(nbnxn_pairlist_set_t *listSet)
4397 {
4398     GMX_RELEASE_ASSERT(listSet->bSimple, "Should only be called for simple lists");
4399
4400     /* TODO: Restructure the lists so we have actual outer and inner
4401      *       list objects so we can set a single pointer instead of
4402      *       swapping several pointers.
4403      */
4404
4405     for (int i = 0; i < listSet->nnbl; i++)
4406     {
4407         NbnxnPairlistCpu &list = *listSet->nbl[i];
4408
4409         /* The search produced a list in ci/cj.
4410          * Swap the list pointers so we get the outer list is ciOuter,cjOuter
4411          * and we can prune that to get an inner list in ci/cj.
4412          */
4413         GMX_RELEASE_ASSERT(list.ciOuter.empty() && list.cjOuter.empty(),
4414                            "The outer lists should be empty before preparation");
4415
4416         std::swap(list.ci, list.ciOuter);
4417         std::swap(list.cj, list.cjOuter);
4418     }
4419 }