SYCL: 3D FFT using oneMKL
[alexxy/gromacs.git] / src / gromacs / fft / tests / fft.cpp
index dfe7189795e539d09d85a817a52459a81356d29d..e06d89e0d27794ba4a504a1f974eab331223ca0e 100644 (file)
@@ -364,7 +364,8 @@ TEST_F(FFTTest3D, Real5_6_9)
     checkRealGrid(realGridSize, realGridSizePadded, in_, outputRealGridValues);
 }
 
-#if GMX_GPU_CUDA || GMX_GPU_OPENCL || (GMX_GPU_SYCL && GMX_SYCL_HIPSYCL)
+#if GMX_GPU_CUDA || GMX_GPU_OPENCL \
+        || (GMX_GPU_SYCL && (GMX_SYCL_HIPSYCL || (GMX_SYCL_DPCPP && GMX_FFT_MKL)))
 TEST_F(FFTTest3D, GpuReal5_6_9)
 {
     // Ensure library resources are managed appropriately
@@ -397,10 +398,16 @@ TEST_F(FFTTest3D, GpuReal5_6_9)
         // Use std::copy to convert from double to real easily
         std::copy(inputdata, inputdata + sizeInReals, in_.begin());
 
+        // DPCPP uses oneMKL, which seems to have troubles with out-of-place transforms
+        const bool performOutOfPlaceFFT = !GMX_SYCL_DPCPP;
+
         SCOPED_TRACE("Allocating the device buffers");
         DeviceBuffer<float> realGrid, complexGrid;
         allocateDeviceBuffer(&realGrid, in_.size(), deviceContext);
-        allocateDeviceBuffer(&complexGrid, complexGridValues.size(), deviceContext);
+        if (performOutOfPlaceFFT)
+        {
+            allocateDeviceBuffer(&complexGrid, complexGridValues.size(), deviceContext);
+        }
 
 #    if GMX_GPU_CUDA
         const FftBackend backend = FftBackend::Cufft;
@@ -409,9 +416,10 @@ TEST_F(FFTTest3D, GpuReal5_6_9)
 #    elif GMX_GPU_SYCL
 #        if GMX_SYCL_HIPSYCL
         const FftBackend backend = FftBackend::SyclRocfft;
+#        elif GMX_SYCL_DPCPP && GMX_FFT_MKL
+        const FftBackend backend = FftBackend::SyclMkl;
 #        endif
 #    endif
-        const bool         performOutOfPlaceFFT    = true;
         MPI_Comm           comm                    = MPI_COMM_NULL;
         const bool         allocateGrid            = false;
         std::array<int, 1> gridSizesInXForEachRank = { 0 };
@@ -430,7 +438,7 @@ TEST_F(FFTTest3D, GpuReal5_6_9)
                           realGridSizePadded,
                           complexGridSizePadded,
                           &realGrid,
-                          &complexGrid);
+                          performOutOfPlaceFFT ? &complexGrid : &realGrid);
 
         // Transfer the real grid input data for the FFT
         copyToDeviceBuffer(
@@ -443,7 +451,7 @@ TEST_F(FFTTest3D, GpuReal5_6_9)
 
         // Check the complex grid (NB this data has not been normalized)
         copyFromDeviceBuffer(complexGridValues.data(),
-                             &complexGrid,
+                             performOutOfPlaceFFT ? &complexGrid : &realGrid,
                              0,
                              complexGridValues.size(),
                              deviceStream,
@@ -452,17 +460,20 @@ TEST_F(FFTTest3D, GpuReal5_6_9)
         checker.checkSequence(
                 complexGridValues.begin(), complexGridValues.end(), "ComplexGridAfterRealToComplex");
 
-        // Clear the real grid input data for the FFT so we can
-        // compute the back transform into it and observe that it did
-        // the work expected.
         std::vector<float> outputRealGridValues(in_.size());
-        copyToDeviceBuffer(&realGrid,
-                           outputRealGridValues.data(),
-                           0,
-                           outputRealGridValues.size(),
-                           deviceStream,
-                           GpuApiCallBehavior::Sync,
-                           nullptr);
+        if (performOutOfPlaceFFT)
+        {
+            // Clear the real grid input data for the FFT so we can
+            // compute the back transform into it and observe that it did
+            // the work expected.
+            copyToDeviceBuffer(&realGrid,
+                               outputRealGridValues.data(),
+                               0,
+                               outputRealGridValues.size(),
+                               deviceStream,
+                               GpuApiCallBehavior::Sync,
+                               nullptr);
+        }
 
         SCOPED_TRACE("Doing the back transform");
         gpu3dFft.perform3dFft(GMX_FFT_COMPLEX_TO_REAL, timingEvent);
@@ -481,7 +492,10 @@ TEST_F(FFTTest3D, GpuReal5_6_9)
 
         SCOPED_TRACE("Cleaning up");
         freeDeviceBuffer(&realGrid);
-        freeDeviceBuffer(&complexGrid);
+        if (performOutOfPlaceFFT)
+        {
+            freeDeviceBuffer(&complexGrid);
+        }
     }
 }