#include "nbnxm_cuda.h"
#include "gromacs/gpu_utils/cudautils.cuh"
+#include "gromacs/gpu_utils/gpueventsynchronizer.cuh"
#include "gromacs/gpu_utils/vectype_ops.cuh"
#include "gromacs/mdlib/force_flags.h"
#include "gromacs/nbnxm/atomdata.h"
}
/* F buffer operations on GPU: performs force summations and conversion from nb to rvec format. */
-void nbnxn_gpu_add_nbat_f_to_f(const AtomLocality atomLocality,
- gmx_nbnxn_gpu_t *nb,
- int atomStart,
- int nAtoms,
- GpuBufferOpsAccumulateForce accumulateForce)
+void nbnxn_gpu_add_nbat_f_to_f(const AtomLocality atomLocality,
+ gmx_nbnxn_gpu_t *nb,
+ void *fPmeDevicePtr,
+ GpuEventSynchronizer *pmeForcesReady,
+ int atomStart,
+ int nAtoms,
+ bool useGpuFPmeReduction,
+ bool accumulateForce)
{
GMX_ASSERT(nb, "Need a valid nbnxn_gpu object");
const InteractionLocality iLocality = gpuAtomToInteractionLocality(atomLocality);
cudaStream_t stream = nb->stream[iLocality];
+ cu_atomdata_t *adat = nb->atdat;
+ bool addPmeF = useGpuFPmeReduction;
- cu_atomdata_t *adat = nb->atdat;
+ if (addPmeF)
+ {
+ //Stream must wait for PME force completion
+ pmeForcesReady->enqueueWaitEvent(stream);
+ }
/* launch kernel */
config.sharedMemorySize = 0;
config.stream = stream;
- auto kernelFn = (accumulateForce == GpuBufferOpsAccumulateForce::True) ?
- nbnxn_gpu_add_nbat_f_to_f_kernel<true> : nbnxn_gpu_add_nbat_f_to_f_kernel<false>;
- const float3 *fPtr = adat->f;
- rvec *frvec = nb->frvec;
- const int *cell = nb->cell;
+ auto kernelFn = accumulateForce ?
+ nbnxn_gpu_add_nbat_f_to_f_kernel<true, false> :
+ nbnxn_gpu_add_nbat_f_to_f_kernel<false, false>;
+
+ if (addPmeF)
+ {
+ kernelFn = accumulateForce ?
+ nbnxn_gpu_add_nbat_f_to_f_kernel<true, true> :
+ nbnxn_gpu_add_nbat_f_to_f_kernel<false, true>;
+ }
+
+ const float3 *d_f = adat->f;
+ float3 *d_fNB = (float3*) nb->frvec;
+ const float3 *d_fPme = (float3*) fPmeDevicePtr;
+ const int *d_cell = nb->cell;
const auto kernelArgs = prepareGpuKernelArguments(kernelFn, config,
- &fPtr,
- &frvec,
- &cell,
+ &d_f,
+ &d_fPme,
+ &d_fNB,
+ &d_cell,
&atomStart,
&nAtoms);