Make orires work with DD particle reordering
authorBerk Hess <hess@kth.se>
Mon, 30 Aug 2021 07:53:39 +0000 (07:53 +0000)
committerBerk Hess <hess@kth.se>
Mon, 30 Aug 2021 07:53:39 +0000 (07:53 +0000)
12 files changed:
src/gromacs/listed_forces/listed_forces.cpp
src/gromacs/listed_forces/orires.cpp
src/gromacs/listed_forces/orires.h
src/gromacs/mdlib/forcerec.cpp
src/gromacs/mdlib/tests/leapfrogtestrunners.cpp
src/gromacs/mdlib/update.cpp
src/gromacs/mdlib/update.h
src/gromacs/mdlib/update_vv.cpp
src/gromacs/mdlib/update_vv.h
src/gromacs/mdrun/md.cpp
src/gromacs/mdrun/runner.cpp
src/gromacs/mdtypes/fcdata.h

index 893a1d306cd0ceb17b9e612a76ca8f4fcb2dc061..ad01d3fba15b7333f1c8eaee54180a4041b0762e 100644 (file)
@@ -854,12 +854,10 @@ void ListedForces::calculate(struct gmx_wallcycle*                     wcycle,
                                                        idef.il[F_ORIRES].size(),
                                                        idef.il[F_ORIRES].iatoms.data(),
                                                        idef.iparams.data(),
-                                                       md,
                                                        xWholeMolecules,
                                                        x,
                                                        fr->bMolPBC ? pbc : nullptr,
-                                                       fcdata->orires.get(),
-                                                       hist);
+                                                       fcdata->orires.get());
         }
         if (fcdata->disres->nres > 0)
         {
index 626d3996e7e0d57b371d24cee49b7df90ce09190..ceec9264ae57c579631dca3ac4465bd914d8c4b4 100644 (file)
@@ -42,6 +42,8 @@
 #include <climits>
 #include <cmath>
 
+#include "gromacs/domdec/ga2la.h"
+#include "gromacs/domdec/localatomsetmanager.h"
 #include "gromacs/gmxlib/network.h"
 #include "gromacs/linearalgebra/nrjac.h"
 #include "gromacs/math/do_fit.h"
@@ -69,13 +71,51 @@ using gmx::RVec;
 // TODO This implementation of ensemble orientation restraints is nasty because
 // a user can't just do multi-sim with single-sim orientation restraints.
 
-t_oriresdata::t_oriresdata(FILE*                 fplog,
-                           const gmx_mtop_t&     mtop,
-                           const t_inputrec&     ir,
-                           const t_commrec*      cr,
-                           const gmx_multisim_t* ms,
-                           t_state*              globalState) :
-    numRestraints(gmx_mtop_ftype_count(mtop, F_ORIRES))
+void extendStateWithOriresHistory(const gmx_mtop_t& mtop, const t_inputrec& ir, t_state* globalState)
+{
+    GMX_RELEASE_ASSERT(globalState != nullptr,
+                       "We need a valid global state in extendStateWithOriresHistory()");
+
+    const int numRestraints = gmx_mtop_ftype_count(mtop, F_ORIRES);
+    if (numRestraints > 0 && ir.orires_tau > 0)
+    {
+        /* Extend the state with the orires history */
+        globalState->flags |= enumValueToBitMask(StateEntry::OrireInitF);
+        globalState->hist.orire_initf = 1;
+        globalState->flags |= enumValueToBitMask(StateEntry::OrireDtav);
+        globalState->hist.orire_Dtav.resize(numRestraints * 5);
+    }
+}
+
+namespace
+{
+
+//! Creates and returns a list of global atom indices of the orientation restraint fit group
+std::vector<gmx::index> fitGlobalAtomIndices(const gmx_mtop_t& mtop)
+{
+    std::vector<gmx::index> indices;
+
+    for (int i = 0; i < mtop.natoms; i++)
+    {
+        if (getGroupType(mtop.groups, SimulationAtomGroupType::OrientationRestraintsFit, i) == 0)
+        {
+            indices.push_back(i);
+        }
+    }
+
+    return indices;
+}
+
+} // namespace
+
+t_oriresdata::t_oriresdata(FILE*                     fplog,
+                           const gmx_mtop_t&         mtop,
+                           const t_inputrec&         ir,
+                           const gmx_multisim_t*     ms,
+                           t_state*                  globalState,
+                           gmx::LocalAtomSetManager* localAtomSetManager) :
+    numRestraints(gmx_mtop_ftype_count(mtop, F_ORIRES)),
+    fitLocalAtomSet_(localAtomSetManager->add(fitGlobalAtomIndices(mtop)))
 {
     GMX_RELEASE_ASSERT(numRestraints > 0,
                        "orires() should only be called with orientation restraints present");
@@ -102,13 +142,6 @@ t_oriresdata::t_oriresdata(FILE*                 fplog,
                 "in the system"));
     }
 
-    if (cr && PAR(cr))
-    {
-        GMX_THROW(gmx::InvalidInputError(
-                "Orientation restraints do not work with MPI parallelization. Choose 1 MPI rank, "
-                "if possible."));
-    }
-
     GMX_RELEASE_ASSERT(globalState != nullptr, "We need a valid global state in t_oriresdata()");
 
     fc             = ir.orires_fc;
@@ -177,11 +210,8 @@ t_oriresdata::t_oriresdata(FILE*                 fplog,
         edt   = std::exp(-ir.delta_t / ir.orires_tau);
         edt_1 = 1.0 - edt;
 
-        /* Extend the state with the orires history */
-        globalState->flags |= enumValueToBitMask(StateEntry::OrireInitF);
-        globalState->hist.orire_initf = 1;
-        globalState->flags |= enumValueToBitMask(StateEntry::OrireDtav);
-        globalState->hist.orire_Dtav.resize(numRestraints * 5);
+        timeAveragingInitFactor_     = std::reference_wrapper<real>(globalState->hist.orire_initf);
+        DTensorsTimeAveragedHistory_ = globalState->hist.orire_Dtav;
     }
 
     orientations.resize(numRestraints);
@@ -205,25 +235,13 @@ t_oriresdata::t_oriresdata(FILE*                 fplog,
     }
     tmpEq.resize(numExperiments);
 
-    numReferenceAtoms = 0;
-    for (int i = 0; i < mtop.natoms; i++)
-    {
-        if (getGroupType(mtop.groups, SimulationAtomGroupType::OrientationRestraintsFit, i) == 0)
-        {
-            numReferenceAtoms++;
-        }
-    }
-    mref.resize(numReferenceAtoms);
-    xref.resize(numReferenceAtoms);
-    xtmp.resize(numReferenceAtoms);
-
     eigenOutput.resize(numExperiments * c_numEigenRealsPerExperiment);
 
     /* Determine the reference structure on the master node.
      * Copy it to the other nodes after checking multi compatibility,
      * so we are sure the subsystems match before copying.
      */
-    auto   x    = makeArrayRef(globalState->x);
+    auto   x    = makeConstArrayRef(globalState->x);
     rvec   com  = { 0, 0, 0 };
     double mtot = 0.0;
     int    j    = 0;
@@ -231,33 +249,38 @@ t_oriresdata::t_oriresdata(FILE*                 fplog,
     {
         const t_atom& local = atomP.atom();
         int           i     = atomP.globalAtomNumber();
-        if (mtop.groups.groupNumbers[SimulationAtomGroupType::OrientationRestraintsFit].empty()
-            || mtop.groups.groupNumbers[SimulationAtomGroupType::OrientationRestraintsFit][i] == 0)
+        if (getGroupType(mtop.groups, SimulationAtomGroupType::OrientationRestraintsFit, i) == 0)
         {
-            /* Not correct for free-energy with changing masses */
-            mref[j] = local.m;
+            // Not correct for free-energy with changing masses
+            const real mass = local.m;
             // Note that only one rank per sim is supported.
             if (isMasterSim(ms))
             {
-                copy_rvec(x[i], xref[j]);
+                referenceCoordinates_.push_back(x[i]);
                 for (int d = 0; d < DIM; d++)
                 {
-                    com[d] += mref[j] * x[i][d];
+                    com[d] += mass * x[i][d];
                 }
             }
-            mtot += mref[j];
+            fitMasses_.push_back(mass);
+            mtot += mass;
             j++;
         }
     }
+
     svmul(1.0 / mtot, com, com);
     if (isMasterSim(ms))
     {
-        for (int j = 0; j < numReferenceAtoms; j++)
+        for (auto& refCoord : referenceCoordinates_)
         {
-            rvec_dec(xref[j], com);
+            refCoord -= com;
         }
     }
 
+    const size_t numFitAtoms = referenceCoordinates_.size();
+    fitLocalAtomIndices_.resize(numFitAtoms);
+    xTmp_.resize(numFitAtoms);
+
     if (fplog)
     {
         fprintf(fplog, "Found %d orientation experiments\n", numExperiments);
@@ -266,7 +289,7 @@ t_oriresdata::t_oriresdata(FILE*                 fplog,
             fprintf(fplog, "  experiment %d has %d restraints\n", i + 1, nr_ex[i]);
         }
 
-        fprintf(fplog, "  the fit group consists of %d atoms and has total mass %g\n", numReferenceAtoms, mtot);
+        fprintf(fplog, "  the fit group consists of %zu atoms and has total mass %g\n", numFitAtoms, mtot);
     }
 
     if (ms)
@@ -280,10 +303,10 @@ t_oriresdata::t_oriresdata(FILE*                 fplog,
 
         check_multi_int(fplog, ms, numRestraints, "the number of orientation restraints", FALSE);
         check_multi_int(
-                fplog, ms, numReferenceAtoms, "the number of fit atoms for orientation restraining", FALSE);
+                fplog, ms, numFitAtoms, "the number of fit atoms for orientation restraining", FALSE);
         check_multi_int(fplog, ms, ir.nsteps, "nsteps", FALSE);
         /* Copy the reference coordinates from the master to the other nodes */
-        gmx_sum_sim(DIM * numReferenceAtoms, xref[0], ms);
+        gmx_sum_sim(DIM * referenceCoordinates_.size(), as_rvec_array(referenceCoordinates_.data())[0], ms);
     }
 
     please_cite(fplog, "Hess2003");
@@ -380,30 +403,24 @@ real calc_orires_dev(const gmx_multisim_t* ms,
                      int                   nfa,
                      const t_iatom         forceatoms[],
                      const t_iparams       ip[],
-                     const t_mdatoms*      md,
                      ArrayRef<const RVec>  xWholeMolecules,
                      const rvec            x[],
                      const t_pbc*          pbc,
-                     t_oriresdata*         od,
-                     const history_t*      hist)
+                     t_oriresdata*         od)
 {
     real       invn, pfac, r2, invr, corrfac, wsv2, sw, dev;
-    double     mtot;
     rvec       com, r_unrot, r;
     const real two_thr = 2.0 / 3.0;
 
-    const bool                     bTAV  = (od->edt != 0);
-    const real                     edt   = od->edt;
-    const real                     edt_1 = od->edt_1;
-    gmx::ArrayRef<OriresMatEq>     matEq = od->tmpEq;
-    const int                      nref  = od->numReferenceAtoms;
-    gmx::ArrayRef<real>            mref  = od->mref;
-    gmx::ArrayRef<const gmx::RVec> xref  = od->xref;
-    gmx::ArrayRef<gmx::RVec>       xtmp  = od->xtmp;
+    const bool                 bTAV  = (od->edt != 0);
+    const real                 edt   = od->edt;
+    const real                 edt_1 = od->edt_1;
+    gmx::ArrayRef<OriresMatEq> matEq = od->tmpEq;
+    gmx::ArrayRef<gmx::RVec>   xFit  = od->xTmp();
 
     if (bTAV)
     {
-        od->exp_min_t_tau = hist->orire_initf * edt;
+        od->exp_min_t_tau = od->timeAveragingInitFactor() * edt;
 
         /* Correction factor to correct for the lack of history
          * at short times.
@@ -424,32 +441,41 @@ real calc_orires_dev(const gmx_multisim_t* ms,
         invn = 1.0;
     }
 
+    // Extract the local atom indices involved in the fit group
+    const auto fitLocalAtomIndices = od->fitLocalAtomSet().localIndex();
+    // We need all atoms in the fit group to be local available. This means that
+    // orientation restraining is limited to one PP-rank. This should be ensured
+    // by the mdrun setup code. We assert here to catch incorrect setup code.
+    GMX_RELEASE_ASSERT(fitLocalAtomIndices.size() == od->referenceCoordinates().size(),
+                       "All fit atoms should be locally available");
+
     clear_rvec(com);
-    mtot        = 0;
-    int   j     = 0;
-    auto* massT = md->massT;
-    auto* cORF  = md->cORF;
-    for (int i = 0; i < md->nr; i++)
-    {
-        if (cORF[i] == 0)
+    double     mtot               = 0.0;
+    gmx::index referenceListIndex = 0;
+    for (const int fitLocalAtomIndex : fitLocalAtomIndices)
+    {
+        const gmx::RVec& x       = xWholeMolecules[fitLocalAtomIndex];
+        const real       mass    = od->fitMasses()[referenceListIndex];
+        xFit[referenceListIndex] = x;
+        for (int d = 0; d < DIM; d++)
         {
-            copy_rvec(xWholeMolecules[i], xtmp[j]);
-            mref[j] = massT[i];
-            for (int d = 0; d < DIM; d++)
-            {
-                com[d] += mref[j] * xtmp[j][d];
-            }
-            mtot += mref[j];
-            j++;
+            com[d] += mass * x[d];
         }
+        mtot += mass;
+        referenceListIndex++;
     }
     svmul(1.0 / mtot, com, com);
-    for (int j = 0; j < nref; j++)
+    for (auto& xFitCoord : xFit)
     {
-        rvec_dec(xtmp[j], com);
+        xFitCoord -= com;
     }
     /* Calculate the rotation matrix to rotate x to the reference orientation */
-    calc_fit_R(DIM, nref, mref.data(), as_rvec_array(xref.data()), as_rvec_array(xtmp.data()), od->rotationMatrix);
+    calc_fit_R(DIM,
+               xFit.size(),
+               od->fitMasses().data(),
+               as_rvec_array(od->referenceCoordinates().data()),
+               as_rvec_array(xFit.data()),
+               od->rotationMatrix);
 
     for (int fa = 0; fa < nfa; fa += 3)
     {
@@ -519,7 +545,7 @@ real calc_orires_dev(const gmx_multisim_t* ms,
              */
             for (int i = 0; i < 5; i++)
             {
-                Dtav[i] = edt * hist->orire_Dtav[restraintIndex * 5 + i]
+                Dtav[i] = edt * od->DTensorsTimeAveragedHistory()[restraintIndex * 5 + i]
                           + edt_1 * od->DTensorsEnsembleAv[restraintIndex][i];
             }
         }
@@ -737,19 +763,19 @@ real orires(int             nfa,
     /* Approx. 80*nfa/3 flops */
 }
 
-void update_orires_history(const t_oriresdata& od, history_t* hist)
+void t_oriresdata::updateHistory()
 {
-    if (od.edt != 0)
+    if (edt != 0)
     {
         /* Copy the new time averages that have been calculated
          *  in calc_orires_dev.
          */
-        hist->orire_initf = od.exp_min_t_tau;
-        for (int pair = 0; pair < od.numRestraints; pair++)
+        *timeAveragingInitFactor_ = exp_min_t_tau;
+        for (int pair = 0; pair < numRestraints; pair++)
         {
             for (int i = 0; i < 5; i++)
             {
-                hist->orire_Dtav[pair * 5 + i] = od.DTensorsTimeAndEnsembleAv[pair][i];
+                DTensorsTimeAveragedHistory_[pair * 5 + i] = DTensorsTimeAndEnsembleAv[pair][i];
             }
         }
     }
index 3a298220d67efa4f7f445a6d932cec5b97373f74..3fcb15d0ac8f48e5ed8ceaec08dc51bd6a9888dd 100644 (file)
@@ -67,6 +67,11 @@ template<typename>
 class ArrayRef;
 } // namespace gmx
 
+/*! \brief Extends \p globalState with orientation restraint history
+ * when there are restraints and time averaging is used.
+ */
+void extendStateWithOriresHistory(const gmx_mtop_t& mtop, const t_inputrec& ir, t_state* globalState);
+
 /*! \brief
  * Calculates the time averaged D matrices, the S matrix for each experiment.
  *
@@ -76,12 +81,10 @@ real calc_orires_dev(const gmx_multisim_t*          ms,
                      int                            nfa,
                      const t_iatom                  fa[],
                      const t_iparams                ip[],
-                     const t_mdatoms*               md,
                      gmx::ArrayRef<const gmx::RVec> xWholeMolecules,
                      const rvec                     x[],
                      const t_pbc*                   pbc,
-                     t_oriresdata*                  oriresdata,
-                     const history_t*               hist);
+                     t_oriresdata*                  oriresdata);
 
 /*! \brief
  * Diagonalizes the order tensor(s) of the orienation restraints.
index 70a2b1e1f54b2590032c94d301726d00b667667b..23a0639167cdff96eee5ee1ef7ef507c4d09247f 100644 (file)
@@ -752,21 +752,7 @@ void init_forcerec(FILE*                            fplog,
         if (!moleculesAreAlwaysWhole && !havePPDomainDecomposition(commrec)
             && (useEwaldSurfaceCorrection || haveOrientationRestraints))
         {
-            GMX_RELEASE_ASSERT(
-                    !havePPDomainDecomposition(commrec),
-                    "WholeMoleculesTransform only works when all atoms are in the same domain");
-            forcerec->wholeMoleculeTransform = std::make_unique<gmx::WholeMoleculeTransform>(
-                    mtop, inputrec.pbcType, DOMAINDECOMP(commrec));
-        }
-        else
-        {
-            /* With Ewald surface correction it is useful to support e.g. running water
-             * in parallel with update groups.
-             * With orientation restraints there is no sensible use case supported with DD.
-             */
-            if ((useEwaldSurfaceCorrection
-                 && !(DOMAINDECOMP(commrec) && dd_moleculesAreAlwaysWhole(*commrec->dd)))
-                || haveOrientationRestraints)
+            if (havePPDomainDecomposition(commrec))
             {
                 gmx_fatal(FARGS,
                           "You requested Ewald surface correction or orientation restraints, "
@@ -774,8 +760,13 @@ void init_forcerec(FILE*                            fplog,
                           "over periodic boundary conditions by the domain decomposition. "
                           "Run without domain decomposition instead.");
             }
+
+            forcerec->wholeMoleculeTransform = std::make_unique<gmx::WholeMoleculeTransform>(
+                    mtop, inputrec.pbcType, DOMAINDECOMP(commrec));
         }
 
+        forcerec->bMolPBC = !DOMAINDECOMP(commrec) || dd_bonded_molpbc(*commrec->dd, forcerec->pbcType);
+
         if (useEwaldSurfaceCorrection)
         {
             GMX_RELEASE_ASSERT(moleculesAreAlwaysWhole || forcerec->wholeMoleculeTransform,
index a48c90ff74ed22dca089fde83ab25b8ad31c8021..e7eaae3308aadc3898bd8a5992d5c409d06a6b73 100644 (file)
@@ -87,7 +87,7 @@ void LeapFrogHostTestRunner::integrate(LeapFrogTestData* testData, int numSteps)
                                                  : gmx::ArrayRef<rvec>{},
                 &testData->state_,
                 testData->f_,
-                testData->forceCalculationData_,
+                &testData->forceCalculationData_,
                 &testData->kineticEnergyData_,
                 testData->velocityScalingMatrix_,
                 etrtNONE,
index e51f6a355ed90545714e5f15051d3153152b3624..7bcbc656875d08056349498716ccf2ce5c64e033 100644 (file)
@@ -122,7 +122,7 @@ public:
                        gmx::ArrayRef<const rvec>                        invMassPerDim,
                        t_state*                                         state,
                        const gmx::ArrayRefWithPadding<const gmx::RVec>& f,
-                       const t_fcdata&                                  fcdata,
+                       t_fcdata*                                        fcdata,
                        const gmx_ekindata_t*                            ekind,
                        const matrix                                     M,
                        int                                              UpdatePart,
@@ -218,7 +218,7 @@ void Update::update_coords(const t_inputrec&                 inputRecord,
                            gmx::ArrayRef<const rvec>         invMassPerDim,
                            t_state*                          state,
                            const gmx::ArrayRefWithPadding<const gmx::RVec>& f,
-                           const t_fcdata&                                  fcdata,
+                           t_fcdata*                                        fcdata,
                            const gmx_ekindata_t*                            ekind,
                            const matrix                                     M,
                            int                                              updatePart,
@@ -1522,7 +1522,7 @@ void Update::Impl::update_coords(const t_inputrec&                 inputRecord,
                                  gmx::ArrayRef<const rvec>         invMassPerDim,
                                  t_state*                          state,
                                  const gmx::ArrayRefWithPadding<const gmx::RVec>& f,
-                                 const t_fcdata&                                  fcdata,
+                                 t_fcdata*                                        fcdata,
                                  const gmx_ekindata_t*                            ekind,
                                  const matrix                                     M,
                                  int                                              updatePart,
@@ -1541,11 +1541,12 @@ void Update::Impl::update_coords(const t_inputrec&                 inputRecord,
     /* We need to update the NMR restraint history when time averaging is used */
     if (state->flags & enumValueToBitMask(StateEntry::DisreRm3Tav))
     {
-        update_disres_history(*fcdata.disres, &state->hist);
+        update_disres_history(*fcdata->disres, &state->hist);
     }
     if (state->flags & enumValueToBitMask(StateEntry::OrireDtav))
     {
-        update_orires_history(*fcdata.orires, &state->hist);
+        GMX_ASSERT(fcdata, "Need valid fcdata");
+        fcdata->orires->updateHistory();
     }
 
     /* ############# START The update of velocities and positions ######### */
index 6771d60f3d06fc72f0896296d5408b2d6de0c0fa..a3c4844e2e06487f65b9bd6e33fd2c1e9f59a565 100644 (file)
@@ -131,7 +131,7 @@ public:
                        gmx::ArrayRef<const rvec>                        invMassPerDim,
                        t_state*                                         state,
                        const gmx::ArrayRefWithPadding<const gmx::RVec>& f,
-                       const t_fcdata&                                  fcdata,
+                       t_fcdata*                                        fcdata,
                        const gmx_ekindata_t*                            ekind,
                        const matrix                                     M,
                        int                                              updatePart,
index f29eabb9261105cecba2fa80ffe03218633696be..a5a9797a60c999d217a4fbb981dd8273877d96f2 100644 (file)
@@ -79,7 +79,7 @@ void integrateVVFirstStep(int64_t                   step,
                           t_commrec*                cr,
                           t_state*                  state,
                           t_mdatoms*                mdatoms,
-                          const t_fcdata&           fcdata,
+                          t_fcdata*                 fcdata,
                           t_extmass*                MassQ,
                           t_vcm*                    vcm,
                           const gmx_localtop_t&     top,
@@ -337,7 +337,7 @@ void integrateVVSecondStep(int64_t
                            t_commrec*                                               cr,
                            t_state*                                                 state,
                            t_mdatoms*                                               mdatoms,
-                           const t_fcdata&                                          fcdata,
+                           t_fcdata*                                                fcdata,
                            t_extmass*                                               MassQ,
                            t_vcm*                                                   vcm,
                            pull_t*                                                  pull_work,
index 50ebbda78ff5b4ca5bc099a2d0fe5c756bb110f1..bc4307d681b0668f16ecbad268ddfdd3be2d558c 100644 (file)
@@ -128,7 +128,7 @@ void integrateVVFirstStep(int64_t                   step,
                           t_commrec*                cr,
                           t_state*                  state,
                           t_mdatoms*                mdatoms,
-                          const t_fcdata&           fcdata,
+                          t_fcdata*                 fcdata,
                           t_extmass*                MassQ,
                           t_vcm*                    vcm,
                           const gmx_localtop_t&     top,
@@ -203,7 +203,7 @@ void integrateVVSecondStep(int64_t
                            t_commrec*                                               cr,
                            t_state*                                                 state,
                            t_mdatoms*                                               mdatoms,
-                           const t_fcdata&                                          fcdata,
+                           t_fcdata*                                                fcdata,
                            t_extmass*                                               MassQ,
                            t_vcm*                                                   vcm,
                            pull_t*                                                  pull_work,
index e83d93640b6f97f5b5931036cce5482102d85da0..11ae4fab7baf53d91d1fb4e532109b810c5f52c2 100644 (file)
@@ -276,7 +276,7 @@ void gmx::LegacySimulator::do_md()
     }
     const bool useReplicaExchange = (replExParams.exchangeInterval > 0);
 
-    const t_fcdata& fcdata = *fr->fcdata;
+    t_fcdata& fcdata = *fr->fcdata;
 
     bool simulationsShareState = false;
     int  nstSignalComm         = nstglobalcomm;
@@ -1218,7 +1218,7 @@ void gmx::LegacySimulator::do_md()
                                  cr,
                                  state,
                                  mdAtoms->mdatoms(),
-                                 fcdata,
+                                 &fcdata,
                                  &MassQ,
                                  &vcm,
                                  top,
@@ -1466,7 +1466,7 @@ void gmx::LegacySimulator::do_md()
                                   cr,
                                   state,
                                   mdAtoms->mdatoms(),
-                                  fcdata,
+                                  &fcdata,
                                   &MassQ,
                                   &vcm,
                                   pull_work,
@@ -1599,7 +1599,7 @@ void gmx::LegacySimulator::do_md()
                                   gmx::arrayRefFromArray(md->invMassPerDim, md->nr),
                                   state,
                                   forceCombined,
-                                  fcdata,
+                                  &fcdata,
                                   ekind,
                                   M,
                                   etrtPOSITION,
index 98cd178a96ed670ae7163fa4ab18e9639d0aea66..765398cca5c780195292b3646a5aeb36d80f29f5 100644 (file)
@@ -1158,10 +1158,9 @@ int Mdrunner::mdrunner()
                 globalState.get(),
                 replExParams.exchangeInterval > 0);
 
-    std::unique_ptr<t_oriresdata> oriresData;
-    if (gmx_mtop_ftype_count(mtop, F_ORIRES) > 0)
+    if (gmx_mtop_ftype_count(mtop, F_ORIRES) > 0 && isSimulationMasterRank)
     {
-        oriresData = std::make_unique<t_oriresdata>(fplog, mtop, *inputrec, cr, ms, globalState.get());
+        extendStateWithOriresHistory(mtop, *inputrec, globalState.get());
     }
 
     auto deform = prepareBoxDeformation(globalState != nullptr ? globalState->box : box,
@@ -1659,7 +1658,11 @@ int Mdrunner::mdrunner()
                       pforce);
         // Dirty hack, for fixing disres and orires should be made mdmodules
         fr->fcdata->disres = disresdata;
-        fr->fcdata->orires.swap(oriresData);
+        if (gmx_mtop_ftype_count(mtop, F_ORIRES) > 0)
+        {
+            fr->fcdata->orires = std::make_unique<t_oriresdata>(
+                    fplog, mtop, *inputrec, ms, globalState.get(), &atomSets);
+        }
 
         // Save a handle to device stream manager to use elsewhere in the code
         // TODO: Forcerec is not a correct place to store it.
index 0dc0c23d789e30411e9aec8651015d93aca22e60..3f44cb18f164969feeae0d277ffb873903177914 100644 (file)
 #ifndef GMX_MDTYPES_FCDATA_H
 #define GMX_MDTYPES_FCDATA_H
 
+#include <functional>
 #include <memory>
+#include <optional>
 #include <vector>
 
+#include "gromacs/domdec/localatomset.h"
 #include "gromacs/math/vectypes.h"
 #include "gromacs/topology/idef.h"
 #include "gromacs/utility/arrayref.h"
 #include "gromacs/utility/real.h"
 
 enum class DistanceRestraintWeighting : int;
+class gmx_ga2la_t;
 struct gmx_mtop_t;
 struct gmx_multisim_t;
-struct t_commrec;
 struct t_inputrec;
 class t_state;
 
+namespace gmx
+{
+class LocalAtomSetManager;
+}
+
 typedef real rvec5[5];
 
 /* Distance restraining stuff */
@@ -99,23 +107,50 @@ struct t_oriresdata
      * \param[in] fplog  Log file, can be nullptr
      * \param[in] mtop   The global topology
      * \param[in] ir     The input record
-     * \param[in] cr     The commrec, can be nullptr when not running in parallel
      * \param[in] ms     The multisim communicator, pass nullptr to avoid ensemble averaging
-     * \param[in,out] globalState  The global state, orientation restraint entires are added
+     * \param[in] globalState  The global state, references are set to members
+     * \param[in,out] localAtomSetManager  The local atom set manager
      *
      * \throws InvalidInputError when there is domain decomposition, fewer than 5 restraints,
      *         periodic molecules or more than 1 molecule for a moleculetype with restraints.
      */
-    t_oriresdata(FILE*                 fplog,
-                 const gmx_mtop_t&     mtop,
-                 const t_inputrec&     ir,
-                 const t_commrec*      cr,
-                 const gmx_multisim_t* ms,
-                 t_state*              globalState);
+    t_oriresdata(FILE*                     fplog,
+                 const gmx_mtop_t&         mtop,
+                 const t_inputrec&         ir,
+                 const gmx_multisim_t*     ms,
+                 t_state*                  globalState,
+                 gmx::LocalAtomSetManager* localAtomSetManager);
 
     //! Destructor
     ~t_oriresdata();
 
+    //! Returns the local atom set for fitting
+    const gmx::LocalAtomSet& fitLocalAtomSet() const { return fitLocalAtomSet_; }
+
+    //! Returns the list of reference coordinates
+    gmx::ArrayRef<const gmx::RVec> referenceCoordinates() const { return referenceCoordinates_; }
+
+    //! Returns the list of masses for fitting
+    gmx::ArrayRef<const real> fitMasses() const { return fitMasses_; }
+
+    //! Returns the list of local atoms for fitting, matching the order of referenceCoordinates
+    gmx::ArrayRef<const int> fitLocalAtomIndices() const { return fitLocalAtomIndices_; }
+
+    //! Returns the list of coordinates for temporary use, size matches referenceCoordinates
+    gmx::ArrayRef<gmx::RVec> xTmp() { return xTmp_; }
+
+    //! Returns the factor for initializing the time averaging
+    real timeAveragingInitFactor() const { return *timeAveragingInitFactor_; }
+
+    //! Returns a const view on the time averaged D tensor history
+    gmx::ArrayRef<const real> DTensorsTimeAveragedHistory() const
+    {
+        return DTensorsTimeAveragedHistory_;
+    }
+
+    //! Updates the history with the current values
+    void updateHistory();
+
     //! Force constant for the restraints
     real fc;
     //! Multiplication factor for time averaging
@@ -130,14 +165,25 @@ struct t_oriresdata
     int numExperiments;
     //! The minimum iparam type index for restraints
     int typeMin;
-    //! The number of atoms for the fit
-    int numReferenceAtoms;
-    //! The masses of the reference atoms
-    std::vector<real> mref;
+
+private:
+    //! List of local atom corresponding to the fit group
+    gmx::LocalAtomSet fitLocalAtomSet_;
     //! The reference coordinates for the fit
-    std::vector<gmx::RVec> xref;
-    //! Temporary array for fitting
-    std::vector<gmx::RVec> xtmp;
+    std::vector<gmx::RVec> referenceCoordinates_;
+    //! The masses for fitting
+    std::vector<real> fitMasses_;
+    //! List of reference atoms for fitting
+    std::vector<int> fitLocalAtomIndices_;
+    //! Temporary array, used for fitting
+    std::vector<gmx::RVec> xTmp_;
+    //! The factor for initializing the time averaging, only present when time averaging is used
+    //! This references the value stored in the global state, which depends on time.
+    std::optional<std::reference_wrapper<real>> timeAveragingInitFactor_;
+    //! View on the time averaged history of the orientation tensors
+    gmx::ArrayRef<real> DTensorsTimeAveragedHistory_;
+
+public:
     //! Rotation matrix to rotate to the reference coordinates
     matrix rotationMatrix;
     //! Array of order tensors, one for each experiment