SYCL: Use acc.bind(cgh) instead of cgh.require(acc)
[alexxy/gromacs.git] / src / gromacs / gpu_utils / devicebuffer_sycl.h
index 68d245337fb6d643865850680f4d455aab6ba527..0bc8b7c59eecee8a09ea390d66167cad34e6fa4c 100644 (file)
@@ -81,9 +81,16 @@ DeviceBuffer<T>::~DeviceBuffer() = default;
 
 //! Copy constructor (references the same underlying SYCL buffer).
 template<typename T>
-DeviceBuffer<T>::DeviceBuffer(DeviceBuffer<T> const& src) :
-    buffer_(new ClSyclBufferWrapper(*src.buffer_))
+DeviceBuffer<T>::DeviceBuffer(DeviceBuffer<T> const& src)
 {
+    if (src.buffer_)
+    {
+        buffer_ = std::make_unique<ClSyclBufferWrapper>(*src.buffer_);
+    }
+    else
+    {
+        buffer_ = nullptr;
+    }
 }
 
 //! Move constructor.
@@ -94,7 +101,14 @@ DeviceBuffer<T>::DeviceBuffer(DeviceBuffer<T>&& src) noexcept = default;
 template<typename T>
 DeviceBuffer<T>& DeviceBuffer<T>::operator=(DeviceBuffer<T> const& src)
 {
-    buffer_.reset(new ClSyclBufferWrapper(*src.buffer_));
+    if (src.buffer_)
+    {
+        buffer_ = std::make_unique<ClSyclBufferWrapper>(*src.buffer_);
+    }
+    else
+    {
+        buffer_.reset(nullptr);
+    }
     return *this;
 }
 
@@ -162,6 +176,7 @@ public:
         static_assert(mode == cl::sycl::access::mode::read,
                       "Can not create non-read-only accessor from a const DeviceBuffer");
     }
+    void bind(cl::sycl::handler& cgh) { cgh.require(*this); }
 
 private:
     //! Helper function to get sycl:buffer object from DeviceBuffer wrapper, with a sanity check.
@@ -174,13 +189,16 @@ private:
 
 namespace gmx::internal
 {
-//! A "blackhole" class to be used when we want to ignore an argument to a function.
-struct EmptyClassThatIgnoresConstructorArguments
+//! A non-functional class that can be used instead of real accessors
+template<class T>
+struct NullAccessor
 {
-    template<class... Args>
-    [[maybe_unused]] EmptyClassThatIgnoresConstructorArguments(Args&&... /*args*/)
-    {
-    }
+    NullAccessor(const DeviceBuffer<T>& /*buffer*/) {}
+    //! Allow casting to nullptr
+    constexpr operator std::nullptr_t() const { return nullptr; }
+    //! Placeholder implementation of \c cl::sycl::accessor::get_pointer.
+    T*   get_pointer() const noexcept { return nullptr; }
+    void bind(cl::sycl::handler& /*cgh*/) { assert(false); }
 };
 } // namespace gmx::internal
 
@@ -214,7 +232,7 @@ struct EmptyClassThatIgnoresConstructorArguments
  */
 template<class T, cl::sycl::access::mode mode, bool enabled>
 using OptionalAccessor =
-        std::conditional_t<enabled, DeviceAccessor<T, mode>, gmx::internal::EmptyClassThatIgnoresConstructorArguments>;
+        std::conditional_t<enabled, DeviceAccessor<T, mode>, gmx::internal::NullAccessor<T>>;
 
 #endif // #ifndef DOXYGEN
 
@@ -390,7 +408,7 @@ namespace gmx::internal
 {
 /*! \brief Helper function to clear device buffer.
  *
- * Not applicable to GROMACS's float3 (a.k.a. gmx::RVec) and other custom types.
+ * Not applicable to GROMACS's Float3 (a.k.a. gmx::RVec) and other custom types.
  * From SYCL specs: "T must be a scalar value or a SYCL vector type."
  */
 template<typename ValueType>
@@ -411,17 +429,44 @@ cl::sycl::event fillSyclBufferWithNull(cl::sycl::buffer<ValueType, 1>& buffer,
     });
 }
 
-//! \brief Helper function to clear device buffer of type float3.
+//! \brief Helper function to clear device buffer of type Float3.
 template<>
 inline cl::sycl::event fillSyclBufferWithNull(cl::sycl::buffer<Float3, 1>& buffer,
                                               size_t                       startingOffset,
                                               size_t                       numValues,
                                               cl::sycl::queue              queue)
 {
-    cl::sycl::buffer<float, 1> bufferAsFloat = buffer.reinterpret<float, 1>(buffer.get_count() * DIM);
-    return fillSyclBufferWithNull<float>(
-            bufferAsFloat, startingOffset * DIM, numValues * DIM, std::move(queue));
+    constexpr bool usingHipSycl =
+#ifdef __HIPSYCL__
+            true;
+#else
+            false;
+#endif
+
+    if constexpr (usingHipSycl)
+    {
+        // hipSYCL does not support reinterpret but allows using Float3 directly.
+        using cl::sycl::access::mode;
+        const cl::sycl::range<1> range(numValues);
+        const cl::sycl::id<1>    offset(startingOffset);
+        const Float3             pattern{ 0, 0, 0 };
+
+        return queue.submit([&](cl::sycl::handler& cgh) {
+            auto d_bufferAccessor =
+                    cl::sycl::accessor<Float3, 1, mode::discard_write>{ buffer, cgh, range, offset };
+            cgh.fill(d_bufferAccessor, pattern);
+        });
+    }
+    else // When not using hipSYCL, reinterpret as a flat float array
+    {
+#ifndef __HIPSYCL__
+        cl::sycl::buffer<float, 1> bufferAsFloat = buffer.reinterpret<float, 1>(buffer.get_count() * DIM);
+        return fillSyclBufferWithNull<float>(
+                bufferAsFloat, startingOffset * DIM, numValues * DIM, std::move(queue));
+#endif
+    }
 }
+
 } // namespace gmx::internal
 
 /*! \brief
@@ -495,9 +540,9 @@ void initParamLookupTable(DeviceBuffer<ValueType>* deviceBuffer,
  * \param[in,out] deviceBuffer  Device buffer to store data in.
  */
 template<typename ValueType>
-void destroyParamLookupTable(DeviceBuffer<ValueType>* deviceBuffer, DeviceTexture& /* deviceTexture */)
+void destroyParamLookupTable(DeviceBuffer<ValueType>* deviceBuffer, DeviceTexture* /* deviceTexture */)
 {
-    deviceBuffer->buffer_.reset(nullptr);
+    freeDeviceBuffer(deviceBuffer);
 }
 
 #endif // GMX_GPU_UTILS_DEVICEBUFFER_SYCL_H