Unify init_gpu function in NBNXM
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_data_mgmt.cpp
index 2f37a0c011c90f8a9b7e7575b62d06ed1b92bad1..3998f833b51f7ab18ef723b8e0a96ce78bedd9eb 100644 (file)
@@ -41,7 +41,6 @@
  */
 #include "gmxpre.h"
 
-#include "gromacs/gpu_utils/device_stream_manager.h"
 #include "gromacs/gpu_utils/pmalloc.h"
 #include "gromacs/hardware/device_information.h"
 #include "gromacs/mdtypes/interaction_const.h"
 namespace Nbnxm
 {
 
-/*! \brief Initialize \p atomdata first time; it only gets filled at pair-search. */
-static void initAtomdataFirst(NBAtomData*          atomdata,
-                              int                  numTypes,
-                              const DeviceContext& deviceContext,
-                              const DeviceStream&  localStream)
+void gpu_init_platform_specific(NbnxmGpu* /* nb */)
 {
-    atomdata->numTypes = numTypes;
-    allocateDeviceBuffer(&atomdata->shiftVec, SHIFTS, deviceContext);
-    atomdata->shiftVecUploaded = false;
-
-    allocateDeviceBuffer(&atomdata->fShift, SHIFTS, deviceContext);
-    allocateDeviceBuffer(&atomdata->eLJ, 1, deviceContext);
-    allocateDeviceBuffer(&atomdata->eElec, 1, deviceContext);
-
-    clearDeviceBufferAsync(&atomdata->fShift, 0, SHIFTS, localStream);
-    clearDeviceBufferAsync(&atomdata->eElec, 0, 1, localStream);
-    clearDeviceBufferAsync(&atomdata->eLJ, 0, 1, localStream);
-
-    /* initialize to nullptr pointers to data that is not allocated here and will
-       need reallocation in later */
-    atomdata->xq = nullptr;
-    atomdata->f  = nullptr;
-
-    /* size -1 indicates that the respective array hasn't been initialized yet */
-    atomdata->numAtoms      = -1;
-    atomdata->numAtomsAlloc = -1;
-}
-
-/*! \brief Initialize the nonbonded parameter data structure. */
-static void initNbparam(NBParamGpu*                     nbp,
-                        const interaction_const_t&      ic,
-                        const PairlistParams&           listParams,
-                        const nbnxn_atomdata_t::Params& nbatParams,
-                        const DeviceContext&            deviceContext)
-{
-    const int numTypes = nbatParams.numTypes;
-
-    set_cutoff_parameters(nbp, &ic, listParams);
-
-    nbp->vdwType  = nbnxmGpuPickVdwKernelType(&ic, nbatParams.ljCombinationRule);
-    nbp->elecType = nbnxmGpuPickElectrostaticsKernelType(&ic, deviceContext.deviceInfo());
-
-    /* generate table for PME */
-    nbp->coulomb_tab = nullptr;
-    if (nbp->elecType == ElecType::EwaldTab || nbp->elecType == ElecType::EwaldTabTwin)
-    {
-        GMX_RELEASE_ASSERT(ic.coulombEwaldTables, "Need valid Coulomb Ewald correction tables");
-        init_ewald_coulomb_force_table(*ic.coulombEwaldTables, nbp, deviceContext);
-    }
-
-    /* set up LJ parameter lookup table */
-    if (!useLjCombRule(nbp->vdwType))
-    {
-        static_assert(sizeof(decltype(nbp->nbfp)) == 2 * sizeof(decltype(*nbatParams.nbfp.data())),
-                      "Mismatch in the size of host / device data types");
-        initParamLookupTable(&nbp->nbfp,
-                             &nbp->nbfp_texobj,
-                             reinterpret_cast<const Float2*>(nbatParams.nbfp.data()),
-                             numTypes * numTypes,
-                             deviceContext);
-    }
-
-    /* set up LJ-PME parameter lookup table */
-    if (ic.vdwtype == VanDerWaalsType::Pme)
-    {
-        static_assert(sizeof(decltype(nbp->nbfp_comb))
-                              == 2 * sizeof(decltype(*nbatParams.nbfp_comb.data())),
-                      "Mismatch in the size of host / device data types");
-        initParamLookupTable(&nbp->nbfp_comb,
-                             &nbp->nbfp_comb_texobj,
-                             reinterpret_cast<const Float2*>(nbatParams.nbfp_comb.data()),
-                             numTypes,
-                             deviceContext);
-    }
-}
-
-NbnxmGpu* gpu_init(const gmx::DeviceStreamManager& deviceStreamManager,
-                   const interaction_const_t*      ic,
-                   const PairlistParams&           listParams,
-                   const nbnxn_atomdata_t*         nbat,
-                   const bool                      bLocalAndNonlocal)
-{
-    auto* nb                              = new NbnxmGpu();
-    nb->deviceContext_                    = &deviceStreamManager.context();
-    nb->atdat                             = new NBAtomData;
-    nb->nbparam                           = new NBParamGpu;
-    nb->plist[InteractionLocality::Local] = new Nbnxm::gpu_plist;
-    if (bLocalAndNonlocal)
-    {
-        nb->plist[InteractionLocality::NonLocal] = new Nbnxm::gpu_plist;
-    }
-
-    nb->bUseTwoStreams = bLocalAndNonlocal;
-
-    nb->timers  = nullptr;
-    nb->timings = nullptr;
-
-    /* init nbst */
-    pmalloc(reinterpret_cast<void**>(&nb->nbst.eLJ), sizeof(*nb->nbst.eLJ));
-    pmalloc(reinterpret_cast<void**>(&nb->nbst.eElec), sizeof(*nb->nbst.eElec));
-    pmalloc(reinterpret_cast<void**>(&nb->nbst.fShift), SHIFTS * sizeof(*nb->nbst.fShift));
-
-    init_plist(nb->plist[InteractionLocality::Local]);
-
-    /* local/non-local GPU streams */
-    GMX_RELEASE_ASSERT(deviceStreamManager.streamIsValid(gmx::DeviceStreamType::NonBondedLocal),
-                       "Local non-bonded stream should be initialized to use GPU for non-bonded.");
-    const DeviceStream& localStream = deviceStreamManager.stream(gmx::DeviceStreamType::NonBondedLocal);
-    nb->deviceStreams[InteractionLocality::Local] = &localStream;
-    // In general, it's not strictly necessary to use 2 streams for SYCL, since they are
-    // out-of-order. But for the time being, it will be less disruptive to keep them.
-    if (nb->bUseTwoStreams)
-    {
-        init_plist(nb->plist[InteractionLocality::NonLocal]);
-
-        GMX_RELEASE_ASSERT(deviceStreamManager.streamIsValid(gmx::DeviceStreamType::NonBondedNonLocal),
-                           "Non-local non-bonded stream should be initialized to use GPU for "
-                           "non-bonded with domain decomposition.");
-        nb->deviceStreams[InteractionLocality::NonLocal] =
-                &deviceStreamManager.stream(gmx::DeviceStreamType::NonBondedNonLocal);
-    }
-
-    nb->bDoTime = false;
-
-    const nbnxn_atomdata_t::Params& nbatParams    = nbat->params();
-    const DeviceContext&            deviceContext = *nb->deviceContext_;
-
-    initNbparam(nb->nbparam, *ic, listParams, nbatParams, deviceContext);
-    initAtomdataFirst(nb->atdat, nbatParams.numTypes, deviceContext, localStream);
-
-    return nb;
+    // Nothing specific in SYCL
 }
 
 void gpu_upload_shiftvec(NbnxmGpu* nb, const nbnxn_atomdata_t* nbatom)
@@ -218,11 +89,13 @@ void gpu_free(NbnxmGpu* nb)
         return;
     }
 
+    delete nb->timers;
+    sfree(nb->timings);
+
     NBAtomData* atdat   = nb->atdat;
     NBParamGpu* nbparam = nb->nbparam;
 
-    if ((!nbparam->coulomb_tab)
-        && (nbparam->elecType == ElecType::EwaldTab || nbparam->elecType == ElecType::EwaldTabTwin))
+    if (nbparam->elecType == ElecType::EwaldTab || nbparam->elecType == ElecType::EwaldTabTwin)
     {
         destroyParamLookupTable(&nbparam->coulomb_tab, nbparam->coulomb_tab_texobj);
     }