Unify constructor of nbnxn_atomdata_t
authorJoe Jordan <ejjordan12@gmail.com>
Tue, 27 Apr 2021 08:40:11 +0000 (08:40 +0000)
committerPaul Bauer <paul.bauer.q@gmail.com>
Tue, 27 Apr 2021 08:40:11 +0000 (08:40 +0000)
api/nblib/gmxsetup.cpp
src/gromacs/nbnxm/atomdata.cpp
src/gromacs/nbnxm/atomdata.h
src/gromacs/nbnxm/benchmark/bench_setup.cpp
src/gromacs/nbnxm/nbnxm_setup.cpp

index 955b8cc1db35d10a50f875b18aba964ae7051798..3baa358e98d0fbe5393606240f53b8f6ec031bf5 100644 (file)
@@ -192,17 +192,14 @@ void NbvSetupUtil::setupNbnxmInstance(const size_t numParticleTypes, const NBKer
     auto pairSearch   = std::make_unique<PairSearch>(
             PbcType::Xyz, false, nullptr, nullptr, pairlistParams.pairlistType, false, numThreads, pinPolicy);
 
-    auto atomData = std::make_unique<nbnxn_atomdata_t>(pinPolicy);
-
-    // Needs to be called with the number of unique ParticleTypes
-    nbnxn_atomdata_init(gmx::MDLogger(),
-                        atomData.get(),
-                        kernelSetup.kernelType,
-                        combinationRule,
-                        numParticleTypes,
-                        nonbondedParameters_,
-                        1,
-                        numThreads);
+    auto atomData = std::make_unique<nbnxn_atomdata_t>(pinPolicy,
+                                                       gmx::MDLogger(),
+                                                       kernelSetup.kernelType,
+                                                       combinationRule,
+                                                       numParticleTypes,
+                                                       nonbondedParameters_,
+                                                       1,
+                                                       numThreads);
 
     // Put everything together
     auto nbv = std::make_unique<nonbonded_verlet_t>(
index d74467f1634ebdf44e8942819532fc1803a8a8d5..bce2b142aecf82791e308f0aa93bf12a1af7d0c6 100644 (file)
@@ -432,17 +432,6 @@ nbnxn_atomdata_t::Params::Params(gmx::PinningPolicy pinningPolicy) :
 {
 }
 
-nbnxn_atomdata_t::nbnxn_atomdata_t(gmx::PinningPolicy pinningPolicy) :
-    params_(pinningPolicy),
-    numAtoms_(0),
-    natoms_local(0),
-    shift_vec({}, { pinningPolicy }),
-    x_({}, { pinningPolicy }),
-    simdMasks(),
-    bUseBufferFlags(FALSE)
-{
-}
-
 /* Initializes an nbnxn_atomdata_t::Params data structure */
 static void nbnxn_atomdata_params_init(const gmx::MDLogger&      mdlog,
                                        nbnxn_atomdata_t::Params* params,
@@ -626,17 +615,24 @@ static void nbnxn_atomdata_params_init(const gmx::MDLogger&      mdlog,
 }
 
 /* Initializes an nbnxn_atomdata_t data structure */
-void nbnxn_atomdata_init(const gmx::MDLogger&    mdlog,
-                         nbnxn_atomdata_t*       nbat,
-                         const Nbnxm::KernelType kernelType,
-                         int                     enbnxninitcombrule,
-                         int                     ntype,
-                         ArrayRef<const real>    nbfp,
-                         int                     n_energygroups,
-                         int                     nout)
+nbnxn_atomdata_t::nbnxn_atomdata_t(gmx::PinningPolicy      pinningPolicy,
+                                   const gmx::MDLogger&    mdlog,
+                                   const Nbnxm::KernelType kernelType,
+                                   int                     enbnxninitcombrule,
+                                   int                     ntype,
+                                   ArrayRef<const real>    nbfp,
+                                   int                     n_energygroups,
+                                   int                     nout) :
+    params_(pinningPolicy),
+    numAtoms_(0),
+    natoms_local(0),
+    shift_vec({}, { pinningPolicy }),
+    x_({}, { pinningPolicy }),
+    simdMasks(),
+    bUseBufferFlags(FALSE)
 {
     nbnxn_atomdata_params_init(
-            mdlog, &nbat->paramsDeprecated(), kernelType, enbnxninitcombrule, ntype, nbfp, n_energygroups);
+            mdlog, &paramsDeprecated(), kernelType, enbnxninitcombrule, ntype, nbfp, n_energygroups);
 
     const bool simple = Nbnxm::kernelTypeUsesSimplePairlist(kernelType);
     const bool bSIMD  = Nbnxm::kernelTypeIsSimd(kernelType);
@@ -648,38 +644,37 @@ void nbnxn_atomdata_init(const gmx::MDLogger&    mdlog,
             int pack_x = std::max(c_nbnxnCpuIClusterSize, Nbnxm::JClusterSizePerKernelType[kernelType]);
             switch (pack_x)
             {
-                case 4: nbat->XFormat = nbatX4; break;
-                case 8: nbat->XFormat = nbatX8; break;
+                case 4: XFormat = nbatX4; break;
+                case 8: XFormat = nbatX8; break;
                 default: gmx_incons("Unsupported packing width");
             }
         }
         else
         {
-            nbat->XFormat = nbatXYZ;
+            XFormat = nbatXYZ;
         }
 
-        nbat->FFormat = nbat->XFormat;
+        FFormat = XFormat;
     }
     else
     {
-        nbat->XFormat = nbatXYZQ;
-        nbat->FFormat = nbatXYZ;
+        XFormat = nbatXYZQ;
+        FFormat = nbatXYZ;
     }
 
-    nbat->shift_vec.resize(gmx::c_numShiftVectors);
+    shift_vec.resize(gmx::c_numShiftVectors);
 
-    nbat->xstride = (nbat->XFormat == nbatXYZQ ? STRIDE_XYZQ : DIM);
-    nbat->fstride = (nbat->FFormat == nbatXYZQ ? STRIDE_XYZQ : DIM);
+    xstride = (XFormat == nbatXYZQ ? STRIDE_XYZQ : DIM);
+    fstride = (FFormat == nbatXYZQ ? STRIDE_XYZQ : DIM);
 
     /* Initialize the output data structures */
     for (int i = 0; i < nout; i++)
     {
-        const auto& pinningPolicy = nbat->params().type.get_allocator().pinningPolicy();
-        nbat->out.emplace_back(
-                kernelType, nbat->params().nenergrp, 1 << nbat->params().neg_2log, pinningPolicy);
+        const auto& pinningPolicy = params().type.get_allocator().pinningPolicy();
+        out.emplace_back(kernelType, params().nenergrp, 1 << params().neg_2log, pinningPolicy);
     }
 
-    nbat->buffer_flags.clear();
+    buffer_flags.clear();
 }
 
 template<int packSize>
index cb878aa7791573bfc78d22cf2ec41038ab99477f..1f36881a96aa80f7edd43cfdaab5df2328cb6d1b 100644 (file)
@@ -232,9 +232,24 @@ struct nbnxn_atomdata_t
 
     /*! \brief Constructor
      *
-     * \param[in] pinningPolicy  Sets the pinning policy for all data that might be transferred to a GPU
+     * \param[in] pinningPolicy      Sets the pinning policy for all data that might be transferred
+     *                               to a GPU
+     * \param[in] mdlog              The logger
+     * \param[in] kernelType         Nonbonded NxN kernel type
+     * \param[in] enbnxninitcombrule LJ combination rule
+     * \param[in] ntype              Number of atom types
+     * \param[in] nbfp               Non-bonded force parameters
+     * \param[in] n_energygroups     Number of energy groups
+     * \param[in] nout               Number of output data structures
      */
-    nbnxn_atomdata_t(gmx::PinningPolicy pinningPolicy);
+    nbnxn_atomdata_t(gmx::PinningPolicy        pinningPolicy,
+                     const gmx::MDLogger&      mdlog,
+                     Nbnxm::KernelType         kernelType,
+                     int                       enbnxninitcombrule,
+                     int                       ntype,
+                     gmx::ArrayRef<const real> nbfp,
+                     int                       n_energygroups,
+                     int                       nout);
 
     //! Returns a const reference to the parameters
     const Params& params() const { return params_; }
index 929625fc4151b02f5f11b4b5bdd61a706dc8c9f0..642818194676e3233a85c3d5f4400b7ea6a4b121 100644 (file)
@@ -200,21 +200,19 @@ static std::unique_ptr<nonbonded_verlet_t> setupNbnxmForBenchInstance(const Kern
     auto pairSearch = std::make_unique<PairSearch>(
             PbcType::Xyz, false, nullptr, nullptr, pairlistParams.pairlistType, false, numThreads, pinPolicy);
 
-    auto atomData = std::make_unique<nbnxn_atomdata_t>(pinPolicy);
+    auto atomData = std::make_unique<nbnxn_atomdata_t>(pinPolicy,
+                                                       gmx::MDLogger(),
+                                                       kernelSetup.kernelType,
+                                                       combinationRule,
+                                                       system.numAtomTypes,
+                                                       system.nonbondedParameters,
+                                                       1,
+                                                       numThreads);
 
     // Put everything together
     auto nbv = std::make_unique<nonbonded_verlet_t>(
             std::move(pairlistSets), std::move(pairSearch), std::move(atomData), kernelSetup, nullptr, nullptr);
 
-    nbnxn_atomdata_init(gmx::MDLogger(),
-                        nbv->nbat.get(),
-                        kernelSetup.kernelType,
-                        combinationRule,
-                        system.numAtomTypes,
-                        system.nonbondedParameters,
-                        1,
-                        numThreads);
-
     t_nrnb nrnb;
 
     GMX_RELEASE_ASSERT(!TRICLINIC(system.box), "Only rectangular unit-cells are supported here");
index a1185ea2cc150b20c2a15ca90d94dd43bc3dbac0..422840d889afd285cd58ff8419eedaef24f02e1e 100644 (file)
@@ -418,8 +418,6 @@ std::unique_ptr<nonbonded_verlet_t> init_nb_verlet(const gmx::MDLogger& mdlog,
     auto pinPolicy = (useGpuForNonbonded ? gmx::PinningPolicy::PinnedIfSupported
                                          : gmx::PinningPolicy::CannotBePinned);
 
-    auto nbat = std::make_unique<nbnxn_atomdata_t>(pinPolicy);
-
     int mimimumNumEnergyGroupNonbonded = inputrec.opts.ngener;
     if (inputrec.opts.ngener - inputrec.nwall == 1)
     {
@@ -429,9 +427,10 @@ std::unique_ptr<nonbonded_verlet_t> init_nb_verlet(const gmx::MDLogger& mdlog,
          */
         mimimumNumEnergyGroupNonbonded = 1;
     }
-    nbnxn_atomdata_init(
+
+    auto nbat = std::make_unique<nbnxn_atomdata_t>(
+            pinPolicy,
             mdlog,
-            nbat.get(),
             kernelSetup.kernelType,
             enbnxninitcombrule,
             forcerec.ntype,