Activate GPU update support in SYCL build
[alexxy/gromacs.git] / src / gromacs / taskassignment / decidegpuusage.cpp
index 36c4650f0649311fe23d213ec27e7a66f737e389..5bac5adcef47df7fd726cdb4fd16aa3d7b02930a 100644 (file)
@@ -73,6 +73,7 @@
 #include "gromacs/utility/fatalerror.h"
 #include "gromacs/utility/gmxassert.h"
 #include "gromacs/utility/logger.h"
+#include "gromacs/utility/message_string_collector.h"
 #include "gromacs/utility/stringutil.h"
 
 
@@ -155,8 +156,50 @@ bool decideWhetherToUseGpusForNonbondedWithThreadMpi(const TaskTarget        non
     return haveAvailableDevices;
 }
 
+static bool canUseGpusForPme(const bool           useGpuForNonbonded,
+                             const TaskTarget     pmeTarget,
+                             const TaskTarget     pmeFftTarget,
+                             const gmx_hw_info_t& hardwareInfo,
+                             const t_inputrec&    inputrec,
+                             std::string*         errorMessage)
+{
+    if (pmeTarget == TaskTarget::Cpu)
+    {
+        return false;
+    }
+
+    std::string                 tempString;
+    gmx::MessageStringCollector errorReasons;
+    // Before changing the prefix string, make sure that it is not searched for in regression tests.
+    errorReasons.startContext("Cannot compute PME interactions on a GPU, because:");
+    errorReasons.appendIf(!useGpuForNonbonded, "Nonbonded interactions must also run on GPUs.");
+    errorReasons.appendIf(!pme_gpu_supports_build(&tempString), tempString);
+    errorReasons.appendIf(!pme_gpu_supports_hardware(hardwareInfo, &tempString), tempString);
+    errorReasons.appendIf(!pme_gpu_supports_input(inputrec, &tempString), tempString);
+    if (pmeFftTarget == TaskTarget::Cpu)
+    {
+        // User requested PME FFT on CPU, so we check whether we are able to use PME Mixed mode.
+        errorReasons.appendIf(!pme_gpu_mixed_mode_supports_input(inputrec, &tempString), tempString);
+    }
+    errorReasons.finishContext();
+
+    if (errorReasons.isEmpty())
+    {
+        return true;
+    }
+    else
+    {
+        if (pmeTarget == TaskTarget::Gpu && errorMessage != nullptr)
+        {
+            *errorMessage = errorReasons.toString();
+        }
+        return false;
+    }
+}
+
 bool decideWhetherToUseGpusForPmeWithThreadMpi(const bool              useGpuForNonbonded,
                                                const TaskTarget        pmeTarget,
+                                               const TaskTarget        pmeFftTarget,
                                                const int               numDevicesToUse,
                                                const std::vector<int>& userGpuTaskAssignment,
                                                const gmx_hw_info_t&    hardwareInfo,
@@ -165,11 +208,9 @@ bool decideWhetherToUseGpusForPmeWithThreadMpi(const bool              useGpuFor
                                                const int               numPmeRanksPerSimulation)
 {
     // First, exclude all cases where we can't run PME on GPUs.
-    if ((pmeTarget == TaskTarget::Cpu) || !useGpuForNonbonded || !pme_gpu_supports_build(nullptr)
-        || !pme_gpu_supports_hardware(hardwareInfo, nullptr) || !pme_gpu_supports_input(inputrec, nullptr))
+    if (!canUseGpusForPme(useGpuForNonbonded, pmeTarget, pmeFftTarget, hardwareInfo, inputrec, nullptr))
     {
-        // PME can't run on a GPU. If the user required that, we issue
-        // an error later.
+        // PME can't run on a GPU. If the user required that, we issue an error later.
         return false;
     }
 
@@ -330,6 +371,7 @@ bool decideWhetherToUseGpusForNonbonded(const TaskTarget          nonbondedTarge
 
 bool decideWhetherToUseGpusForPme(const bool              useGpuForNonbonded,
                                   const TaskTarget        pmeTarget,
+                                  const TaskTarget        pmeFftTarget,
                                   const std::vector<int>& userGpuTaskAssignment,
                                   const gmx_hw_info_t&    hardwareInfo,
                                   const t_inputrec&       inputrec,
@@ -337,43 +379,12 @@ bool decideWhetherToUseGpusForPme(const bool              useGpuForNonbonded,
                                   const int               numPmeRanksPerSimulation,
                                   const bool              gpusWereDetected)
 {
-    if (pmeTarget == TaskTarget::Cpu)
-    {
-        return false;
-    }
-
-    if (!useGpuForNonbonded)
-    {
-        if (pmeTarget == TaskTarget::Gpu)
-        {
-            GMX_THROW(NotImplementedError(
-                    "PME on GPUs is only supported when nonbonded interactions run on GPUs also."));
-        }
-        return false;
-    }
-
     std::string message;
-    if (!pme_gpu_supports_build(&message))
+    if (!canUseGpusForPme(useGpuForNonbonded, pmeTarget, pmeFftTarget, hardwareInfo, inputrec, &message))
     {
-        if (pmeTarget == TaskTarget::Gpu)
-        {
-            GMX_THROW(NotImplementedError("Cannot compute PME interactions on a GPU, because " + message));
-        }
-        return false;
-    }
-    if (!pme_gpu_supports_hardware(hardwareInfo, &message))
-    {
-        if (pmeTarget == TaskTarget::Gpu)
-        {
-            GMX_THROW(NotImplementedError("Cannot compute PME interactions on a GPU, because " + message));
-        }
-        return false;
-    }
-    if (!pme_gpu_supports_input(inputrec, &message))
-    {
-        if (pmeTarget == TaskTarget::Gpu)
+        if (!message.empty())
         {
-            GMX_THROW(NotImplementedError("Cannot compute PME interactions on a GPU, because " + message));
+            GMX_THROW(InconsistentInputError(message));
         }
         return false;
     }
@@ -607,9 +618,9 @@ bool decideWhetherToUseGpuForUpdate(const bool                     isDomainDecom
     {
         errorMessage += "Compatible GPUs must have been found.\n";
     }
-    if (!GMX_GPU_CUDA)
+    if (!(GMX_GPU_CUDA || GMX_GPU_SYCL))
     {
-        errorMessage += "Only a CUDA build is supported.\n";
+        errorMessage += "Only CUDA and SYCL builds are supported.\n";
     }
     if (inputrec.eI != IntegrationAlgorithm::MD)
     {