Remove thread-MPI limitation for GPU PP Halo exchange
[alexxy/gromacs.git] / src / gromacs / mdrun / runner.cpp
index 68da5a1b17251bb9f8576011bd5ab93efac54107..33f9145889d2ce1899c38971f230e823691561e3 100644 (file)
 #include "gromacs/utility/programcontext.h"
 #include "gromacs/utility/smalloc.h"
 #include "gromacs/utility/stringutil.h"
+#include "gromacs/utility/mpiinfo.h"
 
 #include "isimulator.h"
 #include "membedholder.h"
@@ -206,13 +207,66 @@ static DevelopmentFeatureFlags manageDevelopmentFeatures(const gmx::MDLogger& md
 
     devFlags.enableGpuBufferOps =
             GMX_GPU_CUDA && useGpuForNonbonded && (getenv("GMX_USE_GPU_BUFFER_OPS") != nullptr);
-    devFlags.enableGpuHaloExchange = GMX_GPU_CUDA && GMX_THREAD_MPI && getenv("GMX_GPU_DD_COMMS") != nullptr;
+    devFlags.enableGpuHaloExchange = GMX_GPU_CUDA && getenv("GMX_GPU_DD_COMMS") != nullptr;
     devFlags.forceGpuUpdateDefault = (getenv("GMX_FORCE_UPDATE_DEFAULT_GPU") != nullptr) || GMX_FAHCORE;
     devFlags.enableGpuPmePPComm =
             GMX_GPU_CUDA && GMX_THREAD_MPI && getenv("GMX_GPU_PME_PP_COMMS") != nullptr;
 
 #pragma GCC diagnostic pop
 
+    // Direct GPU comm path is being used with CUDA_AWARE_MPI
+    // make sure underlying MPI implementation is CUDA-aware
+    if (!GMX_THREAD_MPI && devFlags.enableGpuHaloExchange)
+    {
+        const bool haveDetectedCudaAwareMpi =
+                (checkMpiCudaAwareSupport() == CudaAwareMpiStatus::Supported);
+        const bool forceCudaAwareMpi = (getenv("GMX_FORCE_CUDA_AWARE_MPI") != nullptr);
+
+        if (!haveDetectedCudaAwareMpi && forceCudaAwareMpi)
+        {
+            // CUDA-aware support not detected in MPI library but, user has forced it's use
+            GMX_LOG(mdlog.warning)
+                    .asParagraph()
+                    .appendTextFormatted(
+                            "This run has forced use of 'CUDA-aware MPI'. "
+                            "But, GROMACS cannot determine if underlying MPI "
+                            "is CUDA-aware. GROMACS recommends use of latest openMPI version "
+                            "for CUDA-aware support. "
+                            "If you observe failures at runtime, try unsetting "
+                            "GMX_FORCE_CUDA_AWARE_MPI environment variable.");
+        }
+
+        if (haveDetectedCudaAwareMpi || forceCudaAwareMpi)
+        {
+            devFlags.usingCudaAwareMpi = true;
+            GMX_LOG(mdlog.warning)
+                    .asParagraph()
+                    .appendTextFormatted("Using CUDA-aware MPI for 'GPU halo exchange' feature.");
+        }
+        else
+        {
+            if (devFlags.enableGpuHaloExchange)
+            {
+                GMX_LOG(mdlog.warning)
+                        .asParagraph()
+                        .appendTextFormatted(
+                                "GMX_GPU_DD_COMMS environment variable detected, but the 'GPU "
+                                "halo exchange' feature will not be enabled as GROMACS couldn't "
+                                "detect CUDA_aware support in underlying MPI implementation.");
+                devFlags.enableGpuHaloExchange = false;
+            }
+
+            GMX_LOG(mdlog.warning)
+                    .asParagraph()
+                    .appendTextFormatted(
+                            "GROMACS recommends use of latest OpenMPI version for CUDA-aware "
+                            "support. "
+                            "If you are certain about CUDA-aware support in your MPI library, "
+                            "you can force it's use by setting environment variable "
+                            " GMX_FORCE_CUDA_AWARE_MPI.");
+        }
+    }
+
     if (devFlags.enableGpuBufferOps)
     {
         GMX_LOG(mdlog.warning)
@@ -2051,7 +2105,14 @@ int Mdrunner::mdrunner()
     {
         physicalNodeComm.barrier();
     }
-    releaseDevice(deviceInfo);
+
+    if (!devFlags.usingCudaAwareMpi)
+    {
+        // Don't reset GPU in case of CUDA-AWARE MPI
+        // UCX creates CUDA buffers which are cleaned-up as part of MPI_Finalize()
+        // resetting the device before MPI_Finalize() results in crashes inside UCX
+        releaseDevice(deviceInfo);
+    }
 
     /* Does what it says */
     print_date_and_time(fplog, cr->nodeid, "Finished mdrun", gmx_gettime());