Unify init_gpu function in NBNXM
[alexxy/gromacs.git] / src / gromacs / nbnxm / cuda / nbnxm_cuda_data_mgmt.cu
index 804a8ea18066e61b104872e660be1ef056a95530..274f40448f256004dfe08cf223f67c70f9db78c3 100644 (file)
@@ -52,7 +52,6 @@
 // TODO Remove this comment when the above order issue is resolved
 #include "gromacs/gpu_utils/cudautils.cuh"
 #include "gromacs/gpu_utils/device_context.h"
-#include "gromacs/gpu_utils/device_stream_manager.h"
 #include "gromacs/gpu_utils/gpu_utils.h"
 #include "gromacs/gpu_utils/gpueventsynchronizer.cuh"
 #include "gromacs/gpu_utils/pmalloc.h"
@@ -92,174 +91,11 @@ namespace Nbnxm
  */
 static unsigned int gpu_min_ci_balanced_factor = 44;
 
-/*! Initializes the atomdata structure first time, it only gets filled at
-    pair-search. */
-static void init_atomdata_first(NBAtomData*          ad,
-                                int                  nTypes,
-                                const DeviceContext& deviceContext,
-                                const DeviceStream&  localStream)
+void gpu_init_platform_specific(NbnxmGpu* /* nb */)
 {
-    ad->numTypes = nTypes;
-    allocateDeviceBuffer(&ad->shiftVec, SHIFTS, deviceContext);
-    ad->shiftVecUploaded = false;
-
-    allocateDeviceBuffer(&ad->fShift, SHIFTS, deviceContext);
-    allocateDeviceBuffer(&ad->eLJ, 1, deviceContext);
-    allocateDeviceBuffer(&ad->eElec, 1, deviceContext);
-
-    clearDeviceBufferAsync(&ad->fShift, 0, SHIFTS, localStream);
-    clearDeviceBufferAsync(&ad->eElec, 0, 1, localStream);
-    clearDeviceBufferAsync(&ad->eLJ, 0, 1, localStream);
-
-    /* initialize to nullptr poiters to data that is not allocated here and will
-       need reallocation in nbnxn_cuda_init_atomdata */
-    ad->xq = nullptr;
-    ad->f  = nullptr;
-
-    /* size -1 indicates that the respective array hasn't been initialized yet */
-    ad->numAtoms      = -1;
-    ad->numAtomsAlloc = -1;
-}
-
-/*! Initializes the nonbonded parameter data structure. */
-static void init_nbparam(NBParamGpu*                     nbp,
-                         const interaction_const_t*      ic,
-                         const PairlistParams&           listParams,
-                         const nbnxn_atomdata_t::Params& nbatParams,
-                         const DeviceContext&            deviceContext)
-{
-    const int ntypes = nbatParams.numTypes;
-
-    set_cutoff_parameters(nbp, ic, listParams);
-
-    /* The kernel code supports LJ combination rules (geometric and LB) for
-     * all kernel types, but we only generate useful combination rule kernels.
-     * We currently only use LJ combination rule (geometric and LB) kernels
-     * for plain cut-off LJ. On Maxwell the force only kernels speed up 15%
-     * with PME and 20% with RF, the other kernels speed up about half as much.
-     * For LJ force-switch the geometric rule would give 7% speed-up, but this
-     * combination is rarely used. LJ force-switch with LB rule is more common,
-     * but gives only 1% speed-up.
-     */
-    nbp->vdwType  = nbnxmGpuPickVdwKernelType(ic, nbatParams.ljCombinationRule);
-    nbp->elecType = nbnxmGpuPickElectrostaticsKernelType(ic, deviceContext.deviceInfo());
-
-    /* generate table for PME */
-    nbp->coulomb_tab = nullptr;
-    if (nbp->elecType == ElecType::EwaldTab || nbp->elecType == ElecType::EwaldTabTwin)
-    {
-        GMX_RELEASE_ASSERT(ic->coulombEwaldTables, "Need valid Coulomb Ewald correction tables");
-        init_ewald_coulomb_force_table(*ic->coulombEwaldTables, nbp, deviceContext);
-    }
-
-    /* set up LJ parameter lookup table */
-    if (!useLjCombRule(nbp->vdwType))
-    {
-        static_assert(sizeof(decltype(nbp->nbfp)) == 2 * sizeof(decltype(*nbatParams.nbfp.data())),
-                      "Mismatch in the size of host / device data types");
-        initParamLookupTable(&nbp->nbfp,
-                             &nbp->nbfp_texobj,
-                             reinterpret_cast<const Float2*>(nbatParams.nbfp.data()),
-                             ntypes * ntypes,
-                             deviceContext);
-    }
-
-    /* set up LJ-PME parameter lookup table */
-    if (ic->vdwtype == VanDerWaalsType::Pme)
-    {
-        static_assert(sizeof(decltype(nbp->nbfp_comb))
-                              == 2 * sizeof(decltype(*nbatParams.nbfp_comb.data())),
-                      "Mismatch in the size of host / device data types");
-        initParamLookupTable(&nbp->nbfp_comb,
-                             &nbp->nbfp_comb_texobj,
-                             reinterpret_cast<const Float2*>(nbatParams.nbfp_comb.data()),
-                             ntypes,
-                             deviceContext);
-    }
-}
-
-NbnxmGpu* gpu_init(const gmx::DeviceStreamManager& deviceStreamManager,
-                   const interaction_const_t*      ic,
-                   const PairlistParams&           listParams,
-                   const nbnxn_atomdata_t*         nbat,
-                   bool                            bLocalAndNonlocal)
-{
-    auto nb            = new NbnxmGpu();
-    nb->deviceContext_ = &deviceStreamManager.context();
-    snew(nb->atdat, 1);
-    snew(nb->nbparam, 1);
-    snew(nb->plist[InteractionLocality::Local], 1);
-    if (bLocalAndNonlocal)
-    {
-        snew(nb->plist[InteractionLocality::NonLocal], 1);
-    }
-
-    nb->bUseTwoStreams = bLocalAndNonlocal;
-
-    nb->timers = new Nbnxm::GpuTimers();
-    snew(nb->timings, 1);
-
-    /* init nbst */
-    pmalloc((void**)&nb->nbst.eLJ, sizeof(*nb->nbst.eLJ));
-    pmalloc((void**)&nb->nbst.eElec, sizeof(*nb->nbst.eElec));
-    pmalloc((void**)&nb->nbst.fShift, SHIFTS * sizeof(*nb->nbst.fShift));
-
-    init_plist(nb->plist[InteractionLocality::Local]);
-
-    /* local/non-local GPU streams */
-    GMX_RELEASE_ASSERT(deviceStreamManager.streamIsValid(gmx::DeviceStreamType::NonBondedLocal),
-                       "Local non-bonded stream should be initialized to use GPU for non-bonded.");
-    const DeviceStream& localStream = deviceStreamManager.stream(gmx::DeviceStreamType::NonBondedLocal);
-    nb->deviceStreams[InteractionLocality::Local] = &localStream;
-    if (nb->bUseTwoStreams)
-    {
-        init_plist(nb->plist[InteractionLocality::NonLocal]);
-
-        /* Note that the device we're running on does not have to support
-         * priorities, because we are querying the priority range which in this
-         * case will be a single value.
-         */
-        GMX_RELEASE_ASSERT(deviceStreamManager.streamIsValid(gmx::DeviceStreamType::NonBondedNonLocal),
-                           "Non-local non-bonded stream should be initialized to use GPU for "
-                           "non-bonded with domain decomposition.");
-        nb->deviceStreams[InteractionLocality::NonLocal] =
-                &deviceStreamManager.stream(gmx::DeviceStreamType::NonBondedNonLocal);
-        ;
-    }
-
-    /* WARNING: CUDA timings are incorrect with multiple streams.
-     *          This is the main reason why they are disabled by default.
-     */
-    // TODO: Consider turning on by default when we can detect nr of streams.
-    nb->bDoTime = (getenv("GMX_ENABLE_GPU_TIMING") != nullptr);
-
-    if (nb->bDoTime)
-    {
-        init_timings(nb->timings);
-    }
-
     /* set the kernel type for the current GPU */
     /* pick L1 cache configuration */
     cuda_set_cacheconfig();
-
-    const nbnxn_atomdata_t::Params& nbatParams    = nbat->params();
-    const DeviceContext&            deviceContext = *nb->deviceContext_;
-    init_atomdata_first(nb->atdat, nbatParams.numTypes, deviceContext, localStream);
-    init_nbparam(nb->nbparam, ic, listParams, nbatParams, deviceContext);
-
-    nb->atomIndicesSize       = 0;
-    nb->atomIndicesSize_alloc = 0;
-    nb->ncxy_na               = 0;
-    nb->ncxy_na_alloc         = 0;
-    nb->ncxy_ind              = 0;
-    nb->ncxy_ind_alloc        = 0;
-
-    if (debug)
-    {
-        fprintf(debug, "Initialized CUDA data structures.\n");
-    }
-
-    return nb;
 }
 
 void gpu_upload_shiftvec(NbnxmGpu* nb, const nbnxn_atomdata_t* nbatom)
@@ -290,17 +126,17 @@ void gpu_free(NbnxmGpu* nb)
         return;
     }
 
+    delete nb->timers;
+    sfree(nb->timings);
+
     NBAtomData* atdat   = nb->atdat;
     NBParamGpu* nbparam = nb->nbparam;
 
-    if ((!nbparam->coulomb_tab)
-        && (nbparam->elecType == ElecType::EwaldTab || nbparam->elecType == ElecType::EwaldTabTwin))
+    if (nbparam->elecType == ElecType::EwaldTab || nbparam->elecType == ElecType::EwaldTabTwin)
     {
         destroyParamLookupTable(&nbparam->coulomb_tab, nbparam->coulomb_tab_texobj);
     }
 
-    delete nb->timers;
-
     if (!useLjCombRule(nb->nbparam->vdwType))
     {
         destroyParamLookupTable(&nbparam->nbfp, nbparam->nbfp_texobj);
@@ -319,8 +155,14 @@ void gpu_free(NbnxmGpu* nb)
 
     freeDeviceBuffer(&atdat->f);
     freeDeviceBuffer(&atdat->xq);
-    freeDeviceBuffer(&atdat->atomTypes);
-    freeDeviceBuffer(&atdat->ljComb);
+    if (useLjCombRule(nb->nbparam->vdwType))
+    {
+        freeDeviceBuffer(&atdat->ljComb);
+    }
+    else
+    {
+        freeDeviceBuffer(&atdat->atomTypes);
+    }
 
     /* Free plist */
     auto* plist = nb->plist[InteractionLocality::Local];
@@ -328,7 +170,7 @@ void gpu_free(NbnxmGpu* nb)
     freeDeviceBuffer(&plist->cj4);
     freeDeviceBuffer(&plist->imask);
     freeDeviceBuffer(&plist->excl);
-    sfree(plist);
+    delete plist;
     if (nb->bUseTwoStreams)
     {
         auto* plist_nl = nb->plist[InteractionLocality::NonLocal];
@@ -336,7 +178,7 @@ void gpu_free(NbnxmGpu* nb)
         freeDeviceBuffer(&plist_nl->cj4);
         freeDeviceBuffer(&plist_nl->imask);
         freeDeviceBuffer(&plist_nl->excl);
-        sfree(plist_nl);
+        delete plist_nl;
     }
 
     /* Free nbst */
@@ -349,9 +191,8 @@ void gpu_free(NbnxmGpu* nb)
     pfree(nb->nbst.fShift);
     nb->nbst.fShift = nullptr;
 
-    sfree(atdat);
-    sfree(nbparam);
-    sfree(nb->timings);
+    delete atdat;
+    delete nbparam;
     delete nb;
 
     if (debug)