Add function calls to MdModules to sign up for notifications
[alexxy/gromacs.git] / src / gromacs / applied_forces / densityfitting.cpp
index 5a134e4327ab0a1d9c929d5511a895cb07dae96f..2889e61a5072eabe85e1f4a13fa11f2c8a34ec45 100644 (file)
@@ -223,8 +223,12 @@ private:
  */
 class DensityFitting final : public IMDModule
 {
+
 public:
-    /*! \brief Construct the density fitting module.
+    //! Construct the density fitting module.
+    explicit DensityFitting() = default;
+
+    /*! \brief Request to be notified during pre-processing.
      *
      * \param[in] notifier allows the module to subscribe to notifications from MdModules.
      *
@@ -234,25 +238,18 @@ public:
      *   - storing its internal parameters in a tpr file by writing to a
      *     key-value-tree during pre-processing by a function taking a
      *     KeyValueTreeObjectBuilder as parameter
-     *   - reading its internal parameters from a key-value-tree during
-     *     simulation setup by taking a const KeyValueTreeObject & parameter
-     *   - constructing local atom sets in the simulation parameter setup
-     *     by taking a LocalAtomSetManager * as parameter
-     *   - the type of periodic boundary conditions that are used
-     *     by taking a PeriodicBoundaryConditionType as parameter
-     *   - the writing of checkpoint data
-     *     by taking a MdModulesWriteCheckpointData as parameter
-     *   - the reading of checkpoint data
-     *     by taking a MdModulesCheckpointReadingDataOnMaster as parameter
-     *   - the broadcasting of checkpoint data
-     *     by taking MdModulesCheckpointReadingBroadcast as parameter
      */
-    explicit DensityFitting(MdModulesNotifier* notifier)
+    void subscribeToPreProcessingNotifications(MdModulesNotifier* notifier) override
     {
         // Callbacks for several kinds of MdModuleNotification are created
         // and subscribed, and will be dispatched correctly at run time
         // based on the type of the parameter required by the lambda.
 
+        if (!densityFittingOptions_.active())
+        {
+            return;
+        }
+
         // Setting the atom group indices from index group string
         const auto setFitGroupIndicesFunction = [this](const IndexGroupsAndNames& indexGroupsAndNames) {
             densityFittingOptions_.setFitGroupIndices(indexGroupsAndNames);
@@ -265,18 +262,41 @@ public:
         };
         notifier->preProcessingNotifications_.subscribe(writeInternalParametersFunction);
 
-        // Reading internal parameters during simulation setup
-        const auto readInternalParametersFunction = [this](const KeyValueTreeObject& tree) {
-            densityFittingOptions_.readInternalParametersFromKvt(tree);
-        };
-        notifier->simulationSetupNotifications_.subscribe(readInternalParametersFunction);
-
         // Checking for consistency with all .mdp options
         const auto checkEnergyCaluclationFrequencyFunction =
                 [this](EnergyCalculationFrequencyErrors* energyCalculationFrequencyErrors) {
                     densityFittingOptions_.checkEnergyCaluclationFrequency(energyCalculationFrequencyErrors);
                 };
         notifier->preProcessingNotifications_.subscribe(checkEnergyCaluclationFrequencyFunction);
+    }
+
+    /*! \brief Request to be notified.
+     * The density fitting code subscribes to these notifications:
+     *   - reading its internal parameters from a key-value-tree during
+     *     simulation setup by taking a const KeyValueTreeObject & parameter
+     *   - constructing local atom sets in the simulation parameter setup
+     *     by taking a LocalAtomSetManager * as parameter
+     *   - the type of periodic boundary conditions that are used
+     *     by taking a PeriodicBoundaryConditionType as parameter
+     *   - the writing of checkpoint data
+     *     by taking a MdModulesWriteCheckpointData as parameter
+     *   - the reading of checkpoint data
+     *     by taking a MdModulesCheckpointReadingDataOnMaster as parameter
+     *   - the broadcasting of checkpoint data
+     *     by taking MdModulesCheckpointReadingBroadcast as parameter
+     */
+    void subscribeToSimulationSetupNotifications(MdModulesNotifier* notifier) override
+    {
+        if (!densityFittingOptions_.active())
+        {
+            return;
+        }
+
+        // Reading internal parameters during simulation setup
+        const auto readInternalParametersFunction = [this](const KeyValueTreeObject& tree) {
+            densityFittingOptions_.readInternalParametersFromKvt(tree);
+        };
+        notifier->simulationSetupNotifications_.subscribe(readInternalParametersFunction);
 
         // constructing local atom sets during simulation setup
         const auto setLocalAtomSetFunction = [this](LocalAtomSetManager* localAtomSetManager) {
@@ -359,12 +379,8 @@ public:
      */
     void constructLocalAtomSet(LocalAtomSetManager* localAtomSetManager)
     {
-        if (densityFittingOptions_.active())
-        {
-            LocalAtomSet atomSet =
-                    localAtomSetManager->add(densityFittingOptions_.buildParameters().indices_);
-            densityFittingSimulationParameters_.setLocalAtomSet(atomSet);
-        }
+        LocalAtomSet atomSet = localAtomSetManager->add(densityFittingOptions_.buildParameters().indices_);
+        densityFittingSimulationParameters_.setLocalAtomSet(atomSet);
     }
 
     /*! \brief Request energy output to energy file during simulation.
@@ -384,21 +400,17 @@ public:
      */
     void writeCheckpointData(MdModulesWriteCheckpointData checkpointWriting)
     {
-        if (densityFittingOptions_.active())
-        {
-            const DensityFittingForceProviderState& state = forceProvider_->stateToCheckpoint();
-            checkpointWriting.builder_.addValue<std::int64_t>(
-                    DensityFittingModuleInfo::name_ + "-stepsSinceLastCalculation",
-                    state.stepsSinceLastCalculation_);
-            checkpointWriting.builder_.addValue<real>(
-                    DensityFittingModuleInfo::name_ + "-adaptiveForceConstantScale",
-                    state.adaptiveForceConstantScale_);
-            KeyValueTreeObjectBuilder exponentialMovingAverageKvtEntry =
-                    checkpointWriting.builder_.addObject(DensityFittingModuleInfo::name_
-                                                         + "-exponentialMovingAverageState");
-            exponentialMovingAverageStateAsKeyValueTree(exponentialMovingAverageKvtEntry,
-                                                        state.exponentialMovingAverageState_);
-        }
+        const DensityFittingForceProviderState& state = forceProvider_->stateToCheckpoint();
+        checkpointWriting.builder_.addValue<std::int64_t>(
+                DensityFittingModuleInfo::name_ + "-stepsSinceLastCalculation",
+                state.stepsSinceLastCalculation_);
+        checkpointWriting.builder_.addValue<real>(
+                DensityFittingModuleInfo::name_ + "-adaptiveForceConstantScale",
+                state.adaptiveForceConstantScale_);
+        KeyValueTreeObjectBuilder exponentialMovingAverageKvtEntry = checkpointWriting.builder_.addObject(
+                DensityFittingModuleInfo::name_ + "-exponentialMovingAverageState");
+        exponentialMovingAverageStateAsKeyValueTree(exponentialMovingAverageKvtEntry,
+                                                    state.exponentialMovingAverageState_);
     }
 
     /*! \brief Read the internal parameters from the checkpoint file on master
@@ -406,34 +418,31 @@ public:
      */
     void readCheckpointDataOnMaster(MdModulesCheckpointReadingDataOnMaster checkpointReading)
     {
-        if (densityFittingOptions_.active())
+        if (checkpointReading.checkpointedData_.keyExists(DensityFittingModuleInfo::name_
+                                                          + "-stepsSinceLastCalculation"))
         {
-            if (checkpointReading.checkpointedData_.keyExists(DensityFittingModuleInfo::name_
-                                                              + "-stepsSinceLastCalculation"))
-            {
-                densityFittingState_.stepsSinceLastCalculation_ =
-                        checkpointReading
-                                .checkpointedData_[DensityFittingModuleInfo::name_
-                                                   + "-stepsSinceLastCalculation"]
-                                .cast<std::int64_t>();
-            }
-            if (checkpointReading.checkpointedData_.keyExists(DensityFittingModuleInfo::name_
-                                                              + "-adaptiveForceConstantScale"))
-            {
-                densityFittingState_.adaptiveForceConstantScale_ =
-                        checkpointReading
-                                .checkpointedData_[DensityFittingModuleInfo::name_
-                                                   + "-adaptiveForceConstantScale"]
-                                .cast<real>();
-            }
-            if (checkpointReading.checkpointedData_.keyExists(DensityFittingModuleInfo::name_
-                                                              + "-exponentialMovingAverageState"))
-            {
-                densityFittingState_.exponentialMovingAverageState_ = exponentialMovingAverageStateFromKeyValueTree(
-                        checkpointReading
-                                .checkpointedData_[DensityFittingModuleInfo::name_ + "-exponentialMovingAverageState"]
-                                .asObject());
-            }
+            densityFittingState_.stepsSinceLastCalculation_ =
+                    checkpointReading
+                            .checkpointedData_[DensityFittingModuleInfo::name_
+                                               + "-stepsSinceLastCalculation"]
+                            .cast<std::int64_t>();
+        }
+        if (checkpointReading.checkpointedData_.keyExists(DensityFittingModuleInfo::name_
+                                                          + "-adaptiveForceConstantScale"))
+        {
+            densityFittingState_.adaptiveForceConstantScale_ =
+                    checkpointReading
+                            .checkpointedData_[DensityFittingModuleInfo::name_
+                                               + "-adaptiveForceConstantScale"]
+                            .cast<real>();
+        }
+        if (checkpointReading.checkpointedData_.keyExists(DensityFittingModuleInfo::name_
+                                                          + "-exponentialMovingAverageState"))
+        {
+            densityFittingState_.exponentialMovingAverageState_ = exponentialMovingAverageStateFromKeyValueTree(
+                    checkpointReading
+                            .checkpointedData_[DensityFittingModuleInfo::name_ + "-exponentialMovingAverageState"]
+                            .asObject());
         }
     }
 
@@ -443,14 +452,11 @@ public:
      */
     void broadcastCheckpointData(MdModulesCheckpointReadingBroadcast checkpointBroadcast)
     {
-        if (densityFittingOptions_.active())
+        if (PAR(&(checkpointBroadcast.cr_)))
         {
-            if (PAR(&(checkpointBroadcast.cr_)))
-            {
-                block_bc(&(checkpointBroadcast.cr_), densityFittingState_.stepsSinceLastCalculation_);
-                block_bc(&(checkpointBroadcast.cr_), densityFittingState_.adaptiveForceConstantScale_);
-                block_bc(&(checkpointBroadcast.cr_), densityFittingState_.exponentialMovingAverageState_);
-            }
+            block_bc(&(checkpointBroadcast.cr_), densityFittingState_.stepsSinceLastCalculation_);
+            block_bc(&(checkpointBroadcast.cr_), densityFittingState_.adaptiveForceConstantScale_);
+            block_bc(&(checkpointBroadcast.cr_), densityFittingState_.exponentialMovingAverageState_);
         }
     }
 
@@ -473,9 +479,9 @@ private:
 
 } // namespace
 
-std::unique_ptr<IMDModule> DensityFittingModuleInfo::create(MdModulesNotifier* notifier)
+std::unique_ptr<IMDModule> DensityFittingModuleInfo::create()
 {
-    return std::make_unique<DensityFitting>(notifier);
+    return std::make_unique<DensityFitting>();
 }
 
 const std::string DensityFittingModuleInfo::name_ = "density-guided-simulation";