Introduce GlobalCommunicationHelper
authorPascal Merz <pascal.merz@me.com>
Thu, 6 Aug 2020 22:30:23 +0000 (16:30 -0600)
committerPascal Merz <pascal.merz@me.com>
Fri, 7 Aug 2020 18:08:00 +0000 (12:08 -0600)
In view of #3437 (!410), this introduces a helper container to store
data related to global communication. With the upcoming builder
approach, there won't be a central location where all elements are
created, so this helper container helps grouping related data allowing
for shorter call signatures. This helper object will become obsolete
when moving to a client-based global communication (#3421,
draft implementation ready).

src/gromacs/modularsimulator/domdechelper.cpp
src/gromacs/modularsimulator/modularsimulator.cpp
src/gromacs/modularsimulator/simulatoralgorithm.cpp
src/gromacs/modularsimulator/simulatoralgorithm.h

index db030fda3e7e6d34a4879a2482ceb6398c587331..200911542f148a0b26136727c21d3f2fef97e758 100644 (file)
@@ -94,6 +94,9 @@ DomDecHelper::DomDecHelper(bool                               isVerbose,
     pull_work_(pull_work)
 {
     GMX_ASSERT(DOMAINDECOMP(cr), "Domain decomposition Helper constructed in non-DD simulation");
+    GMX_ASSERT(checkBondedInteractionsCallback_,
+               "Domain decomposition needs a callback to check the number of bonded "
+               "interactions.");
 }
 
 void DomDecHelper::setup()
index 87bbe91d3fffcd81f5d29ded9f81542db97c0a87..5552e94ff23993260eb6d4001531df7f7210aa97 100644 (file)
@@ -141,13 +141,12 @@ std::unique_ptr<ISimulatorElement> ModularSimulatorAlgorithmBuilder::buildIntegr
         SignallerBuilder<TrajectorySignaller>*     trajectorySignallerBuilder,
         TrajectoryElementBuilder*                  trajectoryElementBuilder,
         std::vector<ICheckpointHelperClient*>*     checkpointClients,
-        CheckBondedInteractionsCallbackPtr*        checkBondedInteractionsCallback,
         compat::not_null<StatePropagatorData*>     statePropagatorDataPtr,
         compat::not_null<EnergyData*>              energyDataPtr,
         FreeEnergyPerturbationData*                freeEnergyPerturbationDataPtr,
         bool                                       hasReadEkinState,
         TopologyHolder::Builder*                   topologyHolderBuilder,
-        SimulationSignals*                         signals)
+        GlobalCommunicationHelper*                 globalCommunicationHelper)
 {
     auto forceElement = buildForces(neighborSearchSignallerBuilder, energySignallerBuilder,
                                     statePropagatorDataPtr, energyDataPtr,
@@ -161,19 +160,23 @@ std::unique_ptr<ISimulatorElement> ModularSimulatorAlgorithmBuilder::buildIntegr
     std::function<void()> needToCheckNumberOfBondedInteractions;
     if (legacySimulatorData_->inputrec->eI == eiMD)
     {
-        auto computeGlobalsElement = std::make_unique<ComputeGlobalsElement<ComputeGlobalsAlgorithm::LeapFrog>>(
-                statePropagatorDataPtr, energyDataPtr, freeEnergyPerturbationDataPtr, signals,
-                nstglobalcomm_, legacySimulatorData_->fplog, legacySimulatorData_->mdlog,
-                legacySimulatorData_->cr, legacySimulatorData_->inputrec, legacySimulatorData_->mdAtoms,
-                legacySimulatorData_->nrnb, legacySimulatorData_->wcycle, legacySimulatorData_->fr,
-                legacySimulatorData_->top_global, legacySimulatorData_->constr, hasReadEkinState);
+        auto computeGlobalsElement =
+                std::make_unique<ComputeGlobalsElement<ComputeGlobalsAlgorithm::LeapFrog>>(
+                        statePropagatorDataPtr, energyDataPtr, freeEnergyPerturbationDataPtr,
+                        globalCommunicationHelper->simulationSignals(),
+                        globalCommunicationHelper->nstglobalcomm(), legacySimulatorData_->fplog,
+                        legacySimulatorData_->mdlog, legacySimulatorData_->cr,
+                        legacySimulatorData_->inputrec, legacySimulatorData_->mdAtoms,
+                        legacySimulatorData_->nrnb, legacySimulatorData_->wcycle,
+                        legacySimulatorData_->fr, legacySimulatorData_->top_global,
+                        legacySimulatorData_->constr, hasReadEkinState);
         topologyHolderBuilder->registerClient(computeGlobalsElement.get());
         energySignallerBuilder->registerSignallerClient(compat::make_not_null(computeGlobalsElement.get()));
         trajectorySignallerBuilder->registerSignallerClient(
                 compat::make_not_null(computeGlobalsElement.get()));
 
-        *checkBondedInteractionsCallback =
-                computeGlobalsElement->getCheckNumberOfBondedInteractionsCallback();
+        globalCommunicationHelper->setCheckBondedInteractionsCallback(
+                computeGlobalsElement->getCheckNumberOfBondedInteractionsCallback());
 
         auto propagator = std::make_unique<Propagator<IntegrationStep::LeapFrog>>(
                 legacySimulatorData_->inputrec->delta_t, statePropagatorDataPtr,
@@ -255,19 +258,21 @@ std::unique_ptr<ISimulatorElement> ModularSimulatorAlgorithmBuilder::buildIntegr
     {
         auto computeGlobalsElement =
                 std::make_unique<ComputeGlobalsElement<ComputeGlobalsAlgorithm::VelocityVerlet>>(
-                        statePropagatorDataPtr, energyDataPtr, freeEnergyPerturbationDataPtr, signals,
-                        nstglobalcomm_, legacySimulatorData_->fplog, legacySimulatorData_->mdlog,
-                        legacySimulatorData_->cr, legacySimulatorData_->inputrec,
-                        legacySimulatorData_->mdAtoms, legacySimulatorData_->nrnb,
-                        legacySimulatorData_->wcycle, legacySimulatorData_->fr,
-                        legacySimulatorData_->top_global, legacySimulatorData_->constr, hasReadEkinState);
+                        statePropagatorDataPtr, energyDataPtr, freeEnergyPerturbationDataPtr,
+                        globalCommunicationHelper->simulationSignals(),
+                        globalCommunicationHelper->nstglobalcomm(), legacySimulatorData_->fplog,
+                        legacySimulatorData_->mdlog, legacySimulatorData_->cr,
+                        legacySimulatorData_->inputrec, legacySimulatorData_->mdAtoms,
+                        legacySimulatorData_->nrnb, legacySimulatorData_->wcycle,
+                        legacySimulatorData_->fr, legacySimulatorData_->top_global,
+                        legacySimulatorData_->constr, hasReadEkinState);
         topologyHolderBuilder->registerClient(computeGlobalsElement.get());
         energySignallerBuilder->registerSignallerClient(compat::make_not_null(computeGlobalsElement.get()));
         trajectorySignallerBuilder->registerSignallerClient(
                 compat::make_not_null(computeGlobalsElement.get()));
 
-        *checkBondedInteractionsCallback =
-                computeGlobalsElement->getCheckNumberOfBondedInteractionsCallback();
+        globalCommunicationHelper->setCheckBondedInteractionsCallback(
+                computeGlobalsElement->getCheckNumberOfBondedInteractionsCallback());
 
         auto propagatorVelocities = std::make_unique<Propagator<IntegrationStep::VelocitiesOnly>>(
                 legacySimulatorData_->inputrec->delta_t * 0.5, statePropagatorDataPtr,
index f4d8abbe948bcbfd256f1a60635dd5462ceef1dd..0488095fa5b449b79aaf1689a25498174526948a 100644 (file)
@@ -388,6 +388,7 @@ ModularSimulatorAlgorithm ModularSimulatorAlgorithmBuilder::constructElementsAnd
             legacySimulatorData_->cr, legacySimulatorData_->mdlog, legacySimulatorData_->mdrunOptions,
             legacySimulatorData_->inputrec, legacySimulatorData_->nrnb, legacySimulatorData_->wcycle,
             legacySimulatorData_->fr, legacySimulatorData_->walltime_accounting);
+    GlobalCommunicationHelper globalCommunicationHelper(nstglobalcomm_, &algorithm.signals_);
     /* When restarting from a checkpoint, it can be appropriate to
      * initialize ekind from quantities in the checkpoint. Otherwise,
      * compute_globals must initialize ekind before the simulation
@@ -454,10 +455,10 @@ ModularSimulatorAlgorithm ModularSimulatorAlgorithmBuilder::constructElementsAnd
      */
     const bool simulationsShareState = false;
     algorithm.stopHandler_           = legacySimulatorData_->stopHandlerBuilder->getStopHandlerMD(
-            compat::not_null<SimulationSignal*>(&algorithm.signals_[eglsSTOPCOND]),
+            compat::not_null<SimulationSignal*>(&(*globalCommunicationHelper.simulationSignals())[eglsSTOPCOND]),
             simulationsShareState, MASTER(legacySimulatorData_->cr),
             legacySimulatorData_->inputrec->nstlist, legacySimulatorData_->mdrunOptions.reproducible,
-            nstglobalcomm_, legacySimulatorData_->mdrunOptions.maximumHoursToRun,
+            globalCommunicationHelper.nstglobalcomm(), legacySimulatorData_->mdrunOptions.maximumHoursToRun,
             legacySimulatorData_->inputrec->nstlist == 0, legacySimulatorData_->fplog,
             algorithm.stophandlerCurrentStep_, algorithm.stophandlerIsNSStep_,
             legacySimulatorData_->walltime_accounting);
@@ -485,13 +486,11 @@ ModularSimulatorAlgorithm ModularSimulatorAlgorithmBuilder::constructElementsAnd
      */
     // TODO: Make a CheckpointHelperBuilder
     std::vector<ICheckpointHelperClient*> checkpointClients;
-    CheckBondedInteractionsCallbackPtr    checkBondedInteractionsCallback = nullptr;
-    auto                                  integrator                      = buildIntegrator(
+    auto                                  integrator = buildIntegrator(
             &neighborSearchSignallerBuilder, &lastStepSignallerBuilder, &energySignallerBuilder,
             &loggingSignallerBuilder, &trajectorySignallerBuilder, &trajectoryElementBuilder,
-            &checkpointClients, &checkBondedInteractionsCallback, statePropagatorDataPtr,
-            energyDataPtr, freeEnergyPerturbationDataPtr, hasReadEkinState, &topologyHolderBuilder,
-            &algorithm.signals_);
+            &checkpointClients, statePropagatorDataPtr, energyDataPtr, freeEnergyPerturbationDataPtr,
+            hasReadEkinState, &topologyHolderBuilder, &globalCommunicationHelper);
 
     FreeEnergyPerturbationData::Element* freeEnergyPerturbationElement = nullptr;
     if (algorithm.freeEnergyPerturbationData_)
@@ -523,16 +522,14 @@ ModularSimulatorAlgorithm ModularSimulatorAlgorithmBuilder::constructElementsAnd
 
     if (DOMAINDECOMP(legacySimulatorData_->cr))
     {
-        GMX_ASSERT(checkBondedInteractionsCallback,
-                   "Domain decomposition needs a callback for check the number of bonded "
-                   "interactions.");
         algorithm.domDecHelper_ = std::make_unique<DomDecHelper>(
                 legacySimulatorData_->mdrunOptions.verbose,
                 legacySimulatorData_->mdrunOptions.verboseStepPrintInterval, statePropagatorDataPtr,
-                algorithm.topologyHolder_.get(), std::move(checkBondedInteractionsCallback),
-                nstglobalcomm_, legacySimulatorData_->fplog, legacySimulatorData_->cr,
-                legacySimulatorData_->mdlog, legacySimulatorData_->constr, legacySimulatorData_->inputrec,
-                legacySimulatorData_->mdAtoms, legacySimulatorData_->nrnb,
+                algorithm.topologyHolder_.get(),
+                globalCommunicationHelper.moveCheckBondedInteractionsCallback(),
+                globalCommunicationHelper.nstglobalcomm(), legacySimulatorData_->fplog,
+                legacySimulatorData_->cr, legacySimulatorData_->mdlog, legacySimulatorData_->constr,
+                legacySimulatorData_->inputrec, legacySimulatorData_->mdAtoms, legacySimulatorData_->nrnb,
                 legacySimulatorData_->wcycle, legacySimulatorData_->fr, legacySimulatorData_->vsite,
                 legacySimulatorData_->imdSession, legacySimulatorData_->pull_work);
         neighborSearchSignallerBuilder.registerSignallerClient(
@@ -541,7 +538,8 @@ ModularSimulatorAlgorithm ModularSimulatorAlgorithmBuilder::constructElementsAnd
 
     const bool simulationsShareResetCounters = false;
     algorithm.resetHandler_                  = std::make_unique<ResetHandler>(
-            compat::make_not_null<SimulationSignal*>(&algorithm.signals_[eglsRESETCOUNTERS]),
+            compat::make_not_null<SimulationSignal*>(
+                    &(*globalCommunicationHelper.simulationSignals())[eglsRESETCOUNTERS]),
             simulationsShareResetCounters, legacySimulatorData_->inputrec->nsteps,
             MASTER(legacySimulatorData_->cr), legacySimulatorData_->mdrunOptions.timingOptions.resetHalfway,
             legacySimulatorData_->mdrunOptions.maximumHoursToRun, legacySimulatorData_->mdlog,
@@ -578,7 +576,8 @@ ModularSimulatorAlgorithm ModularSimulatorAlgorithmBuilder::constructElementsAnd
     // Add checkpoint helper here since we need a pointer to the trajectory element and
     // need to register it with the lastStepSignallerBuilder
     auto checkpointHandler = std::make_unique<CheckpointHandler>(
-            compat::make_not_null<SimulationSignal*>(&algorithm.signals_[eglsCHKPT]),
+            compat::make_not_null<SimulationSignal*>(
+                    &(*globalCommunicationHelper.simulationSignals())[eglsCHKPT]),
             simulationsShareState, legacySimulatorData_->inputrec->nstlist == 0,
             MASTER(legacySimulatorData_->cr), legacySimulatorData_->mdrunOptions.writeConfout,
             legacySimulatorData_->mdrunOptions.checkpointOptions.period);
@@ -659,4 +658,31 @@ SignallerCallbackPtr ModularSimulatorAlgorithm::SignalHelper::registerNSCallback
     return std::make_unique<SignallerCallback>(
             [this](Step step, Time gmx_unused time) { this->nextNSStep_ = step; });
 }
+
+GlobalCommunicationHelper::GlobalCommunicationHelper(int nstglobalcomm, SimulationSignals* simulationSignals) :
+    nstglobalcomm_(nstglobalcomm),
+    simulationSignals_(simulationSignals)
+{
+}
+
+int GlobalCommunicationHelper::nstglobalcomm() const
+{
+    return nstglobalcomm_;
+}
+
+SimulationSignals* GlobalCommunicationHelper::simulationSignals()
+{
+    return simulationSignals_;
+}
+
+void GlobalCommunicationHelper::setCheckBondedInteractionsCallback(CheckBondedInteractionsCallbackPtr ptr)
+{
+    checkBondedInteractionsCallbackPtr_ = std::move(ptr);
+}
+
+CheckBondedInteractionsCallbackPtr GlobalCommunicationHelper::moveCheckBondedInteractionsCallback()
+{
+    return std::move(checkBondedInteractionsCallbackPtr_);
+}
+
 } // namespace gmx
index 5cab3de05431d28d02e325cf98e05e60426a078e..29087ccd9fe6908c7d49c7e26f82ed53639321d9 100644 (file)
@@ -262,6 +262,38 @@ private:
     gmx_walltime_accounting* walltime_accounting;
 };
 
+/*! \internal
+ * \brief Helper container with data connected to global communication
+ *
+ * This includes data that needs to be shared between elements involved in
+ * global communication. This will become obsolete as soon as global
+ * communication is moved to a client system (#3421).
+ */
+class GlobalCommunicationHelper
+{
+public:
+    //! Constructor
+    GlobalCommunicationHelper(int nstglobalcomm, SimulationSignals* simulationSignals);
+
+    //! Get the compute globals communication period
+    [[nodiscard]] int nstglobalcomm() const;
+    //! Get a pointer to the signals vector
+    [[nodiscard]] SimulationSignals* simulationSignals();
+
+    //! Set the callback to check the number of bonded interactions
+    void setCheckBondedInteractionsCallback(CheckBondedInteractionsCallbackPtr ptr);
+    //! Move the callback to check the number of bonded interactions
+    [[nodiscard]] CheckBondedInteractionsCallbackPtr moveCheckBondedInteractionsCallback();
+
+private:
+    //! Compute globals communication period
+    const int nstglobalcomm_;
+    //! Signal vector (used by stop / reset / checkpointing signaller)
+    SimulationSignals* simulationSignals_;
+    //! Callback to check the number of bonded interactions
+    CheckBondedInteractionsCallbackPtr checkBondedInteractionsCallbackPtr_;
+};
+
 /*!\internal
  * \brief Builder for ModularSimulatorAlgorithm objects
  *
@@ -308,13 +340,12 @@ private:
                     SignallerBuilder<TrajectorySignaller>*     trajectorySignallerBuilder,
                     TrajectoryElementBuilder*                  trajectoryElementBuilder,
                     std::vector<ICheckpointHelperClient*>*     checkpointClients,
-                    CheckBondedInteractionsCallbackPtr*        checkBondedInteractionsCallback,
                     compat::not_null<StatePropagatorData*>     statePropagatorDataPtr,
                     compat::not_null<EnergyData*>              energyDataPtr,
                     FreeEnergyPerturbationData*                freeEnergyPerturbationDataPtr,
                     bool                                       hasReadEkinState,
                     TopologyHolder::Builder*                   topologyHolderBuilder,
-                    SimulationSignals*                         signals);
+                    GlobalCommunicationHelper*                 globalCommunicationHelper);
 
     //! Build the force element - can be normal forces or shell / flex constraints
     std::unique_ptr<ISimulatorElement>