Fix bug resetting mdatoms masses to lambda=0 state
authorPascal Merz <pascal.merz@me.com>
Tue, 29 Sep 2020 19:47:35 +0000 (19:47 +0000)
committerMark Abraham <mark.j.abraham@gmail.com>
Tue, 29 Sep 2020 19:47:35 +0000 (19:47 +0000)
Following the reorganization in !384, the TopologyHolder reset the masses
in mdatoms after they were set by the FreeEnergyPerturbationElement. This
lead to having the masses equal to state A independently of the actual
lambda state.

This change fixes this bug by moving the mass setting into the setup
phase ensuring that it is set by the FEP element independently of the
build order of the elements.

When DD is used, the mdatoms masses were reset to their lambda=0 state
after every DD step. This is fixed by making the DomDecHelper aware of
the FEP element.

Note that the use of MDAtoms should likely be revisited to allow for a
more elegant solution, but this would likely require changes which are
less local than the proposed solution here. Further development is
therefore deferred to #3700.

src/gromacs/modularsimulator/domdechelper.cpp
src/gromacs/modularsimulator/domdechelper.h
src/gromacs/modularsimulator/freeenergyperturbationdata.cpp
src/gromacs/modularsimulator/freeenergyperturbationdata.h
src/gromacs/modularsimulator/simulatoralgorithm.cpp

index 4d21b62d1000036d0dda8790071ace5cfe2390b3..9128119bfebf7724131de5b222fc4810b6937a7d 100644 (file)
@@ -50,6 +50,7 @@
 #include "gromacs/mdtypes/state.h"
 #include "gromacs/pbcutil/pbc.h"
 
+#include "freeenergyperturbationdata.h"
 #include "statepropagatordata.h"
 #include "topologyholder.h"
 
@@ -58,6 +59,7 @@ namespace gmx
 DomDecHelper::DomDecHelper(bool                            isVerbose,
                            int                             verbosePrintInterval,
                            StatePropagatorData*            statePropagatorData,
+                           FreeEnergyPerturbationData*     freeEnergyPerturbationData,
                            TopologyHolder*                 topologyHolder,
                            CheckBondedInteractionsCallback checkBondedInteractionsCallback,
                            int                             nstglobalcomm,
@@ -78,6 +80,7 @@ DomDecHelper::DomDecHelper(bool                            isVerbose,
     verbosePrintInterval_(verbosePrintInterval),
     nstglobalcomm_(nstglobalcomm),
     statePropagatorData_(statePropagatorData),
+    freeEnergyPerturbationData_(freeEnergyPerturbationData),
     topologyHolder_(topologyHolder),
     checkBondedInteractionsCallback_(std::move(checkBondedInteractionsCallback)),
     fplog_(fplog),
@@ -98,10 +101,6 @@ DomDecHelper::DomDecHelper(bool                            isVerbose,
 
 void DomDecHelper::setup()
 {
-    std::unique_ptr<t_state> localState   = statePropagatorData_->localState();
-    t_state*                 globalState  = statePropagatorData_->globalState();
-    ForceBuffers*            forcePointer = statePropagatorData_->forcePointer();
-
     // constant choices for this call to dd_partition_system
     const bool     verbose       = false;
     const bool     isMasterState = true;
@@ -109,14 +108,8 @@ void DomDecHelper::setup()
     gmx_wallcycle* wcycle        = nullptr;
 
     // Distribute the charge groups over the nodes from the master node
-    dd_partition_system(fplog_, mdlog_, inputrec_->init_step, cr_, isMasterState, nstglobalcomm,
-                        globalState, topologyHolder_->globalTopology(), inputrec_, imdSession_,
-                        pull_work_, localState.get(), forcePointer, mdAtoms_,
-                        topologyHolder_->localTopology_.get(), fr_, vsite_, constr_, nrnb_, wcycle,
-                        verbose);
-    topologyHolder_->updateLocalTopology();
-    checkBondedInteractionsCallback_();
-    statePropagatorData_->setLocalState(std::move(localState));
+    partitionSystem(verbose, isMasterState, nstglobalcomm, wcycle,
+                    statePropagatorData_->localState(), statePropagatorData_->globalState());
 }
 
 void DomDecHelper::run(Step step, Time gmx_unused time)
@@ -125,9 +118,8 @@ void DomDecHelper::run(Step step, Time gmx_unused time)
     {
         return;
     }
-    std::unique_ptr<t_state> localState   = statePropagatorData_->localState();
-    t_state*                 globalState  = statePropagatorData_->globalState();
-    ForceBuffers*            forcePointer = statePropagatorData_->forcePointer();
+    std::unique_ptr<t_state> localState  = statePropagatorData_->localState();
+    t_state*                 globalState = statePropagatorData_->globalState();
 
     // constant choices for this call to dd_partition_system
     const bool verbose = isVerbose_ && (step % verbosePrintInterval_ == 0 || step == inputrec_->init_step);
@@ -149,13 +141,31 @@ void DomDecHelper::run(Step step, Time gmx_unused time)
     }
 
     // Distribute the charge groups over the nodes from the master node
-    dd_partition_system(fplog_, mdlog_, step, cr_, isMasterState, nstglobalcomm_, globalState,
-                        topologyHolder_->globalTopology(), inputrec_, imdSession_, pull_work_,
-                        localState.get(), forcePointer, mdAtoms_, topologyHolder_->localTopology_.get(),
-                        fr_, vsite_, constr_, nrnb_, wcycle_, verbose);
+    partitionSystem(verbose, isMasterState, nstglobalcomm_, wcycle_, std::move(localState), globalState);
+}
+
+void DomDecHelper::partitionSystem(bool                     verbose,
+                                   bool                     isMasterState,
+                                   int                      nstglobalcomm,
+                                   gmx_wallcycle*           wcycle,
+                                   std::unique_ptr<t_state> localState,
+                                   t_state*                 globalState)
+{
+    ForceBuffers* forcePointer = statePropagatorData_->forcePointer();
+
+    // Distribute the charge groups over the nodes from the master node
+    dd_partition_system(fplog_, mdlog_, inputrec_->init_step, cr_, isMasterState, nstglobalcomm,
+                        globalState, topologyHolder_->globalTopology(), inputrec_, imdSession_,
+                        pull_work_, localState.get(), forcePointer, mdAtoms_,
+                        topologyHolder_->localTopology_.get(), fr_, vsite_, constr_, nrnb_, wcycle,
+                        verbose);
     topologyHolder_->updateLocalTopology();
     checkBondedInteractionsCallback_();
     statePropagatorData_->setLocalState(std::move(localState));
+    if (freeEnergyPerturbationData_)
+    {
+        freeEnergyPerturbationData_->updateMDAtoms();
+    }
 }
 
 std::optional<SignallerCallback> DomDecHelper::registerNSCallback()
index 1229ff7ed0f94bf375c1d419ee57023ad1a8c2ad..2b6691c1578a199618059eb3e67c90f22f3be9c2 100644 (file)
@@ -57,6 +57,7 @@ struct t_nrnb;
 namespace gmx
 {
 class Constraints;
+class FreeEnergyPerturbationData;
 class ImdSession;
 class MDAtoms;
 class MDLogger;
@@ -90,6 +91,7 @@ public:
     DomDecHelper(bool                            isVerbose,
                  int                             verbosePrintInterval,
                  StatePropagatorData*            statePropagatorData,
+                 FreeEnergyPerturbationData*     freeEnergyPerturbationData,
                  TopologyHolder*                 topologyHolder,
                  CheckBondedInteractionsCallback checkBondedInteractionsCallback,
                  int                             nstglobalcomm,
@@ -136,11 +138,21 @@ private:
     // TODO: Clarify relationship to data objects and find a more robust alternative to raw pointers (#3583)
     //! Pointer to the micro state
     StatePropagatorData* statePropagatorData_;
+    //! Pointer to the free energy data
+    FreeEnergyPerturbationData* freeEnergyPerturbationData_;
     //! Pointer to the topology
     TopologyHolder* topologyHolder_;
     //! Pointer to the ComputeGlobalsHelper object - to ask for # of bonded interaction checking
     CheckBondedInteractionsCallback checkBondedInteractionsCallback_;
 
+    //! Helper function unifying the DD partitioning calls in setup() and run()
+    void partitionSystem(bool                     verbose,
+                         bool                     isMasterState,
+                         int                      nstglobalcomm,
+                         gmx_wallcycle*           wcycle,
+                         std::unique_ptr<t_state> localState,
+                         t_state*                 globalState);
+
     // Access to ISimulator data
     //! Handles logging.
     FILE* fplog_;
index 86b098d89f68a41ae60f8fc46f1cdce4fc514d5b..ef1baf067d5edc5369d28b4395d9af45754fde34 100644 (file)
@@ -72,7 +72,6 @@ FreeEnergyPerturbationData::FreeEnergyPerturbationData(FILE* fplog, const t_inpu
     // available on master. We have the lambda vector available everywhere, so we pass a `true`
     // for isMaster on all ranks. See #3647.
     initialize_lambdas(fplog_, *inputrec_, true, &currentFEPState_, lambda_);
-    update_mdatoms(mdAtoms_->mdatoms(), lambda_[efptMASS]);
 }
 
 void FreeEnergyPerturbationData::Element::scheduleTask(Step step,
@@ -89,7 +88,7 @@ void FreeEnergyPerturbationData::updateLambdas(Step step)
 {
     // at beginning of step (if lambdas change...)
     lambda_ = currentLambdas(step, *(inputrec_->fepvals), currentFEPState_);
-    update_mdatoms(mdAtoms_->mdatoms(), lambda_[efptMASS]);
+    updateMDAtoms();
 }
 
 ArrayRef<real> FreeEnergyPerturbationData::lambdaView()
@@ -107,6 +106,11 @@ int FreeEnergyPerturbationData::currentFEPState()
     return currentFEPState_;
 }
 
+void FreeEnergyPerturbationData::updateMDAtoms()
+{
+    update_mdatoms(mdAtoms_->mdatoms(), lambda_[efptMASS]);
+}
+
 namespace
 {
 /*!
@@ -155,8 +159,6 @@ void FreeEnergyPerturbationData::Element::restoreCheckpointState(std::optional<R
         dd_bcast(cr->dd, freeEnergyPerturbationData_->lambda_.size() * sizeof(real),
                  freeEnergyPerturbationData_->lambda_.data());
     }
-    update_mdatoms(freeEnergyPerturbationData_->mdAtoms_->mdatoms(),
-                   freeEnergyPerturbationData_->lambda_[efptMASS]);
 }
 
 const std::string& FreeEnergyPerturbationData::Element::clientID()
@@ -171,6 +173,11 @@ FreeEnergyPerturbationData::Element::Element(FreeEnergyPerturbationData* freeEne
 {
 }
 
+void FreeEnergyPerturbationData::Element::elementSetup()
+{
+    freeEnergyPerturbationData_->updateMDAtoms();
+}
+
 FreeEnergyPerturbationData::Element* FreeEnergyPerturbationData::element()
 {
     return element_.get();
index 0ee6fe78ece8f17d29b0815d11ec24d0a8f3fa2a..b1291dfde672b3a3004c4fbe01c4b7291e8827b4 100644 (file)
@@ -83,6 +83,8 @@ public:
     ArrayRef<const real> constLambdaView();
     //! Get the current FEP state
     int currentFEPState();
+    //! Update MDAtoms (public because it's called by DomDec - see #3700)
+    void updateMDAtoms();
 
     //! The element taking part in the simulator loop
     class Element;
@@ -128,8 +130,8 @@ public:
     //! Update lambda and mdatoms
     void scheduleTask(Step step, Time time, const RegisterRunFunction& registerRunFunction) override;
 
-    //! No setup needed
-    void elementSetup() override{};
+    //! Update the MdAtoms object
+    void elementSetup() override;
 
     //! No teardown needed
     void elementTeardown() override{};
index 61e25e0b6e5c7c3ed907fa275f16beca6243f296..aa18bd767d0b24dee23c4659b5e149b06ce03cc0 100644 (file)
@@ -500,7 +500,8 @@ ModularSimulatorAlgorithm ModularSimulatorAlgorithmBuilder::build()
         algorithm.domDecHelper_ = std::make_unique<DomDecHelper>(
                 legacySimulatorData_->mdrunOptions.verbose,
                 legacySimulatorData_->mdrunOptions.verboseStepPrintInterval,
-                algorithm.statePropagatorData_.get(), algorithm.topologyHolder_.get(),
+                algorithm.statePropagatorData_.get(), algorithm.freeEnergyPerturbationData_.get(),
+                algorithm.topologyHolder_.get(),
                 globalCommunicationHelper_.moveCheckBondedInteractionsCallback(),
                 globalCommunicationHelper_.nstglobalcomm(), legacySimulatorData_->fplog,
                 legacySimulatorData_->cr, legacySimulatorData_->mdlog, legacySimulatorData_->constr,