Add helper functions for setting up Nbnxm gpu object in nblib
[alexxy/gromacs.git] / api / nblib / nbnxmsetuphelpers.cpp
index 160612f8e8228fd2d170d962b7924878074c2b25..c876482b97f3b30b9072d4554e173bbd661d6170 100644 (file)
@@ -129,6 +129,16 @@ Nbnxm::KernelSetup createKernelSetupCPU(const SimdKernels nbnxmSimd, const bool
     return kernelSetup;
 }
 
+Nbnxm::KernelSetup createKernelSetupGPU(const bool useTabulatedEwaldCorr)
+{
+    Nbnxm::KernelSetup kernelSetup;
+    kernelSetup.kernelType         = Nbnxm::KernelType::Gpu8x8x8;
+    kernelSetup.ewaldExclusionType = useTabulatedEwaldCorr ? Nbnxm::EwaldExclusionType::Table
+                                                           : Nbnxm::EwaldExclusionType::Analytical;
+
+    return kernelSetup;
+}
+
 std::vector<int64_t> createParticleInfoAllVdw(const size_t numParticles)
 {
     std::vector<int64_t> particleInfoAllVdw(numParticles);
@@ -178,6 +188,29 @@ gmx::StepWorkload createStepWorkload()
     return stepWorkload;
 }
 
+static gmx::SimulationWorkload createSimulationWorkload()
+{
+    gmx::SimulationWorkload simulationWork;
+    simulationWork.computeNonbonded = true;
+    return simulationWork;
+}
+
+gmx::SimulationWorkload createSimulationWorkloadGpu()
+{
+    gmx::SimulationWorkload simulationWork = createSimulationWorkload();
+
+    simulationWork.useGpuNonbonded = true;
+    simulationWork.useGpuUpdate    = false;
+
+    return simulationWork;
+}
+
+std::shared_ptr<gmx::DeviceStreamManager> createDeviceStreamManager(const DeviceInformation& deviceInfo,
+                                                                    const gmx::SimulationWorkload& simulationWorkload)
+{
+    return std::make_shared<gmx::DeviceStreamManager>(deviceInfo, false, simulationWorkload, false);
+}
+
 real ewaldCoeff(const real ewald_rtol, const real pairlistCutoff)
 {
     return calc_ewaldcoeff_q(pairlistCutoff, ewald_rtol);
@@ -279,6 +312,53 @@ std::unique_ptr<nonbonded_verlet_t> createNbnxmCPU(const size_t              num
     return nbv;
 }
 
+std::unique_ptr<nonbonded_verlet_t> createNbnxmGPU(const size_t               numParticleTypes,
+                                                   const NBKernelOptions&     options,
+                                                   const std::vector<real>&   nonbondedParameters,
+                                                   const interaction_const_t& interactionConst,
+                                                   const gmx::DeviceStreamManager& deviceStreamManager)
+{
+    const auto pinPolicy       = gmx::PinningPolicy::PinnedIfSupported;
+    const int  combinationRule = static_cast<int>(options.ljCombinationRule);
+
+    Nbnxm::KernelSetup kernelSetup = createKernelSetupGPU(options.useTabulatedEwaldCorr);
+
+    PairlistParams pairlistParams(kernelSetup.kernelType, false, options.pairlistCutoff, false);
+
+
+    // nbnxn_atomdata is always initialized with 1 thread if the GPU is used
+    constexpr int numThreadsInit = 1;
+    // multiple energy groups are not supported on the GPU
+    constexpr int numEnergyGroups = 1;
+    auto          atomData        = std::make_unique<nbnxn_atomdata_t>(pinPolicy,
+                                                       gmx::MDLogger(),
+                                                       kernelSetup.kernelType,
+                                                       combinationRule,
+                                                       numParticleTypes,
+                                                       nonbondedParameters,
+                                                       numEnergyGroups,
+                                                       numThreadsInit);
+
+    NbnxmGpu* nbnxmGpu = Nbnxm::gpu_init(
+            deviceStreamManager, &interactionConst, pairlistParams, atomData.get(), false);
+
+    // minimum iList count for GPU balancing
+    int iListCount = Nbnxm::gpu_min_ci_balanced(nbnxmGpu);
+
+    auto pairlistSets = std::make_unique<PairlistSets>(pairlistParams, false, iListCount);
+    auto pairSearch   = std::make_unique<PairSearch>(
+            PbcType::Xyz, false, nullptr, nullptr, pairlistParams.pairlistType, false, options.numOpenMPThreads, pinPolicy);
+
+    // Put everything together
+    auto nbv = std::make_unique<nonbonded_verlet_t>(
+            std::move(pairlistSets), std::move(pairSearch), std::move(atomData), kernelSetup, nbnxmGpu, nullptr);
+
+    // Some paramters must be copied to NbnxmGpu to have a fully constructed nonbonded_verlet_t
+    Nbnxm::gpu_init_atomdata(nbv->gpu_nbv, nbv->nbat.get());
+
+    return nbv;
+}
+
 void setGmxNonBondedNThreads(int numThreads)
 {
     gmx_omp_nthreads_set(ModuleMultiThread::Pairsearch, numThreads);