ae703bdb16f042f7ea63e5749805c86418ff1656
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief
38  *  NBNXM SYCL kernels
39  *
40  *  \ingroup module_nbnxm
41  */
42 #include "gmxpre.h"
43
44 #include "nbnxm_sycl_kernel.h"
45
46 #include "gromacs/gpu_utils/devicebuffer.h"
47 #include "gromacs/gpu_utils/gmxsycl.h"
48 #include "gromacs/math/functions.h"
49 #include "gromacs/mdtypes/simulation_workload.h"
50 #include "gromacs/pbcutil/ishift.h"
51 #include "gromacs/utility/template_mp.h"
52
53 #include "nbnxm_sycl_kernel_utils.h"
54 #include "nbnxm_sycl_types.h"
55
56 //! \brief Class name for NBNXM kernel
57 template<bool doPruneNBL, bool doCalcEnergies, enum Nbnxm::ElecType elecType, enum Nbnxm::VdwType vdwType>
58 class NbnxmKernel;
59
60 namespace Nbnxm
61 {
62
63 //! \brief Set of boolean constants mimicking preprocessor macros.
64 template<enum ElecType elecType, enum VdwType vdwType>
65 struct EnergyFunctionProperties {
66     static constexpr bool elecCutoff = (elecType == ElecType::Cut); ///< EL_CUTOFF
67     static constexpr bool elecRF     = (elecType == ElecType::RF);  ///< EL_RF
68     static constexpr bool elecEwaldAna =
69             (elecType == ElecType::EwaldAna || elecType == ElecType::EwaldAnaTwin); ///< EL_EWALD_ANA
70     static constexpr bool elecEwaldTab =
71             (elecType == ElecType::EwaldTab || elecType == ElecType::EwaldTabTwin); ///< EL_EWALD_TAB
72     static constexpr bool elecEwaldTwin =
73             (elecType == ElecType::EwaldAnaTwin || elecType == ElecType::EwaldTabTwin);
74     static constexpr bool elecEwald        = (elecEwaldAna || elecEwaldTab); ///< EL_EWALD_ANY
75     static constexpr bool vdwCombLB        = (vdwType == VdwType::CutCombLB);
76     static constexpr bool vdwCombGeom      = (vdwType == VdwType::CutCombGeom); ///< LJ_COMB_GEOM
77     static constexpr bool vdwComb          = (vdwCombLB || vdwCombGeom);        ///< LJ_COMB
78     static constexpr bool vdwEwaldCombGeom = (vdwType == VdwType::EwaldGeom); ///< LJ_EWALD_COMB_GEOM
79     static constexpr bool vdwEwaldCombLB   = (vdwType == VdwType::EwaldLB);   ///< LJ_EWALD_COMB_LB
80     static constexpr bool vdwEwald         = (vdwEwaldCombGeom || vdwEwaldCombLB); ///< LJ_EWALD
81     static constexpr bool vdwFSwitch       = (vdwType == VdwType::FSwitch); ///< LJ_FORCE_SWITCH
82     static constexpr bool vdwPSwitch       = (vdwType == VdwType::PSwitch); ///< LJ_POT_SWITCH
83 };
84
85 //! \brief Templated constants to shorten kernel function declaration.
86 //@{
87 template<enum VdwType vdwType>
88 constexpr bool ljComb = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwComb;
89
90 template<enum ElecType elecType> // Yes, ElecType
91 constexpr bool vdwCutoffCheck = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTwin;
92
93 template<enum ElecType elecType>
94 constexpr bool elecEwald = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwald;
95
96 template<enum ElecType elecType>
97 constexpr bool elecEwaldTab = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTab;
98
99 template<enum VdwType vdwType>
100 constexpr bool ljEwald = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwEwald;
101 //@}
102
103 using cl::sycl::access::fence_space;
104 using cl::sycl::access::mode;
105 using cl::sycl::access::target;
106
107 static inline Float2 convertSigmaEpsilonToC6C12(const float sigma, const float epsilon)
108 {
109     const float sigma2 = sigma * sigma;
110     const float sigma6 = sigma2 * sigma2 * sigma2;
111     const float c6     = epsilon * sigma6;
112     const float c12    = c6 * sigma6;
113
114     return Float2(c6, c12);
115 }
116
117 template<bool doCalcEnergies>
118 static inline void ljForceSwitch(const shift_consts_t         dispersionShift,
119                                  const shift_consts_t         repulsionShift,
120                                  const float                  rVdwSwitch,
121                                  const float                  c6,
122                                  const float                  c12,
123                                  const float                  rInv,
124                                  const float                  r2,
125                                  cl::sycl::private_ptr<float> fInvR,
126                                  cl::sycl::private_ptr<float> eLJ)
127 {
128     /* force switch constants */
129     const float dispShiftV2 = dispersionShift.c2;
130     const float dispShiftV3 = dispersionShift.c3;
131     const float repuShiftV2 = repulsionShift.c2;
132     const float repuShiftV3 = repulsionShift.c3;
133
134     const float r       = r2 * rInv;
135     const float rSwitch = cl::sycl::fdim(r, rVdwSwitch); // max(r - rVdwSwitch, 0)
136
137     *fInvR += -c6 * (dispShiftV2 + dispShiftV3 * rSwitch) * rSwitch * rSwitch * rInv
138               + c12 * (repuShiftV2 + repuShiftV3 * rSwitch) * rSwitch * rSwitch * rInv;
139
140     if constexpr (doCalcEnergies)
141     {
142         const float dispShiftF2 = dispShiftV2 / 3;
143         const float dispShiftF3 = dispShiftV3 / 4;
144         const float repuShiftF2 = repuShiftV2 / 3;
145         const float repuShiftF3 = repuShiftV3 / 4;
146         *eLJ += c6 * (dispShiftF2 + dispShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch
147                 - c12 * (repuShiftF2 + repuShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch;
148     }
149 }
150
151 //! \brief Fetch C6 grid contribution coefficients and return the product of these.
152 template<enum VdwType vdwType>
153 static inline float calculateLJEwaldC6Grid(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
154                                            const int                                typeI,
155                                            const int                                typeJ)
156 {
157     if constexpr (vdwType == VdwType::EwaldGeom)
158     {
159         return a_nbfpComb[typeI][0] * a_nbfpComb[typeJ][0];
160     }
161     else
162     {
163         static_assert(vdwType == VdwType::EwaldLB);
164         /* sigma and epsilon are scaled to give 6*C6 */
165         const Float2 c6c12_i = a_nbfpComb[typeI];
166         const Float2 c6c12_j = a_nbfpComb[typeJ];
167
168         const float sigma   = c6c12_i[0] + c6c12_j[0];
169         const float epsilon = c6c12_i[1] * c6c12_j[1];
170
171         const float sigma2 = sigma * sigma;
172         return epsilon * sigma2 * sigma2 * sigma2;
173     }
174 }
175
176 //! Calculate LJ-PME grid force contribution with geometric or LB combination rule.
177 template<bool doCalcEnergies, enum VdwType vdwType>
178 static inline void ljEwaldComb(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
179                                const float                              sh_lj_ewald,
180                                const int                                typeI,
181                                const int                                typeJ,
182                                const float                              r2,
183                                const float                              r2Inv,
184                                const float                              lje_coeff2,
185                                const float                              lje_coeff6_6,
186                                const float                              int_bit,
187                                cl::sycl::private_ptr<float>             fInvR,
188                                cl::sycl::private_ptr<float>             eLJ)
189 {
190     const float c6grid = calculateLJEwaldC6Grid<vdwType>(a_nbfpComb, typeI, typeJ);
191
192     /* Recalculate inv_r6 without exclusion mask */
193     const float inv_r6_nm = r2Inv * r2Inv * r2Inv;
194     const float cr2       = lje_coeff2 * r2;
195     const float expmcr2   = cl::sycl::exp(-cr2);
196     const float poly      = 1.0F + cr2 + 0.5F * cr2 * cr2;
197
198     /* Subtract the grid force from the total LJ force */
199     *fInvR += c6grid * (inv_r6_nm - expmcr2 * (inv_r6_nm * poly + lje_coeff6_6)) * r2Inv;
200
201     if constexpr (doCalcEnergies)
202     {
203         /* Shift should be applied only to real LJ pairs */
204         const float sh_mask = sh_lj_ewald * int_bit;
205         *eLJ += c_oneSixth * c6grid * (inv_r6_nm * (1.0F - expmcr2 * poly) + sh_mask);
206     }
207 }
208
209 /*! \brief Apply potential switch. */
210 template<bool doCalcEnergies>
211 static inline void ljPotentialSwitch(const switch_consts_t        vdwSwitch,
212                                      const float                  rVdwSwitch,
213                                      const float                  rInv,
214                                      const float                  r2,
215                                      cl::sycl::private_ptr<float> fInvR,
216                                      cl::sycl::private_ptr<float> eLJ)
217 {
218     /* potential switch constants */
219     const float switchV3 = vdwSwitch.c3;
220     const float switchV4 = vdwSwitch.c4;
221     const float switchV5 = vdwSwitch.c5;
222     const float switchF2 = 3 * vdwSwitch.c3;
223     const float switchF3 = 4 * vdwSwitch.c4;
224     const float switchF4 = 5 * vdwSwitch.c5;
225
226     const float r       = r2 * rInv;
227     const float rSwitch = r - rVdwSwitch;
228
229     if (rSwitch > 0.0F)
230     {
231         const float sw =
232                 1.0F + (switchV3 + (switchV4 + switchV5 * rSwitch) * rSwitch) * rSwitch * rSwitch * rSwitch;
233         const float dsw = (switchF2 + (switchF3 + switchF4 * rSwitch) * rSwitch) * rSwitch * rSwitch;
234
235         *fInvR = (*fInvR) * sw - rInv * (*eLJ) * dsw;
236         if constexpr (doCalcEnergies)
237         {
238             *eLJ *= sw;
239         }
240     }
241 }
242
243
244 /*! \brief Calculate analytical Ewald correction term. */
245 static inline float pmeCorrF(const float z2)
246 {
247     constexpr float FN6 = -1.7357322914161492954e-8F;
248     constexpr float FN5 = 1.4703624142580877519e-6F;
249     constexpr float FN4 = -0.000053401640219807709149F;
250     constexpr float FN3 = 0.0010054721316683106153F;
251     constexpr float FN2 = -0.019278317264888380590F;
252     constexpr float FN1 = 0.069670166153766424023F;
253     constexpr float FN0 = -0.75225204789749321333F;
254
255     constexpr float FD4 = 0.0011193462567257629232F;
256     constexpr float FD3 = 0.014866955030185295499F;
257     constexpr float FD2 = 0.11583842382862377919F;
258     constexpr float FD1 = 0.50736591960530292870F;
259     constexpr float FD0 = 1.0F;
260
261     const float z4 = z2 * z2;
262
263     float       polyFD0 = FD4 * z4 + FD2;
264     const float polyFD1 = FD3 * z4 + FD1;
265     polyFD0             = polyFD0 * z4 + FD0;
266     polyFD0             = polyFD1 * z2 + polyFD0;
267
268     polyFD0 = 1.0F / polyFD0;
269
270     float polyFN0 = FN6 * z4 + FN4;
271     float polyFN1 = FN5 * z4 + FN3;
272     polyFN0       = polyFN0 * z4 + FN2;
273     polyFN1       = polyFN1 * z4 + FN1;
274     polyFN0       = polyFN0 * z4 + FN0;
275     polyFN0       = polyFN1 * z2 + polyFN0;
276
277     return polyFN0 * polyFD0;
278 }
279
280 /*! \brief Linear interpolation using exactly two FMA operations.
281  *
282  *  Implements numeric equivalent of: (1-t)*d0 + t*d1.
283  */
284 template<typename T>
285 static inline T lerp(T d0, T d1, T t)
286 {
287     return fma(t, d1, fma(-t, d0, d0));
288 }
289
290 /*! \brief Interpolate Ewald coulomb force correction using the F*r table. */
291 static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::read> a_coulombTab,
292                                              const float coulombTabScale,
293                                              const float r)
294 {
295     const float normalized = coulombTabScale * r;
296     const int   index      = static_cast<int>(normalized);
297     const float fraction   = normalized - index;
298
299     const float left  = a_coulombTab[index];
300     const float right = a_coulombTab[index + 1];
301
302     return lerp(left, right, fraction); // TODO: cl::sycl::mix
303 }
304
305 /*! \brief Reduce c_clSize j-force components using shifts and atomically accumulate into a_f.
306  *
307  * c_clSize consecutive threads hold the force components of a j-atom which we
308  * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
309  */
310 static inline void reduceForceJShuffle(Float3                             f,
311                                        const cl::sycl::nd_item<1>         itemIdx,
312                                        const int                          tidxi,
313                                        const int                          aidx,
314                                        DeviceAccessor<float, mode_atomic> a_f)
315 {
316     static_assert(c_clSize == 8 || c_clSize == 4);
317     sycl_2020::sub_group sg = itemIdx.get_sub_group();
318
319     f[0] += sycl_2020::shift_left(sg, f[0], 1);
320     f[1] += sycl_2020::shift_right(sg, f[1], 1);
321     f[2] += sycl_2020::shift_left(sg, f[2], 1);
322     if (tidxi & 1)
323     {
324         f[0] = f[1];
325     }
326
327     f[0] += sycl_2020::shift_left(sg, f[0], 2);
328     f[2] += sycl_2020::shift_right(sg, f[2], 2);
329     if (tidxi & 2)
330     {
331         f[0] = f[2];
332     }
333
334     if constexpr (c_clSize == 8)
335     {
336         f[0] += sycl_2020::shift_left(sg, f[0], 4);
337     }
338
339     if (tidxi < 3)
340     {
341         atomicFetchAdd(a_f, 3 * aidx + tidxi, f[0]);
342     }
343 }
344
345 // This function also requires sm_buf to have a length of at least 1.
346 // The function returns:
347 //     - for thread #0 in the group: sum of all valueToReduce in a group
348 //     - for other threads: unspecified
349 template<int subGroupSize, int groupSize>
350 static inline float groupReduce(const cl::sycl::nd_item<1> itemIdx,
351                                 const unsigned int         tidxi,
352                                 cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
353                                 float valueToReduce)
354 {
355     constexpr int numSubGroupsInGroup = groupSize / subGroupSize;
356     static_assert(numSubGroupsInGroup == 1 || numSubGroupsInGroup == 2);
357     sycl_2020::sub_group sg = itemIdx.get_sub_group();
358     valueToReduce           = sycl_2020::group_reduce(sg, valueToReduce, sycl_2020::plus<float>());
359     // If we have two sub-groups, we should reduce across them.
360     if constexpr (numSubGroupsInGroup == 2)
361     {
362         if (tidxi == subGroupSize)
363         {
364             sm_buf[0] = valueToReduce;
365         }
366         itemIdx.barrier(fence_space::local_space);
367         if (tidxi == 0)
368         {
369             valueToReduce += sm_buf[0];
370         }
371     }
372     return valueToReduce;
373 }
374
375 /*! \brief Reduce c_clSize j-force components using local memory and atomically accumulate into a_f.
376  *
377  * c_clSize consecutive threads hold the force components of a j-atom which we
378  * reduced in cl_Size steps using shift and atomically accumulate them into \p a_f.
379  *
380  * TODO: implement binary reduction flavor for the case where cl_Size is power of two.
381  */
382 static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
383                                        Float3                             f,
384                                        const cl::sycl::nd_item<1>         itemIdx,
385                                        const int                          tidxi,
386                                        const int                          tidxj,
387                                        const int                          aidx,
388                                        DeviceAccessor<float, mode_atomic> a_f)
389 {
390     static constexpr int sc_fBufferStride = c_clSizeSq;
391     int                  tidx             = tidxi + tidxj * c_clSize;
392     sm_buf[0 * sc_fBufferStride + tidx]   = f[0];
393     sm_buf[1 * sc_fBufferStride + tidx]   = f[1];
394     sm_buf[2 * sc_fBufferStride + tidx]   = f[2];
395
396     subGroupBarrier(itemIdx);
397
398     // reducing data 8-by-by elements on the leader of same threads as those storing above
399     assert(itemIdx.get_sub_group().get_local_range().size() >= c_clSize);
400
401     if (tidxi < 3)
402     {
403         float fSum = 0.0F;
404         for (int j = tidxj * c_clSize; j < (tidxj + 1) * c_clSize; j++)
405         {
406             fSum += sm_buf[sc_fBufferStride * tidxi + j];
407         }
408
409         atomicFetchAdd(a_f, 3 * aidx + tidxi, fSum);
410     }
411 }
412
413
414 /*! \brief Reduce c_clSize j-force components using either shifts or local memory and atomically accumulate into a_f.
415  */
416 static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
417                                 Float3                                                        f,
418                                 const cl::sycl::nd_item<1>         itemIdx,
419                                 const int                          tidxi,
420                                 const int                          tidxj,
421                                 const int                          aidx,
422                                 DeviceAccessor<float, mode_atomic> a_f)
423 {
424     if constexpr (!gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster))
425     {
426         reduceForceJGeneric(sm_buf, f, itemIdx, tidxi, tidxj, aidx, a_f);
427     }
428     else
429     {
430         reduceForceJShuffle(f, itemIdx, tidxi, aidx, a_f);
431     }
432 }
433
434
435 /*! \brief Final i-force reduction.
436  *
437  * Reduce c_nbnxnGpuNumClusterPerSupercluster i-force componets stored in \p fCiBuf[]
438  * accumulating atomically into \p a_f.
439  * If \p calcFShift is true, further reduce shift forces and atomically accumulate into \p a_fShift.
440  *
441  * This implementation works only with power of two array sizes.
442  */
443 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
444                                          const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
445                                          const bool   calcFShift,
446                                          const cl::sycl::nd_item<1>         itemIdx,
447                                          const int                          tidxi,
448                                          const int                          tidxj,
449                                          const int                          sci,
450                                          const int                          shift,
451                                          DeviceAccessor<float, mode_atomic> a_f,
452                                          DeviceAccessor<float, mode_atomic> a_fShift)
453 {
454     // must have power of two elements in fCiBuf
455     static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster));
456
457     static constexpr int bufStride  = c_clSize * c_clSize;
458     static constexpr int clSizeLog2 = gmx::StaticLog2<c_clSize>::value;
459     const int            tidx       = tidxi + tidxj * c_clSize;
460     float                fShiftBuf  = 0.0F;
461     for (int ciOffset = 0; ciOffset < c_nbnxnGpuNumClusterPerSupercluster; ciOffset++)
462     {
463         const int aidx = (sci * c_nbnxnGpuNumClusterPerSupercluster + ciOffset) * c_clSize + tidxi;
464         /* store i forces in shmem */
465         sm_buf[tidx]                 = fCiBuf[ciOffset][0];
466         sm_buf[bufStride + tidx]     = fCiBuf[ciOffset][1];
467         sm_buf[2 * bufStride + tidx] = fCiBuf[ciOffset][2];
468         itemIdx.barrier(fence_space::local_space);
469
470         /* Reduce the initial c_clSize values for each i atom to half
471          * every step by using c_clSize * i threads. */
472         int i = c_clSize / 2;
473         for (int j = clSizeLog2 - 1; j > 0; j--)
474         {
475             if (tidxj < i)
476             {
477                 sm_buf[tidxj * c_clSize + tidxi] += sm_buf[(tidxj + i) * c_clSize + tidxi];
478                 sm_buf[bufStride + tidxj * c_clSize + tidxi] +=
479                         sm_buf[bufStride + (tidxj + i) * c_clSize + tidxi];
480                 sm_buf[2 * bufStride + tidxj * c_clSize + tidxi] +=
481                         sm_buf[2 * bufStride + (tidxj + i) * c_clSize + tidxi];
482             }
483             i >>= 1;
484             itemIdx.barrier(fence_space::local_space);
485         }
486
487         /* i == 1, last reduction step, writing to global mem */
488         /* Split the reduction between the first 3 line threads
489            Threads with line id 0 will do the reduction for (float3).x components
490            Threads with line id 1 will do the reduction for (float3).y components
491            Threads with line id 2 will do the reduction for (float3).z components. */
492         if (tidxj < 3)
493         {
494             const float f =
495                     sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
496             atomicFetchAdd(a_f, 3 * aidx + tidxj, f);
497             if (calcFShift)
498             {
499                 fShiftBuf += f;
500             }
501         }
502         itemIdx.barrier(fence_space::local_space);
503     }
504     /* add up local shift forces into global mem */
505     if (calcFShift)
506     {
507         /* Only threads with tidxj < 3 will update fshift.
508            The threads performing the update must be the same as the threads
509            storing the reduction result above. */
510         if (tidxj < 3)
511         {
512             if constexpr (c_clSize == 4)
513             {
514                 /* Intel Xe (Gen12LP) and earlier GPUs implement floating-point atomics via
515                  * a compare-and-swap (CAS) loop. It has particularly poor performance when
516                  * updating the same memory location from the same work-group.
517                  * Such optimization might be slightly beneficial for NVIDIA and AMD as well,
518                  * but it is unlikely to make a big difference and thus was not evaluated.
519                  */
520                 auto sg = itemIdx.get_sub_group();
521                 fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 1);
522                 fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 2);
523                 if (tidxi == 0)
524                 {
525                     atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
526                 }
527             }
528             else
529             {
530                 atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
531             }
532         }
533     }
534 }
535
536 /*! \brief Main kernel for NBNXM.
537  *
538  */
539 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
540 auto nbnxmKernel(cl::sycl::handler&                                   cgh,
541                  DeviceAccessor<Float4, mode::read>                   a_xq,
542                  DeviceAccessor<float, mode_atomic>                   a_f,
543                  DeviceAccessor<Float3, mode::read>                   a_shiftVec,
544                  DeviceAccessor<float, mode_atomic>                   a_fShift,
545                  OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyElec,
546                  OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyVdw,
547                  DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
548                  DeviceAccessor<nbnxn_sci_t, mode::read>                                 a_plistSci,
549                  DeviceAccessor<nbnxn_excl_t, mode::read>                    a_plistExcl,
550                  OptionalAccessor<Float2, mode::read, ljComb<vdwType>>       a_ljComb,
551                  OptionalAccessor<int, mode::read, !ljComb<vdwType>>         a_atomTypes,
552                  OptionalAccessor<Float2, mode::read, !ljComb<vdwType>>      a_nbfp,
553                  OptionalAccessor<Float2, mode::read, ljEwald<vdwType>>      a_nbfpComb,
554                  OptionalAccessor<float, mode::read, elecEwaldTab<elecType>> a_coulombTab,
555                  const int                                                   numTypes,
556                  const float                                                 rCoulombSq,
557                  const float                                                 rVdwSq,
558                  const float                                                 twoKRf,
559                  const float                                                 ewaldBeta,
560                  const float                                                 rlistOuterSq,
561                  const float                                                 ewaldShift,
562                  const float                                                 epsFac,
563                  const float                                                 ewaldCoeffLJ,
564                  const float                                                 cRF,
565                  const shift_consts_t                                        dispersionShift,
566                  const shift_consts_t                                        repulsionShift,
567                  const switch_consts_t                                       vdwSwitch,
568                  const float                                                 rVdwSwitch,
569                  const float                                                 ljEwaldShift,
570                  const float                                                 coulombTabScale,
571                  const bool                                                  calcShift)
572 {
573     static constexpr EnergyFunctionProperties<elecType, vdwType> props;
574
575     cgh.require(a_xq);
576     cgh.require(a_f);
577     cgh.require(a_shiftVec);
578     cgh.require(a_fShift);
579     if constexpr (doCalcEnergies)
580     {
581         cgh.require(a_energyElec);
582         cgh.require(a_energyVdw);
583     }
584     cgh.require(a_plistCJ4);
585     cgh.require(a_plistSci);
586     cgh.require(a_plistExcl);
587     if constexpr (!props.vdwComb)
588     {
589         cgh.require(a_atomTypes);
590         cgh.require(a_nbfp);
591     }
592     else
593     {
594         cgh.require(a_ljComb);
595     }
596     if constexpr (props.vdwEwald)
597     {
598         cgh.require(a_nbfpComb);
599     }
600     if constexpr (props.elecEwaldTab)
601     {
602         cgh.require(a_coulombTab);
603     }
604
605     // shmem buffer for i x+q pre-loading
606     cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(
607             cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
608
609     // shmem buffer for force reduction
610     // SYCL-TODO: Make into 3D; section 4.7.6.11 of SYCL2020 specs
611     cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_reductionBuffer(
612             cl::sycl::range<1>(c_clSize * c_clSize * DIM), cgh);
613
614     auto sm_atomTypeI = [&]() {
615         if constexpr (!props.vdwComb)
616         {
617             return cl::sycl::accessor<int, 2, mode::read_write, target::local>(
618                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
619         }
620         else
621         {
622             return nullptr;
623         }
624     }();
625
626     auto sm_ljCombI = [&]() {
627         if constexpr (props.vdwComb)
628         {
629             return cl::sycl::accessor<Float2, 2, mode::read_write, target::local>(
630                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
631         }
632         else
633         {
634             return nullptr;
635         }
636     }();
637
638     /* Flag to control the calculation of exclusion forces in the kernel
639      * We do that with Ewald (elec/vdw) and RF. Cut-off only has exclusion
640      * energy terms. */
641     constexpr bool doExclusionForces =
642             (props.elecEwald || props.elecRF || props.vdwEwald || (props.elecCutoff && doCalcEnergies));
643
644     // The post-prune j-i cluster-pair organization is linked to how exclusion and interaction mask data is stored.
645     // Currently this is ideally suited for 32-wide subgroup size but slightly less so for others,
646     // e.g. subGroupSize > prunedClusterPairSize on AMD GCN / CDNA.
647     // Hence, the two are decoupled.
648     // When changing this code, please update requiredSubGroupSizeForNbnxm in src/gromacs/hardware/device_management_sycl.cpp.
649     constexpr int prunedClusterPairSize = c_clSize * c_splitClSize;
650 #if defined(HIPSYCL_PLATFORM_ROCM) // SYCL-TODO AMD RDNA/RDNA2 has 32-wide exec; how can we check for that?
651     gmx_unused constexpr int subGroupSize = c_clSize * c_clSize;
652 #else
653     gmx_unused constexpr int subGroupSize = prunedClusterPairSize;
654 #endif
655
656     return [=](cl::sycl::nd_item<1> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
657     {
658         /* thread/block/warp id-s */
659         const cl::sycl::id<3> localId = unflattenId<c_clSize, c_clSize>(itemIdx.get_local_id());
660         const unsigned        tidxi   = localId[0];
661         const unsigned        tidxj   = localId[1];
662         const unsigned        tidx    = tidxj * c_clSize + tidxi;
663         const unsigned        tidxz   = 0;
664
665         // Group indexing was flat originally, no need to unflatten it.
666         const unsigned bidx = itemIdx.get_group(0);
667
668         const sycl_2020::sub_group sg = itemIdx.get_sub_group();
669         // Could use sg.get_group_range to compute the imask & exclusion Idx, but too much of the logic relies on it anyway
670         // and in cases where prunedClusterPairSize != subGroupSize we can't use it anyway
671         const unsigned imeiIdx = tidx / prunedClusterPairSize;
672
673         Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster]; // i force buffer
674         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
675         {
676             fCiBuf[i] = Float3(0.0F, 0.0F, 0.0F);
677         }
678
679         const nbnxn_sci_t nbSci     = a_plistSci[bidx];
680         const int         sci       = nbSci.sci;
681         const int         cij4Start = nbSci.cj4_ind_start;
682         const int         cij4End   = nbSci.cj4_ind_end;
683
684         // Only needed if props.elecEwaldAna
685         const float beta2 = ewaldBeta * ewaldBeta;
686         const float beta3 = ewaldBeta * ewaldBeta * ewaldBeta;
687
688         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i += c_clSize)
689         {
690             /* Pre-load i-atom x and q into shared memory */
691             const int             ci       = sci * c_nbnxnGpuNumClusterPerSupercluster + tidxj + i;
692             const int             ai       = ci * c_clSize + tidxi;
693             const cl::sycl::id<2> cacheIdx = cl::sycl::id<2>(tidxj + i, tidxi);
694
695             const Float3 shift = a_shiftVec[nbSci.shift];
696             Float4       xqi   = a_xq[ai];
697             xqi += Float4(shift[0], shift[1], shift[2], 0.0F);
698             xqi[3] *= epsFac;
699             sm_xq[cacheIdx] = xqi;
700
701             if constexpr (!props.vdwComb)
702             {
703                 // Pre-load the i-atom types into shared memory
704                 sm_atomTypeI[cacheIdx] = a_atomTypes[ai];
705             }
706             else
707             {
708                 // Pre-load the LJ combination parameters into shared memory
709                 sm_ljCombI[cacheIdx] = a_ljComb[ai];
710             }
711         }
712         itemIdx.barrier(fence_space::local_space);
713
714         float ewaldCoeffLJ_2, ewaldCoeffLJ_6_6; // Only needed if (props.vdwEwald)
715         if constexpr (props.vdwEwald)
716         {
717             ewaldCoeffLJ_2   = ewaldCoeffLJ * ewaldCoeffLJ;
718             ewaldCoeffLJ_6_6 = ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * c_oneSixth;
719         }
720
721         float energyVdw, energyElec; // Only needed if (doCalcEnergies)
722         if constexpr (doCalcEnergies)
723         {
724             energyVdw = energyElec = 0.0F;
725         }
726         if constexpr (doCalcEnergies && doExclusionForces)
727         {
728             if (nbSci.shift == gmx::c_centralShiftIndex
729                 && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
730             {
731                 // we have the diagonal: add the charge and LJ self interaction energy term
732                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
733                 {
734                     // TODO: Are there other options?
735                     if constexpr (props.elecEwald || props.elecRF || props.elecCutoff)
736                     {
737                         const float qi = sm_xq[i][tidxi][3];
738                         energyElec += qi * qi;
739                     }
740                     if constexpr (props.vdwEwald)
741                     {
742                         energyVdw +=
743                                 a_nbfp[a_atomTypes[(sci * c_nbnxnGpuNumClusterPerSupercluster + i) * c_clSize + tidxi]
744                                        * (numTypes + 1)][0];
745                     }
746                 }
747                 /* divide the self term(s) equally over the j-threads, then multiply with the coefficients. */
748                 if constexpr (props.vdwEwald)
749                 {
750                     energyVdw /= c_clSize;
751                     energyVdw *= 0.5F * c_oneSixth * ewaldCoeffLJ_6_6; // c_OneTwelfth?
752                 }
753                 if constexpr (props.elecRF || props.elecCutoff)
754                 {
755                     // Correct for epsfac^2 due to adding qi^2 */
756                     energyElec /= epsFac * c_clSize;
757                     energyElec *= -0.5F * cRF;
758                 }
759                 if constexpr (props.elecEwald)
760                 {
761                     // Correct for epsfac^2 due to adding qi^2 */
762                     energyElec /= epsFac * c_clSize;
763                     energyElec *= -ewaldBeta * c_OneOverSqrtPi; /* last factor 1/sqrt(pi) */
764                 }
765             } // (nbSci.shift == gmx::c_centralShiftIndex && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
766         }     // (doCalcEnergies && doExclusionForces)
767
768         // Only needed if (doExclusionForces)
769         const bool nonSelfInteraction = !(nbSci.shift == gmx::c_centralShiftIndex & tidxj <= tidxi);
770
771         // loop over the j clusters = seen by any of the atoms in the current super-cluster
772         for (int j4 = cij4Start + tidxz; j4 < cij4End; j4 += 1)
773         {
774             unsigned imask = a_plistCJ4[j4].imei[imeiIdx].imask;
775             if (!doPruneNBL && !imask)
776             {
777                 continue;
778             }
779             const int wexclIdx = a_plistCJ4[j4].imei[imeiIdx].excl_ind;
780             static_assert(gmx::isPowerOfTwo(prunedClusterPairSize));
781             const unsigned wexcl = a_plistExcl[wexclIdx].pair[tidx & (prunedClusterPairSize - 1)];
782             for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
783             {
784                 const bool maskSet =
785                         imask & (superClInteractionMask << (jm * c_nbnxnGpuNumClusterPerSupercluster));
786                 if (!maskSet)
787                 {
788                     continue;
789                 }
790                 unsigned  maskJI = (1U << (jm * c_nbnxnGpuNumClusterPerSupercluster));
791                 const int cj     = a_plistCJ4[j4].cj[jm];
792                 const int aj     = cj * c_clSize + tidxj;
793
794                 // load j atom data
795                 const Float4 xqj = a_xq[aj];
796
797                 const Float3 xj(xqj[0], xqj[1], xqj[2]);
798                 const float  qj = xqj[3];
799                 int          atomTypeJ; // Only needed if (!props.vdwComb)
800                 Float2       ljCombJ;   // Only needed if (props.vdwComb)
801                 if constexpr (props.vdwComb)
802                 {
803                     ljCombJ = a_ljComb[aj];
804                 }
805                 else
806                 {
807                     atomTypeJ = a_atomTypes[aj];
808                 }
809
810                 Float3 fCjBuf(0.0F, 0.0F, 0.0F);
811
812                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
813                 {
814                     if (imask & maskJI)
815                     {
816                         // i cluster index
817                         const int ci = sci * c_nbnxnGpuNumClusterPerSupercluster + i;
818                         // all threads load an atom from i cluster ci into shmem!
819                         const Float4 xqi = sm_xq[i][tidxi];
820                         const Float3 xi(xqi[0], xqi[1], xqi[2]);
821
822                         // distance between i and j atoms
823                         const Float3 rv = xi - xj;
824                         float        r2 = norm2(rv);
825
826                         if constexpr (doPruneNBL)
827                         {
828                             /* If _none_ of the atoms pairs are in cutoff range,
829                              * the bit corresponding to the current
830                              * cluster-pair in imask gets set to 0. */
831                             if (!sycl_2020::group_any_of(sg, r2 < rlistOuterSq))
832                             {
833                                 imask &= ~maskJI;
834                             }
835                         }
836                         const float pairExclMask = (wexcl & maskJI) ? 1.0F : 0.0F;
837
838                         // cutoff & exclusion check
839
840                         const bool notExcluded = doExclusionForces ? (nonSelfInteraction | (ci != cj))
841                                                                    : (wexcl & maskJI);
842
843                         // SYCL-TODO: Check optimal way of branching here.
844                         if ((r2 < rCoulombSq) && notExcluded)
845                         {
846                             const float qi = xqi[3];
847                             int         atomTypeI; // Only needed if (!props.vdwComb)
848                             float       sigma, epsilon;
849                             Float2      c6c12;
850
851                             if constexpr (!props.vdwComb)
852                             {
853                                 /* LJ 6*C6 and 12*C12 */
854                                 atomTypeI = sm_atomTypeI[i][tidxi];
855                                 c6c12     = a_nbfp[numTypes * atomTypeI + atomTypeJ];
856                             }
857                             else
858                             {
859                                 const Float2 ljCombI = sm_ljCombI[i][tidxi];
860                                 if constexpr (props.vdwCombGeom)
861                                 {
862                                     c6c12 = Float2(ljCombI[0] * ljCombJ[0], ljCombI[1] * ljCombJ[1]);
863                                 }
864                                 else
865                                 {
866                                     static_assert(props.vdwCombLB);
867                                     // LJ 2^(1/6)*sigma and 12*epsilon
868                                     sigma   = ljCombI[0] + ljCombJ[0];
869                                     epsilon = ljCombI[1] * ljCombJ[1];
870                                     if constexpr (doCalcEnergies)
871                                     {
872                                         c6c12 = convertSigmaEpsilonToC6C12(sigma, epsilon);
873                                     }
874                                 } // props.vdwCombGeom
875                             }     // !props.vdwComb
876
877                             // c6 and c12 are unused and garbage iff props.vdwCombLB && !doCalcEnergies
878                             const float c6  = c6c12[0];
879                             const float c12 = c6c12[1];
880
881                             // Ensure distance do not become so small that r^-12 overflows
882                             r2 = std::max(r2, c_nbnxnMinDistanceSquared);
883 #if GMX_SYCL_HIPSYCL
884                             // No fast/native functions in some compilation passes
885                             const float rInv = cl::sycl::rsqrt(r2);
886 #else
887                             // SYCL-TODO: sycl::half_precision::rsqrt?
888                             const float rInv = cl::sycl::native::rsqrt(r2);
889 #endif
890                             const float r2Inv = rInv * rInv;
891                             float       r6Inv, fInvR, energyLJPair;
892                             if constexpr (!props.vdwCombLB || doCalcEnergies)
893                             {
894                                 r6Inv = r2Inv * r2Inv * r2Inv;
895                                 if constexpr (doExclusionForces)
896                                 {
897                                     // SYCL-TODO: Check if true for SYCL
898                                     /* We could mask r2Inv, but with Ewald masking both
899                                      * r6Inv and fInvR is faster */
900                                     r6Inv *= pairExclMask;
901                                 }
902                                 fInvR = r6Inv * (c12 * r6Inv - c6) * r2Inv;
903                             }
904                             else
905                             {
906                                 float sig_r  = sigma * rInv;
907                                 float sig_r2 = sig_r * sig_r;
908                                 float sig_r6 = sig_r2 * sig_r2 * sig_r2;
909                                 if constexpr (doExclusionForces)
910                                 {
911                                     sig_r6 *= pairExclMask;
912                                 }
913                                 fInvR = epsilon * sig_r6 * (sig_r6 - 1.0F) * r2Inv;
914                             } // (!props.vdwCombLB || doCalcEnergies)
915                             if constexpr (doCalcEnergies || props.vdwPSwitch)
916                             {
917                                 energyLJPair = pairExclMask
918                                                * (c12 * (r6Inv * r6Inv + repulsionShift.cpot) * c_oneTwelfth
919                                                   - c6 * (r6Inv + dispersionShift.cpot) * c_oneSixth);
920                             }
921                             if constexpr (props.vdwFSwitch)
922                             {
923                                 ljForceSwitch<doCalcEnergies>(
924                                         dispersionShift, repulsionShift, rVdwSwitch, c6, c12, rInv, r2, &fInvR, &energyLJPair);
925                             }
926                             if constexpr (props.vdwEwald)
927                             {
928                                 ljEwaldComb<doCalcEnergies, vdwType>(a_nbfpComb,
929                                                                      ljEwaldShift,
930                                                                      atomTypeI,
931                                                                      atomTypeJ,
932                                                                      r2,
933                                                                      r2Inv,
934                                                                      ewaldCoeffLJ_2,
935                                                                      ewaldCoeffLJ_6_6,
936                                                                      pairExclMask,
937                                                                      &fInvR,
938                                                                      &energyLJPair);
939                             } // (props.vdwEwald)
940                             if constexpr (props.vdwPSwitch)
941                             {
942                                 ljPotentialSwitch<doCalcEnergies>(
943                                         vdwSwitch, rVdwSwitch, rInv, r2, &fInvR, &energyLJPair);
944                             }
945                             if constexpr (props.elecEwaldTwin)
946                             {
947                                 // Separate VDW cut-off check to enable twin-range cut-offs
948                                 // (rVdw < rCoulomb <= rList)
949                                 const float vdwInRange = (r2 < rVdwSq) ? 1.0F : 0.0F;
950                                 fInvR *= vdwInRange;
951                                 if constexpr (doCalcEnergies)
952                                 {
953                                     energyLJPair *= vdwInRange;
954                                 }
955                             }
956                             if constexpr (doCalcEnergies)
957                             {
958                                 energyVdw += energyLJPair;
959                             }
960
961                             if constexpr (props.elecCutoff)
962                             {
963                                 if constexpr (doExclusionForces)
964                                 {
965                                     fInvR += qi * qj * pairExclMask * r2Inv * rInv;
966                                 }
967                                 else
968                                 {
969                                     fInvR += qi * qj * r2Inv * rInv;
970                                 }
971                             }
972                             if constexpr (props.elecRF)
973                             {
974                                 fInvR += qi * qj * (pairExclMask * r2Inv * rInv - twoKRf);
975                             }
976                             if constexpr (props.elecEwaldAna)
977                             {
978                                 fInvR += qi * qj
979                                          * (pairExclMask * r2Inv * rInv + pmeCorrF(beta2 * r2) * beta3);
980                             }
981                             if constexpr (props.elecEwaldTab)
982                             {
983                                 fInvR += qi * qj
984                                          * (pairExclMask * r2Inv
985                                             - interpolateCoulombForceR(
986                                                     a_coulombTab, coulombTabScale, r2 * rInv))
987                                          * rInv;
988                             }
989
990                             if constexpr (doCalcEnergies)
991                             {
992                                 if constexpr (props.elecCutoff)
993                                 {
994                                     energyElec += qi * qj * (pairExclMask * rInv - cRF);
995                                 }
996                                 if constexpr (props.elecRF)
997                                 {
998                                     energyElec +=
999                                             qi * qj * (pairExclMask * rInv + 0.5f * twoKRf * r2 - cRF);
1000                                 }
1001                                 if constexpr (props.elecEwald)
1002                                 {
1003                                     energyElec +=
1004                                             qi * qj
1005                                             * (rInv * (pairExclMask - cl::sycl::erf(r2 * rInv * ewaldBeta))
1006                                                - pairExclMask * ewaldShift);
1007                                 }
1008                             }
1009
1010                             const Float3 forceIJ = rv * fInvR;
1011
1012                             /* accumulate j forces in registers */
1013                             fCjBuf -= forceIJ;
1014                             /* accumulate i forces in registers */
1015                             fCiBuf[i] += forceIJ;
1016                         } // (r2 < rCoulombSq) && notExcluded
1017                     }     // (imask & maskJI)
1018                     /* shift the mask bit by 1 */
1019                     maskJI += maskJI;
1020                 } // for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
1021                 /* reduce j forces */
1022                 reduceForceJ(sm_reductionBuffer, fCjBuf, itemIdx, tidxi, tidxj, aj, a_f);
1023             } // for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
1024             if constexpr (doPruneNBL)
1025             {
1026                 /* Update the imask with the new one which does not contain the
1027                  * out of range clusters anymore. */
1028                 a_plistCJ4[j4].imei[imeiIdx].imask = imask;
1029             }
1030         } // for (int j4 = cij4Start; j4 < cij4End; j4 += 1)
1031
1032         /* skip central shifts when summing shift forces */
1033         const bool doCalcShift = (calcShift && !(nbSci.shift == gmx::c_centralShiftIndex));
1034
1035         reduceForceIAndFShift(
1036                 sm_reductionBuffer, fCiBuf, doCalcShift, itemIdx, tidxi, tidxj, sci, nbSci.shift, a_f, a_fShift);
1037
1038         if constexpr (doCalcEnergies)
1039         {
1040             const float energyVdwGroup =
1041                     groupReduce<subGroupSize, c_clSizeSq>(itemIdx, tidx, sm_reductionBuffer, energyVdw);
1042             const float energyElecGroup = groupReduce<subGroupSize, c_clSizeSq>(
1043                     itemIdx, tidx, sm_reductionBuffer, energyElec);
1044
1045             if (tidx == 0)
1046             {
1047                 atomicFetchAdd(a_energyVdw, 0, energyVdwGroup);
1048                 atomicFetchAdd(a_energyElec, 0, energyElecGroup);
1049             }
1050         }
1051     };
1052 }
1053
1054 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType, class... Args>
1055 cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int numSci, Args&&... args)
1056 {
1057     using kernelNameType = NbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>;
1058
1059     /* Kernel launch config:
1060      * - The thread block dimensions match the size of i-clusters, j-clusters,
1061      *   and j-cluster concurrency, in x, y, and z, respectively.
1062      * - The 1D block-grid contains as many blocks as super-clusters.
1063      */
1064     const int                   numBlocks = numSci;
1065     const cl::sycl::range<3>    blockSize{ c_clSize, c_clSize, 1 };
1066     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
1067     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
1068
1069     cl::sycl::queue q = deviceStream.stream();
1070
1071     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
1072         auto kernel = nbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>(
1073                 cgh, std::forward<Args>(args)...);
1074         cgh.parallel_for<kernelNameType>(flattenNDRange(range), kernel);
1075     });
1076
1077     return e;
1078 }
1079
1080 template<class... Args>
1081 cl::sycl::event chooseAndLaunchNbnxmKernel(bool          doPruneNBL,
1082                                            bool          doCalcEnergies,
1083                                            enum ElecType elecType,
1084                                            enum VdwType  vdwType,
1085                                            Args&&... args)
1086 {
1087     return gmx::dispatchTemplatedFunction(
1088             [&](auto doPruneNBL_, auto doCalcEnergies_, auto elecType_, auto vdwType_) {
1089                 return launchNbnxmKernel<doPruneNBL_, doCalcEnergies_, elecType_, vdwType_>(
1090                         std::forward<Args>(args)...);
1091             },
1092             doPruneNBL,
1093             doCalcEnergies,
1094             elecType,
1095             vdwType);
1096 }
1097
1098 void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const InteractionLocality iloc)
1099 {
1100     NBAtomDataGpu*      adat         = nb->atdat;
1101     NBParamGpu*         nbp          = nb->nbparam;
1102     gpu_plist*          plist        = nb->plist[iloc];
1103     const bool          doPruneNBL   = (plist->haveFreshList && !nb->didPrune[iloc]);
1104     const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
1105
1106     // Casting to float simplifies using atomic ops in the kernel
1107     cl::sycl::buffer<Float3, 1> f(*adat->f.buffer_);
1108     auto                        fAsFloat = f.reinterpret<float, 1>(f.get_count() * DIM);
1109     cl::sycl::buffer<Float3, 1> fShift(*adat->fShift.buffer_);
1110     auto fShiftAsFloat = fShift.reinterpret<float, 1>(fShift.get_count() * DIM);
1111
1112     cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL,
1113                                                    stepWork.computeEnergy,
1114                                                    nbp->elecType,
1115                                                    nbp->vdwType,
1116                                                    deviceStream,
1117                                                    plist->nsci,
1118                                                    adat->xq,
1119                                                    fAsFloat,
1120                                                    adat->shiftVec,
1121                                                    fShiftAsFloat,
1122                                                    adat->eElec,
1123                                                    adat->eLJ,
1124                                                    plist->cj4,
1125                                                    plist->sci,
1126                                                    plist->excl,
1127                                                    adat->ljComb,
1128                                                    adat->atomTypes,
1129                                                    nbp->nbfp,
1130                                                    nbp->nbfp_comb,
1131                                                    nbp->coulomb_tab,
1132                                                    adat->numTypes,
1133                                                    nbp->rcoulomb_sq,
1134                                                    nbp->rvdw_sq,
1135                                                    nbp->two_k_rf,
1136                                                    nbp->ewald_beta,
1137                                                    nbp->rlistOuter_sq,
1138                                                    nbp->sh_ewald,
1139                                                    nbp->epsfac,
1140                                                    nbp->ewaldcoeff_lj,
1141                                                    nbp->c_rf,
1142                                                    nbp->dispersion_shift,
1143                                                    nbp->repulsion_shift,
1144                                                    nbp->vdw_switch,
1145                                                    nbp->rvdw_switch,
1146                                                    nbp->sh_lj_ewald,
1147                                                    nbp->coulomb_tab_scale,
1148                                                    stepWork.computeVirial);
1149 }
1150
1151 } // namespace Nbnxm