Move DeviceInfo into GPU traits
[alexxy/gromacs.git] / src / gromacs / gpu_utils / gpu_utils_ocl.cpp
index 8770e6862d1637337af772388396e4c65b14d42e..d4e9daddd099a0fc77768131f52c1a3dc74d0bd1 100644 (file)
@@ -129,26 +129,26 @@ static std::string makeOpenClInternalErrorString(const char* message, cl_int sta
 }
 
 /*!
- * \brief Checks that device \c devInfo is sane (ie can run a kernel).
+ * \brief Checks that device \c deviceInfo is sane (ie can run a kernel).
  *
  * Compiles and runs a dummy kernel to determine whether the given
  * OpenCL device functions properly.
  *
  *
- * \param[in]  devInfo         The device info pointer.
+ * \param[in]  deviceInfo      The device info pointer.
  * \param[out] errorMessage    An error message related to a failing OpenCL API call.
  * \throws     std::bad_alloc  When out of memory.
  * \returns                    Whether the device passed sanity checks
  */
-static bool isDeviceSane(const gmx_device_info_t* devInfo, std::string* errorMessage)
+static bool isDeviceSane(const DeviceInformation* deviceInfo, std::string* errorMessage)
 {
     cl_context_properties properties[] = {
-        CL_CONTEXT_PLATFORM, reinterpret_cast<cl_context_properties>(devInfo->ocl_gpu_id.ocl_platform_id), 0
+        CL_CONTEXT_PLATFORM, reinterpret_cast<cl_context_properties>(deviceInfo->oclPlatformId), 0
     };
     // uncrustify spacing
 
     cl_int    status;
-    auto      deviceId = devInfo->ocl_gpu_id.ocl_device_id;
+    auto      deviceId = deviceInfo->oclDeviceId;
     ClContext context(clCreateContext(properties, 1, &deviceId, nullptr, nullptr, &status));
     if (status != CL_SUCCESS)
     {
@@ -198,15 +198,15 @@ static bool isDeviceSane(const gmx_device_info_t* devInfo, std::string* errorMes
 }
 
 /*!
- * \brief Checks that device \c devInfo is compatible with GROMACS.
+ * \brief Checks that device \c deviceInfo is compatible with GROMACS.
  *
  *  Vendor and OpenCL version support checks are executed an the result
  *  of these returned.
  *
- * \param[in]  devInfo         The device info pointer.
- * \returns                    The result of the compatibility checks.
+ * \param[in]  deviceInfo  The device info pointer.
+ * \returns                The result of the compatibility checks.
  */
-static int isDeviceSupported(const gmx_device_info_t* devInfo)
+static int isDeviceSupported(const DeviceInformation* deviceInfo)
 {
     if (getenv("GMX_OCL_DISABLE_COMPATIBILITY_CHECK") != nullptr)
     {
@@ -222,7 +222,7 @@ static int isDeviceSupported(const gmx_device_info_t* devInfo)
     // the device which has the following format:
     //      OpenCL<space><major_version.minor_version><space><vendor-specific information>
     unsigned int deviceVersionMinor, deviceVersionMajor;
-    const int    valuesScanned = std::sscanf(devInfo->device_version, "OpenCL %u.%u",
+    const int    valuesScanned = std::sscanf(deviceInfo->device_version, "OpenCL %u.%u",
                                           &deviceVersionMajor, &deviceVersionMinor);
     const bool   versionLargeEnough =
             ((valuesScanned == 2)
@@ -234,7 +234,7 @@ static int isDeviceSupported(const gmx_device_info_t* devInfo)
     }
 
     /* Only AMD, Intel, and NVIDIA GPUs are supported for now */
-    switch (devInfo->deviceVendor)
+    switch (deviceInfo->deviceVendor)
     {
         case DeviceVendor::Nvidia: return egpuCompatible;
         case DeviceVendor::Amd:
@@ -258,7 +258,7 @@ static int isDeviceSupported(const gmx_device_info_t* devInfo)
  * \returns  An e_gpu_detect_res_t to indicate how the GPU coped with
  *           the sanity and compatibility check.
  */
-static int checkGpu(size_t deviceId, const gmx_device_info_t* deviceInfo)
+static int checkGpu(size_t deviceId, const DeviceInformation* deviceInfo)
 {
 
     int supportStatus = isDeviceSupported(deviceInfo);
@@ -393,7 +393,7 @@ void findGpus(gmx_gpu_info_t* gpu_info)
             break;
         }
 
-        snew(gpu_info->gpu_dev, gpu_info->n_dev);
+        snew(gpu_info->deviceInfo, gpu_info->n_dev);
 
         {
             int           device_index;
@@ -421,47 +421,47 @@ void findGpus(gmx_gpu_info_t* gpu_info)
 
                 for (unsigned int j = 0; j < ocl_device_count; j++)
                 {
-                    gpu_info->gpu_dev[device_index].ocl_gpu_id.ocl_platform_id = ocl_platform_ids[i];
-                    gpu_info->gpu_dev[device_index].ocl_gpu_id.ocl_device_id   = ocl_device_ids[j];
+                    gpu_info->deviceInfo[device_index].oclPlatformId = ocl_platform_ids[i];
+                    gpu_info->deviceInfo[device_index].oclDeviceId   = ocl_device_ids[j];
 
-                    gpu_info->gpu_dev[device_index].device_name[0] = 0;
+                    gpu_info->deviceInfo[device_index].device_name[0] = 0;
                     clGetDeviceInfo(ocl_device_ids[j], CL_DEVICE_NAME,
-                                    sizeof(gpu_info->gpu_dev[device_index].device_name),
-                                    gpu_info->gpu_dev[device_index].device_name, nullptr);
+                                    sizeof(gpu_info->deviceInfo[device_index].device_name),
+                                    gpu_info->deviceInfo[device_index].device_name, nullptr);
 
-                    gpu_info->gpu_dev[device_index].device_version[0] = 0;
+                    gpu_info->deviceInfo[device_index].device_version[0] = 0;
                     clGetDeviceInfo(ocl_device_ids[j], CL_DEVICE_VERSION,
-                                    sizeof(gpu_info->gpu_dev[device_index].device_version),
-                                    gpu_info->gpu_dev[device_index].device_version, nullptr);
+                                    sizeof(gpu_info->deviceInfo[device_index].device_version),
+                                    gpu_info->deviceInfo[device_index].device_version, nullptr);
 
-                    gpu_info->gpu_dev[device_index].vendorName[0] = 0;
+                    gpu_info->deviceInfo[device_index].vendorName[0] = 0;
                     clGetDeviceInfo(ocl_device_ids[j], CL_DEVICE_VENDOR,
-                                    sizeof(gpu_info->gpu_dev[device_index].vendorName),
-                                    gpu_info->gpu_dev[device_index].vendorName, nullptr);
+                                    sizeof(gpu_info->deviceInfo[device_index].vendorName),
+                                    gpu_info->deviceInfo[device_index].vendorName, nullptr);
 
-                    gpu_info->gpu_dev[device_index].compute_units = 0;
+                    gpu_info->deviceInfo[device_index].compute_units = 0;
                     clGetDeviceInfo(ocl_device_ids[j], CL_DEVICE_MAX_COMPUTE_UNITS,
-                                    sizeof(gpu_info->gpu_dev[device_index].compute_units),
-                                    &(gpu_info->gpu_dev[device_index].compute_units), nullptr);
+                                    sizeof(gpu_info->deviceInfo[device_index].compute_units),
+                                    &(gpu_info->deviceInfo[device_index].compute_units), nullptr);
 
-                    gpu_info->gpu_dev[device_index].adress_bits = 0;
+                    gpu_info->deviceInfo[device_index].adress_bits = 0;
                     clGetDeviceInfo(ocl_device_ids[j], CL_DEVICE_ADDRESS_BITS,
-                                    sizeof(gpu_info->gpu_dev[device_index].adress_bits),
-                                    &(gpu_info->gpu_dev[device_index].adress_bits), nullptr);
+                                    sizeof(gpu_info->deviceInfo[device_index].adress_bits),
+                                    &(gpu_info->deviceInfo[device_index].adress_bits), nullptr);
 
-                    gpu_info->gpu_dev[device_index].deviceVendor =
-                            getDeviceVendor(gpu_info->gpu_dev[device_index].vendorName);
+                    gpu_info->deviceInfo[device_index].deviceVendor =
+                            getDeviceVendor(gpu_info->deviceInfo[device_index].vendorName);
 
                     clGetDeviceInfo(ocl_device_ids[j], CL_DEVICE_MAX_WORK_ITEM_SIZES, 3 * sizeof(size_t),
-                                    &gpu_info->gpu_dev[device_index].maxWorkItemSizes, nullptr);
+                                    &gpu_info->deviceInfo[device_index].maxWorkItemSizes, nullptr);
 
                     clGetDeviceInfo(ocl_device_ids[j], CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t),
-                                    &gpu_info->gpu_dev[device_index].maxWorkGroupSize, nullptr);
+                                    &gpu_info->deviceInfo[device_index].maxWorkGroupSize, nullptr);
 
-                    gpu_info->gpu_dev[device_index].stat =
-                            gmx::checkGpu(device_index, gpu_info->gpu_dev + device_index);
+                    gpu_info->deviceInfo[device_index].stat =
+                            gmx::checkGpu(device_index, gpu_info->deviceInfo + device_index);
 
-                    if (egpuCompatible == gpu_info->gpu_dev[device_index].stat)
+                    if (egpuCompatible == gpu_info->deviceInfo[device_index].stat)
                     {
                         gpu_info->n_dev_compatible++;
                     }
@@ -479,16 +479,13 @@ void findGpus(gmx_gpu_info_t* gpu_info)
                 int last = -1;
                 for (int i = 0; i < gpu_info->n_dev; i++)
                 {
-                    if (gpu_info->gpu_dev[i].deviceVendor == DeviceVendor::Amd)
+                    if (gpu_info->deviceInfo[i].deviceVendor == DeviceVendor::Amd)
                     {
                         last++;
 
                         if (last < i)
                         {
-                            gmx_device_info_t ocl_gpu_info;
-                            ocl_gpu_info            = gpu_info->gpu_dev[i];
-                            gpu_info->gpu_dev[i]    = gpu_info->gpu_dev[last];
-                            gpu_info->gpu_dev[last] = ocl_gpu_info;
+                            std::swap(gpu_info->deviceInfo[i], gpu_info->deviceInfo[last]);
                         }
                     }
                 }
@@ -498,16 +495,13 @@ void findGpus(gmx_gpu_info_t* gpu_info)
                 {
                     for (int i = 0; i < gpu_info->n_dev; i++)
                     {
-                        if (gpu_info->gpu_dev[i].deviceVendor == DeviceVendor::Nvidia)
+                        if (gpu_info->deviceInfo[i].deviceVendor == DeviceVendor::Nvidia)
                         {
                             last++;
 
                             if (last < i)
                             {
-                                gmx_device_info_t ocl_gpu_info;
-                                ocl_gpu_info            = gpu_info->gpu_dev[i];
-                                gpu_info->gpu_dev[i]    = gpu_info->gpu_dev[last];
-                                gpu_info->gpu_dev[last] = ocl_gpu_info;
+                                std::swap(gpu_info->deviceInfo[i], gpu_info->deviceInfo[last]);
                             }
                         }
                     }
@@ -532,7 +526,7 @@ void get_gpu_device_info_string(char* s, const gmx_gpu_info_t& gpu_info, int ind
         return;
     }
 
-    gmx_device_info_t* dinfo = &gpu_info.gpu_dev[index];
+    DeviceInformation* dinfo = &gpu_info.deviceInfo[index];
 
     bool bGpuExists = (dinfo->stat != egpuNonexistent && dinfo->stat != egpuInsane);
 
@@ -548,7 +542,7 @@ void get_gpu_device_info_string(char* s, const gmx_gpu_info_t& gpu_info, int ind
 }
 
 
-void init_gpu(const gmx_device_info_t* deviceInfo)
+void init_gpu(const DeviceInformation* deviceInfo)
 {
     assert(deviceInfo);
 
@@ -570,21 +564,21 @@ void init_gpu(const gmx_device_info_t* deviceInfo)
     }
 }
 
-gmx_device_info_t* getDeviceInfo(const gmx_gpu_info_t& gpu_info, int deviceId)
+DeviceInformation* getDeviceInfo(const gmx_gpu_info_t& gpu_info, int deviceId)
 {
     if (deviceId < 0 || deviceId >= gpu_info.n_dev)
     {
         gmx_incons("Invalid GPU deviceId requested");
     }
-    return &gpu_info.gpu_dev[deviceId];
+    return &gpu_info.deviceInfo[deviceId];
 }
 
 size_t sizeof_gpu_dev_info()
 {
-    return sizeof(gmx_device_info_t);
+    return sizeof(DeviceInformation);
 }
 
 int gpu_info_get_stat(const gmx_gpu_info_t& info, int index)
 {
-    return info.gpu_dev[index].stat;
+    return info.deviceInfo[index].stat;
 }