Improve handling of CUDA API errors
[alexxy/gromacs.git] / src / gromacs / gpu_utils / device_stream.cu
index cc1f8798622bc30a284cdef90f3967cef8eae88e..5a309c1156256563a512eb8d501759cc66a2dd00 100644 (file)
@@ -44,6 +44,7 @@
 
 #include "device_stream.h"
 
+#include "gromacs/gpu_utils/cudautils.cuh"
 #include "gromacs/utility/exceptions.h"
 #include "gromacs/utility/gmxassert.h"
 #include "gromacs/utility/stringutil.h"
@@ -57,11 +58,7 @@ DeviceStream::DeviceStream(const DeviceContext& /* deviceContext */,
     if (priority == DeviceStreamPriority::Normal)
     {
         stat = cudaStreamCreate(&stream_);
-        if (stat != cudaSuccess)
-        {
-            GMX_THROW(gmx::InternalError(gmx::formatString(
-                    "Could not create CUDA stream (CUDA error %d: %s).", stat, cudaGetErrorString(stat))));
-        }
+        gmx::checkDeviceError(stat, "Could not create CUDA stream.");
     }
     else if (priority == DeviceStreamPriority::High)
     {
@@ -70,20 +67,10 @@ DeviceStream::DeviceStream(const DeviceContext& /* deviceContext */,
         // range, which in that case will be a single value.
         int highestPriority;
         stat = cudaDeviceGetStreamPriorityRange(nullptr, &highestPriority);
-        if (stat != cudaSuccess)
-        {
-            GMX_THROW(gmx::InternalError(gmx::formatString(
-                    "Could not query CUDA stream priority range (CUDA error %d: %s).", stat,
-                    cudaGetErrorString(stat))));
-        }
+        gmx::checkDeviceError(stat, "Could not query CUDA stream priority range.");
 
         stat = cudaStreamCreateWithPriority(&stream_, cudaStreamDefault, highestPriority);
-        if (stat != cudaSuccess)
-        {
-            GMX_THROW(gmx::InternalError(gmx::formatString(
-                    "Could not create CUDA stream with high priority (CUDA error %d: %s).", stat,
-                    cudaGetErrorString(stat))));
-        }
+        gmx::checkDeviceError(stat, "Could not create CUDA stream with high priority.");
     }
 }
 
@@ -93,9 +80,7 @@ DeviceStream::~DeviceStream()
     {
         cudaError_t stat = cudaStreamDestroy(stream_);
         GMX_RELEASE_ASSERT(stat == cudaSuccess,
-                           gmx::formatString("Failed to release CUDA stream (CUDA error %d: %s).",
-                                             stat, cudaGetErrorString(stat))
-                                   .c_str());
+                           ("Failed to release CUDA stream. " + gmx::getDeviceErrorString(stat)).c_str());
         stream_ = nullptr;
     }
 }
@@ -114,7 +99,5 @@ void DeviceStream::synchronize() const
 {
     cudaError_t stat = cudaStreamSynchronize(stream_);
     GMX_RELEASE_ASSERT(stat == cudaSuccess,
-                       gmx::formatString("cudaStreamSynchronize failed  (CUDA error %d: %s).", stat,
-                                         cudaGetErrorString(stat))
-                               .c_str());
+                       ("cudaStreamSynchronize failed. " + gmx::getDeviceErrorString(stat)).c_str());
 }