2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2020,2021, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
40 * \ingroup module_nbnxm
44 #include "nbnxm_sycl_kernel.h"
46 #include "gromacs/gpu_utils/devicebuffer.h"
47 #include "gromacs/gpu_utils/gmxsycl.h"
48 #include "gromacs/math/functions.h"
49 #include "gromacs/mdtypes/simulation_workload.h"
50 #include "gromacs/pbcutil/ishift.h"
51 #include "gromacs/utility/template_mp.h"
53 #include "nbnxm_sycl_kernel_utils.h"
54 #include "nbnxm_sycl_types.h"
56 //! \brief Class name for NBNXM kernel
57 template<bool doPruneNBL, bool doCalcEnergies, enum Nbnxm::ElecType elecType, enum Nbnxm::VdwType vdwType>
63 //! \brief Set of boolean constants mimicking preprocessor macros.
64 template<enum ElecType elecType, enum VdwType vdwType>
65 struct EnergyFunctionProperties {
66 static constexpr bool elecCutoff = (elecType == ElecType::Cut); ///< EL_CUTOFF
67 static constexpr bool elecRF = (elecType == ElecType::RF); ///< EL_RF
68 static constexpr bool elecEwaldAna =
69 (elecType == ElecType::EwaldAna || elecType == ElecType::EwaldAnaTwin); ///< EL_EWALD_ANA
70 static constexpr bool elecEwaldTab =
71 (elecType == ElecType::EwaldTab || elecType == ElecType::EwaldTabTwin); ///< EL_EWALD_TAB
72 static constexpr bool elecEwaldTwin =
73 (elecType == ElecType::EwaldAnaTwin || elecType == ElecType::EwaldTabTwin); ///< Use twin cut-off.
74 static constexpr bool elecEwald = (elecEwaldAna || elecEwaldTab); ///< EL_EWALD_ANY
75 static constexpr bool vdwCombLB = (vdwType == VdwType::CutCombLB); ///< LJ_COMB && !LJ_COMB_GEOM
76 static constexpr bool vdwCombGeom = (vdwType == VdwType::CutCombGeom); ///< LJ_COMB_GEOM
77 static constexpr bool vdwComb = (vdwCombLB || vdwCombGeom); ///< LJ_COMB
78 static constexpr bool vdwEwaldCombGeom = (vdwType == VdwType::EwaldGeom); ///< LJ_EWALD_COMB_GEOM
79 static constexpr bool vdwEwaldCombLB = (vdwType == VdwType::EwaldLB); ///< LJ_EWALD_COMB_LB
80 static constexpr bool vdwEwald = (vdwEwaldCombGeom || vdwEwaldCombLB); ///< LJ_EWALD
81 static constexpr bool vdwFSwitch = (vdwType == VdwType::FSwitch); ///< LJ_FORCE_SWITCH
82 static constexpr bool vdwPSwitch = (vdwType == VdwType::PSwitch); ///< LJ_POT_SWITCH
85 //! \brief Templated constants to shorten kernel function declaration.
87 template<enum VdwType vdwType>
88 constexpr bool ljComb = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwComb;
90 template<enum ElecType elecType>
91 constexpr bool elecEwald = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwald;
93 template<enum ElecType elecType>
94 constexpr bool elecEwaldTab = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTab;
96 template<enum VdwType vdwType>
97 constexpr bool ljEwald = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwEwald;
100 using cl::sycl::access::fence_space;
101 using cl::sycl::access::mode;
102 using cl::sycl::access::target;
104 //! \brief Convert \p sigma and \p epsilon VdW parameters to \c c6,c12 pair.
105 static inline Float2 convertSigmaEpsilonToC6C12(const float sigma, const float epsilon)
107 const float sigma2 = sigma * sigma;
108 const float sigma6 = sigma2 * sigma2 * sigma2;
109 const float c6 = epsilon * sigma6;
110 const float c12 = c6 * sigma6;
115 //! \brief Calculate force and energy for a pair of atoms, VdW force-switch flavor.
116 template<bool doCalcEnergies>
117 static inline void ljForceSwitch(const shift_consts_t dispersionShift,
118 const shift_consts_t repulsionShift,
119 const float rVdwSwitch,
124 cl::sycl::private_ptr<float> fInvR,
125 cl::sycl::private_ptr<float> eLJ)
127 /* force switch constants */
128 const float dispShiftV2 = dispersionShift.c2;
129 const float dispShiftV3 = dispersionShift.c3;
130 const float repuShiftV2 = repulsionShift.c2;
131 const float repuShiftV3 = repulsionShift.c3;
133 const float r = r2 * rInv;
134 const float rSwitch = cl::sycl::fdim(r, rVdwSwitch); // max(r - rVdwSwitch, 0)
136 *fInvR += -c6 * (dispShiftV2 + dispShiftV3 * rSwitch) * rSwitch * rSwitch * rInv
137 + c12 * (repuShiftV2 + repuShiftV3 * rSwitch) * rSwitch * rSwitch * rInv;
139 if constexpr (doCalcEnergies)
141 const float dispShiftF2 = dispShiftV2 / 3;
142 const float dispShiftF3 = dispShiftV3 / 4;
143 const float repuShiftF2 = repuShiftV2 / 3;
144 const float repuShiftF3 = repuShiftV3 / 4;
145 *eLJ += c6 * (dispShiftF2 + dispShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch
146 - c12 * (repuShiftF2 + repuShiftF3 * rSwitch) * rSwitch * rSwitch * rSwitch;
150 //! \brief Fetch C6 grid contribution coefficients and return the product of these.
151 template<enum VdwType vdwType>
152 static inline float calculateLJEwaldC6Grid(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
156 if constexpr (vdwType == VdwType::EwaldGeom)
158 return a_nbfpComb[typeI][0] * a_nbfpComb[typeJ][0];
162 static_assert(vdwType == VdwType::EwaldLB);
163 /* sigma and epsilon are scaled to give 6*C6 */
164 const Float2 c6c12_i = a_nbfpComb[typeI];
165 const Float2 c6c12_j = a_nbfpComb[typeJ];
167 const float sigma = c6c12_i[0] + c6c12_j[0];
168 const float epsilon = c6c12_i[1] * c6c12_j[1];
170 const float sigma2 = sigma * sigma;
171 return epsilon * sigma2 * sigma2 * sigma2;
175 //! Calculate LJ-PME grid force contribution with geometric or LB combination rule.
176 template<bool doCalcEnergies, enum VdwType vdwType>
177 static inline void ljEwaldComb(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
178 const float sh_lj_ewald,
183 const float lje_coeff2,
184 const float lje_coeff6_6,
186 cl::sycl::private_ptr<float> fInvR,
187 cl::sycl::private_ptr<float> eLJ)
189 const float c6grid = calculateLJEwaldC6Grid<vdwType>(a_nbfpComb, typeI, typeJ);
191 /* Recalculate inv_r6 without exclusion mask */
192 const float inv_r6_nm = r2Inv * r2Inv * r2Inv;
193 const float cr2 = lje_coeff2 * r2;
194 const float expmcr2 = cl::sycl::exp(-cr2);
195 const float poly = 1.0F + cr2 + 0.5F * cr2 * cr2;
197 /* Subtract the grid force from the total LJ force */
198 *fInvR += c6grid * (inv_r6_nm - expmcr2 * (inv_r6_nm * poly + lje_coeff6_6)) * r2Inv;
200 if constexpr (doCalcEnergies)
202 /* Shift should be applied only to real LJ pairs */
203 const float sh_mask = sh_lj_ewald * int_bit;
204 *eLJ += c_oneSixth * c6grid * (inv_r6_nm * (1.0F - expmcr2 * poly) + sh_mask);
208 /*! \brief Apply potential switch. */
209 template<bool doCalcEnergies>
210 static inline void ljPotentialSwitch(const switch_consts_t vdwSwitch,
211 const float rVdwSwitch,
214 cl::sycl::private_ptr<float> fInvR,
215 cl::sycl::private_ptr<float> eLJ)
217 /* potential switch constants */
218 const float switchV3 = vdwSwitch.c3;
219 const float switchV4 = vdwSwitch.c4;
220 const float switchV5 = vdwSwitch.c5;
221 const float switchF2 = 3 * vdwSwitch.c3;
222 const float switchF3 = 4 * vdwSwitch.c4;
223 const float switchF4 = 5 * vdwSwitch.c5;
225 const float r = r2 * rInv;
226 const float rSwitch = r - rVdwSwitch;
231 1.0F + (switchV3 + (switchV4 + switchV5 * rSwitch) * rSwitch) * rSwitch * rSwitch * rSwitch;
232 const float dsw = (switchF2 + (switchF3 + switchF4 * rSwitch) * rSwitch) * rSwitch * rSwitch;
234 *fInvR = (*fInvR) * sw - rInv * (*eLJ) * dsw;
235 if constexpr (doCalcEnergies)
243 /*! \brief Calculate analytical Ewald correction term. */
244 static inline float pmeCorrF(const float z2)
246 constexpr float FN6 = -1.7357322914161492954e-8F;
247 constexpr float FN5 = 1.4703624142580877519e-6F;
248 constexpr float FN4 = -0.000053401640219807709149F;
249 constexpr float FN3 = 0.0010054721316683106153F;
250 constexpr float FN2 = -0.019278317264888380590F;
251 constexpr float FN1 = 0.069670166153766424023F;
252 constexpr float FN0 = -0.75225204789749321333F;
254 constexpr float FD4 = 0.0011193462567257629232F;
255 constexpr float FD3 = 0.014866955030185295499F;
256 constexpr float FD2 = 0.11583842382862377919F;
257 constexpr float FD1 = 0.50736591960530292870F;
258 constexpr float FD0 = 1.0F;
260 const float z4 = z2 * z2;
262 float polyFD0 = FD4 * z4 + FD2;
263 const float polyFD1 = FD3 * z4 + FD1;
264 polyFD0 = polyFD0 * z4 + FD0;
265 polyFD0 = polyFD1 * z2 + polyFD0;
267 polyFD0 = 1.0F / polyFD0;
269 float polyFN0 = FN6 * z4 + FN4;
270 float polyFN1 = FN5 * z4 + FN3;
271 polyFN0 = polyFN0 * z4 + FN2;
272 polyFN1 = polyFN1 * z4 + FN1;
273 polyFN0 = polyFN0 * z4 + FN0;
274 polyFN0 = polyFN1 * z2 + polyFN0;
276 return polyFN0 * polyFD0;
279 /*! \brief Linear interpolation using exactly two FMA operations.
281 * Implements numeric equivalent of: (1-t)*d0 + t*d1.
284 static inline T lerp(T d0, T d1, T t)
286 return fma(t, d1, fma(-t, d0, d0));
289 /*! \brief Interpolate Ewald coulomb force correction using the F*r table. */
290 static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::read> a_coulombTab,
291 const float coulombTabScale,
294 const float normalized = coulombTabScale * r;
295 const int index = static_cast<int>(normalized);
296 const float fraction = normalized - index;
298 const float left = a_coulombTab[index];
299 const float right = a_coulombTab[index + 1];
301 return lerp(left, right, fraction); // TODO: cl::sycl::mix
304 /*! \brief Reduce c_clSize j-force components using shifts and atomically accumulate into a_f.
306 * c_clSize consecutive threads hold the force components of a j-atom which we
307 * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
309 static inline void reduceForceJShuffle(Float3 f,
310 const cl::sycl::nd_item<3> itemIdx,
313 DeviceAccessor<Float3, mode::read_write> a_f)
315 static_assert(c_clSize == 8 || c_clSize == 4);
316 sycl_2020::sub_group sg = itemIdx.get_sub_group();
318 f[0] += sycl_2020::shift_left(sg, f[0], 1);
319 f[1] += sycl_2020::shift_right(sg, f[1], 1);
320 f[2] += sycl_2020::shift_left(sg, f[2], 1);
326 f[0] += sycl_2020::shift_left(sg, f[0], 2);
327 f[2] += sycl_2020::shift_right(sg, f[2], 2);
333 if constexpr (c_clSize == 8)
335 f[0] += sycl_2020::shift_left(sg, f[0], 4);
340 atomicFetchAdd(a_f[aidx][tidxi], f[0]);
345 * \brief Do workgroup-level reduction of a single \c float.
347 * While SYCL has \c sycl::reduce_over_group, it currently (oneAPI 2021.3.0) uses a very large
348 * shared memory buffer, which leads to a reduced occupancy.
350 * \note The caller must make sure there are no races when reusing the \p sm_buf.
352 * \tparam subGroupSize Size of a sub-group.
353 * \tparam groupSize Size of a work-group.
354 * \param itemIdx Current thread's \c sycl::nd_item.
355 * \param tidxi Current thread's linearized local index.
356 * \param sm_buf Accessor for local reduction buffer.
357 * \param valueToReduce Current thread's value. Must have length of at least 1.
358 * \return For thread with \p tidxi 0: sum of all \p valueToReduce. Other threads: unspecified.
360 template<int subGroupSize, int groupSize>
361 static inline float groupReduce(const cl::sycl::nd_item<3> itemIdx,
362 const unsigned int tidxi,
363 cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
366 constexpr int numSubGroupsInGroup = groupSize / subGroupSize;
367 static_assert(numSubGroupsInGroup == 1 || numSubGroupsInGroup == 2);
368 sycl_2020::sub_group sg = itemIdx.get_sub_group();
369 valueToReduce = sycl_2020::group_reduce(sg, valueToReduce, sycl_2020::plus<float>());
370 // If we have two sub-groups, we should reduce across them.
371 if constexpr (numSubGroupsInGroup == 2)
373 if (tidxi == subGroupSize)
375 sm_buf[0] = valueToReduce;
377 itemIdx.barrier(fence_space::local_space);
380 valueToReduce += sm_buf[0];
383 return valueToReduce;
386 /*! \brief Reduce c_clSize j-force components using local memory and atomically accumulate into a_f.
388 * c_clSize consecutive threads hold the force components of a j-atom which we
389 * reduced in cl_Size steps using shift and atomically accumulate them into \p a_f.
391 * TODO: implement binary reduction flavor for the case where cl_Size is power of two.
393 static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
395 const cl::sycl::nd_item<3> itemIdx,
399 DeviceAccessor<Float3, mode::read_write> a_f)
401 static constexpr int sc_fBufferStride = c_clSizeSq;
402 int tidx = tidxi + tidxj * c_clSize;
403 sm_buf[0 * sc_fBufferStride + tidx] = f[0];
404 sm_buf[1 * sc_fBufferStride + tidx] = f[1];
405 sm_buf[2 * sc_fBufferStride + tidx] = f[2];
407 subGroupBarrier(itemIdx);
409 // reducing data 8-by-by elements on the leader of same threads as those storing above
410 assert(itemIdx.get_sub_group().get_local_range().size() >= c_clSize);
415 for (int j = tidxj * c_clSize; j < (tidxj + 1) * c_clSize; j++)
417 fSum += sm_buf[sc_fBufferStride * tidxi + j];
420 atomicFetchAdd(a_f[aidx][tidxi], fSum);
425 /*! \brief Reduce c_clSize j-force components using either shifts or local memory and atomically accumulate into a_f.
427 static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
429 const cl::sycl::nd_item<3> itemIdx,
433 DeviceAccessor<Float3, mode::read_write> a_f)
435 if constexpr (!gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster))
437 reduceForceJGeneric(sm_buf, f, itemIdx, tidxi, tidxj, aidx, a_f);
441 reduceForceJShuffle(f, itemIdx, tidxi, aidx, a_f);
446 /*! \brief Final i-force reduction.
448 * Reduce c_nbnxnGpuNumClusterPerSupercluster i-force components stored in \p fCiBuf[]
449 * accumulating atomically into \p a_f.
450 * If \p calcFShift is true, further reduce shift forces and atomically accumulate into \p a_fShift.
452 * This implementation works only with power of two array sizes.
454 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
455 const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
456 const bool calcFShift,
457 const cl::sycl::nd_item<3> itemIdx,
462 DeviceAccessor<Float3, mode::read_write> a_f,
463 DeviceAccessor<Float3, mode::read_write> a_fShift)
465 // must have power of two elements in fCiBuf
466 static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster));
468 static constexpr int bufStride = c_clSize * c_clSize;
469 static constexpr int clSizeLog2 = gmx::StaticLog2<c_clSize>::value;
470 const int tidx = tidxi + tidxj * c_clSize;
471 float fShiftBuf = 0.0F;
472 for (int ciOffset = 0; ciOffset < c_nbnxnGpuNumClusterPerSupercluster; ciOffset++)
474 const int aidx = (sci * c_nbnxnGpuNumClusterPerSupercluster + ciOffset) * c_clSize + tidxi;
475 /* store i forces in shmem */
476 sm_buf[tidx] = fCiBuf[ciOffset][0];
477 sm_buf[bufStride + tidx] = fCiBuf[ciOffset][1];
478 sm_buf[2 * bufStride + tidx] = fCiBuf[ciOffset][2];
479 itemIdx.barrier(fence_space::local_space);
481 /* Reduce the initial c_clSize values for each i atom to half
482 * every step by using c_clSize * i threads. */
483 int i = c_clSize / 2;
484 for (int j = clSizeLog2 - 1; j > 0; j--)
488 sm_buf[tidxj * c_clSize + tidxi] += sm_buf[(tidxj + i) * c_clSize + tidxi];
489 sm_buf[bufStride + tidxj * c_clSize + tidxi] +=
490 sm_buf[bufStride + (tidxj + i) * c_clSize + tidxi];
491 sm_buf[2 * bufStride + tidxj * c_clSize + tidxi] +=
492 sm_buf[2 * bufStride + (tidxj + i) * c_clSize + tidxi];
495 itemIdx.barrier(fence_space::local_space);
498 /* i == 1, last reduction step, writing to global mem */
499 /* Split the reduction between the first 3 line threads
500 Threads with line id 0 will do the reduction for (float3).x components
501 Threads with line id 1 will do the reduction for (float3).y components
502 Threads with line id 2 will do the reduction for (float3).z components. */
506 sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
507 atomicFetchAdd(a_f[aidx][tidxj], f);
513 itemIdx.barrier(fence_space::local_space);
515 /* add up local shift forces into global mem */
518 /* Only threads with tidxj < 3 will update fshift.
519 The threads performing the update must be the same as the threads
520 storing the reduction result above. */
523 if constexpr (c_clSize == 4)
525 /* Intel Xe (Gen12LP) and earlier GPUs implement floating-point atomics via
526 * a compare-and-swap (CAS) loop. It has particularly poor performance when
527 * updating the same memory location from the same work-group.
528 * Such optimization might be slightly beneficial for NVIDIA and AMD as well,
529 * but it is unlikely to make a big difference and thus was not evaluated.
531 auto sg = itemIdx.get_sub_group();
532 fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 1);
533 fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 2);
536 atomicFetchAdd(a_fShift[shift][tidxj], fShiftBuf);
541 atomicFetchAdd(a_fShift[shift][tidxj], fShiftBuf);
547 /*! \brief Main kernel for NBNXM.
550 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
551 auto nbnxmKernel(cl::sycl::handler& cgh,
552 DeviceAccessor<Float4, mode::read> a_xq,
553 DeviceAccessor<Float3, mode::read_write> a_f,
554 DeviceAccessor<Float3, mode::read> a_shiftVec,
555 DeviceAccessor<Float3, mode::read_write> a_fShift,
556 OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyElec,
557 OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyVdw,
558 DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
559 DeviceAccessor<nbnxn_sci_t, mode::read> a_plistSci,
560 DeviceAccessor<nbnxn_excl_t, mode::read> a_plistExcl,
561 OptionalAccessor<Float2, mode::read, ljComb<vdwType>> a_ljComb,
562 OptionalAccessor<int, mode::read, !ljComb<vdwType>> a_atomTypes,
563 OptionalAccessor<Float2, mode::read, !ljComb<vdwType>> a_nbfp,
564 OptionalAccessor<Float2, mode::read, ljEwald<vdwType>> a_nbfpComb,
565 OptionalAccessor<float, mode::read, elecEwaldTab<elecType>> a_coulombTab,
567 const float rCoulombSq,
570 const float ewaldBeta,
571 const float rlistOuterSq,
572 const float ewaldShift,
574 const float ewaldCoeffLJ,
576 const shift_consts_t dispersionShift,
577 const shift_consts_t repulsionShift,
578 const switch_consts_t vdwSwitch,
579 const float rVdwSwitch,
580 const float ljEwaldShift,
581 const float coulombTabScale,
582 const bool calcShift)
584 static constexpr EnergyFunctionProperties<elecType, vdwType> props;
588 a_shiftVec.bind(cgh);
590 if constexpr (doCalcEnergies)
592 a_energyElec.bind(cgh);
593 a_energyVdw.bind(cgh);
595 a_plistCJ4.bind(cgh);
596 a_plistSci.bind(cgh);
597 a_plistExcl.bind(cgh);
598 if constexpr (!props.vdwComb)
600 a_atomTypes.bind(cgh);
607 if constexpr (props.vdwEwald)
609 a_nbfpComb.bind(cgh);
611 if constexpr (props.elecEwaldTab)
613 a_coulombTab.bind(cgh);
616 // shmem buffer for i x+q pre-loading
617 cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(
618 cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
620 // shmem buffer for force reduction
621 // SYCL-TODO: Make into 3D; section 4.7.6.11 of SYCL2020 specs
622 cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_reductionBuffer(
623 cl::sycl::range<1>(c_clSize * c_clSize * DIM), cgh);
625 auto sm_atomTypeI = [&]() {
626 if constexpr (!props.vdwComb)
628 return cl::sycl::accessor<int, 2, mode::read_write, target::local>(
629 cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
637 auto sm_ljCombI = [&]() {
638 if constexpr (props.vdwComb)
640 return cl::sycl::accessor<Float2, 2, mode::read_write, target::local>(
641 cl::sycl::range<2>(c_nbnxnGpuNumClusterPerSupercluster, c_clSize), cgh);
649 /* Flag to control the calculation of exclusion forces in the kernel
650 * We do that with Ewald (elec/vdw) and RF. Cut-off only has exclusion
652 constexpr bool doExclusionForces =
653 (props.elecEwald || props.elecRF || props.vdwEwald || (props.elecCutoff && doCalcEnergies));
655 // The post-prune j-i cluster-pair organization is linked to how exclusion and interaction mask data is stored.
656 // Currently, this is ideally suited for 32-wide subgroup size but slightly less so for others,
657 // e.g. subGroupSize > prunedClusterPairSize on AMD GCN / CDNA.
658 // Hence, the two are decoupled.
659 // When changing this code, please update requiredSubGroupSizeForNbnxm in src/gromacs/hardware/device_management_sycl.cpp.
660 constexpr int prunedClusterPairSize = c_clSize * c_splitClSize;
661 #if defined(HIPSYCL_PLATFORM_ROCM) // SYCL-TODO AMD RDNA/RDNA2 has 32-wide exec; how can we check for that?
662 gmx_unused constexpr int subGroupSize = c_clSize * c_clSize;
664 gmx_unused constexpr int subGroupSize = prunedClusterPairSize;
667 return [=](cl::sycl::nd_item<3> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
669 /* thread/block/warp id-s */
670 const unsigned tidxi = itemIdx.get_local_id(2);
671 const unsigned tidxj = itemIdx.get_local_id(1);
672 const unsigned tidx = tidxj * c_clSize + tidxi;
673 const unsigned tidxz = 0;
675 const unsigned bidx = itemIdx.get_group(0);
677 const sycl_2020::sub_group sg = itemIdx.get_sub_group();
678 // Could use sg.get_group_range to compute the imask & exclusion Idx, but too much of the logic relies on it anyway
679 // and in cases where prunedClusterPairSize != subGroupSize we can't use it anyway
680 const unsigned imeiIdx = tidx / prunedClusterPairSize;
682 Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster]; // i force buffer
683 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
685 fCiBuf[i] = Float3(0.0F, 0.0F, 0.0F);
688 const nbnxn_sci_t nbSci = a_plistSci[bidx];
689 const int sci = nbSci.sci;
690 const int cij4Start = nbSci.cj4_ind_start;
691 const int cij4End = nbSci.cj4_ind_end;
693 // Only needed if props.elecEwaldAna
694 const float beta2 = ewaldBeta * ewaldBeta;
695 const float beta3 = ewaldBeta * ewaldBeta * ewaldBeta;
697 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i += c_clSize)
699 /* Pre-load i-atom x and q into shared memory */
700 const int ci = sci * c_nbnxnGpuNumClusterPerSupercluster + tidxj + i;
701 const int ai = ci * c_clSize + tidxi;
702 const cl::sycl::id<2> cacheIdx = cl::sycl::id<2>(tidxj + i, tidxi);
704 const Float3 shift = a_shiftVec[nbSci.shift];
705 Float4 xqi = a_xq[ai];
706 xqi += Float4(shift[0], shift[1], shift[2], 0.0F);
708 sm_xq[cacheIdx] = xqi;
710 if constexpr (!props.vdwComb)
712 // Pre-load the i-atom types into shared memory
713 sm_atomTypeI[cacheIdx] = a_atomTypes[ai];
717 // Pre-load the LJ combination parameters into shared memory
718 sm_ljCombI[cacheIdx] = a_ljComb[ai];
721 itemIdx.barrier(fence_space::local_space);
723 float ewaldCoeffLJ_2, ewaldCoeffLJ_6_6; // Only needed if (props.vdwEwald)
724 if constexpr (props.vdwEwald)
726 ewaldCoeffLJ_2 = ewaldCoeffLJ * ewaldCoeffLJ;
727 ewaldCoeffLJ_6_6 = ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * ewaldCoeffLJ_2 * c_oneSixth;
730 float energyVdw, energyElec; // Only needed if (doCalcEnergies)
731 if constexpr (doCalcEnergies)
733 energyVdw = energyElec = 0.0F;
735 if constexpr (doCalcEnergies && doExclusionForces)
737 if (nbSci.shift == gmx::c_centralShiftIndex
738 && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
740 // we have the diagonal: add the charge and LJ self interaction energy term
741 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
743 // TODO: Are there other options?
744 if constexpr (props.elecEwald || props.elecRF || props.elecCutoff)
746 const float qi = sm_xq[i][tidxi][3];
747 energyElec += qi * qi;
749 if constexpr (props.vdwEwald)
752 a_nbfp[a_atomTypes[(sci * c_nbnxnGpuNumClusterPerSupercluster + i) * c_clSize + tidxi]
753 * (numTypes + 1)][0];
756 /* divide the self term(s) equally over the j-threads, then multiply with the coefficients. */
757 if constexpr (props.vdwEwald)
759 energyVdw /= c_clSize;
760 energyVdw *= 0.5F * c_oneSixth * ewaldCoeffLJ_6_6; // c_OneTwelfth?
762 if constexpr (props.elecRF || props.elecCutoff)
764 // Correct for epsfac^2 due to adding qi^2 */
765 energyElec /= epsFac * c_clSize;
766 energyElec *= -0.5F * cRF;
768 if constexpr (props.elecEwald)
770 // Correct for epsfac^2 due to adding qi^2 */
771 energyElec /= epsFac * c_clSize;
772 energyElec *= -ewaldBeta * c_OneOverSqrtPi; /* last factor 1/sqrt(pi) */
774 } // (nbSci.shift == gmx::c_centralShiftIndex && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
775 } // (doCalcEnergies && doExclusionForces)
777 // Only needed if (doExclusionForces)
778 const bool nonSelfInteraction = !(nbSci.shift == gmx::c_centralShiftIndex & tidxj <= tidxi);
780 // loop over the j clusters = seen by any of the atoms in the current super-cluster
781 for (int j4 = cij4Start + tidxz; j4 < cij4End; j4 += 1)
783 unsigned imask = a_plistCJ4[j4].imei[imeiIdx].imask;
784 if (!doPruneNBL && !imask)
788 const int wexclIdx = a_plistCJ4[j4].imei[imeiIdx].excl_ind;
789 static_assert(gmx::isPowerOfTwo(prunedClusterPairSize));
790 const unsigned wexcl = a_plistExcl[wexclIdx].pair[tidx & (prunedClusterPairSize - 1)];
791 for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
794 imask & (superClInteractionMask << (jm * c_nbnxnGpuNumClusterPerSupercluster));
799 unsigned maskJI = (1U << (jm * c_nbnxnGpuNumClusterPerSupercluster));
800 const int cj = a_plistCJ4[j4].cj[jm];
801 const int aj = cj * c_clSize + tidxj;
804 const Float4 xqj = a_xq[aj];
806 const Float3 xj(xqj[0], xqj[1], xqj[2]);
807 const float qj = xqj[3];
808 int atomTypeJ; // Only needed if (!props.vdwComb)
809 Float2 ljCombJ; // Only needed if (props.vdwComb)
810 if constexpr (props.vdwComb)
812 ljCombJ = a_ljComb[aj];
816 atomTypeJ = a_atomTypes[aj];
819 Float3 fCjBuf(0.0F, 0.0F, 0.0F);
821 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
826 const int ci = sci * c_nbnxnGpuNumClusterPerSupercluster + i;
827 // all threads load an atom from i cluster ci into shmem!
828 const Float4 xqi = sm_xq[i][tidxi];
829 const Float3 xi(xqi[0], xqi[1], xqi[2]);
831 // distance between i and j atoms
832 const Float3 rv = xi - xj;
833 float r2 = norm2(rv);
835 if constexpr (doPruneNBL)
837 /* If _none_ of the atoms pairs are in cutoff range,
838 * the bit corresponding to the current
839 * cluster-pair in imask gets set to 0. */
840 if (!sycl_2020::group_any_of(sg, r2 < rlistOuterSq))
845 const float pairExclMask = (wexcl & maskJI) ? 1.0F : 0.0F;
847 // cutoff & exclusion check
849 const bool notExcluded = doExclusionForces ? (nonSelfInteraction | (ci != cj))
852 // SYCL-TODO: Check optimal way of branching here.
853 if ((r2 < rCoulombSq) && notExcluded)
855 const float qi = xqi[3];
856 int atomTypeI; // Only needed if (!props.vdwComb)
857 float sigma, epsilon;
860 if constexpr (!props.vdwComb)
862 /* LJ 6*C6 and 12*C12 */
863 atomTypeI = sm_atomTypeI[i][tidxi];
864 c6c12 = a_nbfp[numTypes * atomTypeI + atomTypeJ];
868 const Float2 ljCombI = sm_ljCombI[i][tidxi];
869 if constexpr (props.vdwCombGeom)
871 c6c12 = Float2(ljCombI[0] * ljCombJ[0], ljCombI[1] * ljCombJ[1]);
875 static_assert(props.vdwCombLB);
876 // LJ 2^(1/6)*sigma and 12*epsilon
877 sigma = ljCombI[0] + ljCombJ[0];
878 epsilon = ljCombI[1] * ljCombJ[1];
879 if constexpr (doCalcEnergies)
881 c6c12 = convertSigmaEpsilonToC6C12(sigma, epsilon);
883 } // props.vdwCombGeom
886 // c6 and c12 are unused and garbage iff props.vdwCombLB && !doCalcEnergies
887 const float c6 = c6c12[0];
888 const float c12 = c6c12[1];
890 // Ensure distance do not become so small that r^-12 overflows
891 r2 = std::max(r2, c_nbnxnMinDistanceSquared);
893 // No fast/native functions in some compilation passes
894 const float rInv = cl::sycl::rsqrt(r2);
896 // SYCL-TODO: sycl::half_precision::rsqrt?
897 const float rInv = cl::sycl::native::rsqrt(r2);
899 const float r2Inv = rInv * rInv;
900 float r6Inv, fInvR, energyLJPair;
901 if constexpr (!props.vdwCombLB || doCalcEnergies)
903 r6Inv = r2Inv * r2Inv * r2Inv;
904 if constexpr (doExclusionForces)
906 // SYCL-TODO: Check if true for SYCL
907 /* We could mask r2Inv, but with Ewald masking both
908 * r6Inv and fInvR is faster */
909 r6Inv *= pairExclMask;
911 fInvR = r6Inv * (c12 * r6Inv - c6) * r2Inv;
915 float sig_r = sigma * rInv;
916 float sig_r2 = sig_r * sig_r;
917 float sig_r6 = sig_r2 * sig_r2 * sig_r2;
918 if constexpr (doExclusionForces)
920 sig_r6 *= pairExclMask;
922 fInvR = epsilon * sig_r6 * (sig_r6 - 1.0F) * r2Inv;
923 } // (!props.vdwCombLB || doCalcEnergies)
924 if constexpr (doCalcEnergies || props.vdwPSwitch)
926 energyLJPair = pairExclMask
927 * (c12 * (r6Inv * r6Inv + repulsionShift.cpot) * c_oneTwelfth
928 - c6 * (r6Inv + dispersionShift.cpot) * c_oneSixth);
930 if constexpr (props.vdwFSwitch)
932 ljForceSwitch<doCalcEnergies>(
933 dispersionShift, repulsionShift, rVdwSwitch, c6, c12, rInv, r2, &fInvR, &energyLJPair);
935 if constexpr (props.vdwEwald)
937 ljEwaldComb<doCalcEnergies, vdwType>(a_nbfpComb,
948 } // (props.vdwEwald)
949 if constexpr (props.vdwPSwitch)
951 ljPotentialSwitch<doCalcEnergies>(
952 vdwSwitch, rVdwSwitch, rInv, r2, &fInvR, &energyLJPair);
954 if constexpr (props.elecEwaldTwin)
956 // Separate VDW cut-off check to enable twin-range cut-offs
957 // (rVdw < rCoulomb <= rList)
958 const float vdwInRange = (r2 < rVdwSq) ? 1.0F : 0.0F;
960 if constexpr (doCalcEnergies)
962 energyLJPair *= vdwInRange;
965 if constexpr (doCalcEnergies)
967 energyVdw += energyLJPair;
970 if constexpr (props.elecCutoff)
972 if constexpr (doExclusionForces)
974 fInvR += qi * qj * pairExclMask * r2Inv * rInv;
978 fInvR += qi * qj * r2Inv * rInv;
981 if constexpr (props.elecRF)
983 fInvR += qi * qj * (pairExclMask * r2Inv * rInv - twoKRf);
985 if constexpr (props.elecEwaldAna)
988 * (pairExclMask * r2Inv * rInv + pmeCorrF(beta2 * r2) * beta3);
990 if constexpr (props.elecEwaldTab)
993 * (pairExclMask * r2Inv
994 - interpolateCoulombForceR(
995 a_coulombTab, coulombTabScale, r2 * rInv))
999 if constexpr (doCalcEnergies)
1001 if constexpr (props.elecCutoff)
1003 energyElec += qi * qj * (pairExclMask * rInv - cRF);
1005 if constexpr (props.elecRF)
1008 qi * qj * (pairExclMask * rInv + 0.5F * twoKRf * r2 - cRF);
1010 if constexpr (props.elecEwald)
1014 * (rInv * (pairExclMask - cl::sycl::erf(r2 * rInv * ewaldBeta))
1015 - pairExclMask * ewaldShift);
1019 const Float3 forceIJ = rv * fInvR;
1021 /* accumulate j forces in registers */
1023 /* accumulate i forces in registers */
1024 fCiBuf[i] += forceIJ;
1025 } // (r2 < rCoulombSq) && notExcluded
1026 } // (imask & maskJI)
1027 /* shift the mask bit by 1 */
1029 } // for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
1030 /* reduce j forces */
1031 reduceForceJ(sm_reductionBuffer, fCjBuf, itemIdx, tidxi, tidxj, aj, a_f);
1032 } // for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
1033 if constexpr (doPruneNBL)
1035 /* Update the imask with the new one which does not contain the
1036 * out of range clusters anymore. */
1037 a_plistCJ4[j4].imei[imeiIdx].imask = imask;
1039 } // for (int j4 = cij4Start; j4 < cij4End; j4 += 1)
1041 /* skip central shifts when summing shift forces */
1042 const bool doCalcShift = (calcShift && nbSci.shift != gmx::c_centralShiftIndex);
1044 reduceForceIAndFShift(
1045 sm_reductionBuffer, fCiBuf, doCalcShift, itemIdx, tidxi, tidxj, sci, nbSci.shift, a_f, a_fShift);
1047 if constexpr (doCalcEnergies)
1049 const float energyVdwGroup =
1050 groupReduce<subGroupSize, c_clSizeSq>(itemIdx, tidx, sm_reductionBuffer, energyVdw);
1051 itemIdx.barrier(fence_space::local_space); // Prevent the race on sm_reductionBuffer.
1052 const float energyElecGroup = groupReduce<subGroupSize, c_clSizeSq>(
1053 itemIdx, tidx, sm_reductionBuffer, energyElec);
1057 atomicFetchAdd(a_energyVdw[0], energyVdwGroup);
1058 atomicFetchAdd(a_energyElec[0], energyElecGroup);
1064 //! \brief NBNXM kernel launch code.
1065 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType, class... Args>
1066 cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int numSci, Args&&... args)
1068 using kernelNameType = NbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>;
1070 /* Kernel launch config:
1071 * - The thread block dimensions match the size of i-clusters, j-clusters,
1072 * and j-cluster concurrency, in x, y, and z, respectively.
1073 * - The 1D block-grid contains as many blocks as super-clusters.
1075 const int numBlocks = numSci;
1076 const cl::sycl::range<3> blockSize{ 1, c_clSize, c_clSize };
1077 const cl::sycl::range<3> globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
1078 const cl::sycl::nd_range<3> range{ globalSize, blockSize };
1080 cl::sycl::queue q = deviceStream.stream();
1082 cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
1083 auto kernel = nbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>(
1084 cgh, std::forward<Args>(args)...);
1085 cgh.parallel_for<kernelNameType>(range, kernel);
1091 //! \brief Select templated kernel and launch it.
1092 template<class... Args>
1093 cl::sycl::event chooseAndLaunchNbnxmKernel(bool doPruneNBL,
1094 bool doCalcEnergies,
1095 enum ElecType elecType,
1096 enum VdwType vdwType,
1099 return gmx::dispatchTemplatedFunction(
1100 [&](auto doPruneNBL_, auto doCalcEnergies_, auto elecType_, auto vdwType_) {
1101 return launchNbnxmKernel<doPruneNBL_, doCalcEnergies_, elecType_, vdwType_>(
1102 std::forward<Args>(args)...);
1110 void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const InteractionLocality iloc)
1112 NBAtomDataGpu* adat = nb->atdat;
1113 NBParamGpu* nbp = nb->nbparam;
1114 gpu_plist* plist = nb->plist[iloc];
1115 const bool doPruneNBL = (plist->haveFreshList && !nb->didPrune[iloc]);
1116 const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
1118 cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL,
1119 stepWork.computeEnergy,
1148 nbp->dispersion_shift,
1149 nbp->repulsion_shift,
1153 nbp->coulomb_tab_scale,
1154 stepWork.computeVirial);
1157 } // namespace Nbnxm