Fix nbnxm hipSYCL kernels 64-wide exec on AMD
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief
38  *  NBNXM SYCL kernels
39  *
40  *  \ingroup module_nbnxm
41  */
42 #include "gmxpre.h"
43
44 #include "nbnxm_sycl_kernel.h"
45
46 #include "gromacs/gpu_utils/devicebuffer.h"
47 #include "gromacs/gpu_utils/gmxsycl.h"
48 #include "gromacs/math/functions.h"
49 #include "gromacs/mdtypes/simulation_workload.h"
50 #include "gromacs/pbcutil/ishift.h"
51 #include "gromacs/utility/template_mp.h"
52
53 #include "nbnxm_sycl_kernel_utils.h"
54 #include "nbnxm_sycl_types.h"
55
56 namespace Nbnxm
57 {
58
59 //! \brief Set of boolean constants mimicking preprocessor macros.
60 template<enum ElecType elecType, enum VdwType vdwType>
61 struct EnergyFunctionProperties {
62     static constexpr bool elecCutoff = (elecType == ElecType::Cut); ///< EL_CUTOFF
63     static constexpr bool elecRF     = (elecType == ElecType::RF);  ///< EL_RF
64     static constexpr bool elecEwaldAna =
65             (elecType == ElecType::EwaldAna || elecType == ElecType::EwaldAnaTwin); ///< EL_EWALD_ANA
66     static constexpr bool elecEwaldTab =
67             (elecType == ElecType::EwaldTab || elecType == ElecType::EwaldTabTwin); ///< EL_EWALD_TAB
68     static constexpr bool elecEwaldTwin =
69             (elecType == ElecType::EwaldAnaTwin || elecType == ElecType::EwaldTabTwin);
70     static constexpr bool elecEwald        = (elecEwaldAna || elecEwaldTab); ///< EL_EWALD_ANY
71     static constexpr bool vdwCombLB        = (vdwType == VdwType::CutCombLB);
72     static constexpr bool vdwCombGeom      = (vdwType == VdwType::CutCombGeom); ///< LJ_COMB_GEOM
73     static constexpr bool vdwComb          = (vdwCombLB || vdwCombGeom);        ///< LJ_COMB
74     static constexpr bool vdwEwaldCombGeom = (vdwType == VdwType::EwaldGeom); ///< LJ_EWALD_COMB_GEOM
75     static constexpr bool vdwEwaldCombLB   = (vdwType == VdwType::EwaldLB);   ///< LJ_EWALD_COMB_LB
76     static constexpr bool vdwEwald         = (vdwEwaldCombGeom || vdwEwaldCombLB); ///< LJ_EWALD
77     static constexpr bool vdwFSwitch       = (vdwType == VdwType::FSwitch); ///< LJ_FORCE_SWITCH
78     static constexpr bool vdwPSwitch       = (vdwType == VdwType::PSwitch); ///< LJ_POT_SWITCH
79 };
80
81 //! \brief Templated constants to shorten kernel function declaration.
82 //@{
83 template<enum VdwType vdwType>
84 constexpr bool ljComb = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwComb;
85
86 template<enum ElecType elecType> // Yes, ElecType
87 constexpr bool vdwCutoffCheck = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTwin;
88
89 template<enum ElecType elecType>
90 constexpr bool elecEwald = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwald;
91
92 template<enum ElecType elecType>
93 constexpr bool elecEwaldTab = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTab;
94
95 template<enum VdwType vdwType>
96 constexpr bool ljEwald = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwEwald;
97 //@}
98
99 using cl::sycl::access::fence_space;
100 using cl::sycl::access::mode;
101 using cl::sycl::access::target;
102
103 static inline Float2 convertSigmaEpsilonToC6C12(const float sigma, const float epsilon)
104 {
105     const float sigma2 = sigma * sigma;
106     const float sigma6 = sigma2 * sigma2 * sigma2;
107     const float c6     = epsilon * sigma6;
108     const float c12    = c6 * sigma6;
109
110     return Float2(c6, c12);
111 }
112
113 template<bool doCalcEnergies>
114 static inline void ljForceSwitch(const shift_consts_t         dispersionShift,
115                                  const shift_consts_t         repulsionShift,
116                                  const float                  rVdwSwitch,
117                                  const float                  c6,
118                                  const float                  c12,
119                                  const float                  rInv,
120                                  const float                  r2,
121                                  cl::sycl::private_ptr<float> fInvR,
122                                  cl::sycl::private_ptr<float> eLJ)
123 {
124     /* force switch constants */
125     const float dispShiftV2 = dispersionShift.c2;
126     const float dispShiftV3 = dispersionShift.c3;
127     const float repuShiftV2 = repulsionShift.c2;
128     const float repuShiftV3 = repulsionShift.c3;
129
130     const float r       = r2 * rInv;
131     const float rSwitch = cl::sycl::fdim(r, rVdwSwitch); // max(r - rVdwSwitch, 0)
132
133     *fInvR += -c6 * (dispShiftV2 + dispShiftV3 * rSwitch) * rSwitch * rSwitch * rInv
134               + c12 * (repuShiftV2 + repuShiftV3 * rSwitch) * rSwitch * rSwitch * rInv;
135
136     if constexpr (doCalcEnergies)
137     {
138         const float dispShiftF2 = dispShiftV2 / 3;
139         const float dispShiftF3 = dispShiftV3 / 4;
140         const float repuShiftF2 = repuShiftV2 / 3;
141         const float repuShiftF3 = repuShiftV3 / 4;
142         *eLJ += c6 * (dispShiftF2 + dispShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch
143                 - c12 * (repuShiftF2 + repuShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch;
144     }
145 }
146
147 //! \brief Fetch C6 grid contribution coefficients and return the product of these.
148 template<enum VdwType vdwType>
149 static inline float calculateLJEwaldC6Grid(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
150                                            const int                                typeI,
151                                            const int                                typeJ)
152 {
153     if constexpr (vdwType == VdwType::EwaldGeom)
154     {
155         return a_nbfpComb[typeI][0] * a_nbfpComb[typeJ][0];
156     }
157     else
158     {
159         static_assert(vdwType == VdwType::EwaldLB);
160         /* sigma and epsilon are scaled to give 6*C6 */
161         const Float2 c6c12_i = a_nbfpComb[typeI];
162         const Float2 c6c12_j = a_nbfpComb[typeJ];
163
164         const float sigma   = c6c12_i[0] + c6c12_j[0];
165         const float epsilon = c6c12_i[1] * c6c12_j[1];
166
167         const float sigma2 = sigma * sigma;
168         return epsilon * sigma2 * sigma2 * sigma2;
169     }
170 }
171
172 //! Calculate LJ-PME grid force contribution with geometric or LB combination rule.
173 template<bool doCalcEnergies, enum VdwType vdwType>
174 static inline void ljEwaldComb(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
175                                const float                              sh_lj_ewald,
176                                const int                                typeI,
177                                const int                                typeJ,
178                                const float                              r2,
179                                const float                              r2Inv,
180                                const float                              lje_coeff2,
181                                const float                              lje_coeff6_6,
182                                const float                              int_bit,
183                                cl::sycl::private_ptr<float>             fInvR,
184                                cl::sycl::private_ptr<float>             eLJ)
185 {
186     const float c6grid = calculateLJEwaldC6Grid<vdwType>(a_nbfpComb, typeI, typeJ);
187
188     /* Recalculate inv_r6 without exclusion mask */
189     const float inv_r6_nm = r2Inv * r2Inv * r2Inv;
190     const float cr2       = lje_coeff2 * r2;
191     const float expmcr2   = cl::sycl::exp(-cr2);
192     const float poly      = 1.0F + cr2 + 0.5F * cr2 * cr2;
193
194     /* Subtract the grid force from the total LJ force */
195     *fInvR += c6grid * (inv_r6_nm - expmcr2 * (inv_r6_nm * poly + lje_coeff6_6)) * r2Inv;
196
197     if constexpr (doCalcEnergies)
198     {
199         /* Shift should be applied only to real LJ pairs */
200         const float sh_mask = sh_lj_ewald * int_bit;
201         *eLJ += c_oneSixth * c6grid * (inv_r6_nm * (1.0F - expmcr2 * poly) + sh_mask);
202     }
203 }
204
205 /*! \brief Apply potential switch. */
206 template<bool doCalcEnergies>
207 static inline void ljPotentialSwitch(const switch_consts_t        vdwSwitch,
208                                      const float                  rVdwSwitch,
209                                      const float                  rInv,
210                                      const float                  r2,
211                                      cl::sycl::private_ptr<float> fInvR,
212                                      cl::sycl::private_ptr<float> eLJ)
213 {
214     /* potential switch constants */
215     const float switchV3 = vdwSwitch.c3;
216     const float switchV4 = vdwSwitch.c4;
217     const float switchV5 = vdwSwitch.c5;
218     const float switchF2 = 3 * vdwSwitch.c3;
219     const float switchF3 = 4 * vdwSwitch.c4;
220     const float switchF4 = 5 * vdwSwitch.c5;
221
222     const float r       = r2 * rInv;
223     const float rSwitch = r - rVdwSwitch;
224
225     if (rSwitch > 0.0F)
226     {
227         const float sw =
228                 1.0F + (switchV3 + (switchV4 + switchV5 * rSwitch) * rSwitch) * rSwitch * rSwitch * rSwitch;
229         const float dsw = (switchF2 + (switchF3 + switchF4 * rSwitch) * rSwitch) * rSwitch * rSwitch;
230
231         *fInvR = (*fInvR) * sw - rInv * (*eLJ) * dsw;
232         if constexpr (doCalcEnergies)
233         {
234             *eLJ *= sw;
235         }
236     }
237 }
238
239
240 /*! \brief Calculate analytical Ewald correction term. */
241 static inline float pmeCorrF(const float z2)
242 {
243     constexpr float FN6 = -1.7357322914161492954e-8F;
244     constexpr float FN5 = 1.4703624142580877519e-6F;
245     constexpr float FN4 = -0.000053401640219807709149F;
246     constexpr float FN3 = 0.0010054721316683106153F;
247     constexpr float FN2 = -0.019278317264888380590F;
248     constexpr float FN1 = 0.069670166153766424023F;
249     constexpr float FN0 = -0.75225204789749321333F;
250
251     constexpr float FD4 = 0.0011193462567257629232F;
252     constexpr float FD3 = 0.014866955030185295499F;
253     constexpr float FD2 = 0.11583842382862377919F;
254     constexpr float FD1 = 0.50736591960530292870F;
255     constexpr float FD0 = 1.0F;
256
257     const float z4 = z2 * z2;
258
259     float       polyFD0 = FD4 * z4 + FD2;
260     const float polyFD1 = FD3 * z4 + FD1;
261     polyFD0             = polyFD0 * z4 + FD0;
262     polyFD0             = polyFD1 * z2 + polyFD0;
263
264     polyFD0 = 1.0F / polyFD0;
265
266     float polyFN0 = FN6 * z4 + FN4;
267     float polyFN1 = FN5 * z4 + FN3;
268     polyFN0       = polyFN0 * z4 + FN2;
269     polyFN1       = polyFN1 * z4 + FN1;
270     polyFN0       = polyFN0 * z4 + FN0;
271     polyFN0       = polyFN1 * z2 + polyFN0;
272
273     return polyFN0 * polyFD0;
274 }
275
276 /*! \brief Linear interpolation using exactly two FMA operations.
277  *
278  *  Implements numeric equivalent of: (1-t)*d0 + t*d1.
279  */
280 template<typename T>
281 static inline T lerp(T d0, T d1, T t)
282 {
283     return fma(t, d1, fma(-t, d0, d0));
284 }
285
286 /*! \brief Interpolate Ewald coulomb force correction using the F*r table. */
287 static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::read> a_coulombTab,
288                                              const float coulombTabScale,
289                                              const float r)
290 {
291     const float normalized = coulombTabScale * r;
292     const int   index      = static_cast<int>(normalized);
293     const float fraction   = normalized - index;
294
295     const float left  = a_coulombTab[index];
296     const float right = a_coulombTab[index + 1];
297
298     return lerp(left, right, fraction); // TODO: cl::sycl::mix
299 }
300
301 /*! \brief Reduce c_clSize j-force components and atomically accumulate into a_f.
302  *
303  * c_clSize consecutive threads hold the force components of a j-atom which we
304  * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
305  */
306 static inline void reduceForceJShuffle(Float3                             f,
307                                        const cl::sycl::nd_item<1>         itemIdx,
308                                        const int                          tidxi,
309                                        const int                          aidx,
310                                        DeviceAccessor<float, mode_atomic> a_f)
311 {
312     static_assert(c_clSize == 8 || c_clSize == 4);
313     sycl_2020::sub_group sg = itemIdx.get_sub_group();
314
315     f[0] += sycl_2020::shift_left(sg, f[0], 1);
316     f[1] += sycl_2020::shift_right(sg, f[1], 1);
317     f[2] += sycl_2020::shift_left(sg, f[2], 1);
318     if (tidxi & 1)
319     {
320         f[0] = f[1];
321     }
322
323     f[0] += sycl_2020::shift_left(sg, f[0], 2);
324     f[2] += sycl_2020::shift_right(sg, f[2], 2);
325     if (tidxi & 2)
326     {
327         f[0] = f[2];
328     }
329
330     if constexpr (c_clSize == 8)
331     {
332         f[0] += sycl_2020::shift_left(sg, f[0], 4);
333     }
334
335     if (tidxi < 3)
336     {
337         atomicFetchAdd(a_f, 3 * aidx + tidxi, f[0]);
338     }
339 }
340
341
342 /*! \brief Final i-force reduction.
343  *
344  * Reduce c_nbnxnGpuNumClusterPerSupercluster i-force componets stored in \p fCiBuf[]
345  * accumulating atomically into \p a_f.
346  * If \p calcFShift is true, further reduce shift forces and atomically accumulate into \p a_fShift.
347  *
348  * This implementation works only with power of two array sizes.
349  */
350 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
351                                          const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
352                                          const bool   calcFShift,
353                                          const cl::sycl::nd_item<1>         itemIdx,
354                                          const int                          tidxi,
355                                          const int                          tidxj,
356                                          const int                          sci,
357                                          const int                          shift,
358                                          DeviceAccessor<float, mode_atomic> a_f,
359                                          DeviceAccessor<float, mode_atomic> a_fShift)
360 {
361     // must have power of two elements in fCiBuf
362     static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster));
363
364     static constexpr int bufStride  = c_clSize * c_clSize;
365     static constexpr int clSizeLog2 = gmx::StaticLog2<c_clSize>::value;
366     const int            tidx       = tidxi + tidxj * c_clSize;
367     float                fShiftBuf  = 0;
368     for (int ciOffset = 0; ciOffset < c_nbnxnGpuNumClusterPerSupercluster; ciOffset++)
369     {
370         const int aidx = (sci * c_nbnxnGpuNumClusterPerSupercluster + ciOffset) * c_clSize + tidxi;
371         /* store i forces in shmem */
372         sm_buf[tidx]                 = fCiBuf[ciOffset][0];
373         sm_buf[bufStride + tidx]     = fCiBuf[ciOffset][1];
374         sm_buf[2 * bufStride + tidx] = fCiBuf[ciOffset][2];
375         itemIdx.barrier(fence_space::local_space);
376
377         /* Reduce the initial c_clSize values for each i atom to half
378          * every step by using c_clSize * i threads. */
379         int i = c_clSize / 2;
380         for (int j = clSizeLog2 - 1; j > 0; j--)
381         {
382             if (tidxj < i)
383             {
384                 sm_buf[tidxj * c_clSize + tidxi] += sm_buf[(tidxj + i) * c_clSize + tidxi];
385                 sm_buf[bufStride + tidxj * c_clSize + tidxi] +=
386                         sm_buf[bufStride + (tidxj + i) * c_clSize + tidxi];
387                 sm_buf[2 * bufStride + tidxj * c_clSize + tidxi] +=
388                         sm_buf[2 * bufStride + (tidxj + i) * c_clSize + tidxi];
389             }
390             i >>= 1;
391             itemIdx.barrier(fence_space::local_space);
392         }
393
394         /* i == 1, last reduction step, writing to global mem */
395         /* Split the reduction between the first 3 line threads
396            Threads with line id 0 will do the reduction for (float3).x components
397            Threads with line id 1 will do the reduction for (float3).y components
398            Threads with line id 2 will do the reduction for (float3).z components. */
399         if (tidxj < 3)
400         {
401             const float f =
402                     sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
403             atomicFetchAdd(a_f, 3 * aidx + tidxj, f);
404             if (calcFShift)
405             {
406                 fShiftBuf += f;
407             }
408         }
409         itemIdx.barrier(fence_space::local_space);
410     }
411     /* add up local shift forces into global mem */
412     if (calcFShift)
413     {
414         /* Only threads with tidxj < 3 will update fshift.
415            The threads performing the update must be the same as the threads
416            storing the reduction result above. */
417         if (tidxj < 3)
418         {
419             atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
420         }
421     }
422 }
423
424 /*! \brief Main kernel for NBNXM.
425  *
426  */
427 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
428 auto nbnxmKernel(cl::sycl::handler&                                   cgh,
429                  DeviceAccessor<Float4, mode::read>                   a_xq,
430                  DeviceAccessor<float, mode_atomic>                   a_f,
431                  DeviceAccessor<Float3, mode::read>                   a_shiftVec,
432                  DeviceAccessor<float, mode_atomic>                   a_fShift,
433                  OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyElec,
434                  OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyVdw,
435                  DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
436                  DeviceAccessor<nbnxn_sci_t, mode::read>                                 a_plistSci,
437                  DeviceAccessor<nbnxn_excl_t, mode::read>                    a_plistExcl,
438                  OptionalAccessor<Float2, mode::read, ljComb<vdwType>>       a_ljComb,
439                  OptionalAccessor<int, mode::read, !ljComb<vdwType>>         a_atomTypes,
440                  OptionalAccessor<Float2, mode::read, !ljComb<vdwType>>      a_nbfp,
441                  OptionalAccessor<Float2, mode::read, ljEwald<vdwType>>      a_nbfpComb,
442                  OptionalAccessor<float, mode::read, elecEwaldTab<elecType>> a_coulombTab,
443                  const int                                                   numTypes,
444                  const float                                                 rCoulombSq,
445                  const float                                                 rVdwSq,
446                  const float                                                 twoKRf,
447                  const float                                                 ewaldBeta,
448                  const float                                                 rlistOuterSq,
449                  const float                                                 ewaldShift,
450                  const float                                                 epsFac,
451                  const float                                                 ewaldCoeffLJ,
452                  const float                                                 cRF,
453                  const shift_consts_t                                        dispersionShift,
454                  const shift_consts_t                                        repulsionShift,
455                  const switch_consts_t                                       vdwSwitch,
456                  const float                                                 rVdwSwitch,
457                  const float                                                 ljEwaldShift,
458                  const float                                                 coulombTabScale,
459                  const bool                                                  calcShift)
460 {
461     static constexpr EnergyFunctionProperties<elecType, vdwType> props;
462
463     cgh.require(a_xq);
464     cgh.require(a_f);
465     cgh.require(a_shiftVec);
466     cgh.require(a_fShift);
467     if constexpr (doCalcEnergies)
468     {
469         cgh.require(a_energyElec);
470         cgh.require(a_energyVdw);
471     }
472     cgh.require(a_plistCJ4);
473     cgh.require(a_plistSci);
474     cgh.require(a_plistExcl);
475     if constexpr (!props.vdwComb)
476     {
477         cgh.require(a_atomTypes);
478         cgh.require(a_nbfp);
479     }
480     else
481     {
482         cgh.require(a_ljComb);
483     }
484     if constexpr (props.vdwEwald)
485     {
486         cgh.require(a_nbfpComb);
487     }
488     if constexpr (props.elecEwaldTab)
489     {
490         cgh.require(a_coulombTab);
491     }
492
493     // shmem buffer for i x+q pre-loading
494     cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(
495             cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
496
497     // shmem buffer for force reduction
498     // SYCL-TODO: Make into 3D; section 4.7.6.11 of SYCL2020 specs
499     cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_reductionBuffer(
500             cl::sycl::range<1>(c_clSize * c_clSize * DIM), cgh);
501
502     auto sm_atomTypeI = [&]() {
503         if constexpr (!props.vdwComb)
504         {
505             return cl::sycl::accessor<int, 2, mode::read_write, target::local>(
506                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
507         }
508         else
509         {
510             return nullptr;
511         }
512     }();
513
514     auto sm_ljCombI = [&]() {
515         if constexpr (props.vdwComb)
516         {
517             return cl::sycl::accessor<Float2, 2, mode::read_write, target::local>(
518                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
519         }
520         else
521         {
522             return nullptr;
523         }
524     }();
525
526     /* Flag to control the calculation of exclusion forces in the kernel
527      * We do that with Ewald (elec/vdw) and RF. Cut-off only has exclusion
528      * energy terms. */
529     constexpr bool doExclusionForces =
530             (props.elecEwald || props.elecRF || props.vdwEwald || (props.elecCutoff && doCalcEnergies));
531
532     // The post-prune j-i cluster-pair organization is linked to how exclusion and interaction mask data is stored.
533     // Currently this is ideally suited for 32-wide subgroup size but slightly less so for others,
534     // e.g. subGroupSize > prunedClusterPairSize on AMD GCN / CDNA.
535     // Hence, the two are decoupled.
536     constexpr int prunedClusterPairSize = c_clSize * c_splitClSize;
537 #if defined(HIPSYCL_PLATFORM_ROCM) // SYCL-TODO AMD RDNA/RDNA2 has 32-wide exec; how can we check for that?
538     gmx_unused constexpr int subGroupSize = c_clSize * c_clSize;
539 #else
540     gmx_unused constexpr int subGroupSize = prunedClusterPairSize;
541 #endif
542
543     return [=](cl::sycl::nd_item<1> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
544     {
545         /* thread/block/warp id-s */
546         const cl::sycl::id<3> localId = unflattenId<c_clSize, c_clSize>(itemIdx.get_local_id());
547         const unsigned        tidxi   = localId[0];
548         const unsigned        tidxj   = localId[1];
549         const unsigned        tidx    = tidxj * c_clSize + tidxi;
550         const unsigned        tidxz   = 0;
551
552         // Group indexing was flat originally, no need to unflatten it.
553         const unsigned bidx = itemIdx.get_group(0);
554
555         const sycl_2020::sub_group sg = itemIdx.get_sub_group();
556         // Could use sg.get_group_range to compute the imask & exclusion Idx, but too much of the logic relies on it anyway
557         // and in cases where prunedClusterPairSize != subGroupSize we can't use it anyway
558         const unsigned imeiIdx = tidx / prunedClusterPairSize;
559
560         Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster]; // i force buffer
561         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
562         {
563             fCiBuf[i] = Float3(0.0F, 0.0F, 0.0F);
564         }
565
566         const nbnxn_sci_t nbSci     = a_plistSci[bidx];
567         const int         sci       = nbSci.sci;
568         const int         cij4Start = nbSci.cj4_ind_start;
569         const int         cij4End   = nbSci.cj4_ind_end;
570
571         // Only needed if props.elecEwaldAna
572         const float beta2 = ewaldBeta * ewaldBeta;
573         const float beta3 = ewaldBeta * ewaldBeta * ewaldBeta;
574
575         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i += c_clSize)
576         {
577             /* Pre-load i-atom x and q into shared memory */
578             const int             ci       = sci * c_nbnxnGpuNumClusterPerSupercluster + tidxj + i;
579             const int             ai       = ci * c_clSize + tidxi;
580             const cl::sycl::id<2> cacheIdx = cl::sycl::id<2>(tidxj + i, tidxi);
581
582             const Float3 shift = a_shiftVec[nbSci.shift];
583             Float4       xqi   = a_xq[ai];
584             xqi += Float4(shift[0], shift[1], shift[2], 0.0F);
585             xqi[3] *= epsFac;
586             sm_xq[cacheIdx] = xqi;
587
588             if constexpr (!props.vdwComb)
589             {
590                 // Pre-load the i-atom types into shared memory
591                 sm_atomTypeI[cacheIdx] = a_atomTypes[ai];
592             }
593             else
594             {
595                 // Pre-load the LJ combination parameters into shared memory
596                 sm_ljCombI[cacheIdx] = a_ljComb[ai];
597             }
598         }
599         itemIdx.barrier(fence_space::local_space);
600
601         float ewaldCoeffLJ_2, ewaldCoeffLJ_6_6; // Only needed if (props.vdwEwald)
602         if constexpr (props.vdwEwald)
603         {
604             ewaldCoeffLJ_2   = ewaldCoeffLJ * ewaldCoeffLJ;
605             ewaldCoeffLJ_6_6 = ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * c_oneSixth;
606         }
607
608         float energyVdw, energyElec; // Only needed if (doCalcEnergies)
609         if constexpr (doCalcEnergies)
610         {
611             energyVdw = energyElec = 0.0F;
612         }
613         if constexpr (doCalcEnergies && doExclusionForces)
614         {
615             if (nbSci.shift == gmx::c_centralShiftIndex
616                 && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
617             {
618                 // we have the diagonal: add the charge and LJ self interaction energy term
619                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
620                 {
621                     // TODO: Are there other options?
622                     if constexpr (props.elecEwald || props.elecRF || props.elecCutoff)
623                     {
624                         const float qi = sm_xq[i][tidxi][3];
625                         energyElec += qi * qi;
626                     }
627                     if constexpr (props.vdwEwald)
628                     {
629                         energyVdw +=
630                                 a_nbfp[a_atomTypes[(sci * c_nbnxnGpuNumClusterPerSupercluster + i) * c_clSize + tidxi]
631                                        * (numTypes + 1)][0];
632                     }
633                 }
634                 /* divide the self term(s) equally over the j-threads, then multiply with the coefficients. */
635                 if constexpr (props.vdwEwald)
636                 {
637                     energyVdw /= c_clSize;
638                     energyVdw *= 0.5F * c_oneSixth * ewaldCoeffLJ_6_6; // c_OneTwelfth?
639                 }
640                 if constexpr (props.elecRF || props.elecCutoff)
641                 {
642                     // Correct for epsfac^2 due to adding qi^2 */
643                     energyElec /= epsFac * c_clSize;
644                     energyElec *= -0.5F * cRF;
645                 }
646                 if constexpr (props.elecEwald)
647                 {
648                     // Correct for epsfac^2 due to adding qi^2 */
649                     energyElec /= epsFac * c_clSize;
650                     energyElec *= -ewaldBeta * c_OneOverSqrtPi; /* last factor 1/sqrt(pi) */
651                 }
652             } // (nbSci.shift == gmx::c_centralShiftIndex && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
653         }     // (doCalcEnergies && doExclusionForces)
654
655         // Only needed if (doExclusionForces)
656         const bool nonSelfInteraction = !(nbSci.shift == gmx::c_centralShiftIndex & tidxj <= tidxi);
657
658         // loop over the j clusters = seen by any of the atoms in the current super-cluster
659         for (int j4 = cij4Start + tidxz; j4 < cij4End; j4 += 1)
660         {
661             unsigned imask = a_plistCJ4[j4].imei[imeiIdx].imask;
662             if (!doPruneNBL && !imask)
663             {
664                 continue;
665             }
666             const int wexclIdx = a_plistCJ4[j4].imei[imeiIdx].excl_ind;
667             static_assert(gmx::isPowerOfTwo(prunedClusterPairSize));
668             const unsigned wexcl = a_plistExcl[wexclIdx].pair[tidx & (prunedClusterPairSize - 1)];
669             for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
670             {
671                 const bool maskSet =
672                         imask & (superClInteractionMask << (jm * c_nbnxnGpuNumClusterPerSupercluster));
673                 if (!maskSet)
674                 {
675                     continue;
676                 }
677                 unsigned  maskJI = (1U << (jm * c_nbnxnGpuNumClusterPerSupercluster));
678                 const int cj     = a_plistCJ4[j4].cj[jm];
679                 const int aj     = cj * c_clSize + tidxj;
680
681                 // load j atom data
682                 const Float4 xqj = a_xq[aj];
683
684                 const Float3 xj(xqj[0], xqj[1], xqj[2]);
685                 const float  qj = xqj[3];
686                 int          atomTypeJ; // Only needed if (!props.vdwComb)
687                 Float2       ljCombJ;   // Only needed if (props.vdwComb)
688                 if constexpr (props.vdwComb)
689                 {
690                     ljCombJ = a_ljComb[aj];
691                 }
692                 else
693                 {
694                     atomTypeJ = a_atomTypes[aj];
695                 }
696
697                 Float3 fCjBuf(0.0F, 0.0F, 0.0F);
698
699                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
700                 {
701                     if (imask & maskJI)
702                     {
703                         // i cluster index
704                         const int ci = sci * c_nbnxnGpuNumClusterPerSupercluster + i;
705                         // all threads load an atom from i cluster ci into shmem!
706                         const Float4 xqi = sm_xq[i][tidxi];
707                         const Float3 xi(xqi[0], xqi[1], xqi[2]);
708
709                         // distance between i and j atoms
710                         const Float3 rv = xi - xj;
711                         float        r2 = norm2(rv);
712
713                         if constexpr (doPruneNBL)
714                         {
715                             /* If _none_ of the atoms pairs are in cutoff range,
716                              * the bit corresponding to the current
717                              * cluster-pair in imask gets set to 0. */
718                             if (!sycl_2020::group_any_of(sg, r2 < rlistOuterSq))
719                             {
720                                 imask &= ~maskJI;
721                             }
722                         }
723                         const float pairExclMask = (wexcl & maskJI) ? 1.0F : 0.0F;
724
725                         // cutoff & exclusion check
726
727                         const bool notExcluded = doExclusionForces ? (nonSelfInteraction | (ci != cj))
728                                                                    : (wexcl & maskJI);
729
730                         // SYCL-TODO: Check optimal way of branching here.
731                         if ((r2 < rCoulombSq) && notExcluded)
732                         {
733                             const float qi = xqi[3];
734                             int         atomTypeI; // Only needed if (!props.vdwComb)
735                             float       sigma, epsilon;
736                             Float2      c6c12;
737
738                             if constexpr (!props.vdwComb)
739                             {
740                                 /* LJ 6*C6 and 12*C12 */
741                                 atomTypeI = sm_atomTypeI[i][tidxi];
742                                 c6c12     = a_nbfp[numTypes * atomTypeI + atomTypeJ];
743                             }
744                             else
745                             {
746                                 const Float2 ljCombI = sm_ljCombI[i][tidxi];
747                                 if constexpr (props.vdwCombGeom)
748                                 {
749                                     c6c12 = Float2(ljCombI[0] * ljCombJ[0], ljCombI[1] * ljCombJ[1]);
750                                 }
751                                 else
752                                 {
753                                     static_assert(props.vdwCombLB);
754                                     // LJ 2^(1/6)*sigma and 12*epsilon
755                                     sigma   = ljCombI[0] + ljCombJ[0];
756                                     epsilon = ljCombI[1] * ljCombJ[1];
757                                     if constexpr (doCalcEnergies)
758                                     {
759                                         c6c12 = convertSigmaEpsilonToC6C12(sigma, epsilon);
760                                     }
761                                 } // props.vdwCombGeom
762                             }     // !props.vdwComb
763
764                             // c6 and c12 are unused and garbage iff props.vdwCombLB && !doCalcEnergies
765                             const float c6  = c6c12[0];
766                             const float c12 = c6c12[1];
767
768                             // Ensure distance do not become so small that r^-12 overflows
769                             r2 = std::max(r2, c_nbnxnMinDistanceSquared);
770 #if GMX_SYCL_HIPSYCL
771                             // No fast/native functions in some compilation passes
772                             const float rInv = cl::sycl::rsqrt(r2);
773 #else
774                             // SYCL-TODO: sycl::half_precision::rsqrt?
775                             const float rInv = cl::sycl::native::rsqrt(r2);
776 #endif
777                             const float r2Inv = rInv * rInv;
778                             float       r6Inv, fInvR, energyLJPair;
779                             if constexpr (!props.vdwCombLB || doCalcEnergies)
780                             {
781                                 r6Inv = r2Inv * r2Inv * r2Inv;
782                                 if constexpr (doExclusionForces)
783                                 {
784                                     // SYCL-TODO: Check if true for SYCL
785                                     /* We could mask r2Inv, but with Ewald masking both
786                                      * r6Inv and fInvR is faster */
787                                     r6Inv *= pairExclMask;
788                                 }
789                                 fInvR = r6Inv * (c12 * r6Inv - c6) * r2Inv;
790                             }
791                             else
792                             {
793                                 float sig_r  = sigma * rInv;
794                                 float sig_r2 = sig_r * sig_r;
795                                 float sig_r6 = sig_r2 * sig_r2 * sig_r2;
796                                 if constexpr (doExclusionForces)
797                                 {
798                                     sig_r6 *= pairExclMask;
799                                 }
800                                 fInvR = epsilon * sig_r6 * (sig_r6 - 1.0F) * r2Inv;
801                             } // (!props.vdwCombLB || doCalcEnergies)
802                             if constexpr (doCalcEnergies || props.vdwPSwitch)
803                             {
804                                 energyLJPair = pairExclMask
805                                                * (c12 * (r6Inv * r6Inv + repulsionShift.cpot) * c_oneTwelfth
806                                                   - c6 * (r6Inv + dispersionShift.cpot) * c_oneSixth);
807                             }
808                             if constexpr (props.vdwFSwitch)
809                             {
810                                 ljForceSwitch<doCalcEnergies>(
811                                         dispersionShift, repulsionShift, rVdwSwitch, c6, c12, rInv, r2, &fInvR, &energyLJPair);
812                             }
813                             if constexpr (props.vdwEwald)
814                             {
815                                 ljEwaldComb<doCalcEnergies, vdwType>(a_nbfpComb,
816                                                                      ljEwaldShift,
817                                                                      atomTypeI,
818                                                                      atomTypeJ,
819                                                                      r2,
820                                                                      r2Inv,
821                                                                      ewaldCoeffLJ_2,
822                                                                      ewaldCoeffLJ_6_6,
823                                                                      pairExclMask,
824                                                                      &fInvR,
825                                                                      &energyLJPair);
826                             } // (props.vdwEwald)
827                             if constexpr (props.vdwPSwitch)
828                             {
829                                 ljPotentialSwitch<doCalcEnergies>(
830                                         vdwSwitch, rVdwSwitch, rInv, r2, &fInvR, &energyLJPair);
831                             }
832                             if constexpr (props.elecEwaldTwin)
833                             {
834                                 // Separate VDW cut-off check to enable twin-range cut-offs
835                                 // (rVdw < rCoulomb <= rList)
836                                 const float vdwInRange = (r2 < rVdwSq) ? 1.0F : 0.0F;
837                                 fInvR *= vdwInRange;
838                                 if constexpr (doCalcEnergies)
839                                 {
840                                     energyLJPair *= vdwInRange;
841                                 }
842                             }
843                             if constexpr (doCalcEnergies)
844                             {
845                                 energyVdw += energyLJPair;
846                             }
847
848                             if constexpr (props.elecCutoff)
849                             {
850                                 if constexpr (doExclusionForces)
851                                 {
852                                     fInvR += qi * qj * pairExclMask * r2Inv * rInv;
853                                 }
854                                 else
855                                 {
856                                     fInvR += qi * qj * r2Inv * rInv;
857                                 }
858                             }
859                             if constexpr (props.elecRF)
860                             {
861                                 fInvR += qi * qj * (pairExclMask * r2Inv * rInv - twoKRf);
862                             }
863                             if constexpr (props.elecEwaldAna)
864                             {
865                                 fInvR += qi * qj
866                                          * (pairExclMask * r2Inv * rInv + pmeCorrF(beta2 * r2) * beta3);
867                             }
868                             if constexpr (props.elecEwaldTab)
869                             {
870                                 fInvR += qi * qj
871                                          * (pairExclMask * r2Inv
872                                             - interpolateCoulombForceR(
873                                                       a_coulombTab, coulombTabScale, r2 * rInv))
874                                          * rInv;
875                             }
876
877                             if constexpr (doCalcEnergies)
878                             {
879                                 if constexpr (props.elecCutoff)
880                                 {
881                                     energyElec += qi * qj * (pairExclMask * rInv - cRF);
882                                 }
883                                 if constexpr (props.elecRF)
884                                 {
885                                     energyElec +=
886                                             qi * qj * (pairExclMask * rInv + 0.5f * twoKRf * r2 - cRF);
887                                 }
888                                 if constexpr (props.elecEwald)
889                                 {
890                                     energyElec +=
891                                             qi * qj
892                                             * (rInv * (pairExclMask - cl::sycl::erf(r2 * rInv * ewaldBeta))
893                                                - pairExclMask * ewaldShift);
894                                 }
895                             }
896
897                             const Float3 forceIJ = rv * fInvR;
898
899                             /* accumulate j forces in registers */
900                             fCjBuf -= forceIJ;
901                             /* accumulate i forces in registers */
902                             fCiBuf[i] += forceIJ;
903                         } // (r2 < rCoulombSq) && notExcluded
904                     }     // (imask & maskJI)
905                     /* shift the mask bit by 1 */
906                     maskJI += maskJI;
907                 } // for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
908                 /* reduce j forces */
909                 reduceForceJShuffle(fCjBuf, itemIdx, tidxi, aj, a_f);
910             } // for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
911             if constexpr (doPruneNBL)
912             {
913                 /* Update the imask with the new one which does not contain the
914                  * out of range clusters anymore. */
915                 a_plistCJ4[j4].imei[imeiIdx].imask = imask;
916             }
917         } // for (int j4 = cij4Start; j4 < cij4End; j4 += 1)
918
919         /* skip central shifts when summing shift forces */
920         const bool doCalcShift = (calcShift && !(nbSci.shift == gmx::c_centralShiftIndex));
921
922         reduceForceIAndFShift(
923                 sm_reductionBuffer, fCiBuf, doCalcShift, itemIdx, tidxi, tidxj, sci, nbSci.shift, a_f, a_fShift);
924
925         if constexpr (doCalcEnergies)
926         {
927             const float energyVdwGroup = sycl_2020::group_reduce(
928                     itemIdx.get_group(), energyVdw, 0.0F, sycl_2020::plus<float>());
929             const float energyElecGroup = sycl_2020::group_reduce(
930                     itemIdx.get_group(), energyElec, 0.0F, sycl_2020::plus<float>());
931
932             if (tidx == 0)
933             {
934                 atomicFetchAdd(a_energyVdw, 0, energyVdwGroup);
935                 atomicFetchAdd(a_energyElec, 0, energyElecGroup);
936             }
937         }
938     };
939 }
940
941 // SYCL 1.2.1 requires providing a unique type for a kernel. Should not be needed for SYCL2020.
942 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
943 class NbnxmKernelName;
944
945 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType, class... Args>
946 cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int numSci, Args&&... args)
947 {
948     // Should not be needed for SYCL2020.
949     using kernelNameType = NbnxmKernelName<doPruneNBL, doCalcEnergies, elecType, vdwType>;
950
951     /* Kernel launch config:
952      * - The thread block dimensions match the size of i-clusters, j-clusters,
953      *   and j-cluster concurrency, in x, y, and z, respectively.
954      * - The 1D block-grid contains as many blocks as super-clusters.
955      */
956     const int                   numBlocks = numSci;
957     const cl::sycl::range<3>    blockSize{ c_clSize, c_clSize, 1 };
958     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
959     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
960
961     cl::sycl::queue q = deviceStream.stream();
962
963     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
964         auto kernel = nbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>(
965                 cgh, std::forward<Args>(args)...);
966         cgh.parallel_for<kernelNameType>(flattenNDRange(range), kernel);
967     });
968
969     return e;
970 }
971
972 template<class... Args>
973 cl::sycl::event chooseAndLaunchNbnxmKernel(bool          doPruneNBL,
974                                            bool          doCalcEnergies,
975                                            enum ElecType elecType,
976                                            enum VdwType  vdwType,
977                                            Args&&... args)
978 {
979     return gmx::dispatchTemplatedFunction(
980             [&](auto doPruneNBL_, auto doCalcEnergies_, auto elecType_, auto vdwType_) {
981                 return launchNbnxmKernel<doPruneNBL_, doCalcEnergies_, elecType_, vdwType_>(
982                         std::forward<Args>(args)...);
983             },
984             doPruneNBL,
985             doCalcEnergies,
986             elecType,
987             vdwType);
988 }
989
990 void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const InteractionLocality iloc)
991 {
992     NBAtomDataGpu*      adat         = nb->atdat;
993     NBParamGpu*         nbp          = nb->nbparam;
994     gpu_plist*          plist        = nb->plist[iloc];
995     const bool          doPruneNBL   = (plist->haveFreshList && !nb->didPrune[iloc]);
996     const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
997
998     // Casting to float simplifies using atomic ops in the kernel
999     cl::sycl::buffer<Float3, 1> f(*adat->f.buffer_);
1000     auto                        fAsFloat = f.reinterpret<float, 1>(f.get_count() * DIM);
1001     cl::sycl::buffer<Float3, 1> fShift(*adat->fShift.buffer_);
1002     auto fShiftAsFloat = fShift.reinterpret<float, 1>(fShift.get_count() * DIM);
1003
1004     cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL,
1005                                                    stepWork.computeEnergy,
1006                                                    nbp->elecType,
1007                                                    nbp->vdwType,
1008                                                    deviceStream,
1009                                                    plist->nsci,
1010                                                    adat->xq,
1011                                                    fAsFloat,
1012                                                    adat->shiftVec,
1013                                                    fShiftAsFloat,
1014                                                    adat->eElec,
1015                                                    adat->eLJ,
1016                                                    plist->cj4,
1017                                                    plist->sci,
1018                                                    plist->excl,
1019                                                    adat->ljComb,
1020                                                    adat->atomTypes,
1021                                                    nbp->nbfp,
1022                                                    nbp->nbfp_comb,
1023                                                    nbp->coulomb_tab,
1024                                                    adat->numTypes,
1025                                                    nbp->rcoulomb_sq,
1026                                                    nbp->rvdw_sq,
1027                                                    nbp->two_k_rf,
1028                                                    nbp->ewald_beta,
1029                                                    nbp->rlistOuter_sq,
1030                                                    nbp->sh_ewald,
1031                                                    nbp->epsfac,
1032                                                    nbp->ewaldcoeff_lj,
1033                                                    nbp->c_rf,
1034                                                    nbp->dispersion_shift,
1035                                                    nbp->repulsion_shift,
1036                                                    nbp->vdw_switch,
1037                                                    nbp->rvdw_switch,
1038                                                    nbp->sh_lj_ewald,
1039                                                    nbp->coulomb_tab_scale,
1040                                                    stepWork.computeVirial);
1041 }
1042
1043 } // namespace Nbnxm