Add GPU 3D FFT tests
[alexxy/gromacs.git] / src / gromacs / fft / tests / fft.cpp
index 3ef73d4f3f7133bcd60264a2e64fe58fe16c7794..b743d3d1149280a975cb43e905771b8ecf43ea21 100644 (file)
 
 #include "gromacs/fft/fft.h"
 
+#include "config.h"
+
 #include <algorithm>
 #include <vector>
 
+#include <gmock/gmock.h>
 #include <gtest/gtest.h>
 
+#include "gromacs/fft/gpu_3dfft.h"
 #include "gromacs/fft/parallel_3dfft.h"
+#include "gromacs/gpu_utils/clfftinitializer.h"
+#if GMX_GPU
+#    include "gromacs/gpu_utils/devicebuffer.h"
+#endif
 #include "gromacs/utility/stringutil.h"
 
 #include "testutils/refdata.h"
+#include "testutils/test_hardware_environment.h"
 #include "testutils/testasserts.h"
+#include "testutils/testmatchers.h"
 
-namespace
+namespace gmx
+{
+namespace test
 {
 
 /*! \brief Input data for FFT tests.
@@ -109,25 +121,23 @@ const double inputdata[] = {
 class BaseFFTTest : public ::testing::Test
 {
 public:
-    BaseFFTTest() : checker_(data_.rootChecker()), flags_(GMX_FFT_FLAG_CONSERVATIVE)
-    {
-        // TODO: These tolerances are just something that has been observed
-        // to be sufficient to pass the tests.  It would be nicer to
-        // actually argue about why they are sufficient (or what is).
-        checker_.setDefaultTolerance(gmx::test::relativeToleranceAsPrecisionDependentUlp(10.0, 64, 512));
-    }
+    BaseFFTTest() : flags_(GMX_FFT_FLAG_CONSERVATIVE) {}
     ~BaseFFTTest() override { gmx_fft_cleanup(); }
 
-    gmx::test::TestReferenceData    data_;
-    gmx::test::TestReferenceChecker checker_;
-    std::vector<real>               in_, out_;
-    int                             flags_;
+    TestReferenceData data_;
+    std::vector<real> in_, out_;
+    int               flags_;
+    // TODO: These tolerances are just something that has been observed
+    // to be sufficient to pass the tests.  It would be nicer to
+    // actually argue about why they are sufficient (or what is).
+    // Should work for both one-way and forward+backward transform.
+    FloatingPointTolerance defaultTolerance_ = relativeToleranceAsPrecisionDependentUlp(10.0, 64, 512);
 };
 
 class FFTTest : public BaseFFTTest
 {
 public:
-    FFTTest() : fft_(nullptr) {}
+    FFTTest() : fft_(nullptr) { checker_.setDefaultTolerance(defaultTolerance_); }
     ~FFTTest() override
     {
         if (fft_)
@@ -135,13 +145,14 @@ public:
             gmx_fft_destroy(fft_);
         }
     }
-    gmx_fft_t fft_;
+    TestReferenceChecker checker_ = data_.rootChecker();
+    gmx_fft_t            fft_;
 };
 
 class ManyFFTTest : public BaseFFTTest
 {
 public:
-    ManyFFTTest() : fft_(nullptr) {}
+    ManyFFTTest() : fft_(nullptr) { checker_.setDefaultTolerance(defaultTolerance_); }
     ~ManyFFTTest() override
     {
         if (fft_)
@@ -149,7 +160,8 @@ public:
             gmx_many_fft_destroy(fft_);
         }
     }
-    gmx_fft_t fft_;
+    TestReferenceChecker checker_ = data_.rootChecker();
+    gmx_fft_t            fft_;
 };
 
 
@@ -159,11 +171,11 @@ class FFTTest1D : public FFTTest, public ::testing::WithParamInterface<int>
 {
 };
 
-class FFFTest3D : public BaseFFTTest
+class FFTTest3D : public BaseFFTTest
 {
 public:
-    FFFTest3D() : fft_(nullptr) {}
-    ~FFFTest3D() override
+    FFTTest3D() : fft_(nullptr) {}
+    ~FFTTest3D() override
     {
         if (fft_)
         {
@@ -276,44 +288,172 @@ TEST_F(FFTTest, Real2DLength18_15Test)
     //    _checker.checkSequenceArray(rx*ny, out, "backward");
 }
 
+namespace
+{
+
+/*! \brief Check that the real grid after forward and backward
+ * 3D transforms matches the input real grid. */
+void checkRealGrid(const ivec           realGridSize,
+                   const ivec           realGridSizePadded,
+                   ArrayRef<const real> inputRealGrid,
+                   ArrayRef<real>       outputRealGridValues)
+{
+    // Normalize the output (as the implementation does not
+    // normalize either FFT)
+    const real normalizationConstant = 1.0 / (realGridSize[XX] * realGridSize[YY] * realGridSize[ZZ]);
+    std::transform(outputRealGridValues.begin(),
+                   outputRealGridValues.end(),
+                   outputRealGridValues.begin(),
+                   [normalizationConstant](const real r) { return r * normalizationConstant; });
+    // Check the real grid, skipping unused data from the padding
+    const auto realGridTolerance = relativeToleranceAsFloatingPoint(10, 1e-6);
+    for (int i = 0; i < realGridSize[XX] * realGridSize[YY]; i++)
+    {
+        auto expected =
+                arrayRefFromArray(inputRealGrid.data() + i * realGridSizePadded[ZZ], realGridSize[ZZ]);
+        auto actual = arrayRefFromArray(outputRealGridValues.data() + i * realGridSizePadded[ZZ],
+                                        realGridSize[ZZ]);
+        EXPECT_THAT(actual, Pointwise(RealEq(realGridTolerance), expected))
+                << formatString("checking backward transform part %d", i);
+    }
+}
+
+} // namespace
+
 // TODO: test with threads and more than 1 MPI ranks
-TEST_F(FFFTest3D, Real5_6_9)
+TEST_F(FFTTest3D, Real5_6_9)
 {
-    int        ndata[] = { 5, 6, 9 };
-    MPI_Comm   comm[]  = { MPI_COMM_NULL, MPI_COMM_NULL };
+    int        realGridSize[] = { 5, 6, 9 };
+    MPI_Comm   comm[]         = { MPI_COMM_NULL, MPI_COMM_NULL };
     real*      rdata;
     t_complex* cdata;
-    ivec       local_ndata, offset, rsize, csize, complex_order;
+    ivec       local_ndata, offset, realGridSizePadded, complexGridSizePadded, complex_order;
+    TestReferenceChecker checker(data_.rootChecker());
+    checker.setDefaultTolerance(defaultTolerance_);
 
-    gmx_parallel_3dfft_init(&fft_, ndata, &rdata, &cdata, comm, TRUE, 1);
+    gmx_parallel_3dfft_init(&fft_, realGridSize, &rdata, &cdata, comm, TRUE, 1);
 
-    gmx_parallel_3dfft_real_limits(fft_, local_ndata, offset, rsize);
-    gmx_parallel_3dfft_complex_limits(fft_, complex_order, local_ndata, offset, csize);
-    checker_.checkVector(rsize, "rsize");
-    checker_.checkVector(csize, "csize");
-    int size        = csize[0] * csize[1] * csize[2];
+    gmx_parallel_3dfft_real_limits(fft_, local_ndata, offset, realGridSizePadded);
+    gmx_parallel_3dfft_complex_limits(fft_, complex_order, local_ndata, offset, complexGridSizePadded);
+    checker.checkVector(realGridSizePadded, "realGridSizePadded");
+    checker.checkVector(complexGridSizePadded, "complexGridSizePadded");
+    int size = complexGridSizePadded[0] * complexGridSizePadded[1] * complexGridSizePadded[2];
     int sizeInBytes = size * sizeof(t_complex);
     int sizeInReals = sizeInBytes / sizeof(real);
 
+    // Prepare the real grid
     in_ = std::vector<real>(sizeInReals);
     // Use std::copy to convert from double to real easily
     std::copy(inputdata, inputdata + sizeInReals, in_.begin());
     // Use memcpy to convert to t_complex easily
     memcpy(rdata, in_.data(), sizeInBytes);
+
+    // Do the forward FFT to compute the complex grid
     gmx_parallel_3dfft_execute(fft_, GMX_FFT_REAL_TO_COMPLEX, 0, nullptr);
-    // TODO use std::complex and add checkComplex for it
-    checker_.checkSequenceArray(size * 2, reinterpret_cast<real*>(cdata), "forward");
 
-    // Use std::copy to convert from double to real easily
-    std::copy(inputdata, inputdata + sizeInReals, in_.begin());
-    // Use memcpy to convert to t_complex easily
-    memcpy(cdata, in_.data(), sizeInBytes);
+    // Check the complex grid (NB this data has not been normalized)
+    ArrayRef<real> complexGridValues = arrayRefFromArray(reinterpret_cast<real*>(cdata), size * 2);
+    checker.checkSequence(
+            complexGridValues.begin(), complexGridValues.end(), "ComplexGridAfterRealToComplex");
+
+    // Do the back transform
     gmx_parallel_3dfft_execute(fft_, GMX_FFT_COMPLEX_TO_REAL, 0, nullptr);
-    for (int i = 0; i < ndata[0] * ndata[1]; i++) // check sequence but skip unused data
+
+    ArrayRef<real> outputRealGridValues = arrayRefFromArray(
+            rdata, realGridSizePadded[XX] * realGridSizePadded[YY] * realGridSizePadded[ZZ]);
+    checkRealGrid(realGridSize, realGridSizePadded, in_, outputRealGridValues);
+}
+
+#if GMX_GPU
+TEST_F(FFTTest3D, GpuReal5_6_9)
+{
+    // Ensure library resources are managed appropriately
+    ClfftInitializer clfftInitializer;
+    for (const auto& testDevice : getTestHardwareEnvironment()->getTestDeviceList())
     {
-        checker_.checkSequenceArray(
-                ndata[2], rdata + i * rsize[2], gmx::formatString("backward %d", i).c_str());
+        TestReferenceChecker checker(data_.rootChecker()); // Must be inside the loop to avoid warnings
+        checker.setDefaultTolerance(defaultTolerance_);
+
+        const DeviceContext& deviceContext = testDevice->deviceContext();
+        setActiveDevice(testDevice->deviceInfo());
+        const DeviceStream& deviceStream = testDevice->deviceStream();
+
+        ivec realGridSize       = { 5, 6, 9 };
+        ivec realGridSizePadded = { realGridSize[XX], realGridSize[YY], (realGridSize[ZZ] / 2 + 1) * 2 };
+        ivec complexGridSizePadded = { realGridSize[XX], realGridSize[YY], (realGridSize[ZZ] / 2) + 1 };
+
+        checker.checkVector(realGridSizePadded, "realGridSizePadded");
+        checker.checkVector(complexGridSizePadded, "complexGridSizePadded");
+
+        int size = complexGridSizePadded[0] * complexGridSizePadded[1] * complexGridSizePadded[2];
+        int sizeInReals = size * 2;
+
+        // Set up the complex grid. Complex numbers take twice the
+        // memory.
+        std::vector<float> complexGridValues(sizeInReals);
+        in_.resize(sizeInReals);
+        // Use std::copy to convert from double to real easily
+        std::copy(inputdata, inputdata + sizeInReals, in_.begin());
+
+        // Allocate the device buffers
+        DeviceBuffer<float> realGrid, complexGrid;
+        allocateDeviceBuffer(&realGrid, in_.size(), deviceContext);
+        allocateDeviceBuffer(&complexGrid, complexGridValues.size(), deviceContext);
+
+        const bool useDecomposition     = false;
+        const bool performOutOfPlaceFFT = true;
+        Gpu3dFft   gpu3dFft(realGridSize,
+                          realGridSizePadded,
+                          complexGridSizePadded,
+                          useDecomposition,
+                          performOutOfPlaceFFT,
+                          deviceContext,
+                          deviceStream,
+                          realGrid,
+                          complexGrid);
+
+        // Transfer the real grid input data for the FFT
+        copyToDeviceBuffer(
+                &realGrid, in_.data(), 0, in_.size(), deviceStream, GpuApiCallBehavior::Sync, nullptr);
+
+        // Do the forward FFT to compute the complex grid
+        CommandEvent* timingEvent = nullptr;
+        gpu3dFft.perform3dFft(GMX_FFT_REAL_TO_COMPLEX, timingEvent);
+        deviceStream.synchronize();
+
+        // Check the complex grid (NB this data has not been normalized)
+        copyFromDeviceBuffer(complexGridValues.data(),
+                             &complexGrid,
+                             0,
+                             complexGridValues.size(),
+                             deviceStream,
+                             GpuApiCallBehavior::Sync,
+                             nullptr);
+        checker.checkSequence(
+                complexGridValues.begin(), complexGridValues.end(), "ComplexGridAfterRealToComplex");
+
+        // Do the back transform
+        gpu3dFft.perform3dFft(GMX_FFT_COMPLEX_TO_REAL, timingEvent);
+        deviceStream.synchronize();
+
+        // Transfer the real grid back from the device
+        std::vector<float> outputRealGridValues(in_.size());
+        copyFromDeviceBuffer(outputRealGridValues.data(),
+                             &realGrid,
+                             0,
+                             outputRealGridValues.size(),
+                             deviceStream,
+                             GpuApiCallBehavior::Sync,
+                             nullptr);
+
+        checkRealGrid(realGridSize, realGridSizePadded, in_, outputRealGridValues);
+
+        freeDeviceBuffer(&realGrid);
+        freeDeviceBuffer(&complexGrid);
     }
 }
 
-} // namespace
+#endif
+
+} // namespace test
+} // namespace gmx