Fix clang-format
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief
38  *  NBNXM SYCL kernels
39  *
40  *  \ingroup module_nbnxm
41  */
42 #include "gmxpre.h"
43
44 #include "nbnxm_sycl_kernel.h"
45
46 #include "gromacs/gpu_utils/devicebuffer.h"
47 #include "gromacs/gpu_utils/gmxsycl.h"
48 #include "gromacs/math/functions.h"
49 #include "gromacs/mdtypes/simulation_workload.h"
50 #include "gromacs/pbcutil/ishift.h"
51 #include "gromacs/utility/template_mp.h"
52
53 #include "nbnxm_sycl_kernel_utils.h"
54 #include "nbnxm_sycl_types.h"
55
56 namespace Nbnxm
57 {
58
59 //! \brief Set of boolean constants mimicking preprocessor macros.
60 template<enum ElecType elecType, enum VdwType vdwType>
61 struct EnergyFunctionProperties {
62     static constexpr bool elecCutoff = (elecType == ElecType::Cut); ///< EL_CUTOFF
63     static constexpr bool elecRF     = (elecType == ElecType::RF);  ///< EL_RF
64     static constexpr bool elecEwaldAna =
65             (elecType == ElecType::EwaldAna || elecType == ElecType::EwaldAnaTwin); ///< EL_EWALD_ANA
66     static constexpr bool elecEwaldTab =
67             (elecType == ElecType::EwaldTab || elecType == ElecType::EwaldTabTwin); ///< EL_EWALD_TAB
68     static constexpr bool elecEwaldTwin =
69             (elecType == ElecType::EwaldAnaTwin || elecType == ElecType::EwaldTabTwin);
70     static constexpr bool elecEwald        = (elecEwaldAna || elecEwaldTab); ///< EL_EWALD_ANY
71     static constexpr bool vdwCombLB        = (vdwType == VdwType::CutCombLB);
72     static constexpr bool vdwCombGeom      = (vdwType == VdwType::CutCombGeom); ///< LJ_COMB_GEOM
73     static constexpr bool vdwComb          = (vdwCombLB || vdwCombGeom);        ///< LJ_COMB
74     static constexpr bool vdwEwaldCombGeom = (vdwType == VdwType::EwaldGeom); ///< LJ_EWALD_COMB_GEOM
75     static constexpr bool vdwEwaldCombLB   = (vdwType == VdwType::EwaldLB);   ///< LJ_EWALD_COMB_LB
76     static constexpr bool vdwEwald         = (vdwEwaldCombGeom || vdwEwaldCombLB); ///< LJ_EWALD
77     static constexpr bool vdwFSwitch       = (vdwType == VdwType::FSwitch); ///< LJ_FORCE_SWITCH
78     static constexpr bool vdwPSwitch       = (vdwType == VdwType::PSwitch); ///< LJ_POT_SWITCH
79 };
80
81 //! \brief Templated constants to shorten kernel function declaration.
82 //@{
83 template<enum VdwType vdwType>
84 constexpr bool ljComb = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwComb;
85
86 template<enum ElecType elecType> // Yes, ElecType
87 constexpr bool vdwCutoffCheck = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTwin;
88
89 template<enum ElecType elecType>
90 constexpr bool elecEwald = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwald;
91
92 template<enum ElecType elecType>
93 constexpr bool elecEwaldTab = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTab;
94
95 template<enum VdwType vdwType>
96 constexpr bool ljEwald = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwEwald;
97 //@}
98
99 using cl::sycl::access::fence_space;
100 using cl::sycl::access::mode;
101 using cl::sycl::access::target;
102
103 static inline Float2 convertSigmaEpsilonToC6C12(const float sigma, const float epsilon)
104 {
105     const float sigma2 = sigma * sigma;
106     const float sigma6 = sigma2 * sigma2 * sigma2;
107     const float c6     = epsilon * sigma6;
108     const float c12    = c6 * sigma6;
109
110     return Float2(c6, c12);
111 }
112
113 template<bool doCalcEnergies>
114 static inline void ljForceSwitch(const shift_consts_t         dispersionShift,
115                                  const shift_consts_t         repulsionShift,
116                                  const float                  rVdwSwitch,
117                                  const float                  c6,
118                                  const float                  c12,
119                                  const float                  rInv,
120                                  const float                  r2,
121                                  cl::sycl::private_ptr<float> fInvR,
122                                  cl::sycl::private_ptr<float> eLJ)
123 {
124     /* force switch constants */
125     const float dispShiftV2 = dispersionShift.c2;
126     const float dispShiftV3 = dispersionShift.c3;
127     const float repuShiftV2 = repulsionShift.c2;
128     const float repuShiftV3 = repulsionShift.c3;
129
130     const float r       = r2 * rInv;
131     const float rSwitch = cl::sycl::fdim(r, rVdwSwitch); // max(r - rVdwSwitch, 0)
132
133     *fInvR += -c6 * (dispShiftV2 + dispShiftV3 * rSwitch) * rSwitch * rSwitch * rInv
134               + c12 * (repuShiftV2 + repuShiftV3 * rSwitch) * rSwitch * rSwitch * rInv;
135
136     if constexpr (doCalcEnergies)
137     {
138         const float dispShiftF2 = dispShiftV2 / 3;
139         const float dispShiftF3 = dispShiftV3 / 4;
140         const float repuShiftF2 = repuShiftV2 / 3;
141         const float repuShiftF3 = repuShiftV3 / 4;
142         *eLJ += c6 * (dispShiftF2 + dispShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch
143                 - c12 * (repuShiftF2 + repuShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch;
144     }
145 }
146
147 //! \brief Fetch C6 grid contribution coefficients and return the product of these.
148 template<enum VdwType vdwType>
149 static inline float calculateLJEwaldC6Grid(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
150                                            const int                                typeI,
151                                            const int                                typeJ)
152 {
153     if constexpr (vdwType == VdwType::EwaldGeom)
154     {
155         return a_nbfpComb[typeI][0] * a_nbfpComb[typeJ][0];
156     }
157     else
158     {
159         static_assert(vdwType == VdwType::EwaldLB);
160         /* sigma and epsilon are scaled to give 6*C6 */
161         const Float2 c6c12_i = a_nbfpComb[typeI];
162         const Float2 c6c12_j = a_nbfpComb[typeJ];
163
164         const float sigma   = c6c12_i[0] + c6c12_j[0];
165         const float epsilon = c6c12_i[1] * c6c12_j[1];
166
167         const float sigma2 = sigma * sigma;
168         return epsilon * sigma2 * sigma2 * sigma2;
169     }
170 }
171
172 //! Calculate LJ-PME grid force contribution with geometric or LB combination rule.
173 template<bool doCalcEnergies, enum VdwType vdwType>
174 static inline void ljEwaldComb(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
175                                const float                              sh_lj_ewald,
176                                const int                                typeI,
177                                const int                                typeJ,
178                                const float                              r2,
179                                const float                              r2Inv,
180                                const float                              lje_coeff2,
181                                const float                              lje_coeff6_6,
182                                const float                              int_bit,
183                                cl::sycl::private_ptr<float>             fInvR,
184                                cl::sycl::private_ptr<float>             eLJ)
185 {
186     const float c6grid = calculateLJEwaldC6Grid<vdwType>(a_nbfpComb, typeI, typeJ);
187
188     /* Recalculate inv_r6 without exclusion mask */
189     const float inv_r6_nm = r2Inv * r2Inv * r2Inv;
190     const float cr2       = lje_coeff2 * r2;
191     const float expmcr2   = cl::sycl::exp(-cr2);
192     const float poly      = 1.0F + cr2 + 0.5F * cr2 * cr2;
193
194     /* Subtract the grid force from the total LJ force */
195     *fInvR += c6grid * (inv_r6_nm - expmcr2 * (inv_r6_nm * poly + lje_coeff6_6)) * r2Inv;
196
197     if constexpr (doCalcEnergies)
198     {
199         /* Shift should be applied only to real LJ pairs */
200         const float sh_mask = sh_lj_ewald * int_bit;
201         *eLJ += c_oneSixth * c6grid * (inv_r6_nm * (1.0F - expmcr2 * poly) + sh_mask);
202     }
203 }
204
205 /*! \brief Apply potential switch. */
206 template<bool doCalcEnergies>
207 static inline void ljPotentialSwitch(const switch_consts_t        vdwSwitch,
208                                      const float                  rVdwSwitch,
209                                      const float                  rInv,
210                                      const float                  r2,
211                                      cl::sycl::private_ptr<float> fInvR,
212                                      cl::sycl::private_ptr<float> eLJ)
213 {
214     /* potential switch constants */
215     const float switchV3 = vdwSwitch.c3;
216     const float switchV4 = vdwSwitch.c4;
217     const float switchV5 = vdwSwitch.c5;
218     const float switchF2 = 3 * vdwSwitch.c3;
219     const float switchF3 = 4 * vdwSwitch.c4;
220     const float switchF4 = 5 * vdwSwitch.c5;
221
222     const float r       = r2 * rInv;
223     const float rSwitch = r - rVdwSwitch;
224
225     if (rSwitch > 0.0F)
226     {
227         const float sw =
228                 1.0F + (switchV3 + (switchV4 + switchV5 * rSwitch) * rSwitch) * rSwitch * rSwitch * rSwitch;
229         const float dsw = (switchF2 + (switchF3 + switchF4 * rSwitch) * rSwitch) * rSwitch * rSwitch;
230
231         *fInvR = (*fInvR) * sw - rInv * (*eLJ) * dsw;
232         if constexpr (doCalcEnergies)
233         {
234             *eLJ *= sw;
235         }
236     }
237 }
238
239
240 /*! \brief Calculate analytical Ewald correction term. */
241 static inline float pmeCorrF(const float z2)
242 {
243     constexpr float FN6 = -1.7357322914161492954e-8F;
244     constexpr float FN5 = 1.4703624142580877519e-6F;
245     constexpr float FN4 = -0.000053401640219807709149F;
246     constexpr float FN3 = 0.0010054721316683106153F;
247     constexpr float FN2 = -0.019278317264888380590F;
248     constexpr float FN1 = 0.069670166153766424023F;
249     constexpr float FN0 = -0.75225204789749321333F;
250
251     constexpr float FD4 = 0.0011193462567257629232F;
252     constexpr float FD3 = 0.014866955030185295499F;
253     constexpr float FD2 = 0.11583842382862377919F;
254     constexpr float FD1 = 0.50736591960530292870F;
255     constexpr float FD0 = 1.0F;
256
257     const float z4 = z2 * z2;
258
259     float       polyFD0 = FD4 * z4 + FD2;
260     const float polyFD1 = FD3 * z4 + FD1;
261     polyFD0             = polyFD0 * z4 + FD0;
262     polyFD0             = polyFD1 * z2 + polyFD0;
263
264     polyFD0 = 1.0F / polyFD0;
265
266     float polyFN0 = FN6 * z4 + FN4;
267     float polyFN1 = FN5 * z4 + FN3;
268     polyFN0       = polyFN0 * z4 + FN2;
269     polyFN1       = polyFN1 * z4 + FN1;
270     polyFN0       = polyFN0 * z4 + FN0;
271     polyFN0       = polyFN1 * z2 + polyFN0;
272
273     return polyFN0 * polyFD0;
274 }
275
276 /*! \brief Linear interpolation using exactly two FMA operations.
277  *
278  *  Implements numeric equivalent of: (1-t)*d0 + t*d1.
279  */
280 template<typename T>
281 static inline T lerp(T d0, T d1, T t)
282 {
283     return fma(t, d1, fma(-t, d0, d0));
284 }
285
286 /*! \brief Interpolate Ewald coulomb force correction using the F*r table. */
287 static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::read> a_coulombTab,
288                                              const float coulombTabScale,
289                                              const float r)
290 {
291     const float normalized = coulombTabScale * r;
292     const int   index      = static_cast<int>(normalized);
293     const float fraction   = normalized - index;
294
295     const float left  = a_coulombTab[index];
296     const float right = a_coulombTab[index + 1];
297
298     return lerp(left, right, fraction); // TODO: cl::sycl::mix
299 }
300
301 /*! \brief Reduce c_clSize j-force components using shifts and atomically accumulate into a_f.
302  *
303  * c_clSize consecutive threads hold the force components of a j-atom which we
304  * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
305  */
306 static inline void reduceForceJShuffle(Float3                             f,
307                                        const cl::sycl::nd_item<1>         itemIdx,
308                                        const int                          tidxi,
309                                        const int                          aidx,
310                                        DeviceAccessor<float, mode_atomic> a_f)
311 {
312     static_assert(c_clSize == 8 || c_clSize == 4);
313     sycl_2020::sub_group sg = itemIdx.get_sub_group();
314
315     f[0] += sycl_2020::shift_left(sg, f[0], 1);
316     f[1] += sycl_2020::shift_right(sg, f[1], 1);
317     f[2] += sycl_2020::shift_left(sg, f[2], 1);
318     if (tidxi & 1)
319     {
320         f[0] = f[1];
321     }
322
323     f[0] += sycl_2020::shift_left(sg, f[0], 2);
324     f[2] += sycl_2020::shift_right(sg, f[2], 2);
325     if (tidxi & 2)
326     {
327         f[0] = f[2];
328     }
329
330     if constexpr (c_clSize == 8)
331     {
332         f[0] += sycl_2020::shift_left(sg, f[0], 4);
333     }
334
335     if (tidxi < 3)
336     {
337         atomicFetchAdd(a_f, 3 * aidx + tidxi, f[0]);
338     }
339 }
340
341 /*! \brief Reduce c_clSize j-force components using local memory and atomically accumulate into a_f.
342  *
343  * c_clSize consecutive threads hold the force components of a j-atom which we
344  * reduced in cl_Size steps using shift and atomically accumulate them into \p a_f.
345  *
346  * TODO: implement binary reduction flavor for the case where cl_Size is power of two.
347  */
348 static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
349                                        Float3                             f,
350                                        const cl::sycl::nd_item<1>         itemIdx,
351                                        const int                          tidxi,
352                                        const int                          tidxj,
353                                        const int                          aidx,
354                                        DeviceAccessor<float, mode_atomic> a_f)
355 {
356     static constexpr int sc_fBufferStride = c_clSizeSq;
357     int                  tidx             = tidxi + tidxj * c_clSize;
358     sm_buf[0 * sc_fBufferStride + tidx]   = f[0];
359     sm_buf[1 * sc_fBufferStride + tidx]   = f[1];
360     sm_buf[2 * sc_fBufferStride + tidx]   = f[2];
361
362     subGroupBarrier(itemIdx);
363
364     // reducing data 8-by-by elements on the leader of same threads as those storing above
365     assert(itemIdx.get_sub_group().get_local_range().size() >= c_clSize);
366
367     if (tidxi < 3)
368     {
369         float fSum = 0.0F;
370         for (int j = tidxj * c_clSize; j < (tidxj + 1) * c_clSize; j++)
371         {
372             fSum += sm_buf[sc_fBufferStride * tidxi + j];
373         }
374
375         atomicFetchAdd(a_f, 3 * aidx + tidxi, fSum);
376     }
377 }
378
379
380 /*! \brief Reduce c_clSize j-force components using either shifts or local memory and atomically accumulate into a_f.
381  */
382 static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
383                                 Float3                                                        f,
384                                 const cl::sycl::nd_item<1>         itemIdx,
385                                 const int                          tidxi,
386                                 const int                          tidxj,
387                                 const int                          aidx,
388                                 DeviceAccessor<float, mode_atomic> a_f)
389 {
390     if constexpr (!gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster))
391     {
392         reduceForceJGeneric(sm_buf, f, itemIdx, tidxi, tidxj, aidx, a_f);
393     }
394     else
395     {
396         reduceForceJShuffle(f, itemIdx, tidxi, aidx, a_f);
397     }
398 }
399
400
401 /*! \brief Final i-force reduction.
402  *
403  * Reduce c_nbnxnGpuNumClusterPerSupercluster i-force componets stored in \p fCiBuf[]
404  * accumulating atomically into \p a_f.
405  * If \p calcFShift is true, further reduce shift forces and atomically accumulate into \p a_fShift.
406  *
407  * This implementation works only with power of two array sizes.
408  */
409 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
410                                          const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
411                                          const bool   calcFShift,
412                                          const cl::sycl::nd_item<1>         itemIdx,
413                                          const int                          tidxi,
414                                          const int                          tidxj,
415                                          const int                          sci,
416                                          const int                          shift,
417                                          DeviceAccessor<float, mode_atomic> a_f,
418                                          DeviceAccessor<float, mode_atomic> a_fShift)
419 {
420     // must have power of two elements in fCiBuf
421     static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster));
422
423     static constexpr int bufStride  = c_clSize * c_clSize;
424     static constexpr int clSizeLog2 = gmx::StaticLog2<c_clSize>::value;
425     const int            tidx       = tidxi + tidxj * c_clSize;
426     float                fShiftBuf  = 0.0F;
427     for (int ciOffset = 0; ciOffset < c_nbnxnGpuNumClusterPerSupercluster; ciOffset++)
428     {
429         const int aidx = (sci * c_nbnxnGpuNumClusterPerSupercluster + ciOffset) * c_clSize + tidxi;
430         /* store i forces in shmem */
431         sm_buf[tidx]                 = fCiBuf[ciOffset][0];
432         sm_buf[bufStride + tidx]     = fCiBuf[ciOffset][1];
433         sm_buf[2 * bufStride + tidx] = fCiBuf[ciOffset][2];
434         itemIdx.barrier(fence_space::local_space);
435
436         /* Reduce the initial c_clSize values for each i atom to half
437          * every step by using c_clSize * i threads. */
438         int i = c_clSize / 2;
439         for (int j = clSizeLog2 - 1; j > 0; j--)
440         {
441             if (tidxj < i)
442             {
443                 sm_buf[tidxj * c_clSize + tidxi] += sm_buf[(tidxj + i) * c_clSize + tidxi];
444                 sm_buf[bufStride + tidxj * c_clSize + tidxi] +=
445                         sm_buf[bufStride + (tidxj + i) * c_clSize + tidxi];
446                 sm_buf[2 * bufStride + tidxj * c_clSize + tidxi] +=
447                         sm_buf[2 * bufStride + (tidxj + i) * c_clSize + tidxi];
448             }
449             i >>= 1;
450             itemIdx.barrier(fence_space::local_space);
451         }
452
453         /* i == 1, last reduction step, writing to global mem */
454         /* Split the reduction between the first 3 line threads
455            Threads with line id 0 will do the reduction for (float3).x components
456            Threads with line id 1 will do the reduction for (float3).y components
457            Threads with line id 2 will do the reduction for (float3).z components. */
458         if (tidxj < 3)
459         {
460             const float f =
461                     sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
462             atomicFetchAdd(a_f, 3 * aidx + tidxj, f);
463             if (calcFShift)
464             {
465                 fShiftBuf += f;
466             }
467         }
468         itemIdx.barrier(fence_space::local_space);
469     }
470     /* add up local shift forces into global mem */
471     if (calcFShift)
472     {
473         /* Only threads with tidxj < 3 will update fshift.
474            The threads performing the update must be the same as the threads
475            storing the reduction result above. */
476         if (tidxj < 3)
477         {
478             atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
479         }
480     }
481 }
482
483 /*! \brief Main kernel for NBNXM.
484  *
485  */
486 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
487 auto nbnxmKernel(cl::sycl::handler&                                   cgh,
488                  DeviceAccessor<Float4, mode::read>                   a_xq,
489                  DeviceAccessor<float, mode_atomic>                   a_f,
490                  DeviceAccessor<Float3, mode::read>                   a_shiftVec,
491                  DeviceAccessor<float, mode_atomic>                   a_fShift,
492                  OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyElec,
493                  OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyVdw,
494                  DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
495                  DeviceAccessor<nbnxn_sci_t, mode::read>                                 a_plistSci,
496                  DeviceAccessor<nbnxn_excl_t, mode::read>                    a_plistExcl,
497                  OptionalAccessor<Float2, mode::read, ljComb<vdwType>>       a_ljComb,
498                  OptionalAccessor<int, mode::read, !ljComb<vdwType>>         a_atomTypes,
499                  OptionalAccessor<Float2, mode::read, !ljComb<vdwType>>      a_nbfp,
500                  OptionalAccessor<Float2, mode::read, ljEwald<vdwType>>      a_nbfpComb,
501                  OptionalAccessor<float, mode::read, elecEwaldTab<elecType>> a_coulombTab,
502                  const int                                                   numTypes,
503                  const float                                                 rCoulombSq,
504                  const float                                                 rVdwSq,
505                  const float                                                 twoKRf,
506                  const float                                                 ewaldBeta,
507                  const float                                                 rlistOuterSq,
508                  const float                                                 ewaldShift,
509                  const float                                                 epsFac,
510                  const float                                                 ewaldCoeffLJ,
511                  const float                                                 cRF,
512                  const shift_consts_t                                        dispersionShift,
513                  const shift_consts_t                                        repulsionShift,
514                  const switch_consts_t                                       vdwSwitch,
515                  const float                                                 rVdwSwitch,
516                  const float                                                 ljEwaldShift,
517                  const float                                                 coulombTabScale,
518                  const bool                                                  calcShift)
519 {
520     static constexpr EnergyFunctionProperties<elecType, vdwType> props;
521
522     cgh.require(a_xq);
523     cgh.require(a_f);
524     cgh.require(a_shiftVec);
525     cgh.require(a_fShift);
526     if constexpr (doCalcEnergies)
527     {
528         cgh.require(a_energyElec);
529         cgh.require(a_energyVdw);
530     }
531     cgh.require(a_plistCJ4);
532     cgh.require(a_plistSci);
533     cgh.require(a_plistExcl);
534     if constexpr (!props.vdwComb)
535     {
536         cgh.require(a_atomTypes);
537         cgh.require(a_nbfp);
538     }
539     else
540     {
541         cgh.require(a_ljComb);
542     }
543     if constexpr (props.vdwEwald)
544     {
545         cgh.require(a_nbfpComb);
546     }
547     if constexpr (props.elecEwaldTab)
548     {
549         cgh.require(a_coulombTab);
550     }
551
552     // shmem buffer for i x+q pre-loading
553     cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(
554             cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
555
556     // shmem buffer for force reduction
557     // SYCL-TODO: Make into 3D; section 4.7.6.11 of SYCL2020 specs
558     cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_reductionBuffer(
559             cl::sycl::range<1>(c_clSize * c_clSize * DIM), cgh);
560
561     auto sm_atomTypeI = [&]() {
562         if constexpr (!props.vdwComb)
563         {
564             return cl::sycl::accessor<int, 2, mode::read_write, target::local>(
565                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
566         }
567         else
568         {
569             return nullptr;
570         }
571     }();
572
573     auto sm_ljCombI = [&]() {
574         if constexpr (props.vdwComb)
575         {
576             return cl::sycl::accessor<Float2, 2, mode::read_write, target::local>(
577                     cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
578         }
579         else
580         {
581             return nullptr;
582         }
583     }();
584
585     /* Flag to control the calculation of exclusion forces in the kernel
586      * We do that with Ewald (elec/vdw) and RF. Cut-off only has exclusion
587      * energy terms. */
588     constexpr bool doExclusionForces =
589             (props.elecEwald || props.elecRF || props.vdwEwald || (props.elecCutoff && doCalcEnergies));
590
591     // The post-prune j-i cluster-pair organization is linked to how exclusion and interaction mask data is stored.
592     // Currently this is ideally suited for 32-wide subgroup size but slightly less so for others,
593     // e.g. subGroupSize > prunedClusterPairSize on AMD GCN / CDNA.
594     // Hence, the two are decoupled.
595     constexpr int prunedClusterPairSize = c_clSize * c_splitClSize;
596 #if defined(HIPSYCL_PLATFORM_ROCM) // SYCL-TODO AMD RDNA/RDNA2 has 32-wide exec; how can we check for that?
597     gmx_unused constexpr int subGroupSize = c_clSize * c_clSize;
598 #else
599     gmx_unused constexpr int subGroupSize = prunedClusterPairSize;
600 #endif
601
602     return [=](cl::sycl::nd_item<1> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
603     {
604         /* thread/block/warp id-s */
605         const cl::sycl::id<3> localId = unflattenId<c_clSize, c_clSize>(itemIdx.get_local_id());
606         const unsigned        tidxi   = localId[0];
607         const unsigned        tidxj   = localId[1];
608         const unsigned        tidx    = tidxj * c_clSize + tidxi;
609         const unsigned        tidxz   = 0;
610
611         // Group indexing was flat originally, no need to unflatten it.
612         const unsigned bidx = itemIdx.get_group(0);
613
614         const sycl_2020::sub_group sg = itemIdx.get_sub_group();
615         // Could use sg.get_group_range to compute the imask & exclusion Idx, but too much of the logic relies on it anyway
616         // and in cases where prunedClusterPairSize != subGroupSize we can't use it anyway
617         const unsigned imeiIdx = tidx / prunedClusterPairSize;
618
619         Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster]; // i force buffer
620         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
621         {
622             fCiBuf[i] = Float3(0.0F, 0.0F, 0.0F);
623         }
624
625         const nbnxn_sci_t nbSci     = a_plistSci[bidx];
626         const int         sci       = nbSci.sci;
627         const int         cij4Start = nbSci.cj4_ind_start;
628         const int         cij4End   = nbSci.cj4_ind_end;
629
630         // Only needed if props.elecEwaldAna
631         const float beta2 = ewaldBeta * ewaldBeta;
632         const float beta3 = ewaldBeta * ewaldBeta * ewaldBeta;
633
634         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i += c_clSize)
635         {
636             /* Pre-load i-atom x and q into shared memory */
637             const int             ci       = sci * c_nbnxnGpuNumClusterPerSupercluster + tidxj + i;
638             const int             ai       = ci * c_clSize + tidxi;
639             const cl::sycl::id<2> cacheIdx = cl::sycl::id<2>(tidxj + i, tidxi);
640
641             const Float3 shift = a_shiftVec[nbSci.shift];
642             Float4       xqi   = a_xq[ai];
643             xqi += Float4(shift[0], shift[1], shift[2], 0.0F);
644             xqi[3] *= epsFac;
645             sm_xq[cacheIdx] = xqi;
646
647             if constexpr (!props.vdwComb)
648             {
649                 // Pre-load the i-atom types into shared memory
650                 sm_atomTypeI[cacheIdx] = a_atomTypes[ai];
651             }
652             else
653             {
654                 // Pre-load the LJ combination parameters into shared memory
655                 sm_ljCombI[cacheIdx] = a_ljComb[ai];
656             }
657         }
658         itemIdx.barrier(fence_space::local_space);
659
660         float ewaldCoeffLJ_2, ewaldCoeffLJ_6_6; // Only needed if (props.vdwEwald)
661         if constexpr (props.vdwEwald)
662         {
663             ewaldCoeffLJ_2   = ewaldCoeffLJ * ewaldCoeffLJ;
664             ewaldCoeffLJ_6_6 = ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * c_oneSixth;
665         }
666
667         float energyVdw, energyElec; // Only needed if (doCalcEnergies)
668         if constexpr (doCalcEnergies)
669         {
670             energyVdw = energyElec = 0.0F;
671         }
672         if constexpr (doCalcEnergies && doExclusionForces)
673         {
674             if (nbSci.shift == gmx::c_centralShiftIndex
675                 && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
676             {
677                 // we have the diagonal: add the charge and LJ self interaction energy term
678                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
679                 {
680                     // TODO: Are there other options?
681                     if constexpr (props.elecEwald || props.elecRF || props.elecCutoff)
682                     {
683                         const float qi = sm_xq[i][tidxi][3];
684                         energyElec += qi * qi;
685                     }
686                     if constexpr (props.vdwEwald)
687                     {
688                         energyVdw +=
689                                 a_nbfp[a_atomTypes[(sci * c_nbnxnGpuNumClusterPerSupercluster + i) * c_clSize + tidxi]
690                                        * (numTypes + 1)][0];
691                     }
692                 }
693                 /* divide the self term(s) equally over the j-threads, then multiply with the coefficients. */
694                 if constexpr (props.vdwEwald)
695                 {
696                     energyVdw /= c_clSize;
697                     energyVdw *= 0.5F * c_oneSixth * ewaldCoeffLJ_6_6; // c_OneTwelfth?
698                 }
699                 if constexpr (props.elecRF || props.elecCutoff)
700                 {
701                     // Correct for epsfac^2 due to adding qi^2 */
702                     energyElec /= epsFac * c_clSize;
703                     energyElec *= -0.5F * cRF;
704                 }
705                 if constexpr (props.elecEwald)
706                 {
707                     // Correct for epsfac^2 due to adding qi^2 */
708                     energyElec /= epsFac * c_clSize;
709                     energyElec *= -ewaldBeta * c_OneOverSqrtPi; /* last factor 1/sqrt(pi) */
710                 }
711             } // (nbSci.shift == gmx::c_centralShiftIndex && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
712         }     // (doCalcEnergies && doExclusionForces)
713
714         // Only needed if (doExclusionForces)
715         const bool nonSelfInteraction = !(nbSci.shift == gmx::c_centralShiftIndex & tidxj <= tidxi);
716
717         // loop over the j clusters = seen by any of the atoms in the current super-cluster
718         for (int j4 = cij4Start + tidxz; j4 < cij4End; j4 += 1)
719         {
720             unsigned imask = a_plistCJ4[j4].imei[imeiIdx].imask;
721             if (!doPruneNBL && !imask)
722             {
723                 continue;
724             }
725             const int wexclIdx = a_plistCJ4[j4].imei[imeiIdx].excl_ind;
726             static_assert(gmx::isPowerOfTwo(prunedClusterPairSize));
727             const unsigned wexcl = a_plistExcl[wexclIdx].pair[tidx & (prunedClusterPairSize - 1)];
728             for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
729             {
730                 const bool maskSet =
731                         imask & (superClInteractionMask << (jm * c_nbnxnGpuNumClusterPerSupercluster));
732                 if (!maskSet)
733                 {
734                     continue;
735                 }
736                 unsigned  maskJI = (1U << (jm * c_nbnxnGpuNumClusterPerSupercluster));
737                 const int cj     = a_plistCJ4[j4].cj[jm];
738                 const int aj     = cj * c_clSize + tidxj;
739
740                 // load j atom data
741                 const Float4 xqj = a_xq[aj];
742
743                 const Float3 xj(xqj[0], xqj[1], xqj[2]);
744                 const float  qj = xqj[3];
745                 int          atomTypeJ; // Only needed if (!props.vdwComb)
746                 Float2       ljCombJ;   // Only needed if (props.vdwComb)
747                 if constexpr (props.vdwComb)
748                 {
749                     ljCombJ = a_ljComb[aj];
750                 }
751                 else
752                 {
753                     atomTypeJ = a_atomTypes[aj];
754                 }
755
756                 Float3 fCjBuf(0.0F, 0.0F, 0.0F);
757
758                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
759                 {
760                     if (imask & maskJI)
761                     {
762                         // i cluster index
763                         const int ci = sci * c_nbnxnGpuNumClusterPerSupercluster + i;
764                         // all threads load an atom from i cluster ci into shmem!
765                         const Float4 xqi = sm_xq[i][tidxi];
766                         const Float3 xi(xqi[0], xqi[1], xqi[2]);
767
768                         // distance between i and j atoms
769                         const Float3 rv = xi - xj;
770                         float        r2 = norm2(rv);
771
772                         if constexpr (doPruneNBL)
773                         {
774                             /* If _none_ of the atoms pairs are in cutoff range,
775                              * the bit corresponding to the current
776                              * cluster-pair in imask gets set to 0. */
777                             if (!sycl_2020::group_any_of(sg, r2 < rlistOuterSq))
778                             {
779                                 imask &= ~maskJI;
780                             }
781                         }
782                         const float pairExclMask = (wexcl & maskJI) ? 1.0F : 0.0F;
783
784                         // cutoff & exclusion check
785
786                         const bool notExcluded = doExclusionForces ? (nonSelfInteraction | (ci != cj))
787                                                                    : (wexcl & maskJI);
788
789                         // SYCL-TODO: Check optimal way of branching here.
790                         if ((r2 < rCoulombSq) && notExcluded)
791                         {
792                             const float qi = xqi[3];
793                             int         atomTypeI; // Only needed if (!props.vdwComb)
794                             float       sigma, epsilon;
795                             Float2      c6c12;
796
797                             if constexpr (!props.vdwComb)
798                             {
799                                 /* LJ 6*C6 and 12*C12 */
800                                 atomTypeI = sm_atomTypeI[i][tidxi];
801                                 c6c12     = a_nbfp[numTypes * atomTypeI + atomTypeJ];
802                             }
803                             else
804                             {
805                                 const Float2 ljCombI = sm_ljCombI[i][tidxi];
806                                 if constexpr (props.vdwCombGeom)
807                                 {
808                                     c6c12 = Float2(ljCombI[0] * ljCombJ[0], ljCombI[1] * ljCombJ[1]);
809                                 }
810                                 else
811                                 {
812                                     static_assert(props.vdwCombLB);
813                                     // LJ 2^(1/6)*sigma and 12*epsilon
814                                     sigma   = ljCombI[0] + ljCombJ[0];
815                                     epsilon = ljCombI[1] * ljCombJ[1];
816                                     if constexpr (doCalcEnergies)
817                                     {
818                                         c6c12 = convertSigmaEpsilonToC6C12(sigma, epsilon);
819                                     }
820                                 } // props.vdwCombGeom
821                             }     // !props.vdwComb
822
823                             // c6 and c12 are unused and garbage iff props.vdwCombLB && !doCalcEnergies
824                             const float c6  = c6c12[0];
825                             const float c12 = c6c12[1];
826
827                             // Ensure distance do not become so small that r^-12 overflows
828                             r2 = std::max(r2, c_nbnxnMinDistanceSquared);
829 #if GMX_SYCL_HIPSYCL
830                             // No fast/native functions in some compilation passes
831                             const float rInv = cl::sycl::rsqrt(r2);
832 #else
833                             // SYCL-TODO: sycl::half_precision::rsqrt?
834                             const float rInv = cl::sycl::native::rsqrt(r2);
835 #endif
836                             const float r2Inv = rInv * rInv;
837                             float       r6Inv, fInvR, energyLJPair;
838                             if constexpr (!props.vdwCombLB || doCalcEnergies)
839                             {
840                                 r6Inv = r2Inv * r2Inv * r2Inv;
841                                 if constexpr (doExclusionForces)
842                                 {
843                                     // SYCL-TODO: Check if true for SYCL
844                                     /* We could mask r2Inv, but with Ewald masking both
845                                      * r6Inv and fInvR is faster */
846                                     r6Inv *= pairExclMask;
847                                 }
848                                 fInvR = r6Inv * (c12 * r6Inv - c6) * r2Inv;
849                             }
850                             else
851                             {
852                                 float sig_r  = sigma * rInv;
853                                 float sig_r2 = sig_r * sig_r;
854                                 float sig_r6 = sig_r2 * sig_r2 * sig_r2;
855                                 if constexpr (doExclusionForces)
856                                 {
857                                     sig_r6 *= pairExclMask;
858                                 }
859                                 fInvR = epsilon * sig_r6 * (sig_r6 - 1.0F) * r2Inv;
860                             } // (!props.vdwCombLB || doCalcEnergies)
861                             if constexpr (doCalcEnergies || props.vdwPSwitch)
862                             {
863                                 energyLJPair = pairExclMask
864                                                * (c12 * (r6Inv * r6Inv + repulsionShift.cpot) * c_oneTwelfth
865                                                   - c6 * (r6Inv + dispersionShift.cpot) * c_oneSixth);
866                             }
867                             if constexpr (props.vdwFSwitch)
868                             {
869                                 ljForceSwitch<doCalcEnergies>(
870                                         dispersionShift, repulsionShift, rVdwSwitch, c6, c12, rInv, r2, &fInvR, &energyLJPair);
871                             }
872                             if constexpr (props.vdwEwald)
873                             {
874                                 ljEwaldComb<doCalcEnergies, vdwType>(a_nbfpComb,
875                                                                      ljEwaldShift,
876                                                                      atomTypeI,
877                                                                      atomTypeJ,
878                                                                      r2,
879                                                                      r2Inv,
880                                                                      ewaldCoeffLJ_2,
881                                                                      ewaldCoeffLJ_6_6,
882                                                                      pairExclMask,
883                                                                      &fInvR,
884                                                                      &energyLJPair);
885                             } // (props.vdwEwald)
886                             if constexpr (props.vdwPSwitch)
887                             {
888                                 ljPotentialSwitch<doCalcEnergies>(
889                                         vdwSwitch, rVdwSwitch, rInv, r2, &fInvR, &energyLJPair);
890                             }
891                             if constexpr (props.elecEwaldTwin)
892                             {
893                                 // Separate VDW cut-off check to enable twin-range cut-offs
894                                 // (rVdw < rCoulomb <= rList)
895                                 const float vdwInRange = (r2 < rVdwSq) ? 1.0F : 0.0F;
896                                 fInvR *= vdwInRange;
897                                 if constexpr (doCalcEnergies)
898                                 {
899                                     energyLJPair *= vdwInRange;
900                                 }
901                             }
902                             if constexpr (doCalcEnergies)
903                             {
904                                 energyVdw += energyLJPair;
905                             }
906
907                             if constexpr (props.elecCutoff)
908                             {
909                                 if constexpr (doExclusionForces)
910                                 {
911                                     fInvR += qi * qj * pairExclMask * r2Inv * rInv;
912                                 }
913                                 else
914                                 {
915                                     fInvR += qi * qj * r2Inv * rInv;
916                                 }
917                             }
918                             if constexpr (props.elecRF)
919                             {
920                                 fInvR += qi * qj * (pairExclMask * r2Inv * rInv - twoKRf);
921                             }
922                             if constexpr (props.elecEwaldAna)
923                             {
924                                 fInvR += qi * qj
925                                          * (pairExclMask * r2Inv * rInv + pmeCorrF(beta2 * r2) * beta3);
926                             }
927                             if constexpr (props.elecEwaldTab)
928                             {
929                                 fInvR += qi * qj
930                                          * (pairExclMask * r2Inv
931                                             - interpolateCoulombForceR(
932                                                     a_coulombTab, coulombTabScale, r2 * rInv))
933                                          * rInv;
934                             }
935
936                             if constexpr (doCalcEnergies)
937                             {
938                                 if constexpr (props.elecCutoff)
939                                 {
940                                     energyElec += qi * qj * (pairExclMask * rInv - cRF);
941                                 }
942                                 if constexpr (props.elecRF)
943                                 {
944                                     energyElec +=
945                                             qi * qj * (pairExclMask * rInv + 0.5f * twoKRf * r2 - cRF);
946                                 }
947                                 if constexpr (props.elecEwald)
948                                 {
949                                     energyElec +=
950                                             qi * qj
951                                             * (rInv * (pairExclMask - cl::sycl::erf(r2 * rInv * ewaldBeta))
952                                                - pairExclMask * ewaldShift);
953                                 }
954                             }
955
956                             const Float3 forceIJ = rv * fInvR;
957
958                             /* accumulate j forces in registers */
959                             fCjBuf -= forceIJ;
960                             /* accumulate i forces in registers */
961                             fCiBuf[i] += forceIJ;
962                         } // (r2 < rCoulombSq) && notExcluded
963                     }     // (imask & maskJI)
964                     /* shift the mask bit by 1 */
965                     maskJI += maskJI;
966                 } // for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
967                 /* reduce j forces */
968                 reduceForceJ(sm_reductionBuffer, fCjBuf, itemIdx, tidxi, tidxj, aj, a_f);
969             } // for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
970             if constexpr (doPruneNBL)
971             {
972                 /* Update the imask with the new one which does not contain the
973                  * out of range clusters anymore. */
974                 a_plistCJ4[j4].imei[imeiIdx].imask = imask;
975             }
976         } // for (int j4 = cij4Start; j4 < cij4End; j4 += 1)
977
978         /* skip central shifts when summing shift forces */
979         const bool doCalcShift = (calcShift && !(nbSci.shift == gmx::c_centralShiftIndex));
980
981         reduceForceIAndFShift(
982                 sm_reductionBuffer, fCiBuf, doCalcShift, itemIdx, tidxi, tidxj, sci, nbSci.shift, a_f, a_fShift);
983
984         if constexpr (doCalcEnergies)
985         {
986             const float energyVdwGroup = sycl_2020::group_reduce(
987                     itemIdx.get_group(), energyVdw, 0.0F, sycl_2020::plus<float>());
988             const float energyElecGroup = sycl_2020::group_reduce(
989                     itemIdx.get_group(), energyElec, 0.0F, sycl_2020::plus<float>());
990
991             if (tidx == 0)
992             {
993                 atomicFetchAdd(a_energyVdw, 0, energyVdwGroup);
994                 atomicFetchAdd(a_energyElec, 0, energyElecGroup);
995             }
996         }
997     };
998 }
999
1000 // SYCL 1.2.1 requires providing a unique type for a kernel. Should not be needed for SYCL2020.
1001 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
1002 class NbnxmKernelName;
1003
1004 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType, class... Args>
1005 cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int numSci, Args&&... args)
1006 {
1007     // Should not be needed for SYCL2020.
1008     using kernelNameType = NbnxmKernelName<doPruneNBL, doCalcEnergies, elecType, vdwType>;
1009
1010     /* Kernel launch config:
1011      * - The thread block dimensions match the size of i-clusters, j-clusters,
1012      *   and j-cluster concurrency, in x, y, and z, respectively.
1013      * - The 1D block-grid contains as many blocks as super-clusters.
1014      */
1015     const int                   numBlocks = numSci;
1016     const cl::sycl::range<3>    blockSize{ c_clSize, c_clSize, 1 };
1017     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
1018     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
1019
1020     cl::sycl::queue q = deviceStream.stream();
1021
1022     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
1023         auto kernel = nbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>(
1024                 cgh, std::forward<Args>(args)...);
1025         cgh.parallel_for<kernelNameType>(flattenNDRange(range), kernel);
1026     });
1027
1028     return e;
1029 }
1030
1031 template<class... Args>
1032 cl::sycl::event chooseAndLaunchNbnxmKernel(bool          doPruneNBL,
1033                                            bool          doCalcEnergies,
1034                                            enum ElecType elecType,
1035                                            enum VdwType  vdwType,
1036                                            Args&&... args)
1037 {
1038     return gmx::dispatchTemplatedFunction(
1039             [&](auto doPruneNBL_, auto doCalcEnergies_, auto elecType_, auto vdwType_) {
1040                 return launchNbnxmKernel<doPruneNBL_, doCalcEnergies_, elecType_, vdwType_>(
1041                         std::forward<Args>(args)...);
1042             },
1043             doPruneNBL,
1044             doCalcEnergies,
1045             elecType,
1046             vdwType);
1047 }
1048
1049 void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const InteractionLocality iloc)
1050 {
1051     NBAtomDataGpu*      adat         = nb->atdat;
1052     NBParamGpu*         nbp          = nb->nbparam;
1053     gpu_plist*          plist        = nb->plist[iloc];
1054     const bool          doPruneNBL   = (plist->haveFreshList && !nb->didPrune[iloc]);
1055     const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
1056
1057     // Casting to float simplifies using atomic ops in the kernel
1058     cl::sycl::buffer<Float3, 1> f(*adat->f.buffer_);
1059     auto                        fAsFloat = f.reinterpret<float, 1>(f.get_count() * DIM);
1060     cl::sycl::buffer<Float3, 1> fShift(*adat->fShift.buffer_);
1061     auto fShiftAsFloat = fShift.reinterpret<float, 1>(fShift.get_count() * DIM);
1062
1063     cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL,
1064                                                    stepWork.computeEnergy,
1065                                                    nbp->elecType,
1066                                                    nbp->vdwType,
1067                                                    deviceStream,
1068                                                    plist->nsci,
1069                                                    adat->xq,
1070                                                    fAsFloat,
1071                                                    adat->shiftVec,
1072                                                    fShiftAsFloat,
1073                                                    adat->eElec,
1074                                                    adat->eLJ,
1075                                                    plist->cj4,
1076                                                    plist->sci,
1077                                                    plist->excl,
1078                                                    adat->ljComb,
1079                                                    adat->atomTypes,
1080                                                    nbp->nbfp,
1081                                                    nbp->nbfp_comb,
1082                                                    nbp->coulomb_tab,
1083                                                    adat->numTypes,
1084                                                    nbp->rcoulomb_sq,
1085                                                    nbp->rvdw_sq,
1086                                                    nbp->two_k_rf,
1087                                                    nbp->ewald_beta,
1088                                                    nbp->rlistOuter_sq,
1089                                                    nbp->sh_ewald,
1090                                                    nbp->epsfac,
1091                                                    nbp->ewaldcoeff_lj,
1092                                                    nbp->c_rf,
1093                                                    nbp->dispersion_shift,
1094                                                    nbp->repulsion_shift,
1095                                                    nbp->vdw_switch,
1096                                                    nbp->rvdw_switch,
1097                                                    nbp->sh_lj_ewald,
1098                                                    nbp->coulomb_tab_scale,
1099                                                    stepWork.computeVirial);
1100 }
1101
1102 } // namespace Nbnxm