Implement PME solve in SYCL
authorMark Abraham <mark.j.abraham@gmail.com>
Wed, 13 Oct 2021 16:06:54 +0000 (16:06 +0000)
committerAndrey Alekseenko <al42and@gmail.com>
Wed, 13 Oct 2021 16:06:54 +0000 (16:06 +0000)
Refs #3965

src/gromacs/ewald/CMakeLists.txt
src/gromacs/ewald/pme_gpu_program_impl_sycl.cpp
src/gromacs/ewald/pme_solve_sycl.cpp [new file with mode: 0644]
src/gromacs/ewald/pme_solve_sycl.h [new file with mode: 0644]
src/gromacs/ewald/tests/pmesolvetest.cpp
src/gromacs/gpu_utils/sycl_kernel_utils.h

index 78d2a4ed02b2b5eebdf1fd3533c4a12e24e038fb..fd24566ee0c883f576b1f9ed6ab3e7efb32e2238 100644 (file)
@@ -92,8 +92,9 @@ elseif (GMX_GPU_SYCL)
         pme_gpu.cpp
         pme_gpu_internal.cpp
         pme_gpu_program_impl_sycl.cpp
-        pme_spread_sycl.cpp
         pme_gpu_timings.cpp
+        pme_solve_sycl.cpp
+        pme_spread_sycl.cpp
         )
     _gmx_add_files_to_property(SYCL_SOURCES
         pme_gather_sycl.cpp
@@ -102,6 +103,7 @@ elseif (GMX_GPU_SYCL)
         pme_gpu_program_impl_sycl.cpp
         pme_gpu_3dfft_sycl.cpp
         pme_gpu_timings.cpp
+        pme_solve_sycl.cpp
         pme_spread_sycl.cpp
       )
 else()
index 196cff26ae3881a4bbee79fe6f1e5296971bf349..a65a2828b074859802b30735aaafa7da323e21f2 100644 (file)
@@ -50,6 +50,7 @@
 
 #include "pme_gpu_program_impl.h"
 #include "pme_gather_sycl.h"
+#include "pme_solve_sycl.h"
 #include "pme_spread_sycl.h"
 
 #include "pme_gpu_constants.h"
@@ -62,6 +63,9 @@ constexpr int c_pmeOrder = 4;
 constexpr bool c_wrapX = true;
 constexpr bool c_wrapY = true;
 
+constexpr int c_stateA = 0;
+constexpr int c_stateB = 1;
+
 static int subGroupSizeFromVendor(const DeviceInformation& deviceInfo)
 {
     switch (deviceInfo.deviceVendor)
@@ -96,9 +100,20 @@ static int subGroupSizeFromVendor(const DeviceInformation& deviceInfo)
     INSTANTIATE_##x(order, 2, ThreadsPerAtom::Order, subGroupSize);        \
     INSTANTIATE_##x(order, 2, ThreadsPerAtom::OrderSquared, subGroupSize);
 
+#define INSTANTIATE_SOLVE(subGroupSize)                                                     \
+    extern template class PmeSolveKernel<GridOrdering::XYZ, false, c_stateA, subGroupSize>; \
+    extern template class PmeSolveKernel<GridOrdering::XYZ, true, c_stateA, subGroupSize>;  \
+    extern template class PmeSolveKernel<GridOrdering::YZX, false, c_stateA, subGroupSize>; \
+    extern template class PmeSolveKernel<GridOrdering::YZX, true, c_stateA, subGroupSize>;  \
+    extern template class PmeSolveKernel<GridOrdering::XYZ, false, c_stateB, subGroupSize>; \
+    extern template class PmeSolveKernel<GridOrdering::XYZ, true, c_stateB, subGroupSize>;  \
+    extern template class PmeSolveKernel<GridOrdering::YZX, false, c_stateB, subGroupSize>; \
+    extern template class PmeSolveKernel<GridOrdering::YZX, true, c_stateB, subGroupSize>;
+
 #define INSTANTIATE(order, subGroupSize)        \
     INSTANTIATE_X(SPREAD, order, subGroupSize); \
-    INSTANTIATE_X(GATHER, order, subGroupSize);
+    INSTANTIATE_X(GATHER, order, subGroupSize); \
+    INSTANTIATE_SOLVE(subGroupSize);
 
 #if GMX_SYCL_DPCPP
 INSTANTIATE(4, 16);
@@ -107,7 +122,6 @@ INSTANTIATE(4, 32);
 INSTANTIATE(4, 64);
 #endif
 
-
 //! Helper function to set proper kernel functor pointers
 template<int subGroupSize>
 static void setKernelPointers(struct PmeGpuProgramImpl* pmeGpuProgram)
@@ -164,6 +178,22 @@ static void setKernelPointers(struct PmeGpuProgramImpl* pmeGpuProgram)
             new PmeGatherKernel<c_pmeOrder, c_wrapX, c_wrapY, 2, true, ThreadsPerAtom::OrderSquared, subGroupSize>();
     pmeGpuProgram->gatherKernelReadSplinesThPerAtom4Dual =
             new PmeGatherKernel<c_pmeOrder, c_wrapX, c_wrapY, 2, true, ThreadsPerAtom::Order, subGroupSize>();
+    pmeGpuProgram->solveXYZKernelA =
+            new PmeSolveKernel<GridOrdering::XYZ, false, c_stateA, subGroupSize>();
+    pmeGpuProgram->solveXYZEnergyKernelA =
+            new PmeSolveKernel<GridOrdering::XYZ, true, c_stateA, subGroupSize>();
+    pmeGpuProgram->solveYZXKernelA =
+            new PmeSolveKernel<GridOrdering::YZX, false, c_stateA, subGroupSize>();
+    pmeGpuProgram->solveYZXEnergyKernelA =
+            new PmeSolveKernel<GridOrdering::YZX, true, c_stateA, subGroupSize>();
+    pmeGpuProgram->solveXYZKernelB =
+            new PmeSolveKernel<GridOrdering::XYZ, false, c_stateB, subGroupSize>();
+    pmeGpuProgram->solveXYZEnergyKernelB =
+            new PmeSolveKernel<GridOrdering::XYZ, true, c_stateB, subGroupSize>();
+    pmeGpuProgram->solveYZXKernelB =
+            new PmeSolveKernel<GridOrdering::YZX, false, c_stateB, subGroupSize>();
+    pmeGpuProgram->solveYZXEnergyKernelB =
+            new PmeSolveKernel<GridOrdering::YZX, true, c_stateB, subGroupSize>();
 }
 
 PmeGpuProgramImpl::PmeGpuProgramImpl(const DeviceContext& deviceContext) :
@@ -205,4 +235,20 @@ PmeGpuProgramImpl::~PmeGpuProgramImpl()
     delete splineAndSpreadKernelThPerAtom4Dual;
     delete splineAndSpreadKernelWriteSplinesDual;
     delete splineAndSpreadKernelWriteSplinesThPerAtom4Dual;
+    delete gatherKernelSingle;
+    delete gatherKernelThPerAtom4Single;
+    delete gatherKernelReadSplinesSingle;
+    delete gatherKernelReadSplinesThPerAtom4Single;
+    delete gatherKernelDual;
+    delete gatherKernelThPerAtom4Dual;
+    delete gatherKernelReadSplinesDual;
+    delete gatherKernelReadSplinesThPerAtom4Dual;
+    delete solveYZXKernelA;
+    delete solveXYZKernelA;
+    delete solveYZXEnergyKernelA;
+    delete solveXYZEnergyKernelA;
+    delete solveYZXKernelB;
+    delete solveXYZKernelB;
+    delete solveYZXEnergyKernelB;
+    delete solveXYZEnergyKernelB;
 }
diff --git a/src/gromacs/ewald/pme_solve_sycl.cpp b/src/gromacs/ewald/pme_solve_sycl.cpp
new file mode 100644 (file)
index 0000000..633cf31
--- /dev/null
@@ -0,0 +1,491 @@
+/*
+ * This file is part of the GROMACS molecular simulation package.
+ *
+ * Copyright (c) 2021, by the GROMACS development team, led by
+ * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
+ * and including many others, as listed in the AUTHORS file in the
+ * top-level source directory and at http://www.gromacs.org.
+ *
+ * GROMACS is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public License
+ * as published by the Free Software Foundation; either version 2.1
+ * of the License, or (at your option) any later version.
+ *
+ * GROMACS is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with GROMACS; if not, see
+ * http://www.gnu.org/licenses, or write to the Free Software Foundation,
+ * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
+ *
+ * If you want to redistribute modifications to GROMACS, please
+ * consider that scientific software is very special. Version
+ * control is crucial - bugs must be traceable. We will be happy to
+ * consider code for inclusion in the official distribution, but
+ * derived work must not be called official GROMACS. Details are found
+ * in the README & COPYING files - if they are missing, get the
+ * official version at http://www.gromacs.org.
+ *
+ * To help us fund GROMACS development, we humbly ask that you cite
+ * the research papers on the package. Check out http://www.gromacs.org.
+ */
+
+/*! \internal \file
+ *  \brief Implements PME GPU Fourier grid solving in SYCL.
+ *
+ *  \author Mark Abraham <mark.j.abraham@gmail.com>
+ */
+
+#include "gmxpre.h"
+
+#include "pme_solve_sycl.h"
+
+#include <cassert>
+
+#include "gromacs/gpu_utils/gmxsycl.h"
+#include "gromacs/gpu_utils/sycl_kernel_utils.h"
+#include "gromacs/math/units.h"
+
+#include "pme_gpu_constants.h"
+
+using cl::sycl::access::mode;
+
+/*! \brief
+ * PME complex grid solver kernel function.
+ *
+ * \tparam     gridOrdering             Specifies the dimension ordering of the complex grid.
+ * \tparam     computeEnergyAndVirial   Tells if the reciprocal energy and virial should be
+ *                                        computed.
+ * \tparam     subGroupSize             Describes the width of a SYCL subgroup
+ */
+template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int subGroupSize>
+auto makeSolveKernel(cl::sycl::handler&                            cgh,
+                     DeviceAccessor<float, mode::read>             a_splineModuli,
+                     DeviceAccessor<SolveKernelParams, mode::read> a_solveKernelParams,
+                     DeviceAccessor<float, mode::read_write>       a_virialAndEnergy,
+                     DeviceAccessor<float, mode::read_write>       a_fourierGrid)
+{
+    cgh.require(a_splineModuli);
+    cgh.require(a_solveKernelParams);
+    cgh.require(a_virialAndEnergy);
+    cgh.require(a_fourierGrid);
+
+    /* Reduce 7 outputs per warp in the shared memory */
+    const int stride =
+            8; // this is c_virialAndEnergyCount==7 rounded up to power of 2 for convenience, hence the assert
+    static_assert(c_virialAndEnergyCount == 7);
+    const int reductionBufferSize = c_solveMaxWarpsPerBlock * stride;
+    cl::sycl::accessor<float, 1, mode::read_write, cl::sycl::target::local> sm_virialAndEnergy(
+            cl::sycl::range<1>(reductionBufferSize), cgh);
+
+    /* Each thread works on one cell of the Fourier space complex 3D grid (gm_grid).
+     * Each block handles up to c_solveMaxWarpsPerBlock * subGroupSize cells -
+     * depending on the grid contiguous dimension size,
+     * that can range from a part of a single gridline to several complete gridlines.
+     */
+    return [=](cl::sycl::nd_item<3> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
+    {
+        /* This kernel supports 2 different grid dimension orderings: YZX and XYZ */
+        int majorDim, middleDim, minorDim;
+        switch (gridOrdering)
+        {
+            case GridOrdering::YZX:
+                majorDim  = YY;
+                middleDim = ZZ;
+                minorDim  = XX;
+                break;
+
+            case GridOrdering::XYZ:
+                majorDim  = XX;
+                middleDim = YY;
+                minorDim  = ZZ;
+                break;
+
+            default: assert(false);
+        }
+
+        /* Global memory pointers */
+        const float* __restrict__ gm_splineValueMajor =
+                a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[majorDim];
+        const float* __restrict__ gm_splineValueMiddle =
+                a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[middleDim];
+        const float* __restrict__ gm_splineValueMinor =
+                a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[minorDim];
+        // The Fourier grid is allocated as float values, even though
+        // it logically contains complex values. (It also can be
+        // the same memory as the real grid for in-place transforms.)
+        // The buffer underlying the accessor may have a size that is
+        // larger than the active grid, because it is allocated with
+        // reallocateDeviceBuffer. The size of that larger-than-needed
+        // grid can be an odd number of floats, even though actual
+        // grid code only accesses up to an even number of floats. If
+        // we would use the reinterpet method of the accessor to
+        // convert from float to float2, runtime boundary checks can
+        // fail because of this mismatch. So, we extract the
+        // underlying global_ptr and use that to construct
+        // cl::sycl::float2 values when needed.
+        cl::sycl::global_ptr<float> gm_fourierGrid = a_fourierGrid.get_pointer();
+
+        /* Various grid sizes and indices */
+        const int localOffsetMinor = 0, localOffsetMajor = 0, localOffsetMiddle = 0;
+        const int localSizeMinor   = a_solveKernelParams[0].complexGridSizePadded[minorDim];
+        const int localSizeMiddle  = a_solveKernelParams[0].complexGridSizePadded[middleDim];
+        const int localCountMiddle = a_solveKernelParams[0].complexGridSize[middleDim];
+        const int localCountMinor  = a_solveKernelParams[0].complexGridSize[minorDim];
+        const int nMajor           = a_solveKernelParams[0].realGridSize[majorDim];
+        const int nMiddle          = a_solveKernelParams[0].realGridSize[middleDim];
+        const int nMinor           = a_solveKernelParams[0].realGridSize[minorDim];
+        const int maxkMajor        = (nMajor + 1) / 2;  // X or Y
+        const int maxkMiddle       = (nMiddle + 1) / 2; // Y OR Z => only check for !YZX
+        const int maxkMinor        = (nMinor + 1) / 2;  // Z or X => only check for YZX
+
+        const int threadLocalId     = itemIdx.get_local_linear_id();
+        const int gridLineSize      = localCountMinor;
+        const int gridLineIndex     = threadLocalId / gridLineSize;
+        const int gridLineCellIndex = threadLocalId - gridLineSize * gridLineIndex;
+        const int gridLinesPerBlock =
+                cl::sycl::max(itemIdx.get_local_range(2) / size_t(gridLineSize), size_t(1));
+        const int activeWarps = (itemIdx.get_local_range(2) / subGroupSize);
+        const int indexMinor = itemIdx.get_group(2) * itemIdx.get_local_range(2) + gridLineCellIndex;
+        const int indexMiddle = itemIdx.get_group(1) * gridLinesPerBlock + gridLineIndex;
+        const int indexMajor  = itemIdx.get_group(0);
+
+        /* Optional outputs */
+        float energy = 0.0F;
+        float virxx  = 0.0F;
+        float virxy  = 0.0F;
+        float virxz  = 0.0F;
+        float viryy  = 0.0F;
+        float viryz  = 0.0F;
+        float virzz  = 0.0F;
+
+        assert(indexMajor < a_solveKernelParams[0].complexGridSize[majorDim]);
+        if ((indexMiddle < localCountMiddle) & (indexMinor < localCountMinor)
+            & (gridLineIndex < gridLinesPerBlock))
+        {
+            /* The offset should be equal to the global thread index for coalesced access */
+            const int gridThreadIndex =
+                    (indexMajor * localSizeMiddle + indexMiddle) * localSizeMinor + indexMinor;
+
+            const int kMajor = indexMajor + localOffsetMajor;
+            /* Checking either X in XYZ, or Y in YZX cases */
+            const float mMajor = (kMajor < maxkMajor) ? kMajor : (kMajor - nMajor);
+
+            const int kMiddle = indexMiddle + localOffsetMiddle;
+            float     mMiddle = kMiddle;
+            /* Checking Y in XYZ case */
+            if (gridOrdering == GridOrdering::XYZ)
+            {
+                mMiddle = (kMiddle < maxkMiddle) ? kMiddle : (kMiddle - nMiddle);
+            }
+            const int kMinor = localOffsetMinor + indexMinor;
+            float     mMinor = kMinor;
+            /* Checking X in YZX case */
+            if (gridOrdering == GridOrdering::YZX)
+            {
+                mMinor = (kMinor < maxkMinor) ? kMinor : (kMinor - nMinor);
+            }
+            /* We should skip the k-space point (0,0,0) */
+            const bool notZeroPoint = (kMinor > 0) | (kMajor > 0) | (kMiddle > 0);
+
+            float mX, mY, mZ;
+            switch (gridOrdering)
+            {
+                case GridOrdering::YZX:
+                    mX = mMinor;
+                    mY = mMajor;
+                    mZ = mMiddle;
+                    break;
+
+                case GridOrdering::XYZ:
+                    mX = mMajor;
+                    mY = mMiddle;
+                    mZ = mMinor;
+                    break;
+
+                default: assert(false);
+            }
+
+            /* 0.5 correction factor for the first and last components of a Z dimension */
+            float corner_fac = 1.0F;
+            switch (gridOrdering)
+            {
+                case GridOrdering::YZX:
+                    if ((kMiddle == 0) | (kMiddle == maxkMiddle))
+                    {
+                        corner_fac = 0.5F;
+                    }
+                    break;
+
+                case GridOrdering::XYZ:
+                    if ((kMinor == 0) | (kMinor == maxkMinor))
+                    {
+                        corner_fac = 0.5F;
+                    }
+                    break;
+
+                default: assert(false);
+            }
+
+            if (notZeroPoint)
+            {
+                const float mhxk = mX * a_solveKernelParams[0].recipBox[XX][XX];
+                const float mhyk = mX * a_solveKernelParams[0].recipBox[XX][YY]
+                                   + mY * a_solveKernelParams[0].recipBox[YY][YY];
+                const float mhzk = mX * a_solveKernelParams[0].recipBox[XX][ZZ]
+                                   + mY * a_solveKernelParams[0].recipBox[YY][ZZ]
+                                   + mZ * a_solveKernelParams[0].recipBox[ZZ][ZZ];
+
+                const float m2k = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
+                assert(m2k != 0.0F);
+                float denom = m2k * float(M_PI) * a_solveKernelParams[0].boxVolume
+                              * gm_splineValueMajor[kMajor] * gm_splineValueMiddle[kMiddle]
+                              * gm_splineValueMinor[kMinor];
+                assert(sycl_2020::isfinite(denom));
+                assert(denom != 0.0F);
+
+                const float tmp1   = cl::sycl::exp(-a_solveKernelParams[0].ewaldFactor * m2k);
+                const float etermk = a_solveKernelParams[0].elFactor * tmp1 / denom;
+
+                // sycl::float2::load and store are buggy in hipSYCL,
+                // but can probably be used after resolution of
+                // https://github.com/illuhad/hipSYCL/issues/647
+                cl::sycl::float2 gridValue;
+                sycl_2020::loadToVec(
+                        gridThreadIndex, cl::sycl::global_ptr<const float>(gm_fourierGrid), &gridValue);
+                const cl::sycl::float2 oldGridValue = gridValue;
+                gridValue *= etermk;
+                sycl_2020::storeFromVec(gridValue, gridThreadIndex, gm_fourierGrid);
+
+                if (computeEnergyAndVirial)
+                {
+                    const float tmp1k = 2.0F * cl::sycl::dot(gridValue, oldGridValue);
+
+                    float vfactor = (a_solveKernelParams[0].ewaldFactor + 1.0F / m2k) * 2.0F;
+                    float ets2    = corner_fac * tmp1k;
+                    energy        = ets2;
+
+                    float ets2vf = ets2 * vfactor;
+
+                    virxx = ets2vf * mhxk * mhxk - ets2;
+                    virxy = ets2vf * mhxk * mhyk;
+                    virxz = ets2vf * mhxk * mhzk;
+                    viryy = ets2vf * mhyk * mhyk - ets2;
+                    viryz = ets2vf * mhyk * mhzk;
+                    virzz = ets2vf * mhzk * mhzk - ets2;
+                }
+            }
+        }
+
+        /* Optional energy/virial reduction */
+        if (computeEnergyAndVirial)
+        {
+            /* A tricky shuffle reduction inspired by reduce_force_j_warp_shfl.
+             * The idea is to reduce 7 energy/virial components into a single variable (aligned by
+             * 8). We will reduce everything into virxx.
+             */
+
+            /* We can only reduce warp-wise */
+            const int width = subGroupSize;
+            static_assert(subGroupSize >= 8);
+
+            sycl_2020::sub_group sg = itemIdx.get_sub_group();
+
+            /* Making pair sums */
+            virxx += sycl_2020::shift_left(sg, virxx, 1);
+            viryy += sycl_2020::shift_right(sg, viryy, 1);
+            virzz += sycl_2020::shift_left(sg, virzz, 1);
+            virxy += sycl_2020::shift_right(sg, virxy, 1);
+            virxz += sycl_2020::shift_left(sg, virxz, 1);
+            viryz += sycl_2020::shift_right(sg, viryz, 1);
+            energy += sycl_2020::shift_left(sg, energy, 1);
+            if (threadLocalId & 1)
+            {
+                virxx = viryy; // virxx now holds virxx and viryy pair sums
+                virzz = virxy; // virzz now holds virzz and virxy pair sums
+                virxz = viryz; // virxz now holds virxz and viryz pair sums
+            }
+
+            /* Making quad sums */
+            virxx += sycl_2020::shift_left(sg, virxx, 2);
+            virzz += sycl_2020::shift_right(sg, virzz, 2);
+            virxz += sycl_2020::shift_left(sg, virxz, 2);
+            energy += sycl_2020::shift_right(sg, energy, 2);
+            if (threadLocalId & 2)
+            {
+                virxx = virzz; // virxx now holds quad sums of virxx, virxy, virzz and virxy
+                virxz = energy; // virxz now holds quad sums of virxz, viryz, energy and unused paddings
+            }
+
+            /* Making octet sums */
+            virxx += sycl_2020::shift_left(sg, virxx, 4);
+            virxz += sycl_2020::shift_right(sg, virxz, 4);
+            if (threadLocalId & 4)
+            {
+                virxx = virxz; // virxx now holds all 7 components' octet sums + unused paddings
+            }
+
+            /* We only need to reduce virxx now */
+#pragma unroll
+            for (int delta = 8; delta < width; delta <<= 1)
+            {
+                virxx += sycl_2020::shift_left(sg, virxx, delta);
+            }
+            /* Now first 7 threads of each warp have the full output contributions in virxx */
+
+            const int  componentIndex      = threadLocalId & (subGroupSize - 1);
+            const bool validComponentIndex = (componentIndex < c_virialAndEnergyCount);
+
+            if (validComponentIndex)
+            {
+                const int warpIndex = threadLocalId / subGroupSize;
+                sm_virialAndEnergy[warpIndex * stride + componentIndex] = virxx;
+            }
+            itemIdx.barrier(cl::sycl::access::fence_space::local_space);
+
+            /* Reduce to the single warp size */
+            const int targetIndex = threadLocalId;
+#pragma unroll
+            for (int reductionStride = reductionBufferSize >> 1; reductionStride >= subGroupSize;
+                 reductionStride >>= 1)
+            {
+                const int sourceIndex = targetIndex + reductionStride;
+                if ((targetIndex < reductionStride) & (sourceIndex < activeWarps * stride))
+                {
+                    sm_virialAndEnergy[targetIndex] += sm_virialAndEnergy[sourceIndex];
+                }
+                itemIdx.barrier(cl::sycl::access::fence_space::local_space);
+            }
+
+            /* Now use shuffle again */
+            /* NOTE: This reduction assumes there are at least 4 warps (asserted).
+             *       To use fewer warps, add to the conditional:
+             *       && threadLocalId < activeWarps * stride
+             */
+            assert(activeWarps * stride >= subGroupSize);
+            if (threadLocalId < subGroupSize)
+            {
+                float output = sm_virialAndEnergy[threadLocalId];
+#pragma unroll
+                for (int delta = stride; delta < subGroupSize; delta <<= 1)
+                {
+                    output += sycl_2020::shift_left(sg, output, delta);
+                }
+                /* Final output */
+                if (validComponentIndex)
+                {
+                    assert(sycl_2020::isfinite(output));
+                    atomicFetchAdd(a_virialAndEnergy[componentIndex], output);
+                }
+            }
+        }
+    };
+}
+
+template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int gridIndex, int subGroupSize>
+PmeSolveKernel<gridOrdering, computeEnergyAndVirial, gridIndex, subGroupSize>::PmeSolveKernel()
+{
+    reset();
+}
+
+template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int gridIndex, int subGroupSize>
+void PmeSolveKernel<gridOrdering, computeEnergyAndVirial, gridIndex, subGroupSize>::setArg(size_t argIndex,
+                                                                                           void* arg)
+{
+    if (argIndex == 0)
+    {
+        auto* params = reinterpret_cast<PmeGpuKernelParams*>(arg);
+
+        constParams_                             = &params->constants;
+        gridParams_                              = &params->grid;
+        solveKernelParams_.ewaldFactor           = params->grid.ewaldFactor;
+        solveKernelParams_.realGridSize          = params->grid.realGridSize;
+        solveKernelParams_.complexGridSize       = params->grid.complexGridSize;
+        solveKernelParams_.complexGridSizePadded = params->grid.complexGridSizePadded;
+        solveKernelParams_.splineValuesOffset    = params->grid.splineValuesOffset;
+        solveKernelParams_.recipBox[XX]          = params->current.recipBox[XX];
+        solveKernelParams_.recipBox[YY]          = params->current.recipBox[YY];
+        solveKernelParams_.recipBox[ZZ]          = params->current.recipBox[ZZ];
+        solveKernelParams_.boxVolume             = params->current.boxVolume;
+        solveKernelParams_.elFactor              = params->constants.elFactor;
+    }
+    else
+    {
+        GMX_RELEASE_ASSERT(argIndex == 0, "Trying to pass too many args to the solve kernel");
+    }
+}
+
+template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int gridIndex, int subGroupSize>
+cl::sycl::event PmeSolveKernel<gridOrdering, computeEnergyAndVirial, gridIndex, subGroupSize>::launch(
+        const KernelLaunchConfig& config,
+        const DeviceStream&       deviceStream)
+{
+    GMX_RELEASE_ASSERT(gridParams_, "Can not launch the kernel before setting its args");
+    GMX_RELEASE_ASSERT(constParams_, "Can not launch the kernel before setting its args");
+
+    using KernelNameType = PmeSolveKernel<gridOrdering, computeEnergyAndVirial, gridIndex, subGroupSize>;
+
+    // SYCL has different multidimensional layout than OpenCL/CUDA.
+    const cl::sycl::range<3> localSize{ config.blockSize[2], config.blockSize[1], config.blockSize[0] };
+    const cl::sycl::range<3> groupRange{ config.gridSize[2], config.gridSize[1], config.gridSize[0] };
+    const cl::sycl::nd_range<3> range{ groupRange * localSize, localSize };
+
+    cl::sycl::queue q = deviceStream.stream();
+
+    cl::sycl::buffer<SolveKernelParams, 1> d_solveKernelParams(&solveKernelParams_, 1);
+    cl::sycl::event                        e = q.submit([&](cl::sycl::handler& cgh) {
+        auto kernel = makeSolveKernel<gridOrdering, computeEnergyAndVirial, subGroupSize>(
+                cgh,
+                gridParams_->d_splineModuli[gridIndex],
+                d_solveKernelParams,
+                constParams_->d_virialAndEnergy[gridIndex],
+                gridParams_->d_fourierGrid[gridIndex]);
+        cgh.parallel_for<KernelNameType>(range, kernel);
+    });
+
+    // Delete set args, so we don't forget to set them before the next launch.
+    reset();
+
+    return e;
+}
+
+template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int gridIndex, int subGroupSize>
+void PmeSolveKernel<gridOrdering, computeEnergyAndVirial, gridIndex, subGroupSize>::reset()
+{
+    gridParams_  = nullptr;
+    constParams_ = nullptr;
+}
+
+//! Kernel class instantiations
+/* Disable the "explicit template instantiation 'PmeSplineAndSpreadKernel<...>' will emit a vtable in every
+ * translation unit [-Wweak-template-vtables]" warning.
+ * It is only explicitly instantiated in this translation unit, so we should be safe.
+ */
+#ifdef __clang__
+#    pragma clang diagnostic push
+#    pragma clang diagnostic ignored "-Wweak-template-vtables"
+#endif
+
+#define INSTANTIATE(subGroupSize)                                             \
+    template class PmeSolveKernel<GridOrdering::XYZ, false, 0, subGroupSize>; \
+    template class PmeSolveKernel<GridOrdering::XYZ, true, 0, subGroupSize>;  \
+    template class PmeSolveKernel<GridOrdering::YZX, false, 0, subGroupSize>; \
+    template class PmeSolveKernel<GridOrdering::YZX, true, 0, subGroupSize>;  \
+    template class PmeSolveKernel<GridOrdering::XYZ, false, 1, subGroupSize>; \
+    template class PmeSolveKernel<GridOrdering::XYZ, true, 1, subGroupSize>;  \
+    template class PmeSolveKernel<GridOrdering::YZX, false, 1, subGroupSize>; \
+    template class PmeSolveKernel<GridOrdering::YZX, true, 1, subGroupSize>;
+
+#if GMX_SYCL_DPCPP
+INSTANTIATE(16);
+#elif GMX_SYCL_HIPSYCL
+INSTANTIATE(32);
+INSTANTIATE(64);
+#endif
+
+#ifdef __clang__
+#    pragma clang diagnostic pop
+#endif
diff --git a/src/gromacs/ewald/pme_solve_sycl.h b/src/gromacs/ewald/pme_solve_sycl.h
new file mode 100644 (file)
index 0000000..78ff491
--- /dev/null
@@ -0,0 +1,95 @@
+/*
+ * This file is part of the GROMACS molecular simulation package.
+ *
+ * Copyright (c) 2021, by the GROMACS development team, led by
+ * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
+ * and including many others, as listed in the AUTHORS file in the
+ * top-level source directory and at http://www.gromacs.org.
+ *
+ * GROMACS is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public License
+ * as published by the Free Software Foundation; either version 2.1
+ * of the License, or (at your option) any later version.
+ *
+ * GROMACS is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with GROMACS; if not, see
+ * http://www.gnu.org/licenses, or write to the Free Software Foundation,
+ * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
+ *
+ * If you want to redistribute modifications to GROMACS, please
+ * consider that scientific software is very special. Version
+ * control is crucial - bugs must be traceable. We will be happy to
+ * consider code for inclusion in the official distribution, but
+ * derived work must not be called official GROMACS. Details are found
+ * in the README & COPYING files - if they are missing, get the
+ * official version at http://www.gromacs.org.
+ *
+ * To help us fund GROMACS development, we humbly ask that you cite
+ * the research papers on the package. Check out http://www.gromacs.org.
+ */
+
+/*! \internal \file
+ *  \brief Implements PME GPU spline calculation and charge spreading in SYCL.
+ *
+ *  \author Mark Abraham <mark.j.abraham@gmail.com>
+ *  \author Andrey Alekseenko <al42and@gmail.com>
+ */
+
+#include "gromacs/gpu_utils/gmxsycl.h"
+#include "gromacs/gpu_utils/syclutils.h"
+#include "gromacs/math/vectypes.h"
+
+#include "pme_gpu_internal.h"
+#include "pme_gpu_types.h"
+
+struct PmeGpuConstParams;
+struct PmeGpuGridParams;
+
+//! Contains most of the parameters used by the solve kernel
+struct SolveKernelParams
+{
+    /*! \brief Ewald solving factor = (M_PI / pme->ewaldcoeff_q)^2 */
+    float ewaldFactor;
+    /*! \brief Real-space grid data dimensions. */
+    gmx::IVec realGridSize;
+    /*! \brief Fourier grid dimensions. This counts the complex numbers! */
+    gmx::IVec complexGridSize;
+    /*! \brief Fourier grid dimensions (padded). This counts the complex numbers! */
+    gmx::IVec complexGridSizePadded;
+    /*! \brief Offsets for X/Y/Z components of d_splineModuli */
+    gmx::IVec splineValuesOffset;
+    /*! \brief Reciprocal (inverted unit cell) box. */
+    gmx::RVec recipBox[DIM];
+    /*! \brief The unit cell volume for solving. */
+    float boxVolume;
+    /*! \brief Electrostatics coefficient = c_one4PiEps0 / pme->epsilon_r */
+    float elFactor;
+};
+
+//! The kernel for PME solve
+template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int gridIndex, int subGroupSize>
+class PmeSolveKernel : public ISyclKernelFunctor
+{
+public:
+    PmeSolveKernel();
+    //! Sets the kernel arguments
+    void setArg(size_t argIndex, void* arg) override;
+    //! Launches the kernel with given \c config and \c deviceStream
+    cl::sycl::event launch(const KernelLaunchConfig& config, const DeviceStream& deviceStream) override;
+
+private:
+    //! Kernel argument set by \c setArg()
+    PmeGpuConstParams* constParams_ = nullptr;
+    //! Kernel argument set by \c setArg()
+    PmeGpuGridParams* gridParams_ = nullptr;
+    //! Kernel argument set by \c setArg()
+    SolveKernelParams solveKernelParams_;
+
+    //! Called after each launch to ensure we set the arguments again properly
+    void reset();
+};
index 5c48955cc473f27f37737dce24a47f83052b0eb3..c8ea1bff1240ded7d5f06553d523529368557964 100644 (file)
@@ -76,7 +76,17 @@ public:
     PmeSolveTest() = default;
 
     //! Sets the programs once
-    static void SetUpTestSuite() { s_pmeTestHardwareContexts = createPmeTestHardwareContextList(); }
+    static void SetUpTestSuite()
+    {
+        s_pmeTestHardwareContexts    = createPmeTestHardwareContextList();
+        g_allowPmeWithSyclForTesting = true; // We support PmeSolve with SYCL
+    }
+
+    static void TearDownTestSuite()
+    {
+        // Revert the value back.
+        g_allowPmeWithSyclForTesting = false;
+    }
 
     //! The test
     static void runTest()
index 544c9e2b13236021dda8eeef37919d77eb25bf2c..aba31db05351b8ca1d83937c98d64fd1761f0af1 100644 (file)
@@ -144,6 +144,113 @@ static inline float shift_right(sycl_2020::sub_group sg, float var, sycl_2020::s
     return sg.shuffle_up(var, delta);
 }
 #endif
+
+#if GMX_SYCL_HIPSYCL
+/*! \brief Polyfill for sycl::isfinite missing from hipSYCL
+ *
+ * Does not follow GROMACS style because it should follow the name for
+ * which it is a polyfill. */
+template<typename Real>
+__device__ __host__ static inline bool isfinite(Real value)
+{
+    // This is not yet implemented in hipSYCL pending
+    // https://github.com/illuhad/hipSYCL/issues/636
+#    ifdef SYCL_DEVICE_ONLY
+#        if defined(HIPSYCL_PLATFORM_CUDA) && defined(__HIPSYCL_ENABLE_CUDA_TARGET__)
+    return isfinite(value);
+#        elif defined(HIPSYCL_PLATFORM_ROCM) && defined(__HIPSYCL_ENABLE_HIP_TARGET__)
+    return isfinite(value);
+#        else
+#            error "Unsupported hipSYCL target"
+#        endif
+#    else
+    // Should never be called
+    assert(false);
+    GMX_UNUSED_VALUE(value);
+    return false;
+#    endif
+}
+#elif GMX_SYCL_DPCPP
+template<typename Real>
+static inline bool isfinite(Real value)
+{
+    return cl::sycl::isfinite(value);
+}
+
+#endif
+
+#if GMX_SYCL_HIPSYCL
+
+/*! \brief Polyfill for sycl::vec::load buggy in hipSYCL
+ *
+ * Loads from the address \c ptr offset in elements of type T by
+ * NumElements * offset, into the components of \c v.
+ *
+ * Can probably be removed when
+ * https://github.com/illuhad/hipSYCL/issues/647 is resolved. */
+template<cl::sycl::access::address_space AddressSpace, typename T, int NumElements>
+static inline void loadToVec(size_t                                     offset,
+                             cl::sycl::multi_ptr<const T, AddressSpace> ptr,
+                             cl::sycl::vec<T, NumElements>*             v)
+{
+    for (int i = 0; i < NumElements; ++i)
+    {
+        (*v)[i] = ptr.get()[offset * NumElements + i];
+    }
+}
+
+/*! \brief Polyfill for sycl::vec::store buggy in hipSYCL
+ *
+ * Loads from the address \c ptr offset in elements of type T by
+ * NumElements * offset, into the components of \c v.
+ *
+ * Can probably be removed when
+ * https://github.com/illuhad/hipSYCL/issues/647 is resolved. */
+template<cl::sycl::access::address_space AddressSpace, typename T, int NumElements>
+static inline void storeFromVec(const cl::sycl::vec<T, NumElements>& v,
+                                size_t                               offset,
+                                cl::sycl::multi_ptr<T, AddressSpace> ptr)
+{
+    for (int i = 0; i < NumElements; ++i)
+    {
+        ptr.get()[offset * NumElements + i] = v[i];
+    }
+}
+
+#elif GMX_SYCL_DPCPP
+
+/*! \brief Polyfill for sycl::vec::load buggy in hipSYCL
+ *
+ * Loads from the address \c ptr offset in elements of type T by
+ * NumElements * offset, into the components of \c v.
+ *
+ * Can probably be removed when
+ * https://github.com/illuhad/hipSYCL/issues/647 is resolved. */
+template<cl::sycl::access::address_space AddressSpace, typename T, int NumElements>
+static inline void loadToVec(size_t offset,
+                             cl::sycl::multi_ptr<const T, AddressSpace> ptr,
+                             cl::sycl::vec<T, NumElements>* v)
+{
+    v->load(offset, ptr);
+}
+
+/*! \brief Polyfill for sycl::vec::store buggy in hipSYCL
+ *
+ * Loads from the address \c ptr offset in elements of type T by
+ * NumElements * offset, into the components of \c v.
+ *
+ * Can probably be removed when
+ * https://github.com/illuhad/hipSYCL/issues/647 is resolved. */
+template<cl::sycl::access::address_space AddressSpace, typename T, int NumElements>
+static inline void storeFromVec(const cl::sycl::vec<T, NumElements>& v,
+                                size_t offset,
+                                cl::sycl::multi_ptr<T, AddressSpace> ptr)
+{
+    v.store(offset, ptr);
+}
+
+#endif
+
 } // namespace sycl_2020
 
 #endif /* GMX_GPU_UTILS_SYCL_KERNEL_UTILS_H */