1916010e23bba1ed42ae6f39c9f9435b843b570c
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief
38  *  NBNXM SYCL kernels
39  *
40  *  \ingroup module_nbnxm
41  */
42 #include "gmxpre.h"
43
44 #include "nbnxm_sycl_kernel.h"
45
46 #include "gromacs/gpu_utils/devicebuffer.h"
47 #include "gromacs/gpu_utils/gmxsycl.h"
48 #include "gromacs/math/functions.h"
49 #include "gromacs/mdtypes/simulation_workload.h"
50 #include "gromacs/pbcutil/ishift.h"
51 #include "gromacs/utility/template_mp.h"
52
53 #include "nbnxm_sycl_kernel_utils.h"
54 #include "nbnxm_sycl_types.h"
55
56 namespace Nbnxm
57 {
58
59 //! \brief Set of boolean constants mimicking preprocessor macros.
60 template<enum ElecType elecType, enum VdwType vdwType>
61 struct EnergyFunctionProperties {
62     static constexpr bool elecCutoff = (elecType == ElecType::Cut); ///< EL_CUTOFF
63     static constexpr bool elecRF     = (elecType == ElecType::RF);  ///< EL_RF
64     static constexpr bool elecEwaldAna =
65             (elecType == ElecType::EwaldAna || elecType == ElecType::EwaldAnaTwin); ///< EL_EWALD_ANA
66     static constexpr bool elecEwaldTab =
67             (elecType == ElecType::EwaldTab || elecType == ElecType::EwaldTabTwin); ///< EL_EWALD_TAB
68     static constexpr bool elecEwaldTwin =
69             (elecType == ElecType::EwaldAnaTwin || elecType == ElecType::EwaldTabTwin);
70     static constexpr bool elecEwald        = (elecEwaldAna || elecEwaldTab); ///< EL_EWALD_ANY
71     static constexpr bool vdwCombLB        = (vdwType == VdwType::CutCombLB);
72     static constexpr bool vdwCombGeom      = (vdwType == VdwType::CutCombGeom); ///< LJ_COMB_GEOM
73     static constexpr bool vdwComb          = (vdwCombLB || vdwCombGeom);        ///< LJ_COMB
74     static constexpr bool vdwEwaldCombGeom = (vdwType == VdwType::EwaldGeom); ///< LJ_EWALD_COMB_GEOM
75     static constexpr bool vdwEwaldCombLB   = (vdwType == VdwType::EwaldLB);   ///< LJ_EWALD_COMB_LB
76     static constexpr bool vdwEwald         = (vdwEwaldCombGeom || vdwEwaldCombLB); ///< LJ_EWALD
77     static constexpr bool vdwFSwitch       = (vdwType == VdwType::FSwitch); ///< LJ_FORCE_SWITCH
78     static constexpr bool vdwPSwitch       = (vdwType == VdwType::PSwitch); ///< LJ_POT_SWITCH
79 };
80
81 //! \brief Templated constants to shorten kernel function declaration.
82 //@{
83 template<enum VdwType vdwType>
84 constexpr bool ljComb = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwComb;
85
86 template<enum ElecType elecType> // Yes, ElecType
87 constexpr bool vdwCutoffCheck = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTwin;
88
89 template<enum ElecType elecType>
90 constexpr bool elecEwald = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwald;
91
92 template<enum ElecType elecType>
93 constexpr bool elecEwaldTab = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTab;
94
95 template<enum VdwType vdwType>
96 constexpr bool ljEwald = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwEwald;
97 //@}
98
99 using cl::sycl::access::fence_space;
100 using cl::sycl::access::mode;
101 using cl::sycl::access::target;
102
103 static inline Float2 convertSigmaEpsilonToC6C12(const float sigma, const float epsilon)
104 {
105     const float sigma2 = sigma * sigma;
106     const float sigma6 = sigma2 * sigma2 * sigma2;
107     const float c6     = epsilon * sigma6;
108     const float c12    = c6 * sigma6;
109
110     return Float2(c6, c12);
111 }
112
113 template<bool doCalcEnergies>
114 static inline void ljForceSwitch(const shift_consts_t         dispersionShift,
115                                  const shift_consts_t         repulsionShift,
116                                  const float                  rVdwSwitch,
117                                  const float                  c6,
118                                  const float                  c12,
119                                  const float                  rInv,
120                                  const float                  r2,
121                                  cl::sycl::private_ptr<float> fInvR,
122                                  cl::sycl::private_ptr<float> eLJ)
123 {
124     /* force switch constants */
125     const float dispShiftV2 = dispersionShift.c2;
126     const float dispShiftV3 = dispersionShift.c3;
127     const float repuShiftV2 = repulsionShift.c2;
128     const float repuShiftV3 = repulsionShift.c3;
129
130     const float r       = r2 * rInv;
131     const float rSwitch = cl::sycl::fdim(r, rVdwSwitch); // max(r - rVdwSwitch, 0)
132
133     *fInvR += -c6 * (dispShiftV2 + dispShiftV3 * rSwitch) * rSwitch * rSwitch * rInv
134               + c12 * (repuShiftV2 + repuShiftV3 * rSwitch) * rSwitch * rSwitch * rInv;
135
136     if constexpr (doCalcEnergies)
137     {
138         const float dispShiftF2 = dispShiftV2 / 3;
139         const float dispShiftF3 = dispShiftV3 / 4;
140         const float repuShiftF2 = repuShiftV2 / 3;
141         const float repuShiftF3 = repuShiftV3 / 4;
142         *eLJ += c6 * (dispShiftF2 + dispShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch
143                 - c12 * (repuShiftF2 + repuShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch;
144     }
145 }
146
147 //! \brief Fetch C6 grid contribution coefficients and return the product of these.
148 template<enum VdwType vdwType>
149 static inline float calculateLJEwaldC6Grid(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
150                                            const int                                typeI,
151                                            const int                                typeJ)
152 {
153     if constexpr (vdwType == VdwType::EwaldGeom)
154     {
155         return a_nbfpComb[typeI][0] * a_nbfpComb[typeJ][0];
156     }
157     else
158     {
159         static_assert(vdwType == VdwType::EwaldLB);
160         /* sigma and epsilon are scaled to give 6*C6 */
161         const Float2 c6c12_i = a_nbfpComb[typeI];
162         const Float2 c6c12_j = a_nbfpComb[typeJ];
163
164         const float sigma   = c6c12_i[0] + c6c12_j[0];
165         const float epsilon = c6c12_i[1] * c6c12_j[1];
166
167         const float sigma2 = sigma * sigma;
168         return epsilon * sigma2 * sigma2 * sigma2;
169     }
170 }
171
172 //! Calculate LJ-PME grid force contribution with geometric or LB combination rule.
173 template<bool doCalcEnergies, enum VdwType vdwType>
174 static inline void ljEwaldComb(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
175                                const float                              sh_lj_ewald,
176                                const int                                typeI,
177                                const int                                typeJ,
178                                const float                              r2,
179                                const float                              r2Inv,
180                                const float                              lje_coeff2,
181                                const float                              lje_coeff6_6,
182                                const float                              int_bit,
183                                cl::sycl::private_ptr<float>             fInvR,
184                                cl::sycl::private_ptr<float>             eLJ)
185 {
186     const float c6grid = calculateLJEwaldC6Grid<vdwType>(a_nbfpComb, typeI, typeJ);
187
188     /* Recalculate inv_r6 without exclusion mask */
189     const float inv_r6_nm = r2Inv * r2Inv * r2Inv;
190     const float cr2       = lje_coeff2 * r2;
191     const float expmcr2   = cl::sycl::exp(-cr2);
192     const float poly      = 1.0F + cr2 + 0.5F * cr2 * cr2;
193
194     /* Subtract the grid force from the total LJ force */
195     *fInvR += c6grid * (inv_r6_nm - expmcr2 * (inv_r6_nm * poly + lje_coeff6_6)) * r2Inv;
196
197     if constexpr (doCalcEnergies)
198     {
199         /* Shift should be applied only to real LJ pairs */
200         const float sh_mask = sh_lj_ewald * int_bit;
201         *eLJ += c_oneSixth * c6grid * (inv_r6_nm * (1.0F - expmcr2 * poly) + sh_mask);
202     }
203 }
204
205 /*! \brief Apply potential switch. */
206 template<bool doCalcEnergies>
207 static inline void ljPotentialSwitch(const switch_consts_t        vdwSwitch,
208                                      const float                  rVdwSwitch,
209                                      const float                  rInv,
210                                      const float                  r2,
211                                      cl::sycl::private_ptr<float> fInvR,
212                                      cl::sycl::private_ptr<float> eLJ)
213 {
214     /* potential switch constants */
215     const float switchV3 = vdwSwitch.c3;
216     const float switchV4 = vdwSwitch.c4;
217     const float switchV5 = vdwSwitch.c5;
218     const float switchF2 = 3 * vdwSwitch.c3;
219     const float switchF3 = 4 * vdwSwitch.c4;
220     const float switchF4 = 5 * vdwSwitch.c5;
221
222     const float r       = r2 * rInv;
223     const float rSwitch = r - rVdwSwitch;
224
225     if (rSwitch > 0.0F)
226     {
227         const float sw =
228                 1.0F + (switchV3 + (switchV4 + switchV5 * rSwitch) * rSwitch) * rSwitch * rSwitch * rSwitch;
229         const float dsw = (switchF2 + (switchF3 + switchF4 * rSwitch) * rSwitch) * rSwitch * rSwitch;
230
231         *fInvR = (*fInvR) * sw - rInv * (*eLJ) * dsw;
232         if constexpr (doCalcEnergies)
233         {
234             *eLJ *= sw;
235         }
236     }
237 }
238
239
240 /*! \brief Calculate analytical Ewald correction term. */
241 static inline float pmeCorrF(const float z2)
242 {
243     constexpr float FN6 = -1.7357322914161492954e-8F;
244     constexpr float FN5 = 1.4703624142580877519e-6F;
245     constexpr float FN4 = -0.000053401640219807709149F;
246     constexpr float FN3 = 0.0010054721316683106153F;
247     constexpr float FN2 = -0.019278317264888380590F;
248     constexpr float FN1 = 0.069670166153766424023F;
249     constexpr float FN0 = -0.75225204789749321333F;
250
251     constexpr float FD4 = 0.0011193462567257629232F;
252     constexpr float FD3 = 0.014866955030185295499F;
253     constexpr float FD2 = 0.11583842382862377919F;
254     constexpr float FD1 = 0.50736591960530292870F;
255     constexpr float FD0 = 1.0F;
256
257     const float z4 = z2 * z2;
258
259     float       polyFD0 = FD4 * z4 + FD2;
260     const float polyFD1 = FD3 * z4 + FD1;
261     polyFD0             = polyFD0 * z4 + FD0;
262     polyFD0             = polyFD1 * z2 + polyFD0;
263
264     polyFD0 = 1.0F / polyFD0;
265
266     float polyFN0 = FN6 * z4 + FN4;
267     float polyFN1 = FN5 * z4 + FN3;
268     polyFN0       = polyFN0 * z4 + FN2;
269     polyFN1       = polyFN1 * z4 + FN1;
270     polyFN0       = polyFN0 * z4 + FN0;
271     polyFN0       = polyFN1 * z2 + polyFN0;
272
273     return polyFN0 * polyFD0;
274 }
275
276 /*! \brief Linear interpolation using exactly two FMA operations.
277  *
278  *  Implements numeric equivalent of: (1-t)*d0 + t*d1.
279  */
280 template<typename T>
281 static inline T lerp(T d0, T d1, T t)
282 {
283     return fma(t, d1, fma(-t, d0, d0));
284 }
285
286 /*! \brief Interpolate Ewald coulomb force correction using the F*r table. */
287 static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::read> a_coulombTab,
288                                              const float coulombTabScale,
289                                              const float r)
290 {
291     const float normalized = coulombTabScale * r;
292     const int   index      = static_cast<int>(normalized);
293     const float fraction   = normalized - index;
294
295     const float left  = a_coulombTab[index];
296     const float right = a_coulombTab[index + 1];
297
298     return lerp(left, right, fraction); // TODO: cl::sycl::mix
299 }
300
301 /*! \brief Reduce c_clSize j-force components and atomically accumulate into a_f.
302  *
303  * c_clSize consecutive threads hold the force components of a j-atom which we
304  * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
305  */
306 static inline void reduceForceJShuffle(Float3                             f,
307                                        const cl::sycl::nd_item<1>         itemIdx,
308                                        const int                          tidxi,
309                                        const int                          aidx,
310                                        DeviceAccessor<float, mode_atomic> a_f)
311 {
312     static_assert(c_clSize == 8 || c_clSize == 4);
313     sycl_2020::sub_group sg = itemIdx.get_sub_group();
314
315     f[0] += sycl_2020::shift_left(sg, f[0], 1);
316     f[1] += sycl_2020::shift_right(sg, f[1], 1);
317     f[2] += sycl_2020::shift_left(sg, f[2], 1);
318     if (tidxi & 1)
319     {
320         f[0] = f[1];
321     }
322
323     f[0] += sycl_2020::shift_left(sg, f[0], 2);
324     f[2] += sycl_2020::shift_right(sg, f[2], 2);
325     if (tidxi & 2)
326     {
327         f[0] = f[2];
328     }
329
330     if constexpr (c_clSize == 8)
331     {
332         f[0] += sycl_2020::shift_left(sg, f[0], 4);
333     }
334
335     if (tidxi < 3)
336     {
337         atomicFetchAdd(a_f, 3 * aidx + tidxi, f[0]);
338     }
339 }
340
341
342 /*! \brief Final i-force reduction.
343  *
344  * Reduce c_nbnxnGpuNumClusterPerSupercluster i-force componets stored in \p fCiBuf[]
345  * accumulating atomically into \p a_f.
346  * If \p calcFShift is true, further reduce shift forces and atomically accumulate into \p a_fShift.
347  *
348  * This implementation works only with power of two array sizes.
349  */
350 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
351                                          const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
352                                          const bool   calcFShift,
353                                          const cl::sycl::nd_item<1>         itemIdx,
354                                          const int                          tidxi,
355                                          const int                          tidxj,
356                                          const int                          sci,
357                                          const int                          shift,
358                                          DeviceAccessor<float, mode_atomic> a_f,
359                                          DeviceAccessor<float, mode_atomic> a_fShift)
360 {
361     // must have power of two elements in fCiBuf
362     static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster));
363
364     static constexpr int bufStride  = c_clSize * c_clSize;
365     static constexpr int clSizeLog2 = gmx::StaticLog2<c_clSize>::value;
366     const int            tidx       = tidxi + tidxj * c_clSize;
367     float                fShiftBuf  = 0;
368     for (int ciOffset = 0; ciOffset < c_nbnxnGpuNumClusterPerSupercluster; ciOffset++)
369     {
370         const int aidx = (sci * c_nbnxnGpuNumClusterPerSupercluster + ciOffset) * c_clSize + tidxi;
371         /* store i forces in shmem */
372         sm_buf[tidx]                 = fCiBuf[ciOffset][0];
373         sm_buf[bufStride + tidx]     = fCiBuf[ciOffset][1];
374         sm_buf[2 * bufStride + tidx] = fCiBuf[ciOffset][2];
375         itemIdx.barrier(fence_space::local_space);
376
377         /* Reduce the initial c_clSize values for each i atom to half
378          * every step by using c_clSize * i threads. */
379         int i = c_clSize / 2;
380         for (int j = clSizeLog2 - 1; j > 0; j--)
381         {
382             if (tidxj < i)
383             {
384                 sm_buf[tidxj * c_clSize + tidxi] += sm_buf[(tidxj + i) * c_clSize + tidxi];
385                 sm_buf[bufStride + tidxj * c_clSize + tidxi] +=
386                         sm_buf[bufStride + (tidxj + i) * c_clSize + tidxi];
387                 sm_buf[2 * bufStride + tidxj * c_clSize + tidxi] +=
388                         sm_buf[2 * bufStride + (tidxj + i) * c_clSize + tidxi];
389             }
390             i >>= 1;
391             itemIdx.barrier(fence_space::local_space);
392         }
393
394         /* i == 1, last reduction step, writing to global mem */
395         /* Split the reduction between the first 3 line threads
396            Threads with line id 0 will do the reduction for (float3).x components
397            Threads with line id 1 will do the reduction for (float3).y components
398            Threads with line id 2 will do the reduction for (float3).z components. */
399         if (tidxj < 3)
400         {
401             const float f =
402                     sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
403             atomicFetchAdd(a_f, 3 * aidx + tidxj, f);
404             if (calcFShift)
405             {
406                 fShiftBuf += f;
407             }
408         }
409         itemIdx.barrier(fence_space::local_space);
410     }
411     /* add up local shift forces into global mem */
412     if (calcFShift)
413     {
414         /* Only threads with tidxj < 3 will update fshift.
415            The threads performing the update must be the same as the threads
416            storing the reduction result above. */
417         if (tidxj < 3)
418         {
419             atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
420         }
421     }
422 }
423
424 /*! \brief Main kernel for NBNXM.
425  *
426  */
427 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
428 auto nbnxmKernel(cl::sycl::handler&                                   cgh,
429                  DeviceAccessor<Float4, mode::read>                   a_xq,
430                  DeviceAccessor<float, mode_atomic>                   a_f,
431                  DeviceAccessor<Float3, mode::read>                   a_shiftVec,
432                  DeviceAccessor<float, mode_atomic>                   a_fShift,
433                  OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyElec,
434                  OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyVdw,
435                  DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
436                  DeviceAccessor<nbnxn_sci_t, mode::read>                                 a_plistSci,
437                  DeviceAccessor<nbnxn_excl_t, mode::read>                    a_plistExcl,
438                  OptionalAccessor<Float2, mode::read, ljComb<vdwType>>       a_ljComb,
439                  OptionalAccessor<int, mode::read, !ljComb<vdwType>>         a_atomTypes,
440                  OptionalAccessor<Float2, mode::read, !ljComb<vdwType>>      a_nbfp,
441                  OptionalAccessor<Float2, mode::read, ljEwald<vdwType>>      a_nbfpComb,
442                  OptionalAccessor<float, mode::read, elecEwaldTab<elecType>> a_coulombTab,
443                  const int                                                   numTypes,
444                  const float                                                 rCoulombSq,
445                  const float                                                 rVdwSq,
446                  const float                                                 twoKRf,
447                  const float                                                 ewaldBeta,
448                  const float                                                 rlistOuterSq,
449                  const float                                                 ewaldShift,
450                  const float                                                 epsFac,
451                  const float                                                 ewaldCoeffLJ,
452                  const float                                                 cRF,
453                  const shift_consts_t                                        dispersionShift,
454                  const shift_consts_t                                        repulsionShift,
455                  const switch_consts_t                                       vdwSwitch,
456                  const float                                                 rVdwSwitch,
457                  const float                                                 ljEwaldShift,
458                  const float                                                 coulombTabScale,
459                  const bool                                                  calcShift)
460 {
461     static constexpr EnergyFunctionProperties<elecType, vdwType> props;
462
463     cgh.require(a_xq);
464     cgh.require(a_f);
465     cgh.require(a_shiftVec);
466     cgh.require(a_fShift);
467     if constexpr (doCalcEnergies)
468     {
469         cgh.require(a_energyElec);
470         cgh.require(a_energyVdw);
471     }
472     cgh.require(a_plistCJ4);
473     cgh.require(a_plistSci);
474     cgh.require(a_plistExcl);
475     if constexpr (!props.vdwComb)
476     {
477         cgh.require(a_atomTypes);
478         cgh.require(a_nbfp);
479     }
480     else
481     {
482         cgh.require(a_ljComb);
483     }
484     if constexpr (props.vdwEwald)
485     {
486         cgh.require(a_nbfpComb);
487     }
488     if constexpr (props.elecEwaldTab)
489     {
490         cgh.require(a_coulombTab);
491     }
492
493     // shmem buffer for i x+q pre-loading
494     cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(
495             cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
496
497     // shmem buffer for force reduction
498     // SYCL-TODO: Make into 3D; section 4.7.6.11 of SYCL2020 specs
499     cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_reductionBuffer(
500             cl::sycl::range<1>(c_clSize * c_clSize * DIM), cgh);
501
502     auto sm_atomTypeI = [&]() {
503         if constexpr (!props.vdwComb)
504         {
505             return cl::sycl::accessor<int, 2, mode::read_write, target::local>(
506                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
507         }
508         else
509         {
510             return nullptr;
511         }
512     }();
513
514     auto sm_ljCombI = [&]() {
515         if constexpr (props.vdwComb)
516         {
517             return cl::sycl::accessor<Float2, 2, mode::read_write, target::local>(
518                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
519         }
520         else
521         {
522             return nullptr;
523         }
524     }();
525
526     /* Flag to control the calculation of exclusion forces in the kernel
527      * We do that with Ewald (elec/vdw) and RF. Cut-off only has exclusion
528      * energy terms. */
529     constexpr bool doExclusionForces =
530             (props.elecEwald || props.elecRF || props.vdwEwald || (props.elecCutoff && doCalcEnergies));
531
532     // The post-prune j-i cluster-pair organization is linked to how exclusion and interaction mask data is stored.
533     // Currently this is ideally suited for 32-wide subgroup size but slightly less so for others,
534     // e.g. subGroupSize > prunedClusterPairSize on AMD GCN / CDNA.
535     // Hence, the two are decoupled.
536     constexpr int prunedClusterPairSize = c_clSize * c_splitClSize;
537 #if defined(HIPSYCL_PLATFORM_ROCM) // SYCL-TODO AMD RDNA/RDNA2 has 32-wide exec; how can we check for that?
538     constexpr int subGroupSize = c_clSize * c_clSize;
539 #else
540     constexpr int subGroupSize = prunedClusterPairSize;
541 #endif
542
543     return [=](cl::sycl::nd_item<1> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
544     {
545         /* thread/block/warp id-s */
546         const cl::sycl::id<3> localId = unflattenId<c_clSize, c_clSize>(itemIdx.get_local_id());
547         const unsigned        tidxi   = localId[0];
548         const unsigned        tidxj   = localId[1];
549         const unsigned        tidx    = tidxj * c_clSize + tidxi;
550         const unsigned        tidxz   = 0;
551
552         // Group indexing was flat originally, no need to unflatten it.
553         const unsigned bidx = itemIdx.get_group(0);
554
555         const sycl_2020::sub_group sg = itemIdx.get_sub_group();
556         // Could use sg.get_group_range to compute the imask & exclusion Idx, but too much of the logic relies on it anyway
557         // and in cases where prunedClusterPairSize != subGroupSize we can't use it anyway
558         const unsigned imeiIdx = tidx / prunedClusterPairSize;
559
560         Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster]; // i force buffer
561         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
562         {
563             fCiBuf[i] = Float3(0.0F, 0.0F, 0.0F);
564         }
565
566         const nbnxn_sci_t nbSci     = a_plistSci[bidx];
567         const int         sci       = nbSci.sci;
568         const int         cij4Start = nbSci.cj4_ind_start;
569         const int         cij4End   = nbSci.cj4_ind_end;
570
571         // Only needed if props.elecEwaldAna
572         const float beta2 = ewaldBeta * ewaldBeta;
573         const float beta3 = ewaldBeta * ewaldBeta * ewaldBeta;
574
575         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i += c_clSize)
576         {
577             /* Pre-load i-atom x and q into shared memory */
578             const int             ci       = sci * c_nbnxnGpuNumClusterPerSupercluster + tidxj + i;
579             const int             ai       = ci * c_clSize + tidxi;
580             const cl::sycl::id<2> cacheIdx = cl::sycl::id<2>(tidxj + i, tidxi);
581
582             const Float3 shift = a_shiftVec[nbSci.shift];
583             Float4       xqi   = a_xq[ai];
584             xqi += Float4(shift[0], shift[1], shift[2], 0.0F);
585             xqi[3] *= epsFac;
586             sm_xq[cacheIdx] = xqi;
587
588             if constexpr (!props.vdwComb)
589             {
590                 // Pre-load the i-atom types into shared memory
591                 sm_atomTypeI[cacheIdx] = a_atomTypes[ai];
592             }
593             else
594             {
595                 // Pre-load the LJ combination parameters into shared memory
596                 sm_ljCombI[cacheIdx] = a_ljComb[ai];
597             }
598         }
599         itemIdx.barrier(fence_space::local_space);
600
601         float ewaldCoeffLJ_2, ewaldCoeffLJ_6_6; // Only needed if (props.vdwEwald)
602         if constexpr (props.vdwEwald)
603         {
604             ewaldCoeffLJ_2   = ewaldCoeffLJ * ewaldCoeffLJ;
605             ewaldCoeffLJ_6_6 = ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * c_oneSixth;
606         }
607
608         float energyVdw, energyElec; // Only needed if (doCalcEnergies)
609         if constexpr (doCalcEnergies)
610         {
611             energyVdw = energyElec = 0.0F;
612         }
613         if constexpr (doCalcEnergies && doExclusionForces)
614         {
615             if (nbSci.shift == CENTRAL && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
616             {
617                 // we have the diagonal: add the charge and LJ self interaction energy term
618                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
619                 {
620                     // TODO: Are there other options?
621                     if constexpr (props.elecEwald || props.elecRF || props.elecCutoff)
622                     {
623                         const float qi = sm_xq[i][tidxi][3];
624                         energyElec += qi * qi;
625                     }
626                     if constexpr (props.vdwEwald)
627                     {
628                         energyVdw +=
629                                 a_nbfp[a_atomTypes[(sci * c_nbnxnGpuNumClusterPerSupercluster + i) * c_clSize + tidxi]
630                                        * (numTypes + 1)][0];
631                     }
632                 }
633                 /* divide the self term(s) equally over the j-threads, then multiply with the coefficients. */
634                 if constexpr (props.vdwEwald)
635                 {
636                     energyVdw /= c_clSize;
637                     energyVdw *= 0.5F * c_oneSixth * ewaldCoeffLJ_6_6; // c_OneTwelfth?
638                 }
639                 if constexpr (props.elecRF || props.elecCutoff)
640                 {
641                     // Correct for epsfac^2 due to adding qi^2 */
642                     energyElec /= epsFac * c_clSize;
643                     energyElec *= -0.5F * cRF;
644                 }
645                 if constexpr (props.elecEwald)
646                 {
647                     // Correct for epsfac^2 due to adding qi^2 */
648                     energyElec /= epsFac * c_clSize;
649                     energyElec *= -ewaldBeta * c_OneOverSqrtPi; /* last factor 1/sqrt(pi) */
650                 }
651             } // (nbSci.shift == CENTRAL && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
652         }     // (doCalcEnergies && doExclusionForces)
653
654         // Only needed if (doExclusionForces)
655         const bool nonSelfInteraction = !(nbSci.shift == CENTRAL & tidxj <= tidxi);
656
657         // loop over the j clusters = seen by any of the atoms in the current super-cluster
658         for (int j4 = cij4Start + tidxz; j4 < cij4End; j4 += 1)
659         {
660             unsigned imask = a_plistCJ4[j4].imei[imeiIdx].imask;
661             if (!doPruneNBL && !imask)
662             {
663                 continue;
664             }
665             const int wexclIdx = a_plistCJ4[j4].imei[imeiIdx].excl_ind;
666             const unsigned wexcl = a_plistExcl[wexclIdx].pair[tidx & (subGroupSize - 1)]; // sg.get_local_linear_id()
667             for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
668             {
669                 const bool maskSet =
670                         imask & (superClInteractionMask << (jm * c_nbnxnGpuNumClusterPerSupercluster));
671                 if (!maskSet)
672                 {
673                     continue;
674                 }
675                 unsigned  maskJI = (1U << (jm * c_nbnxnGpuNumClusterPerSupercluster));
676                 const int cj     = a_plistCJ4[j4].cj[jm];
677                 const int aj     = cj * c_clSize + tidxj;
678
679                 // load j atom data
680                 const Float4 xqj = a_xq[aj];
681
682                 const Float3 xj(xqj[0], xqj[1], xqj[2]);
683                 const float  qj = xqj[3];
684                 int          atomTypeJ; // Only needed if (!props.vdwComb)
685                 Float2       ljCombJ;   // Only needed if (props.vdwComb)
686                 if constexpr (props.vdwComb)
687                 {
688                     ljCombJ = a_ljComb[aj];
689                 }
690                 else
691                 {
692                     atomTypeJ = a_atomTypes[aj];
693                 }
694
695                 Float3 fCjBuf(0.0F, 0.0F, 0.0F);
696
697                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
698                 {
699                     if (imask & maskJI)
700                     {
701                         // i cluster index
702                         const int ci = sci * c_nbnxnGpuNumClusterPerSupercluster + i;
703                         // all threads load an atom from i cluster ci into shmem!
704                         const Float4 xqi = sm_xq[i][tidxi];
705                         const Float3 xi(xqi[0], xqi[1], xqi[2]);
706
707                         // distance between i and j atoms
708                         const Float3 rv = xi - xj;
709                         float        r2 = norm2(rv);
710
711                         if constexpr (doPruneNBL)
712                         {
713                             /* If _none_ of the atoms pairs are in cutoff range,
714                              * the bit corresponding to the current
715                              * cluster-pair in imask gets set to 0. */
716                             if (!sycl_2020::group_any_of(sg, r2 < rlistOuterSq))
717                             {
718                                 imask &= ~maskJI;
719                             }
720                         }
721                         const float pairExclMask = (wexcl & maskJI) ? 1.0F : 0.0F;
722
723                         // cutoff & exclusion check
724
725                         const bool notExcluded = doExclusionForces ? (nonSelfInteraction | (ci != cj))
726                                                                    : (wexcl & maskJI);
727
728                         // SYCL-TODO: Check optimal way of branching here.
729                         if ((r2 < rCoulombSq) && notExcluded)
730                         {
731                             const float qi = xqi[3];
732                             int         atomTypeI; // Only needed if (!props.vdwComb)
733                             float       sigma, epsilon;
734                             Float2      c6c12;
735
736                             if constexpr (!props.vdwComb)
737                             {
738                                 /* LJ 6*C6 and 12*C12 */
739                                 atomTypeI = sm_atomTypeI[i][tidxi];
740                                 c6c12     = a_nbfp[numTypes * atomTypeI + atomTypeJ];
741                             }
742                             else
743                             {
744                                 const Float2 ljCombI = sm_ljCombI[i][tidxi];
745                                 if constexpr (props.vdwCombGeom)
746                                 {
747                                     c6c12 = Float2(ljCombI[0] * ljCombJ[0], ljCombI[1] * ljCombJ[1]);
748                                 }
749                                 else
750                                 {
751                                     static_assert(props.vdwCombLB);
752                                     // LJ 2^(1/6)*sigma and 12*epsilon
753                                     sigma   = ljCombI[0] + ljCombJ[0];
754                                     epsilon = ljCombI[1] * ljCombJ[1];
755                                     if constexpr (doCalcEnergies)
756                                     {
757                                         c6c12 = convertSigmaEpsilonToC6C12(sigma, epsilon);
758                                     }
759                                 } // props.vdwCombGeom
760                             }     // !props.vdwComb
761
762                             // c6 and c12 are unused and garbage iff props.vdwCombLB && !doCalcEnergies
763                             const float c6  = c6c12[0];
764                             const float c12 = c6c12[1];
765
766                             // Ensure distance do not become so small that r^-12 overflows
767                             r2 = std::max(r2, c_nbnxnMinDistanceSquared);
768 #if GMX_SYCL_HIPSYCL
769                             // No fast/native functions in some compilation passes
770                             const float rInv = cl::sycl::rsqrt(r2);
771 #else
772                             // SYCL-TODO: sycl::half_precision::rsqrt?
773                             const float rInv = cl::sycl::native::rsqrt(r2);
774 #endif
775                             const float r2Inv = rInv * rInv;
776                             float       r6Inv, fInvR, energyLJPair;
777                             if constexpr (!props.vdwCombLB || doCalcEnergies)
778                             {
779                                 r6Inv = r2Inv * r2Inv * r2Inv;
780                                 if constexpr (doExclusionForces)
781                                 {
782                                     // SYCL-TODO: Check if true for SYCL
783                                     /* We could mask r2Inv, but with Ewald masking both
784                                      * r6Inv and fInvR is faster */
785                                     r6Inv *= pairExclMask;
786                                 }
787                                 fInvR = r6Inv * (c12 * r6Inv - c6) * r2Inv;
788                             }
789                             else
790                             {
791                                 float sig_r  = sigma * rInv;
792                                 float sig_r2 = sig_r * sig_r;
793                                 float sig_r6 = sig_r2 * sig_r2 * sig_r2;
794                                 if constexpr (doExclusionForces)
795                                 {
796                                     sig_r6 *= pairExclMask;
797                                 }
798                                 fInvR = epsilon * sig_r6 * (sig_r6 - 1.0F) * r2Inv;
799                             } // (!props.vdwCombLB || doCalcEnergies)
800                             if constexpr (doCalcEnergies || props.vdwPSwitch)
801                             {
802                                 energyLJPair = pairExclMask
803                                                * (c12 * (r6Inv * r6Inv + repulsionShift.cpot) * c_oneTwelfth
804                                                   - c6 * (r6Inv + dispersionShift.cpot) * c_oneSixth);
805                             }
806                             if constexpr (props.vdwFSwitch)
807                             {
808                                 ljForceSwitch<doCalcEnergies>(
809                                         dispersionShift, repulsionShift, rVdwSwitch, c6, c12, rInv, r2, &fInvR, &energyLJPair);
810                             }
811                             if constexpr (props.vdwEwald)
812                             {
813                                 ljEwaldComb<doCalcEnergies, vdwType>(a_nbfpComb,
814                                                                      ljEwaldShift,
815                                                                      atomTypeI,
816                                                                      atomTypeJ,
817                                                                      r2,
818                                                                      r2Inv,
819                                                                      ewaldCoeffLJ_2,
820                                                                      ewaldCoeffLJ_6_6,
821                                                                      pairExclMask,
822                                                                      &fInvR,
823                                                                      &energyLJPair);
824                             } // (props.vdwEwald)
825                             if constexpr (props.vdwPSwitch)
826                             {
827                                 ljPotentialSwitch<doCalcEnergies>(
828                                         vdwSwitch, rVdwSwitch, rInv, r2, &fInvR, &energyLJPair);
829                             }
830                             if constexpr (props.elecEwaldTwin)
831                             {
832                                 // Separate VDW cut-off check to enable twin-range cut-offs
833                                 // (rVdw < rCoulomb <= rList)
834                                 const float vdwInRange = (r2 < rVdwSq) ? 1.0F : 0.0F;
835                                 fInvR *= vdwInRange;
836                                 if constexpr (doCalcEnergies)
837                                 {
838                                     energyLJPair *= vdwInRange;
839                                 }
840                             }
841                             if constexpr (doCalcEnergies)
842                             {
843                                 energyVdw += energyLJPair;
844                             }
845
846                             if constexpr (props.elecCutoff)
847                             {
848                                 if constexpr (doExclusionForces)
849                                 {
850                                     fInvR += qi * qj * pairExclMask * r2Inv * rInv;
851                                 }
852                                 else
853                                 {
854                                     fInvR += qi * qj * r2Inv * rInv;
855                                 }
856                             }
857                             if constexpr (props.elecRF)
858                             {
859                                 fInvR += qi * qj * (pairExclMask * r2Inv * rInv - twoKRf);
860                             }
861                             if constexpr (props.elecEwaldAna)
862                             {
863                                 fInvR += qi * qj
864                                          * (pairExclMask * r2Inv * rInv + pmeCorrF(beta2 * r2) * beta3);
865                             }
866                             if constexpr (props.elecEwaldTab)
867                             {
868                                 fInvR += qi * qj
869                                          * (pairExclMask * r2Inv
870                                             - interpolateCoulombForceR(
871                                                       a_coulombTab, coulombTabScale, r2 * rInv))
872                                          * rInv;
873                             }
874
875                             if constexpr (doCalcEnergies)
876                             {
877                                 if constexpr (props.elecCutoff)
878                                 {
879                                     energyElec += qi * qj * (pairExclMask * rInv - cRF);
880                                 }
881                                 if constexpr (props.elecRF)
882                                 {
883                                     energyElec +=
884                                             qi * qj * (pairExclMask * rInv + 0.5f * twoKRf * r2 - cRF);
885                                 }
886                                 if constexpr (props.elecEwald)
887                                 {
888                                     energyElec +=
889                                             qi * qj
890                                             * (rInv * (pairExclMask - cl::sycl::erf(r2 * rInv * ewaldBeta))
891                                                - pairExclMask * ewaldShift);
892                                 }
893                             }
894
895                             const Float3 forceIJ = rv * fInvR;
896
897                             /* accumulate j forces in registers */
898                             fCjBuf -= forceIJ;
899                             /* accumulate i forces in registers */
900                             fCiBuf[i] += forceIJ;
901                         } // (r2 < rCoulombSq) && notExcluded
902                     }     // (imask & maskJI)
903                     /* shift the mask bit by 1 */
904                     maskJI += maskJI;
905                 } // for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
906                 /* reduce j forces */
907                 reduceForceJShuffle(fCjBuf, itemIdx, tidxi, aj, a_f);
908             } // for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
909             if constexpr (doPruneNBL)
910             {
911                 /* Update the imask with the new one which does not contain the
912                  * out of range clusters anymore. */
913                 a_plistCJ4[j4].imei[imeiIdx].imask = imask;
914             }
915         } // for (int j4 = cij4Start; j4 < cij4End; j4 += 1)
916
917         /* skip central shifts when summing shift forces */
918         const bool doCalcShift = (calcShift && !(nbSci.shift == CENTRAL));
919
920         reduceForceIAndFShift(
921                 sm_reductionBuffer, fCiBuf, doCalcShift, itemIdx, tidxi, tidxj, sci, nbSci.shift, a_f, a_fShift);
922
923         if constexpr (doCalcEnergies)
924         {
925             const float energyVdwGroup = sycl_2020::group_reduce(
926                     itemIdx.get_group(), energyVdw, 0.0F, sycl_2020::plus<float>());
927             const float energyElecGroup = sycl_2020::group_reduce(
928                     itemIdx.get_group(), energyElec, 0.0F, sycl_2020::plus<float>());
929
930             if (tidx == 0)
931             {
932                 atomicFetchAdd(a_energyVdw, 0, energyVdwGroup);
933                 atomicFetchAdd(a_energyElec, 0, energyElecGroup);
934             }
935         }
936     };
937 }
938
939 // SYCL 1.2.1 requires providing a unique type for a kernel. Should not be needed for SYCL2020.
940 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
941 class NbnxmKernelName;
942
943 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType, class... Args>
944 cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int numSci, Args&&... args)
945 {
946     // Should not be needed for SYCL2020.
947     using kernelNameType = NbnxmKernelName<doPruneNBL, doCalcEnergies, elecType, vdwType>;
948
949     /* Kernel launch config:
950      * - The thread block dimensions match the size of i-clusters, j-clusters,
951      *   and j-cluster concurrency, in x, y, and z, respectively.
952      * - The 1D block-grid contains as many blocks as super-clusters.
953      */
954     const int                   numBlocks = numSci;
955     const cl::sycl::range<3>    blockSize{ c_clSize, c_clSize, 1 };
956     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
957     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
958
959     cl::sycl::queue q = deviceStream.stream();
960
961     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
962         auto kernel = nbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>(
963                 cgh, std::forward<Args>(args)...);
964         cgh.parallel_for<kernelNameType>(flattenNDRange(range), kernel);
965     });
966
967     return e;
968 }
969
970 template<class... Args>
971 cl::sycl::event chooseAndLaunchNbnxmKernel(bool          doPruneNBL,
972                                            bool          doCalcEnergies,
973                                            enum ElecType elecType,
974                                            enum VdwType  vdwType,
975                                            Args&&... args)
976 {
977     return gmx::dispatchTemplatedFunction(
978             [&](auto doPruneNBL_, auto doCalcEnergies_, auto elecType_, auto vdwType_) {
979                 return launchNbnxmKernel<doPruneNBL_, doCalcEnergies_, elecType_, vdwType_>(
980                         std::forward<Args>(args)...);
981             },
982             doPruneNBL,
983             doCalcEnergies,
984             elecType,
985             vdwType);
986 }
987
988 void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const InteractionLocality iloc)
989 {
990     NBAtomDataGpu*      adat         = nb->atdat;
991     NBParamGpu*         nbp          = nb->nbparam;
992     gpu_plist*          plist        = nb->plist[iloc];
993     const bool          doPruneNBL   = (plist->haveFreshList && !nb->didPrune[iloc]);
994     const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
995
996     // Casting to float simplifies using atomic ops in the kernel
997     cl::sycl::buffer<Float3, 1> f(*adat->f.buffer_);
998     auto                        fAsFloat = f.reinterpret<float, 1>(f.get_count() * DIM);
999     cl::sycl::buffer<Float3, 1> fShift(*adat->fShift.buffer_);
1000     auto fShiftAsFloat = fShift.reinterpret<float, 1>(fShift.get_count() * DIM);
1001
1002     cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL,
1003                                                    stepWork.computeEnergy,
1004                                                    nbp->elecType,
1005                                                    nbp->vdwType,
1006                                                    deviceStream,
1007                                                    plist->nsci,
1008                                                    adat->xq,
1009                                                    fAsFloat,
1010                                                    adat->shiftVec,
1011                                                    fShiftAsFloat,
1012                                                    adat->eElec,
1013                                                    adat->eLJ,
1014                                                    plist->cj4,
1015                                                    plist->sci,
1016                                                    plist->excl,
1017                                                    adat->ljComb,
1018                                                    adat->atomTypes,
1019                                                    nbp->nbfp,
1020                                                    nbp->nbfp_comb,
1021                                                    nbp->coulomb_tab,
1022                                                    adat->numTypes,
1023                                                    nbp->rcoulomb_sq,
1024                                                    nbp->rvdw_sq,
1025                                                    nbp->two_k_rf,
1026                                                    nbp->ewald_beta,
1027                                                    nbp->rlistOuter_sq,
1028                                                    nbp->sh_ewald,
1029                                                    nbp->epsfac,
1030                                                    nbp->ewaldcoeff_lj,
1031                                                    nbp->c_rf,
1032                                                    nbp->dispersion_shift,
1033                                                    nbp->repulsion_shift,
1034                                                    nbp->vdw_switch,
1035                                                    nbp->rvdw_switch,
1036                                                    nbp->sh_lj_ewald,
1037                                                    nbp->coulomb_tab_scale,
1038                                                    stepWork.computeVirial);
1039 }
1040
1041 } // namespace Nbnxm