#include "gromacs/domdec/partition.h"
#include "gromacs/gmxlib/network.h"
#include "gromacs/gmxlib/nrnb.h"
+#include "gromacs/gpu_utils/device_stream_manager.h"
#include "gromacs/gpu_utils/gpu_utils.h"
#include "gromacs/hardware/hw_info.h"
#include "gromacs/listed_forces/manage_threading.h"
return bCutoffAllowed;
}
-void constructGpuHaloExchange(const gmx::MDLogger& mdlog,
- const t_commrec& cr,
- const DeviceContext& deviceContext,
- const DeviceStream& streamLocal,
- const DeviceStream& streamNonLocal)
+void constructGpuHaloExchange(const gmx::MDLogger& mdlog,
+ const t_commrec& cr,
+ const gmx::DeviceStreamManager& deviceStreamManager)
{
-
+ GMX_RELEASE_ASSERT(deviceStreamManager.streamIsValid(gmx::DeviceStreamType::NonBondedLocal),
+ "Local non-bonded stream should be valid when using"
+ "GPU halo exchange.");
+ GMX_RELEASE_ASSERT(deviceStreamManager.streamIsValid(gmx::DeviceStreamType::NonBondedNonLocal),
+ "Non-local non-bonded stream should be valid when using "
+ "GPU halo exchange.");
int gpuHaloExchangeSize = 0;
int pulseStart = 0;
if (cr.dd->gpuHaloExchange.empty())
for (int pulse = pulseStart; pulse < cr.dd->comm->cd[0].numPulses(); pulse++)
{
cr.dd->gpuHaloExchange.push_back(std::make_unique<gmx::GpuHaloExchange>(
- cr.dd, cr.mpi_comm_mysim, deviceContext, streamLocal, streamNonLocal, pulse));
+ cr.dd, cr.mpi_comm_mysim, deviceStreamManager.context(),
+ deviceStreamManager.stream(gmx::DeviceStreamType::NonBondedLocal),
+ deviceStreamManager.stream(gmx::DeviceStreamType::NonBondedNonLocal), pulse));
}
}
}