Use device information object instead of id when performing device checks
[alexxy/gromacs.git] / src / gromacs / hardware / device_management.cu
index 32708873ec2765d29d3474d1f9c540e6c7dcb3b9..b59b51e59a10af1a1226ed879123207bed8ac310 100644 (file)
@@ -106,78 +106,33 @@ static cudaError_t checkCompiledTargetCompatibility(int deviceId, const cudaDevi
  * \todo Introduce errors codes and handle errors more smoothly.
  *
  *
- * \param[in]  dev_id      the device ID of the GPU or -1 if the device has already been initialized
- * \param[in]  dev_prop    The device properties structure
- * \returns                0 if the device looks OK, -1 if it sanity checks failed, and -2 if the device is busy
+ * \param[in]  deviceInfo  Device information on the device to check.
+ * \returns                The status enumeration value for the checked device:
  */
-static DeviceStatus isDeviceFunctional(int dev_id, const cudaDeviceProp& dev_prop)
+static DeviceStatus isDeviceFunctional(const DeviceInformation& deviceInfo)
 {
     cudaError_t cu_err;
-    int         dev_count, id;
-
-    cu_err = cudaGetDeviceCount(&dev_count);
-    if (cu_err != cudaSuccess)
-    {
-        fprintf(stderr, "Error %d while querying device count: %s\n", cu_err, cudaGetErrorString(cu_err));
-        return DeviceStatus::NonFunctional;
-    }
-
-    /* no CUDA compatible device at all */
-    if (dev_count == 0)
-    {
-        return DeviceStatus::NonFunctional;
-    }
-
-    /* things might go horribly wrong if cudart is not compatible with the driver */
-    if (dev_count < 0 || dev_count > c_cudaMaxDeviceCount)
-    {
-        return DeviceStatus::NonFunctional;
-    }
-
-    if (dev_id == -1) /* device already selected let's not destroy the context */
-    {
-        cu_err = cudaGetDevice(&id);
-        if (cu_err != cudaSuccess)
-        {
-            fprintf(stderr, "Error %d while querying device id: %s\n", cu_err, cudaGetErrorString(cu_err));
-            return DeviceStatus::NonFunctional;
-        }
-    }
-    else
-    {
-        id = dev_id;
-        if (id > dev_count - 1) /* pfff there's no such device */
-        {
-            fprintf(stderr,
-                    "The requested device with id %d does not seem to exist (device count=%d)\n",
-                    dev_id, dev_count);
-            return DeviceStatus::NonFunctional;
-        }
-    }
 
     /* both major & minor is 9999 if no CUDA capable devices are present */
-    if (dev_prop.major == 9999 && dev_prop.minor == 9999)
+    if (deviceInfo.prop.major == 9999 && deviceInfo.prop.minor == 9999)
     {
         return DeviceStatus::NonFunctional;
     }
     /* we don't care about emulation mode */
-    if (dev_prop.major == 0)
+    if (deviceInfo.prop.major == 0)
     {
         return DeviceStatus::NonFunctional;
     }
 
-    if (id != -1)
+    cu_err = cudaSetDevice(deviceInfo.id);
+    if (cu_err != cudaSuccess)
     {
-        cu_err = cudaSetDevice(id);
-        if (cu_err != cudaSuccess)
-        {
-            fprintf(stderr, "Error %d while switching to device #%d: %s\n", cu_err, id,
-                    cudaGetErrorString(cu_err));
-            return DeviceStatus::NonFunctional;
-        }
+        fprintf(stderr, "Error %d while switching to device #%d: %s\n", cu_err, deviceInfo.id,
+                cudaGetErrorString(cu_err));
+        return DeviceStatus::NonFunctional;
     }
 
-    cu_err = checkCompiledTargetCompatibility(dev_id, dev_prop);
+    cu_err = checkCompiledTargetCompatibility(deviceInfo.id, deviceInfo.prop);
     // Avoid triggering an error if GPU devices are in exclusive or prohibited mode;
     // it is enough to check for cudaErrorDevicesUnavailable only here because
     // if we encounter it that will happen in cudaFuncGetAttributes in the above function.
@@ -196,7 +151,6 @@ static DeviceStatus isDeviceFunctional(int dev_id, const cudaDeviceProp& dev_pro
         KernelLaunchConfig config;
         config.blockSize[0]                = 512;
         const auto          dummyArguments = prepareGpuKernelArguments(dummy_kernel, config);
-        DeviceInformation   deviceInfo;
         const DeviceContext deviceContext(deviceInfo);
         const DeviceStream  deviceStream(deviceContext, DeviceStreamPriority::Normal, false);
         launchGpuKernel(dummy_kernel, config, deviceStream, nullptr, "Dummy kernel", dummyArguments);
@@ -205,8 +159,8 @@ static DeviceStatus isDeviceFunctional(int dev_id, const cudaDeviceProp& dev_pro
     {
         // launchGpuKernel error is not fatal and should continue with marking the device bad
         fprintf(stderr,
-                "Error occurred while running dummy kernel sanity check on device #%d:\n %s\n", id,
-                formatExceptionMessageToString(ex).c_str());
+                "Error occurred while running dummy kernel sanity check on device #%d:\n %s\n",
+                deviceInfo.id, formatExceptionMessageToString(ex).c_str());
         return DeviceStatus::NonFunctional;
     }
 
@@ -215,12 +169,8 @@ static DeviceStatus isDeviceFunctional(int dev_id, const cudaDeviceProp& dev_pro
         return DeviceStatus::NonFunctional;
     }
 
-    /* destroy context if we created one */
-    if (id != -1)
-    {
-        cu_err = cudaDeviceReset();
-        CU_RET_ERR(cu_err, "cudaDeviceReset failed");
-    }
+    cu_err = cudaDeviceReset();
+    CU_RET_ERR(cu_err, "cudaDeviceReset failed");
 
     return DeviceStatus::Compatible;
 }
@@ -247,17 +197,16 @@ static bool isDeviceGenerationSupported(const cudaDeviceProp& deviceProperties)
  *  upon return. Note that this also means it is the caller's responsibility to
  *  reset the CUDA runtime state.
  *
- *  \param[in]  deviceId   the ID of the GPU to check.
- *  \param[in]  deviceProp the CUDA device properties of the device checked.
+ *  \param[in]  deviceInfo The device information on the device to check.
  *  \returns               the status of the requested device
  */
-static DeviceStatus checkDeviceStatus(int deviceId, const cudaDeviceProp& deviceProp)
+static DeviceStatus checkDeviceStatus(const DeviceInformation& deviceInfo)
 {
-    if (!isDeviceGenerationSupported(deviceProp))
+    if (!isDeviceGenerationSupported(deviceInfo.prop))
     {
         return DeviceStatus::Incompatible;
     }
-    return isDeviceFunctional(deviceId, deviceProp);
+    return isDeviceFunctional(deviceInfo);
 }
 
 bool isDeviceDetectionFunctional(std::string* errorMessage)
@@ -326,6 +275,9 @@ std::vector<std::unique_ptr<DeviceInformation>> findDevices()
                 "canPerformDeviceDetection() was not called appropriately beforehand."));
     }
 
+    /* things might go horribly wrong if cudart is not compatible with the driver */
+    numDevices = std::min(numDevices, c_cudaMaxDeviceCount);
+
     // We expect to start device support/sanity checks with a clean runtime error state
     gmx::ensureNoPendingCudaError("");
 
@@ -335,13 +287,14 @@ std::vector<std::unique_ptr<DeviceInformation>> findDevices()
         cudaDeviceProp prop;
         memset(&prop, 0, sizeof(cudaDeviceProp));
         stat = cudaGetDeviceProperties(&prop, i);
-        const DeviceStatus checkResult =
-                (stat != cudaSuccess) ? DeviceStatus::NonFunctional : checkDeviceStatus(i, prop);
 
-        deviceInfoList[i] = std::make_unique<DeviceInformation>();
+        deviceInfoList[i]       = std::make_unique<DeviceInformation>();
+        deviceInfoList[i]->id   = i;
+        deviceInfoList[i]->prop = prop;
+
+        const DeviceStatus checkResult = (stat != cudaSuccess) ? DeviceStatus::NonFunctional
+                                                               : checkDeviceStatus(*deviceInfoList[i]);
 
-        deviceInfoList[i]->id     = i;
-        deviceInfoList[i]->prop   = prop;
         deviceInfoList[i]->status = checkResult;
 
         if (checkResult != DeviceStatus::Compatible)