Use ObservablesReducer for check of DD bonded interaction count.
[alexxy/gromacs.git] / src / gromacs / modularsimulator / statepropagatordata.cpp
index a0177f43c9ebdec912b89e10196b575207b72b44..c57ec7c0611573a01bb91da23c93fdded517c9d4 100644 (file)
@@ -146,6 +146,7 @@ StatePropagatorData::StatePropagatorData(int                numAtoms,
                                          FILE*              fplog,
                                          const t_commrec*   cr,
                                          t_state*           globalState,
+                                         t_state*           localState,
                                          bool               useGPU,
                                          bool               canMoleculesBeDistributedOverPBC,
                                          bool               writeFinalConfiguration,
@@ -180,10 +181,9 @@ StatePropagatorData::StatePropagatorData(int                numAtoms,
     // Local state only becomes valid now.
     if (DOMAINDECOMP(cr))
     {
-        auto localState = std::make_unique<t_state>();
-        dd_init_local_state(*cr->dd, globalState, localState.get());
+        dd_init_local_state(*cr->dd, globalState, localState);
         stateHasVelocities = ((localState->flags & enumValueToBitMask(StateEntry::V)) != 0);
-        setLocalState(std::move(localState));
+        setLocalState(localState);
     }
     else
     {
@@ -327,29 +327,41 @@ int StatePropagatorData::totalNumAtoms() const
     return totalNumAtoms_;
 }
 
-std::unique_ptr<t_state> StatePropagatorData::localState()
+t_state* StatePropagatorData::localState()
 {
-    auto state   = std::make_unique<t_state>();
-    state->flags = enumValueToBitMask(StateEntry::X) | enumValueToBitMask(StateEntry::V)
-                   | enumValueToBitMask(StateEntry::Box);
-    state_change_natoms(state.get(), localNAtoms_);
-    state->x = x_;
-    state->v = v_;
-    copy_mat(box_, state->box);
-    state->ddp_count       = ddpCount_;
-    state->ddp_count_cg_gl = ddpCountCgGl_;
-    state->cg_gl           = cgGl_;
-    return state;
+    localState_->flags = enumValueToBitMask(StateEntry::X) | enumValueToBitMask(StateEntry::V)
+                         | enumValueToBitMask(StateEntry::Box);
+    state_change_natoms(localState_, localNAtoms_);
+    std::swap(localState_->x, x_);
+    std::swap(localState_->v, v_);
+    copy_mat(box_, localState_->box);
+    localState_->ddp_count       = ddpCount_;
+    localState_->ddp_count_cg_gl = ddpCountCgGl_;
+    localState_->cg_gl           = cgGl_;
+    return localState_;
 }
 
-void StatePropagatorData::setLocalState(std::unique_ptr<t_state> state)
+std::unique_ptr<t_state> StatePropagatorData::copyLocalState(std::unique_ptr<t_state> copy)
 {
+    copy->flags = enumValueToBitMask(StateEntry::X) | enumValueToBitMask(StateEntry::V)
+                  | enumValueToBitMask(StateEntry::Box);
+    state_change_natoms(copy.get(), localNAtoms_);
+    copy->x = x_;
+    copy->v = v_;
+    copy_mat(box_, copy->box);
+    copy->ddp_count       = ddpCount_;
+    copy->ddp_count_cg_gl = ddpCountCgGl_;
+    copy->cg_gl           = cgGl_;
+    return copy;
+}
+
+void StatePropagatorData::setLocalState(t_state* state)
+{
+    localState_  = state;
     localNAtoms_ = state->natoms;
-    x_.resizeWithPadding(localNAtoms_);
     previousX_.resizeWithPadding(localNAtoms_);
-    v_.resizeWithPadding(localNAtoms_);
-    x_ = state->x;
-    v_ = state->v;
+    std::swap(x_, state->x);
+    std::swap(v_, state->v);
     copy_mat(state->box, box_);
     copyPosition();
     ddpCount_     = state->ddp_count;
@@ -433,8 +445,9 @@ void StatePropagatorData::Element::scheduleTask(Step                       step,
 
 void StatePropagatorData::Element::saveState()
 {
-    GMX_ASSERT(!localStateBackup_, "Save state called again before previous state was written.");
-    localStateBackup_ = statePropagatorData_->localState();
+    GMX_ASSERT(!localStateBackupValid_,
+               "Save state called again before previous state was written.");
+    localStateBackup_ = statePropagatorData_->copyLocalState(std::move(localStateBackup_));
     if (freeEnergyPerturbationData_)
     {
         localStateBackup_->fep_state    = freeEnergyPerturbationData_->currentFEPState();
@@ -443,6 +456,7 @@ void StatePropagatorData::Element::saveState()
         localStateBackup_->flags |=
                 enumValueToBitMask(StateEntry::Lambda) | enumValueToBitMask(StateEntry::FepState);
     }
+    localStateBackupValid_ = true;
 }
 
 std::optional<SignallerCallback> StatePropagatorData::Element::registerTrajectorySignallerCallback(TrajectoryEvent event)
@@ -511,7 +525,7 @@ void StatePropagatorData::Element::write(gmx_mdoutf_t outf, Step currentStep, Ti
         wallcycle_stop(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
         return;
     }
-    GMX_ASSERT(localStateBackup_, "Trajectory writing called, but no state saved.");
+    GMX_ASSERT(localStateBackupValid_, "Trajectory writing called, but no state saved.");
 
     // TODO: This is only used for CPT - needs to be filled when we turn CPT back on
     ObservablesHistory* observablesHistory = nullptr;
@@ -531,7 +545,7 @@ void StatePropagatorData::Element::write(gmx_mdoutf_t outf, Step currentStep, Ti
 
     if (currentStep != lastStep_ || !isRegularSimulationEnd_)
     {
-        localStateBackup_.reset();
+        localStateBackupValid_ = false;
     }
     wallcycle_stop(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
 }
@@ -695,7 +709,7 @@ void StatePropagatorData::Element::trajectoryWriterTeardown(gmx_mdoutf* gmx_unus
         return;
     }
 
-    GMX_ASSERT(localStateBackup_, "Final trajectory writing called, but no state saved.");
+    GMX_ASSERT(localStateBackupValid_, "Final trajectory writing called, but no state saved.");
 
     wallcycle_start(mdoutf_get_wcycle(outf), WallCycleCounter::Traj);
     if (DOMAINDECOMP(cr_))
@@ -770,6 +784,7 @@ StatePropagatorData::Element::Element(StatePropagatorData* statePropagatorData,
     nstvout_(nstvout),
     nstfout_(nstfout),
     nstxout_compressed_(nstxout_compressed),
+    localStateBackup_(std::make_unique<t_state>()),
     writeOutStep_(-1),
     freeEnergyPerturbationData_(nullptr),
     isRegularSimulationEnd_(false),