#include <cmath>
+#include "gromacs/gpu_utils/device_stream_manager.h"
#include "gromacs/gpu_utils/gpu_utils.h"
#include "gromacs/gpu_utils/oclutils.h"
#include "gromacs/hardware/gpu_hw_info.h"
cl_int cl_error;
cl_atomdata_t* adat = nb->atdat;
- cl_command_queue ls = nb->deviceStreams[InteractionLocality::Local].stream();
+ cl_command_queue ls = nb->deviceStreams[InteractionLocality::Local]->stream();
size_t local_work_size[3] = { 1, 1, 1 };
size_t global_work_size[3] = { 1, 1, 1 };
//! This function is documented in the header file
-NbnxmGpu* gpu_init(const DeviceContext& deviceContext,
- const interaction_const_t* ic,
- const PairlistParams& listParams,
- const nbnxn_atomdata_t* nbat,
- const bool bLocalAndNonlocal)
+NbnxmGpu* gpu_init(const gmx::DeviceStreamManager& deviceStreamManager,
+ const interaction_const_t* ic,
+ const PairlistParams& listParams,
+ const nbnxn_atomdata_t* nbat,
+ const bool bLocalAndNonlocal)
{
GMX_ASSERT(ic, "Need a valid interaction constants object");
auto nb = new NbnxmGpu();
- nb->deviceContext_ = &deviceContext;
+ nb->deviceContext_ = &deviceStreamManager.context();
snew(nb->atdat, 1);
snew(nb->nbparam, 1);
snew(nb->plist[InteractionLocality::Local], 1);
nb->timers = new cl_timers_t();
snew(nb->timings, 1);
+ /* set device info, just point it to the right GPU among the detected ones */
nb->dev_rundata = new gmx_device_runtime_data_t();
/* init nbst */
nb->bDoTime = (getenv("GMX_DISABLE_GPU_TIMING") == nullptr);
/* local/non-local GPU streams */
- nb->deviceStreams[InteractionLocality::Local].init(*nb->deviceContext_,
- DeviceStreamPriority::Normal, nb->bDoTime);
+ GMX_RELEASE_ASSERT(deviceStreamManager.streamIsValid(gmx::DeviceStreamType::NonBondedLocal),
+ "Local non-bonded stream should be initialized to use GPU for non-bonded.");
+ nb->deviceStreams[InteractionLocality::Local] =
+ &deviceStreamManager.stream(gmx::DeviceStreamType::NonBondedLocal);
if (nb->bUseTwoStreams)
{
init_plist(nb->plist[InteractionLocality::NonLocal]);
- nb->deviceStreams[InteractionLocality::NonLocal].init(
- *nb->deviceContext_, DeviceStreamPriority::High, nb->bDoTime);
+ GMX_RELEASE_ASSERT(deviceStreamManager.streamIsValid(gmx::DeviceStreamType::NonBondedNonLocal),
+ "Non-local non-bonded stream should be initialized to use GPU for "
+ "non-bonded with domain decomposition.");
+ nb->deviceStreams[InteractionLocality::NonLocal] =
+ &deviceStreamManager.stream(gmx::DeviceStreamType::NonBondedNonLocal);
}
if (nb->bDoTime)
cl_int gmx_used_in_debug cl_error;
cl_atomdata_t* atomData = nb->atdat;
- cl_command_queue ls = nb->deviceStreams[InteractionLocality::Local].stream();
+ cl_command_queue ls = nb->deviceStreams[InteractionLocality::Local]->stream();
cl_float value = 0.0F;
cl_error = clEnqueueFillBuffer(ls, atomData->f, &value, sizeof(cl_float), 0,
/* kick off buffer clearing kernel to ensure concurrency with constraints/update */
cl_int gmx_unused cl_error;
- cl_error = clFlush(nb->deviceStreams[InteractionLocality::Local].stream());
+ cl_error = clFlush(nb->deviceStreams[InteractionLocality::Local]->stream());
GMX_ASSERT(cl_error == CL_SUCCESS, ("clFlush failed: " + ocl_get_error_string(cl_error)).c_str());
}
// because getLastRangeTime() gets skipped with empty lists later
// which leads to the counter not being reset.
bool bDoTime = (nb->bDoTime && !h_plist->sci.empty());
- const DeviceStream& deviceStream = nb->deviceStreams[iloc];
+ const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
cl_plist_t* d_plist = nb->plist[iloc];
if (d_plist->na_c < 0)
void gpu_upload_shiftvec(NbnxmGpu* nb, const nbnxn_atomdata_t* nbatom)
{
cl_atomdata_t* adat = nb->atdat;
- cl_command_queue ls = nb->deviceStreams[InteractionLocality::Local].stream();
+ cl_command_queue ls = nb->deviceStreams[InteractionLocality::Local]->stream();
/* only if we have a dynamic box */
if (nbatom->bDynamicBox || !adat->bShiftVecUploaded)
bool bDoTime = nb->bDoTime;
cl_timers_t* timers = nb->timers;
cl_atomdata_t* d_atdat = nb->atdat;
- const DeviceStream& deviceStream = nb->deviceStreams[InteractionLocality::Local];
+ const DeviceStream& deviceStream = *nb->deviceStreams[InteractionLocality::Local];
natoms = nbat->numAtoms();
realloced = false;