From 691e8291a8f4a02fa1eada7dddca6e3dd653ff55 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Szil=C3=A1rd=20P=C3=A1ll?= Date: Thu, 3 Jun 2021 15:26:20 +0200 Subject: [PATCH] Fix getLocalAtomCount() Pass dd as pointer to avoid nullptr dereferencing introduced in ba71b7526c22715aceb36586e76f2ae58c4a2461. Refs #3915 --- src/gromacs/mdlib/sim_util.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/gromacs/mdlib/sim_util.cpp b/src/gromacs/mdlib/sim_util.cpp index 5b1ee2983a..1439c40199 100644 --- a/src/gromacs/mdlib/sim_util.cpp +++ b/src/gromacs/mdlib/sim_util.cpp @@ -1200,9 +1200,10 @@ static void setupGpuForceReductions(gmx::MdrunScheduleWorkload* runScheduleWork, /*! \brief Return the number of local atoms. */ -static int getLocalAtomCount(const gmx_domdec_t& dd, const t_mdatoms& mdatoms, bool havePPDomainDecomposition) +static int getLocalAtomCount(const gmx_domdec_t* dd, const t_mdatoms& mdatoms, bool havePPDomainDecomposition) { - return havePPDomainDecomposition ? dd_numAtomsZones(dd) : mdatoms.homenr; + GMX_ASSERT(!(havePPDomainDecomposition && (dd == nullptr)), "Can't have PP decomposition with dd uninitialized!"); + return havePPDomainDecomposition ? dd_numAtomsZones(*dd) : mdatoms.homenr; } @@ -1330,7 +1331,7 @@ void do_force(FILE* fplog, { // TODO refactor this to do_md, after partitioning. stateGpu->reinit(mdatoms->homenr, - getLocalAtomCount(*cr->dd, *mdatoms, havePPDomainDecomposition(cr))); + getLocalAtomCount(cr->dd, *mdatoms, havePPDomainDecomposition(cr))); if (stepWork.haveGpuPmeOnThisRank) { // TODO: This should be moved into PME setup function ( pme_gpu_prepare_computation(...) ) @@ -2077,7 +2078,7 @@ void do_force(FILE* fplog, && !(stepWork.computeVirial || simulationWork.useGpuNonbonded || stepWork.haveGpuPmeOnThisRank)); if (combineMtsForcesBeforeHaloExchange) { - combineMtsForces(getLocalAtomCount(*cr->dd, *mdatoms, havePPDomainDecomposition(cr)), + combineMtsForces(getLocalAtomCount(cr->dd, *mdatoms, havePPDomainDecomposition(cr)), force.unpaddedArrayRef(), forceView->forceMtsCombined(), inputrec.mtsLevels[1].stepFactor); -- 2.22.0