39811fd11d26f9f4f2086f1c382a3c13998535b4
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief
38  *  NBNXM SYCL kernels
39  *
40  *  \ingroup module_nbnxm
41  */
42 #include "gmxpre.h"
43
44 #include "nbnxm_sycl_kernel.h"
45
46 #include "gromacs/gpu_utils/devicebuffer.h"
47 #include "gromacs/gpu_utils/gmxsycl.h"
48 #include "gromacs/math/functions.h"
49 #include "gromacs/mdtypes/simulation_workload.h"
50 #include "gromacs/pbcutil/ishift.h"
51 #include "gromacs/utility/template_mp.h"
52
53 #include "nbnxm_sycl_kernel_utils.h"
54 #include "nbnxm_sycl_types.h"
55
56 //! \brief Class name for NBNXM kernel
57 template<bool doPruneNBL, bool doCalcEnergies, enum Nbnxm::ElecType elecType, enum Nbnxm::VdwType vdwType>
58 class NbnxmKernel;
59
60 namespace Nbnxm
61 {
62
63 //! \brief Set of boolean constants mimicking preprocessor macros.
64 template<enum ElecType elecType, enum VdwType vdwType>
65 struct EnergyFunctionProperties {
66     static constexpr bool elecCutoff = (elecType == ElecType::Cut); ///< EL_CUTOFF
67     static constexpr bool elecRF     = (elecType == ElecType::RF);  ///< EL_RF
68     static constexpr bool elecEwaldAna =
69             (elecType == ElecType::EwaldAna || elecType == ElecType::EwaldAnaTwin); ///< EL_EWALD_ANA
70     static constexpr bool elecEwaldTab =
71             (elecType == ElecType::EwaldTab || elecType == ElecType::EwaldTabTwin); ///< EL_EWALD_TAB
72     static constexpr bool elecEwaldTwin =
73             (elecType == ElecType::EwaldAnaTwin || elecType == ElecType::EwaldTabTwin); ///< Use twin cut-off.
74     static constexpr bool elecEwald = (elecEwaldAna || elecEwaldTab);  ///< EL_EWALD_ANY
75     static constexpr bool vdwCombLB = (vdwType == VdwType::CutCombLB); ///< LJ_COMB && !LJ_COMB_GEOM
76     static constexpr bool vdwCombGeom      = (vdwType == VdwType::CutCombGeom); ///< LJ_COMB_GEOM
77     static constexpr bool vdwComb          = (vdwCombLB || vdwCombGeom);        ///< LJ_COMB
78     static constexpr bool vdwEwaldCombGeom = (vdwType == VdwType::EwaldGeom); ///< LJ_EWALD_COMB_GEOM
79     static constexpr bool vdwEwaldCombLB   = (vdwType == VdwType::EwaldLB);   ///< LJ_EWALD_COMB_LB
80     static constexpr bool vdwEwald         = (vdwEwaldCombGeom || vdwEwaldCombLB); ///< LJ_EWALD
81     static constexpr bool vdwFSwitch       = (vdwType == VdwType::FSwitch); ///< LJ_FORCE_SWITCH
82     static constexpr bool vdwPSwitch       = (vdwType == VdwType::PSwitch); ///< LJ_POT_SWITCH
83 };
84
85 //! \brief Templated constants to shorten kernel function declaration.
86 //@{
87 template<enum VdwType vdwType>
88 constexpr bool ljComb = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwComb;
89
90 template<enum ElecType elecType>
91 constexpr bool elecEwald = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwald;
92
93 template<enum ElecType elecType>
94 constexpr bool elecEwaldTab = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTab;
95
96 template<enum VdwType vdwType>
97 constexpr bool ljEwald = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwEwald;
98 //@}
99
100 using cl::sycl::access::fence_space;
101 using cl::sycl::access::mode;
102 using cl::sycl::access::target;
103
104 //! \brief Convert \p sigma and \p epsilon VdW parameters to \c c6,c12 pair.
105 static inline Float2 convertSigmaEpsilonToC6C12(const float sigma, const float epsilon)
106 {
107     const float sigma2 = sigma * sigma;
108     const float sigma6 = sigma2 * sigma2 * sigma2;
109     const float c6     = epsilon * sigma6;
110     const float c12    = c6 * sigma6;
111
112     return { c6, c12 };
113 }
114
115 //! \brief Calculate force and energy for a pair of atoms, VdW force-switch flavor.
116 template<bool doCalcEnergies>
117 static inline void ljForceSwitch(const shift_consts_t         dispersionShift,
118                                  const shift_consts_t         repulsionShift,
119                                  const float                  rVdwSwitch,
120                                  const float                  c6,
121                                  const float                  c12,
122                                  const float                  rInv,
123                                  const float                  r2,
124                                  cl::sycl::private_ptr<float> fInvR,
125                                  cl::sycl::private_ptr<float> eLJ)
126 {
127     /* force switch constants */
128     const float dispShiftV2 = dispersionShift.c2;
129     const float dispShiftV3 = dispersionShift.c3;
130     const float repuShiftV2 = repulsionShift.c2;
131     const float repuShiftV3 = repulsionShift.c3;
132
133     const float r       = r2 * rInv;
134     const float rSwitch = cl::sycl::fdim(r, rVdwSwitch); // max(r - rVdwSwitch, 0)
135
136     *fInvR += -c6 * (dispShiftV2 + dispShiftV3 * rSwitch) * rSwitch * rSwitch * rInv
137               + c12 * (repuShiftV2 + repuShiftV3 * rSwitch) * rSwitch * rSwitch * rInv;
138
139     if constexpr (doCalcEnergies)
140     {
141         const float dispShiftF2 = dispShiftV2 / 3;
142         const float dispShiftF3 = dispShiftV3 / 4;
143         const float repuShiftF2 = repuShiftV2 / 3;
144         const float repuShiftF3 = repuShiftV3 / 4;
145         *eLJ += c6 * (dispShiftF2 + dispShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch
146                 - c12 * (repuShiftF2 + repuShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch;
147     }
148 }
149
150 //! \brief Fetch C6 grid contribution coefficients and return the product of these.
151 template<enum VdwType vdwType>
152 static inline float calculateLJEwaldC6Grid(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
153                                            const int                                typeI,
154                                            const int                                typeJ)
155 {
156     if constexpr (vdwType == VdwType::EwaldGeom)
157     {
158         return a_nbfpComb[typeI][0] * a_nbfpComb[typeJ][0];
159     }
160     else
161     {
162         static_assert(vdwType == VdwType::EwaldLB);
163         /* sigma and epsilon are scaled to give 6*C6 */
164         const Float2 c6c12_i = a_nbfpComb[typeI];
165         const Float2 c6c12_j = a_nbfpComb[typeJ];
166
167         const float sigma   = c6c12_i[0] + c6c12_j[0];
168         const float epsilon = c6c12_i[1] * c6c12_j[1];
169
170         const float sigma2 = sigma * sigma;
171         return epsilon * sigma2 * sigma2 * sigma2;
172     }
173 }
174
175 //! Calculate LJ-PME grid force contribution with geometric or LB combination rule.
176 template<bool doCalcEnergies, enum VdwType vdwType>
177 static inline void ljEwaldComb(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
178                                const float                              sh_lj_ewald,
179                                const int                                typeI,
180                                const int                                typeJ,
181                                const float                              r2,
182                                const float                              r2Inv,
183                                const float                              lje_coeff2,
184                                const float                              lje_coeff6_6,
185                                const float                              int_bit,
186                                cl::sycl::private_ptr<float>             fInvR,
187                                cl::sycl::private_ptr<float>             eLJ)
188 {
189     const float c6grid = calculateLJEwaldC6Grid<vdwType>(a_nbfpComb, typeI, typeJ);
190
191     /* Recalculate inv_r6 without exclusion mask */
192     const float inv_r6_nm = r2Inv * r2Inv * r2Inv;
193     const float cr2       = lje_coeff2 * r2;
194     const float expmcr2   = cl::sycl::exp(-cr2);
195     const float poly      = 1.0F + cr2 + 0.5F * cr2 * cr2;
196
197     /* Subtract the grid force from the total LJ force */
198     *fInvR += c6grid * (inv_r6_nm - expmcr2 * (inv_r6_nm * poly + lje_coeff6_6)) * r2Inv;
199
200     if constexpr (doCalcEnergies)
201     {
202         /* Shift should be applied only to real LJ pairs */
203         const float sh_mask = sh_lj_ewald * int_bit;
204         *eLJ += c_oneSixth * c6grid * (inv_r6_nm * (1.0F - expmcr2 * poly) + sh_mask);
205     }
206 }
207
208 /*! \brief Apply potential switch. */
209 template<bool doCalcEnergies>
210 static inline void ljPotentialSwitch(const switch_consts_t        vdwSwitch,
211                                      const float                  rVdwSwitch,
212                                      const float                  rInv,
213                                      const float                  r2,
214                                      cl::sycl::private_ptr<float> fInvR,
215                                      cl::sycl::private_ptr<float> eLJ)
216 {
217     /* potential switch constants */
218     const float switchV3 = vdwSwitch.c3;
219     const float switchV4 = vdwSwitch.c4;
220     const float switchV5 = vdwSwitch.c5;
221     const float switchF2 = 3 * vdwSwitch.c3;
222     const float switchF3 = 4 * vdwSwitch.c4;
223     const float switchF4 = 5 * vdwSwitch.c5;
224
225     const float r       = r2 * rInv;
226     const float rSwitch = r - rVdwSwitch;
227
228     if (rSwitch > 0.0F)
229     {
230         const float sw =
231                 1.0F + (switchV3 + (switchV4 + switchV5 * rSwitch) * rSwitch) * rSwitch * rSwitch * rSwitch;
232         const float dsw = (switchF2 + (switchF3 + switchF4 * rSwitch) * rSwitch) * rSwitch * rSwitch;
233
234         *fInvR = (*fInvR) * sw - rInv * (*eLJ) * dsw;
235         if constexpr (doCalcEnergies)
236         {
237             *eLJ *= sw;
238         }
239     }
240 }
241
242
243 /*! \brief Calculate analytical Ewald correction term. */
244 static inline float pmeCorrF(const float z2)
245 {
246     constexpr float FN6 = -1.7357322914161492954e-8F;
247     constexpr float FN5 = 1.4703624142580877519e-6F;
248     constexpr float FN4 = -0.000053401640219807709149F;
249     constexpr float FN3 = 0.0010054721316683106153F;
250     constexpr float FN2 = -0.019278317264888380590F;
251     constexpr float FN1 = 0.069670166153766424023F;
252     constexpr float FN0 = -0.75225204789749321333F;
253
254     constexpr float FD4 = 0.0011193462567257629232F;
255     constexpr float FD3 = 0.014866955030185295499F;
256     constexpr float FD2 = 0.11583842382862377919F;
257     constexpr float FD1 = 0.50736591960530292870F;
258     constexpr float FD0 = 1.0F;
259
260     const float z4 = z2 * z2;
261
262     float       polyFD0 = FD4 * z4 + FD2;
263     const float polyFD1 = FD3 * z4 + FD1;
264     polyFD0             = polyFD0 * z4 + FD0;
265     polyFD0             = polyFD1 * z2 + polyFD0;
266
267     polyFD0 = 1.0F / polyFD0;
268
269     float polyFN0 = FN6 * z4 + FN4;
270     float polyFN1 = FN5 * z4 + FN3;
271     polyFN0       = polyFN0 * z4 + FN2;
272     polyFN1       = polyFN1 * z4 + FN1;
273     polyFN0       = polyFN0 * z4 + FN0;
274     polyFN0       = polyFN1 * z2 + polyFN0;
275
276     return polyFN0 * polyFD0;
277 }
278
279 /*! \brief Linear interpolation using exactly two FMA operations.
280  *
281  *  Implements numeric equivalent of: (1-t)*d0 + t*d1.
282  */
283 template<typename T>
284 static inline T lerp(T d0, T d1, T t)
285 {
286     return fma(t, d1, fma(-t, d0, d0));
287 }
288
289 /*! \brief Interpolate Ewald coulomb force correction using the F*r table. */
290 static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::read> a_coulombTab,
291                                              const float coulombTabScale,
292                                              const float r)
293 {
294     const float normalized = coulombTabScale * r;
295     const int   index      = static_cast<int>(normalized);
296     const float fraction   = normalized - index;
297
298     const float left  = a_coulombTab[index];
299     const float right = a_coulombTab[index + 1];
300
301     return lerp(left, right, fraction); // TODO: cl::sycl::mix
302 }
303
304 /*! \brief Reduce c_clSize j-force components using shifts and atomically accumulate into a_f.
305  *
306  * c_clSize consecutive threads hold the force components of a j-atom which we
307  * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
308  */
309 static inline void reduceForceJShuffle(Float3                                   f,
310                                        const cl::sycl::nd_item<3>               itemIdx,
311                                        const int                                tidxi,
312                                        const int                                aidx,
313                                        DeviceAccessor<Float3, mode::read_write> a_f)
314 {
315     static_assert(c_clSize == 8 || c_clSize == 4);
316     sycl_2020::sub_group sg = itemIdx.get_sub_group();
317
318     f[0] += sycl_2020::shift_left(sg, f[0], 1);
319     f[1] += sycl_2020::shift_right(sg, f[1], 1);
320     f[2] += sycl_2020::shift_left(sg, f[2], 1);
321     if (tidxi & 1)
322     {
323         f[0] = f[1];
324     }
325
326     f[0] += sycl_2020::shift_left(sg, f[0], 2);
327     f[2] += sycl_2020::shift_right(sg, f[2], 2);
328     if (tidxi & 2)
329     {
330         f[0] = f[2];
331     }
332
333     if constexpr (c_clSize == 8)
334     {
335         f[0] += sycl_2020::shift_left(sg, f[0], 4);
336     }
337
338     if (tidxi < 3)
339     {
340         atomicFetchAdd(a_f[aidx][tidxi], f[0]);
341     }
342 }
343
344 /*!
345  * \brief Do workgroup-level reduction of a single \c float.
346  *
347  * While SYCL has \c sycl::reduce_over_group, it currently (oneAPI 2021.3.0) uses a very large
348  * shared memory buffer, which leads to a reduced occupancy.
349  *
350  * \tparam subGroupSize Size of a sub-group.
351  * \tparam groupSize Size of a work-group.
352  * \param itemIdx Current thread's \c sycl::nd_item.
353  * \param tidxi Current thread's linearized local index.
354  * \param sm_buf Accessor for local reduction buffer.
355  * \param valueToReduce Current thread's value. Must have length of at least 1.
356  * \return For thread with \p tidxi 0: sum of all \p valueToReduce. Other threads: unspecified.
357  */
358 template<int subGroupSize, int groupSize>
359 static inline float groupReduce(const cl::sycl::nd_item<3> itemIdx,
360                                 const unsigned int         tidxi,
361                                 cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
362                                 float valueToReduce)
363 {
364     constexpr int numSubGroupsInGroup = groupSize / subGroupSize;
365     static_assert(numSubGroupsInGroup == 1 || numSubGroupsInGroup == 2);
366     sycl_2020::sub_group sg = itemIdx.get_sub_group();
367     valueToReduce           = sycl_2020::group_reduce(sg, valueToReduce, sycl_2020::plus<float>());
368     // If we have two sub-groups, we should reduce across them.
369     if constexpr (numSubGroupsInGroup == 2)
370     {
371         if (tidxi == subGroupSize)
372         {
373             sm_buf[0] = valueToReduce;
374         }
375         itemIdx.barrier(fence_space::local_space);
376         if (tidxi == 0)
377         {
378             valueToReduce += sm_buf[0];
379         }
380     }
381     return valueToReduce;
382 }
383
384 /*! \brief Reduce c_clSize j-force components using local memory and atomically accumulate into a_f.
385  *
386  * c_clSize consecutive threads hold the force components of a j-atom which we
387  * reduced in cl_Size steps using shift and atomically accumulate them into \p a_f.
388  *
389  * TODO: implement binary reduction flavor for the case where cl_Size is power of two.
390  */
391 static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
392                                        Float3                                   f,
393                                        const cl::sycl::nd_item<3>               itemIdx,
394                                        const int                                tidxi,
395                                        const int                                tidxj,
396                                        const int                                aidx,
397                                        DeviceAccessor<Float3, mode::read_write> a_f)
398 {
399     static constexpr int sc_fBufferStride = c_clSizeSq;
400     int                  tidx             = tidxi + tidxj * c_clSize;
401     sm_buf[0 * sc_fBufferStride + tidx]   = f[0];
402     sm_buf[1 * sc_fBufferStride + tidx]   = f[1];
403     sm_buf[2 * sc_fBufferStride + tidx]   = f[2];
404
405     subGroupBarrier(itemIdx);
406
407     // reducing data 8-by-by elements on the leader of same threads as those storing above
408     assert(itemIdx.get_sub_group().get_local_range().size() >= c_clSize);
409
410     if (tidxi < 3)
411     {
412         float fSum = 0.0F;
413         for (int j = tidxj * c_clSize; j < (tidxj + 1) * c_clSize; j++)
414         {
415             fSum += sm_buf[sc_fBufferStride * tidxi + j];
416         }
417
418         atomicFetchAdd(a_f[aidx][tidxi], fSum);
419     }
420 }
421
422
423 /*! \brief Reduce c_clSize j-force components using either shifts or local memory and atomically accumulate into a_f.
424  */
425 static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
426                                 Float3                                                        f,
427                                 const cl::sycl::nd_item<3>               itemIdx,
428                                 const int                                tidxi,
429                                 const int                                tidxj,
430                                 const int                                aidx,
431                                 DeviceAccessor<Float3, mode::read_write> a_f)
432 {
433     if constexpr (!gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster))
434     {
435         reduceForceJGeneric(sm_buf, f, itemIdx, tidxi, tidxj, aidx, a_f);
436     }
437     else
438     {
439         reduceForceJShuffle(f, itemIdx, tidxi, aidx, a_f);
440     }
441 }
442
443
444 /*! \brief Final i-force reduction.
445  *
446  * Reduce c_nbnxnGpuNumClusterPerSupercluster i-force components stored in \p fCiBuf[]
447  * accumulating atomically into \p a_f.
448  * If \p calcFShift is true, further reduce shift forces and atomically accumulate into \p a_fShift.
449  *
450  * This implementation works only with power of two array sizes.
451  */
452 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
453                                          const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
454                                          const bool   calcFShift,
455                                          const cl::sycl::nd_item<3>               itemIdx,
456                                          const int                                tidxi,
457                                          const int                                tidxj,
458                                          const int                                sci,
459                                          const int                                shift,
460                                          DeviceAccessor<Float3, mode::read_write> a_f,
461                                          DeviceAccessor<Float3, mode::read_write> a_fShift)
462 {
463     // must have power of two elements in fCiBuf
464     static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster));
465
466     static constexpr int bufStride  = c_clSize * c_clSize;
467     static constexpr int clSizeLog2 = gmx::StaticLog2<c_clSize>::value;
468     const int            tidx       = tidxi + tidxj * c_clSize;
469     float                fShiftBuf  = 0.0F;
470     for (int ciOffset = 0; ciOffset < c_nbnxnGpuNumClusterPerSupercluster; ciOffset++)
471     {
472         const int aidx = (sci * c_nbnxnGpuNumClusterPerSupercluster + ciOffset) * c_clSize + tidxi;
473         /* store i forces in shmem */
474         sm_buf[tidx]                 = fCiBuf[ciOffset][0];
475         sm_buf[bufStride + tidx]     = fCiBuf[ciOffset][1];
476         sm_buf[2 * bufStride + tidx] = fCiBuf[ciOffset][2];
477         itemIdx.barrier(fence_space::local_space);
478
479         /* Reduce the initial c_clSize values for each i atom to half
480          * every step by using c_clSize * i threads. */
481         int i = c_clSize / 2;
482         for (int j = clSizeLog2 - 1; j > 0; j--)
483         {
484             if (tidxj < i)
485             {
486                 sm_buf[tidxj * c_clSize + tidxi] += sm_buf[(tidxj + i) * c_clSize + tidxi];
487                 sm_buf[bufStride + tidxj * c_clSize + tidxi] +=
488                         sm_buf[bufStride + (tidxj + i) * c_clSize + tidxi];
489                 sm_buf[2 * bufStride + tidxj * c_clSize + tidxi] +=
490                         sm_buf[2 * bufStride + (tidxj + i) * c_clSize + tidxi];
491             }
492             i >>= 1;
493             itemIdx.barrier(fence_space::local_space);
494         }
495
496         /* i == 1, last reduction step, writing to global mem */
497         /* Split the reduction between the first 3 line threads
498            Threads with line id 0 will do the reduction for (float3).x components
499            Threads with line id 1 will do the reduction for (float3).y components
500            Threads with line id 2 will do the reduction for (float3).z components. */
501         if (tidxj < 3)
502         {
503             const float f =
504                     sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
505             atomicFetchAdd(a_f[aidx][tidxj], f);
506             if (calcFShift)
507             {
508                 fShiftBuf += f;
509             }
510         }
511         itemIdx.barrier(fence_space::local_space);
512     }
513     /* add up local shift forces into global mem */
514     if (calcFShift)
515     {
516         /* Only threads with tidxj < 3 will update fshift.
517            The threads performing the update must be the same as the threads
518            storing the reduction result above. */
519         if (tidxj < 3)
520         {
521             if constexpr (c_clSize == 4)
522             {
523                 /* Intel Xe (Gen12LP) and earlier GPUs implement floating-point atomics via
524                  * a compare-and-swap (CAS) loop. It has particularly poor performance when
525                  * updating the same memory location from the same work-group.
526                  * Such optimization might be slightly beneficial for NVIDIA and AMD as well,
527                  * but it is unlikely to make a big difference and thus was not evaluated.
528                  */
529                 auto sg = itemIdx.get_sub_group();
530                 fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 1);
531                 fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 2);
532                 if (tidxi == 0)
533                 {
534                     atomicFetchAdd(a_fShift[shift][tidxj], fShiftBuf);
535                 }
536             }
537             else
538             {
539                 atomicFetchAdd(a_fShift[shift][tidxj], fShiftBuf);
540             }
541         }
542     }
543 }
544
545 /*! \brief Main kernel for NBNXM.
546  *
547  */
548 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
549 auto nbnxmKernel(cl::sycl::handler&                                        cgh,
550                  DeviceAccessor<Float4, mode::read>                        a_xq,
551                  DeviceAccessor<Float3, mode::read_write>                  a_f,
552                  DeviceAccessor<Float3, mode::read>                        a_shiftVec,
553                  DeviceAccessor<Float3, mode::read_write>                  a_fShift,
554                  OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyElec,
555                  OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyVdw,
556                  DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
557                  DeviceAccessor<nbnxn_sci_t, mode::read>                                 a_plistSci,
558                  DeviceAccessor<nbnxn_excl_t, mode::read>                    a_plistExcl,
559                  OptionalAccessor<Float2, mode::read, ljComb<vdwType>>       a_ljComb,
560                  OptionalAccessor<int, mode::read, !ljComb<vdwType>>         a_atomTypes,
561                  OptionalAccessor<Float2, mode::read, !ljComb<vdwType>>      a_nbfp,
562                  OptionalAccessor<Float2, mode::read, ljEwald<vdwType>>      a_nbfpComb,
563                  OptionalAccessor<float, mode::read, elecEwaldTab<elecType>> a_coulombTab,
564                  const int                                                   numTypes,
565                  const float                                                 rCoulombSq,
566                  const float                                                 rVdwSq,
567                  const float                                                 twoKRf,
568                  const float                                                 ewaldBeta,
569                  const float                                                 rlistOuterSq,
570                  const float                                                 ewaldShift,
571                  const float                                                 epsFac,
572                  const float                                                 ewaldCoeffLJ,
573                  const float                                                 cRF,
574                  const shift_consts_t                                        dispersionShift,
575                  const shift_consts_t                                        repulsionShift,
576                  const switch_consts_t                                       vdwSwitch,
577                  const float                                                 rVdwSwitch,
578                  const float                                                 ljEwaldShift,
579                  const float                                                 coulombTabScale,
580                  const bool                                                  calcShift)
581 {
582     static constexpr EnergyFunctionProperties<elecType, vdwType> props;
583
584     cgh.require(a_xq);
585     cgh.require(a_f);
586     cgh.require(a_shiftVec);
587     cgh.require(a_fShift);
588     if constexpr (doCalcEnergies)
589     {
590         cgh.require(a_energyElec);
591         cgh.require(a_energyVdw);
592     }
593     cgh.require(a_plistCJ4);
594     cgh.require(a_plistSci);
595     cgh.require(a_plistExcl);
596     if constexpr (!props.vdwComb)
597     {
598         cgh.require(a_atomTypes);
599         cgh.require(a_nbfp);
600     }
601     else
602     {
603         cgh.require(a_ljComb);
604     }
605     if constexpr (props.vdwEwald)
606     {
607         cgh.require(a_nbfpComb);
608     }
609     if constexpr (props.elecEwaldTab)
610     {
611         cgh.require(a_coulombTab);
612     }
613
614     // shmem buffer for i x+q pre-loading
615     cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(
616             cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
617
618     // shmem buffer for force reduction
619     // SYCL-TODO: Make into 3D; section 4.7.6.11 of SYCL2020 specs
620     cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_reductionBuffer(
621             cl::sycl::range<1>(c_clSize * c_clSize * DIM), cgh);
622
623     auto sm_atomTypeI = [&]() {
624         if constexpr (!props.vdwComb)
625         {
626             return cl::sycl::accessor<int, 2, mode::read_write, target::local>(
627                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
628         }
629         else
630         {
631             return nullptr;
632         }
633     }();
634
635     auto sm_ljCombI = [&]() {
636         if constexpr (props.vdwComb)
637         {
638             return cl::sycl::accessor<Float2, 2, mode::read_write, target::local>(
639                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
640         }
641         else
642         {
643             return nullptr;
644         }
645     }();
646
647     /* Flag to control the calculation of exclusion forces in the kernel
648      * We do that with Ewald (elec/vdw) and RF. Cut-off only has exclusion
649      * energy terms. */
650     constexpr bool doExclusionForces =
651             (props.elecEwald || props.elecRF || props.vdwEwald || (props.elecCutoff && doCalcEnergies));
652
653     // The post-prune j-i cluster-pair organization is linked to how exclusion and interaction mask data is stored.
654     // Currently, this is ideally suited for 32-wide subgroup size but slightly less so for others,
655     // e.g. subGroupSize > prunedClusterPairSize on AMD GCN / CDNA.
656     // Hence, the two are decoupled.
657     // When changing this code, please update requiredSubGroupSizeForNbnxm in src/gromacs/hardware/device_management_sycl.cpp.
658     constexpr int prunedClusterPairSize = c_clSize * c_splitClSize;
659 #if defined(HIPSYCL_PLATFORM_ROCM) // SYCL-TODO AMD RDNA/RDNA2 has 32-wide exec; how can we check for that?
660     gmx_unused constexpr int subGroupSize = c_clSize * c_clSize;
661 #else
662     gmx_unused constexpr int subGroupSize = prunedClusterPairSize;
663 #endif
664
665     return [=](cl::sycl::nd_item<3> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
666     {
667         /* thread/block/warp id-s */
668         const unsigned tidxi = itemIdx.get_local_id(2);
669         const unsigned tidxj = itemIdx.get_local_id(1);
670         const unsigned tidx  = tidxj * c_clSize + tidxi;
671         const unsigned tidxz = 0;
672
673         const unsigned bidx = itemIdx.get_group(0);
674
675         const sycl_2020::sub_group sg = itemIdx.get_sub_group();
676         // Could use sg.get_group_range to compute the imask & exclusion Idx, but too much of the logic relies on it anyway
677         // and in cases where prunedClusterPairSize != subGroupSize we can't use it anyway
678         const unsigned imeiIdx = tidx / prunedClusterPairSize;
679
680         Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster]; // i force buffer
681         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
682         {
683             fCiBuf[i] = Float3(0.0F, 0.0F, 0.0F);
684         }
685
686         const nbnxn_sci_t nbSci     = a_plistSci[bidx];
687         const int         sci       = nbSci.sci;
688         const int         cij4Start = nbSci.cj4_ind_start;
689         const int         cij4End   = nbSci.cj4_ind_end;
690
691         // Only needed if props.elecEwaldAna
692         const float beta2 = ewaldBeta * ewaldBeta;
693         const float beta3 = ewaldBeta * ewaldBeta * ewaldBeta;
694
695         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i += c_clSize)
696         {
697             /* Pre-load i-atom x and q into shared memory */
698             const int             ci       = sci * c_nbnxnGpuNumClusterPerSupercluster + tidxj + i;
699             const int             ai       = ci * c_clSize + tidxi;
700             const cl::sycl::id<2> cacheIdx = cl::sycl::id<2>(tidxj + i, tidxi);
701
702             const Float3 shift = a_shiftVec[nbSci.shift];
703             Float4       xqi   = a_xq[ai];
704             xqi += Float4(shift[0], shift[1], shift[2], 0.0F);
705             xqi[3] *= epsFac;
706             sm_xq[cacheIdx] = xqi;
707
708             if constexpr (!props.vdwComb)
709             {
710                 // Pre-load the i-atom types into shared memory
711                 sm_atomTypeI[cacheIdx] = a_atomTypes[ai];
712             }
713             else
714             {
715                 // Pre-load the LJ combination parameters into shared memory
716                 sm_ljCombI[cacheIdx] = a_ljComb[ai];
717             }
718         }
719         itemIdx.barrier(fence_space::local_space);
720
721         float ewaldCoeffLJ_2, ewaldCoeffLJ_6_6; // Only needed if (props.vdwEwald)
722         if constexpr (props.vdwEwald)
723         {
724             ewaldCoeffLJ_2   = ewaldCoeffLJ * ewaldCoeffLJ;
725             ewaldCoeffLJ_6_6 = ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * c_oneSixth;
726         }
727
728         float energyVdw, energyElec; // Only needed if (doCalcEnergies)
729         if constexpr (doCalcEnergies)
730         {
731             energyVdw = energyElec = 0.0F;
732         }
733         if constexpr (doCalcEnergies && doExclusionForces)
734         {
735             if (nbSci.shift == gmx::c_centralShiftIndex
736                 && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
737             {
738                 // we have the diagonal: add the charge and LJ self interaction energy term
739                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
740                 {
741                     // TODO: Are there other options?
742                     if constexpr (props.elecEwald || props.elecRF || props.elecCutoff)
743                     {
744                         const float qi = sm_xq[i][tidxi][3];
745                         energyElec += qi * qi;
746                     }
747                     if constexpr (props.vdwEwald)
748                     {
749                         energyVdw +=
750                                 a_nbfp[a_atomTypes[(sci * c_nbnxnGpuNumClusterPerSupercluster + i) * c_clSize + tidxi]
751                                        * (numTypes + 1)][0];
752                     }
753                 }
754                 /* divide the self term(s) equally over the j-threads, then multiply with the coefficients. */
755                 if constexpr (props.vdwEwald)
756                 {
757                     energyVdw /= c_clSize;
758                     energyVdw *= 0.5F * c_oneSixth * ewaldCoeffLJ_6_6; // c_OneTwelfth?
759                 }
760                 if constexpr (props.elecRF || props.elecCutoff)
761                 {
762                     // Correct for epsfac^2 due to adding qi^2 */
763                     energyElec /= epsFac * c_clSize;
764                     energyElec *= -0.5F * cRF;
765                 }
766                 if constexpr (props.elecEwald)
767                 {
768                     // Correct for epsfac^2 due to adding qi^2 */
769                     energyElec /= epsFac * c_clSize;
770                     energyElec *= -ewaldBeta * c_OneOverSqrtPi; /* last factor 1/sqrt(pi) */
771                 }
772             } // (nbSci.shift == gmx::c_centralShiftIndex && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
773         }     // (doCalcEnergies && doExclusionForces)
774
775         // Only needed if (doExclusionForces)
776         const bool nonSelfInteraction = !(nbSci.shift == gmx::c_centralShiftIndex & tidxj <= tidxi);
777
778         // loop over the j clusters = seen by any of the atoms in the current super-cluster
779         for (int j4 = cij4Start + tidxz; j4 < cij4End; j4 += 1)
780         {
781             unsigned imask = a_plistCJ4[j4].imei[imeiIdx].imask;
782             if (!doPruneNBL && !imask)
783             {
784                 continue;
785             }
786             const int wexclIdx = a_plistCJ4[j4].imei[imeiIdx].excl_ind;
787             static_assert(gmx::isPowerOfTwo(prunedClusterPairSize));
788             const unsigned wexcl = a_plistExcl[wexclIdx].pair[tidx & (prunedClusterPairSize - 1)];
789             for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
790             {
791                 const bool maskSet =
792                         imask & (superClInteractionMask << (jm * c_nbnxnGpuNumClusterPerSupercluster));
793                 if (!maskSet)
794                 {
795                     continue;
796                 }
797                 unsigned  maskJI = (1U << (jm * c_nbnxnGpuNumClusterPerSupercluster));
798                 const int cj     = a_plistCJ4[j4].cj[jm];
799                 const int aj     = cj * c_clSize + tidxj;
800
801                 // load j atom data
802                 const Float4 xqj = a_xq[aj];
803
804                 const Float3 xj(xqj[0], xqj[1], xqj[2]);
805                 const float  qj = xqj[3];
806                 int          atomTypeJ; // Only needed if (!props.vdwComb)
807                 Float2       ljCombJ;   // Only needed if (props.vdwComb)
808                 if constexpr (props.vdwComb)
809                 {
810                     ljCombJ = a_ljComb[aj];
811                 }
812                 else
813                 {
814                     atomTypeJ = a_atomTypes[aj];
815                 }
816
817                 Float3 fCjBuf(0.0F, 0.0F, 0.0F);
818
819                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
820                 {
821                     if (imask & maskJI)
822                     {
823                         // i cluster index
824                         const int ci = sci * c_nbnxnGpuNumClusterPerSupercluster + i;
825                         // all threads load an atom from i cluster ci into shmem!
826                         const Float4 xqi = sm_xq[i][tidxi];
827                         const Float3 xi(xqi[0], xqi[1], xqi[2]);
828
829                         // distance between i and j atoms
830                         const Float3 rv = xi - xj;
831                         float        r2 = norm2(rv);
832
833                         if constexpr (doPruneNBL)
834                         {
835                             /* If _none_ of the atoms pairs are in cutoff range,
836                              * the bit corresponding to the current
837                              * cluster-pair in imask gets set to 0. */
838                             if (!sycl_2020::group_any_of(sg, r2 < rlistOuterSq))
839                             {
840                                 imask &= ~maskJI;
841                             }
842                         }
843                         const float pairExclMask = (wexcl & maskJI) ? 1.0F : 0.0F;
844
845                         // cutoff & exclusion check
846
847                         const bool notExcluded = doExclusionForces ? (nonSelfInteraction | (ci != cj))
848                                                                    : (wexcl & maskJI);
849
850                         // SYCL-TODO: Check optimal way of branching here.
851                         if ((r2 < rCoulombSq) && notExcluded)
852                         {
853                             const float qi = xqi[3];
854                             int         atomTypeI; // Only needed if (!props.vdwComb)
855                             float       sigma, epsilon;
856                             Float2      c6c12;
857
858                             if constexpr (!props.vdwComb)
859                             {
860                                 /* LJ 6*C6 and 12*C12 */
861                                 atomTypeI = sm_atomTypeI[i][tidxi];
862                                 c6c12     = a_nbfp[numTypes * atomTypeI + atomTypeJ];
863                             }
864                             else
865                             {
866                                 const Float2 ljCombI = sm_ljCombI[i][tidxi];
867                                 if constexpr (props.vdwCombGeom)
868                                 {
869                                     c6c12 = Float2(ljCombI[0] * ljCombJ[0], ljCombI[1] * ljCombJ[1]);
870                                 }
871                                 else
872                                 {
873                                     static_assert(props.vdwCombLB);
874                                     // LJ 2^(1/6)*sigma and 12*epsilon
875                                     sigma   = ljCombI[0] + ljCombJ[0];
876                                     epsilon = ljCombI[1] * ljCombJ[1];
877                                     if constexpr (doCalcEnergies)
878                                     {
879                                         c6c12 = convertSigmaEpsilonToC6C12(sigma, epsilon);
880                                     }
881                                 } // props.vdwCombGeom
882                             }     // !props.vdwComb
883
884                             // c6 and c12 are unused and garbage iff props.vdwCombLB && !doCalcEnergies
885                             const float c6  = c6c12[0];
886                             const float c12 = c6c12[1];
887
888                             // Ensure distance do not become so small that r^-12 overflows
889                             r2 = std::max(r2, c_nbnxnMinDistanceSquared);
890 #if GMX_SYCL_HIPSYCL
891                             // No fast/native functions in some compilation passes
892                             const float rInv = cl::sycl::rsqrt(r2);
893 #else
894                             // SYCL-TODO: sycl::half_precision::rsqrt?
895                             const float rInv = cl::sycl::native::rsqrt(r2);
896 #endif
897                             const float r2Inv = rInv * rInv;
898                             float       r6Inv, fInvR, energyLJPair;
899                             if constexpr (!props.vdwCombLB || doCalcEnergies)
900                             {
901                                 r6Inv = r2Inv * r2Inv * r2Inv;
902                                 if constexpr (doExclusionForces)
903                                 {
904                                     // SYCL-TODO: Check if true for SYCL
905                                     /* We could mask r2Inv, but with Ewald masking both
906                                      * r6Inv and fInvR is faster */
907                                     r6Inv *= pairExclMask;
908                                 }
909                                 fInvR = r6Inv * (c12 * r6Inv - c6) * r2Inv;
910                             }
911                             else
912                             {
913                                 float sig_r  = sigma * rInv;
914                                 float sig_r2 = sig_r * sig_r;
915                                 float sig_r6 = sig_r2 * sig_r2 * sig_r2;
916                                 if constexpr (doExclusionForces)
917                                 {
918                                     sig_r6 *= pairExclMask;
919                                 }
920                                 fInvR = epsilon * sig_r6 * (sig_r6 - 1.0F) * r2Inv;
921                             } // (!props.vdwCombLB || doCalcEnergies)
922                             if constexpr (doCalcEnergies || props.vdwPSwitch)
923                             {
924                                 energyLJPair = pairExclMask
925                                                * (c12 * (r6Inv * r6Inv + repulsionShift.cpot) * c_oneTwelfth
926                                                   - c6 * (r6Inv + dispersionShift.cpot) * c_oneSixth);
927                             }
928                             if constexpr (props.vdwFSwitch)
929                             {
930                                 ljForceSwitch<doCalcEnergies>(
931                                         dispersionShift, repulsionShift, rVdwSwitch, c6, c12, rInv, r2, &fInvR, &energyLJPair);
932                             }
933                             if constexpr (props.vdwEwald)
934                             {
935                                 ljEwaldComb<doCalcEnergies, vdwType>(a_nbfpComb,
936                                                                      ljEwaldShift,
937                                                                      atomTypeI,
938                                                                      atomTypeJ,
939                                                                      r2,
940                                                                      r2Inv,
941                                                                      ewaldCoeffLJ_2,
942                                                                      ewaldCoeffLJ_6_6,
943                                                                      pairExclMask,
944                                                                      &fInvR,
945                                                                      &energyLJPair);
946                             } // (props.vdwEwald)
947                             if constexpr (props.vdwPSwitch)
948                             {
949                                 ljPotentialSwitch<doCalcEnergies>(
950                                         vdwSwitch, rVdwSwitch, rInv, r2, &fInvR, &energyLJPair);
951                             }
952                             if constexpr (props.elecEwaldTwin)
953                             {
954                                 // Separate VDW cut-off check to enable twin-range cut-offs
955                                 // (rVdw < rCoulomb <= rList)
956                                 const float vdwInRange = (r2 < rVdwSq) ? 1.0F : 0.0F;
957                                 fInvR *= vdwInRange;
958                                 if constexpr (doCalcEnergies)
959                                 {
960                                     energyLJPair *= vdwInRange;
961                                 }
962                             }
963                             if constexpr (doCalcEnergies)
964                             {
965                                 energyVdw += energyLJPair;
966                             }
967
968                             if constexpr (props.elecCutoff)
969                             {
970                                 if constexpr (doExclusionForces)
971                                 {
972                                     fInvR += qi * qj * pairExclMask * r2Inv * rInv;
973                                 }
974                                 else
975                                 {
976                                     fInvR += qi * qj * r2Inv * rInv;
977                                 }
978                             }
979                             if constexpr (props.elecRF)
980                             {
981                                 fInvR += qi * qj * (pairExclMask * r2Inv * rInv - twoKRf);
982                             }
983                             if constexpr (props.elecEwaldAna)
984                             {
985                                 fInvR += qi * qj
986                                          * (pairExclMask * r2Inv * rInv + pmeCorrF(beta2 * r2) * beta3);
987                             }
988                             if constexpr (props.elecEwaldTab)
989                             {
990                                 fInvR += qi * qj
991                                          * (pairExclMask * r2Inv
992                                             - interpolateCoulombForceR(
993                                                     a_coulombTab, coulombTabScale, r2 * rInv))
994                                          * rInv;
995                             }
996
997                             if constexpr (doCalcEnergies)
998                             {
999                                 if constexpr (props.elecCutoff)
1000                                 {
1001                                     energyElec += qi * qj * (pairExclMask * rInv - cRF);
1002                                 }
1003                                 if constexpr (props.elecRF)
1004                                 {
1005                                     energyElec +=
1006                                             qi * qj * (pairExclMask * rInv + 0.5F * twoKRf * r2 - cRF);
1007                                 }
1008                                 if constexpr (props.elecEwald)
1009                                 {
1010                                     energyElec +=
1011                                             qi * qj
1012                                             * (rInv * (pairExclMask - cl::sycl::erf(r2 * rInv * ewaldBeta))
1013                                                - pairExclMask * ewaldShift);
1014                                 }
1015                             }
1016
1017                             const Float3 forceIJ = rv * fInvR;
1018
1019                             /* accumulate j forces in registers */
1020                             fCjBuf -= forceIJ;
1021                             /* accumulate i forces in registers */
1022                             fCiBuf[i] += forceIJ;
1023                         } // (r2 < rCoulombSq) && notExcluded
1024                     }     // (imask & maskJI)
1025                     /* shift the mask bit by 1 */
1026                     maskJI += maskJI;
1027                 } // for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
1028                 /* reduce j forces */
1029                 reduceForceJ(sm_reductionBuffer, fCjBuf, itemIdx, tidxi, tidxj, aj, a_f);
1030             } // for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
1031             if constexpr (doPruneNBL)
1032             {
1033                 /* Update the imask with the new one which does not contain the
1034                  * out of range clusters anymore. */
1035                 a_plistCJ4[j4].imei[imeiIdx].imask = imask;
1036             }
1037         } // for (int j4 = cij4Start; j4 < cij4End; j4 += 1)
1038
1039         /* skip central shifts when summing shift forces */
1040         const bool doCalcShift = (calcShift && nbSci.shift != gmx::c_centralShiftIndex);
1041
1042         reduceForceIAndFShift(
1043                 sm_reductionBuffer, fCiBuf, doCalcShift, itemIdx, tidxi, tidxj, sci, nbSci.shift, a_f, a_fShift);
1044
1045         if constexpr (doCalcEnergies)
1046         {
1047             const float energyVdwGroup =
1048                     groupReduce<subGroupSize, c_clSizeSq>(itemIdx, tidx, sm_reductionBuffer, energyVdw);
1049             const float energyElecGroup = groupReduce<subGroupSize, c_clSizeSq>(
1050                     itemIdx, tidx, sm_reductionBuffer, energyElec);
1051
1052             if (tidx == 0)
1053             {
1054                 atomicFetchAdd(a_energyVdw[0], energyVdwGroup);
1055                 atomicFetchAdd(a_energyElec[0], energyElecGroup);
1056             }
1057         }
1058     };
1059 }
1060
1061 //! \brief NBNXM kernel launch code.
1062 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType, class... Args>
1063 cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int numSci, Args&&... args)
1064 {
1065     using kernelNameType = NbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>;
1066
1067     /* Kernel launch config:
1068      * - The thread block dimensions match the size of i-clusters, j-clusters,
1069      *   and j-cluster concurrency, in x, y, and z, respectively.
1070      * - The 1D block-grid contains as many blocks as super-clusters.
1071      */
1072     const int                   numBlocks = numSci;
1073     const cl::sycl::range<3>    blockSize{ 1, c_clSize, c_clSize };
1074     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
1075     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
1076
1077     cl::sycl::queue q = deviceStream.stream();
1078
1079     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
1080         auto kernel = nbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>(
1081                 cgh, std::forward<Args>(args)...);
1082         cgh.parallel_for<kernelNameType>(range, kernel);
1083     });
1084
1085     return e;
1086 }
1087
1088 //! \brief Select templated kernel and launch it.
1089 template<class... Args>
1090 cl::sycl::event chooseAndLaunchNbnxmKernel(bool          doPruneNBL,
1091                                            bool          doCalcEnergies,
1092                                            enum ElecType elecType,
1093                                            enum VdwType  vdwType,
1094                                            Args&&... args)
1095 {
1096     return gmx::dispatchTemplatedFunction(
1097             [&](auto doPruneNBL_, auto doCalcEnergies_, auto elecType_, auto vdwType_) {
1098                 return launchNbnxmKernel<doPruneNBL_, doCalcEnergies_, elecType_, vdwType_>(
1099                         std::forward<Args>(args)...);
1100             },
1101             doPruneNBL,
1102             doCalcEnergies,
1103             elecType,
1104             vdwType);
1105 }
1106
1107 void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const InteractionLocality iloc)
1108 {
1109     NBAtomDataGpu*      adat         = nb->atdat;
1110     NBParamGpu*         nbp          = nb->nbparam;
1111     gpu_plist*          plist        = nb->plist[iloc];
1112     const bool          doPruneNBL   = (plist->haveFreshList && !nb->didPrune[iloc]);
1113     const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
1114
1115     cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL,
1116                                                    stepWork.computeEnergy,
1117                                                    nbp->elecType,
1118                                                    nbp->vdwType,
1119                                                    deviceStream,
1120                                                    plist->nsci,
1121                                                    adat->xq,
1122                                                    adat->f,
1123                                                    adat->shiftVec,
1124                                                    adat->fShift,
1125                                                    adat->eElec,
1126                                                    adat->eLJ,
1127                                                    plist->cj4,
1128                                                    plist->sci,
1129                                                    plist->excl,
1130                                                    adat->ljComb,
1131                                                    adat->atomTypes,
1132                                                    nbp->nbfp,
1133                                                    nbp->nbfp_comb,
1134                                                    nbp->coulomb_tab,
1135                                                    adat->numTypes,
1136                                                    nbp->rcoulomb_sq,
1137                                                    nbp->rvdw_sq,
1138                                                    nbp->two_k_rf,
1139                                                    nbp->ewald_beta,
1140                                                    nbp->rlistOuter_sq,
1141                                                    nbp->sh_ewald,
1142                                                    nbp->epsfac,
1143                                                    nbp->ewaldcoeff_lj,
1144                                                    nbp->c_rf,
1145                                                    nbp->dispersion_shift,
1146                                                    nbp->repulsion_shift,
1147                                                    nbp->vdw_switch,
1148                                                    nbp->rvdw_switch,
1149                                                    nbp->sh_lj_ewald,
1150                                                    nbp->coulomb_tab_scale,
1151                                                    stepWork.computeVirial);
1152 }
1153
1154 } // namespace Nbnxm