/*
* This file is part of the GROMACS molecular simulation package.
*
- * Copyright (c) 2019,2020, by the GROMACS development team, led by
+ * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
* Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
* and including many others, as listed in the AUTHORS file in the
* top-level source directory and at http://www.gromacs.org.
#include "gromacs/gpu_utils/device_context.h"
#include "gromacs/gpu_utils/device_stream.h"
-#include "gromacs/gpu_utils/gputraits.h"
+#include "gromacs/mdtypes/simulation_workload.h"
#include "gromacs/utility/enumerationhelpers.h"
#include "gromacs/utility/exceptions.h"
#include "gromacs/utility/gmxassert.h"
* \throws InternalError If any of the required resources could not be initialized.
*/
Impl(const DeviceInformation& deviceInfo,
- bool useGpuForPme,
bool havePpDomainDecomposition,
- bool doGpuPmePpTransfer,
- bool useGpuForUpdate,
+ SimulationWorkload simulationWork,
bool useTiming);
~Impl();
//! Device context.
DeviceContext context_;
//! GPU command streams.
- EnumerationArray<DeviceStreamType, DeviceStream> streams_;
+ EnumerationArray<DeviceStreamType, std::unique_ptr<DeviceStream>> streams_;
};
// DeviceStreamManager::Impl
DeviceStreamManager::Impl::Impl(const DeviceInformation& deviceInfo,
- const bool useGpuForPme,
const bool havePpDomainDecomposition,
- const bool doGpuPmePpTransfer,
- const bool useGpuForUpdate,
+ const SimulationWorkload simulationWork,
const bool useTiming) :
context_(deviceInfo)
{
try
{
- streams_[DeviceStreamType::NonBondedLocal].init(context_, DeviceStreamPriority::Normal, useTiming);
+ streams_[DeviceStreamType::NonBondedLocal] =
+ std::make_unique<DeviceStream>(context_, DeviceStreamPriority::Normal, useTiming);
- if (useGpuForPme)
+ if (simulationWork.useGpuPme)
{
/* Creating a PME GPU stream:
* - default high priority with CUDA
* - no priorities implemented yet with OpenCL; see #2532
*/
- streams_[DeviceStreamType::Pme].init(context_, DeviceStreamPriority::High, useTiming);
+ streams_[DeviceStreamType::Pme] =
+ std::make_unique<DeviceStream>(context_, DeviceStreamPriority::High, useTiming);
}
if (havePpDomainDecomposition)
{
- streams_[DeviceStreamType::NonBondedNonLocal].init(context_, DeviceStreamPriority::High,
- useTiming);
+ streams_[DeviceStreamType::NonBondedNonLocal] =
+ std::make_unique<DeviceStream>(context_, DeviceStreamPriority::High, useTiming);
}
// Update stream is used both for coordinates transfers and for GPU update/constraints
- if (useGpuForPme || useGpuForUpdate)
+ if (simulationWork.useGpuPme || simulationWork.useGpuUpdate || simulationWork.useGpuXBufferOps)
{
- streams_[DeviceStreamType::UpdateAndConstraints].init(
- context_, DeviceStreamPriority::Normal, useTiming);
+ streams_[DeviceStreamType::UpdateAndConstraints] =
+ std::make_unique<DeviceStream>(context_, DeviceStreamPriority::Normal, useTiming);
}
- if (doGpuPmePpTransfer)
+ if (simulationWork.useGpuPmePpCommunication)
{
- streams_[DeviceStreamType::PmePpTransfer].init(context_, DeviceStreamPriority::Normal, useTiming);
+ streams_[DeviceStreamType::PmePpTransfer] =
+ std::make_unique<DeviceStream>(context_, DeviceStreamPriority::Normal, useTiming);
}
}
GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
// DeviceStreamManager
DeviceStreamManager::DeviceStreamManager(const DeviceInformation& deviceInfo,
- const bool useGpuForPme,
const bool havePpDomainDecomposition,
- const bool doGpuPmePpTransfer,
- const bool useGpuForUpdate,
+ const SimulationWorkload simulationWork,
const bool useTiming) :
- impl_(new Impl(deviceInfo, useGpuForPme, havePpDomainDecomposition, doGpuPmePpTransfer, useGpuForUpdate, useTiming))
+ impl_(new Impl(deviceInfo, havePpDomainDecomposition, simulationWork, useTiming))
{
}
const DeviceStream& DeviceStreamManager::stream(DeviceStreamType streamToGet) const
{
- return impl_->streams_[streamToGet];
+ return *impl_->streams_[streamToGet];
}
const DeviceStream& DeviceStreamManager::bondedStream(bool hasPPDomainDecomposition) const
bool DeviceStreamManager::streamIsValid(DeviceStreamType streamToCheck) const
{
- return impl_->streams_[streamToCheck].isValid();
+ return impl_->streams_[streamToCheck] != nullptr && impl_->streams_[streamToCheck]->isValid();
}
} // namespace gmx