#include "gromacs/gpu_utils/cudautils.cuh"
#include "gromacs/gpu_utils/gpueventsynchronizer.cuh"
#include "gromacs/gpu_utils/vectype_ops.cuh"
-#include "gromacs/mdlib/force_flags.h"
+#include "gromacs/mdlib/ppforceworkload.h"
#include "gromacs/nbnxm/atomdata.h"
#include "gromacs/nbnxm/gpu_common.h"
#include "gromacs/nbnxm/gpu_common_utils.h"
with this event in the non-local stream before launching the non-bonded kernel.
*/
void gpu_launch_kernel(gmx_nbnxn_cuda_t *nb,
- const int flags,
+ const gmx::ForceFlags &forceFlags,
const InteractionLocality iloc)
{
cu_atomdata_t *adat = nb->atdat;
cu_timers_t *t = nb->timers;
cudaStream_t stream = nb->stream[iloc];
- bool bCalcEner = flags & GMX_FORCE_ENERGY;
- bool bCalcFshift = flags & GMX_FORCE_VIRIAL;
bool bDoTime = nb->bDoTime;
/* Don't launch the non-local kernel if there is no work to do.
auto *timingEvent = bDoTime ? t->interaction[iloc].nb_k.fetchNextEvent() : nullptr;
const auto kernel = select_nbnxn_kernel(nbp->eeltype,
nbp->vdwtype,
- bCalcEner,
+ forceFlags.computeEnergy,
(plist->haveFreshList && !nb->timers->interaction[iloc].didPrune),
nb->dev_info);
- const auto kernelArgs = prepareGpuKernelArguments(kernel, config, adat, nbp, plist, &bCalcFshift);
+ const auto kernelArgs = prepareGpuKernelArguments(kernel, config, adat, nbp, plist, &forceFlags.computeVirial);
launchGpuKernel(kernel, config, timingEvent, "k_calc_nb", kernelArgs);
if (bDoTime)
void gpu_launch_cpyback(gmx_nbnxn_cuda_t *nb,
nbnxn_atomdata_t *nbatom,
- const int flags,
+ const gmx::ForceFlags &forceFlags,
const AtomLocality atomLocality,
const bool copyBackNbForce)
{
bool bDoTime = nb->bDoTime;
cudaStream_t stream = nb->stream[iloc];
- bool bCalcEner = flags & GMX_FORCE_ENERGY;
- bool bCalcFshift = flags & GMX_FORCE_VIRIAL;
-
/* don't launch non-local copy-back if there was no non-local work to do */
if ((iloc == InteractionLocality::NonLocal) && !haveGpuShortRangeWork(*nb, iloc))
{
/* only transfer energies in the local stream */
if (iloc == InteractionLocality::Local)
{
- /* DtoH fshift */
- if (bCalcFshift)
+ /* DtoH fshift when virial is needed */
+ if (forceFlags.computeVirial)
{
cu_copy_D2H_async(nb->nbst.fshift, adat->fshift,
SHIFTS * sizeof(*nb->nbst.fshift), stream);
}
/* DtoH energies */
- if (bCalcEner)
+ if (forceFlags.computeEnergy)
{
cu_copy_D2H_async(nb->nbst.e_lj, adat->e_lj,
sizeof(*nb->nbst.e_lj), stream);