560575fdb5cdf95e0dd5859b29c6aed4384dac30
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief
38  *  NBNXM SYCL kernels
39  *
40  *  \ingroup module_nbnxm
41  */
42 #include "gmxpre.h"
43
44 #include "nbnxm_sycl_kernel.h"
45
46 #include "gromacs/gpu_utils/devicebuffer.h"
47 #include "gromacs/gpu_utils/gmxsycl.h"
48 #include "gromacs/math/functions.h"
49 #include "gromacs/mdtypes/simulation_workload.h"
50 #include "gromacs/pbcutil/ishift.h"
51 #include "gromacs/utility/template_mp.h"
52
53 #include "nbnxm_sycl_kernel_utils.h"
54 #include "nbnxm_sycl_types.h"
55
56 //! \brief Class name for NBNXM kernel
57 template<bool doPruneNBL, bool doCalcEnergies, enum Nbnxm::ElecType elecType, enum Nbnxm::VdwType vdwType>
58 class NbnxmKernel;
59
60 namespace Nbnxm
61 {
62
63 //! \brief Set of boolean constants mimicking preprocessor macros.
64 template<enum ElecType elecType, enum VdwType vdwType>
65 struct EnergyFunctionProperties {
66     static constexpr bool elecCutoff = (elecType == ElecType::Cut); ///< EL_CUTOFF
67     static constexpr bool elecRF     = (elecType == ElecType::RF);  ///< EL_RF
68     static constexpr bool elecEwaldAna =
69             (elecType == ElecType::EwaldAna || elecType == ElecType::EwaldAnaTwin); ///< EL_EWALD_ANA
70     static constexpr bool elecEwaldTab =
71             (elecType == ElecType::EwaldTab || elecType == ElecType::EwaldTabTwin); ///< EL_EWALD_TAB
72     static constexpr bool elecEwaldTwin =
73             (elecType == ElecType::EwaldAnaTwin || elecType == ElecType::EwaldTabTwin);
74     static constexpr bool elecEwald        = (elecEwaldAna || elecEwaldTab); ///< EL_EWALD_ANY
75     static constexpr bool vdwCombLB        = (vdwType == VdwType::CutCombLB);
76     static constexpr bool vdwCombGeom      = (vdwType == VdwType::CutCombGeom); ///< LJ_COMB_GEOM
77     static constexpr bool vdwComb          = (vdwCombLB || vdwCombGeom);        ///< LJ_COMB
78     static constexpr bool vdwEwaldCombGeom = (vdwType == VdwType::EwaldGeom); ///< LJ_EWALD_COMB_GEOM
79     static constexpr bool vdwEwaldCombLB   = (vdwType == VdwType::EwaldLB);   ///< LJ_EWALD_COMB_LB
80     static constexpr bool vdwEwald         = (vdwEwaldCombGeom || vdwEwaldCombLB); ///< LJ_EWALD
81     static constexpr bool vdwFSwitch       = (vdwType == VdwType::FSwitch); ///< LJ_FORCE_SWITCH
82     static constexpr bool vdwPSwitch       = (vdwType == VdwType::PSwitch); ///< LJ_POT_SWITCH
83 };
84
85 //! \brief Templated constants to shorten kernel function declaration.
86 //@{
87 template<enum VdwType vdwType>
88 constexpr bool ljComb = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwComb;
89
90 template<enum ElecType elecType> // Yes, ElecType
91 constexpr bool vdwCutoffCheck = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTwin;
92
93 template<enum ElecType elecType>
94 constexpr bool elecEwald = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwald;
95
96 template<enum ElecType elecType>
97 constexpr bool elecEwaldTab = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTab;
98
99 template<enum VdwType vdwType>
100 constexpr bool ljEwald = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwEwald;
101 //@}
102
103 using cl::sycl::access::fence_space;
104 using cl::sycl::access::mode;
105 using cl::sycl::access::target;
106
107 static inline Float2 convertSigmaEpsilonToC6C12(const float sigma, const float epsilon)
108 {
109     const float sigma2 = sigma * sigma;
110     const float sigma6 = sigma2 * sigma2 * sigma2;
111     const float c6     = epsilon * sigma6;
112     const float c12    = c6 * sigma6;
113
114     return Float2(c6, c12);
115 }
116
117 template<bool doCalcEnergies>
118 static inline void ljForceSwitch(const shift_consts_t         dispersionShift,
119                                  const shift_consts_t         repulsionShift,
120                                  const float                  rVdwSwitch,
121                                  const float                  c6,
122                                  const float                  c12,
123                                  const float                  rInv,
124                                  const float                  r2,
125                                  cl::sycl::private_ptr<float> fInvR,
126                                  cl::sycl::private_ptr<float> eLJ)
127 {
128     /* force switch constants */
129     const float dispShiftV2 = dispersionShift.c2;
130     const float dispShiftV3 = dispersionShift.c3;
131     const float repuShiftV2 = repulsionShift.c2;
132     const float repuShiftV3 = repulsionShift.c3;
133
134     const float r       = r2 * rInv;
135     const float rSwitch = cl::sycl::fdim(r, rVdwSwitch); // max(r - rVdwSwitch, 0)
136
137     *fInvR += -c6 * (dispShiftV2 + dispShiftV3 * rSwitch) * rSwitch * rSwitch * rInv
138               + c12 * (repuShiftV2 + repuShiftV3 * rSwitch) * rSwitch * rSwitch * rInv;
139
140     if constexpr (doCalcEnergies)
141     {
142         const float dispShiftF2 = dispShiftV2 / 3;
143         const float dispShiftF3 = dispShiftV3 / 4;
144         const float repuShiftF2 = repuShiftV2 / 3;
145         const float repuShiftF3 = repuShiftV3 / 4;
146         *eLJ += c6 * (dispShiftF2 + dispShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch
147                 - c12 * (repuShiftF2 + repuShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch;
148     }
149 }
150
151 //! \brief Fetch C6 grid contribution coefficients and return the product of these.
152 template<enum VdwType vdwType>
153 static inline float calculateLJEwaldC6Grid(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
154                                            const int                                typeI,
155                                            const int                                typeJ)
156 {
157     if constexpr (vdwType == VdwType::EwaldGeom)
158     {
159         return a_nbfpComb[typeI][0] * a_nbfpComb[typeJ][0];
160     }
161     else
162     {
163         static_assert(vdwType == VdwType::EwaldLB);
164         /* sigma and epsilon are scaled to give 6*C6 */
165         const Float2 c6c12_i = a_nbfpComb[typeI];
166         const Float2 c6c12_j = a_nbfpComb[typeJ];
167
168         const float sigma   = c6c12_i[0] + c6c12_j[0];
169         const float epsilon = c6c12_i[1] * c6c12_j[1];
170
171         const float sigma2 = sigma * sigma;
172         return epsilon * sigma2 * sigma2 * sigma2;
173     }
174 }
175
176 //! Calculate LJ-PME grid force contribution with geometric or LB combination rule.
177 template<bool doCalcEnergies, enum VdwType vdwType>
178 static inline void ljEwaldComb(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
179                                const float                              sh_lj_ewald,
180                                const int                                typeI,
181                                const int                                typeJ,
182                                const float                              r2,
183                                const float                              r2Inv,
184                                const float                              lje_coeff2,
185                                const float                              lje_coeff6_6,
186                                const float                              int_bit,
187                                cl::sycl::private_ptr<float>             fInvR,
188                                cl::sycl::private_ptr<float>             eLJ)
189 {
190     const float c6grid = calculateLJEwaldC6Grid<vdwType>(a_nbfpComb, typeI, typeJ);
191
192     /* Recalculate inv_r6 without exclusion mask */
193     const float inv_r6_nm = r2Inv * r2Inv * r2Inv;
194     const float cr2       = lje_coeff2 * r2;
195     const float expmcr2   = cl::sycl::exp(-cr2);
196     const float poly      = 1.0F + cr2 + 0.5F * cr2 * cr2;
197
198     /* Subtract the grid force from the total LJ force */
199     *fInvR += c6grid * (inv_r6_nm - expmcr2 * (inv_r6_nm * poly + lje_coeff6_6)) * r2Inv;
200
201     if constexpr (doCalcEnergies)
202     {
203         /* Shift should be applied only to real LJ pairs */
204         const float sh_mask = sh_lj_ewald * int_bit;
205         *eLJ += c_oneSixth * c6grid * (inv_r6_nm * (1.0F - expmcr2 * poly) + sh_mask);
206     }
207 }
208
209 /*! \brief Apply potential switch. */
210 template<bool doCalcEnergies>
211 static inline void ljPotentialSwitch(const switch_consts_t        vdwSwitch,
212                                      const float                  rVdwSwitch,
213                                      const float                  rInv,
214                                      const float                  r2,
215                                      cl::sycl::private_ptr<float> fInvR,
216                                      cl::sycl::private_ptr<float> eLJ)
217 {
218     /* potential switch constants */
219     const float switchV3 = vdwSwitch.c3;
220     const float switchV4 = vdwSwitch.c4;
221     const float switchV5 = vdwSwitch.c5;
222     const float switchF2 = 3 * vdwSwitch.c3;
223     const float switchF3 = 4 * vdwSwitch.c4;
224     const float switchF4 = 5 * vdwSwitch.c5;
225
226     const float r       = r2 * rInv;
227     const float rSwitch = r - rVdwSwitch;
228
229     if (rSwitch > 0.0F)
230     {
231         const float sw =
232                 1.0F + (switchV3 + (switchV4 + switchV5 * rSwitch) * rSwitch) * rSwitch * rSwitch * rSwitch;
233         const float dsw = (switchF2 + (switchF3 + switchF4 * rSwitch) * rSwitch) * rSwitch * rSwitch;
234
235         *fInvR = (*fInvR) * sw - rInv * (*eLJ) * dsw;
236         if constexpr (doCalcEnergies)
237         {
238             *eLJ *= sw;
239         }
240     }
241 }
242
243
244 /*! \brief Calculate analytical Ewald correction term. */
245 static inline float pmeCorrF(const float z2)
246 {
247     constexpr float FN6 = -1.7357322914161492954e-8F;
248     constexpr float FN5 = 1.4703624142580877519e-6F;
249     constexpr float FN4 = -0.000053401640219807709149F;
250     constexpr float FN3 = 0.0010054721316683106153F;
251     constexpr float FN2 = -0.019278317264888380590F;
252     constexpr float FN1 = 0.069670166153766424023F;
253     constexpr float FN0 = -0.75225204789749321333F;
254
255     constexpr float FD4 = 0.0011193462567257629232F;
256     constexpr float FD3 = 0.014866955030185295499F;
257     constexpr float FD2 = 0.11583842382862377919F;
258     constexpr float FD1 = 0.50736591960530292870F;
259     constexpr float FD0 = 1.0F;
260
261     const float z4 = z2 * z2;
262
263     float       polyFD0 = FD4 * z4 + FD2;
264     const float polyFD1 = FD3 * z4 + FD1;
265     polyFD0             = polyFD0 * z4 + FD0;
266     polyFD0             = polyFD1 * z2 + polyFD0;
267
268     polyFD0 = 1.0F / polyFD0;
269
270     float polyFN0 = FN6 * z4 + FN4;
271     float polyFN1 = FN5 * z4 + FN3;
272     polyFN0       = polyFN0 * z4 + FN2;
273     polyFN1       = polyFN1 * z4 + FN1;
274     polyFN0       = polyFN0 * z4 + FN0;
275     polyFN0       = polyFN1 * z2 + polyFN0;
276
277     return polyFN0 * polyFD0;
278 }
279
280 /*! \brief Linear interpolation using exactly two FMA operations.
281  *
282  *  Implements numeric equivalent of: (1-t)*d0 + t*d1.
283  */
284 template<typename T>
285 static inline T lerp(T d0, T d1, T t)
286 {
287     return fma(t, d1, fma(-t, d0, d0));
288 }
289
290 /*! \brief Interpolate Ewald coulomb force correction using the F*r table. */
291 static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::read> a_coulombTab,
292                                              const float coulombTabScale,
293                                              const float r)
294 {
295     const float normalized = coulombTabScale * r;
296     const int   index      = static_cast<int>(normalized);
297     const float fraction   = normalized - index;
298
299     const float left  = a_coulombTab[index];
300     const float right = a_coulombTab[index + 1];
301
302     return lerp(left, right, fraction); // TODO: cl::sycl::mix
303 }
304
305 /*! \brief Reduce c_clSize j-force components using shifts and atomically accumulate into a_f.
306  *
307  * c_clSize consecutive threads hold the force components of a j-atom which we
308  * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
309  */
310 static inline void reduceForceJShuffle(Float3                             f,
311                                        const cl::sycl::nd_item<1>         itemIdx,
312                                        const int                          tidxi,
313                                        const int                          aidx,
314                                        DeviceAccessor<float, mode_atomic> a_f)
315 {
316     static_assert(c_clSize == 8 || c_clSize == 4);
317     sycl_2020::sub_group sg = itemIdx.get_sub_group();
318
319     f[0] += sycl_2020::shift_left(sg, f[0], 1);
320     f[1] += sycl_2020::shift_right(sg, f[1], 1);
321     f[2] += sycl_2020::shift_left(sg, f[2], 1);
322     if (tidxi & 1)
323     {
324         f[0] = f[1];
325     }
326
327     f[0] += sycl_2020::shift_left(sg, f[0], 2);
328     f[2] += sycl_2020::shift_right(sg, f[2], 2);
329     if (tidxi & 2)
330     {
331         f[0] = f[2];
332     }
333
334     if constexpr (c_clSize == 8)
335     {
336         f[0] += sycl_2020::shift_left(sg, f[0], 4);
337     }
338
339     if (tidxi < 3)
340     {
341         atomicFetchAdd(a_f, 3 * aidx + tidxi, f[0]);
342     }
343 }
344
345 /*! \brief Reduce c_clSize j-force components using local memory and atomically accumulate into a_f.
346  *
347  * c_clSize consecutive threads hold the force components of a j-atom which we
348  * reduced in cl_Size steps using shift and atomically accumulate them into \p a_f.
349  *
350  * TODO: implement binary reduction flavor for the case where cl_Size is power of two.
351  */
352 static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
353                                        Float3                             f,
354                                        const cl::sycl::nd_item<1>         itemIdx,
355                                        const int                          tidxi,
356                                        const int                          tidxj,
357                                        const int                          aidx,
358                                        DeviceAccessor<float, mode_atomic> a_f)
359 {
360     static constexpr int sc_fBufferStride = c_clSizeSq;
361     int                  tidx             = tidxi + tidxj * c_clSize;
362     sm_buf[0 * sc_fBufferStride + tidx]   = f[0];
363     sm_buf[1 * sc_fBufferStride + tidx]   = f[1];
364     sm_buf[2 * sc_fBufferStride + tidx]   = f[2];
365
366     subGroupBarrier(itemIdx);
367
368     // reducing data 8-by-by elements on the leader of same threads as those storing above
369     assert(itemIdx.get_sub_group().get_local_range().size() >= c_clSize);
370
371     if (tidxi < 3)
372     {
373         float fSum = 0.0F;
374         for (int j = tidxj * c_clSize; j < (tidxj + 1) * c_clSize; j++)
375         {
376             fSum += sm_buf[sc_fBufferStride * tidxi + j];
377         }
378
379         atomicFetchAdd(a_f, 3 * aidx + tidxi, fSum);
380     }
381 }
382
383
384 /*! \brief Reduce c_clSize j-force components using either shifts or local memory and atomically accumulate into a_f.
385  */
386 static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
387                                 Float3                                                        f,
388                                 const cl::sycl::nd_item<1>         itemIdx,
389                                 const int                          tidxi,
390                                 const int                          tidxj,
391                                 const int                          aidx,
392                                 DeviceAccessor<float, mode_atomic> a_f)
393 {
394     if constexpr (!gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster))
395     {
396         reduceForceJGeneric(sm_buf, f, itemIdx, tidxi, tidxj, aidx, a_f);
397     }
398     else
399     {
400         reduceForceJShuffle(f, itemIdx, tidxi, aidx, a_f);
401     }
402 }
403
404
405 /*! \brief Final i-force reduction.
406  *
407  * Reduce c_nbnxnGpuNumClusterPerSupercluster i-force componets stored in \p fCiBuf[]
408  * accumulating atomically into \p a_f.
409  * If \p calcFShift is true, further reduce shift forces and atomically accumulate into \p a_fShift.
410  *
411  * This implementation works only with power of two array sizes.
412  */
413 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
414                                          const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
415                                          const bool   calcFShift,
416                                          const cl::sycl::nd_item<1>         itemIdx,
417                                          const int                          tidxi,
418                                          const int                          tidxj,
419                                          const int                          sci,
420                                          const int                          shift,
421                                          DeviceAccessor<float, mode_atomic> a_f,
422                                          DeviceAccessor<float, mode_atomic> a_fShift)
423 {
424     // must have power of two elements in fCiBuf
425     static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster));
426
427     static constexpr int bufStride  = c_clSize * c_clSize;
428     static constexpr int clSizeLog2 = gmx::StaticLog2<c_clSize>::value;
429     const int            tidx       = tidxi + tidxj * c_clSize;
430     float                fShiftBuf  = 0.0F;
431     for (int ciOffset = 0; ciOffset < c_nbnxnGpuNumClusterPerSupercluster; ciOffset++)
432     {
433         const int aidx = (sci * c_nbnxnGpuNumClusterPerSupercluster + ciOffset) * c_clSize + tidxi;
434         /* store i forces in shmem */
435         sm_buf[tidx]                 = fCiBuf[ciOffset][0];
436         sm_buf[bufStride + tidx]     = fCiBuf[ciOffset][1];
437         sm_buf[2 * bufStride + tidx] = fCiBuf[ciOffset][2];
438         itemIdx.barrier(fence_space::local_space);
439
440         /* Reduce the initial c_clSize values for each i atom to half
441          * every step by using c_clSize * i threads. */
442         int i = c_clSize / 2;
443         for (int j = clSizeLog2 - 1; j > 0; j--)
444         {
445             if (tidxj < i)
446             {
447                 sm_buf[tidxj * c_clSize + tidxi] += sm_buf[(tidxj + i) * c_clSize + tidxi];
448                 sm_buf[bufStride + tidxj * c_clSize + tidxi] +=
449                         sm_buf[bufStride + (tidxj + i) * c_clSize + tidxi];
450                 sm_buf[2 * bufStride + tidxj * c_clSize + tidxi] +=
451                         sm_buf[2 * bufStride + (tidxj + i) * c_clSize + tidxi];
452             }
453             i >>= 1;
454             itemIdx.barrier(fence_space::local_space);
455         }
456
457         /* i == 1, last reduction step, writing to global mem */
458         /* Split the reduction between the first 3 line threads
459            Threads with line id 0 will do the reduction for (float3).x components
460            Threads with line id 1 will do the reduction for (float3).y components
461            Threads with line id 2 will do the reduction for (float3).z components. */
462         if (tidxj < 3)
463         {
464             const float f =
465                     sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
466             atomicFetchAdd(a_f, 3 * aidx + tidxj, f);
467             if (calcFShift)
468             {
469                 fShiftBuf += f;
470             }
471         }
472         itemIdx.barrier(fence_space::local_space);
473     }
474     /* add up local shift forces into global mem */
475     if (calcFShift)
476     {
477         /* Only threads with tidxj < 3 will update fshift.
478            The threads performing the update must be the same as the threads
479            storing the reduction result above. */
480         if (tidxj < 3)
481         {
482             if constexpr (c_clSize == 4)
483             {
484                 /* Intel Xe (Gen12LP) and earlier GPUs implement floating-point atomics via
485                  * a compare-and-swap (CAS) loop. It has particularly poor performance when
486                  * updating the same memory location from the same work-group.
487                  * Such optimization might be slightly beneficial for NVIDIA and AMD as well,
488                  * but it is unlikely to make a big difference and thus was not evaluated.
489                  */
490                 auto sg = itemIdx.get_sub_group();
491                 fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 1);
492                 fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 2);
493                 if (tidxi == 0)
494                 {
495                     atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
496                 }
497             }
498             else
499             {
500                 atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
501             }
502         }
503     }
504 }
505
506 /*! \brief Main kernel for NBNXM.
507  *
508  */
509 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
510 auto nbnxmKernel(cl::sycl::handler&                                   cgh,
511                  DeviceAccessor<Float4, mode::read>                   a_xq,
512                  DeviceAccessor<float, mode_atomic>                   a_f,
513                  DeviceAccessor<Float3, mode::read>                   a_shiftVec,
514                  DeviceAccessor<float, mode_atomic>                   a_fShift,
515                  OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyElec,
516                  OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyVdw,
517                  DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
518                  DeviceAccessor<nbnxn_sci_t, mode::read>                                 a_plistSci,
519                  DeviceAccessor<nbnxn_excl_t, mode::read>                    a_plistExcl,
520                  OptionalAccessor<Float2, mode::read, ljComb<vdwType>>       a_ljComb,
521                  OptionalAccessor<int, mode::read, !ljComb<vdwType>>         a_atomTypes,
522                  OptionalAccessor<Float2, mode::read, !ljComb<vdwType>>      a_nbfp,
523                  OptionalAccessor<Float2, mode::read, ljEwald<vdwType>>      a_nbfpComb,
524                  OptionalAccessor<float, mode::read, elecEwaldTab<elecType>> a_coulombTab,
525                  const int                                                   numTypes,
526                  const float                                                 rCoulombSq,
527                  const float                                                 rVdwSq,
528                  const float                                                 twoKRf,
529                  const float                                                 ewaldBeta,
530                  const float                                                 rlistOuterSq,
531                  const float                                                 ewaldShift,
532                  const float                                                 epsFac,
533                  const float                                                 ewaldCoeffLJ,
534                  const float                                                 cRF,
535                  const shift_consts_t                                        dispersionShift,
536                  const shift_consts_t                                        repulsionShift,
537                  const switch_consts_t                                       vdwSwitch,
538                  const float                                                 rVdwSwitch,
539                  const float                                                 ljEwaldShift,
540                  const float                                                 coulombTabScale,
541                  const bool                                                  calcShift)
542 {
543     static constexpr EnergyFunctionProperties<elecType, vdwType> props;
544
545     cgh.require(a_xq);
546     cgh.require(a_f);
547     cgh.require(a_shiftVec);
548     cgh.require(a_fShift);
549     if constexpr (doCalcEnergies)
550     {
551         cgh.require(a_energyElec);
552         cgh.require(a_energyVdw);
553     }
554     cgh.require(a_plistCJ4);
555     cgh.require(a_plistSci);
556     cgh.require(a_plistExcl);
557     if constexpr (!props.vdwComb)
558     {
559         cgh.require(a_atomTypes);
560         cgh.require(a_nbfp);
561     }
562     else
563     {
564         cgh.require(a_ljComb);
565     }
566     if constexpr (props.vdwEwald)
567     {
568         cgh.require(a_nbfpComb);
569     }
570     if constexpr (props.elecEwaldTab)
571     {
572         cgh.require(a_coulombTab);
573     }
574
575     // shmem buffer for i x+q pre-loading
576     cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(
577             cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
578
579     // shmem buffer for force reduction
580     // SYCL-TODO: Make into 3D; section 4.7.6.11 of SYCL2020 specs
581     cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_reductionBuffer(
582             cl::sycl::range<1>(c_clSize * c_clSize * DIM), cgh);
583
584     auto sm_atomTypeI = [&]() {
585         if constexpr (!props.vdwComb)
586         {
587             return cl::sycl::accessor<int, 2, mode::read_write, target::local>(
588                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
589         }
590         else
591         {
592             return nullptr;
593         }
594     }();
595
596     auto sm_ljCombI = [&]() {
597         if constexpr (props.vdwComb)
598         {
599             return cl::sycl::accessor<Float2, 2, mode::read_write, target::local>(
600                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
601         }
602         else
603         {
604             return nullptr;
605         }
606     }();
607
608     /* Flag to control the calculation of exclusion forces in the kernel
609      * We do that with Ewald (elec/vdw) and RF. Cut-off only has exclusion
610      * energy terms. */
611     constexpr bool doExclusionForces =
612             (props.elecEwald || props.elecRF || props.vdwEwald || (props.elecCutoff && doCalcEnergies));
613
614     // The post-prune j-i cluster-pair organization is linked to how exclusion and interaction mask data is stored.
615     // Currently this is ideally suited for 32-wide subgroup size but slightly less so for others,
616     // e.g. subGroupSize > prunedClusterPairSize on AMD GCN / CDNA.
617     // Hence, the two are decoupled.
618     // When changing this code, please update requiredSubGroupSizeForNbnxm in src/gromacs/hardware/device_management_sycl.cpp.
619     constexpr int prunedClusterPairSize = c_clSize * c_splitClSize;
620 #if defined(HIPSYCL_PLATFORM_ROCM) // SYCL-TODO AMD RDNA/RDNA2 has 32-wide exec; how can we check for that?
621     gmx_unused constexpr int subGroupSize = c_clSize * c_clSize;
622 #else
623     gmx_unused constexpr int subGroupSize = prunedClusterPairSize;
624 #endif
625
626     return [=](cl::sycl::nd_item<1> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
627     {
628         /* thread/block/warp id-s */
629         const cl::sycl::id<3> localId = unflattenId<c_clSize, c_clSize>(itemIdx.get_local_id());
630         const unsigned        tidxi   = localId[0];
631         const unsigned        tidxj   = localId[1];
632         const unsigned        tidx    = tidxj * c_clSize + tidxi;
633         const unsigned        tidxz   = 0;
634
635         // Group indexing was flat originally, no need to unflatten it.
636         const unsigned bidx = itemIdx.get_group(0);
637
638         const sycl_2020::sub_group sg = itemIdx.get_sub_group();
639         // Could use sg.get_group_range to compute the imask & exclusion Idx, but too much of the logic relies on it anyway
640         // and in cases where prunedClusterPairSize != subGroupSize we can't use it anyway
641         const unsigned imeiIdx = tidx / prunedClusterPairSize;
642
643         Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster]; // i force buffer
644         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
645         {
646             fCiBuf[i] = Float3(0.0F, 0.0F, 0.0F);
647         }
648
649         const nbnxn_sci_t nbSci     = a_plistSci[bidx];
650         const int         sci       = nbSci.sci;
651         const int         cij4Start = nbSci.cj4_ind_start;
652         const int         cij4End   = nbSci.cj4_ind_end;
653
654         // Only needed if props.elecEwaldAna
655         const float beta2 = ewaldBeta * ewaldBeta;
656         const float beta3 = ewaldBeta * ewaldBeta * ewaldBeta;
657
658         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i += c_clSize)
659         {
660             /* Pre-load i-atom x and q into shared memory */
661             const int             ci       = sci * c_nbnxnGpuNumClusterPerSupercluster + tidxj + i;
662             const int             ai       = ci * c_clSize + tidxi;
663             const cl::sycl::id<2> cacheIdx = cl::sycl::id<2>(tidxj + i, tidxi);
664
665             const Float3 shift = a_shiftVec[nbSci.shift];
666             Float4       xqi   = a_xq[ai];
667             xqi += Float4(shift[0], shift[1], shift[2], 0.0F);
668             xqi[3] *= epsFac;
669             sm_xq[cacheIdx] = xqi;
670
671             if constexpr (!props.vdwComb)
672             {
673                 // Pre-load the i-atom types into shared memory
674                 sm_atomTypeI[cacheIdx] = a_atomTypes[ai];
675             }
676             else
677             {
678                 // Pre-load the LJ combination parameters into shared memory
679                 sm_ljCombI[cacheIdx] = a_ljComb[ai];
680             }
681         }
682         itemIdx.barrier(fence_space::local_space);
683
684         float ewaldCoeffLJ_2, ewaldCoeffLJ_6_6; // Only needed if (props.vdwEwald)
685         if constexpr (props.vdwEwald)
686         {
687             ewaldCoeffLJ_2   = ewaldCoeffLJ * ewaldCoeffLJ;
688             ewaldCoeffLJ_6_6 = ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * c_oneSixth;
689         }
690
691         float energyVdw, energyElec; // Only needed if (doCalcEnergies)
692         if constexpr (doCalcEnergies)
693         {
694             energyVdw = energyElec = 0.0F;
695         }
696         if constexpr (doCalcEnergies && doExclusionForces)
697         {
698             if (nbSci.shift == gmx::c_centralShiftIndex
699                 && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
700             {
701                 // we have the diagonal: add the charge and LJ self interaction energy term
702                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
703                 {
704                     // TODO: Are there other options?
705                     if constexpr (props.elecEwald || props.elecRF || props.elecCutoff)
706                     {
707                         const float qi = sm_xq[i][tidxi][3];
708                         energyElec += qi * qi;
709                     }
710                     if constexpr (props.vdwEwald)
711                     {
712                         energyVdw +=
713                                 a_nbfp[a_atomTypes[(sci * c_nbnxnGpuNumClusterPerSupercluster + i) * c_clSize + tidxi]
714                                        * (numTypes + 1)][0];
715                     }
716                 }
717                 /* divide the self term(s) equally over the j-threads, then multiply with the coefficients. */
718                 if constexpr (props.vdwEwald)
719                 {
720                     energyVdw /= c_clSize;
721                     energyVdw *= 0.5F * c_oneSixth * ewaldCoeffLJ_6_6; // c_OneTwelfth?
722                 }
723                 if constexpr (props.elecRF || props.elecCutoff)
724                 {
725                     // Correct for epsfac^2 due to adding qi^2 */
726                     energyElec /= epsFac * c_clSize;
727                     energyElec *= -0.5F * cRF;
728                 }
729                 if constexpr (props.elecEwald)
730                 {
731                     // Correct for epsfac^2 due to adding qi^2 */
732                     energyElec /= epsFac * c_clSize;
733                     energyElec *= -ewaldBeta * c_OneOverSqrtPi; /* last factor 1/sqrt(pi) */
734                 }
735             } // (nbSci.shift == gmx::c_centralShiftIndex && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
736         }     // (doCalcEnergies && doExclusionForces)
737
738         // Only needed if (doExclusionForces)
739         const bool nonSelfInteraction = !(nbSci.shift == gmx::c_centralShiftIndex & tidxj <= tidxi);
740
741         // loop over the j clusters = seen by any of the atoms in the current super-cluster
742         for (int j4 = cij4Start + tidxz; j4 < cij4End; j4 += 1)
743         {
744             unsigned imask = a_plistCJ4[j4].imei[imeiIdx].imask;
745             if (!doPruneNBL && !imask)
746             {
747                 continue;
748             }
749             const int wexclIdx = a_plistCJ4[j4].imei[imeiIdx].excl_ind;
750             static_assert(gmx::isPowerOfTwo(prunedClusterPairSize));
751             const unsigned wexcl = a_plistExcl[wexclIdx].pair[tidx & (prunedClusterPairSize - 1)];
752             for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
753             {
754                 const bool maskSet =
755                         imask & (superClInteractionMask << (jm * c_nbnxnGpuNumClusterPerSupercluster));
756                 if (!maskSet)
757                 {
758                     continue;
759                 }
760                 unsigned  maskJI = (1U << (jm * c_nbnxnGpuNumClusterPerSupercluster));
761                 const int cj     = a_plistCJ4[j4].cj[jm];
762                 const int aj     = cj * c_clSize + tidxj;
763
764                 // load j atom data
765                 const Float4 xqj = a_xq[aj];
766
767                 const Float3 xj(xqj[0], xqj[1], xqj[2]);
768                 const float  qj = xqj[3];
769                 int          atomTypeJ; // Only needed if (!props.vdwComb)
770                 Float2       ljCombJ;   // Only needed if (props.vdwComb)
771                 if constexpr (props.vdwComb)
772                 {
773                     ljCombJ = a_ljComb[aj];
774                 }
775                 else
776                 {
777                     atomTypeJ = a_atomTypes[aj];
778                 }
779
780                 Float3 fCjBuf(0.0F, 0.0F, 0.0F);
781
782                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
783                 {
784                     if (imask & maskJI)
785                     {
786                         // i cluster index
787                         const int ci = sci * c_nbnxnGpuNumClusterPerSupercluster + i;
788                         // all threads load an atom from i cluster ci into shmem!
789                         const Float4 xqi = sm_xq[i][tidxi];
790                         const Float3 xi(xqi[0], xqi[1], xqi[2]);
791
792                         // distance between i and j atoms
793                         const Float3 rv = xi - xj;
794                         float        r2 = norm2(rv);
795
796                         if constexpr (doPruneNBL)
797                         {
798                             /* If _none_ of the atoms pairs are in cutoff range,
799                              * the bit corresponding to the current
800                              * cluster-pair in imask gets set to 0. */
801                             if (!sycl_2020::group_any_of(sg, r2 < rlistOuterSq))
802                             {
803                                 imask &= ~maskJI;
804                             }
805                         }
806                         const float pairExclMask = (wexcl & maskJI) ? 1.0F : 0.0F;
807
808                         // cutoff & exclusion check
809
810                         const bool notExcluded = doExclusionForces ? (nonSelfInteraction | (ci != cj))
811                                                                    : (wexcl & maskJI);
812
813                         // SYCL-TODO: Check optimal way of branching here.
814                         if ((r2 < rCoulombSq) && notExcluded)
815                         {
816                             const float qi = xqi[3];
817                             int         atomTypeI; // Only needed if (!props.vdwComb)
818                             float       sigma, epsilon;
819                             Float2      c6c12;
820
821                             if constexpr (!props.vdwComb)
822                             {
823                                 /* LJ 6*C6 and 12*C12 */
824                                 atomTypeI = sm_atomTypeI[i][tidxi];
825                                 c6c12     = a_nbfp[numTypes * atomTypeI + atomTypeJ];
826                             }
827                             else
828                             {
829                                 const Float2 ljCombI = sm_ljCombI[i][tidxi];
830                                 if constexpr (props.vdwCombGeom)
831                                 {
832                                     c6c12 = Float2(ljCombI[0] * ljCombJ[0], ljCombI[1] * ljCombJ[1]);
833                                 }
834                                 else
835                                 {
836                                     static_assert(props.vdwCombLB);
837                                     // LJ 2^(1/6)*sigma and 12*epsilon
838                                     sigma   = ljCombI[0] + ljCombJ[0];
839                                     epsilon = ljCombI[1] * ljCombJ[1];
840                                     if constexpr (doCalcEnergies)
841                                     {
842                                         c6c12 = convertSigmaEpsilonToC6C12(sigma, epsilon);
843                                     }
844                                 } // props.vdwCombGeom
845                             }     // !props.vdwComb
846
847                             // c6 and c12 are unused and garbage iff props.vdwCombLB && !doCalcEnergies
848                             const float c6  = c6c12[0];
849                             const float c12 = c6c12[1];
850
851                             // Ensure distance do not become so small that r^-12 overflows
852                             r2 = std::max(r2, c_nbnxnMinDistanceSquared);
853 #if GMX_SYCL_HIPSYCL
854                             // No fast/native functions in some compilation passes
855                             const float rInv = cl::sycl::rsqrt(r2);
856 #else
857                             // SYCL-TODO: sycl::half_precision::rsqrt?
858                             const float rInv = cl::sycl::native::rsqrt(r2);
859 #endif
860                             const float r2Inv = rInv * rInv;
861                             float       r6Inv, fInvR, energyLJPair;
862                             if constexpr (!props.vdwCombLB || doCalcEnergies)
863                             {
864                                 r6Inv = r2Inv * r2Inv * r2Inv;
865                                 if constexpr (doExclusionForces)
866                                 {
867                                     // SYCL-TODO: Check if true for SYCL
868                                     /* We could mask r2Inv, but with Ewald masking both
869                                      * r6Inv and fInvR is faster */
870                                     r6Inv *= pairExclMask;
871                                 }
872                                 fInvR = r6Inv * (c12 * r6Inv - c6) * r2Inv;
873                             }
874                             else
875                             {
876                                 float sig_r  = sigma * rInv;
877                                 float sig_r2 = sig_r * sig_r;
878                                 float sig_r6 = sig_r2 * sig_r2 * sig_r2;
879                                 if constexpr (doExclusionForces)
880                                 {
881                                     sig_r6 *= pairExclMask;
882                                 }
883                                 fInvR = epsilon * sig_r6 * (sig_r6 - 1.0F) * r2Inv;
884                             } // (!props.vdwCombLB || doCalcEnergies)
885                             if constexpr (doCalcEnergies || props.vdwPSwitch)
886                             {
887                                 energyLJPair = pairExclMask
888                                                * (c12 * (r6Inv * r6Inv + repulsionShift.cpot) * c_oneTwelfth
889                                                   - c6 * (r6Inv + dispersionShift.cpot) * c_oneSixth);
890                             }
891                             if constexpr (props.vdwFSwitch)
892                             {
893                                 ljForceSwitch<doCalcEnergies>(
894                                         dispersionShift, repulsionShift, rVdwSwitch, c6, c12, rInv, r2, &fInvR, &energyLJPair);
895                             }
896                             if constexpr (props.vdwEwald)
897                             {
898                                 ljEwaldComb<doCalcEnergies, vdwType>(a_nbfpComb,
899                                                                      ljEwaldShift,
900                                                                      atomTypeI,
901                                                                      atomTypeJ,
902                                                                      r2,
903                                                                      r2Inv,
904                                                                      ewaldCoeffLJ_2,
905                                                                      ewaldCoeffLJ_6_6,
906                                                                      pairExclMask,
907                                                                      &fInvR,
908                                                                      &energyLJPair);
909                             } // (props.vdwEwald)
910                             if constexpr (props.vdwPSwitch)
911                             {
912                                 ljPotentialSwitch<doCalcEnergies>(
913                                         vdwSwitch, rVdwSwitch, rInv, r2, &fInvR, &energyLJPair);
914                             }
915                             if constexpr (props.elecEwaldTwin)
916                             {
917                                 // Separate VDW cut-off check to enable twin-range cut-offs
918                                 // (rVdw < rCoulomb <= rList)
919                                 const float vdwInRange = (r2 < rVdwSq) ? 1.0F : 0.0F;
920                                 fInvR *= vdwInRange;
921                                 if constexpr (doCalcEnergies)
922                                 {
923                                     energyLJPair *= vdwInRange;
924                                 }
925                             }
926                             if constexpr (doCalcEnergies)
927                             {
928                                 energyVdw += energyLJPair;
929                             }
930
931                             if constexpr (props.elecCutoff)
932                             {
933                                 if constexpr (doExclusionForces)
934                                 {
935                                     fInvR += qi * qj * pairExclMask * r2Inv * rInv;
936                                 }
937                                 else
938                                 {
939                                     fInvR += qi * qj * r2Inv * rInv;
940                                 }
941                             }
942                             if constexpr (props.elecRF)
943                             {
944                                 fInvR += qi * qj * (pairExclMask * r2Inv * rInv - twoKRf);
945                             }
946                             if constexpr (props.elecEwaldAna)
947                             {
948                                 fInvR += qi * qj
949                                          * (pairExclMask * r2Inv * rInv + pmeCorrF(beta2 * r2) * beta3);
950                             }
951                             if constexpr (props.elecEwaldTab)
952                             {
953                                 fInvR += qi * qj
954                                          * (pairExclMask * r2Inv
955                                             - interpolateCoulombForceR(
956                                                     a_coulombTab, coulombTabScale, r2 * rInv))
957                                          * rInv;
958                             }
959
960                             if constexpr (doCalcEnergies)
961                             {
962                                 if constexpr (props.elecCutoff)
963                                 {
964                                     energyElec += qi * qj * (pairExclMask * rInv - cRF);
965                                 }
966                                 if constexpr (props.elecRF)
967                                 {
968                                     energyElec +=
969                                             qi * qj * (pairExclMask * rInv + 0.5f * twoKRf * r2 - cRF);
970                                 }
971                                 if constexpr (props.elecEwald)
972                                 {
973                                     energyElec +=
974                                             qi * qj
975                                             * (rInv * (pairExclMask - cl::sycl::erf(r2 * rInv * ewaldBeta))
976                                                - pairExclMask * ewaldShift);
977                                 }
978                             }
979
980                             const Float3 forceIJ = rv * fInvR;
981
982                             /* accumulate j forces in registers */
983                             fCjBuf -= forceIJ;
984                             /* accumulate i forces in registers */
985                             fCiBuf[i] += forceIJ;
986                         } // (r2 < rCoulombSq) && notExcluded
987                     }     // (imask & maskJI)
988                     /* shift the mask bit by 1 */
989                     maskJI += maskJI;
990                 } // for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
991                 /* reduce j forces */
992                 reduceForceJ(sm_reductionBuffer, fCjBuf, itemIdx, tidxi, tidxj, aj, a_f);
993             } // for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
994             if constexpr (doPruneNBL)
995             {
996                 /* Update the imask with the new one which does not contain the
997                  * out of range clusters anymore. */
998                 a_plistCJ4[j4].imei[imeiIdx].imask = imask;
999             }
1000         } // for (int j4 = cij4Start; j4 < cij4End; j4 += 1)
1001
1002         /* skip central shifts when summing shift forces */
1003         const bool doCalcShift = (calcShift && !(nbSci.shift == gmx::c_centralShiftIndex));
1004
1005         reduceForceIAndFShift(
1006                 sm_reductionBuffer, fCiBuf, doCalcShift, itemIdx, tidxi, tidxj, sci, nbSci.shift, a_f, a_fShift);
1007
1008         if constexpr (doCalcEnergies)
1009         {
1010             const float energyVdwGroup = sycl_2020::group_reduce(
1011                     itemIdx.get_group(), energyVdw, 0.0F, sycl_2020::plus<float>());
1012             const float energyElecGroup = sycl_2020::group_reduce(
1013                     itemIdx.get_group(), energyElec, 0.0F, sycl_2020::plus<float>());
1014
1015             if (tidx == 0)
1016             {
1017                 atomicFetchAdd(a_energyVdw, 0, energyVdwGroup);
1018                 atomicFetchAdd(a_energyElec, 0, energyElecGroup);
1019             }
1020         }
1021     };
1022 }
1023
1024 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType, class... Args>
1025 cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int numSci, Args&&... args)
1026 {
1027     using kernelNameType = NbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>;
1028
1029     /* Kernel launch config:
1030      * - The thread block dimensions match the size of i-clusters, j-clusters,
1031      *   and j-cluster concurrency, in x, y, and z, respectively.
1032      * - The 1D block-grid contains as many blocks as super-clusters.
1033      */
1034     const int                   numBlocks = numSci;
1035     const cl::sycl::range<3>    blockSize{ c_clSize, c_clSize, 1 };
1036     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
1037     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
1038
1039     cl::sycl::queue q = deviceStream.stream();
1040
1041     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
1042         auto kernel = nbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>(
1043                 cgh, std::forward<Args>(args)...);
1044         cgh.parallel_for<kernelNameType>(flattenNDRange(range), kernel);
1045     });
1046
1047     return e;
1048 }
1049
1050 template<class... Args>
1051 cl::sycl::event chooseAndLaunchNbnxmKernel(bool          doPruneNBL,
1052                                            bool          doCalcEnergies,
1053                                            enum ElecType elecType,
1054                                            enum VdwType  vdwType,
1055                                            Args&&... args)
1056 {
1057     return gmx::dispatchTemplatedFunction(
1058             [&](auto doPruneNBL_, auto doCalcEnergies_, auto elecType_, auto vdwType_) {
1059                 return launchNbnxmKernel<doPruneNBL_, doCalcEnergies_, elecType_, vdwType_>(
1060                         std::forward<Args>(args)...);
1061             },
1062             doPruneNBL,
1063             doCalcEnergies,
1064             elecType,
1065             vdwType);
1066 }
1067
1068 void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const InteractionLocality iloc)
1069 {
1070     NBAtomDataGpu*      adat         = nb->atdat;
1071     NBParamGpu*         nbp          = nb->nbparam;
1072     gpu_plist*          plist        = nb->plist[iloc];
1073     const bool          doPruneNBL   = (plist->haveFreshList && !nb->didPrune[iloc]);
1074     const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
1075
1076     // Casting to float simplifies using atomic ops in the kernel
1077     cl::sycl::buffer<Float3, 1> f(*adat->f.buffer_);
1078     auto                        fAsFloat = f.reinterpret<float, 1>(f.get_count() * DIM);
1079     cl::sycl::buffer<Float3, 1> fShift(*adat->fShift.buffer_);
1080     auto fShiftAsFloat = fShift.reinterpret<float, 1>(fShift.get_count() * DIM);
1081
1082     cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL,
1083                                                    stepWork.computeEnergy,
1084                                                    nbp->elecType,
1085                                                    nbp->vdwType,
1086                                                    deviceStream,
1087                                                    plist->nsci,
1088                                                    adat->xq,
1089                                                    fAsFloat,
1090                                                    adat->shiftVec,
1091                                                    fShiftAsFloat,
1092                                                    adat->eElec,
1093                                                    adat->eLJ,
1094                                                    plist->cj4,
1095                                                    plist->sci,
1096                                                    plist->excl,
1097                                                    adat->ljComb,
1098                                                    adat->atomTypes,
1099                                                    nbp->nbfp,
1100                                                    nbp->nbfp_comb,
1101                                                    nbp->coulomb_tab,
1102                                                    adat->numTypes,
1103                                                    nbp->rcoulomb_sq,
1104                                                    nbp->rvdw_sq,
1105                                                    nbp->two_k_rf,
1106                                                    nbp->ewald_beta,
1107                                                    nbp->rlistOuter_sq,
1108                                                    nbp->sh_ewald,
1109                                                    nbp->epsfac,
1110                                                    nbp->ewaldcoeff_lj,
1111                                                    nbp->c_rf,
1112                                                    nbp->dispersion_shift,
1113                                                    nbp->repulsion_shift,
1114                                                    nbp->vdw_switch,
1115                                                    nbp->rvdw_switch,
1116                                                    nbp->sh_lj_ewald,
1117                                                    nbp->coulomb_tab_scale,
1118                                                    stepWork.computeVirial);
1119 }
1120
1121 } // namespace Nbnxm