7304922c09f19ae5ea9a54a8b490190b837c040f
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief
38  *  NBNXM SYCL kernels
39  *
40  *  \ingroup module_nbnxm
41  */
42 #include "gmxpre.h"
43
44 #include "nbnxm_sycl_kernel.h"
45
46 #include "gromacs/gpu_utils/devicebuffer.h"
47 #include "gromacs/gpu_utils/gmxsycl.h"
48 #include "gromacs/math/functions.h"
49 #include "gromacs/mdtypes/simulation_workload.h"
50 #include "gromacs/pbcutil/ishift.h"
51 #include "gromacs/utility/template_mp.h"
52
53 #include "nbnxm_sycl_kernel_utils.h"
54 #include "nbnxm_sycl_types.h"
55
56 //! \brief Class name for NBNXM kernel
57 template<bool doPruneNBL, bool doCalcEnergies, enum Nbnxm::ElecType elecType, enum Nbnxm::VdwType vdwType>
58 class NbnxmKernel;
59
60 namespace Nbnxm
61 {
62
63 //! \brief Set of boolean constants mimicking preprocessor macros.
64 template<enum ElecType elecType, enum VdwType vdwType>
65 struct EnergyFunctionProperties {
66     static constexpr bool elecCutoff = (elecType == ElecType::Cut); ///< EL_CUTOFF
67     static constexpr bool elecRF     = (elecType == ElecType::RF);  ///< EL_RF
68     static constexpr bool elecEwaldAna =
69             (elecType == ElecType::EwaldAna || elecType == ElecType::EwaldAnaTwin); ///< EL_EWALD_ANA
70     static constexpr bool elecEwaldTab =
71             (elecType == ElecType::EwaldTab || elecType == ElecType::EwaldTabTwin); ///< EL_EWALD_TAB
72     static constexpr bool elecEwaldTwin =
73             (elecType == ElecType::EwaldAnaTwin || elecType == ElecType::EwaldTabTwin); ///< Use twin cut-off.
74     static constexpr bool elecEwald = (elecEwaldAna || elecEwaldTab);  ///< EL_EWALD_ANY
75     static constexpr bool vdwCombLB = (vdwType == VdwType::CutCombLB); ///< LJ_COMB && !LJ_COMB_GEOM
76     static constexpr bool vdwCombGeom      = (vdwType == VdwType::CutCombGeom); ///< LJ_COMB_GEOM
77     static constexpr bool vdwComb          = (vdwCombLB || vdwCombGeom);        ///< LJ_COMB
78     static constexpr bool vdwEwaldCombGeom = (vdwType == VdwType::EwaldGeom); ///< LJ_EWALD_COMB_GEOM
79     static constexpr bool vdwEwaldCombLB   = (vdwType == VdwType::EwaldLB);   ///< LJ_EWALD_COMB_LB
80     static constexpr bool vdwEwald         = (vdwEwaldCombGeom || vdwEwaldCombLB); ///< LJ_EWALD
81     static constexpr bool vdwFSwitch       = (vdwType == VdwType::FSwitch); ///< LJ_FORCE_SWITCH
82     static constexpr bool vdwPSwitch       = (vdwType == VdwType::PSwitch); ///< LJ_POT_SWITCH
83 };
84
85 //! \brief Templated constants to shorten kernel function declaration.
86 //@{
87 template<enum VdwType vdwType>
88 constexpr bool ljComb = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwComb;
89
90 template<enum ElecType elecType>
91 constexpr bool elecEwald = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwald;
92
93 template<enum ElecType elecType>
94 constexpr bool elecEwaldTab = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTab;
95
96 template<enum VdwType vdwType>
97 constexpr bool ljEwald = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwEwald;
98 //@}
99
100 using cl::sycl::access::fence_space;
101 using cl::sycl::access::mode;
102 using cl::sycl::access::target;
103
104 //! \brief Convert \p sigma and \p epsilon VdW parameters to \c c6,c12 pair.
105 static inline Float2 convertSigmaEpsilonToC6C12(const float sigma, const float epsilon)
106 {
107     const float sigma2 = sigma * sigma;
108     const float sigma6 = sigma2 * sigma2 * sigma2;
109     const float c6     = epsilon * sigma6;
110     const float c12    = c6 * sigma6;
111
112     return { c6, c12 };
113 }
114
115 //! \brief Calculate force and energy for a pair of atoms, VdW force-switch flavor.
116 template<bool doCalcEnergies>
117 static inline void ljForceSwitch(const shift_consts_t         dispersionShift,
118                                  const shift_consts_t         repulsionShift,
119                                  const float                  rVdwSwitch,
120                                  const float                  c6,
121                                  const float                  c12,
122                                  const float                  rInv,
123                                  const float                  r2,
124                                  cl::sycl::private_ptr<float> fInvR,
125                                  cl::sycl::private_ptr<float> eLJ)
126 {
127     /* force switch constants */
128     const float dispShiftV2 = dispersionShift.c2;
129     const float dispShiftV3 = dispersionShift.c3;
130     const float repuShiftV2 = repulsionShift.c2;
131     const float repuShiftV3 = repulsionShift.c3;
132
133     const float r       = r2 * rInv;
134     const float rSwitch = cl::sycl::fdim(r, rVdwSwitch); // max(r - rVdwSwitch, 0)
135
136     *fInvR += -c6 * (dispShiftV2 + dispShiftV3 * rSwitch) * rSwitch * rSwitch * rInv
137               + c12 * (repuShiftV2 + repuShiftV3 * rSwitch) * rSwitch * rSwitch * rInv;
138
139     if constexpr (doCalcEnergies)
140     {
141         const float dispShiftF2 = dispShiftV2 / 3;
142         const float dispShiftF3 = dispShiftV3 / 4;
143         const float repuShiftF2 = repuShiftV2 / 3;
144         const float repuShiftF3 = repuShiftV3 / 4;
145         *eLJ += c6 * (dispShiftF2 + dispShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch
146                 - c12 * (repuShiftF2 + repuShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch;
147     }
148 }
149
150 //! \brief Fetch C6 grid contribution coefficients and return the product of these.
151 template<enum VdwType vdwType>
152 static inline float calculateLJEwaldC6Grid(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
153                                            const int                                typeI,
154                                            const int                                typeJ)
155 {
156     if constexpr (vdwType == VdwType::EwaldGeom)
157     {
158         return a_nbfpComb[typeI][0] * a_nbfpComb[typeJ][0];
159     }
160     else
161     {
162         static_assert(vdwType == VdwType::EwaldLB);
163         /* sigma and epsilon are scaled to give 6*C6 */
164         const Float2 c6c12_i = a_nbfpComb[typeI];
165         const Float2 c6c12_j = a_nbfpComb[typeJ];
166
167         const float sigma   = c6c12_i[0] + c6c12_j[0];
168         const float epsilon = c6c12_i[1] * c6c12_j[1];
169
170         const float sigma2 = sigma * sigma;
171         return epsilon * sigma2 * sigma2 * sigma2;
172     }
173 }
174
175 //! Calculate LJ-PME grid force contribution with geometric or LB combination rule.
176 template<bool doCalcEnergies, enum VdwType vdwType>
177 static inline void ljEwaldComb(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
178                                const float                              sh_lj_ewald,
179                                const int                                typeI,
180                                const int                                typeJ,
181                                const float                              r2,
182                                const float                              r2Inv,
183                                const float                              lje_coeff2,
184                                const float                              lje_coeff6_6,
185                                const float                              int_bit,
186                                cl::sycl::private_ptr<float>             fInvR,
187                                cl::sycl::private_ptr<float>             eLJ)
188 {
189     const float c6grid = calculateLJEwaldC6Grid<vdwType>(a_nbfpComb, typeI, typeJ);
190
191     /* Recalculate inv_r6 without exclusion mask */
192     const float inv_r6_nm = r2Inv * r2Inv * r2Inv;
193     const float cr2       = lje_coeff2 * r2;
194     const float expmcr2   = cl::sycl::exp(-cr2);
195     const float poly      = 1.0F + cr2 + 0.5F * cr2 * cr2;
196
197     /* Subtract the grid force from the total LJ force */
198     *fInvR += c6grid * (inv_r6_nm - expmcr2 * (inv_r6_nm * poly + lje_coeff6_6)) * r2Inv;
199
200     if constexpr (doCalcEnergies)
201     {
202         /* Shift should be applied only to real LJ pairs */
203         const float sh_mask = sh_lj_ewald * int_bit;
204         *eLJ += c_oneSixth * c6grid * (inv_r6_nm * (1.0F - expmcr2 * poly) + sh_mask);
205     }
206 }
207
208 /*! \brief Apply potential switch. */
209 template<bool doCalcEnergies>
210 static inline void ljPotentialSwitch(const switch_consts_t        vdwSwitch,
211                                      const float                  rVdwSwitch,
212                                      const float                  rInv,
213                                      const float                  r2,
214                                      cl::sycl::private_ptr<float> fInvR,
215                                      cl::sycl::private_ptr<float> eLJ)
216 {
217     /* potential switch constants */
218     const float switchV3 = vdwSwitch.c3;
219     const float switchV4 = vdwSwitch.c4;
220     const float switchV5 = vdwSwitch.c5;
221     const float switchF2 = 3 * vdwSwitch.c3;
222     const float switchF3 = 4 * vdwSwitch.c4;
223     const float switchF4 = 5 * vdwSwitch.c5;
224
225     const float r       = r2 * rInv;
226     const float rSwitch = r - rVdwSwitch;
227
228     if (rSwitch > 0.0F)
229     {
230         const float sw =
231                 1.0F + (switchV3 + (switchV4 + switchV5 * rSwitch) * rSwitch) * rSwitch * rSwitch * rSwitch;
232         const float dsw = (switchF2 + (switchF3 + switchF4 * rSwitch) * rSwitch) * rSwitch * rSwitch;
233
234         *fInvR = (*fInvR) * sw - rInv * (*eLJ) * dsw;
235         if constexpr (doCalcEnergies)
236         {
237             *eLJ *= sw;
238         }
239     }
240 }
241
242
243 /*! \brief Calculate analytical Ewald correction term. */
244 static inline float pmeCorrF(const float z2)
245 {
246     constexpr float FN6 = -1.7357322914161492954e-8F;
247     constexpr float FN5 = 1.4703624142580877519e-6F;
248     constexpr float FN4 = -0.000053401640219807709149F;
249     constexpr float FN3 = 0.0010054721316683106153F;
250     constexpr float FN2 = -0.019278317264888380590F;
251     constexpr float FN1 = 0.069670166153766424023F;
252     constexpr float FN0 = -0.75225204789749321333F;
253
254     constexpr float FD4 = 0.0011193462567257629232F;
255     constexpr float FD3 = 0.014866955030185295499F;
256     constexpr float FD2 = 0.11583842382862377919F;
257     constexpr float FD1 = 0.50736591960530292870F;
258     constexpr float FD0 = 1.0F;
259
260     const float z4 = z2 * z2;
261
262     float       polyFD0 = FD4 * z4 + FD2;
263     const float polyFD1 = FD3 * z4 + FD1;
264     polyFD0             = polyFD0 * z4 + FD0;
265     polyFD0             = polyFD1 * z2 + polyFD0;
266
267     polyFD0 = 1.0F / polyFD0;
268
269     float polyFN0 = FN6 * z4 + FN4;
270     float polyFN1 = FN5 * z4 + FN3;
271     polyFN0       = polyFN0 * z4 + FN2;
272     polyFN1       = polyFN1 * z4 + FN1;
273     polyFN0       = polyFN0 * z4 + FN0;
274     polyFN0       = polyFN1 * z2 + polyFN0;
275
276     return polyFN0 * polyFD0;
277 }
278
279 /*! \brief Linear interpolation using exactly two FMA operations.
280  *
281  *  Implements numeric equivalent of: (1-t)*d0 + t*d1.
282  */
283 template<typename T>
284 static inline T lerp(T d0, T d1, T t)
285 {
286     return fma(t, d1, fma(-t, d0, d0));
287 }
288
289 /*! \brief Interpolate Ewald coulomb force correction using the F*r table. */
290 static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::read> a_coulombTab,
291                                              const float coulombTabScale,
292                                              const float r)
293 {
294     const float normalized = coulombTabScale * r;
295     const int   index      = static_cast<int>(normalized);
296     const float fraction   = normalized - index;
297
298     const float left  = a_coulombTab[index];
299     const float right = a_coulombTab[index + 1];
300
301     return lerp(left, right, fraction); // TODO: cl::sycl::mix
302 }
303
304 /*! \brief Reduce c_clSize j-force components using shifts and atomically accumulate into a_f.
305  *
306  * c_clSize consecutive threads hold the force components of a j-atom which we
307  * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
308  */
309 static inline void reduceForceJShuffle(Float3                             f,
310                                        const cl::sycl::nd_item<1>         itemIdx,
311                                        const int                          tidxi,
312                                        const int                          aidx,
313                                        DeviceAccessor<float, mode_atomic> a_f)
314 {
315     static_assert(c_clSize == 8 || c_clSize == 4);
316     sycl_2020::sub_group sg = itemIdx.get_sub_group();
317
318     f[0] += sycl_2020::shift_left(sg, f[0], 1);
319     f[1] += sycl_2020::shift_right(sg, f[1], 1);
320     f[2] += sycl_2020::shift_left(sg, f[2], 1);
321     if (tidxi & 1)
322     {
323         f[0] = f[1];
324     }
325
326     f[0] += sycl_2020::shift_left(sg, f[0], 2);
327     f[2] += sycl_2020::shift_right(sg, f[2], 2);
328     if (tidxi & 2)
329     {
330         f[0] = f[2];
331     }
332
333     if constexpr (c_clSize == 8)
334     {
335         f[0] += sycl_2020::shift_left(sg, f[0], 4);
336     }
337
338     if (tidxi < 3)
339     {
340         atomicFetchAdd(a_f, 3 * aidx + tidxi, f[0]);
341     }
342 }
343
344 /*!
345  * \brief Do workgroup-level reduction of a single \c float.
346  *
347  * While SYCL has \c sycl::reduce_over_group, it currently (oneAPI 2021.3.0) uses a very large
348  * shared memory buffer, which leads to a reduced occupancy.
349  *
350  * \tparam subGroupSize Size of a sub-group.
351  * \tparam groupSize Size of a work-group.
352  * \param itemIdx Current thread's \c sycl::nd_item.
353  * \param tidxi Current thread's linearized local index.
354  * \param sm_buf Accessor for local reduction buffer.
355  * \param valueToReduce Current thread's value. Must have length of at least 1.
356  * \return For thread with \p tidxi 0: sum of all \p valueToReduce. Other threads: unspecified.
357  */
358 template<int subGroupSize, int groupSize>
359 static inline float groupReduce(const cl::sycl::nd_item<1> itemIdx,
360                                 const unsigned int         tidxi,
361                                 cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
362                                 float valueToReduce)
363 {
364     constexpr int numSubGroupsInGroup = groupSize / subGroupSize;
365     static_assert(numSubGroupsInGroup == 1 || numSubGroupsInGroup == 2);
366     sycl_2020::sub_group sg = itemIdx.get_sub_group();
367     valueToReduce           = sycl_2020::group_reduce(sg, valueToReduce, sycl_2020::plus<float>());
368     // If we have two sub-groups, we should reduce across them.
369     if constexpr (numSubGroupsInGroup == 2)
370     {
371         if (tidxi == subGroupSize)
372         {
373             sm_buf[0] = valueToReduce;
374         }
375         itemIdx.barrier(fence_space::local_space);
376         if (tidxi == 0)
377         {
378             valueToReduce += sm_buf[0];
379         }
380     }
381     return valueToReduce;
382 }
383
384 /*! \brief Reduce c_clSize j-force components using local memory and atomically accumulate into a_f.
385  *
386  * c_clSize consecutive threads hold the force components of a j-atom which we
387  * reduced in cl_Size steps using shift and atomically accumulate them into \p a_f.
388  *
389  * TODO: implement binary reduction flavor for the case where cl_Size is power of two.
390  */
391 static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
392                                        Float3                             f,
393                                        const cl::sycl::nd_item<1>         itemIdx,
394                                        const int                          tidxi,
395                                        const int                          tidxj,
396                                        const int                          aidx,
397                                        DeviceAccessor<float, mode_atomic> a_f)
398 {
399     static constexpr int sc_fBufferStride = c_clSizeSq;
400     int                  tidx             = tidxi + tidxj * c_clSize;
401     sm_buf[0 * sc_fBufferStride + tidx]   = f[0];
402     sm_buf[1 * sc_fBufferStride + tidx]   = f[1];
403     sm_buf[2 * sc_fBufferStride + tidx]   = f[2];
404
405     subGroupBarrier(itemIdx);
406
407     // reducing data 8-by-by elements on the leader of same threads as those storing above
408     assert(itemIdx.get_sub_group().get_local_range().size() >= c_clSize);
409
410     if (tidxi < 3)
411     {
412         float fSum = 0.0F;
413         for (int j = tidxj * c_clSize; j < (tidxj + 1) * c_clSize; j++)
414         {
415             fSum += sm_buf[sc_fBufferStride * tidxi + j];
416         }
417
418         atomicFetchAdd(a_f, 3 * aidx + tidxi, fSum);
419     }
420 }
421
422
423 /*! \brief Reduce c_clSize j-force components using either shifts or local memory and atomically accumulate into a_f.
424  */
425 static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
426                                 Float3                                                        f,
427                                 const cl::sycl::nd_item<1>         itemIdx,
428                                 const int                          tidxi,
429                                 const int                          tidxj,
430                                 const int                          aidx,
431                                 DeviceAccessor<float, mode_atomic> a_f)
432 {
433     if constexpr (!gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster))
434     {
435         reduceForceJGeneric(sm_buf, f, itemIdx, tidxi, tidxj, aidx, a_f);
436     }
437     else
438     {
439         reduceForceJShuffle(f, itemIdx, tidxi, aidx, a_f);
440     }
441 }
442
443
444 /*! \brief Final i-force reduction.
445  *
446  * Reduce c_nbnxnGpuNumClusterPerSupercluster i-force components stored in \p fCiBuf[]
447  * accumulating atomically into \p a_f.
448  * If \p calcFShift is true, further reduce shift forces and atomically accumulate into \p a_fShift.
449  *
450  * This implementation works only with power of two array sizes.
451  */
452 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
453                                          const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
454                                          const bool   calcFShift,
455                                          const cl::sycl::nd_item<1>         itemIdx,
456                                          const int                          tidxi,
457                                          const int                          tidxj,
458                                          const int                          sci,
459                                          const int                          shift,
460                                          DeviceAccessor<float, mode_atomic> a_f,
461                                          DeviceAccessor<float, mode_atomic> a_fShift)
462 {
463     // must have power of two elements in fCiBuf
464     static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster));
465
466     static constexpr int bufStride  = c_clSize * c_clSize;
467     static constexpr int clSizeLog2 = gmx::StaticLog2<c_clSize>::value;
468     const int            tidx       = tidxi + tidxj * c_clSize;
469     float                fShiftBuf  = 0.0F;
470     for (int ciOffset = 0; ciOffset < c_nbnxnGpuNumClusterPerSupercluster; ciOffset++)
471     {
472         const int aidx = (sci * c_nbnxnGpuNumClusterPerSupercluster + ciOffset) * c_clSize + tidxi;
473         /* store i forces in shmem */
474         sm_buf[tidx]                 = fCiBuf[ciOffset][0];
475         sm_buf[bufStride + tidx]     = fCiBuf[ciOffset][1];
476         sm_buf[2 * bufStride + tidx] = fCiBuf[ciOffset][2];
477         itemIdx.barrier(fence_space::local_space);
478
479         /* Reduce the initial c_clSize values for each i atom to half
480          * every step by using c_clSize * i threads. */
481         int i = c_clSize / 2;
482         for (int j = clSizeLog2 - 1; j > 0; j--)
483         {
484             if (tidxj < i)
485             {
486                 sm_buf[tidxj * c_clSize + tidxi] += sm_buf[(tidxj + i) * c_clSize + tidxi];
487                 sm_buf[bufStride + tidxj * c_clSize + tidxi] +=
488                         sm_buf[bufStride + (tidxj + i) * c_clSize + tidxi];
489                 sm_buf[2 * bufStride + tidxj * c_clSize + tidxi] +=
490                         sm_buf[2 * bufStride + (tidxj + i) * c_clSize + tidxi];
491             }
492             i >>= 1;
493             itemIdx.barrier(fence_space::local_space);
494         }
495
496         /* i == 1, last reduction step, writing to global mem */
497         /* Split the reduction between the first 3 line threads
498            Threads with line id 0 will do the reduction for (float3).x components
499            Threads with line id 1 will do the reduction for (float3).y components
500            Threads with line id 2 will do the reduction for (float3).z components. */
501         if (tidxj < 3)
502         {
503             const float f =
504                     sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
505             atomicFetchAdd(a_f, 3 * aidx + tidxj, f);
506             if (calcFShift)
507             {
508                 fShiftBuf += f;
509             }
510         }
511         itemIdx.barrier(fence_space::local_space);
512     }
513     /* add up local shift forces into global mem */
514     if (calcFShift)
515     {
516         /* Only threads with tidxj < 3 will update fshift.
517            The threads performing the update must be the same as the threads
518            storing the reduction result above. */
519         if (tidxj < 3)
520         {
521             if constexpr (c_clSize == 4)
522             {
523                 /* Intel Xe (Gen12LP) and earlier GPUs implement floating-point atomics via
524                  * a compare-and-swap (CAS) loop. It has particularly poor performance when
525                  * updating the same memory location from the same work-group.
526                  * Such optimization might be slightly beneficial for NVIDIA and AMD as well,
527                  * but it is unlikely to make a big difference and thus was not evaluated.
528                  */
529                 auto sg = itemIdx.get_sub_group();
530                 fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 1);
531                 fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 2);
532                 if (tidxi == 0)
533                 {
534                     atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
535                 }
536             }
537             else
538             {
539                 atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
540             }
541         }
542     }
543 }
544
545 /*! \brief Main kernel for NBNXM.
546  *
547  */
548 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
549 auto nbnxmKernel(cl::sycl::handler&                                   cgh,
550                  DeviceAccessor<Float4, mode::read>                   a_xq,
551                  DeviceAccessor<float, mode_atomic>                   a_f,
552                  DeviceAccessor<Float3, mode::read>                   a_shiftVec,
553                  DeviceAccessor<float, mode_atomic>                   a_fShift,
554                  OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyElec,
555                  OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyVdw,
556                  DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
557                  DeviceAccessor<nbnxn_sci_t, mode::read>                                 a_plistSci,
558                  DeviceAccessor<nbnxn_excl_t, mode::read>                    a_plistExcl,
559                  OptionalAccessor<Float2, mode::read, ljComb<vdwType>>       a_ljComb,
560                  OptionalAccessor<int, mode::read, !ljComb<vdwType>>         a_atomTypes,
561                  OptionalAccessor<Float2, mode::read, !ljComb<vdwType>>      a_nbfp,
562                  OptionalAccessor<Float2, mode::read, ljEwald<vdwType>>      a_nbfpComb,
563                  OptionalAccessor<float, mode::read, elecEwaldTab<elecType>> a_coulombTab,
564                  const int                                                   numTypes,
565                  const float                                                 rCoulombSq,
566                  const float                                                 rVdwSq,
567                  const float                                                 twoKRf,
568                  const float                                                 ewaldBeta,
569                  const float                                                 rlistOuterSq,
570                  const float                                                 ewaldShift,
571                  const float                                                 epsFac,
572                  const float                                                 ewaldCoeffLJ,
573                  const float                                                 cRF,
574                  const shift_consts_t                                        dispersionShift,
575                  const shift_consts_t                                        repulsionShift,
576                  const switch_consts_t                                       vdwSwitch,
577                  const float                                                 rVdwSwitch,
578                  const float                                                 ljEwaldShift,
579                  const float                                                 coulombTabScale,
580                  const bool                                                  calcShift)
581 {
582     static constexpr EnergyFunctionProperties<elecType, vdwType> props;
583
584     cgh.require(a_xq);
585     cgh.require(a_f);
586     cgh.require(a_shiftVec);
587     cgh.require(a_fShift);
588     if constexpr (doCalcEnergies)
589     {
590         cgh.require(a_energyElec);
591         cgh.require(a_energyVdw);
592     }
593     cgh.require(a_plistCJ4);
594     cgh.require(a_plistSci);
595     cgh.require(a_plistExcl);
596     if constexpr (!props.vdwComb)
597     {
598         cgh.require(a_atomTypes);
599         cgh.require(a_nbfp);
600     }
601     else
602     {
603         cgh.require(a_ljComb);
604     }
605     if constexpr (props.vdwEwald)
606     {
607         cgh.require(a_nbfpComb);
608     }
609     if constexpr (props.elecEwaldTab)
610     {
611         cgh.require(a_coulombTab);
612     }
613
614     // shmem buffer for i x+q pre-loading
615     cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(
616             cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
617
618     // shmem buffer for force reduction
619     // SYCL-TODO: Make into 3D; section 4.7.6.11 of SYCL2020 specs
620     cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_reductionBuffer(
621             cl::sycl::range<1>(c_clSize * c_clSize * DIM), cgh);
622
623     auto sm_atomTypeI = [&]() {
624         if constexpr (!props.vdwComb)
625         {
626             return cl::sycl::accessor<int, 2, mode::read_write, target::local>(
627                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
628         }
629         else
630         {
631             return nullptr;
632         }
633     }();
634
635     auto sm_ljCombI = [&]() {
636         if constexpr (props.vdwComb)
637         {
638             return cl::sycl::accessor<Float2, 2, mode::read_write, target::local>(
639                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
640         }
641         else
642         {
643             return nullptr;
644         }
645     }();
646
647     /* Flag to control the calculation of exclusion forces in the kernel
648      * We do that with Ewald (elec/vdw) and RF. Cut-off only has exclusion
649      * energy terms. */
650     constexpr bool doExclusionForces =
651             (props.elecEwald || props.elecRF || props.vdwEwald || (props.elecCutoff && doCalcEnergies));
652
653     // The post-prune j-i cluster-pair organization is linked to how exclusion and interaction mask data is stored.
654     // Currently, this is ideally suited for 32-wide subgroup size but slightly less so for others,
655     // e.g. subGroupSize > prunedClusterPairSize on AMD GCN / CDNA.
656     // Hence, the two are decoupled.
657     // When changing this code, please update requiredSubGroupSizeForNbnxm in src/gromacs/hardware/device_management_sycl.cpp.
658     constexpr int prunedClusterPairSize = c_clSize * c_splitClSize;
659 #if defined(HIPSYCL_PLATFORM_ROCM) // SYCL-TODO AMD RDNA/RDNA2 has 32-wide exec; how can we check for that?
660     gmx_unused constexpr int subGroupSize = c_clSize * c_clSize;
661 #else
662     gmx_unused constexpr int subGroupSize = prunedClusterPairSize;
663 #endif
664
665     return [=](cl::sycl::nd_item<1> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
666     {
667         /* thread/block/warp id-s */
668         const cl::sycl::id<3> localId = unflattenId<c_clSize, c_clSize>(itemIdx.get_local_id());
669         const unsigned        tidxi   = localId[0];
670         const unsigned        tidxj   = localId[1];
671         const unsigned        tidx    = tidxj * c_clSize + tidxi;
672         const unsigned        tidxz   = 0;
673
674         // Group indexing was flat originally, no need to unflatten it.
675         const unsigned bidx = itemIdx.get_group(0);
676
677         const sycl_2020::sub_group sg = itemIdx.get_sub_group();
678         // Could use sg.get_group_range to compute the imask & exclusion Idx, but too much of the logic relies on it anyway
679         // and in cases where prunedClusterPairSize != subGroupSize we can't use it anyway
680         const unsigned imeiIdx = tidx / prunedClusterPairSize;
681
682         Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster]; // i force buffer
683         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
684         {
685             fCiBuf[i] = Float3(0.0F, 0.0F, 0.0F);
686         }
687
688         const nbnxn_sci_t nbSci     = a_plistSci[bidx];
689         const int         sci       = nbSci.sci;
690         const int         cij4Start = nbSci.cj4_ind_start;
691         const int         cij4End   = nbSci.cj4_ind_end;
692
693         // Only needed if props.elecEwaldAna
694         const float beta2 = ewaldBeta * ewaldBeta;
695         const float beta3 = ewaldBeta * ewaldBeta * ewaldBeta;
696
697         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i += c_clSize)
698         {
699             /* Pre-load i-atom x and q into shared memory */
700             const int             ci       = sci * c_nbnxnGpuNumClusterPerSupercluster + tidxj + i;
701             const int             ai       = ci * c_clSize + tidxi;
702             const cl::sycl::id<2> cacheIdx = cl::sycl::id<2>(tidxj + i, tidxi);
703
704             const Float3 shift = a_shiftVec[nbSci.shift];
705             Float4       xqi   = a_xq[ai];
706             xqi += Float4(shift[0], shift[1], shift[2], 0.0F);
707             xqi[3] *= epsFac;
708             sm_xq[cacheIdx] = xqi;
709
710             if constexpr (!props.vdwComb)
711             {
712                 // Pre-load the i-atom types into shared memory
713                 sm_atomTypeI[cacheIdx] = a_atomTypes[ai];
714             }
715             else
716             {
717                 // Pre-load the LJ combination parameters into shared memory
718                 sm_ljCombI[cacheIdx] = a_ljComb[ai];
719             }
720         }
721         itemIdx.barrier(fence_space::local_space);
722
723         float ewaldCoeffLJ_2, ewaldCoeffLJ_6_6; // Only needed if (props.vdwEwald)
724         if constexpr (props.vdwEwald)
725         {
726             ewaldCoeffLJ_2   = ewaldCoeffLJ * ewaldCoeffLJ;
727             ewaldCoeffLJ_6_6 = ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * c_oneSixth;
728         }
729
730         float energyVdw, energyElec; // Only needed if (doCalcEnergies)
731         if constexpr (doCalcEnergies)
732         {
733             energyVdw = energyElec = 0.0F;
734         }
735         if constexpr (doCalcEnergies && doExclusionForces)
736         {
737             if (nbSci.shift == gmx::c_centralShiftIndex
738                 && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
739             {
740                 // we have the diagonal: add the charge and LJ self interaction energy term
741                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
742                 {
743                     // TODO: Are there other options?
744                     if constexpr (props.elecEwald || props.elecRF || props.elecCutoff)
745                     {
746                         const float qi = sm_xq[i][tidxi][3];
747                         energyElec += qi * qi;
748                     }
749                     if constexpr (props.vdwEwald)
750                     {
751                         energyVdw +=
752                                 a_nbfp[a_atomTypes[(sci * c_nbnxnGpuNumClusterPerSupercluster + i) * c_clSize + tidxi]
753                                        * (numTypes + 1)][0];
754                     }
755                 }
756                 /* divide the self term(s) equally over the j-threads, then multiply with the coefficients. */
757                 if constexpr (props.vdwEwald)
758                 {
759                     energyVdw /= c_clSize;
760                     energyVdw *= 0.5F * c_oneSixth * ewaldCoeffLJ_6_6; // c_OneTwelfth?
761                 }
762                 if constexpr (props.elecRF || props.elecCutoff)
763                 {
764                     // Correct for epsfac^2 due to adding qi^2 */
765                     energyElec /= epsFac * c_clSize;
766                     energyElec *= -0.5F * cRF;
767                 }
768                 if constexpr (props.elecEwald)
769                 {
770                     // Correct for epsfac^2 due to adding qi^2 */
771                     energyElec /= epsFac * c_clSize;
772                     energyElec *= -ewaldBeta * c_OneOverSqrtPi; /* last factor 1/sqrt(pi) */
773                 }
774             } // (nbSci.shift == gmx::c_centralShiftIndex && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
775         }     // (doCalcEnergies && doExclusionForces)
776
777         // Only needed if (doExclusionForces)
778         const bool nonSelfInteraction = !(nbSci.shift == gmx::c_centralShiftIndex & tidxj <= tidxi);
779
780         // loop over the j clusters = seen by any of the atoms in the current super-cluster
781         for (int j4 = cij4Start + tidxz; j4 < cij4End; j4 += 1)
782         {
783             unsigned imask = a_plistCJ4[j4].imei[imeiIdx].imask;
784             if (!doPruneNBL && !imask)
785             {
786                 continue;
787             }
788             const int wexclIdx = a_plistCJ4[j4].imei[imeiIdx].excl_ind;
789             static_assert(gmx::isPowerOfTwo(prunedClusterPairSize));
790             const unsigned wexcl = a_plistExcl[wexclIdx].pair[tidx & (prunedClusterPairSize - 1)];
791             for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
792             {
793                 const bool maskSet =
794                         imask & (superClInteractionMask << (jm * c_nbnxnGpuNumClusterPerSupercluster));
795                 if (!maskSet)
796                 {
797                     continue;
798                 }
799                 unsigned  maskJI = (1U << (jm * c_nbnxnGpuNumClusterPerSupercluster));
800                 const int cj     = a_plistCJ4[j4].cj[jm];
801                 const int aj     = cj * c_clSize + tidxj;
802
803                 // load j atom data
804                 const Float4 xqj = a_xq[aj];
805
806                 const Float3 xj(xqj[0], xqj[1], xqj[2]);
807                 const float  qj = xqj[3];
808                 int          atomTypeJ; // Only needed if (!props.vdwComb)
809                 Float2       ljCombJ;   // Only needed if (props.vdwComb)
810                 if constexpr (props.vdwComb)
811                 {
812                     ljCombJ = a_ljComb[aj];
813                 }
814                 else
815                 {
816                     atomTypeJ = a_atomTypes[aj];
817                 }
818
819                 Float3 fCjBuf(0.0F, 0.0F, 0.0F);
820
821                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
822                 {
823                     if (imask & maskJI)
824                     {
825                         // i cluster index
826                         const int ci = sci * c_nbnxnGpuNumClusterPerSupercluster + i;
827                         // all threads load an atom from i cluster ci into shmem!
828                         const Float4 xqi = sm_xq[i][tidxi];
829                         const Float3 xi(xqi[0], xqi[1], xqi[2]);
830
831                         // distance between i and j atoms
832                         const Float3 rv = xi - xj;
833                         float        r2 = norm2(rv);
834
835                         if constexpr (doPruneNBL)
836                         {
837                             /* If _none_ of the atoms pairs are in cutoff range,
838                              * the bit corresponding to the current
839                              * cluster-pair in imask gets set to 0. */
840                             if (!sycl_2020::group_any_of(sg, r2 < rlistOuterSq))
841                             {
842                                 imask &= ~maskJI;
843                             }
844                         }
845                         const float pairExclMask = (wexcl & maskJI) ? 1.0F : 0.0F;
846
847                         // cutoff & exclusion check
848
849                         const bool notExcluded = doExclusionForces ? (nonSelfInteraction | (ci != cj))
850                                                                    : (wexcl & maskJI);
851
852                         // SYCL-TODO: Check optimal way of branching here.
853                         if ((r2 < rCoulombSq) && notExcluded)
854                         {
855                             const float qi = xqi[3];
856                             int         atomTypeI; // Only needed if (!props.vdwComb)
857                             float       sigma, epsilon;
858                             Float2      c6c12;
859
860                             if constexpr (!props.vdwComb)
861                             {
862                                 /* LJ 6*C6 and 12*C12 */
863                                 atomTypeI = sm_atomTypeI[i][tidxi];
864                                 c6c12     = a_nbfp[numTypes * atomTypeI + atomTypeJ];
865                             }
866                             else
867                             {
868                                 const Float2 ljCombI = sm_ljCombI[i][tidxi];
869                                 if constexpr (props.vdwCombGeom)
870                                 {
871                                     c6c12 = Float2(ljCombI[0] * ljCombJ[0], ljCombI[1] * ljCombJ[1]);
872                                 }
873                                 else
874                                 {
875                                     static_assert(props.vdwCombLB);
876                                     // LJ 2^(1/6)*sigma and 12*epsilon
877                                     sigma   = ljCombI[0] + ljCombJ[0];
878                                     epsilon = ljCombI[1] * ljCombJ[1];
879                                     if constexpr (doCalcEnergies)
880                                     {
881                                         c6c12 = convertSigmaEpsilonToC6C12(sigma, epsilon);
882                                     }
883                                 } // props.vdwCombGeom
884                             }     // !props.vdwComb
885
886                             // c6 and c12 are unused and garbage iff props.vdwCombLB && !doCalcEnergies
887                             const float c6  = c6c12[0];
888                             const float c12 = c6c12[1];
889
890                             // Ensure distance do not become so small that r^-12 overflows
891                             r2 = std::max(r2, c_nbnxnMinDistanceSquared);
892 #if GMX_SYCL_HIPSYCL
893                             // No fast/native functions in some compilation passes
894                             const float rInv = cl::sycl::rsqrt(r2);
895 #else
896                             // SYCL-TODO: sycl::half_precision::rsqrt?
897                             const float rInv = cl::sycl::native::rsqrt(r2);
898 #endif
899                             const float r2Inv = rInv * rInv;
900                             float       r6Inv, fInvR, energyLJPair;
901                             if constexpr (!props.vdwCombLB || doCalcEnergies)
902                             {
903                                 r6Inv = r2Inv * r2Inv * r2Inv;
904                                 if constexpr (doExclusionForces)
905                                 {
906                                     // SYCL-TODO: Check if true for SYCL
907                                     /* We could mask r2Inv, but with Ewald masking both
908                                      * r6Inv and fInvR is faster */
909                                     r6Inv *= pairExclMask;
910                                 }
911                                 fInvR = r6Inv * (c12 * r6Inv - c6) * r2Inv;
912                             }
913                             else
914                             {
915                                 float sig_r  = sigma * rInv;
916                                 float sig_r2 = sig_r * sig_r;
917                                 float sig_r6 = sig_r2 * sig_r2 * sig_r2;
918                                 if constexpr (doExclusionForces)
919                                 {
920                                     sig_r6 *= pairExclMask;
921                                 }
922                                 fInvR = epsilon * sig_r6 * (sig_r6 - 1.0F) * r2Inv;
923                             } // (!props.vdwCombLB || doCalcEnergies)
924                             if constexpr (doCalcEnergies || props.vdwPSwitch)
925                             {
926                                 energyLJPair = pairExclMask
927                                                * (c12 * (r6Inv * r6Inv + repulsionShift.cpot) * c_oneTwelfth
928                                                   - c6 * (r6Inv + dispersionShift.cpot) * c_oneSixth);
929                             }
930                             if constexpr (props.vdwFSwitch)
931                             {
932                                 ljForceSwitch<doCalcEnergies>(
933                                         dispersionShift, repulsionShift, rVdwSwitch, c6, c12, rInv, r2, &fInvR, &energyLJPair);
934                             }
935                             if constexpr (props.vdwEwald)
936                             {
937                                 ljEwaldComb<doCalcEnergies, vdwType>(a_nbfpComb,
938                                                                      ljEwaldShift,
939                                                                      atomTypeI,
940                                                                      atomTypeJ,
941                                                                      r2,
942                                                                      r2Inv,
943                                                                      ewaldCoeffLJ_2,
944                                                                      ewaldCoeffLJ_6_6,
945                                                                      pairExclMask,
946                                                                      &fInvR,
947                                                                      &energyLJPair);
948                             } // (props.vdwEwald)
949                             if constexpr (props.vdwPSwitch)
950                             {
951                                 ljPotentialSwitch<doCalcEnergies>(
952                                         vdwSwitch, rVdwSwitch, rInv, r2, &fInvR, &energyLJPair);
953                             }
954                             if constexpr (props.elecEwaldTwin)
955                             {
956                                 // Separate VDW cut-off check to enable twin-range cut-offs
957                                 // (rVdw < rCoulomb <= rList)
958                                 const float vdwInRange = (r2 < rVdwSq) ? 1.0F : 0.0F;
959                                 fInvR *= vdwInRange;
960                                 if constexpr (doCalcEnergies)
961                                 {
962                                     energyLJPair *= vdwInRange;
963                                 }
964                             }
965                             if constexpr (doCalcEnergies)
966                             {
967                                 energyVdw += energyLJPair;
968                             }
969
970                             if constexpr (props.elecCutoff)
971                             {
972                                 if constexpr (doExclusionForces)
973                                 {
974                                     fInvR += qi * qj * pairExclMask * r2Inv * rInv;
975                                 }
976                                 else
977                                 {
978                                     fInvR += qi * qj * r2Inv * rInv;
979                                 }
980                             }
981                             if constexpr (props.elecRF)
982                             {
983                                 fInvR += qi * qj * (pairExclMask * r2Inv * rInv - twoKRf);
984                             }
985                             if constexpr (props.elecEwaldAna)
986                             {
987                                 fInvR += qi * qj
988                                          * (pairExclMask * r2Inv * rInv + pmeCorrF(beta2 * r2) * beta3);
989                             }
990                             if constexpr (props.elecEwaldTab)
991                             {
992                                 fInvR += qi * qj
993                                          * (pairExclMask * r2Inv
994                                             - interpolateCoulombForceR(
995                                                     a_coulombTab, coulombTabScale, r2 * rInv))
996                                          * rInv;
997                             }
998
999                             if constexpr (doCalcEnergies)
1000                             {
1001                                 if constexpr (props.elecCutoff)
1002                                 {
1003                                     energyElec += qi * qj * (pairExclMask * rInv - cRF);
1004                                 }
1005                                 if constexpr (props.elecRF)
1006                                 {
1007                                     energyElec +=
1008                                             qi * qj * (pairExclMask * rInv + 0.5F * twoKRf * r2 - cRF);
1009                                 }
1010                                 if constexpr (props.elecEwald)
1011                                 {
1012                                     energyElec +=
1013                                             qi * qj
1014                                             * (rInv * (pairExclMask - cl::sycl::erf(r2 * rInv * ewaldBeta))
1015                                                - pairExclMask * ewaldShift);
1016                                 }
1017                             }
1018
1019                             const Float3 forceIJ = rv * fInvR;
1020
1021                             /* accumulate j forces in registers */
1022                             fCjBuf -= forceIJ;
1023                             /* accumulate i forces in registers */
1024                             fCiBuf[i] += forceIJ;
1025                         } // (r2 < rCoulombSq) && notExcluded
1026                     }     // (imask & maskJI)
1027                     /* shift the mask bit by 1 */
1028                     maskJI += maskJI;
1029                 } // for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
1030                 /* reduce j forces */
1031                 reduceForceJ(sm_reductionBuffer, fCjBuf, itemIdx, tidxi, tidxj, aj, a_f);
1032             } // for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
1033             if constexpr (doPruneNBL)
1034             {
1035                 /* Update the imask with the new one which does not contain the
1036                  * out of range clusters anymore. */
1037                 a_plistCJ4[j4].imei[imeiIdx].imask = imask;
1038             }
1039         } // for (int j4 = cij4Start; j4 < cij4End; j4 += 1)
1040
1041         /* skip central shifts when summing shift forces */
1042         const bool doCalcShift = (calcShift && nbSci.shift != gmx::c_centralShiftIndex);
1043
1044         reduceForceIAndFShift(
1045                 sm_reductionBuffer, fCiBuf, doCalcShift, itemIdx, tidxi, tidxj, sci, nbSci.shift, a_f, a_fShift);
1046
1047         if constexpr (doCalcEnergies)
1048         {
1049             const float energyVdwGroup =
1050                     groupReduce<subGroupSize, c_clSizeSq>(itemIdx, tidx, sm_reductionBuffer, energyVdw);
1051             const float energyElecGroup = groupReduce<subGroupSize, c_clSizeSq>(
1052                     itemIdx, tidx, sm_reductionBuffer, energyElec);
1053
1054             if (tidx == 0)
1055             {
1056                 atomicFetchAdd(a_energyVdw, 0, energyVdwGroup);
1057                 atomicFetchAdd(a_energyElec, 0, energyElecGroup);
1058             }
1059         }
1060     };
1061 }
1062
1063 //! \brief NBNXM kernel launch code.
1064 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType, class... Args>
1065 cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int numSci, Args&&... args)
1066 {
1067     using kernelNameType = NbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>;
1068
1069     /* Kernel launch config:
1070      * - The thread block dimensions match the size of i-clusters, j-clusters,
1071      *   and j-cluster concurrency, in x, y, and z, respectively.
1072      * - The 1D block-grid contains as many blocks as super-clusters.
1073      */
1074     const int                   numBlocks = numSci;
1075     const cl::sycl::range<3>    blockSize{ c_clSize, c_clSize, 1 };
1076     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
1077     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
1078
1079     cl::sycl::queue q = deviceStream.stream();
1080
1081     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
1082         auto kernel = nbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>(
1083                 cgh, std::forward<Args>(args)...);
1084         cgh.parallel_for<kernelNameType>(flattenNDRange(range), kernel);
1085     });
1086
1087     return e;
1088 }
1089
1090 //! \brief Select templated kernel and launch it.
1091 template<class... Args>
1092 cl::sycl::event chooseAndLaunchNbnxmKernel(bool          doPruneNBL,
1093                                            bool          doCalcEnergies,
1094                                            enum ElecType elecType,
1095                                            enum VdwType  vdwType,
1096                                            Args&&... args)
1097 {
1098     return gmx::dispatchTemplatedFunction(
1099             [&](auto doPruneNBL_, auto doCalcEnergies_, auto elecType_, auto vdwType_) {
1100                 return launchNbnxmKernel<doPruneNBL_, doCalcEnergies_, elecType_, vdwType_>(
1101                         std::forward<Args>(args)...);
1102             },
1103             doPruneNBL,
1104             doCalcEnergies,
1105             elecType,
1106             vdwType);
1107 }
1108
1109 void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const InteractionLocality iloc)
1110 {
1111     NBAtomDataGpu*      adat         = nb->atdat;
1112     NBParamGpu*         nbp          = nb->nbparam;
1113     gpu_plist*          plist        = nb->plist[iloc];
1114     const bool          doPruneNBL   = (plist->haveFreshList && !nb->didPrune[iloc]);
1115     const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
1116
1117     // Casting to float simplifies using atomic ops in the kernel
1118     cl::sycl::buffer<Float3, 1> f(*adat->f.buffer_);
1119     auto                        fAsFloat = f.reinterpret<float, 1>(f.get_count() * DIM);
1120     cl::sycl::buffer<Float3, 1> fShift(*adat->fShift.buffer_);
1121     auto fShiftAsFloat = fShift.reinterpret<float, 1>(fShift.get_count() * DIM);
1122
1123     cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL,
1124                                                    stepWork.computeEnergy,
1125                                                    nbp->elecType,
1126                                                    nbp->vdwType,
1127                                                    deviceStream,
1128                                                    plist->nsci,
1129                                                    adat->xq,
1130                                                    fAsFloat,
1131                                                    adat->shiftVec,
1132                                                    fShiftAsFloat,
1133                                                    adat->eElec,
1134                                                    adat->eLJ,
1135                                                    plist->cj4,
1136                                                    plist->sci,
1137                                                    plist->excl,
1138                                                    adat->ljComb,
1139                                                    adat->atomTypes,
1140                                                    nbp->nbfp,
1141                                                    nbp->nbfp_comb,
1142                                                    nbp->coulomb_tab,
1143                                                    adat->numTypes,
1144                                                    nbp->rcoulomb_sq,
1145                                                    nbp->rvdw_sq,
1146                                                    nbp->two_k_rf,
1147                                                    nbp->ewald_beta,
1148                                                    nbp->rlistOuter_sq,
1149                                                    nbp->sh_ewald,
1150                                                    nbp->epsfac,
1151                                                    nbp->ewaldcoeff_lj,
1152                                                    nbp->c_rf,
1153                                                    nbp->dispersion_shift,
1154                                                    nbp->repulsion_shift,
1155                                                    nbp->vdw_switch,
1156                                                    nbp->rvdw_switch,
1157                                                    nbp->sh_lj_ewald,
1158                                                    nbp->coulomb_tab_scale,
1159                                                    stepWork.computeVirial);
1160 }
1161
1162 } // namespace Nbnxm