SYCL: Use acc.bind(cgh) instead of cgh.require(acc)
[alexxy/gromacs.git] / src / gromacs / gpu_utils / devicebuffer_sycl.h
index cb3277b82f4c12e60f4dd74d16042ba37aed1bad..0bc8b7c59eecee8a09ea390d66167cad34e6fa4c 100644 (file)
@@ -176,6 +176,7 @@ public:
         static_assert(mode == cl::sycl::access::mode::read,
                       "Can not create non-read-only accessor from a const DeviceBuffer");
     }
+    void bind(cl::sycl::handler& cgh) { cgh.require(*this); }
 
 private:
     //! Helper function to get sycl:buffer object from DeviceBuffer wrapper, with a sanity check.
@@ -188,15 +189,16 @@ private:
 
 namespace gmx::internal
 {
-//! A "blackhole" class to be used when we want to ignore an argument to a function.
-struct EmptyClassThatIgnoresConstructorArguments
+//! A non-functional class that can be used instead of real accessors
+template<class T>
+struct NullAccessor
 {
-    template<class... Args>
-    [[maybe_unused]] EmptyClassThatIgnoresConstructorArguments(Args&&... /*args*/)
-    {
-    }
+    NullAccessor(const DeviceBuffer<T>& /*buffer*/) {}
     //! Allow casting to nullptr
     constexpr operator std::nullptr_t() const { return nullptr; }
+    //! Placeholder implementation of \c cl::sycl::accessor::get_pointer.
+    T*   get_pointer() const noexcept { return nullptr; }
+    void bind(cl::sycl::handler& /*cgh*/) { assert(false); }
 };
 } // namespace gmx::internal
 
@@ -230,7 +232,7 @@ struct EmptyClassThatIgnoresConstructorArguments
  */
 template<class T, cl::sycl::access::mode mode, bool enabled>
 using OptionalAccessor =
-        std::conditional_t<enabled, DeviceAccessor<T, mode>, gmx::internal::EmptyClassThatIgnoresConstructorArguments>;
+        std::conditional_t<enabled, DeviceAccessor<T, mode>, gmx::internal::NullAccessor<T>>;
 
 #endif // #ifndef DOXYGEN
 
@@ -441,7 +443,6 @@ inline cl::sycl::event fillSyclBufferWithNull(cl::sycl::buffer<Float3, 1>& buffe
             false;
 #endif
 
-
     if constexpr (usingHipSycl)
     {
         // hipSYCL does not support reinterpret but allows using Float3 directly.
@@ -541,7 +542,7 @@ void initParamLookupTable(DeviceBuffer<ValueType>* deviceBuffer,
 template<typename ValueType>
 void destroyParamLookupTable(DeviceBuffer<ValueType>* deviceBuffer, DeviceTexture* /* deviceTexture */)
 {
-    deviceBuffer->buffer_.reset(nullptr);
+    freeDeviceBuffer(deviceBuffer);
 }
 
 #endif // GMX_GPU_UTILS_DEVICEBUFFER_SYCL_H