Workaround for a hipSYCL assertion error
authorSzilárd Páll <pall.szilard@gmail.com>
Fri, 21 May 2021 10:16:51 +0000 (10:16 +0000)
committerPaul Bauer <paul.bauer.q@gmail.com>
Fri, 21 May 2021 10:16:51 +0000 (10:16 +0000)
src/gromacs/gpu_utils/devicebuffer_sycl.h

index 48a4ddc55d6951676a1d8b6092d28215eec8cd77..3ae9b615dad16a218a300d2a902d887702de55fc 100644 (file)
@@ -404,7 +404,7 @@ namespace gmx::internal
 {
 /*! \brief Helper function to clear device buffer.
  *
- * Not applicable to GROMACS's float3 (a.k.a. gmx::RVec) and other custom types.
+ * Not applicable to GROMACS's Float3 (a.k.a. gmx::RVec) and other custom types.
  * From SYCL specs: "T must be a scalar value or a SYCL vector type."
  */
 template<typename ValueType>
@@ -425,17 +425,43 @@ cl::sycl::event fillSyclBufferWithNull(cl::sycl::buffer<ValueType, 1>& buffer,
     });
 }
 
-//! \brief Helper function to clear device buffer of type float3.
+//! \brief Helper function to clear device buffer of type Float3.
 template<>
 inline cl::sycl::event fillSyclBufferWithNull(cl::sycl::buffer<Float3, 1>& buffer,
                                               size_t                       startingOffset,
                                               size_t                       numValues,
                                               cl::sycl::queue              queue)
 {
-    cl::sycl::buffer<float, 1> bufferAsFloat = buffer.reinterpret<float, 1>(buffer.get_count() * DIM);
-    return fillSyclBufferWithNull<float>(
-            bufferAsFloat, startingOffset * DIM, numValues * DIM, std::move(queue));
+    constexpr bool usingHipSycl =
+#ifdef __HIPSYCL__
+            true;
+#else
+            false;
+#endif
+
+
+    if constexpr (usingHipSycl)
+    {
+        // hipSYCL does not support reinterpret but allows using Float3 directly.
+        using cl::sycl::access::mode;
+        const cl::sycl::range<1> range(numValues);
+        const cl::sycl::id<1>    offset(startingOffset);
+        const Float3             pattern{ 0, 0, 0 };
+
+        return queue.submit([&](cl::sycl::handler& cgh) {
+            auto d_bufferAccessor =
+                    cl::sycl::accessor<Float3, 1, mode::discard_write>{ buffer, cgh, range, offset };
+            cgh.fill(d_bufferAccessor, pattern);
+        });
+    }
+    else // When not using hipSYCL, reinterpret as a flat float array
+    {
+        cl::sycl::buffer<float, 1> bufferAsFloat = buffer.reinterpret<float, 1>(buffer.get_count() * DIM);
+        return fillSyclBufferWithNull<float>(
+                bufferAsFloat, startingOffset * DIM, numValues * DIM, std::move(queue));
+    }
 }
+
 } // namespace gmx::internal
 
 /*! \brief