SYCL: Shorten mangled kernel name types
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief
38  *  NBNXM SYCL kernels
39  *
40  *  \ingroup module_nbnxm
41  */
42 #include "gmxpre.h"
43
44 #include "nbnxm_sycl_kernel.h"
45
46 #include "gromacs/gpu_utils/devicebuffer.h"
47 #include "gromacs/gpu_utils/gmxsycl.h"
48 #include "gromacs/math/functions.h"
49 #include "gromacs/mdtypes/simulation_workload.h"
50 #include "gromacs/pbcutil/ishift.h"
51 #include "gromacs/utility/template_mp.h"
52
53 #include "nbnxm_sycl_kernel_utils.h"
54 #include "nbnxm_sycl_types.h"
55
56 //! \brief Class name for NBNXM kernel
57 template<bool doPruneNBL, bool doCalcEnergies, enum Nbnxm::ElecType elecType, enum Nbnxm::VdwType vdwType>
58 class NbnxmKernel;
59
60 namespace Nbnxm
61 {
62
63 //! \brief Set of boolean constants mimicking preprocessor macros.
64 template<enum ElecType elecType, enum VdwType vdwType>
65 struct EnergyFunctionProperties {
66     static constexpr bool elecCutoff = (elecType == ElecType::Cut); ///< EL_CUTOFF
67     static constexpr bool elecRF     = (elecType == ElecType::RF);  ///< EL_RF
68     static constexpr bool elecEwaldAna =
69             (elecType == ElecType::EwaldAna || elecType == ElecType::EwaldAnaTwin); ///< EL_EWALD_ANA
70     static constexpr bool elecEwaldTab =
71             (elecType == ElecType::EwaldTab || elecType == ElecType::EwaldTabTwin); ///< EL_EWALD_TAB
72     static constexpr bool elecEwaldTwin =
73             (elecType == ElecType::EwaldAnaTwin || elecType == ElecType::EwaldTabTwin);
74     static constexpr bool elecEwald        = (elecEwaldAna || elecEwaldTab); ///< EL_EWALD_ANY
75     static constexpr bool vdwCombLB        = (vdwType == VdwType::CutCombLB);
76     static constexpr bool vdwCombGeom      = (vdwType == VdwType::CutCombGeom); ///< LJ_COMB_GEOM
77     static constexpr bool vdwComb          = (vdwCombLB || vdwCombGeom);        ///< LJ_COMB
78     static constexpr bool vdwEwaldCombGeom = (vdwType == VdwType::EwaldGeom); ///< LJ_EWALD_COMB_GEOM
79     static constexpr bool vdwEwaldCombLB   = (vdwType == VdwType::EwaldLB);   ///< LJ_EWALD_COMB_LB
80     static constexpr bool vdwEwald         = (vdwEwaldCombGeom || vdwEwaldCombLB); ///< LJ_EWALD
81     static constexpr bool vdwFSwitch       = (vdwType == VdwType::FSwitch); ///< LJ_FORCE_SWITCH
82     static constexpr bool vdwPSwitch       = (vdwType == VdwType::PSwitch); ///< LJ_POT_SWITCH
83 };
84
85 //! \brief Templated constants to shorten kernel function declaration.
86 //@{
87 template<enum VdwType vdwType>
88 constexpr bool ljComb = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwComb;
89
90 template<enum ElecType elecType> // Yes, ElecType
91 constexpr bool vdwCutoffCheck = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTwin;
92
93 template<enum ElecType elecType>
94 constexpr bool elecEwald = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwald;
95
96 template<enum ElecType elecType>
97 constexpr bool elecEwaldTab = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTab;
98
99 template<enum VdwType vdwType>
100 constexpr bool ljEwald = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwEwald;
101 //@}
102
103 using cl::sycl::access::fence_space;
104 using cl::sycl::access::mode;
105 using cl::sycl::access::target;
106
107 static inline Float2 convertSigmaEpsilonToC6C12(const float sigma, const float epsilon)
108 {
109     const float sigma2 = sigma * sigma;
110     const float sigma6 = sigma2 * sigma2 * sigma2;
111     const float c6     = epsilon * sigma6;
112     const float c12    = c6 * sigma6;
113
114     return Float2(c6, c12);
115 }
116
117 template<bool doCalcEnergies>
118 static inline void ljForceSwitch(const shift_consts_t         dispersionShift,
119                                  const shift_consts_t         repulsionShift,
120                                  const float                  rVdwSwitch,
121                                  const float                  c6,
122                                  const float                  c12,
123                                  const float                  rInv,
124                                  const float                  r2,
125                                  cl::sycl::private_ptr<float> fInvR,
126                                  cl::sycl::private_ptr<float> eLJ)
127 {
128     /* force switch constants */
129     const float dispShiftV2 = dispersionShift.c2;
130     const float dispShiftV3 = dispersionShift.c3;
131     const float repuShiftV2 = repulsionShift.c2;
132     const float repuShiftV3 = repulsionShift.c3;
133
134     const float r       = r2 * rInv;
135     const float rSwitch = cl::sycl::fdim(r, rVdwSwitch); // max(r - rVdwSwitch, 0)
136
137     *fInvR += -c6 * (dispShiftV2 + dispShiftV3 * rSwitch) * rSwitch * rSwitch * rInv
138               + c12 * (repuShiftV2 + repuShiftV3 * rSwitch) * rSwitch * rSwitch * rInv;
139
140     if constexpr (doCalcEnergies)
141     {
142         const float dispShiftF2 = dispShiftV2 / 3;
143         const float dispShiftF3 = dispShiftV3 / 4;
144         const float repuShiftF2 = repuShiftV2 / 3;
145         const float repuShiftF3 = repuShiftV3 / 4;
146         *eLJ += c6 * (dispShiftF2 + dispShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch
147                 - c12 * (repuShiftF2 + repuShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch;
148     }
149 }
150
151 //! \brief Fetch C6 grid contribution coefficients and return the product of these.
152 template<enum VdwType vdwType>
153 static inline float calculateLJEwaldC6Grid(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
154                                            const int                                typeI,
155                                            const int                                typeJ)
156 {
157     if constexpr (vdwType == VdwType::EwaldGeom)
158     {
159         return a_nbfpComb[typeI][0] * a_nbfpComb[typeJ][0];
160     }
161     else
162     {
163         static_assert(vdwType == VdwType::EwaldLB);
164         /* sigma and epsilon are scaled to give 6*C6 */
165         const Float2 c6c12_i = a_nbfpComb[typeI];
166         const Float2 c6c12_j = a_nbfpComb[typeJ];
167
168         const float sigma   = c6c12_i[0] + c6c12_j[0];
169         const float epsilon = c6c12_i[1] * c6c12_j[1];
170
171         const float sigma2 = sigma * sigma;
172         return epsilon * sigma2 * sigma2 * sigma2;
173     }
174 }
175
176 //! Calculate LJ-PME grid force contribution with geometric or LB combination rule.
177 template<bool doCalcEnergies, enum VdwType vdwType>
178 static inline void ljEwaldComb(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
179                                const float                              sh_lj_ewald,
180                                const int                                typeI,
181                                const int                                typeJ,
182                                const float                              r2,
183                                const float                              r2Inv,
184                                const float                              lje_coeff2,
185                                const float                              lje_coeff6_6,
186                                const float                              int_bit,
187                                cl::sycl::private_ptr<float>             fInvR,
188                                cl::sycl::private_ptr<float>             eLJ)
189 {
190     const float c6grid = calculateLJEwaldC6Grid<vdwType>(a_nbfpComb, typeI, typeJ);
191
192     /* Recalculate inv_r6 without exclusion mask */
193     const float inv_r6_nm = r2Inv * r2Inv * r2Inv;
194     const float cr2       = lje_coeff2 * r2;
195     const float expmcr2   = cl::sycl::exp(-cr2);
196     const float poly      = 1.0F + cr2 + 0.5F * cr2 * cr2;
197
198     /* Subtract the grid force from the total LJ force */
199     *fInvR += c6grid * (inv_r6_nm - expmcr2 * (inv_r6_nm * poly + lje_coeff6_6)) * r2Inv;
200
201     if constexpr (doCalcEnergies)
202     {
203         /* Shift should be applied only to real LJ pairs */
204         const float sh_mask = sh_lj_ewald * int_bit;
205         *eLJ += c_oneSixth * c6grid * (inv_r6_nm * (1.0F - expmcr2 * poly) + sh_mask);
206     }
207 }
208
209 /*! \brief Apply potential switch. */
210 template<bool doCalcEnergies>
211 static inline void ljPotentialSwitch(const switch_consts_t        vdwSwitch,
212                                      const float                  rVdwSwitch,
213                                      const float                  rInv,
214                                      const float                  r2,
215                                      cl::sycl::private_ptr<float> fInvR,
216                                      cl::sycl::private_ptr<float> eLJ)
217 {
218     /* potential switch constants */
219     const float switchV3 = vdwSwitch.c3;
220     const float switchV4 = vdwSwitch.c4;
221     const float switchV5 = vdwSwitch.c5;
222     const float switchF2 = 3 * vdwSwitch.c3;
223     const float switchF3 = 4 * vdwSwitch.c4;
224     const float switchF4 = 5 * vdwSwitch.c5;
225
226     const float r       = r2 * rInv;
227     const float rSwitch = r - rVdwSwitch;
228
229     if (rSwitch > 0.0F)
230     {
231         const float sw =
232                 1.0F + (switchV3 + (switchV4 + switchV5 * rSwitch) * rSwitch) * rSwitch * rSwitch * rSwitch;
233         const float dsw = (switchF2 + (switchF3 + switchF4 * rSwitch) * rSwitch) * rSwitch * rSwitch;
234
235         *fInvR = (*fInvR) * sw - rInv * (*eLJ) * dsw;
236         if constexpr (doCalcEnergies)
237         {
238             *eLJ *= sw;
239         }
240     }
241 }
242
243
244 /*! \brief Calculate analytical Ewald correction term. */
245 static inline float pmeCorrF(const float z2)
246 {
247     constexpr float FN6 = -1.7357322914161492954e-8F;
248     constexpr float FN5 = 1.4703624142580877519e-6F;
249     constexpr float FN4 = -0.000053401640219807709149F;
250     constexpr float FN3 = 0.0010054721316683106153F;
251     constexpr float FN2 = -0.019278317264888380590F;
252     constexpr float FN1 = 0.069670166153766424023F;
253     constexpr float FN0 = -0.75225204789749321333F;
254
255     constexpr float FD4 = 0.0011193462567257629232F;
256     constexpr float FD3 = 0.014866955030185295499F;
257     constexpr float FD2 = 0.11583842382862377919F;
258     constexpr float FD1 = 0.50736591960530292870F;
259     constexpr float FD0 = 1.0F;
260
261     const float z4 = z2 * z2;
262
263     float       polyFD0 = FD4 * z4 + FD2;
264     const float polyFD1 = FD3 * z4 + FD1;
265     polyFD0             = polyFD0 * z4 + FD0;
266     polyFD0             = polyFD1 * z2 + polyFD0;
267
268     polyFD0 = 1.0F / polyFD0;
269
270     float polyFN0 = FN6 * z4 + FN4;
271     float polyFN1 = FN5 * z4 + FN3;
272     polyFN0       = polyFN0 * z4 + FN2;
273     polyFN1       = polyFN1 * z4 + FN1;
274     polyFN0       = polyFN0 * z4 + FN0;
275     polyFN0       = polyFN1 * z2 + polyFN0;
276
277     return polyFN0 * polyFD0;
278 }
279
280 /*! \brief Linear interpolation using exactly two FMA operations.
281  *
282  *  Implements numeric equivalent of: (1-t)*d0 + t*d1.
283  */
284 template<typename T>
285 static inline T lerp(T d0, T d1, T t)
286 {
287     return fma(t, d1, fma(-t, d0, d0));
288 }
289
290 /*! \brief Interpolate Ewald coulomb force correction using the F*r table. */
291 static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::read> a_coulombTab,
292                                              const float coulombTabScale,
293                                              const float r)
294 {
295     const float normalized = coulombTabScale * r;
296     const int   index      = static_cast<int>(normalized);
297     const float fraction   = normalized - index;
298
299     const float left  = a_coulombTab[index];
300     const float right = a_coulombTab[index + 1];
301
302     return lerp(left, right, fraction); // TODO: cl::sycl::mix
303 }
304
305 /*! \brief Reduce c_clSize j-force components using shifts and atomically accumulate into a_f.
306  *
307  * c_clSize consecutive threads hold the force components of a j-atom which we
308  * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
309  */
310 static inline void reduceForceJShuffle(Float3                             f,
311                                        const cl::sycl::nd_item<1>         itemIdx,
312                                        const int                          tidxi,
313                                        const int                          aidx,
314                                        DeviceAccessor<float, mode_atomic> a_f)
315 {
316     static_assert(c_clSize == 8 || c_clSize == 4);
317     sycl_2020::sub_group sg = itemIdx.get_sub_group();
318
319     f[0] += sycl_2020::shift_left(sg, f[0], 1);
320     f[1] += sycl_2020::shift_right(sg, f[1], 1);
321     f[2] += sycl_2020::shift_left(sg, f[2], 1);
322     if (tidxi & 1)
323     {
324         f[0] = f[1];
325     }
326
327     f[0] += sycl_2020::shift_left(sg, f[0], 2);
328     f[2] += sycl_2020::shift_right(sg, f[2], 2);
329     if (tidxi & 2)
330     {
331         f[0] = f[2];
332     }
333
334     if constexpr (c_clSize == 8)
335     {
336         f[0] += sycl_2020::shift_left(sg, f[0], 4);
337     }
338
339     if (tidxi < 3)
340     {
341         atomicFetchAdd(a_f, 3 * aidx + tidxi, f[0]);
342     }
343 }
344
345 /*! \brief Reduce c_clSize j-force components using local memory and atomically accumulate into a_f.
346  *
347  * c_clSize consecutive threads hold the force components of a j-atom which we
348  * reduced in cl_Size steps using shift and atomically accumulate them into \p a_f.
349  *
350  * TODO: implement binary reduction flavor for the case where cl_Size is power of two.
351  */
352 static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
353                                        Float3                             f,
354                                        const cl::sycl::nd_item<1>         itemIdx,
355                                        const int                          tidxi,
356                                        const int                          tidxj,
357                                        const int                          aidx,
358                                        DeviceAccessor<float, mode_atomic> a_f)
359 {
360     static constexpr int sc_fBufferStride = c_clSizeSq;
361     int                  tidx             = tidxi + tidxj * c_clSize;
362     sm_buf[0 * sc_fBufferStride + tidx]   = f[0];
363     sm_buf[1 * sc_fBufferStride + tidx]   = f[1];
364     sm_buf[2 * sc_fBufferStride + tidx]   = f[2];
365
366     subGroupBarrier(itemIdx);
367
368     // reducing data 8-by-by elements on the leader of same threads as those storing above
369     assert(itemIdx.get_sub_group().get_local_range().size() >= c_clSize);
370
371     if (tidxi < 3)
372     {
373         float fSum = 0.0F;
374         for (int j = tidxj * c_clSize; j < (tidxj + 1) * c_clSize; j++)
375         {
376             fSum += sm_buf[sc_fBufferStride * tidxi + j];
377         }
378
379         atomicFetchAdd(a_f, 3 * aidx + tidxi, fSum);
380     }
381 }
382
383
384 /*! \brief Reduce c_clSize j-force components using either shifts or local memory and atomically accumulate into a_f.
385  */
386 static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
387                                 Float3                                                        f,
388                                 const cl::sycl::nd_item<1>         itemIdx,
389                                 const int                          tidxi,
390                                 const int                          tidxj,
391                                 const int                          aidx,
392                                 DeviceAccessor<float, mode_atomic> a_f)
393 {
394     if constexpr (!gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster))
395     {
396         reduceForceJGeneric(sm_buf, f, itemIdx, tidxi, tidxj, aidx, a_f);
397     }
398     else
399     {
400         reduceForceJShuffle(f, itemIdx, tidxi, aidx, a_f);
401     }
402 }
403
404
405 /*! \brief Final i-force reduction.
406  *
407  * Reduce c_nbnxnGpuNumClusterPerSupercluster i-force componets stored in \p fCiBuf[]
408  * accumulating atomically into \p a_f.
409  * If \p calcFShift is true, further reduce shift forces and atomically accumulate into \p a_fShift.
410  *
411  * This implementation works only with power of two array sizes.
412  */
413 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
414                                          const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
415                                          const bool   calcFShift,
416                                          const cl::sycl::nd_item<1>         itemIdx,
417                                          const int                          tidxi,
418                                          const int                          tidxj,
419                                          const int                          sci,
420                                          const int                          shift,
421                                          DeviceAccessor<float, mode_atomic> a_f,
422                                          DeviceAccessor<float, mode_atomic> a_fShift)
423 {
424     // must have power of two elements in fCiBuf
425     static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster));
426
427     static constexpr int bufStride  = c_clSize * c_clSize;
428     static constexpr int clSizeLog2 = gmx::StaticLog2<c_clSize>::value;
429     const int            tidx       = tidxi + tidxj * c_clSize;
430     float                fShiftBuf  = 0.0F;
431     for (int ciOffset = 0; ciOffset < c_nbnxnGpuNumClusterPerSupercluster; ciOffset++)
432     {
433         const int aidx = (sci * c_nbnxnGpuNumClusterPerSupercluster + ciOffset) * c_clSize + tidxi;
434         /* store i forces in shmem */
435         sm_buf[tidx]                 = fCiBuf[ciOffset][0];
436         sm_buf[bufStride + tidx]     = fCiBuf[ciOffset][1];
437         sm_buf[2 * bufStride + tidx] = fCiBuf[ciOffset][2];
438         itemIdx.barrier(fence_space::local_space);
439
440         /* Reduce the initial c_clSize values for each i atom to half
441          * every step by using c_clSize * i threads. */
442         int i = c_clSize / 2;
443         for (int j = clSizeLog2 - 1; j > 0; j--)
444         {
445             if (tidxj < i)
446             {
447                 sm_buf[tidxj * c_clSize + tidxi] += sm_buf[(tidxj + i) * c_clSize + tidxi];
448                 sm_buf[bufStride + tidxj * c_clSize + tidxi] +=
449                         sm_buf[bufStride + (tidxj + i) * c_clSize + tidxi];
450                 sm_buf[2 * bufStride + tidxj * c_clSize + tidxi] +=
451                         sm_buf[2 * bufStride + (tidxj + i) * c_clSize + tidxi];
452             }
453             i >>= 1;
454             itemIdx.barrier(fence_space::local_space);
455         }
456
457         /* i == 1, last reduction step, writing to global mem */
458         /* Split the reduction between the first 3 line threads
459            Threads with line id 0 will do the reduction for (float3).x components
460            Threads with line id 1 will do the reduction for (float3).y components
461            Threads with line id 2 will do the reduction for (float3).z components. */
462         if (tidxj < 3)
463         {
464             const float f =
465                     sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
466             atomicFetchAdd(a_f, 3 * aidx + tidxj, f);
467             if (calcFShift)
468             {
469                 fShiftBuf += f;
470             }
471         }
472         itemIdx.barrier(fence_space::local_space);
473     }
474     /* add up local shift forces into global mem */
475     if (calcFShift)
476     {
477         /* Only threads with tidxj < 3 will update fshift.
478            The threads performing the update must be the same as the threads
479            storing the reduction result above. */
480         if (tidxj < 3)
481         {
482             atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
483         }
484     }
485 }
486
487 /*! \brief Main kernel for NBNXM.
488  *
489  */
490 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
491 auto nbnxmKernel(cl::sycl::handler&                                   cgh,
492                  DeviceAccessor<Float4, mode::read>                   a_xq,
493                  DeviceAccessor<float, mode_atomic>                   a_f,
494                  DeviceAccessor<Float3, mode::read>                   a_shiftVec,
495                  DeviceAccessor<float, mode_atomic>                   a_fShift,
496                  OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyElec,
497                  OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyVdw,
498                  DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
499                  DeviceAccessor<nbnxn_sci_t, mode::read>                                 a_plistSci,
500                  DeviceAccessor<nbnxn_excl_t, mode::read>                    a_plistExcl,
501                  OptionalAccessor<Float2, mode::read, ljComb<vdwType>>       a_ljComb,
502                  OptionalAccessor<int, mode::read, !ljComb<vdwType>>         a_atomTypes,
503                  OptionalAccessor<Float2, mode::read, !ljComb<vdwType>>      a_nbfp,
504                  OptionalAccessor<Float2, mode::read, ljEwald<vdwType>>      a_nbfpComb,
505                  OptionalAccessor<float, mode::read, elecEwaldTab<elecType>> a_coulombTab,
506                  const int                                                   numTypes,
507                  const float                                                 rCoulombSq,
508                  const float                                                 rVdwSq,
509                  const float                                                 twoKRf,
510                  const float                                                 ewaldBeta,
511                  const float                                                 rlistOuterSq,
512                  const float                                                 ewaldShift,
513                  const float                                                 epsFac,
514                  const float                                                 ewaldCoeffLJ,
515                  const float                                                 cRF,
516                  const shift_consts_t                                        dispersionShift,
517                  const shift_consts_t                                        repulsionShift,
518                  const switch_consts_t                                       vdwSwitch,
519                  const float                                                 rVdwSwitch,
520                  const float                                                 ljEwaldShift,
521                  const float                                                 coulombTabScale,
522                  const bool                                                  calcShift)
523 {
524     static constexpr EnergyFunctionProperties<elecType, vdwType> props;
525
526     cgh.require(a_xq);
527     cgh.require(a_f);
528     cgh.require(a_shiftVec);
529     cgh.require(a_fShift);
530     if constexpr (doCalcEnergies)
531     {
532         cgh.require(a_energyElec);
533         cgh.require(a_energyVdw);
534     }
535     cgh.require(a_plistCJ4);
536     cgh.require(a_plistSci);
537     cgh.require(a_plistExcl);
538     if constexpr (!props.vdwComb)
539     {
540         cgh.require(a_atomTypes);
541         cgh.require(a_nbfp);
542     }
543     else
544     {
545         cgh.require(a_ljComb);
546     }
547     if constexpr (props.vdwEwald)
548     {
549         cgh.require(a_nbfpComb);
550     }
551     if constexpr (props.elecEwaldTab)
552     {
553         cgh.require(a_coulombTab);
554     }
555
556     // shmem buffer for i x+q pre-loading
557     cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(
558             cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
559
560     // shmem buffer for force reduction
561     // SYCL-TODO: Make into 3D; section 4.7.6.11 of SYCL2020 specs
562     cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_reductionBuffer(
563             cl::sycl::range<1>(c_clSize * c_clSize * DIM), cgh);
564
565     auto sm_atomTypeI = [&]() {
566         if constexpr (!props.vdwComb)
567         {
568             return cl::sycl::accessor<int, 2, mode::read_write, target::local>(
569                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
570         }
571         else
572         {
573             return nullptr;
574         }
575     }();
576
577     auto sm_ljCombI = [&]() {
578         if constexpr (props.vdwComb)
579         {
580             return cl::sycl::accessor<Float2, 2, mode::read_write, target::local>(
581                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
582         }
583         else
584         {
585             return nullptr;
586         }
587     }();
588
589     /* Flag to control the calculation of exclusion forces in the kernel
590      * We do that with Ewald (elec/vdw) and RF. Cut-off only has exclusion
591      * energy terms. */
592     constexpr bool doExclusionForces =
593             (props.elecEwald || props.elecRF || props.vdwEwald || (props.elecCutoff && doCalcEnergies));
594
595     // The post-prune j-i cluster-pair organization is linked to how exclusion and interaction mask data is stored.
596     // Currently this is ideally suited for 32-wide subgroup size but slightly less so for others,
597     // e.g. subGroupSize > prunedClusterPairSize on AMD GCN / CDNA.
598     // Hence, the two are decoupled.
599     constexpr int prunedClusterPairSize = c_clSize * c_splitClSize;
600 #if defined(HIPSYCL_PLATFORM_ROCM) // SYCL-TODO AMD RDNA/RDNA2 has 32-wide exec; how can we check for that?
601     gmx_unused constexpr int subGroupSize = c_clSize * c_clSize;
602 #else
603     gmx_unused constexpr int subGroupSize = prunedClusterPairSize;
604 #endif
605
606     return [=](cl::sycl::nd_item<1> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
607     {
608         /* thread/block/warp id-s */
609         const cl::sycl::id<3> localId = unflattenId<c_clSize, c_clSize>(itemIdx.get_local_id());
610         const unsigned        tidxi   = localId[0];
611         const unsigned        tidxj   = localId[1];
612         const unsigned        tidx    = tidxj * c_clSize + tidxi;
613         const unsigned        tidxz   = 0;
614
615         // Group indexing was flat originally, no need to unflatten it.
616         const unsigned bidx = itemIdx.get_group(0);
617
618         const sycl_2020::sub_group sg = itemIdx.get_sub_group();
619         // Could use sg.get_group_range to compute the imask & exclusion Idx, but too much of the logic relies on it anyway
620         // and in cases where prunedClusterPairSize != subGroupSize we can't use it anyway
621         const unsigned imeiIdx = tidx / prunedClusterPairSize;
622
623         Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster]; // i force buffer
624         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
625         {
626             fCiBuf[i] = Float3(0.0F, 0.0F, 0.0F);
627         }
628
629         const nbnxn_sci_t nbSci     = a_plistSci[bidx];
630         const int         sci       = nbSci.sci;
631         const int         cij4Start = nbSci.cj4_ind_start;
632         const int         cij4End   = nbSci.cj4_ind_end;
633
634         // Only needed if props.elecEwaldAna
635         const float beta2 = ewaldBeta * ewaldBeta;
636         const float beta3 = ewaldBeta * ewaldBeta * ewaldBeta;
637
638         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i += c_clSize)
639         {
640             /* Pre-load i-atom x and q into shared memory */
641             const int             ci       = sci * c_nbnxnGpuNumClusterPerSupercluster + tidxj + i;
642             const int             ai       = ci * c_clSize + tidxi;
643             const cl::sycl::id<2> cacheIdx = cl::sycl::id<2>(tidxj + i, tidxi);
644
645             const Float3 shift = a_shiftVec[nbSci.shift];
646             Float4       xqi   = a_xq[ai];
647             xqi += Float4(shift[0], shift[1], shift[2], 0.0F);
648             xqi[3] *= epsFac;
649             sm_xq[cacheIdx] = xqi;
650
651             if constexpr (!props.vdwComb)
652             {
653                 // Pre-load the i-atom types into shared memory
654                 sm_atomTypeI[cacheIdx] = a_atomTypes[ai];
655             }
656             else
657             {
658                 // Pre-load the LJ combination parameters into shared memory
659                 sm_ljCombI[cacheIdx] = a_ljComb[ai];
660             }
661         }
662         itemIdx.barrier(fence_space::local_space);
663
664         float ewaldCoeffLJ_2, ewaldCoeffLJ_6_6; // Only needed if (props.vdwEwald)
665         if constexpr (props.vdwEwald)
666         {
667             ewaldCoeffLJ_2   = ewaldCoeffLJ * ewaldCoeffLJ;
668             ewaldCoeffLJ_6_6 = ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * c_oneSixth;
669         }
670
671         float energyVdw, energyElec; // Only needed if (doCalcEnergies)
672         if constexpr (doCalcEnergies)
673         {
674             energyVdw = energyElec = 0.0F;
675         }
676         if constexpr (doCalcEnergies && doExclusionForces)
677         {
678             if (nbSci.shift == gmx::c_centralShiftIndex
679                 && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
680             {
681                 // we have the diagonal: add the charge and LJ self interaction energy term
682                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
683                 {
684                     // TODO: Are there other options?
685                     if constexpr (props.elecEwald || props.elecRF || props.elecCutoff)
686                     {
687                         const float qi = sm_xq[i][tidxi][3];
688                         energyElec += qi * qi;
689                     }
690                     if constexpr (props.vdwEwald)
691                     {
692                         energyVdw +=
693                                 a_nbfp[a_atomTypes[(sci * c_nbnxnGpuNumClusterPerSupercluster + i) * c_clSize + tidxi]
694                                        * (numTypes + 1)][0];
695                     }
696                 }
697                 /* divide the self term(s) equally over the j-threads, then multiply with the coefficients. */
698                 if constexpr (props.vdwEwald)
699                 {
700                     energyVdw /= c_clSize;
701                     energyVdw *= 0.5F * c_oneSixth * ewaldCoeffLJ_6_6; // c_OneTwelfth?
702                 }
703                 if constexpr (props.elecRF || props.elecCutoff)
704                 {
705                     // Correct for epsfac^2 due to adding qi^2 */
706                     energyElec /= epsFac * c_clSize;
707                     energyElec *= -0.5F * cRF;
708                 }
709                 if constexpr (props.elecEwald)
710                 {
711                     // Correct for epsfac^2 due to adding qi^2 */
712                     energyElec /= epsFac * c_clSize;
713                     energyElec *= -ewaldBeta * c_OneOverSqrtPi; /* last factor 1/sqrt(pi) */
714                 }
715             } // (nbSci.shift == gmx::c_centralShiftIndex && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
716         }     // (doCalcEnergies && doExclusionForces)
717
718         // Only needed if (doExclusionForces)
719         const bool nonSelfInteraction = !(nbSci.shift == gmx::c_centralShiftIndex & tidxj <= tidxi);
720
721         // loop over the j clusters = seen by any of the atoms in the current super-cluster
722         for (int j4 = cij4Start + tidxz; j4 < cij4End; j4 += 1)
723         {
724             unsigned imask = a_plistCJ4[j4].imei[imeiIdx].imask;
725             if (!doPruneNBL && !imask)
726             {
727                 continue;
728             }
729             const int wexclIdx = a_plistCJ4[j4].imei[imeiIdx].excl_ind;
730             static_assert(gmx::isPowerOfTwo(prunedClusterPairSize));
731             const unsigned wexcl = a_plistExcl[wexclIdx].pair[tidx & (prunedClusterPairSize - 1)];
732             for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
733             {
734                 const bool maskSet =
735                         imask & (superClInteractionMask << (jm * c_nbnxnGpuNumClusterPerSupercluster));
736                 if (!maskSet)
737                 {
738                     continue;
739                 }
740                 unsigned  maskJI = (1U << (jm * c_nbnxnGpuNumClusterPerSupercluster));
741                 const int cj     = a_plistCJ4[j4].cj[jm];
742                 const int aj     = cj * c_clSize + tidxj;
743
744                 // load j atom data
745                 const Float4 xqj = a_xq[aj];
746
747                 const Float3 xj(xqj[0], xqj[1], xqj[2]);
748                 const float  qj = xqj[3];
749                 int          atomTypeJ; // Only needed if (!props.vdwComb)
750                 Float2       ljCombJ;   // Only needed if (props.vdwComb)
751                 if constexpr (props.vdwComb)
752                 {
753                     ljCombJ = a_ljComb[aj];
754                 }
755                 else
756                 {
757                     atomTypeJ = a_atomTypes[aj];
758                 }
759
760                 Float3 fCjBuf(0.0F, 0.0F, 0.0F);
761
762                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
763                 {
764                     if (imask & maskJI)
765                     {
766                         // i cluster index
767                         const int ci = sci * c_nbnxnGpuNumClusterPerSupercluster + i;
768                         // all threads load an atom from i cluster ci into shmem!
769                         const Float4 xqi = sm_xq[i][tidxi];
770                         const Float3 xi(xqi[0], xqi[1], xqi[2]);
771
772                         // distance between i and j atoms
773                         const Float3 rv = xi - xj;
774                         float        r2 = norm2(rv);
775
776                         if constexpr (doPruneNBL)
777                         {
778                             /* If _none_ of the atoms pairs are in cutoff range,
779                              * the bit corresponding to the current
780                              * cluster-pair in imask gets set to 0. */
781                             if (!sycl_2020::group_any_of(sg, r2 < rlistOuterSq))
782                             {
783                                 imask &= ~maskJI;
784                             }
785                         }
786                         const float pairExclMask = (wexcl & maskJI) ? 1.0F : 0.0F;
787
788                         // cutoff & exclusion check
789
790                         const bool notExcluded = doExclusionForces ? (nonSelfInteraction | (ci != cj))
791                                                                    : (wexcl & maskJI);
792
793                         // SYCL-TODO: Check optimal way of branching here.
794                         if ((r2 < rCoulombSq) && notExcluded)
795                         {
796                             const float qi = xqi[3];
797                             int         atomTypeI; // Only needed if (!props.vdwComb)
798                             float       sigma, epsilon;
799                             Float2      c6c12;
800
801                             if constexpr (!props.vdwComb)
802                             {
803                                 /* LJ 6*C6 and 12*C12 */
804                                 atomTypeI = sm_atomTypeI[i][tidxi];
805                                 c6c12     = a_nbfp[numTypes * atomTypeI + atomTypeJ];
806                             }
807                             else
808                             {
809                                 const Float2 ljCombI = sm_ljCombI[i][tidxi];
810                                 if constexpr (props.vdwCombGeom)
811                                 {
812                                     c6c12 = Float2(ljCombI[0] * ljCombJ[0], ljCombI[1] * ljCombJ[1]);
813                                 }
814                                 else
815                                 {
816                                     static_assert(props.vdwCombLB);
817                                     // LJ 2^(1/6)*sigma and 12*epsilon
818                                     sigma   = ljCombI[0] + ljCombJ[0];
819                                     epsilon = ljCombI[1] * ljCombJ[1];
820                                     if constexpr (doCalcEnergies)
821                                     {
822                                         c6c12 = convertSigmaEpsilonToC6C12(sigma, epsilon);
823                                     }
824                                 } // props.vdwCombGeom
825                             }     // !props.vdwComb
826
827                             // c6 and c12 are unused and garbage iff props.vdwCombLB && !doCalcEnergies
828                             const float c6  = c6c12[0];
829                             const float c12 = c6c12[1];
830
831                             // Ensure distance do not become so small that r^-12 overflows
832                             r2 = std::max(r2, c_nbnxnMinDistanceSquared);
833 #if GMX_SYCL_HIPSYCL
834                             // No fast/native functions in some compilation passes
835                             const float rInv = cl::sycl::rsqrt(r2);
836 #else
837                             // SYCL-TODO: sycl::half_precision::rsqrt?
838                             const float rInv = cl::sycl::native::rsqrt(r2);
839 #endif
840                             const float r2Inv = rInv * rInv;
841                             float       r6Inv, fInvR, energyLJPair;
842                             if constexpr (!props.vdwCombLB || doCalcEnergies)
843                             {
844                                 r6Inv = r2Inv * r2Inv * r2Inv;
845                                 if constexpr (doExclusionForces)
846                                 {
847                                     // SYCL-TODO: Check if true for SYCL
848                                     /* We could mask r2Inv, but with Ewald masking both
849                                      * r6Inv and fInvR is faster */
850                                     r6Inv *= pairExclMask;
851                                 }
852                                 fInvR = r6Inv * (c12 * r6Inv - c6) * r2Inv;
853                             }
854                             else
855                             {
856                                 float sig_r  = sigma * rInv;
857                                 float sig_r2 = sig_r * sig_r;
858                                 float sig_r6 = sig_r2 * sig_r2 * sig_r2;
859                                 if constexpr (doExclusionForces)
860                                 {
861                                     sig_r6 *= pairExclMask;
862                                 }
863                                 fInvR = epsilon * sig_r6 * (sig_r6 - 1.0F) * r2Inv;
864                             } // (!props.vdwCombLB || doCalcEnergies)
865                             if constexpr (doCalcEnergies || props.vdwPSwitch)
866                             {
867                                 energyLJPair = pairExclMask
868                                                * (c12 * (r6Inv * r6Inv + repulsionShift.cpot) * c_oneTwelfth
869                                                   - c6 * (r6Inv + dispersionShift.cpot) * c_oneSixth);
870                             }
871                             if constexpr (props.vdwFSwitch)
872                             {
873                                 ljForceSwitch<doCalcEnergies>(
874                                         dispersionShift, repulsionShift, rVdwSwitch, c6, c12, rInv, r2, &fInvR, &energyLJPair);
875                             }
876                             if constexpr (props.vdwEwald)
877                             {
878                                 ljEwaldComb<doCalcEnergies, vdwType>(a_nbfpComb,
879                                                                      ljEwaldShift,
880                                                                      atomTypeI,
881                                                                      atomTypeJ,
882                                                                      r2,
883                                                                      r2Inv,
884                                                                      ewaldCoeffLJ_2,
885                                                                      ewaldCoeffLJ_6_6,
886                                                                      pairExclMask,
887                                                                      &fInvR,
888                                                                      &energyLJPair);
889                             } // (props.vdwEwald)
890                             if constexpr (props.vdwPSwitch)
891                             {
892                                 ljPotentialSwitch<doCalcEnergies>(
893                                         vdwSwitch, rVdwSwitch, rInv, r2, &fInvR, &energyLJPair);
894                             }
895                             if constexpr (props.elecEwaldTwin)
896                             {
897                                 // Separate VDW cut-off check to enable twin-range cut-offs
898                                 // (rVdw < rCoulomb <= rList)
899                                 const float vdwInRange = (r2 < rVdwSq) ? 1.0F : 0.0F;
900                                 fInvR *= vdwInRange;
901                                 if constexpr (doCalcEnergies)
902                                 {
903                                     energyLJPair *= vdwInRange;
904                                 }
905                             }
906                             if constexpr (doCalcEnergies)
907                             {
908                                 energyVdw += energyLJPair;
909                             }
910
911                             if constexpr (props.elecCutoff)
912                             {
913                                 if constexpr (doExclusionForces)
914                                 {
915                                     fInvR += qi * qj * pairExclMask * r2Inv * rInv;
916                                 }
917                                 else
918                                 {
919                                     fInvR += qi * qj * r2Inv * rInv;
920                                 }
921                             }
922                             if constexpr (props.elecRF)
923                             {
924                                 fInvR += qi * qj * (pairExclMask * r2Inv * rInv - twoKRf);
925                             }
926                             if constexpr (props.elecEwaldAna)
927                             {
928                                 fInvR += qi * qj
929                                          * (pairExclMask * r2Inv * rInv + pmeCorrF(beta2 * r2) * beta3);
930                             }
931                             if constexpr (props.elecEwaldTab)
932                             {
933                                 fInvR += qi * qj
934                                          * (pairExclMask * r2Inv
935                                             - interpolateCoulombForceR(
936                                                     a_coulombTab, coulombTabScale, r2 * rInv))
937                                          * rInv;
938                             }
939
940                             if constexpr (doCalcEnergies)
941                             {
942                                 if constexpr (props.elecCutoff)
943                                 {
944                                     energyElec += qi * qj * (pairExclMask * rInv - cRF);
945                                 }
946                                 if constexpr (props.elecRF)
947                                 {
948                                     energyElec +=
949                                             qi * qj * (pairExclMask * rInv + 0.5f * twoKRf * r2 - cRF);
950                                 }
951                                 if constexpr (props.elecEwald)
952                                 {
953                                     energyElec +=
954                                             qi * qj
955                                             * (rInv * (pairExclMask - cl::sycl::erf(r2 * rInv * ewaldBeta))
956                                                - pairExclMask * ewaldShift);
957                                 }
958                             }
959
960                             const Float3 forceIJ = rv * fInvR;
961
962                             /* accumulate j forces in registers */
963                             fCjBuf -= forceIJ;
964                             /* accumulate i forces in registers */
965                             fCiBuf[i] += forceIJ;
966                         } // (r2 < rCoulombSq) && notExcluded
967                     }     // (imask & maskJI)
968                     /* shift the mask bit by 1 */
969                     maskJI += maskJI;
970                 } // for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
971                 /* reduce j forces */
972                 reduceForceJ(sm_reductionBuffer, fCjBuf, itemIdx, tidxi, tidxj, aj, a_f);
973             } // for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
974             if constexpr (doPruneNBL)
975             {
976                 /* Update the imask with the new one which does not contain the
977                  * out of range clusters anymore. */
978                 a_plistCJ4[j4].imei[imeiIdx].imask = imask;
979             }
980         } // for (int j4 = cij4Start; j4 < cij4End; j4 += 1)
981
982         /* skip central shifts when summing shift forces */
983         const bool doCalcShift = (calcShift && !(nbSci.shift == gmx::c_centralShiftIndex));
984
985         reduceForceIAndFShift(
986                 sm_reductionBuffer, fCiBuf, doCalcShift, itemIdx, tidxi, tidxj, sci, nbSci.shift, a_f, a_fShift);
987
988         if constexpr (doCalcEnergies)
989         {
990             const float energyVdwGroup = sycl_2020::group_reduce(
991                     itemIdx.get_group(), energyVdw, 0.0F, sycl_2020::plus<float>());
992             const float energyElecGroup = sycl_2020::group_reduce(
993                     itemIdx.get_group(), energyElec, 0.0F, sycl_2020::plus<float>());
994
995             if (tidx == 0)
996             {
997                 atomicFetchAdd(a_energyVdw, 0, energyVdwGroup);
998                 atomicFetchAdd(a_energyElec, 0, energyElecGroup);
999             }
1000         }
1001     };
1002 }
1003
1004 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType, class... Args>
1005 cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int numSci, Args&&... args)
1006 {
1007     using kernelNameType = NbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>;
1008
1009     /* Kernel launch config:
1010      * - The thread block dimensions match the size of i-clusters, j-clusters,
1011      *   and j-cluster concurrency, in x, y, and z, respectively.
1012      * - The 1D block-grid contains as many blocks as super-clusters.
1013      */
1014     const int                   numBlocks = numSci;
1015     const cl::sycl::range<3>    blockSize{ c_clSize, c_clSize, 1 };
1016     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
1017     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
1018
1019     cl::sycl::queue q = deviceStream.stream();
1020
1021     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
1022         auto kernel = nbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>(
1023                 cgh, std::forward<Args>(args)...);
1024         cgh.parallel_for<kernelNameType>(flattenNDRange(range), kernel);
1025     });
1026
1027     return e;
1028 }
1029
1030 template<class... Args>
1031 cl::sycl::event chooseAndLaunchNbnxmKernel(bool          doPruneNBL,
1032                                            bool          doCalcEnergies,
1033                                            enum ElecType elecType,
1034                                            enum VdwType  vdwType,
1035                                            Args&&... args)
1036 {
1037     return gmx::dispatchTemplatedFunction(
1038             [&](auto doPruneNBL_, auto doCalcEnergies_, auto elecType_, auto vdwType_) {
1039                 return launchNbnxmKernel<doPruneNBL_, doCalcEnergies_, elecType_, vdwType_>(
1040                         std::forward<Args>(args)...);
1041             },
1042             doPruneNBL,
1043             doCalcEnergies,
1044             elecType,
1045             vdwType);
1046 }
1047
1048 void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const InteractionLocality iloc)
1049 {
1050     NBAtomDataGpu*      adat         = nb->atdat;
1051     NBParamGpu*         nbp          = nb->nbparam;
1052     gpu_plist*          plist        = nb->plist[iloc];
1053     const bool          doPruneNBL   = (plist->haveFreshList && !nb->didPrune[iloc]);
1054     const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
1055
1056     // Casting to float simplifies using atomic ops in the kernel
1057     cl::sycl::buffer<Float3, 1> f(*adat->f.buffer_);
1058     auto                        fAsFloat = f.reinterpret<float, 1>(f.get_count() * DIM);
1059     cl::sycl::buffer<Float3, 1> fShift(*adat->fShift.buffer_);
1060     auto fShiftAsFloat = fShift.reinterpret<float, 1>(fShift.get_count() * DIM);
1061
1062     cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL,
1063                                                    stepWork.computeEnergy,
1064                                                    nbp->elecType,
1065                                                    nbp->vdwType,
1066                                                    deviceStream,
1067                                                    plist->nsci,
1068                                                    adat->xq,
1069                                                    fAsFloat,
1070                                                    adat->shiftVec,
1071                                                    fShiftAsFloat,
1072                                                    adat->eElec,
1073                                                    adat->eLJ,
1074                                                    plist->cj4,
1075                                                    plist->sci,
1076                                                    plist->excl,
1077                                                    adat->ljComb,
1078                                                    adat->atomTypes,
1079                                                    nbp->nbfp,
1080                                                    nbp->nbfp_comb,
1081                                                    nbp->coulomb_tab,
1082                                                    adat->numTypes,
1083                                                    nbp->rcoulomb_sq,
1084                                                    nbp->rvdw_sq,
1085                                                    nbp->two_k_rf,
1086                                                    nbp->ewald_beta,
1087                                                    nbp->rlistOuter_sq,
1088                                                    nbp->sh_ewald,
1089                                                    nbp->epsfac,
1090                                                    nbp->ewaldcoeff_lj,
1091                                                    nbp->c_rf,
1092                                                    nbp->dispersion_shift,
1093                                                    nbp->repulsion_shift,
1094                                                    nbp->vdw_switch,
1095                                                    nbp->rvdw_switch,
1096                                                    nbp->sh_lj_ewald,
1097                                                    nbp->coulomb_tab_scale,
1098                                                    stepWork.computeVirial);
1099 }
1100
1101 } // namespace Nbnxm