SYCL: Use acc.bind(cgh) instead of cgh.require(acc)
authorAndrey Alekseenko <al42and@gmail.com>
Tue, 2 Nov 2021 12:20:21 +0000 (13:20 +0100)
committerSzilárd Páll <pall.szilard@gmail.com>
Wed, 3 Nov 2021 00:35:44 +0000 (00:35 +0000)
A few other minor changes to SYCL version of DeviceBuffer.

This is a prerequisite for adding support for USM.

Refs #3847, #3965.

12 files changed:
src/gromacs/ewald/pme_gather_sycl.cpp
src/gromacs/ewald/pme_solve_sycl.cpp
src/gromacs/ewald/pme_spread_sycl.cpp
src/gromacs/gpu_utils/devicebuffer_sycl.h
src/gromacs/mdlib/gpuforcereduction_impl_internal_sycl.cpp
src/gromacs/mdlib/leapfrog_gpu_internal_sycl.cpp
src/gromacs/mdlib/lincs_gpu_internal_sycl.cpp
src/gromacs/mdlib/settle_gpu_internal_sycl.cpp
src/gromacs/mdlib/update_constrain_gpu_internal_sycl.cpp
src/gromacs/nbnxm/sycl/nbnxm_gpu_buffer_ops_internal_sycl.cpp
src/gromacs/nbnxm/sycl/nbnxm_sycl_kernel.cpp
src/gromacs/nbnxm/sycl/nbnxm_sycl_kernel_pruneonly.cpp

index b03a7804b38e7b608220075f4299ab623b1b48ee..d133daaeeaa5acc4668366aa26bd347f9bd44fca 100644 (file)
@@ -284,27 +284,27 @@ auto pmeGatherKernel(cl::sycl::handler&                                 cgh,
     constexpr int splineParamsSize    = atomsPerBlock * DIM * order;
     constexpr int gridlineIndicesSize = atomsPerBlock * DIM;
 
-    cgh.require(a_gridA);
-    cgh.require(a_coefficientsA);
-    cgh.require(a_forces);
+    a_gridA.bind(cgh);
+    a_coefficientsA.bind(cgh);
+    a_forces.bind(cgh);
 
     if constexpr (numGrids == 2)
     {
-        cgh.require(a_gridB);
-        cgh.require(a_coefficientsB);
+        a_gridB.bind(cgh);
+        a_coefficientsB.bind(cgh);
     }
 
     if constexpr (readGlobal)
     {
-        cgh.require(a_theta);
-        cgh.require(a_dtheta);
-        cgh.require(a_gridlineIndices);
+        a_theta.bind(cgh);
+        a_dtheta.bind(cgh);
+        a_gridlineIndices.bind(cgh);
     }
     else
     {
-        cgh.require(a_coordinates);
-        cgh.require(a_fractShiftsTable);
-        cgh.require(a_gridlineIndicesTable);
+        a_coordinates.bind(cgh);
+        a_fractShiftsTable.bind(cgh);
+        a_gridlineIndicesTable.bind(cgh);
     }
 
     // Gridline indices, ivec
index 46883060e9b1367d8f34e48a145b530d250b1309..5a0125c44d286ddcc71bc29f6dd5d80fd5024388 100644 (file)
@@ -68,13 +68,13 @@ auto makeSolveKernel(cl::sycl::handler&                            cgh,
                      OptionalAccessor<float, mode::read_write, computeEnergyAndVirial> a_virialAndEnergy,
                      DeviceAccessor<float, mode::read_write> a_fourierGrid)
 {
-    cgh.require(a_splineModuli);
-    cgh.require(a_solveKernelParams);
+    a_splineModuli.bind(cgh);
+    a_solveKernelParams.bind(cgh);
     if constexpr (computeEnergyAndVirial)
     {
-        cgh.require(a_virialAndEnergy);
+        a_virialAndEnergy.bind(cgh);
     }
-    cgh.require(a_fourierGrid);
+    a_fourierGrid.bind(cgh);
 
     /* Reduce 7 outputs per warp in the shared memory */
     const int stride =
index c88c1de374ad0b094852f16da9ddf281e285688a..41244d591fafd41bcc15689d611ee38b41bfcf9e 100644 (file)
@@ -205,31 +205,31 @@ auto pmeSplineAndSpreadKernel(
 
     if constexpr (spreadCharges)
     {
-        cgh.require(a_realGrid_0);
+        a_realGrid_0.bind(cgh);
     }
     if constexpr (writeGlobal || computeSplines)
     {
-        cgh.require(a_theta);
+        a_theta.bind(cgh);
     }
     if constexpr (computeSplines && writeGlobal)
     {
-        cgh.require(a_dtheta);
+        a_dtheta.bind(cgh);
     }
     if constexpr (writeGlobal)
     {
-        cgh.require(a_gridlineIndices);
+        a_gridlineIndices.bind(cgh);
     }
     if constexpr (computeSplines)
     {
-        cgh.require(a_fractShiftsTable);
-        cgh.require(a_gridlineIndicesTable);
-        cgh.require(a_coordinates);
+        a_fractShiftsTable.bind(cgh);
+        a_gridlineIndicesTable.bind(cgh);
+        a_coordinates.bind(cgh);
     }
-    cgh.require(a_coefficients_0);
+    a_coefficients_0.bind(cgh);
     if constexpr (numGrids == 2 && spreadCharges)
     {
-        cgh.require(a_realGrid_1);
-        cgh.require(a_coefficients_1);
+        a_realGrid_1.bind(cgh);
+        a_coefficients_1.bind(cgh);
     }
 
     // Gridline indices, ivec
index cb3277b82f4c12e60f4dd74d16042ba37aed1bad..0bc8b7c59eecee8a09ea390d66167cad34e6fa4c 100644 (file)
@@ -176,6 +176,7 @@ public:
         static_assert(mode == cl::sycl::access::mode::read,
                       "Can not create non-read-only accessor from a const DeviceBuffer");
     }
+    void bind(cl::sycl::handler& cgh) { cgh.require(*this); }
 
 private:
     //! Helper function to get sycl:buffer object from DeviceBuffer wrapper, with a sanity check.
@@ -188,15 +189,16 @@ private:
 
 namespace gmx::internal
 {
-//! A "blackhole" class to be used when we want to ignore an argument to a function.
-struct EmptyClassThatIgnoresConstructorArguments
+//! A non-functional class that can be used instead of real accessors
+template<class T>
+struct NullAccessor
 {
-    template<class... Args>
-    [[maybe_unused]] EmptyClassThatIgnoresConstructorArguments(Args&&... /*args*/)
-    {
-    }
+    NullAccessor(const DeviceBuffer<T>& /*buffer*/) {}
     //! Allow casting to nullptr
     constexpr operator std::nullptr_t() const { return nullptr; }
+    //! Placeholder implementation of \c cl::sycl::accessor::get_pointer.
+    T*   get_pointer() const noexcept { return nullptr; }
+    void bind(cl::sycl::handler& /*cgh*/) { assert(false); }
 };
 } // namespace gmx::internal
 
@@ -230,7 +232,7 @@ struct EmptyClassThatIgnoresConstructorArguments
  */
 template<class T, cl::sycl::access::mode mode, bool enabled>
 using OptionalAccessor =
-        std::conditional_t<enabled, DeviceAccessor<T, mode>, gmx::internal::EmptyClassThatIgnoresConstructorArguments>;
+        std::conditional_t<enabled, DeviceAccessor<T, mode>, gmx::internal::NullAccessor<T>>;
 
 #endif // #ifndef DOXYGEN
 
@@ -441,7 +443,6 @@ inline cl::sycl::event fillSyclBufferWithNull(cl::sycl::buffer<Float3, 1>& buffe
             false;
 #endif
 
-
     if constexpr (usingHipSycl)
     {
         // hipSYCL does not support reinterpret but allows using Float3 directly.
@@ -541,7 +542,7 @@ void initParamLookupTable(DeviceBuffer<ValueType>* deviceBuffer,
 template<typename ValueType>
 void destroyParamLookupTable(DeviceBuffer<ValueType>* deviceBuffer, DeviceTexture* /* deviceTexture */)
 {
-    deviceBuffer->buffer_.reset(nullptr);
+    freeDeviceBuffer(deviceBuffer);
 }
 
 #endif // GMX_GPU_UTILS_DEVICEBUFFER_SYCL_H
index bd76fbf7d628c717716e0285cff896ef4e26b478..b8a2bcbb0ca66f2f458080546ccd30ec08b3b91a 100644 (file)
@@ -72,13 +72,13 @@ static auto reduceKernel(cl::sycl::handler&                                 cgh,
                          DeviceAccessor<int, cl::sycl::access::mode::read> a_cell,
                          const int                                         atomStart)
 {
-    cgh.require(a_nbnxmForce);
+    a_nbnxmForce.bind(cgh);
     if constexpr (addRvecForce)
     {
-        cgh.require(a_rvecForceToAdd);
+        a_rvecForceToAdd.bind(cgh);
     }
-    cgh.require(a_forceTotal);
-    cgh.require(a_cell);
+    a_forceTotal.bind(cgh);
+    a_cell.bind(cgh);
 
     return [=](cl::sycl::id<1> itemIdx) {
         // Set to nbnxnm force, then perhaps accumulate further to it
index b5572dcfe136f4b3ccc9a6059c4c94829c164cfb..25b3199e761fae2c2d48e5a1aa0c1a56f3033470 100644 (file)
@@ -98,18 +98,18 @@ auto leapFrogKernel(
         OptionalAccessor<unsigned short, mode::read, numTempScaleValues == NumTempScaleValues::Multiple> a_tempScaleGroups,
         Float3 prVelocityScalingMatrixDiagonal)
 {
-    cgh.require(a_x);
-    cgh.require(a_xp);
-    cgh.require(a_v);
-    cgh.require(a_f);
-    cgh.require(a_inverseMasses);
+    a_x.bind(cgh);
+    a_xp.bind(cgh);
+    a_v.bind(cgh);
+    a_f.bind(cgh);
+    a_inverseMasses.bind(cgh);
     if constexpr (numTempScaleValues != NumTempScaleValues::None)
     {
-        cgh.require(a_lambdas);
+        a_lambdas.bind(cgh);
     }
     if constexpr (numTempScaleValues == NumTempScaleValues::Multiple)
     {
-        cgh.require(a_tempScaleGroups);
+        a_tempScaleGroups.bind(cgh);
     }
 
     return [=](cl::sycl::id<1> itemIdx) {
index 658a45b32c7f6fa7ac9aa8c86a09ad4941aaaafe..60b10f3f8430a8eab42760e77ae0d8f3c820f55c 100644 (file)
@@ -124,25 +124,25 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
                  OptionalAccessor<float, mode::read_write, computeVirial>          a_virialScaled,
                  PbcAiuc                                                           pbcAiuc)
 {
-    cgh.require(a_constraints);
-    cgh.require(a_constraintsTargetLengths);
+    a_constraints.bind(cgh);
+    a_constraintsTargetLengths.bind(cgh);
     if constexpr (haveCoupledConstraints)
     {
-        cgh.require(a_coupledConstraintsCounts);
-        cgh.require(a_coupledConstraintsIndices);
-        cgh.require(a_massFactors);
-        cgh.require(a_matrixA);
+        a_coupledConstraintsCounts.bind(cgh);
+        a_coupledConstraintsIndices.bind(cgh);
+        a_massFactors.bind(cgh);
+        a_matrixA.bind(cgh);
     }
-    cgh.require(a_inverseMasses);
-    cgh.require(a_x);
-    cgh.require(a_xp);
+    a_inverseMasses.bind(cgh);
+    a_x.bind(cgh);
+    a_xp.bind(cgh);
     if constexpr (updateVelocities)
     {
-        cgh.require(a_v);
+        a_v.bind(cgh);
     }
     if constexpr (computeVirial)
     {
-        cgh.require(a_virialScaled);
+        a_virialScaled.bind(cgh);
     }
 
     /* Shared local memory buffer. Corresponds to sh_r, sm_rhs, and sm_threadVirial in CUDA.
index 6a32e856c0dd3e60a311e15567151567dcd159a5..2f7ae81b4c3b1a9821f96fb2505bd73f8cefc9ab 100644 (file)
@@ -72,16 +72,16 @@ auto settleKernel(cl::sycl::handler&                                           c
                   OptionalAccessor<float, mode::read_write, computeVirial>     a_virialScaled,
                   PbcAiuc                                                      pbcAiuc)
 {
-    cgh.require(a_settles);
-    cgh.require(a_x);
-    cgh.require(a_xp);
+    a_settles.bind(cgh);
+    a_x.bind(cgh);
+    a_xp.bind(cgh);
     if constexpr (updateVelocities)
     {
-        cgh.require(a_v);
+        a_v.bind(cgh);
     }
     if constexpr (computeVirial)
     {
-        cgh.require(a_virialScaled);
+        a_virialScaled.bind(cgh);
     }
 
     // shmem buffer for i x+q pre-loading
index f7113a203fdbc55bf765052ad94dbb71b73cc3a1..12f60ea7ee11360c751ee1ff76a583fb3ae8d29c 100644 (file)
@@ -60,7 +60,7 @@ static auto scaleKernel(cl::sycl::handler&
                         DeviceAccessor<Float3, cl::sycl::access::mode::read_write> a_x,
                         const ScalingMatrix                                        scalingMatrix)
 {
-    cgh.require(a_x);
+    a_x.bind(cgh);
 
     return [=](cl::sycl::id<1> itemIdx) {
         Float3 x     = a_x[itemIdx];
index d95bd6f06afec370b906665ba7e8574f28a28280..b4cda5e263d2c1f3584fd7c94a3e43bb959a1010 100644 (file)
@@ -76,11 +76,11 @@ static auto nbnxmKernelTransformXToXq(cl::sycl::handler&                       c
                                       int                                      numAtomsPerCell,
                                       int                                      columnsOffset)
 {
-    cgh.require(a_xq);
-    cgh.require(a_x);
-    cgh.require(a_atomIndex);
-    cgh.require(a_numAtoms);
-    cgh.require(a_cellIndex);
+    a_xq.bind(cgh);
+    a_x.bind(cgh);
+    a_atomIndex.bind(cgh);
+    a_numAtoms.bind(cgh);
+    a_cellIndex.bind(cgh);
 
     return [=](cl::sycl::id<2> itemIdx) {
         // Map cell-level parallelism to y component of block index.
index ea5321645d826439d21e1e0689b432c5f04caedc..a7b2c6cde6b080ac4946855703037d4209aab454 100644 (file)
@@ -583,34 +583,34 @@ auto nbnxmKernel(cl::sycl::handler&                                        cgh,
 {
     static constexpr EnergyFunctionProperties<elecType, vdwType> props;
 
-    cgh.require(a_xq);
-    cgh.require(a_f);
-    cgh.require(a_shiftVec);
-    cgh.require(a_fShift);
+    a_xq.bind(cgh);
+    a_f.bind(cgh);
+    a_shiftVec.bind(cgh);
+    a_fShift.bind(cgh);
     if constexpr (doCalcEnergies)
     {
-        cgh.require(a_energyElec);
-        cgh.require(a_energyVdw);
+        a_energyElec.bind(cgh);
+        a_energyVdw.bind(cgh);
     }
-    cgh.require(a_plistCJ4);
-    cgh.require(a_plistSci);
-    cgh.require(a_plistExcl);
+    a_plistCJ4.bind(cgh);
+    a_plistSci.bind(cgh);
+    a_plistExcl.bind(cgh);
     if constexpr (!props.vdwComb)
     {
-        cgh.require(a_atomTypes);
-        cgh.require(a_nbfp);
+        a_atomTypes.bind(cgh);
+        a_nbfp.bind(cgh);
     }
     else
     {
-        cgh.require(a_ljComb);
+        a_ljComb.bind(cgh);
     }
     if constexpr (props.vdwEwald)
     {
-        cgh.require(a_nbfpComb);
+        a_nbfpComb.bind(cgh);
     }
     if constexpr (props.elecEwaldTab)
     {
-        cgh.require(a_coulombTab);
+        a_coulombTab.bind(cgh);
     }
 
     // shmem buffer for i x+q pre-loading
index b779ad79e3cc58d403293759c68ba773d753ab21..4b2ca582c3e5b98905313ca72428bb9c4527b406 100644 (file)
@@ -76,11 +76,11 @@ auto nbnxmKernelPruneOnly(cl::sycl::handler&                            cgh,
                           const int   numParts,
                           const int   part)
 {
-    cgh.require(a_xq);
-    cgh.require(a_shiftVec);
-    cgh.require(a_plistCJ4);
-    cgh.require(a_plistSci);
-    cgh.require(a_plistIMask);
+    a_xq.bind(cgh);
+    a_shiftVec.bind(cgh);
+    a_plistCJ4.bind(cgh);
+    a_plistSci.bind(cgh);
+    a_plistIMask.bind(cgh);
 
     /* shmem buffer for i x+q pre-loading */
     cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(