Implement PME solve in SYCL
[alexxy/gromacs.git] / src / gromacs / gpu_utils / sycl_kernel_utils.h
index 544c9e2b13236021dda8eeef37919d77eb25bf2c..aba31db05351b8ca1d83937c98d64fd1761f0af1 100644 (file)
@@ -144,6 +144,113 @@ static inline float shift_right(sycl_2020::sub_group sg, float var, sycl_2020::s
     return sg.shuffle_up(var, delta);
 }
 #endif
+
+#if GMX_SYCL_HIPSYCL
+/*! \brief Polyfill for sycl::isfinite missing from hipSYCL
+ *
+ * Does not follow GROMACS style because it should follow the name for
+ * which it is a polyfill. */
+template<typename Real>
+__device__ __host__ static inline bool isfinite(Real value)
+{
+    // This is not yet implemented in hipSYCL pending
+    // https://github.com/illuhad/hipSYCL/issues/636
+#    ifdef SYCL_DEVICE_ONLY
+#        if defined(HIPSYCL_PLATFORM_CUDA) && defined(__HIPSYCL_ENABLE_CUDA_TARGET__)
+    return isfinite(value);
+#        elif defined(HIPSYCL_PLATFORM_ROCM) && defined(__HIPSYCL_ENABLE_HIP_TARGET__)
+    return isfinite(value);
+#        else
+#            error "Unsupported hipSYCL target"
+#        endif
+#    else
+    // Should never be called
+    assert(false);
+    GMX_UNUSED_VALUE(value);
+    return false;
+#    endif
+}
+#elif GMX_SYCL_DPCPP
+template<typename Real>
+static inline bool isfinite(Real value)
+{
+    return cl::sycl::isfinite(value);
+}
+
+#endif
+
+#if GMX_SYCL_HIPSYCL
+
+/*! \brief Polyfill for sycl::vec::load buggy in hipSYCL
+ *
+ * Loads from the address \c ptr offset in elements of type T by
+ * NumElements * offset, into the components of \c v.
+ *
+ * Can probably be removed when
+ * https://github.com/illuhad/hipSYCL/issues/647 is resolved. */
+template<cl::sycl::access::address_space AddressSpace, typename T, int NumElements>
+static inline void loadToVec(size_t                                     offset,
+                             cl::sycl::multi_ptr<const T, AddressSpace> ptr,
+                             cl::sycl::vec<T, NumElements>*             v)
+{
+    for (int i = 0; i < NumElements; ++i)
+    {
+        (*v)[i] = ptr.get()[offset * NumElements + i];
+    }
+}
+
+/*! \brief Polyfill for sycl::vec::store buggy in hipSYCL
+ *
+ * Loads from the address \c ptr offset in elements of type T by
+ * NumElements * offset, into the components of \c v.
+ *
+ * Can probably be removed when
+ * https://github.com/illuhad/hipSYCL/issues/647 is resolved. */
+template<cl::sycl::access::address_space AddressSpace, typename T, int NumElements>
+static inline void storeFromVec(const cl::sycl::vec<T, NumElements>& v,
+                                size_t                               offset,
+                                cl::sycl::multi_ptr<T, AddressSpace> ptr)
+{
+    for (int i = 0; i < NumElements; ++i)
+    {
+        ptr.get()[offset * NumElements + i] = v[i];
+    }
+}
+
+#elif GMX_SYCL_DPCPP
+
+/*! \brief Polyfill for sycl::vec::load buggy in hipSYCL
+ *
+ * Loads from the address \c ptr offset in elements of type T by
+ * NumElements * offset, into the components of \c v.
+ *
+ * Can probably be removed when
+ * https://github.com/illuhad/hipSYCL/issues/647 is resolved. */
+template<cl::sycl::access::address_space AddressSpace, typename T, int NumElements>
+static inline void loadToVec(size_t offset,
+                             cl::sycl::multi_ptr<const T, AddressSpace> ptr,
+                             cl::sycl::vec<T, NumElements>* v)
+{
+    v->load(offset, ptr);
+}
+
+/*! \brief Polyfill for sycl::vec::store buggy in hipSYCL
+ *
+ * Loads from the address \c ptr offset in elements of type T by
+ * NumElements * offset, into the components of \c v.
+ *
+ * Can probably be removed when
+ * https://github.com/illuhad/hipSYCL/issues/647 is resolved. */
+template<cl::sycl::access::address_space AddressSpace, typename T, int NumElements>
+static inline void storeFromVec(const cl::sycl::vec<T, NumElements>& v,
+                                size_t offset,
+                                cl::sycl::multi_ptr<T, AddressSpace> ptr)
+{
+    v.store(offset, ptr);
+}
+
+#endif
+
 } // namespace sycl_2020
 
 #endif /* GMX_GPU_UTILS_SYCL_KERNEL_UTILS_H */