SYCL: Reduce the number of atomic ops in NBNXM fShift calculation
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief
38  *  NBNXM SYCL kernels
39  *
40  *  \ingroup module_nbnxm
41  */
42 #include "gmxpre.h"
43
44 #include "nbnxm_sycl_kernel.h"
45
46 #include "gromacs/gpu_utils/devicebuffer.h"
47 #include "gromacs/gpu_utils/gmxsycl.h"
48 #include "gromacs/math/functions.h"
49 #include "gromacs/mdtypes/simulation_workload.h"
50 #include "gromacs/pbcutil/ishift.h"
51 #include "gromacs/utility/template_mp.h"
52
53 #include "nbnxm_sycl_kernel_utils.h"
54 #include "nbnxm_sycl_types.h"
55
56 //! \brief Class name for NBNXM kernel
57 template<bool doPruneNBL, bool doCalcEnergies, enum Nbnxm::ElecType elecType, enum Nbnxm::VdwType vdwType>
58 class NbnxmKernel;
59
60 namespace Nbnxm
61 {
62
63 //! \brief Set of boolean constants mimicking preprocessor macros.
64 template<enum ElecType elecType, enum VdwType vdwType>
65 struct EnergyFunctionProperties {
66     static constexpr bool elecCutoff = (elecType == ElecType::Cut); ///< EL_CUTOFF
67     static constexpr bool elecRF     = (elecType == ElecType::RF);  ///< EL_RF
68     static constexpr bool elecEwaldAna =
69             (elecType == ElecType::EwaldAna || elecType == ElecType::EwaldAnaTwin); ///< EL_EWALD_ANA
70     static constexpr bool elecEwaldTab =
71             (elecType == ElecType::EwaldTab || elecType == ElecType::EwaldTabTwin); ///< EL_EWALD_TAB
72     static constexpr bool elecEwaldTwin =
73             (elecType == ElecType::EwaldAnaTwin || elecType == ElecType::EwaldTabTwin);
74     static constexpr bool elecEwald        = (elecEwaldAna || elecEwaldTab); ///< EL_EWALD_ANY
75     static constexpr bool vdwCombLB        = (vdwType == VdwType::CutCombLB);
76     static constexpr bool vdwCombGeom      = (vdwType == VdwType::CutCombGeom); ///< LJ_COMB_GEOM
77     static constexpr bool vdwComb          = (vdwCombLB || vdwCombGeom);        ///< LJ_COMB
78     static constexpr bool vdwEwaldCombGeom = (vdwType == VdwType::EwaldGeom); ///< LJ_EWALD_COMB_GEOM
79     static constexpr bool vdwEwaldCombLB   = (vdwType == VdwType::EwaldLB);   ///< LJ_EWALD_COMB_LB
80     static constexpr bool vdwEwald         = (vdwEwaldCombGeom || vdwEwaldCombLB); ///< LJ_EWALD
81     static constexpr bool vdwFSwitch       = (vdwType == VdwType::FSwitch); ///< LJ_FORCE_SWITCH
82     static constexpr bool vdwPSwitch       = (vdwType == VdwType::PSwitch); ///< LJ_POT_SWITCH
83 };
84
85 //! \brief Templated constants to shorten kernel function declaration.
86 //@{
87 template<enum VdwType vdwType>
88 constexpr bool ljComb = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwComb;
89
90 template<enum ElecType elecType> // Yes, ElecType
91 constexpr bool vdwCutoffCheck = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTwin;
92
93 template<enum ElecType elecType>
94 constexpr bool elecEwald = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwald;
95
96 template<enum ElecType elecType>
97 constexpr bool elecEwaldTab = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTab;
98
99 template<enum VdwType vdwType>
100 constexpr bool ljEwald = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwEwald;
101 //@}
102
103 using cl::sycl::access::fence_space;
104 using cl::sycl::access::mode;
105 using cl::sycl::access::target;
106
107 static inline Float2 convertSigmaEpsilonToC6C12(const float sigma, const float epsilon)
108 {
109     const float sigma2 = sigma * sigma;
110     const float sigma6 = sigma2 * sigma2 * sigma2;
111     const float c6     = epsilon * sigma6;
112     const float c12    = c6 * sigma6;
113
114     return Float2(c6, c12);
115 }
116
117 template<bool doCalcEnergies>
118 static inline void ljForceSwitch(const shift_consts_t         dispersionShift,
119                                  const shift_consts_t         repulsionShift,
120                                  const float                  rVdwSwitch,
121                                  const float                  c6,
122                                  const float                  c12,
123                                  const float                  rInv,
124                                  const float                  r2,
125                                  cl::sycl::private_ptr<float> fInvR,
126                                  cl::sycl::private_ptr<float> eLJ)
127 {
128     /* force switch constants */
129     const float dispShiftV2 = dispersionShift.c2;
130     const float dispShiftV3 = dispersionShift.c3;
131     const float repuShiftV2 = repulsionShift.c2;
132     const float repuShiftV3 = repulsionShift.c3;
133
134     const float r       = r2 * rInv;
135     const float rSwitch = cl::sycl::fdim(r, rVdwSwitch); // max(r - rVdwSwitch, 0)
136
137     *fInvR += -c6 * (dispShiftV2 + dispShiftV3 * rSwitch) * rSwitch * rSwitch * rInv
138               + c12 * (repuShiftV2 + repuShiftV3 * rSwitch) * rSwitch * rSwitch * rInv;
139
140     if constexpr (doCalcEnergies)
141     {
142         const float dispShiftF2 = dispShiftV2 / 3;
143         const float dispShiftF3 = dispShiftV3 / 4;
144         const float repuShiftF2 = repuShiftV2 / 3;
145         const float repuShiftF3 = repuShiftV3 / 4;
146         *eLJ += c6 * (dispShiftF2 + dispShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch
147                 - c12 * (repuShiftF2 + repuShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch;
148     }
149 }
150
151 //! \brief Fetch C6 grid contribution coefficients and return the product of these.
152 template<enum VdwType vdwType>
153 static inline float calculateLJEwaldC6Grid(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
154                                            const int                                typeI,
155                                            const int                                typeJ)
156 {
157     if constexpr (vdwType == VdwType::EwaldGeom)
158     {
159         return a_nbfpComb[typeI][0] * a_nbfpComb[typeJ][0];
160     }
161     else
162     {
163         static_assert(vdwType == VdwType::EwaldLB);
164         /* sigma and epsilon are scaled to give 6*C6 */
165         const Float2 c6c12_i = a_nbfpComb[typeI];
166         const Float2 c6c12_j = a_nbfpComb[typeJ];
167
168         const float sigma   = c6c12_i[0] + c6c12_j[0];
169         const float epsilon = c6c12_i[1] * c6c12_j[1];
170
171         const float sigma2 = sigma * sigma;
172         return epsilon * sigma2 * sigma2 * sigma2;
173     }
174 }
175
176 //! Calculate LJ-PME grid force contribution with geometric or LB combination rule.
177 template<bool doCalcEnergies, enum VdwType vdwType>
178 static inline void ljEwaldComb(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
179                                const float                              sh_lj_ewald,
180                                const int                                typeI,
181                                const int                                typeJ,
182                                const float                              r2,
183                                const float                              r2Inv,
184                                const float                              lje_coeff2,
185                                const float                              lje_coeff6_6,
186                                const float                              int_bit,
187                                cl::sycl::private_ptr<float>             fInvR,
188                                cl::sycl::private_ptr<float>             eLJ)
189 {
190     const float c6grid = calculateLJEwaldC6Grid<vdwType>(a_nbfpComb, typeI, typeJ);
191
192     /* Recalculate inv_r6 without exclusion mask */
193     const float inv_r6_nm = r2Inv * r2Inv * r2Inv;
194     const float cr2       = lje_coeff2 * r2;
195     const float expmcr2   = cl::sycl::exp(-cr2);
196     const float poly      = 1.0F + cr2 + 0.5F * cr2 * cr2;
197
198     /* Subtract the grid force from the total LJ force */
199     *fInvR += c6grid * (inv_r6_nm - expmcr2 * (inv_r6_nm * poly + lje_coeff6_6)) * r2Inv;
200
201     if constexpr (doCalcEnergies)
202     {
203         /* Shift should be applied only to real LJ pairs */
204         const float sh_mask = sh_lj_ewald * int_bit;
205         *eLJ += c_oneSixth * c6grid * (inv_r6_nm * (1.0F - expmcr2 * poly) + sh_mask);
206     }
207 }
208
209 /*! \brief Apply potential switch. */
210 template<bool doCalcEnergies>
211 static inline void ljPotentialSwitch(const switch_consts_t        vdwSwitch,
212                                      const float                  rVdwSwitch,
213                                      const float                  rInv,
214                                      const float                  r2,
215                                      cl::sycl::private_ptr<float> fInvR,
216                                      cl::sycl::private_ptr<float> eLJ)
217 {
218     /* potential switch constants */
219     const float switchV3 = vdwSwitch.c3;
220     const float switchV4 = vdwSwitch.c4;
221     const float switchV5 = vdwSwitch.c5;
222     const float switchF2 = 3 * vdwSwitch.c3;
223     const float switchF3 = 4 * vdwSwitch.c4;
224     const float switchF4 = 5 * vdwSwitch.c5;
225
226     const float r       = r2 * rInv;
227     const float rSwitch = r - rVdwSwitch;
228
229     if (rSwitch > 0.0F)
230     {
231         const float sw =
232                 1.0F + (switchV3 + (switchV4 + switchV5 * rSwitch) * rSwitch) * rSwitch * rSwitch * rSwitch;
233         const float dsw = (switchF2 + (switchF3 + switchF4 * rSwitch) * rSwitch) * rSwitch * rSwitch;
234
235         *fInvR = (*fInvR) * sw - rInv * (*eLJ) * dsw;
236         if constexpr (doCalcEnergies)
237         {
238             *eLJ *= sw;
239         }
240     }
241 }
242
243
244 /*! \brief Calculate analytical Ewald correction term. */
245 static inline float pmeCorrF(const float z2)
246 {
247     constexpr float FN6 = -1.7357322914161492954e-8F;
248     constexpr float FN5 = 1.4703624142580877519e-6F;
249     constexpr float FN4 = -0.000053401640219807709149F;
250     constexpr float FN3 = 0.0010054721316683106153F;
251     constexpr float FN2 = -0.019278317264888380590F;
252     constexpr float FN1 = 0.069670166153766424023F;
253     constexpr float FN0 = -0.75225204789749321333F;
254
255     constexpr float FD4 = 0.0011193462567257629232F;
256     constexpr float FD3 = 0.014866955030185295499F;
257     constexpr float FD2 = 0.11583842382862377919F;
258     constexpr float FD1 = 0.50736591960530292870F;
259     constexpr float FD0 = 1.0F;
260
261     const float z4 = z2 * z2;
262
263     float       polyFD0 = FD4 * z4 + FD2;
264     const float polyFD1 = FD3 * z4 + FD1;
265     polyFD0             = polyFD0 * z4 + FD0;
266     polyFD0             = polyFD1 * z2 + polyFD0;
267
268     polyFD0 = 1.0F / polyFD0;
269
270     float polyFN0 = FN6 * z4 + FN4;
271     float polyFN1 = FN5 * z4 + FN3;
272     polyFN0       = polyFN0 * z4 + FN2;
273     polyFN1       = polyFN1 * z4 + FN1;
274     polyFN0       = polyFN0 * z4 + FN0;
275     polyFN0       = polyFN1 * z2 + polyFN0;
276
277     return polyFN0 * polyFD0;
278 }
279
280 /*! \brief Linear interpolation using exactly two FMA operations.
281  *
282  *  Implements numeric equivalent of: (1-t)*d0 + t*d1.
283  */
284 template<typename T>
285 static inline T lerp(T d0, T d1, T t)
286 {
287     return fma(t, d1, fma(-t, d0, d0));
288 }
289
290 /*! \brief Interpolate Ewald coulomb force correction using the F*r table. */
291 static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::read> a_coulombTab,
292                                              const float coulombTabScale,
293                                              const float r)
294 {
295     const float normalized = coulombTabScale * r;
296     const int   index      = static_cast<int>(normalized);
297     const float fraction   = normalized - index;
298
299     const float left  = a_coulombTab[index];
300     const float right = a_coulombTab[index + 1];
301
302     return lerp(left, right, fraction); // TODO: cl::sycl::mix
303 }
304
305 /*! \brief Reduce c_clSize j-force components using shifts and atomically accumulate into a_f.
306  *
307  * c_clSize consecutive threads hold the force components of a j-atom which we
308  * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
309  */
310 static inline void reduceForceJShuffle(Float3                             f,
311                                        const cl::sycl::nd_item<1>         itemIdx,
312                                        const int                          tidxi,
313                                        const int                          aidx,
314                                        DeviceAccessor<float, mode_atomic> a_f)
315 {
316     static_assert(c_clSize == 8 || c_clSize == 4);
317     sycl_2020::sub_group sg = itemIdx.get_sub_group();
318
319     f[0] += sycl_2020::shift_left(sg, f[0], 1);
320     f[1] += sycl_2020::shift_right(sg, f[1], 1);
321     f[2] += sycl_2020::shift_left(sg, f[2], 1);
322     if (tidxi & 1)
323     {
324         f[0] = f[1];
325     }
326
327     f[0] += sycl_2020::shift_left(sg, f[0], 2);
328     f[2] += sycl_2020::shift_right(sg, f[2], 2);
329     if (tidxi & 2)
330     {
331         f[0] = f[2];
332     }
333
334     if constexpr (c_clSize == 8)
335     {
336         f[0] += sycl_2020::shift_left(sg, f[0], 4);
337     }
338
339     if (tidxi < 3)
340     {
341         atomicFetchAdd(a_f, 3 * aidx + tidxi, f[0]);
342     }
343 }
344
345 /*! \brief Reduce c_clSize j-force components using local memory and atomically accumulate into a_f.
346  *
347  * c_clSize consecutive threads hold the force components of a j-atom which we
348  * reduced in cl_Size steps using shift and atomically accumulate them into \p a_f.
349  *
350  * TODO: implement binary reduction flavor for the case where cl_Size is power of two.
351  */
352 static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
353                                        Float3                             f,
354                                        const cl::sycl::nd_item<1>         itemIdx,
355                                        const int                          tidxi,
356                                        const int                          tidxj,
357                                        const int                          aidx,
358                                        DeviceAccessor<float, mode_atomic> a_f)
359 {
360     static constexpr int sc_fBufferStride = c_clSizeSq;
361     int                  tidx             = tidxi + tidxj * c_clSize;
362     sm_buf[0 * sc_fBufferStride + tidx]   = f[0];
363     sm_buf[1 * sc_fBufferStride + tidx]   = f[1];
364     sm_buf[2 * sc_fBufferStride + tidx]   = f[2];
365
366     subGroupBarrier(itemIdx);
367
368     // reducing data 8-by-by elements on the leader of same threads as those storing above
369     assert(itemIdx.get_sub_group().get_local_range().size() >= c_clSize);
370
371     if (tidxi < 3)
372     {
373         float fSum = 0.0F;
374         for (int j = tidxj * c_clSize; j < (tidxj + 1) * c_clSize; j++)
375         {
376             fSum += sm_buf[sc_fBufferStride * tidxi + j];
377         }
378
379         atomicFetchAdd(a_f, 3 * aidx + tidxi, fSum);
380     }
381 }
382
383
384 /*! \brief Reduce c_clSize j-force components using either shifts or local memory and atomically accumulate into a_f.
385  */
386 static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
387                                 Float3                                                        f,
388                                 const cl::sycl::nd_item<1>         itemIdx,
389                                 const int                          tidxi,
390                                 const int                          tidxj,
391                                 const int                          aidx,
392                                 DeviceAccessor<float, mode_atomic> a_f)
393 {
394     if constexpr (!gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster))
395     {
396         reduceForceJGeneric(sm_buf, f, itemIdx, tidxi, tidxj, aidx, a_f);
397     }
398     else
399     {
400         reduceForceJShuffle(f, itemIdx, tidxi, aidx, a_f);
401     }
402 }
403
404
405 /*! \brief Final i-force reduction.
406  *
407  * Reduce c_nbnxnGpuNumClusterPerSupercluster i-force componets stored in \p fCiBuf[]
408  * accumulating atomically into \p a_f.
409  * If \p calcFShift is true, further reduce shift forces and atomically accumulate into \p a_fShift.
410  *
411  * This implementation works only with power of two array sizes.
412  */
413 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
414                                          const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
415                                          const bool   calcFShift,
416                                          const cl::sycl::nd_item<1>         itemIdx,
417                                          const int                          tidxi,
418                                          const int                          tidxj,
419                                          const int                          sci,
420                                          const int                          shift,
421                                          DeviceAccessor<float, mode_atomic> a_f,
422                                          DeviceAccessor<float, mode_atomic> a_fShift)
423 {
424     // must have power of two elements in fCiBuf
425     static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster));
426
427     static constexpr int bufStride  = c_clSize * c_clSize;
428     static constexpr int clSizeLog2 = gmx::StaticLog2<c_clSize>::value;
429     const int            tidx       = tidxi + tidxj * c_clSize;
430     float                fShiftBuf  = 0.0F;
431     for (int ciOffset = 0; ciOffset < c_nbnxnGpuNumClusterPerSupercluster; ciOffset++)
432     {
433         const int aidx = (sci * c_nbnxnGpuNumClusterPerSupercluster + ciOffset) * c_clSize + tidxi;
434         /* store i forces in shmem */
435         sm_buf[tidx]                 = fCiBuf[ciOffset][0];
436         sm_buf[bufStride + tidx]     = fCiBuf[ciOffset][1];
437         sm_buf[2 * bufStride + tidx] = fCiBuf[ciOffset][2];
438         itemIdx.barrier(fence_space::local_space);
439
440         /* Reduce the initial c_clSize values for each i atom to half
441          * every step by using c_clSize * i threads. */
442         int i = c_clSize / 2;
443         for (int j = clSizeLog2 - 1; j > 0; j--)
444         {
445             if (tidxj < i)
446             {
447                 sm_buf[tidxj * c_clSize + tidxi] += sm_buf[(tidxj + i) * c_clSize + tidxi];
448                 sm_buf[bufStride + tidxj * c_clSize + tidxi] +=
449                         sm_buf[bufStride + (tidxj + i) * c_clSize + tidxi];
450                 sm_buf[2 * bufStride + tidxj * c_clSize + tidxi] +=
451                         sm_buf[2 * bufStride + (tidxj + i) * c_clSize + tidxi];
452             }
453             i >>= 1;
454             itemIdx.barrier(fence_space::local_space);
455         }
456
457         /* i == 1, last reduction step, writing to global mem */
458         /* Split the reduction between the first 3 line threads
459            Threads with line id 0 will do the reduction for (float3).x components
460            Threads with line id 1 will do the reduction for (float3).y components
461            Threads with line id 2 will do the reduction for (float3).z components. */
462         if (tidxj < 3)
463         {
464             const float f =
465                     sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
466             atomicFetchAdd(a_f, 3 * aidx + tidxj, f);
467             if (calcFShift)
468             {
469                 fShiftBuf += f;
470             }
471         }
472         itemIdx.barrier(fence_space::local_space);
473     }
474     /* add up local shift forces into global mem */
475     if (calcFShift)
476     {
477         /* Only threads with tidxj < 3 will update fshift.
478            The threads performing the update must be the same as the threads
479            storing the reduction result above. */
480         if (tidxj < 3)
481         {
482             if constexpr (c_clSize == 4)
483             {
484                 /* Intel Xe (Gen12LP) and earlier GPUs implement floating-point atomics via
485                  * a compare-and-swap (CAS) loop. It has particularly poor performance when
486                  * updating the same memory location from the same work-group.
487                  * Such optimization might be slightly beneficial for NVIDIA and AMD as well,
488                  * but it is unlikely to make a big difference and thus was not evaluated.
489                  */
490                 auto sg = itemIdx.get_sub_group();
491                 fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 1);
492                 fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 2);
493                 if (tidxi == 0)
494                 {
495                     atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
496                 }
497             }
498             else
499             {
500                 atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
501             }
502         }
503     }
504 }
505
506 /*! \brief Main kernel for NBNXM.
507  *
508  */
509 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
510 auto nbnxmKernel(cl::sycl::handler&                                   cgh,
511                  DeviceAccessor<Float4, mode::read>                   a_xq,
512                  DeviceAccessor<float, mode_atomic>                   a_f,
513                  DeviceAccessor<Float3, mode::read>                   a_shiftVec,
514                  DeviceAccessor<float, mode_atomic>                   a_fShift,
515                  OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyElec,
516                  OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyVdw,
517                  DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
518                  DeviceAccessor<nbnxn_sci_t, mode::read>                                 a_plistSci,
519                  DeviceAccessor<nbnxn_excl_t, mode::read>                    a_plistExcl,
520                  OptionalAccessor<Float2, mode::read, ljComb<vdwType>>       a_ljComb,
521                  OptionalAccessor<int, mode::read, !ljComb<vdwType>>         a_atomTypes,
522                  OptionalAccessor<Float2, mode::read, !ljComb<vdwType>>      a_nbfp,
523                  OptionalAccessor<Float2, mode::read, ljEwald<vdwType>>      a_nbfpComb,
524                  OptionalAccessor<float, mode::read, elecEwaldTab<elecType>> a_coulombTab,
525                  const int                                                   numTypes,
526                  const float                                                 rCoulombSq,
527                  const float                                                 rVdwSq,
528                  const float                                                 twoKRf,
529                  const float                                                 ewaldBeta,
530                  const float                                                 rlistOuterSq,
531                  const float                                                 ewaldShift,
532                  const float                                                 epsFac,
533                  const float                                                 ewaldCoeffLJ,
534                  const float                                                 cRF,
535                  const shift_consts_t                                        dispersionShift,
536                  const shift_consts_t                                        repulsionShift,
537                  const switch_consts_t                                       vdwSwitch,
538                  const float                                                 rVdwSwitch,
539                  const float                                                 ljEwaldShift,
540                  const float                                                 coulombTabScale,
541                  const bool                                                  calcShift)
542 {
543     static constexpr EnergyFunctionProperties<elecType, vdwType> props;
544
545     cgh.require(a_xq);
546     cgh.require(a_f);
547     cgh.require(a_shiftVec);
548     cgh.require(a_fShift);
549     if constexpr (doCalcEnergies)
550     {
551         cgh.require(a_energyElec);
552         cgh.require(a_energyVdw);
553     }
554     cgh.require(a_plistCJ4);
555     cgh.require(a_plistSci);
556     cgh.require(a_plistExcl);
557     if constexpr (!props.vdwComb)
558     {
559         cgh.require(a_atomTypes);
560         cgh.require(a_nbfp);
561     }
562     else
563     {
564         cgh.require(a_ljComb);
565     }
566     if constexpr (props.vdwEwald)
567     {
568         cgh.require(a_nbfpComb);
569     }
570     if constexpr (props.elecEwaldTab)
571     {
572         cgh.require(a_coulombTab);
573     }
574
575     // shmem buffer for i x+q pre-loading
576     cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(
577             cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
578
579     // shmem buffer for force reduction
580     // SYCL-TODO: Make into 3D; section 4.7.6.11 of SYCL2020 specs
581     cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_reductionBuffer(
582             cl::sycl::range<1>(c_clSize * c_clSize * DIM), cgh);
583
584     auto sm_atomTypeI = [&]() {
585         if constexpr (!props.vdwComb)
586         {
587             return cl::sycl::accessor<int, 2, mode::read_write, target::local>(
588                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
589         }
590         else
591         {
592             return nullptr;
593         }
594     }();
595
596     auto sm_ljCombI = [&]() {
597         if constexpr (props.vdwComb)
598         {
599             return cl::sycl::accessor<Float2, 2, mode::read_write, target::local>(
600                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
601         }
602         else
603         {
604             return nullptr;
605         }
606     }();
607
608     /* Flag to control the calculation of exclusion forces in the kernel
609      * We do that with Ewald (elec/vdw) and RF. Cut-off only has exclusion
610      * energy terms. */
611     constexpr bool doExclusionForces =
612             (props.elecEwald || props.elecRF || props.vdwEwald || (props.elecCutoff && doCalcEnergies));
613
614     // The post-prune j-i cluster-pair organization is linked to how exclusion and interaction mask data is stored.
615     // Currently this is ideally suited for 32-wide subgroup size but slightly less so for others,
616     // e.g. subGroupSize > prunedClusterPairSize on AMD GCN / CDNA.
617     // Hence, the two are decoupled.
618     constexpr int prunedClusterPairSize = c_clSize * c_splitClSize;
619 #if defined(HIPSYCL_PLATFORM_ROCM) // SYCL-TODO AMD RDNA/RDNA2 has 32-wide exec; how can we check for that?
620     gmx_unused constexpr int subGroupSize = c_clSize * c_clSize;
621 #else
622     gmx_unused constexpr int subGroupSize = prunedClusterPairSize;
623 #endif
624
625     return [=](cl::sycl::nd_item<1> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
626     {
627         /* thread/block/warp id-s */
628         const cl::sycl::id<3> localId = unflattenId<c_clSize, c_clSize>(itemIdx.get_local_id());
629         const unsigned        tidxi   = localId[0];
630         const unsigned        tidxj   = localId[1];
631         const unsigned        tidx    = tidxj * c_clSize + tidxi;
632         const unsigned        tidxz   = 0;
633
634         // Group indexing was flat originally, no need to unflatten it.
635         const unsigned bidx = itemIdx.get_group(0);
636
637         const sycl_2020::sub_group sg = itemIdx.get_sub_group();
638         // Could use sg.get_group_range to compute the imask & exclusion Idx, but too much of the logic relies on it anyway
639         // and in cases where prunedClusterPairSize != subGroupSize we can't use it anyway
640         const unsigned imeiIdx = tidx / prunedClusterPairSize;
641
642         Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster]; // i force buffer
643         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
644         {
645             fCiBuf[i] = Float3(0.0F, 0.0F, 0.0F);
646         }
647
648         const nbnxn_sci_t nbSci     = a_plistSci[bidx];
649         const int         sci       = nbSci.sci;
650         const int         cij4Start = nbSci.cj4_ind_start;
651         const int         cij4End   = nbSci.cj4_ind_end;
652
653         // Only needed if props.elecEwaldAna
654         const float beta2 = ewaldBeta * ewaldBeta;
655         const float beta3 = ewaldBeta * ewaldBeta * ewaldBeta;
656
657         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i += c_clSize)
658         {
659             /* Pre-load i-atom x and q into shared memory */
660             const int             ci       = sci * c_nbnxnGpuNumClusterPerSupercluster + tidxj + i;
661             const int             ai       = ci * c_clSize + tidxi;
662             const cl::sycl::id<2> cacheIdx = cl::sycl::id<2>(tidxj + i, tidxi);
663
664             const Float3 shift = a_shiftVec[nbSci.shift];
665             Float4       xqi   = a_xq[ai];
666             xqi += Float4(shift[0], shift[1], shift[2], 0.0F);
667             xqi[3] *= epsFac;
668             sm_xq[cacheIdx] = xqi;
669
670             if constexpr (!props.vdwComb)
671             {
672                 // Pre-load the i-atom types into shared memory
673                 sm_atomTypeI[cacheIdx] = a_atomTypes[ai];
674             }
675             else
676             {
677                 // Pre-load the LJ combination parameters into shared memory
678                 sm_ljCombI[cacheIdx] = a_ljComb[ai];
679             }
680         }
681         itemIdx.barrier(fence_space::local_space);
682
683         float ewaldCoeffLJ_2, ewaldCoeffLJ_6_6; // Only needed if (props.vdwEwald)
684         if constexpr (props.vdwEwald)
685         {
686             ewaldCoeffLJ_2   = ewaldCoeffLJ * ewaldCoeffLJ;
687             ewaldCoeffLJ_6_6 = ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * c_oneSixth;
688         }
689
690         float energyVdw, energyElec; // Only needed if (doCalcEnergies)
691         if constexpr (doCalcEnergies)
692         {
693             energyVdw = energyElec = 0.0F;
694         }
695         if constexpr (doCalcEnergies && doExclusionForces)
696         {
697             if (nbSci.shift == gmx::c_centralShiftIndex
698                 && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
699             {
700                 // we have the diagonal: add the charge and LJ self interaction energy term
701                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
702                 {
703                     // TODO: Are there other options?
704                     if constexpr (props.elecEwald || props.elecRF || props.elecCutoff)
705                     {
706                         const float qi = sm_xq[i][tidxi][3];
707                         energyElec += qi * qi;
708                     }
709                     if constexpr (props.vdwEwald)
710                     {
711                         energyVdw +=
712                                 a_nbfp[a_atomTypes[(sci * c_nbnxnGpuNumClusterPerSupercluster + i) * c_clSize + tidxi]
713                                        * (numTypes + 1)][0];
714                     }
715                 }
716                 /* divide the self term(s) equally over the j-threads, then multiply with the coefficients. */
717                 if constexpr (props.vdwEwald)
718                 {
719                     energyVdw /= c_clSize;
720                     energyVdw *= 0.5F * c_oneSixth * ewaldCoeffLJ_6_6; // c_OneTwelfth?
721                 }
722                 if constexpr (props.elecRF || props.elecCutoff)
723                 {
724                     // Correct for epsfac^2 due to adding qi^2 */
725                     energyElec /= epsFac * c_clSize;
726                     energyElec *= -0.5F * cRF;
727                 }
728                 if constexpr (props.elecEwald)
729                 {
730                     // Correct for epsfac^2 due to adding qi^2 */
731                     energyElec /= epsFac * c_clSize;
732                     energyElec *= -ewaldBeta * c_OneOverSqrtPi; /* last factor 1/sqrt(pi) */
733                 }
734             } // (nbSci.shift == gmx::c_centralShiftIndex && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
735         }     // (doCalcEnergies && doExclusionForces)
736
737         // Only needed if (doExclusionForces)
738         const bool nonSelfInteraction = !(nbSci.shift == gmx::c_centralShiftIndex & tidxj <= tidxi);
739
740         // loop over the j clusters = seen by any of the atoms in the current super-cluster
741         for (int j4 = cij4Start + tidxz; j4 < cij4End; j4 += 1)
742         {
743             unsigned imask = a_plistCJ4[j4].imei[imeiIdx].imask;
744             if (!doPruneNBL && !imask)
745             {
746                 continue;
747             }
748             const int wexclIdx = a_plistCJ4[j4].imei[imeiIdx].excl_ind;
749             static_assert(gmx::isPowerOfTwo(prunedClusterPairSize));
750             const unsigned wexcl = a_plistExcl[wexclIdx].pair[tidx & (prunedClusterPairSize - 1)];
751             for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
752             {
753                 const bool maskSet =
754                         imask & (superClInteractionMask << (jm * c_nbnxnGpuNumClusterPerSupercluster));
755                 if (!maskSet)
756                 {
757                     continue;
758                 }
759                 unsigned  maskJI = (1U << (jm * c_nbnxnGpuNumClusterPerSupercluster));
760                 const int cj     = a_plistCJ4[j4].cj[jm];
761                 const int aj     = cj * c_clSize + tidxj;
762
763                 // load j atom data
764                 const Float4 xqj = a_xq[aj];
765
766                 const Float3 xj(xqj[0], xqj[1], xqj[2]);
767                 const float  qj = xqj[3];
768                 int          atomTypeJ; // Only needed if (!props.vdwComb)
769                 Float2       ljCombJ;   // Only needed if (props.vdwComb)
770                 if constexpr (props.vdwComb)
771                 {
772                     ljCombJ = a_ljComb[aj];
773                 }
774                 else
775                 {
776                     atomTypeJ = a_atomTypes[aj];
777                 }
778
779                 Float3 fCjBuf(0.0F, 0.0F, 0.0F);
780
781                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
782                 {
783                     if (imask & maskJI)
784                     {
785                         // i cluster index
786                         const int ci = sci * c_nbnxnGpuNumClusterPerSupercluster + i;
787                         // all threads load an atom from i cluster ci into shmem!
788                         const Float4 xqi = sm_xq[i][tidxi];
789                         const Float3 xi(xqi[0], xqi[1], xqi[2]);
790
791                         // distance between i and j atoms
792                         const Float3 rv = xi - xj;
793                         float        r2 = norm2(rv);
794
795                         if constexpr (doPruneNBL)
796                         {
797                             /* If _none_ of the atoms pairs are in cutoff range,
798                              * the bit corresponding to the current
799                              * cluster-pair in imask gets set to 0. */
800                             if (!sycl_2020::group_any_of(sg, r2 < rlistOuterSq))
801                             {
802                                 imask &= ~maskJI;
803                             }
804                         }
805                         const float pairExclMask = (wexcl & maskJI) ? 1.0F : 0.0F;
806
807                         // cutoff & exclusion check
808
809                         const bool notExcluded = doExclusionForces ? (nonSelfInteraction | (ci != cj))
810                                                                    : (wexcl & maskJI);
811
812                         // SYCL-TODO: Check optimal way of branching here.
813                         if ((r2 < rCoulombSq) && notExcluded)
814                         {
815                             const float qi = xqi[3];
816                             int         atomTypeI; // Only needed if (!props.vdwComb)
817                             float       sigma, epsilon;
818                             Float2      c6c12;
819
820                             if constexpr (!props.vdwComb)
821                             {
822                                 /* LJ 6*C6 and 12*C12 */
823                                 atomTypeI = sm_atomTypeI[i][tidxi];
824                                 c6c12     = a_nbfp[numTypes * atomTypeI + atomTypeJ];
825                             }
826                             else
827                             {
828                                 const Float2 ljCombI = sm_ljCombI[i][tidxi];
829                                 if constexpr (props.vdwCombGeom)
830                                 {
831                                     c6c12 = Float2(ljCombI[0] * ljCombJ[0], ljCombI[1] * ljCombJ[1]);
832                                 }
833                                 else
834                                 {
835                                     static_assert(props.vdwCombLB);
836                                     // LJ 2^(1/6)*sigma and 12*epsilon
837                                     sigma   = ljCombI[0] + ljCombJ[0];
838                                     epsilon = ljCombI[1] * ljCombJ[1];
839                                     if constexpr (doCalcEnergies)
840                                     {
841                                         c6c12 = convertSigmaEpsilonToC6C12(sigma, epsilon);
842                                     }
843                                 } // props.vdwCombGeom
844                             }     // !props.vdwComb
845
846                             // c6 and c12 are unused and garbage iff props.vdwCombLB && !doCalcEnergies
847                             const float c6  = c6c12[0];
848                             const float c12 = c6c12[1];
849
850                             // Ensure distance do not become so small that r^-12 overflows
851                             r2 = std::max(r2, c_nbnxnMinDistanceSquared);
852 #if GMX_SYCL_HIPSYCL
853                             // No fast/native functions in some compilation passes
854                             const float rInv = cl::sycl::rsqrt(r2);
855 #else
856                             // SYCL-TODO: sycl::half_precision::rsqrt?
857                             const float rInv = cl::sycl::native::rsqrt(r2);
858 #endif
859                             const float r2Inv = rInv * rInv;
860                             float       r6Inv, fInvR, energyLJPair;
861                             if constexpr (!props.vdwCombLB || doCalcEnergies)
862                             {
863                                 r6Inv = r2Inv * r2Inv * r2Inv;
864                                 if constexpr (doExclusionForces)
865                                 {
866                                     // SYCL-TODO: Check if true for SYCL
867                                     /* We could mask r2Inv, but with Ewald masking both
868                                      * r6Inv and fInvR is faster */
869                                     r6Inv *= pairExclMask;
870                                 }
871                                 fInvR = r6Inv * (c12 * r6Inv - c6) * r2Inv;
872                             }
873                             else
874                             {
875                                 float sig_r  = sigma * rInv;
876                                 float sig_r2 = sig_r * sig_r;
877                                 float sig_r6 = sig_r2 * sig_r2 * sig_r2;
878                                 if constexpr (doExclusionForces)
879                                 {
880                                     sig_r6 *= pairExclMask;
881                                 }
882                                 fInvR = epsilon * sig_r6 * (sig_r6 - 1.0F) * r2Inv;
883                             } // (!props.vdwCombLB || doCalcEnergies)
884                             if constexpr (doCalcEnergies || props.vdwPSwitch)
885                             {
886                                 energyLJPair = pairExclMask
887                                                * (c12 * (r6Inv * r6Inv + repulsionShift.cpot) * c_oneTwelfth
888                                                   - c6 * (r6Inv + dispersionShift.cpot) * c_oneSixth);
889                             }
890                             if constexpr (props.vdwFSwitch)
891                             {
892                                 ljForceSwitch<doCalcEnergies>(
893                                         dispersionShift, repulsionShift, rVdwSwitch, c6, c12, rInv, r2, &fInvR, &energyLJPair);
894                             }
895                             if constexpr (props.vdwEwald)
896                             {
897                                 ljEwaldComb<doCalcEnergies, vdwType>(a_nbfpComb,
898                                                                      ljEwaldShift,
899                                                                      atomTypeI,
900                                                                      atomTypeJ,
901                                                                      r2,
902                                                                      r2Inv,
903                                                                      ewaldCoeffLJ_2,
904                                                                      ewaldCoeffLJ_6_6,
905                                                                      pairExclMask,
906                                                                      &fInvR,
907                                                                      &energyLJPair);
908                             } // (props.vdwEwald)
909                             if constexpr (props.vdwPSwitch)
910                             {
911                                 ljPotentialSwitch<doCalcEnergies>(
912                                         vdwSwitch, rVdwSwitch, rInv, r2, &fInvR, &energyLJPair);
913                             }
914                             if constexpr (props.elecEwaldTwin)
915                             {
916                                 // Separate VDW cut-off check to enable twin-range cut-offs
917                                 // (rVdw < rCoulomb <= rList)
918                                 const float vdwInRange = (r2 < rVdwSq) ? 1.0F : 0.0F;
919                                 fInvR *= vdwInRange;
920                                 if constexpr (doCalcEnergies)
921                                 {
922                                     energyLJPair *= vdwInRange;
923                                 }
924                             }
925                             if constexpr (doCalcEnergies)
926                             {
927                                 energyVdw += energyLJPair;
928                             }
929
930                             if constexpr (props.elecCutoff)
931                             {
932                                 if constexpr (doExclusionForces)
933                                 {
934                                     fInvR += qi * qj * pairExclMask * r2Inv * rInv;
935                                 }
936                                 else
937                                 {
938                                     fInvR += qi * qj * r2Inv * rInv;
939                                 }
940                             }
941                             if constexpr (props.elecRF)
942                             {
943                                 fInvR += qi * qj * (pairExclMask * r2Inv * rInv - twoKRf);
944                             }
945                             if constexpr (props.elecEwaldAna)
946                             {
947                                 fInvR += qi * qj
948                                          * (pairExclMask * r2Inv * rInv + pmeCorrF(beta2 * r2) * beta3);
949                             }
950                             if constexpr (props.elecEwaldTab)
951                             {
952                                 fInvR += qi * qj
953                                          * (pairExclMask * r2Inv
954                                             - interpolateCoulombForceR(
955                                                     a_coulombTab, coulombTabScale, r2 * rInv))
956                                          * rInv;
957                             }
958
959                             if constexpr (doCalcEnergies)
960                             {
961                                 if constexpr (props.elecCutoff)
962                                 {
963                                     energyElec += qi * qj * (pairExclMask * rInv - cRF);
964                                 }
965                                 if constexpr (props.elecRF)
966                                 {
967                                     energyElec +=
968                                             qi * qj * (pairExclMask * rInv + 0.5f * twoKRf * r2 - cRF);
969                                 }
970                                 if constexpr (props.elecEwald)
971                                 {
972                                     energyElec +=
973                                             qi * qj
974                                             * (rInv * (pairExclMask - cl::sycl::erf(r2 * rInv * ewaldBeta))
975                                                - pairExclMask * ewaldShift);
976                                 }
977                             }
978
979                             const Float3 forceIJ = rv * fInvR;
980
981                             /* accumulate j forces in registers */
982                             fCjBuf -= forceIJ;
983                             /* accumulate i forces in registers */
984                             fCiBuf[i] += forceIJ;
985                         } // (r2 < rCoulombSq) && notExcluded
986                     }     // (imask & maskJI)
987                     /* shift the mask bit by 1 */
988                     maskJI += maskJI;
989                 } // for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
990                 /* reduce j forces */
991                 reduceForceJ(sm_reductionBuffer, fCjBuf, itemIdx, tidxi, tidxj, aj, a_f);
992             } // for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
993             if constexpr (doPruneNBL)
994             {
995                 /* Update the imask with the new one which does not contain the
996                  * out of range clusters anymore. */
997                 a_plistCJ4[j4].imei[imeiIdx].imask = imask;
998             }
999         } // for (int j4 = cij4Start; j4 < cij4End; j4 += 1)
1000
1001         /* skip central shifts when summing shift forces */
1002         const bool doCalcShift = (calcShift && !(nbSci.shift == gmx::c_centralShiftIndex));
1003
1004         reduceForceIAndFShift(
1005                 sm_reductionBuffer, fCiBuf, doCalcShift, itemIdx, tidxi, tidxj, sci, nbSci.shift, a_f, a_fShift);
1006
1007         if constexpr (doCalcEnergies)
1008         {
1009             const float energyVdwGroup = sycl_2020::group_reduce(
1010                     itemIdx.get_group(), energyVdw, 0.0F, sycl_2020::plus<float>());
1011             const float energyElecGroup = sycl_2020::group_reduce(
1012                     itemIdx.get_group(), energyElec, 0.0F, sycl_2020::plus<float>());
1013
1014             if (tidx == 0)
1015             {
1016                 atomicFetchAdd(a_energyVdw, 0, energyVdwGroup);
1017                 atomicFetchAdd(a_energyElec, 0, energyElecGroup);
1018             }
1019         }
1020     };
1021 }
1022
1023 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType, class... Args>
1024 cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int numSci, Args&&... args)
1025 {
1026     using kernelNameType = NbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>;
1027
1028     /* Kernel launch config:
1029      * - The thread block dimensions match the size of i-clusters, j-clusters,
1030      *   and j-cluster concurrency, in x, y, and z, respectively.
1031      * - The 1D block-grid contains as many blocks as super-clusters.
1032      */
1033     const int                   numBlocks = numSci;
1034     const cl::sycl::range<3>    blockSize{ c_clSize, c_clSize, 1 };
1035     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
1036     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
1037
1038     cl::sycl::queue q = deviceStream.stream();
1039
1040     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
1041         auto kernel = nbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>(
1042                 cgh, std::forward<Args>(args)...);
1043         cgh.parallel_for<kernelNameType>(flattenNDRange(range), kernel);
1044     });
1045
1046     return e;
1047 }
1048
1049 template<class... Args>
1050 cl::sycl::event chooseAndLaunchNbnxmKernel(bool          doPruneNBL,
1051                                            bool          doCalcEnergies,
1052                                            enum ElecType elecType,
1053                                            enum VdwType  vdwType,
1054                                            Args&&... args)
1055 {
1056     return gmx::dispatchTemplatedFunction(
1057             [&](auto doPruneNBL_, auto doCalcEnergies_, auto elecType_, auto vdwType_) {
1058                 return launchNbnxmKernel<doPruneNBL_, doCalcEnergies_, elecType_, vdwType_>(
1059                         std::forward<Args>(args)...);
1060             },
1061             doPruneNBL,
1062             doCalcEnergies,
1063             elecType,
1064             vdwType);
1065 }
1066
1067 void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const InteractionLocality iloc)
1068 {
1069     NBAtomDataGpu*      adat         = nb->atdat;
1070     NBParamGpu*         nbp          = nb->nbparam;
1071     gpu_plist*          plist        = nb->plist[iloc];
1072     const bool          doPruneNBL   = (plist->haveFreshList && !nb->didPrune[iloc]);
1073     const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
1074
1075     // Casting to float simplifies using atomic ops in the kernel
1076     cl::sycl::buffer<Float3, 1> f(*adat->f.buffer_);
1077     auto                        fAsFloat = f.reinterpret<float, 1>(f.get_count() * DIM);
1078     cl::sycl::buffer<Float3, 1> fShift(*adat->fShift.buffer_);
1079     auto fShiftAsFloat = fShift.reinterpret<float, 1>(fShift.get_count() * DIM);
1080
1081     cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL,
1082                                                    stepWork.computeEnergy,
1083                                                    nbp->elecType,
1084                                                    nbp->vdwType,
1085                                                    deviceStream,
1086                                                    plist->nsci,
1087                                                    adat->xq,
1088                                                    fAsFloat,
1089                                                    adat->shiftVec,
1090                                                    fShiftAsFloat,
1091                                                    adat->eElec,
1092                                                    adat->eLJ,
1093                                                    plist->cj4,
1094                                                    plist->sci,
1095                                                    plist->excl,
1096                                                    adat->ljComb,
1097                                                    adat->atomTypes,
1098                                                    nbp->nbfp,
1099                                                    nbp->nbfp_comb,
1100                                                    nbp->coulomb_tab,
1101                                                    adat->numTypes,
1102                                                    nbp->rcoulomb_sq,
1103                                                    nbp->rvdw_sq,
1104                                                    nbp->two_k_rf,
1105                                                    nbp->ewald_beta,
1106                                                    nbp->rlistOuter_sq,
1107                                                    nbp->sh_ewald,
1108                                                    nbp->epsfac,
1109                                                    nbp->ewaldcoeff_lj,
1110                                                    nbp->c_rf,
1111                                                    nbp->dispersion_shift,
1112                                                    nbp->repulsion_shift,
1113                                                    nbp->vdw_switch,
1114                                                    nbp->rvdw_switch,
1115                                                    nbp->sh_lj_ewald,
1116                                                    nbp->coulomb_tab_scale,
1117                                                    stepWork.computeVirial);
1118 }
1119
1120 } // namespace Nbnxm