SYCL NBNXM offload support
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief
38  *  NBNXM SYCL kernels
39  *
40  *  \ingroup module_nbnxm
41  */
42 #include "gmxpre.h"
43
44 #include "nbnxm_sycl_kernel.h"
45
46 #include "gromacs/gpu_utils/devicebuffer.h"
47 #include "gromacs/gpu_utils/gmxsycl.h"
48 #include "gromacs/math/functions.h"
49 #include "gromacs/mdtypes/simulation_workload.h"
50 #include "gromacs/pbcutil/ishift.h"
51 #include "gromacs/utility/template_mp.h"
52
53 #include "nbnxm_sycl_kernel_utils.h"
54 #include "nbnxm_sycl_types.h"
55
56 namespace Nbnxm
57 {
58
59 //! \brief Set of boolean constants mimicking preprocessor macros.
60 template<enum ElecType elecType, enum VdwType vdwType>
61 struct EnergyFunctionProperties {
62     static constexpr bool elecCutoff = (elecType == ElecType::Cut); ///< EL_CUTOFF
63     static constexpr bool elecRF     = (elecType == ElecType::RF);  ///< EL_RF
64     static constexpr bool elecEwaldAna =
65             (elecType == ElecType::EwaldAna || elecType == ElecType::EwaldAnaTwin); ///< EL_EWALD_ANA
66     static constexpr bool elecEwaldTab =
67             (elecType == ElecType::EwaldTab || elecType == ElecType::EwaldTabTwin); ///< EL_EWALD_TAB
68     static constexpr bool elecEwaldTwin =
69             (elecType == ElecType::EwaldAnaTwin || elecType == ElecType::EwaldTabTwin);
70     static constexpr bool elecEwald        = (elecEwaldAna || elecEwaldTab); ///< EL_EWALD_ANY
71     static constexpr bool vdwCombLB        = (vdwType == VdwType::CutCombLB);
72     static constexpr bool vdwCombGeom      = (vdwType == VdwType::CutCombGeom); ///< LJ_COMB_GEOM
73     static constexpr bool vdwComb          = (vdwCombLB || vdwCombGeom);        ///< LJ_COMB
74     static constexpr bool vdwEwaldCombGeom = (vdwType == VdwType::EwaldGeom); ///< LJ_EWALD_COMB_GEOM
75     static constexpr bool vdwEwaldCombLB   = (vdwType == VdwType::EwaldLB);   ///< LJ_EWALD_COMB_LB
76     static constexpr bool vdwEwald         = (vdwEwaldCombGeom || vdwEwaldCombLB); ///< LJ_EWALD
77     static constexpr bool vdwFSwitch       = (vdwType == VdwType::FSwitch); ///< LJ_FORCE_SWITCH
78     static constexpr bool vdwPSwitch       = (vdwType == VdwType::PSwitch); ///< LJ_POT_SWITCH
79 };
80
81 //! \brief Templated constants to shorten kernel function declaration.
82 //@{
83 template<enum VdwType vdwType>
84 constexpr bool ljComb = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwComb;
85
86 template<enum ElecType elecType> // Yes, ElecType
87 constexpr bool vdwCutoffCheck = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTwin;
88
89 template<enum ElecType elecType>
90 constexpr bool elecEwald = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwald;
91
92 template<enum ElecType elecType>
93 constexpr bool elecEwaldTab = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTab;
94
95 template<enum VdwType vdwType>
96 constexpr bool ljEwald = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwEwald;
97 //@}
98
99 using cl::sycl::access::fence_space;
100 using cl::sycl::access::mode;
101 using cl::sycl::access::target;
102
103 static inline void convertSigmaEpsilonToC6C12(const float                  sigma,
104                                               const float                  epsilon,
105                                               cl::sycl::private_ptr<float> c6,
106                                               cl::sycl::private_ptr<float> c12)
107 {
108     const float sigma2 = sigma * sigma;
109     const float sigma6 = sigma2 * sigma2 * sigma2;
110     *c6                = epsilon * sigma6;
111     *c12               = (*c6) * sigma6;
112 }
113
114 template<bool doCalcEnergies>
115 static inline void ljForceSwitch(const shift_consts_t         dispersionShift,
116                                  const shift_consts_t         repulsionShift,
117                                  const float                  rVdwSwitch,
118                                  const float                  c6,
119                                  const float                  c12,
120                                  const float                  rInv,
121                                  const float                  r2,
122                                  cl::sycl::private_ptr<float> fInvR,
123                                  cl::sycl::private_ptr<float> eLJ)
124 {
125     /* force switch constants */
126     const float dispShiftV2 = dispersionShift.c2;
127     const float dispShiftV3 = dispersionShift.c3;
128     const float repuShiftV2 = repulsionShift.c2;
129     const float repuShiftV3 = repulsionShift.c3;
130
131     const float r       = r2 * rInv;
132     const float rSwitch = cl::sycl::fdim(r, rVdwSwitch); // max(r - rVdwSwitch, 0)
133
134     *fInvR += -c6 * (dispShiftV2 + dispShiftV3 * rSwitch) * rSwitch * rSwitch * rInv
135               + c12 * (repuShiftV2 + repuShiftV3 * rSwitch) * rSwitch * rSwitch * rInv;
136
137     if constexpr (doCalcEnergies)
138     {
139         const float dispShiftF2 = dispShiftV2 / 3;
140         const float dispShiftF3 = dispShiftV3 / 4;
141         const float repuShiftF2 = repuShiftV2 / 3;
142         const float repuShiftF3 = repuShiftV3 / 4;
143         *eLJ += c6 * (dispShiftF2 + dispShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch
144                 - c12 * (repuShiftF2 + repuShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch;
145     }
146 }
147
148 //! \brief Fetch C6 grid contribution coefficients and return the product of these.
149 template<enum VdwType vdwType>
150 static inline float calculateLJEwaldC6Grid(const DeviceAccessor<float, mode::read> a_nbfpComb,
151                                            const int                               typeI,
152                                            const int                               typeJ)
153 {
154     if constexpr (vdwType == VdwType::EwaldGeom)
155     {
156         return a_nbfpComb[2 * typeI] * a_nbfpComb[2 * typeJ];
157     }
158     else
159     {
160         static_assert(vdwType == VdwType::EwaldLB);
161         /* sigma and epsilon are scaled to give 6*C6 */
162         const float c6_i  = a_nbfpComb[2 * typeI];
163         const float c12_i = a_nbfpComb[2 * typeI + 1];
164         const float c6_j  = a_nbfpComb[2 * typeJ];
165         const float c12_j = a_nbfpComb[2 * typeJ + 1];
166
167         const float sigma   = c6_i + c6_j;
168         const float epsilon = c12_i * c12_j;
169
170         const float sigma2 = sigma * sigma;
171         return epsilon * sigma2 * sigma2 * sigma2;
172     }
173 }
174
175 //! Calculate LJ-PME grid force contribution with geometric or LB combination rule.
176 template<bool doCalcEnergies, enum VdwType vdwType>
177 static inline void ljEwaldComb(const DeviceAccessor<float, mode::read> a_nbfpComb,
178                                const float                             sh_lj_ewald,
179                                const int                               typeI,
180                                const int                               typeJ,
181                                const float                             r2,
182                                const float                             r2Inv,
183                                const float                             lje_coeff2,
184                                const float                             lje_coeff6_6,
185                                const float                             int_bit,
186                                cl::sycl::private_ptr<float>            fInvR,
187                                cl::sycl::private_ptr<float>            eLJ)
188 {
189     const float c6grid = calculateLJEwaldC6Grid<vdwType>(a_nbfpComb, typeI, typeJ);
190
191     /* Recalculate inv_r6 without exclusion mask */
192     const float inv_r6_nm = r2Inv * r2Inv * r2Inv;
193     const float cr2       = lje_coeff2 * r2;
194     const float expmcr2   = cl::sycl::exp(-cr2);
195     const float poly      = 1.0F + cr2 + 0.5F * cr2 * cr2;
196
197     /* Subtract the grid force from the total LJ force */
198     *fInvR += c6grid * (inv_r6_nm - expmcr2 * (inv_r6_nm * poly + lje_coeff6_6)) * r2Inv;
199
200     if constexpr (doCalcEnergies)
201     {
202         /* Shift should be applied only to real LJ pairs */
203         const float sh_mask = sh_lj_ewald * int_bit;
204         *eLJ += c_oneSixth * c6grid * (inv_r6_nm * (1.0F - expmcr2 * poly) + sh_mask);
205     }
206 }
207
208 /*! \brief Apply potential switch. */
209 template<bool doCalcEnergies>
210 static inline void ljPotentialSwitch(const switch_consts_t        vdwSwitch,
211                                      const float                  rVdwSwitch,
212                                      const float                  rInv,
213                                      const float                  r2,
214                                      cl::sycl::private_ptr<float> fInvR,
215                                      cl::sycl::private_ptr<float> eLJ)
216 {
217     /* potential switch constants */
218     const float switchV3 = vdwSwitch.c3;
219     const float switchV4 = vdwSwitch.c4;
220     const float switchV5 = vdwSwitch.c5;
221     const float switchF2 = 3 * vdwSwitch.c3;
222     const float switchF3 = 4 * vdwSwitch.c4;
223     const float switchF4 = 5 * vdwSwitch.c5;
224
225     const float r       = r2 * rInv;
226     const float rSwitch = r - rVdwSwitch;
227
228     if (rSwitch > 0.0F)
229     {
230         const float sw =
231                 1.0F + (switchV3 + (switchV4 + switchV5 * rSwitch) * rSwitch) * rSwitch * rSwitch * rSwitch;
232         const float dsw = (switchF2 + (switchF3 + switchF4 * rSwitch) * rSwitch) * rSwitch * rSwitch;
233
234         *fInvR = (*fInvR) * sw - rInv * (*eLJ) * dsw;
235         if constexpr (doCalcEnergies)
236         {
237             *eLJ *= sw;
238         }
239     }
240 }
241
242
243 /*! \brief Calculate analytical Ewald correction term. */
244 static inline float pmeCorrF(const float z2)
245 {
246     constexpr float FN6 = -1.7357322914161492954e-8F;
247     constexpr float FN5 = 1.4703624142580877519e-6F;
248     constexpr float FN4 = -0.000053401640219807709149F;
249     constexpr float FN3 = 0.0010054721316683106153F;
250     constexpr float FN2 = -0.019278317264888380590F;
251     constexpr float FN1 = 0.069670166153766424023F;
252     constexpr float FN0 = -0.75225204789749321333F;
253
254     constexpr float FD4 = 0.0011193462567257629232F;
255     constexpr float FD3 = 0.014866955030185295499F;
256     constexpr float FD2 = 0.11583842382862377919F;
257     constexpr float FD1 = 0.50736591960530292870F;
258     constexpr float FD0 = 1.0F;
259
260     const float z4 = z2 * z2;
261
262     float       polyFD0 = FD4 * z4 + FD2;
263     const float polyFD1 = FD3 * z4 + FD1;
264     polyFD0             = polyFD0 * z4 + FD0;
265     polyFD0             = polyFD1 * z2 + polyFD0;
266
267     polyFD0 = 1.0F / polyFD0;
268
269     float polyFN0 = FN6 * z4 + FN4;
270     float polyFN1 = FN5 * z4 + FN3;
271     polyFN0       = polyFN0 * z4 + FN2;
272     polyFN1       = polyFN1 * z4 + FN1;
273     polyFN0       = polyFN0 * z4 + FN0;
274     polyFN0       = polyFN1 * z2 + polyFN0;
275
276     return polyFN0 * polyFD0;
277 }
278
279 /*! \brief Linear interpolation using exactly two FMA operations.
280  *
281  *  Implements numeric equivalent of: (1-t)*d0 + t*d1.
282  */
283 template<typename T>
284 static inline T lerp(T d0, T d1, T t)
285 {
286     return fma(t, d1, fma(-t, d0, d0));
287 }
288
289 /*! \brief Interpolate Ewald coulomb force correction using the F*r table. */
290 static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::read> a_coulombTab,
291                                              const float coulombTabScale,
292                                              const float r)
293 {
294     const float normalized = coulombTabScale * r;
295     const int   index      = static_cast<int>(normalized);
296     const float fraction   = normalized - index;
297
298     const float left  = a_coulombTab[index];
299     const float right = a_coulombTab[index + 1];
300
301     return lerp(left, right, fraction); // TODO: cl::sycl::mix
302 }
303
304 static inline void reduceForceJShuffle(float3                                  f,
305                                        const cl::sycl::nd_item<1>              itemIdx,
306                                        const int                               tidxi,
307                                        const int                               aidx,
308                                        DeviceAccessor<float, mode::read_write> a_f)
309 {
310     static_assert(c_clSize == 8 || c_clSize == 4);
311     sycl_2020::sub_group sg = itemIdx.get_sub_group();
312
313     f[0] += shuffleDown(f[0], 1, sg);
314     f[1] += shuffleUp(f[1], 1, sg);
315     f[2] += shuffleDown(f[2], 1, sg);
316     if (tidxi & 1)
317     {
318         f[0] = f[1];
319     }
320
321     f[0] += shuffleDown(f[0], 2, sg);
322     f[2] += shuffleUp(f[2], 2, sg);
323     if (tidxi & 2)
324     {
325         f[0] = f[2];
326     }
327
328     if constexpr (c_clSize == 8)
329     {
330         f[0] += shuffleDown(f[0], 4, sg);
331     }
332
333     if (tidxi < 3)
334     {
335         atomicFetchAdd(a_f, 3 * aidx + tidxi, f[0]);
336     }
337 }
338
339
340 /*! \brief Final i-force reduction.
341  *
342  * This implementation works only with power of two array sizes.
343  */
344 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
345                                          const float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
346                                          const bool   calcFShift,
347                                          const cl::sycl::nd_item<1>              itemIdx,
348                                          const int                               tidxi,
349                                          const int                               tidxj,
350                                          const int                               sci,
351                                          const int                               shift,
352                                          DeviceAccessor<float, mode::read_write> a_f,
353                                          DeviceAccessor<float, mode::read_write> a_fShift)
354 {
355     static constexpr int bufStride  = c_clSize * c_clSize;
356     static constexpr int clSizeLog2 = gmx::StaticLog2<c_clSize>::value;
357     const int            tidx       = tidxi + tidxj * c_clSize;
358     float                fShiftBuf  = 0;
359     for (int ciOffset = 0; ciOffset < c_nbnxnGpuNumClusterPerSupercluster; ciOffset++)
360     {
361         const int aidx = (sci * c_nbnxnGpuNumClusterPerSupercluster + ciOffset) * c_clSize + tidxi;
362         /* store i forces in shmem */
363         sm_buf[tidx]                 = fCiBuf[ciOffset][0];
364         sm_buf[bufStride + tidx]     = fCiBuf[ciOffset][1];
365         sm_buf[2 * bufStride + tidx] = fCiBuf[ciOffset][2];
366         itemIdx.barrier(fence_space::local_space);
367
368         /* Reduce the initial c_clSize values for each i atom to half
369          * every step by using c_clSize * i threads. */
370         int i = c_clSize / 2;
371         for (int j = clSizeLog2 - 1; j > 0; j--)
372         {
373             if (tidxj < i)
374             {
375                 sm_buf[tidxj * c_clSize + tidxi] += sm_buf[(tidxj + i) * c_clSize + tidxi];
376                 sm_buf[bufStride + tidxj * c_clSize + tidxi] +=
377                         sm_buf[bufStride + (tidxj + i) * c_clSize + tidxi];
378                 sm_buf[2 * bufStride + tidxj * c_clSize + tidxi] +=
379                         sm_buf[2 * bufStride + (tidxj + i) * c_clSize + tidxi];
380             }
381             i >>= 1;
382             itemIdx.barrier(fence_space::local_space);
383         }
384
385         /* i == 1, last reduction step, writing to global mem */
386         /* Split the reduction between the first 3 line threads
387            Threads with line id 0 will do the reduction for (float3).x components
388            Threads with line id 1 will do the reduction for (float3).y components
389            Threads with line id 2 will do the reduction for (float3).z components. */
390         if (tidxj < 3)
391         {
392             const float f =
393                     sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
394             atomicFetchAdd(a_f, 3 * aidx + tidxj, f);
395             if (calcFShift)
396             {
397                 fShiftBuf += f;
398             }
399         }
400         itemIdx.barrier(fence_space::local_space);
401     }
402     /* add up local shift forces into global mem */
403     if (calcFShift)
404     {
405         /* Only threads with tidxj < 3 will update fshift.
406            The threads performing the update must be the same as the threads
407            storing the reduction result above. */
408         if (tidxj < 3)
409         {
410             atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
411         }
412     }
413 }
414
415
416 /*! \brief Main kernel for NBNXM.
417  *
418  */
419 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
420 auto nbnxmKernel(cl::sycl::handler&                                        cgh,
421                  DeviceAccessor<float4, mode::read>                        a_xq,
422                  DeviceAccessor<float, mode::read_write>                   a_f,
423                  DeviceAccessor<float3, mode::read>                        a_shiftVec,
424                  DeviceAccessor<float, mode::read_write>                   a_fShift,
425                  OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyElec,
426                  OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyVdw,
427                  DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
428                  DeviceAccessor<nbnxn_sci_t, mode::read>                                 a_plistSci,
429                  DeviceAccessor<nbnxn_excl_t, mode::read>                    a_plistExcl,
430                  OptionalAccessor<float2, mode::read, ljComb<vdwType>>       a_ljComb,
431                  OptionalAccessor<int, mode::read, !ljComb<vdwType>>         a_atomTypes,
432                  OptionalAccessor<float, mode::read, !ljComb<vdwType>>       a_nbfp,
433                  OptionalAccessor<float, mode::read, ljEwald<vdwType>>       a_nbfpComb,
434                  OptionalAccessor<float, mode::read, elecEwaldTab<elecType>> a_coulombTab,
435                  const int                                                   numTypes,
436                  const float                                                 rCoulombSq,
437                  const float                                                 rVdwSq,
438                  const float                                                 twoKRf,
439                  const float                                                 ewaldBeta,
440                  const float                                                 rlistOuterSq,
441                  const float                                                 ewaldShift,
442                  const float                                                 epsFac,
443                  const float                                                 ewaldCoeffLJ,
444                  const float                                                 cRF,
445                  const shift_consts_t                                        dispersionShift,
446                  const shift_consts_t                                        repulsionShift,
447                  const switch_consts_t                                       vdwSwitch,
448                  const float                                                 rVdwSwitch,
449                  const float                                                 ljEwaldShift,
450                  const float                                                 coulombTabScale,
451                  const bool                                                  calcShift)
452 {
453     static constexpr EnergyFunctionProperties<elecType, vdwType> props;
454
455     cgh.require(a_xq);
456     cgh.require(a_f);
457     cgh.require(a_shiftVec);
458     cgh.require(a_fShift);
459     if constexpr (doCalcEnergies)
460     {
461         cgh.require(a_energyElec);
462         cgh.require(a_energyVdw);
463     }
464     cgh.require(a_plistCJ4);
465     cgh.require(a_plistSci);
466     cgh.require(a_plistExcl);
467     if constexpr (!props.vdwComb)
468     {
469         cgh.require(a_atomTypes);
470         cgh.require(a_nbfp);
471     }
472     else
473     {
474         cgh.require(a_ljComb);
475     }
476     if constexpr (props.vdwEwald)
477     {
478         cgh.require(a_nbfpComb);
479     }
480     if constexpr (props.elecEwaldTab)
481     {
482         cgh.require(a_coulombTab);
483     }
484
485     // shmem buffer for i x+q pre-loading
486     cl::sycl::accessor<float4, 2, mode::read_write, target::local> sm_xq(
487             cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
488
489     // shmem buffer for force reduction
490     // SYCL-TODO: Make into 3D; section 4.7.6.11 of SYCL2020 specs
491     cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_reductionBuffer(
492             cl::sycl::range<1>(c_clSize * c_clSize * DIM), cgh);
493
494     auto sm_atomTypeI = [&]() {
495         if constexpr (!props.vdwComb)
496         {
497             return cl::sycl::accessor<int, 2, mode::read_write, target::local>(
498                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
499         }
500         else
501         {
502             return nullptr;
503         }
504     }();
505
506     auto sm_ljCombI = [&]() {
507         if constexpr (props.vdwComb)
508         {
509             return cl::sycl::accessor<float2, 2, mode::read_write, target::local>(
510                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
511         }
512         else
513         {
514             return nullptr;
515         }
516     }();
517
518     /* Flag to control the calculation of exclusion forces in the kernel
519      * We do that with Ewald (elec/vdw) and RF. Cut-off only has exclusion
520      * energy terms. */
521     constexpr bool doExclusionForces =
522             (props.elecEwald || props.elecRF || props.vdwEwald || (props.elecCutoff && doCalcEnergies));
523
524     constexpr int subGroupSize = c_clSize * c_clSize / 2;
525
526     return [=](cl::sycl::nd_item<1> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
527     {
528         /* thread/block/warp id-s */
529         const cl::sycl::id<3> localId = unflattenId<c_clSize, c_clSize>(itemIdx.get_local_id());
530         const unsigned        tidxi   = localId[0];
531         const unsigned        tidxj   = localId[1];
532         const unsigned        tidx    = tidxj * c_clSize + tidxi;
533         const unsigned        tidxz   = 0;
534
535         // Group indexing was flat originally, no need to unflatten it.
536         const unsigned bidx = itemIdx.get_group(0);
537
538         const sycl_2020::sub_group sg = itemIdx.get_sub_group();
539         // Better use sg.get_group_range, but too much of the logic relies on it anyway
540         const unsigned widx = tidx / subGroupSize;
541
542         float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster]; // i force buffer
543         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
544         {
545             fCiBuf[i] = float3(0.0F, 0.0F, 0.0F);
546         }
547
548         const nbnxn_sci_t nbSci     = a_plistSci[bidx];
549         const int         sci       = nbSci.sci;
550         const int         cij4Start = nbSci.cj4_ind_start;
551         const int         cij4End   = nbSci.cj4_ind_end;
552
553         // Only needed if props.elecEwaldAna
554         const float beta2 = ewaldBeta * ewaldBeta;
555         const float beta3 = ewaldBeta * ewaldBeta * ewaldBeta;
556
557         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i += c_clSize)
558         {
559             /* Pre-load i-atom x and q into shared memory */
560             const int             ci       = sci * c_nbnxnGpuNumClusterPerSupercluster + tidxj + i;
561             const int             ai       = ci * c_clSize + tidxi;
562             const cl::sycl::id<2> cacheIdx = cl::sycl::id<2>(tidxj + i, tidxi);
563
564             const float3 shift = a_shiftVec[nbSci.shift];
565             float4       xqi   = a_xq[ai];
566             xqi += float4(shift[0], shift[1], shift[2], 0.0F);
567             xqi[3] *= epsFac;
568             sm_xq[cacheIdx] = xqi;
569
570             if constexpr (!props.vdwComb)
571             {
572                 // Pre-load the i-atom types into shared memory
573                 sm_atomTypeI[cacheIdx] = a_atomTypes[ai];
574             }
575             else
576             {
577                 // Pre-load the LJ combination parameters into shared memory
578                 sm_ljCombI[cacheIdx] = a_ljComb[ai];
579             }
580         }
581         itemIdx.barrier(fence_space::local_space);
582
583         float ewaldCoeffLJ_2, ewaldCoeffLJ_6_6; // Only needed if (props.vdwEwald)
584         if constexpr (props.vdwEwald)
585         {
586             ewaldCoeffLJ_2   = ewaldCoeffLJ * ewaldCoeffLJ;
587             ewaldCoeffLJ_6_6 = ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * c_oneSixth;
588         }
589
590         float energyVdw, energyElec; // Only needed if (doCalcEnergies)
591         if constexpr (doCalcEnergies)
592         {
593             energyVdw = energyElec = 0.0F;
594         }
595         if constexpr (doCalcEnergies && doExclusionForces)
596         {
597             if (nbSci.shift == CENTRAL && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
598             {
599                 // we have the diagonal: add the charge and LJ self interaction energy term
600                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
601                 {
602                     // TODO: Are there other options?
603                     if constexpr (props.elecEwald || props.elecRF || props.elecCutoff)
604                     {
605                         const float qi = sm_xq[i][tidxi][3];
606                         energyElec += qi * qi;
607                     }
608                     if constexpr (props.vdwEwald)
609                     {
610                         energyVdw +=
611                                 a_nbfp[a_atomTypes[(sci * c_nbnxnGpuNumClusterPerSupercluster + i) * c_clSize + tidxi]
612                                        * (numTypes + 1) * 2];
613                     }
614                 }
615                 /* divide the self term(s) equally over the j-threads, then multiply with the coefficients. */
616                 if constexpr (props.vdwEwald)
617                 {
618                     energyVdw /= c_clSize;
619                     energyVdw *= 0.5F * c_oneSixth * ewaldCoeffLJ_6_6; // c_OneTwelfth?
620                 }
621                 if constexpr (props.elecRF || props.elecCutoff)
622                 {
623                     // Correct for epsfac^2 due to adding qi^2 */
624                     energyElec /= epsFac * c_clSize;
625                     energyElec *= -0.5F * cRF;
626                 }
627                 if constexpr (props.elecEwald)
628                 {
629                     // Correct for epsfac^2 due to adding qi^2 */
630                     energyElec /= epsFac * c_clSize;
631                     energyElec *= -ewaldBeta * c_OneOverSqrtPi; /* last factor 1/sqrt(pi) */
632                 }
633             } // (nbSci.shift == CENTRAL && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
634         }     // (doCalcEnergies && doExclusionForces)
635
636         // Only needed if (doExclusionForces)
637         const bool nonSelfInteraction = !(nbSci.shift == CENTRAL & tidxj <= tidxi);
638
639         // loop over the j clusters = seen by any of the atoms in the current super-cluster
640         for (int j4 = cij4Start + tidxz; j4 < cij4End; j4 += 1)
641         {
642             unsigned imask = a_plistCJ4[j4].imei[widx].imask;
643             if (!doPruneNBL && !imask)
644             {
645                 continue;
646             }
647             const int wexclIdx = a_plistCJ4[j4].imei[widx].excl_ind;
648             const unsigned wexcl = a_plistExcl[wexclIdx].pair[tidx & (subGroupSize - 1)]; // sg.get_local_linear_id()
649             for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
650             {
651                 const bool maskSet =
652                         imask & (superClInteractionMask << (jm * c_nbnxnGpuNumClusterPerSupercluster));
653                 if (!maskSet)
654                 {
655                     continue;
656                 }
657                 unsigned  maskJI = (1U << (jm * c_nbnxnGpuNumClusterPerSupercluster));
658                 const int cj     = a_plistCJ4[j4].cj[jm];
659                 const int aj     = cj * c_clSize + tidxj;
660
661                 // load j atom data
662                 const float4 xqj = a_xq[aj];
663
664                 const float3 xj(xqj[0], xqj[1], xqj[2]);
665                 const float  qj = xqj[3];
666                 int          atomTypeJ; // Only needed if (!props.vdwComb)
667                 float2       ljCombJ;   // Only needed if (props.vdwComb)
668                 if constexpr (props.vdwComb)
669                 {
670                     ljCombJ = a_ljComb[aj];
671                 }
672                 else
673                 {
674                     atomTypeJ = a_atomTypes[aj];
675                 }
676
677                 float3 fCjBuf(0.0F, 0.0F, 0.0F);
678
679                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
680                 {
681                     if (imask & maskJI)
682                     {
683                         // i cluster index
684                         const int ci = sci * c_nbnxnGpuNumClusterPerSupercluster + i;
685                         // all threads load an atom from i cluster ci into shmem!
686                         const float4 xqi = sm_xq[i][tidxi];
687                         const float3 xi(xqi[0], xqi[1], xqi[2]);
688
689                         // distance between i and j atoms
690                         const float3 rv = xi - xj;
691                         float        r2 = norm2(rv);
692
693                         if constexpr (doPruneNBL)
694                         {
695                             /* If _none_ of the atoms pairs are in cutoff range,
696                              * the bit corresponding to the current
697                              * cluster-pair in imask gets set to 0. */
698                             if (!sycl_2020::group_any_of(sg, r2 < rlistOuterSq))
699                             {
700                                 imask &= ~maskJI;
701                             }
702                         }
703                         const float pairExclMask = (wexcl & maskJI) ? 1.0F : 0.0F;
704
705                         // cutoff & exclusion check
706
707                         const bool notExcluded = doExclusionForces ? (nonSelfInteraction | (ci != cj))
708                                                                    : (wexcl & maskJI);
709
710                         // SYCL-TODO: Check optimal way of branching here.
711                         if ((r2 < rCoulombSq) && notExcluded)
712                         {
713                             const float qi = xqi[3];
714                             int         atomTypeI; // Only needed if (!props.vdwComb)
715                             float       c6, c12, sigma, epsilon;
716
717                             if constexpr (!props.vdwComb)
718                             {
719                                 /* LJ 6*C6 and 12*C12 */
720                                 atomTypeI     = sm_atomTypeI[i][tidxi];
721                                 const int idx = (numTypes * atomTypeI + atomTypeJ) * 2;
722                                 c6            = a_nbfp[idx]; // TODO: Make a_nbfm into float2
723                                 c12           = a_nbfp[idx + 1];
724                             }
725                             else
726                             {
727                                 const float2 ljCombI = sm_ljCombI[i][tidxi];
728                                 if constexpr (props.vdwCombGeom)
729                                 {
730                                     c6  = ljCombI[0] * ljCombJ[0];
731                                     c12 = ljCombI[1] * ljCombJ[1];
732                                 }
733                                 else
734                                 {
735                                     static_assert(props.vdwCombLB);
736                                     // LJ 2^(1/6)*sigma and 12*epsilon
737                                     sigma   = ljCombI[0] + ljCombJ[0];
738                                     epsilon = ljCombI[1] * ljCombJ[1];
739                                     if constexpr (doCalcEnergies)
740                                     {
741                                         convertSigmaEpsilonToC6C12(sigma, epsilon, &c6, &c12);
742                                     }
743                                 } // props.vdwCombGeom
744                             }     // !props.vdwComb
745
746                             // Ensure distance do not become so small that r^-12 overflows
747                             r2 = std::max(r2, c_nbnxnMinDistanceSquared);
748                             // SYCL-TODO: sycl::half_precision::rsqrt?
749                             const float rInv  = cl::sycl::native::rsqrt(r2);
750                             const float r2Inv = rInv * rInv;
751                             float       r6Inv, fInvR, energyLJPair;
752                             if constexpr (!props.vdwCombLB || doCalcEnergies)
753                             {
754                                 r6Inv = r2Inv * r2Inv * r2Inv;
755                                 if constexpr (doExclusionForces)
756                                 {
757                                     // SYCL-TODO: Check if true for SYCL
758                                     /* We could mask r2Inv, but with Ewald masking both
759                                      * r6Inv and fInvR is faster */
760                                     r6Inv *= pairExclMask;
761                                 }
762                                 fInvR = r6Inv * (c12 * r6Inv - c6) * r2Inv;
763                             }
764                             else
765                             {
766                                 float sig_r  = sigma * rInv;
767                                 float sig_r2 = sig_r * sig_r;
768                                 float sig_r6 = sig_r2 * sig_r2 * sig_r2;
769                                 if constexpr (doExclusionForces)
770                                 {
771                                     sig_r6 *= pairExclMask;
772                                 }
773                                 fInvR = epsilon * sig_r6 * (sig_r6 - 1.0F) * r2Inv;
774                             } // (!props.vdwCombLB || doCalcEnergies)
775                             if constexpr (doCalcEnergies || props.vdwPSwitch)
776                             {
777                                 energyLJPair = pairExclMask
778                                                * (c12 * (r6Inv * r6Inv + repulsionShift.cpot) * c_oneTwelfth
779                                                   - c6 * (r6Inv + dispersionShift.cpot) * c_oneSixth);
780                             }
781                             if constexpr (props.vdwFSwitch)
782                             {
783                                 ljForceSwitch<doCalcEnergies>(
784                                         dispersionShift, repulsionShift, rVdwSwitch, c6, c12, rInv, r2, &fInvR, &energyLJPair);
785                             }
786                             if constexpr (props.vdwEwald)
787                             {
788                                 ljEwaldComb<doCalcEnergies, vdwType>(a_nbfpComb,
789                                                                      ljEwaldShift,
790                                                                      atomTypeI,
791                                                                      atomTypeJ,
792                                                                      r2,
793                                                                      r2Inv,
794                                                                      ewaldCoeffLJ_2,
795                                                                      ewaldCoeffLJ_6_6,
796                                                                      pairExclMask,
797                                                                      &fInvR,
798                                                                      &energyLJPair);
799                             } // (props.vdwEwald)
800                             if constexpr (props.vdwPSwitch)
801                             {
802                                 ljPotentialSwitch<doCalcEnergies>(
803                                         vdwSwitch, rVdwSwitch, rInv, r2, &fInvR, &energyLJPair);
804                             }
805                             if constexpr (props.elecEwaldTwin)
806                             {
807                                 // Separate VDW cut-off check to enable twin-range cut-offs
808                                 // (rVdw < rCoulomb <= rList)
809                                 const float vdwInRange = (r2 < rVdwSq) ? 1.0F : 0.0F;
810                                 fInvR *= vdwInRange;
811                                 if constexpr (doCalcEnergies)
812                                 {
813                                     energyLJPair *= vdwInRange;
814                                 }
815                             }
816                             if constexpr (doCalcEnergies)
817                             {
818                                 energyVdw += energyLJPair;
819                             }
820
821                             if constexpr (props.elecCutoff)
822                             {
823                                 if constexpr (doExclusionForces)
824                                 {
825                                     fInvR += qi * qj * pairExclMask * r2Inv * rInv;
826                                 }
827                                 else
828                                 {
829                                     fInvR += qi * qj * r2Inv * rInv;
830                                 }
831                             }
832                             if constexpr (props.elecRF)
833                             {
834                                 fInvR += qi * qj * (pairExclMask * r2Inv * rInv - twoKRf);
835                             }
836                             if constexpr (props.elecEwaldAna)
837                             {
838                                 fInvR += qi * qj
839                                          * (pairExclMask * r2Inv * rInv + pmeCorrF(beta2 * r2) * beta3);
840                             }
841                             if constexpr (props.elecEwaldTab)
842                             {
843                                 fInvR += qi * qj
844                                          * (pairExclMask * r2Inv
845                                             - interpolateCoulombForceR(
846                                                       a_coulombTab, coulombTabScale, r2 * rInv))
847                                          * rInv;
848                             }
849
850                             if constexpr (doCalcEnergies)
851                             {
852                                 if constexpr (props.elecCutoff)
853                                 {
854                                     energyElec += qi * qj * (pairExclMask * rInv - cRF);
855                                 }
856                                 if constexpr (props.elecRF)
857                                 {
858                                     energyElec +=
859                                             qi * qj * (pairExclMask * rInv + 0.5f * twoKRf * r2 - cRF);
860                                 }
861                                 if constexpr (props.elecEwald)
862                                 {
863                                     energyElec +=
864                                             qi * qj
865                                             * (rInv * (pairExclMask - cl::sycl::erf(r2 * rInv * ewaldBeta))
866                                                - pairExclMask * ewaldShift);
867                                 }
868                             }
869
870                             const float3 forceIJ = rv * fInvR;
871
872                             /* accumulate j forces in registers */
873                             fCjBuf -= forceIJ;
874                             /* accumulate i forces in registers */
875                             fCiBuf[i] += forceIJ;
876                         } // (r2 < rCoulombSq) && notExcluded
877                     }     // (imask & maskJI)
878                     /* shift the mask bit by 1 */
879                     maskJI += maskJI;
880                 } // for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
881                 /* reduce j forces */
882                 reduceForceJShuffle(fCjBuf, itemIdx, tidxi, aj, a_f);
883             } // for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
884             if constexpr (doPruneNBL)
885             {
886                 /* Update the imask with the new one which does not contain the
887                  * out of range clusters anymore. */
888                 a_plistCJ4[j4].imei[widx].imask = imask;
889             }
890         } // for (int j4 = cij4Start; j4 < cij4End; j4 += 1)
891
892         /* skip central shifts when summing shift forces */
893         const bool doCalcShift = (calcShift && !(nbSci.shift == CENTRAL));
894
895         reduceForceIAndFShift(
896                 sm_reductionBuffer, fCiBuf, doCalcShift, itemIdx, tidxi, tidxj, sci, nbSci.shift, a_f, a_fShift);
897
898         if constexpr (doCalcEnergies)
899         {
900             const float energyVdwGroup = sycl_2020::group_reduce(
901                     itemIdx.get_group(), energyVdw, 0.0F, sycl_2020::plus<float>());
902             const float energyElecGroup = sycl_2020::group_reduce(
903                     itemIdx.get_group(), energyElec, 0.0F, sycl_2020::plus<float>());
904
905             if (tidx == 0)
906             {
907                 atomicFetchAdd(a_energyVdw, 0, energyVdwGroup);
908                 atomicFetchAdd(a_energyElec, 0, energyElecGroup);
909             }
910         }
911     };
912 }
913
914 // SYCL 1.2.1 requires providing a unique type for a kernel. Should not be needed for SYCL2020.
915 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
916 class NbnxmKernelName;
917
918 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType, class... Args>
919 cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int numSci, Args&&... args)
920 {
921     // Should not be needed for SYCL2020.
922     using kernelNameType = NbnxmKernelName<doPruneNBL, doCalcEnergies, elecType, vdwType>;
923
924     /* Kernel launch config:
925      * - The thread block dimensions match the size of i-clusters, j-clusters,
926      *   and j-cluster concurrency, in x, y, and z, respectively.
927      * - The 1D block-grid contains as many blocks as super-clusters.
928      */
929     const int                   numBlocks = numSci;
930     const cl::sycl::range<3>    blockSize{ c_clSize, c_clSize, 1 };
931     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
932     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
933
934     cl::sycl::queue q = deviceStream.stream();
935
936     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
937         auto kernel = nbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>(
938                 cgh, std::forward<Args>(args)...);
939         cgh.parallel_for<kernelNameType>(flattenNDRange(range), kernel);
940     });
941
942     return e;
943 }
944
945 template<class... Args>
946 cl::sycl::event chooseAndLaunchNbnxmKernel(bool          doPruneNBL,
947                                            bool          doCalcEnergies,
948                                            enum ElecType elecType,
949                                            enum VdwType  vdwType,
950                                            Args&&... args)
951 {
952     return gmx::dispatchTemplatedFunction(
953             [&](auto doPruneNBL_, auto doCalcEnergies_, auto elecType_, auto vdwType_) {
954                 return launchNbnxmKernel<doPruneNBL_, doCalcEnergies_, elecType_, vdwType_>(
955                         std::forward<Args>(args)...);
956             },
957             doPruneNBL,
958             doCalcEnergies,
959             elecType,
960             vdwType);
961 }
962
963 void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const InteractionLocality iloc)
964 {
965     sycl_atomdata_t*    adat         = nb->atdat;
966     NBParamGpu*         nbp          = nb->nbparam;
967     gpu_plist*          plist        = nb->plist[iloc];
968     const bool          doPruneNBL   = (plist->haveFreshList && !nb->didPrune[iloc]);
969     const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
970
971     // Casting to float simplifies using atomic ops in the kernel
972     cl::sycl::buffer<float3, 1> f(*adat->f.buffer_);
973     auto                        fAsFloat = f.reinterpret<float, 1>(f.get_count() * DIM);
974     cl::sycl::buffer<float3, 1> fShift(*adat->fShift.buffer_);
975     auto fShiftAsFloat = fShift.reinterpret<float, 1>(fShift.get_count() * DIM);
976
977     cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL,
978                                                    stepWork.computeEnergy,
979                                                    nbp->elecType,
980                                                    nbp->vdwType,
981                                                    deviceStream,
982                                                    plist->nsci,
983                                                    adat->xq,
984                                                    fAsFloat,
985                                                    adat->shiftVec,
986                                                    fShiftAsFloat,
987                                                    adat->eElec,
988                                                    adat->eLJ,
989                                                    plist->cj4,
990                                                    plist->sci,
991                                                    plist->excl,
992                                                    adat->ljComb,
993                                                    adat->atomTypes,
994                                                    nbp->nbfp,
995                                                    nbp->nbfp_comb,
996                                                    nbp->coulomb_tab,
997                                                    adat->numTypes,
998                                                    nbp->rcoulomb_sq,
999                                                    nbp->rvdw_sq,
1000                                                    nbp->two_k_rf,
1001                                                    nbp->ewald_beta,
1002                                                    nbp->rlistOuter_sq,
1003                                                    nbp->sh_ewald,
1004                                                    nbp->epsfac,
1005                                                    nbp->ewaldcoeff_lj,
1006                                                    nbp->c_rf,
1007                                                    nbp->dispersion_shift,
1008                                                    nbp->repulsion_shift,
1009                                                    nbp->vdw_switch,
1010                                                    nbp->rvdw_switch,
1011                                                    nbp->sh_lj_ewald,
1012                                                    nbp->coulomb_tab_scale,
1013                                                    stepWork.computeVirial);
1014 }
1015
1016 } // namespace Nbnxm