Fix getLocalAtomCount()
authorSzilárd Páll <pall.szilard@gmail.com>
Thu, 3 Jun 2021 13:26:20 +0000 (15:26 +0200)
committerSzilárd Páll <pall.szilard@gmail.com>
Thu, 3 Jun 2021 13:26:20 +0000 (15:26 +0200)
Pass dd as pointer to avoid nullptr dereferencing introduced in
ba71b7526c22715aceb36586e76f2ae58c4a2461.

Refs #3915

src/gromacs/mdlib/sim_util.cpp

index 5b1ee2983af43a22a9160ab8e386388c67456dc6..1439c40199f16d876c7a5acd7b5064c3a6b4c350 100644 (file)
@@ -1200,9 +1200,10 @@ static void setupGpuForceReductions(gmx::MdrunScheduleWorkload* runScheduleWork,
 
 /*! \brief Return the number of local atoms.
  */
-static int getLocalAtomCount(const gmx_domdec_t& dd, const t_mdatoms& mdatoms, bool havePPDomainDecomposition)
+static int getLocalAtomCount(const gmx_domdec_t* dd, const t_mdatoms& mdatoms, bool havePPDomainDecomposition)
 {
-    return havePPDomainDecomposition ? dd_numAtomsZones(dd) : mdatoms.homenr;
+    GMX_ASSERT(!(havePPDomainDecomposition && (dd == nullptr)), "Can't have PP decomposition with dd uninitialized!");
+    return havePPDomainDecomposition ? dd_numAtomsZones(*dd) : mdatoms.homenr;
 }
 
 
@@ -1330,7 +1331,7 @@ void do_force(FILE*                               fplog,
         {
             // TODO refactor this to do_md, after partitioning.
             stateGpu->reinit(mdatoms->homenr,
-                             getLocalAtomCount(*cr->dd, *mdatoms, havePPDomainDecomposition(cr)));
+                             getLocalAtomCount(cr->dd, *mdatoms, havePPDomainDecomposition(cr)));
             if (stepWork.haveGpuPmeOnThisRank)
             {
                 // TODO: This should be moved into PME setup function ( pme_gpu_prepare_computation(...) )
@@ -2077,7 +2078,7 @@ void do_force(FILE*                               fplog,
              && !(stepWork.computeVirial || simulationWork.useGpuNonbonded || stepWork.haveGpuPmeOnThisRank));
     if (combineMtsForcesBeforeHaloExchange)
     {
-        combineMtsForces(getLocalAtomCount(*cr->dd, *mdatoms, havePPDomainDecomposition(cr)),
+        combineMtsForces(getLocalAtomCount(cr->dd, *mdatoms, havePPDomainDecomposition(cr)),
                          force.unpaddedArrayRef(),
                          forceView->forceMtsCombined(),
                          inputrec.mtsLevels[1].stepFactor);