using gmx::AtomLocality;
using gmx::DomainLifetimeWorkload;
using gmx::ForceOutputs;
+using gmx::ForceWithShiftForces;
using gmx::InteractionLocality;
using gmx::RVec;
using gmx::SimulationWorkload;
// PME-first ordering would suffice).
static const bool c_disableAlternatingWait = (getenv("GMX_DISABLE_ALTERNATING_GPU_WAIT") != nullptr);
-static void sum_forces(rvec f[], gmx::ArrayRef<const gmx::RVec> forceToAdd)
+static void sum_forces(ArrayRef<RVec> f, ArrayRef<const RVec> forceToAdd)
{
+ GMX_ASSERT(f.size() >= forceToAdd.size(), "Accumulation buffer should be sufficiently large");
const int end = forceToAdd.size();
int gmx_unused nt = gmx_omp_nthreads_get(emntDefault);
int64_t step,
real forceTolerance,
ArrayRef<const RVec> x,
- const rvec* f)
+ ArrayRef<const RVec> f)
{
real force2Tolerance = gmx::square(forceTolerance);
gmx::index numNonFinite = 0;
}
}
-static void post_process_forces(const t_commrec* cr,
- int64_t step,
- t_nrnb* nrnb,
- gmx_wallcycle_t wcycle,
- const matrix box,
- ArrayRef<const RVec> x,
- ForceOutputs* forceOutputs,
- tensor vir_force,
- const t_mdatoms* mdatoms,
- const t_forcerec* fr,
- gmx::VirtualSitesHandler* vsite,
- const StepWorkload& stepWork)
+//! When necessary, spreads forces on vsites and computes the virial for \p forceOutputs->forceWithShiftForces()
+static void postProcessForceWithShiftForces(t_nrnb* nrnb,
+ gmx_wallcycle_t wcycle,
+ const matrix box,
+ ArrayRef<const RVec> x,
+ ForceOutputs* forceOutputs,
+ tensor vir_force,
+ const t_mdatoms& mdatoms,
+ const t_forcerec& fr,
+ gmx::VirtualSitesHandler* vsite,
+ const StepWorkload& stepWork)
{
- rvec* f = as_rvec_array(forceOutputs->forceWithShiftForces().force().data());
+ ForceWithShiftForces& forceWithShiftForces = forceOutputs->forceWithShiftForces();
+
+ /* If we have NoVirSum forces, but we do not calculate the virial,
+ * we later sum the forceWithShiftForces buffer together with
+ * the noVirSum buffer and spread the combined vsite forces at once.
+ */
+ if (vsite && (!forceOutputs->haveForceWithVirial() || stepWork.computeVirial))
+ {
+ using VirialHandling = gmx::VirtualSitesHandler::VirialHandling;
+
+ auto f = forceWithShiftForces.force();
+ auto fshift = forceWithShiftForces.shiftForces();
+ const VirialHandling virialHandling =
+ (stepWork.computeVirial ? VirialHandling::Pbc : VirialHandling::None);
+ vsite->spreadForces(x, f, virialHandling, fshift, nullptr, nrnb, box, wcycle);
+ forceWithShiftForces.haveSpreadVsiteForces() = true;
+ }
+
+ if (stepWork.computeVirial)
+ {
+ /* Calculation of the virial must be done after vsites! */
+ calc_virial(0, mdatoms.homenr, as_rvec_array(x.data()), forceWithShiftForces, vir_force,
+ box, nrnb, &fr, fr.pbcType);
+ }
+}
+
+//! Spread, compute virial for and sum forces, when necessary
+static void postProcessForces(const t_commrec* cr,
+ int64_t step,
+ t_nrnb* nrnb,
+ gmx_wallcycle_t wcycle,
+ const matrix box,
+ ArrayRef<const RVec> x,
+ ForceOutputs* forceOutputs,
+ tensor vir_force,
+ const t_mdatoms* mdatoms,
+ const t_forcerec* fr,
+ gmx::VirtualSitesHandler* vsite,
+ const StepWorkload& stepWork)
+{
+ // Extract the final output force buffer, which is also the buffer for forces with shift forces
+ ArrayRef<RVec> f = forceOutputs->forceWithShiftForces().force();
if (forceOutputs->haveForceWithVirial())
{
* This is parallellized. MPI communication is performed
* if the constructing atoms aren't local.
*/
+ GMX_ASSERT(!stepWork.computeVirial || f.data() != forceWithVirial.force_.data(),
+ "We need separate force buffers for shift and virial forces when "
+ "computing the virial");
+ GMX_ASSERT(!stepWork.computeVirial
+ || forceOutputs->forceWithShiftForces().haveSpreadVsiteForces(),
+ "We should spread the force with shift forces separately when computing "
+ "the virial");
const gmx::VirtualSitesHandler::VirialHandling virialHandling =
(stepWork.computeVirial ? gmx::VirtualSitesHandler::VirialHandling::NonLinear
: gmx::VirtualSitesHandler::VirialHandling::None);
}
}
}
+ else
+ {
+ GMX_ASSERT(vsite == nullptr || forceOutputs->forceWithShiftForces().haveSpreadVsiteForces(),
+ "We should have spread the vsite forces (earlier)");
+ }
if (fr->print_force >= 0)
{
if (stepWork.computeForces)
{
- /* If we have NoVirSum forces, but we do not calculate the virial,
- * we sum fr->f_novirsum=forceOut.f later.
- */
- if (vsite && !(fr->forceHelperBuffers->haveDirectVirialContributions() && !stepWork.computeVirial))
- {
- auto f = forceOut.forceWithShiftForces().force();
- auto fshift = forceOut.forceWithShiftForces().shiftForces();
- const gmx::VirtualSitesHandler::VirialHandling virialHandling =
- (stepWork.computeVirial ? gmx::VirtualSitesHandler::VirialHandling::Pbc
- : gmx::VirtualSitesHandler::VirialHandling::None);
- vsite->spreadForces(x.unpaddedArrayRef(), f, virialHandling, fshift, nullptr, nrnb, box, wcycle);
- }
-
- if (stepWork.computeVirial)
- {
- /* Calculation of the virial must be done after vsites! */
- calc_virial(0, mdatoms->homenr, as_rvec_array(x.unpaddedArrayRef().data()),
- forceOut.forceWithShiftForces(), vir_force, box, nrnb, fr, inputrec->pbcType);
- }
+ postProcessForceWithShiftForces(nrnb, wcycle, box, x.unpaddedArrayRef(), &forceOut,
+ vir_force, *mdatoms, *fr, vsite, stepWork);
}
// TODO refactor this and unify with above GPU PME-PP / GPU update path call to the same function
if (stepWork.computeForces)
{
- post_process_forces(cr, step, nrnb, wcycle, box, x.unpaddedArrayRef(), &forceOut, vir_force,
- mdatoms, fr, vsite, stepWork);
+ postProcessForces(cr, step, nrnb, wcycle, box, x.unpaddedArrayRef(), &forceOut, vir_force,
+ mdatoms, fr, vsite, stepWork);
}
if (stepWork.computeEnergy)