Merge branch release-2021 into master
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel_pruneonly.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief
38  *  NBNXM SYCL kernels
39  *
40  *  \ingroup module_nbnxm
41  */
42 #include "gmxpre.h"
43
44 #include "nbnxm_sycl_kernel_pruneonly.h"
45
46 #include "gromacs/gpu_utils/devicebuffer.h"
47 #include "gromacs/gpu_utils/gmxsycl.h"
48 #include "gromacs/utility/template_mp.h"
49
50 #include "nbnxm_sycl_kernel_utils.h"
51 #include "nbnxm_sycl_types.h"
52
53 using cl::sycl::access::fence_space;
54 using cl::sycl::access::mode;
55 using cl::sycl::access::target;
56
57 //! \brief Class name for NBNXM prune-only kernel
58 template<bool haveFreshList>
59 class NbnxmKernelPruneOnly;
60
61 namespace Nbnxm
62 {
63
64 /*! \brief Prune-only kernel for NBNXM.
65  *
66  */
67 template<bool haveFreshList>
68 auto nbnxmKernelPruneOnly(cl::sycl::handler&                            cgh,
69                           DeviceAccessor<Float4, mode::read>            a_xq,
70                           DeviceAccessor<Float3, mode::read>            a_shiftVec,
71                           DeviceAccessor<nbnxn_cj4_t, mode::read_write> a_plistCJ4,
72                           DeviceAccessor<nbnxn_sci_t, mode::read>       a_plistSci,
73                           DeviceAccessor<unsigned int, haveFreshList ? mode::write : mode::read> a_plistIMask,
74                           const float rlistOuterSq,
75                           const float rlistInnerSq,
76                           const int   numParts,
77                           const int   part)
78 {
79     cgh.require(a_xq);
80     cgh.require(a_shiftVec);
81     cgh.require(a_plistCJ4);
82     cgh.require(a_plistSci);
83     cgh.require(a_plistIMask);
84
85     /* shmem buffer for i x+q pre-loading */
86     cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(
87             cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
88
89     constexpr int warpSize = c_clSize * c_clSize / 2;
90
91     /* Somewhat weird behavior inherited from OpenCL.
92      * With clSize == 4, we use sub_group size of 16 (not enforced in OpenCL implementation, but chosen
93      * by the IGC compiler), however for data layout we consider it to be 8.
94      * Setting sub_group size to 8 slows down the prune-only kernel 1.5-2 times.
95      * For clSize == But we need to set specific sub_group size >= 32 for clSize == 8 for correctness,
96      * but it causes very poor performance.
97      */
98     constexpr int gmx_unused requiredSubGroupSize = (c_clSize == 4) ? 16 : warpSize;
99
100     /* Requirements:
101      * Work group (block) must have range (c_clSize, c_clSize, ...) (for localId calculation, easy
102      * to change). */
103     return [=](cl::sycl::nd_item<1> itemIdx) [[intel::reqd_sub_group_size(requiredSubGroupSize)]]
104     {
105         const cl::sycl::id<3> localId = unflattenId<c_clSize, c_clSize>(itemIdx.get_local_id());
106         // thread/block/warp id-s
107         const unsigned tidxi = localId[0];
108         const unsigned tidxj = localId[1];
109         const int      tidx  = tidxj * c_clSize + tidxi;
110         const unsigned tidxz = localId[2];
111         const unsigned bidx  = itemIdx.get_group(0);
112
113         const sycl_2020::sub_group sg   = itemIdx.get_sub_group();
114         const unsigned             widx = tidx / warpSize;
115
116         // my i super-cluster's index = sciOffset + current bidx * numParts + part
117         const nbnxn_sci_t nbSci     = a_plistSci[bidx * numParts + part];
118         const int         sci       = nbSci.sci;           /* super-cluster */
119         const int         cij4Start = nbSci.cj4_ind_start; /* first ...*/
120         const int         cij4End   = nbSci.cj4_ind_end;   /* and last index of j clusters */
121
122         if (tidxz == 0)
123         {
124             for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i += c_clSize)
125             {
126                 /* Pre-load i-atom x and q into shared memory */
127                 const int ci = sci * c_nbnxnGpuNumClusterPerSupercluster + tidxj + i;
128                 const int ai = ci * c_clSize + tidxi;
129
130                 /* We don't need q, but using float4 in shmem avoids bank conflicts.
131                    (but it also wastes L2 bandwidth). */
132                 const Float4 xq    = a_xq[ai];
133                 const Float3 shift = a_shiftVec[nbSci.shift];
134                 const Float4 xi(xq[0] + shift[0], xq[1] + shift[1], xq[2] + shift[2], xq[3]);
135                 sm_xq[tidxj + i][tidxi] = xi;
136             }
137         }
138         itemIdx.barrier(fence_space::local_space);
139
140         /* loop over the j clusters = seen by any of the atoms in the current super-cluster.
141          * The loop stride c_syclPruneKernelJ4Concurrency ensures that consecutive warps-pairs are
142          * assigned consecutive j4's entries. */
143         for (int j4 = cij4Start + tidxz; j4 < cij4End; j4 += c_syclPruneKernelJ4Concurrency)
144         {
145             unsigned imaskFull, imaskCheck, imaskNew;
146
147             if constexpr (haveFreshList)
148             {
149                 /* Read the mask from the list transferred from the CPU */
150                 imaskFull = a_plistCJ4[j4].imei[widx].imask;
151                 /* We attempt to prune all pairs present in the original list */
152                 imaskCheck = imaskFull;
153                 imaskNew   = 0;
154             }
155             else
156             {
157                 /* Read the mask from the "warp-pruned" by rlistOuter mask array */
158                 imaskFull = a_plistIMask[j4 * c_nbnxnGpuClusterpairSplit + widx];
159                 /* Read the old rolling pruned mask, use as a base for new */
160                 imaskNew = a_plistCJ4[j4].imei[widx].imask;
161                 /* We only need to check pairs with different mask */
162                 imaskCheck = (imaskNew ^ imaskFull);
163             }
164
165             if (imaskCheck)
166             {
167                 for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
168                 {
169                     if (imaskCheck & (superClInteractionMask << (jm * c_nbnxnGpuNumClusterPerSupercluster)))
170                     {
171                         unsigned mask_ji = (1U << (jm * c_nbnxnGpuNumClusterPerSupercluster));
172                         // SYCL-TODO: Reevaluate prefetching methods
173                         const int cj = a_plistCJ4[j4].cj[jm];
174                         const int aj = cj * c_clSize + tidxj;
175
176                         /* load j atom data */
177                         const Float4 tmp = a_xq[aj];
178                         const Float3 xj(tmp[0], tmp[1], tmp[2]);
179
180                         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
181                         {
182                             if (imaskCheck & mask_ji)
183                             {
184                                 // load i-cluster coordinates from shmem
185                                 const Float4 xi = sm_xq[i][tidxi];
186                                 // distance between i and j atoms
187                                 Float3 rv(xi[0], xi[1], xi[2]);
188                                 rv -= xj;
189                                 const float r2 = norm2(rv);
190
191                                 /* If _none_ of the atoms pairs are in rlistOuter
192                                  * range, the bit corresponding to the current
193                                  * cluster-pair in imask gets set to 0. */
194                                 if (haveFreshList && !(sycl_2020::group_any_of(sg, r2 < rlistOuterSq)))
195                                 {
196                                     imaskFull &= ~mask_ji;
197                                 }
198                                 /* If any atom pair is within range, set the bit
199                                  * corresponding to the current cluster-pair. */
200                                 if (sycl_2020::group_any_of(sg, r2 < rlistInnerSq))
201                                 {
202                                     imaskNew |= mask_ji;
203                                 }
204                             } // (imaskCheck & mask_ji)
205                             /* shift the mask bit by 1 */
206                             mask_ji += mask_ji;
207                         } // (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
208                     } // (imaskCheck & (superClInteractionMask << (jm * c_nbnxnGpuNumClusterPerSupercluster)))
209                 } // for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
210
211                 if constexpr (haveFreshList)
212                 {
213                     /* copy the list pruned to rlistOuter to a separate buffer */
214                     a_plistIMask[j4 * c_nbnxnGpuClusterpairSplit + widx] = imaskFull;
215                 }
216                 /* update the imask with only the pairs up to rlistInner */
217                 a_plistCJ4[j4].imei[widx].imask = imaskNew;
218             } // (imaskCheck)
219         } // for (int j4 = cij4_start + tidxz; j4 < cij4_end; j4 += c_syclPruneKernelJ4Concurrency)
220     };
221 }
222
223 template<bool haveFreshList, class... Args>
224 cl::sycl::event launchNbnxmKernelPruneOnly(const DeviceStream& deviceStream,
225                                            const int           numSciInPart,
226                                            Args&&... args)
227 {
228     using kernelNameType = NbnxmKernelPruneOnly<haveFreshList>;
229
230     /* Kernel launch config:
231      * - The thread block dimensions match the size of i-clusters, j-clusters,
232      *   and j-cluster concurrency, in x, y, and z, respectively.
233      * - The 1D block-grid contains as many blocks as super-clusters.
234      */
235     const unsigned long         numBlocks = numSciInPart;
236     const cl::sycl::range<3>    blockSize{ c_clSize, c_clSize, c_syclPruneKernelJ4Concurrency };
237     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
238     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
239
240     cl::sycl::queue q = deviceStream.stream();
241
242     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
243         auto kernel = nbnxmKernelPruneOnly<haveFreshList>(cgh, std::forward<Args>(args)...);
244         cgh.parallel_for<kernelNameType>(flattenNDRange(range), kernel);
245     });
246
247     return e;
248 }
249
250 template<class... Args>
251 cl::sycl::event chooseAndLaunchNbnxmKernelPruneOnly(bool haveFreshList, Args&&... args)
252 {
253     return gmx::dispatchTemplatedFunction(
254             [&](auto haveFreshList_) {
255                 return launchNbnxmKernelPruneOnly<haveFreshList_>(std::forward<Args>(args)...);
256             },
257             haveFreshList);
258 }
259
260 void launchNbnxmKernelPruneOnly(NbnxmGpu*                 nb,
261                                 const InteractionLocality iloc,
262                                 const int                 numParts,
263                                 const int                 part,
264                                 const int                 numSciInPart)
265 {
266     NBAtomDataGpu*      adat          = nb->atdat;
267     NBParamGpu*         nbp           = nb->nbparam;
268     gpu_plist*          plist         = nb->plist[iloc];
269     const bool          haveFreshList = plist->haveFreshList;
270     const DeviceStream& deviceStream  = *nb->deviceStreams[iloc];
271
272     cl::sycl::event e = chooseAndLaunchNbnxmKernelPruneOnly(haveFreshList,
273                                                             deviceStream,
274                                                             numSciInPart,
275                                                             adat->xq,
276                                                             adat->shiftVec,
277                                                             plist->cj4,
278                                                             plist->sci,
279                                                             plist->imask,
280                                                             nbp->rlistOuter_sq,
281                                                             nbp->rlistInner_sq,
282                                                             numParts,
283                                                             part);
284 }
285
286 } // namespace Nbnxm