Add a getter function for the local atom count
authorSzilárd Páll <pall.szilard@gmail.com>
Thu, 3 Jun 2021 08:18:40 +0000 (10:18 +0200)
committerMark Abraham <mark.j.abraham@gmail.com>
Thu, 3 Jun 2021 09:49:34 +0000 (09:49 +0000)
Refs #3915

src/gromacs/mdlib/sim_util.cpp

index f62bbfe8ff70dd17b9feda575b1bfb831402c8b2..7077b98151346782262b3706fad97edb15b2a5a1 100644 (file)
@@ -1193,6 +1193,14 @@ static void setupGpuForceReductions(gmx::MdrunScheduleWorkload* runScheduleWork,
 }
 
 
+/*! \brief Return the number of local atoms.
+ */
+static int getLocalAtomCount(const gmx_domdec_t& dd, const t_mdatoms& mdatoms, bool havePPDomainDecomposition)
+{
+    return havePPDomainDecomposition ? dd_numAtomsZones(dd) : mdatoms.homenr;
+}
+
+
 void do_force(FILE*                               fplog,
               const t_commrec*                    cr,
               const gmx_multisim_t*               ms,
@@ -1317,7 +1325,7 @@ void do_force(FILE*                               fplog,
         {
             // TODO refactor this to do_md, after partitioning.
             stateGpu->reinit(mdatoms->homenr,
-                             cr->dd != nullptr ? dd_numAtomsZones(*cr->dd) : mdatoms->homenr);
+                             getLocalAtomCount(*cr->dd, *mdatoms, havePPDomainDecomposition(cr)));
             if (stepWork.haveGpuPmeOnThisRank)
             {
                 // TODO: This should be moved into PME setup function ( pme_gpu_prepare_computation(...) )
@@ -2070,8 +2078,7 @@ void do_force(FILE*                               fplog,
              && !(stepWork.computeVirial || simulationWork.useGpuNonbonded || stepWork.haveGpuPmeOnThisRank));
     if (combineMtsForcesBeforeHaloExchange)
     {
-        const int numAtoms = havePPDomainDecomposition(cr) ? dd_numAtomsZones(*cr->dd) : mdatoms->homenr;
-        combineMtsForces(numAtoms,
+        combineMtsForces(getLocalAtomCount(*cr->dd, *mdatoms, havePPDomainDecomposition(cr)),
                          force.unpaddedArrayRef(),
                          forceView->forceMtsCombined(),
                          inputrec.mtsLevels[1].stepFactor);