SYCL: Use acc.bind(cgh) instead of cgh.require(acc)
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief
38  *  NBNXM SYCL kernels
39  *
40  *  \ingroup module_nbnxm
41  */
42 #include "gmxpre.h"
43
44 #include "nbnxm_sycl_kernel.h"
45
46 #include "gromacs/gpu_utils/devicebuffer.h"
47 #include "gromacs/gpu_utils/gmxsycl.h"
48 #include "gromacs/math/functions.h"
49 #include "gromacs/mdtypes/simulation_workload.h"
50 #include "gromacs/pbcutil/ishift.h"
51 #include "gromacs/utility/template_mp.h"
52
53 #include "nbnxm_sycl_kernel_utils.h"
54 #include "nbnxm_sycl_types.h"
55
56 //! \brief Class name for NBNXM kernel
57 template<bool doPruneNBL, bool doCalcEnergies, enum Nbnxm::ElecType elecType, enum Nbnxm::VdwType vdwType>
58 class NbnxmKernel;
59
60 namespace Nbnxm
61 {
62
63 //! \brief Set of boolean constants mimicking preprocessor macros.
64 template<enum ElecType elecType, enum VdwType vdwType>
65 struct EnergyFunctionProperties {
66     static constexpr bool elecCutoff = (elecType == ElecType::Cut); ///< EL_CUTOFF
67     static constexpr bool elecRF     = (elecType == ElecType::RF);  ///< EL_RF
68     static constexpr bool elecEwaldAna =
69             (elecType == ElecType::EwaldAna || elecType == ElecType::EwaldAnaTwin); ///< EL_EWALD_ANA
70     static constexpr bool elecEwaldTab =
71             (elecType == ElecType::EwaldTab || elecType == ElecType::EwaldTabTwin); ///< EL_EWALD_TAB
72     static constexpr bool elecEwaldTwin =
73             (elecType == ElecType::EwaldAnaTwin || elecType == ElecType::EwaldTabTwin); ///< Use twin cut-off.
74     static constexpr bool elecEwald = (elecEwaldAna || elecEwaldTab);  ///< EL_EWALD_ANY
75     static constexpr bool vdwCombLB = (vdwType == VdwType::CutCombLB); ///< LJ_COMB && !LJ_COMB_GEOM
76     static constexpr bool vdwCombGeom      = (vdwType == VdwType::CutCombGeom); ///< LJ_COMB_GEOM
77     static constexpr bool vdwComb          = (vdwCombLB || vdwCombGeom);        ///< LJ_COMB
78     static constexpr bool vdwEwaldCombGeom = (vdwType == VdwType::EwaldGeom); ///< LJ_EWALD_COMB_GEOM
79     static constexpr bool vdwEwaldCombLB   = (vdwType == VdwType::EwaldLB);   ///< LJ_EWALD_COMB_LB
80     static constexpr bool vdwEwald         = (vdwEwaldCombGeom || vdwEwaldCombLB); ///< LJ_EWALD
81     static constexpr bool vdwFSwitch       = (vdwType == VdwType::FSwitch); ///< LJ_FORCE_SWITCH
82     static constexpr bool vdwPSwitch       = (vdwType == VdwType::PSwitch); ///< LJ_POT_SWITCH
83 };
84
85 //! \brief Templated constants to shorten kernel function declaration.
86 //@{
87 template<enum VdwType vdwType>
88 constexpr bool ljComb = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwComb;
89
90 template<enum ElecType elecType>
91 constexpr bool elecEwald = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwald;
92
93 template<enum ElecType elecType>
94 constexpr bool elecEwaldTab = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTab;
95
96 template<enum VdwType vdwType>
97 constexpr bool ljEwald = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwEwald;
98 //@}
99
100 using cl::sycl::access::fence_space;
101 using cl::sycl::access::mode;
102 using cl::sycl::access::target;
103
104 //! \brief Convert \p sigma and \p epsilon VdW parameters to \c c6,c12 pair.
105 static inline Float2 convertSigmaEpsilonToC6C12(const float sigma, const float epsilon)
106 {
107     const float sigma2 = sigma * sigma;
108     const float sigma6 = sigma2 * sigma2 * sigma2;
109     const float c6     = epsilon * sigma6;
110     const float c12    = c6 * sigma6;
111
112     return { c6, c12 };
113 }
114
115 //! \brief Calculate force and energy for a pair of atoms, VdW force-switch flavor.
116 template<bool doCalcEnergies>
117 static inline void ljForceSwitch(const shift_consts_t         dispersionShift,
118                                  const shift_consts_t         repulsionShift,
119                                  const float                  rVdwSwitch,
120                                  const float                  c6,
121                                  const float                  c12,
122                                  const float                  rInv,
123                                  const float                  r2,
124                                  cl::sycl::private_ptr<float> fInvR,
125                                  cl::sycl::private_ptr<float> eLJ)
126 {
127     /* force switch constants */
128     const float dispShiftV2 = dispersionShift.c2;
129     const float dispShiftV3 = dispersionShift.c3;
130     const float repuShiftV2 = repulsionShift.c2;
131     const float repuShiftV3 = repulsionShift.c3;
132
133     const float r       = r2 * rInv;
134     const float rSwitch = cl::sycl::fdim(r, rVdwSwitch); // max(r - rVdwSwitch, 0)
135
136     *fInvR += -c6 * (dispShiftV2 + dispShiftV3 * rSwitch) * rSwitch * rSwitch * rInv
137               + c12 * (repuShiftV2 + repuShiftV3 * rSwitch) * rSwitch * rSwitch * rInv;
138
139     if constexpr (doCalcEnergies)
140     {
141         const float dispShiftF2 = dispShiftV2 / 3;
142         const float dispShiftF3 = dispShiftV3 / 4;
143         const float repuShiftF2 = repuShiftV2 / 3;
144         const float repuShiftF3 = repuShiftV3 / 4;
145         *eLJ += c6 * (dispShiftF2 + dispShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch
146                 - c12 * (repuShiftF2 + repuShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch;
147     }
148 }
149
150 //! \brief Fetch C6 grid contribution coefficients and return the product of these.
151 template<enum VdwType vdwType>
152 static inline float calculateLJEwaldC6Grid(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
153                                            const int                                typeI,
154                                            const int                                typeJ)
155 {
156     if constexpr (vdwType == VdwType::EwaldGeom)
157     {
158         return a_nbfpComb[typeI][0] * a_nbfpComb[typeJ][0];
159     }
160     else
161     {
162         static_assert(vdwType == VdwType::EwaldLB);
163         /* sigma and epsilon are scaled to give 6*C6 */
164         const Float2 c6c12_i = a_nbfpComb[typeI];
165         const Float2 c6c12_j = a_nbfpComb[typeJ];
166
167         const float sigma   = c6c12_i[0] + c6c12_j[0];
168         const float epsilon = c6c12_i[1] * c6c12_j[1];
169
170         const float sigma2 = sigma * sigma;
171         return epsilon * sigma2 * sigma2 * sigma2;
172     }
173 }
174
175 //! Calculate LJ-PME grid force contribution with geometric or LB combination rule.
176 template<bool doCalcEnergies, enum VdwType vdwType>
177 static inline void ljEwaldComb(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
178                                const float                              sh_lj_ewald,
179                                const int                                typeI,
180                                const int                                typeJ,
181                                const float                              r2,
182                                const float                              r2Inv,
183                                const float                              lje_coeff2,
184                                const float                              lje_coeff6_6,
185                                const float                              int_bit,
186                                cl::sycl::private_ptr<float>             fInvR,
187                                cl::sycl::private_ptr<float>             eLJ)
188 {
189     const float c6grid = calculateLJEwaldC6Grid<vdwType>(a_nbfpComb, typeI, typeJ);
190
191     /* Recalculate inv_r6 without exclusion mask */
192     const float inv_r6_nm = r2Inv * r2Inv * r2Inv;
193     const float cr2       = lje_coeff2 * r2;
194     const float expmcr2   = cl::sycl::exp(-cr2);
195     const float poly      = 1.0F + cr2 + 0.5F * cr2 * cr2;
196
197     /* Subtract the grid force from the total LJ force */
198     *fInvR += c6grid * (inv_r6_nm - expmcr2 * (inv_r6_nm * poly + lje_coeff6_6)) * r2Inv;
199
200     if constexpr (doCalcEnergies)
201     {
202         /* Shift should be applied only to real LJ pairs */
203         const float sh_mask = sh_lj_ewald * int_bit;
204         *eLJ += c_oneSixth * c6grid * (inv_r6_nm * (1.0F - expmcr2 * poly) + sh_mask);
205     }
206 }
207
208 /*! \brief Apply potential switch. */
209 template<bool doCalcEnergies>
210 static inline void ljPotentialSwitch(const switch_consts_t        vdwSwitch,
211                                      const float                  rVdwSwitch,
212                                      const float                  rInv,
213                                      const float                  r2,
214                                      cl::sycl::private_ptr<float> fInvR,
215                                      cl::sycl::private_ptr<float> eLJ)
216 {
217     /* potential switch constants */
218     const float switchV3 = vdwSwitch.c3;
219     const float switchV4 = vdwSwitch.c4;
220     const float switchV5 = vdwSwitch.c5;
221     const float switchF2 = 3 * vdwSwitch.c3;
222     const float switchF3 = 4 * vdwSwitch.c4;
223     const float switchF4 = 5 * vdwSwitch.c5;
224
225     const float r       = r2 * rInv;
226     const float rSwitch = r - rVdwSwitch;
227
228     if (rSwitch > 0.0F)
229     {
230         const float sw =
231                 1.0F + (switchV3 + (switchV4 + switchV5 * rSwitch) * rSwitch) * rSwitch * rSwitch * rSwitch;
232         const float dsw = (switchF2 + (switchF3 + switchF4 * rSwitch) * rSwitch) * rSwitch * rSwitch;
233
234         *fInvR = (*fInvR) * sw - rInv * (*eLJ) * dsw;
235         if constexpr (doCalcEnergies)
236         {
237             *eLJ *= sw;
238         }
239     }
240 }
241
242
243 /*! \brief Calculate analytical Ewald correction term. */
244 static inline float pmeCorrF(const float z2)
245 {
246     constexpr float FN6 = -1.7357322914161492954e-8F;
247     constexpr float FN5 = 1.4703624142580877519e-6F;
248     constexpr float FN4 = -0.000053401640219807709149F;
249     constexpr float FN3 = 0.0010054721316683106153F;
250     constexpr float FN2 = -0.019278317264888380590F;
251     constexpr float FN1 = 0.069670166153766424023F;
252     constexpr float FN0 = -0.75225204789749321333F;
253
254     constexpr float FD4 = 0.0011193462567257629232F;
255     constexpr float FD3 = 0.014866955030185295499F;
256     constexpr float FD2 = 0.11583842382862377919F;
257     constexpr float FD1 = 0.50736591960530292870F;
258     constexpr float FD0 = 1.0F;
259
260     const float z4 = z2 * z2;
261
262     float       polyFD0 = FD4 * z4 + FD2;
263     const float polyFD1 = FD3 * z4 + FD1;
264     polyFD0             = polyFD0 * z4 + FD0;
265     polyFD0             = polyFD1 * z2 + polyFD0;
266
267     polyFD0 = 1.0F / polyFD0;
268
269     float polyFN0 = FN6 * z4 + FN4;
270     float polyFN1 = FN5 * z4 + FN3;
271     polyFN0       = polyFN0 * z4 + FN2;
272     polyFN1       = polyFN1 * z4 + FN1;
273     polyFN0       = polyFN0 * z4 + FN0;
274     polyFN0       = polyFN1 * z2 + polyFN0;
275
276     return polyFN0 * polyFD0;
277 }
278
279 /*! \brief Linear interpolation using exactly two FMA operations.
280  *
281  *  Implements numeric equivalent of: (1-t)*d0 + t*d1.
282  */
283 template<typename T>
284 static inline T lerp(T d0, T d1, T t)
285 {
286     return fma(t, d1, fma(-t, d0, d0));
287 }
288
289 /*! \brief Interpolate Ewald coulomb force correction using the F*r table. */
290 static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::read> a_coulombTab,
291                                              const float coulombTabScale,
292                                              const float r)
293 {
294     const float normalized = coulombTabScale * r;
295     const int   index      = static_cast<int>(normalized);
296     const float fraction   = normalized - index;
297
298     const float left  = a_coulombTab[index];
299     const float right = a_coulombTab[index + 1];
300
301     return lerp(left, right, fraction); // TODO: cl::sycl::mix
302 }
303
304 /*! \brief Reduce c_clSize j-force components using shifts and atomically accumulate into a_f.
305  *
306  * c_clSize consecutive threads hold the force components of a j-atom which we
307  * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
308  */
309 static inline void reduceForceJShuffle(Float3                                   f,
310                                        const cl::sycl::nd_item<3>               itemIdx,
311                                        const int                                tidxi,
312                                        const int                                aidx,
313                                        DeviceAccessor<Float3, mode::read_write> a_f)
314 {
315     static_assert(c_clSize == 8 || c_clSize == 4);
316     sycl_2020::sub_group sg = itemIdx.get_sub_group();
317
318     f[0] += sycl_2020::shift_left(sg, f[0], 1);
319     f[1] += sycl_2020::shift_right(sg, f[1], 1);
320     f[2] += sycl_2020::shift_left(sg, f[2], 1);
321     if (tidxi & 1)
322     {
323         f[0] = f[1];
324     }
325
326     f[0] += sycl_2020::shift_left(sg, f[0], 2);
327     f[2] += sycl_2020::shift_right(sg, f[2], 2);
328     if (tidxi & 2)
329     {
330         f[0] = f[2];
331     }
332
333     if constexpr (c_clSize == 8)
334     {
335         f[0] += sycl_2020::shift_left(sg, f[0], 4);
336     }
337
338     if (tidxi < 3)
339     {
340         atomicFetchAdd(a_f[aidx][tidxi], f[0]);
341     }
342 }
343
344 /*!
345  * \brief Do workgroup-level reduction of a single \c float.
346  *
347  * While SYCL has \c sycl::reduce_over_group, it currently (oneAPI 2021.3.0) uses a very large
348  * shared memory buffer, which leads to a reduced occupancy.
349  *
350  * \note The caller must make sure there are no races when reusing the \p sm_buf.
351  *
352  * \tparam subGroupSize Size of a sub-group.
353  * \tparam groupSize Size of a work-group.
354  * \param itemIdx Current thread's \c sycl::nd_item.
355  * \param tidxi Current thread's linearized local index.
356  * \param sm_buf Accessor for local reduction buffer.
357  * \param valueToReduce Current thread's value. Must have length of at least 1.
358  * \return For thread with \p tidxi 0: sum of all \p valueToReduce. Other threads: unspecified.
359  */
360 template<int subGroupSize, int groupSize>
361 static inline float groupReduce(const cl::sycl::nd_item<3> itemIdx,
362                                 const unsigned int         tidxi,
363                                 cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
364                                 float valueToReduce)
365 {
366     constexpr int numSubGroupsInGroup = groupSize / subGroupSize;
367     static_assert(numSubGroupsInGroup == 1 || numSubGroupsInGroup == 2);
368     sycl_2020::sub_group sg = itemIdx.get_sub_group();
369     valueToReduce           = sycl_2020::group_reduce(sg, valueToReduce, sycl_2020::plus<float>());
370     // If we have two sub-groups, we should reduce across them.
371     if constexpr (numSubGroupsInGroup == 2)
372     {
373         if (tidxi == subGroupSize)
374         {
375             sm_buf[0] = valueToReduce;
376         }
377         itemIdx.barrier(fence_space::local_space);
378         if (tidxi == 0)
379         {
380             valueToReduce += sm_buf[0];
381         }
382     }
383     return valueToReduce;
384 }
385
386 /*! \brief Reduce c_clSize j-force components using local memory and atomically accumulate into a_f.
387  *
388  * c_clSize consecutive threads hold the force components of a j-atom which we
389  * reduced in cl_Size steps using shift and atomically accumulate them into \p a_f.
390  *
391  * TODO: implement binary reduction flavor for the case where cl_Size is power of two.
392  */
393 static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
394                                        Float3                                   f,
395                                        const cl::sycl::nd_item<3>               itemIdx,
396                                        const int                                tidxi,
397                                        const int                                tidxj,
398                                        const int                                aidx,
399                                        DeviceAccessor<Float3, mode::read_write> a_f)
400 {
401     static constexpr int sc_fBufferStride = c_clSizeSq;
402     int                  tidx             = tidxi + tidxj * c_clSize;
403     sm_buf[0 * sc_fBufferStride + tidx]   = f[0];
404     sm_buf[1 * sc_fBufferStride + tidx]   = f[1];
405     sm_buf[2 * sc_fBufferStride + tidx]   = f[2];
406
407     subGroupBarrier(itemIdx);
408
409     // reducing data 8-by-by elements on the leader of same threads as those storing above
410     assert(itemIdx.get_sub_group().get_local_range().size() >= c_clSize);
411
412     if (tidxi < 3)
413     {
414         float fSum = 0.0F;
415         for (int j = tidxj * c_clSize; j < (tidxj + 1) * c_clSize; j++)
416         {
417             fSum += sm_buf[sc_fBufferStride * tidxi + j];
418         }
419
420         atomicFetchAdd(a_f[aidx][tidxi], fSum);
421     }
422 }
423
424
425 /*! \brief Reduce c_clSize j-force components using either shifts or local memory and atomically accumulate into a_f.
426  */
427 static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
428                                 Float3                                                        f,
429                                 const cl::sycl::nd_item<3>               itemIdx,
430                                 const int                                tidxi,
431                                 const int                                tidxj,
432                                 const int                                aidx,
433                                 DeviceAccessor<Float3, mode::read_write> a_f)
434 {
435     if constexpr (!gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster))
436     {
437         reduceForceJGeneric(sm_buf, f, itemIdx, tidxi, tidxj, aidx, a_f);
438     }
439     else
440     {
441         reduceForceJShuffle(f, itemIdx, tidxi, aidx, a_f);
442     }
443 }
444
445
446 /*! \brief Final i-force reduction.
447  *
448  * Reduce c_nbnxnGpuNumClusterPerSupercluster i-force components stored in \p fCiBuf[]
449  * accumulating atomically into \p a_f.
450  * If \p calcFShift is true, further reduce shift forces and atomically accumulate into \p a_fShift.
451  *
452  * This implementation works only with power of two array sizes.
453  */
454 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
455                                          const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
456                                          const bool   calcFShift,
457                                          const cl::sycl::nd_item<3>               itemIdx,
458                                          const int                                tidxi,
459                                          const int                                tidxj,
460                                          const int                                sci,
461                                          const int                                shift,
462                                          DeviceAccessor<Float3, mode::read_write> a_f,
463                                          DeviceAccessor<Float3, mode::read_write> a_fShift)
464 {
465     // must have power of two elements in fCiBuf
466     static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster));
467
468     static constexpr int bufStride  = c_clSize * c_clSize;
469     static constexpr int clSizeLog2 = gmx::StaticLog2<c_clSize>::value;
470     const int            tidx       = tidxi + tidxj * c_clSize;
471     float                fShiftBuf  = 0.0F;
472     for (int ciOffset = 0; ciOffset < c_nbnxnGpuNumClusterPerSupercluster; ciOffset++)
473     {
474         const int aidx = (sci * c_nbnxnGpuNumClusterPerSupercluster + ciOffset) * c_clSize + tidxi;
475         /* store i forces in shmem */
476         sm_buf[tidx]                 = fCiBuf[ciOffset][0];
477         sm_buf[bufStride + tidx]     = fCiBuf[ciOffset][1];
478         sm_buf[2 * bufStride + tidx] = fCiBuf[ciOffset][2];
479         itemIdx.barrier(fence_space::local_space);
480
481         /* Reduce the initial c_clSize values for each i atom to half
482          * every step by using c_clSize * i threads. */
483         int i = c_clSize / 2;
484         for (int j = clSizeLog2 - 1; j > 0; j--)
485         {
486             if (tidxj < i)
487             {
488                 sm_buf[tidxj * c_clSize + tidxi] += sm_buf[(tidxj + i) * c_clSize + tidxi];
489                 sm_buf[bufStride + tidxj * c_clSize + tidxi] +=
490                         sm_buf[bufStride + (tidxj + i) * c_clSize + tidxi];
491                 sm_buf[2 * bufStride + tidxj * c_clSize + tidxi] +=
492                         sm_buf[2 * bufStride + (tidxj + i) * c_clSize + tidxi];
493             }
494             i >>= 1;
495             itemIdx.barrier(fence_space::local_space);
496         }
497
498         /* i == 1, last reduction step, writing to global mem */
499         /* Split the reduction between the first 3 line threads
500            Threads with line id 0 will do the reduction for (float3).x components
501            Threads with line id 1 will do the reduction for (float3).y components
502            Threads with line id 2 will do the reduction for (float3).z components. */
503         if (tidxj < 3)
504         {
505             const float f =
506                     sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
507             atomicFetchAdd(a_f[aidx][tidxj], f);
508             if (calcFShift)
509             {
510                 fShiftBuf += f;
511             }
512         }
513         itemIdx.barrier(fence_space::local_space);
514     }
515     /* add up local shift forces into global mem */
516     if (calcFShift)
517     {
518         /* Only threads with tidxj < 3 will update fshift.
519            The threads performing the update must be the same as the threads
520            storing the reduction result above. */
521         if (tidxj < 3)
522         {
523             if constexpr (c_clSize == 4)
524             {
525                 /* Intel Xe (Gen12LP) and earlier GPUs implement floating-point atomics via
526                  * a compare-and-swap (CAS) loop. It has particularly poor performance when
527                  * updating the same memory location from the same work-group.
528                  * Such optimization might be slightly beneficial for NVIDIA and AMD as well,
529                  * but it is unlikely to make a big difference and thus was not evaluated.
530                  */
531                 auto sg = itemIdx.get_sub_group();
532                 fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 1);
533                 fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 2);
534                 if (tidxi == 0)
535                 {
536                     atomicFetchAdd(a_fShift[shift][tidxj], fShiftBuf);
537                 }
538             }
539             else
540             {
541                 atomicFetchAdd(a_fShift[shift][tidxj], fShiftBuf);
542             }
543         }
544     }
545 }
546
547 /*! \brief Main kernel for NBNXM.
548  *
549  */
550 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
551 auto nbnxmKernel(cl::sycl::handler&                                        cgh,
552                  DeviceAccessor<Float4, mode::read>                        a_xq,
553                  DeviceAccessor<Float3, mode::read_write>                  a_f,
554                  DeviceAccessor<Float3, mode::read>                        a_shiftVec,
555                  DeviceAccessor<Float3, mode::read_write>                  a_fShift,
556                  OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyElec,
557                  OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyVdw,
558                  DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
559                  DeviceAccessor<nbnxn_sci_t, mode::read>                                 a_plistSci,
560                  DeviceAccessor<nbnxn_excl_t, mode::read>                    a_plistExcl,
561                  OptionalAccessor<Float2, mode::read, ljComb<vdwType>>       a_ljComb,
562                  OptionalAccessor<int, mode::read, !ljComb<vdwType>>         a_atomTypes,
563                  OptionalAccessor<Float2, mode::read, !ljComb<vdwType>>      a_nbfp,
564                  OptionalAccessor<Float2, mode::read, ljEwald<vdwType>>      a_nbfpComb,
565                  OptionalAccessor<float, mode::read, elecEwaldTab<elecType>> a_coulombTab,
566                  const int                                                   numTypes,
567                  const float                                                 rCoulombSq,
568                  const float                                                 rVdwSq,
569                  const float                                                 twoKRf,
570                  const float                                                 ewaldBeta,
571                  const float                                                 rlistOuterSq,
572                  const float                                                 ewaldShift,
573                  const float                                                 epsFac,
574                  const float                                                 ewaldCoeffLJ,
575                  const float                                                 cRF,
576                  const shift_consts_t                                        dispersionShift,
577                  const shift_consts_t                                        repulsionShift,
578                  const switch_consts_t                                       vdwSwitch,
579                  const float                                                 rVdwSwitch,
580                  const float                                                 ljEwaldShift,
581                  const float                                                 coulombTabScale,
582                  const bool                                                  calcShift)
583 {
584     static constexpr EnergyFunctionProperties<elecType, vdwType> props;
585
586     a_xq.bind(cgh);
587     a_f.bind(cgh);
588     a_shiftVec.bind(cgh);
589     a_fShift.bind(cgh);
590     if constexpr (doCalcEnergies)
591     {
592         a_energyElec.bind(cgh);
593         a_energyVdw.bind(cgh);
594     }
595     a_plistCJ4.bind(cgh);
596     a_plistSci.bind(cgh);
597     a_plistExcl.bind(cgh);
598     if constexpr (!props.vdwComb)
599     {
600         a_atomTypes.bind(cgh);
601         a_nbfp.bind(cgh);
602     }
603     else
604     {
605         a_ljComb.bind(cgh);
606     }
607     if constexpr (props.vdwEwald)
608     {
609         a_nbfpComb.bind(cgh);
610     }
611     if constexpr (props.elecEwaldTab)
612     {
613         a_coulombTab.bind(cgh);
614     }
615
616     // shmem buffer for i x+q pre-loading
617     cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(
618             cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
619
620     // shmem buffer for force reduction
621     // SYCL-TODO: Make into 3D; section 4.7.6.11 of SYCL2020 specs
622     cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_reductionBuffer(
623             cl::sycl::range<1>(c_clSize * c_clSize * DIM), cgh);
624
625     auto sm_atomTypeI = [&]() {
626         if constexpr (!props.vdwComb)
627         {
628             return cl::sycl::accessor<int, 2, mode::read_write, target::local>(
629                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
630         }
631         else
632         {
633             return nullptr;
634         }
635     }();
636
637     auto sm_ljCombI = [&]() {
638         if constexpr (props.vdwComb)
639         {
640             return cl::sycl::accessor<Float2, 2, mode::read_write, target::local>(
641                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
642         }
643         else
644         {
645             return nullptr;
646         }
647     }();
648
649     /* Flag to control the calculation of exclusion forces in the kernel
650      * We do that with Ewald (elec/vdw) and RF. Cut-off only has exclusion
651      * energy terms. */
652     constexpr bool doExclusionForces =
653             (props.elecEwald || props.elecRF || props.vdwEwald || (props.elecCutoff && doCalcEnergies));
654
655     // The post-prune j-i cluster-pair organization is linked to how exclusion and interaction mask data is stored.
656     // Currently, this is ideally suited for 32-wide subgroup size but slightly less so for others,
657     // e.g. subGroupSize > prunedClusterPairSize on AMD GCN / CDNA.
658     // Hence, the two are decoupled.
659     // When changing this code, please update requiredSubGroupSizeForNbnxm in src/gromacs/hardware/device_management_sycl.cpp.
660     constexpr int prunedClusterPairSize = c_clSize * c_splitClSize;
661 #if defined(HIPSYCL_PLATFORM_ROCM) // SYCL-TODO AMD RDNA/RDNA2 has 32-wide exec; how can we check for that?
662     gmx_unused constexpr int subGroupSize = c_clSize * c_clSize;
663 #else
664     gmx_unused constexpr int subGroupSize = prunedClusterPairSize;
665 #endif
666
667     return [=](cl::sycl::nd_item<3> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
668     {
669         /* thread/block/warp id-s */
670         const unsigned tidxi = itemIdx.get_local_id(2);
671         const unsigned tidxj = itemIdx.get_local_id(1);
672         const unsigned tidx  = tidxj * c_clSize + tidxi;
673         const unsigned tidxz = 0;
674
675         const unsigned bidx = itemIdx.get_group(0);
676
677         const sycl_2020::sub_group sg = itemIdx.get_sub_group();
678         // Could use sg.get_group_range to compute the imask & exclusion Idx, but too much of the logic relies on it anyway
679         // and in cases where prunedClusterPairSize != subGroupSize we can't use it anyway
680         const unsigned imeiIdx = tidx / prunedClusterPairSize;
681
682         Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster]; // i force buffer
683         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
684         {
685             fCiBuf[i] = Float3(0.0F, 0.0F, 0.0F);
686         }
687
688         const nbnxn_sci_t nbSci     = a_plistSci[bidx];
689         const int         sci       = nbSci.sci;
690         const int         cij4Start = nbSci.cj4_ind_start;
691         const int         cij4End   = nbSci.cj4_ind_end;
692
693         // Only needed if props.elecEwaldAna
694         const float beta2 = ewaldBeta * ewaldBeta;
695         const float beta3 = ewaldBeta * ewaldBeta * ewaldBeta;
696
697         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i += c_clSize)
698         {
699             /* Pre-load i-atom x and q into shared memory */
700             const int             ci       = sci * c_nbnxnGpuNumClusterPerSupercluster + tidxj + i;
701             const int             ai       = ci * c_clSize + tidxi;
702             const cl::sycl::id<2> cacheIdx = cl::sycl::id<2>(tidxj + i, tidxi);
703
704             const Float3 shift = a_shiftVec[nbSci.shift];
705             Float4       xqi   = a_xq[ai];
706             xqi += Float4(shift[0], shift[1], shift[2], 0.0F);
707             xqi[3] *= epsFac;
708             sm_xq[cacheIdx] = xqi;
709
710             if constexpr (!props.vdwComb)
711             {
712                 // Pre-load the i-atom types into shared memory
713                 sm_atomTypeI[cacheIdx] = a_atomTypes[ai];
714             }
715             else
716             {
717                 // Pre-load the LJ combination parameters into shared memory
718                 sm_ljCombI[cacheIdx] = a_ljComb[ai];
719             }
720         }
721         itemIdx.barrier(fence_space::local_space);
722
723         float ewaldCoeffLJ_2, ewaldCoeffLJ_6_6; // Only needed if (props.vdwEwald)
724         if constexpr (props.vdwEwald)
725         {
726             ewaldCoeffLJ_2   = ewaldCoeffLJ * ewaldCoeffLJ;
727             ewaldCoeffLJ_6_6 = ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * c_oneSixth;
728         }
729
730         float energyVdw, energyElec; // Only needed if (doCalcEnergies)
731         if constexpr (doCalcEnergies)
732         {
733             energyVdw = energyElec = 0.0F;
734         }
735         if constexpr (doCalcEnergies && doExclusionForces)
736         {
737             if (nbSci.shift == gmx::c_centralShiftIndex
738                 && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
739             {
740                 // we have the diagonal: add the charge and LJ self interaction energy term
741                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
742                 {
743                     // TODO: Are there other options?
744                     if constexpr (props.elecEwald || props.elecRF || props.elecCutoff)
745                     {
746                         const float qi = sm_xq[i][tidxi][3];
747                         energyElec += qi * qi;
748                     }
749                     if constexpr (props.vdwEwald)
750                     {
751                         energyVdw +=
752                                 a_nbfp[a_atomTypes[(sci * c_nbnxnGpuNumClusterPerSupercluster + i) * c_clSize + tidxi]
753                                        * (numTypes + 1)][0];
754                     }
755                 }
756                 /* divide the self term(s) equally over the j-threads, then multiply with the coefficients. */
757                 if constexpr (props.vdwEwald)
758                 {
759                     energyVdw /= c_clSize;
760                     energyVdw *= 0.5F * c_oneSixth * ewaldCoeffLJ_6_6; // c_OneTwelfth?
761                 }
762                 if constexpr (props.elecRF || props.elecCutoff)
763                 {
764                     // Correct for epsfac^2 due to adding qi^2 */
765                     energyElec /= epsFac * c_clSize;
766                     energyElec *= -0.5F * cRF;
767                 }
768                 if constexpr (props.elecEwald)
769                 {
770                     // Correct for epsfac^2 due to adding qi^2 */
771                     energyElec /= epsFac * c_clSize;
772                     energyElec *= -ewaldBeta * c_OneOverSqrtPi; /* last factor 1/sqrt(pi) */
773                 }
774             } // (nbSci.shift == gmx::c_centralShiftIndex && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
775         }     // (doCalcEnergies && doExclusionForces)
776
777         // Only needed if (doExclusionForces)
778         const bool nonSelfInteraction = !(nbSci.shift == gmx::c_centralShiftIndex & tidxj <= tidxi);
779
780         // loop over the j clusters = seen by any of the atoms in the current super-cluster
781         for (int j4 = cij4Start + tidxz; j4 < cij4End; j4 += 1)
782         {
783             unsigned imask = a_plistCJ4[j4].imei[imeiIdx].imask;
784             if (!doPruneNBL && !imask)
785             {
786                 continue;
787             }
788             const int wexclIdx = a_plistCJ4[j4].imei[imeiIdx].excl_ind;
789             static_assert(gmx::isPowerOfTwo(prunedClusterPairSize));
790             const unsigned wexcl = a_plistExcl[wexclIdx].pair[tidx & (prunedClusterPairSize - 1)];
791             for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
792             {
793                 const bool maskSet =
794                         imask & (superClInteractionMask << (jm * c_nbnxnGpuNumClusterPerSupercluster));
795                 if (!maskSet)
796                 {
797                     continue;
798                 }
799                 unsigned  maskJI = (1U << (jm * c_nbnxnGpuNumClusterPerSupercluster));
800                 const int cj     = a_plistCJ4[j4].cj[jm];
801                 const int aj     = cj * c_clSize + tidxj;
802
803                 // load j atom data
804                 const Float4 xqj = a_xq[aj];
805
806                 const Float3 xj(xqj[0], xqj[1], xqj[2]);
807                 const float  qj = xqj[3];
808                 int          atomTypeJ; // Only needed if (!props.vdwComb)
809                 Float2       ljCombJ;   // Only needed if (props.vdwComb)
810                 if constexpr (props.vdwComb)
811                 {
812                     ljCombJ = a_ljComb[aj];
813                 }
814                 else
815                 {
816                     atomTypeJ = a_atomTypes[aj];
817                 }
818
819                 Float3 fCjBuf(0.0F, 0.0F, 0.0F);
820
821                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
822                 {
823                     if (imask & maskJI)
824                     {
825                         // i cluster index
826                         const int ci = sci * c_nbnxnGpuNumClusterPerSupercluster + i;
827                         // all threads load an atom from i cluster ci into shmem!
828                         const Float4 xqi = sm_xq[i][tidxi];
829                         const Float3 xi(xqi[0], xqi[1], xqi[2]);
830
831                         // distance between i and j atoms
832                         const Float3 rv = xi - xj;
833                         float        r2 = norm2(rv);
834
835                         if constexpr (doPruneNBL)
836                         {
837                             /* If _none_ of the atoms pairs are in cutoff range,
838                              * the bit corresponding to the current
839                              * cluster-pair in imask gets set to 0. */
840                             if (!sycl_2020::group_any_of(sg, r2 < rlistOuterSq))
841                             {
842                                 imask &= ~maskJI;
843                             }
844                         }
845                         const float pairExclMask = (wexcl & maskJI) ? 1.0F : 0.0F;
846
847                         // cutoff & exclusion check
848
849                         const bool notExcluded = doExclusionForces ? (nonSelfInteraction | (ci != cj))
850                                                                    : (wexcl & maskJI);
851
852                         // SYCL-TODO: Check optimal way of branching here.
853                         if ((r2 < rCoulombSq) && notExcluded)
854                         {
855                             const float qi = xqi[3];
856                             int         atomTypeI; // Only needed if (!props.vdwComb)
857                             float       sigma, epsilon;
858                             Float2      c6c12;
859
860                             if constexpr (!props.vdwComb)
861                             {
862                                 /* LJ 6*C6 and 12*C12 */
863                                 atomTypeI = sm_atomTypeI[i][tidxi];
864                                 c6c12     = a_nbfp[numTypes * atomTypeI + atomTypeJ];
865                             }
866                             else
867                             {
868                                 const Float2 ljCombI = sm_ljCombI[i][tidxi];
869                                 if constexpr (props.vdwCombGeom)
870                                 {
871                                     c6c12 = Float2(ljCombI[0] * ljCombJ[0], ljCombI[1] * ljCombJ[1]);
872                                 }
873                                 else
874                                 {
875                                     static_assert(props.vdwCombLB);
876                                     // LJ 2^(1/6)*sigma and 12*epsilon
877                                     sigma   = ljCombI[0] + ljCombJ[0];
878                                     epsilon = ljCombI[1] * ljCombJ[1];
879                                     if constexpr (doCalcEnergies)
880                                     {
881                                         c6c12 = convertSigmaEpsilonToC6C12(sigma, epsilon);
882                                     }
883                                 } // props.vdwCombGeom
884                             }     // !props.vdwComb
885
886                             // c6 and c12 are unused and garbage iff props.vdwCombLB && !doCalcEnergies
887                             const float c6  = c6c12[0];
888                             const float c12 = c6c12[1];
889
890                             // Ensure distance do not become so small that r^-12 overflows
891                             r2 = std::max(r2, c_nbnxnMinDistanceSquared);
892 #if GMX_SYCL_HIPSYCL
893                             // No fast/native functions in some compilation passes
894                             const float rInv = cl::sycl::rsqrt(r2);
895 #else
896                             // SYCL-TODO: sycl::half_precision::rsqrt?
897                             const float rInv = cl::sycl::native::rsqrt(r2);
898 #endif
899                             const float r2Inv = rInv * rInv;
900                             float       r6Inv, fInvR, energyLJPair;
901                             if constexpr (!props.vdwCombLB || doCalcEnergies)
902                             {
903                                 r6Inv = r2Inv * r2Inv * r2Inv;
904                                 if constexpr (doExclusionForces)
905                                 {
906                                     // SYCL-TODO: Check if true for SYCL
907                                     /* We could mask r2Inv, but with Ewald masking both
908                                      * r6Inv and fInvR is faster */
909                                     r6Inv *= pairExclMask;
910                                 }
911                                 fInvR = r6Inv * (c12 * r6Inv - c6) * r2Inv;
912                             }
913                             else
914                             {
915                                 float sig_r  = sigma * rInv;
916                                 float sig_r2 = sig_r * sig_r;
917                                 float sig_r6 = sig_r2 * sig_r2 * sig_r2;
918                                 if constexpr (doExclusionForces)
919                                 {
920                                     sig_r6 *= pairExclMask;
921                                 }
922                                 fInvR = epsilon * sig_r6 * (sig_r6 - 1.0F) * r2Inv;
923                             } // (!props.vdwCombLB || doCalcEnergies)
924                             if constexpr (doCalcEnergies || props.vdwPSwitch)
925                             {
926                                 energyLJPair = pairExclMask
927                                                * (c12 * (r6Inv * r6Inv + repulsionShift.cpot) * c_oneTwelfth
928                                                   - c6 * (r6Inv + dispersionShift.cpot) * c_oneSixth);
929                             }
930                             if constexpr (props.vdwFSwitch)
931                             {
932                                 ljForceSwitch<doCalcEnergies>(
933                                         dispersionShift, repulsionShift, rVdwSwitch, c6, c12, rInv, r2, &fInvR, &energyLJPair);
934                             }
935                             if constexpr (props.vdwEwald)
936                             {
937                                 ljEwaldComb<doCalcEnergies, vdwType>(a_nbfpComb,
938                                                                      ljEwaldShift,
939                                                                      atomTypeI,
940                                                                      atomTypeJ,
941                                                                      r2,
942                                                                      r2Inv,
943                                                                      ewaldCoeffLJ_2,
944                                                                      ewaldCoeffLJ_6_6,
945                                                                      pairExclMask,
946                                                                      &fInvR,
947                                                                      &energyLJPair);
948                             } // (props.vdwEwald)
949                             if constexpr (props.vdwPSwitch)
950                             {
951                                 ljPotentialSwitch<doCalcEnergies>(
952                                         vdwSwitch, rVdwSwitch, rInv, r2, &fInvR, &energyLJPair);
953                             }
954                             if constexpr (props.elecEwaldTwin)
955                             {
956                                 // Separate VDW cut-off check to enable twin-range cut-offs
957                                 // (rVdw < rCoulomb <= rList)
958                                 const float vdwInRange = (r2 < rVdwSq) ? 1.0F : 0.0F;
959                                 fInvR *= vdwInRange;
960                                 if constexpr (doCalcEnergies)
961                                 {
962                                     energyLJPair *= vdwInRange;
963                                 }
964                             }
965                             if constexpr (doCalcEnergies)
966                             {
967                                 energyVdw += energyLJPair;
968                             }
969
970                             if constexpr (props.elecCutoff)
971                             {
972                                 if constexpr (doExclusionForces)
973                                 {
974                                     fInvR += qi * qj * pairExclMask * r2Inv * rInv;
975                                 }
976                                 else
977                                 {
978                                     fInvR += qi * qj * r2Inv * rInv;
979                                 }
980                             }
981                             if constexpr (props.elecRF)
982                             {
983                                 fInvR += qi * qj * (pairExclMask * r2Inv * rInv - twoKRf);
984                             }
985                             if constexpr (props.elecEwaldAna)
986                             {
987                                 fInvR += qi * qj
988                                          * (pairExclMask * r2Inv * rInv + pmeCorrF(beta2 * r2) * beta3);
989                             }
990                             if constexpr (props.elecEwaldTab)
991                             {
992                                 fInvR += qi * qj
993                                          * (pairExclMask * r2Inv
994                                             - interpolateCoulombForceR(
995                                                     a_coulombTab, coulombTabScale, r2 * rInv))
996                                          * rInv;
997                             }
998
999                             if constexpr (doCalcEnergies)
1000                             {
1001                                 if constexpr (props.elecCutoff)
1002                                 {
1003                                     energyElec += qi * qj * (pairExclMask * rInv - cRF);
1004                                 }
1005                                 if constexpr (props.elecRF)
1006                                 {
1007                                     energyElec +=
1008                                             qi * qj * (pairExclMask * rInv + 0.5F * twoKRf * r2 - cRF);
1009                                 }
1010                                 if constexpr (props.elecEwald)
1011                                 {
1012                                     energyElec +=
1013                                             qi * qj
1014                                             * (rInv * (pairExclMask - cl::sycl::erf(r2 * rInv * ewaldBeta))
1015                                                - pairExclMask * ewaldShift);
1016                                 }
1017                             }
1018
1019                             const Float3 forceIJ = rv * fInvR;
1020
1021                             /* accumulate j forces in registers */
1022                             fCjBuf -= forceIJ;
1023                             /* accumulate i forces in registers */
1024                             fCiBuf[i] += forceIJ;
1025                         } // (r2 < rCoulombSq) && notExcluded
1026                     }     // (imask & maskJI)
1027                     /* shift the mask bit by 1 */
1028                     maskJI += maskJI;
1029                 } // for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
1030                 /* reduce j forces */
1031                 reduceForceJ(sm_reductionBuffer, fCjBuf, itemIdx, tidxi, tidxj, aj, a_f);
1032             } // for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
1033             if constexpr (doPruneNBL)
1034             {
1035                 /* Update the imask with the new one which does not contain the
1036                  * out of range clusters anymore. */
1037                 a_plistCJ4[j4].imei[imeiIdx].imask = imask;
1038             }
1039         } // for (int j4 = cij4Start; j4 < cij4End; j4 += 1)
1040
1041         /* skip central shifts when summing shift forces */
1042         const bool doCalcShift = (calcShift && nbSci.shift != gmx::c_centralShiftIndex);
1043
1044         reduceForceIAndFShift(
1045                 sm_reductionBuffer, fCiBuf, doCalcShift, itemIdx, tidxi, tidxj, sci, nbSci.shift, a_f, a_fShift);
1046
1047         if constexpr (doCalcEnergies)
1048         {
1049             const float energyVdwGroup =
1050                     groupReduce<subGroupSize, c_clSizeSq>(itemIdx, tidx, sm_reductionBuffer, energyVdw);
1051             itemIdx.barrier(fence_space::local_space); // Prevent the race on sm_reductionBuffer.
1052             const float energyElecGroup = groupReduce<subGroupSize, c_clSizeSq>(
1053                     itemIdx, tidx, sm_reductionBuffer, energyElec);
1054
1055             if (tidx == 0)
1056             {
1057                 atomicFetchAdd(a_energyVdw[0], energyVdwGroup);
1058                 atomicFetchAdd(a_energyElec[0], energyElecGroup);
1059             }
1060         }
1061     };
1062 }
1063
1064 //! \brief NBNXM kernel launch code.
1065 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType, class... Args>
1066 cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int numSci, Args&&... args)
1067 {
1068     using kernelNameType = NbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>;
1069
1070     /* Kernel launch config:
1071      * - The thread block dimensions match the size of i-clusters, j-clusters,
1072      *   and j-cluster concurrency, in x, y, and z, respectively.
1073      * - The 1D block-grid contains as many blocks as super-clusters.
1074      */
1075     const int                   numBlocks = numSci;
1076     const cl::sycl::range<3>    blockSize{ 1, c_clSize, c_clSize };
1077     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
1078     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
1079
1080     cl::sycl::queue q = deviceStream.stream();
1081
1082     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
1083         auto kernel = nbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>(
1084                 cgh, std::forward<Args>(args)...);
1085         cgh.parallel_for<kernelNameType>(range, kernel);
1086     });
1087
1088     return e;
1089 }
1090
1091 //! \brief Select templated kernel and launch it.
1092 template<class... Args>
1093 cl::sycl::event chooseAndLaunchNbnxmKernel(bool          doPruneNBL,
1094                                            bool          doCalcEnergies,
1095                                            enum ElecType elecType,
1096                                            enum VdwType  vdwType,
1097                                            Args&&... args)
1098 {
1099     return gmx::dispatchTemplatedFunction(
1100             [&](auto doPruneNBL_, auto doCalcEnergies_, auto elecType_, auto vdwType_) {
1101                 return launchNbnxmKernel<doPruneNBL_, doCalcEnergies_, elecType_, vdwType_>(
1102                         std::forward<Args>(args)...);
1103             },
1104             doPruneNBL,
1105             doCalcEnergies,
1106             elecType,
1107             vdwType);
1108 }
1109
1110 void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const InteractionLocality iloc)
1111 {
1112     NBAtomDataGpu*      adat         = nb->atdat;
1113     NBParamGpu*         nbp          = nb->nbparam;
1114     gpu_plist*          plist        = nb->plist[iloc];
1115     const bool          doPruneNBL   = (plist->haveFreshList && !nb->didPrune[iloc]);
1116     const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
1117
1118     cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL,
1119                                                    stepWork.computeEnergy,
1120                                                    nbp->elecType,
1121                                                    nbp->vdwType,
1122                                                    deviceStream,
1123                                                    plist->nsci,
1124                                                    adat->xq,
1125                                                    adat->f,
1126                                                    adat->shiftVec,
1127                                                    adat->fShift,
1128                                                    adat->eElec,
1129                                                    adat->eLJ,
1130                                                    plist->cj4,
1131                                                    plist->sci,
1132                                                    plist->excl,
1133                                                    adat->ljComb,
1134                                                    adat->atomTypes,
1135                                                    nbp->nbfp,
1136                                                    nbp->nbfp_comb,
1137                                                    nbp->coulomb_tab,
1138                                                    adat->numTypes,
1139                                                    nbp->rcoulomb_sq,
1140                                                    nbp->rvdw_sq,
1141                                                    nbp->two_k_rf,
1142                                                    nbp->ewald_beta,
1143                                                    nbp->rlistOuter_sq,
1144                                                    nbp->sh_ewald,
1145                                                    nbp->epsfac,
1146                                                    nbp->ewaldcoeff_lj,
1147                                                    nbp->c_rf,
1148                                                    nbp->dispersion_shift,
1149                                                    nbp->repulsion_shift,
1150                                                    nbp->vdw_switch,
1151                                                    nbp->rvdw_switch,
1152                                                    nbp->sh_lj_ewald,
1153                                                    nbp->coulomb_tab_scale,
1154                                                    stepWork.computeVirial);
1155 }
1156
1157 } // namespace Nbnxm