Use more ArrayRefs in signatures of update impl class
authorJoe Jordan <ejjordan12@gmail.com>
Wed, 31 Mar 2021 08:30:23 +0000 (08:30 +0000)
committerJoe Jordan <ejjordan12@gmail.com>
Wed, 31 Mar 2021 08:30:23 +0000 (08:30 +0000)
Part of ongoing work to make refactoring mdatoms easier. This only
changes the impl class of update. A followup change can propagate
using ArrayRef to the update class itself, which will make it easier
to call update in both tests and in nblib.

src/gromacs/mdlib/update.cpp

index 7b7c6717ea4f724030e0f3f92b1a96bc8b5a5ee5..1afe8047c8985e5d605f3ccae3ae07986bcc2601 100644 (file)
@@ -122,7 +122,13 @@ public:
 
     void update_coords(const t_inputrec&                                inputRecord,
                        int64_t                                          step,
-                       const t_mdatoms*                                 md,
+                       int                                              homenr,
+                       bool                                             havePartiallyFrozenAtoms,
+                       gmx::ArrayRef<const ParticleType>                ptype,
+                       gmx::ArrayRef<const unsigned short>              cFREEZE,
+                       gmx::ArrayRef<const unsigned short>              cTC,
+                       gmx::ArrayRef<const real>                        invMass,
+                       gmx::ArrayRef<rvec>                              invMassPerDim,
                        t_state*                                         state,
                        const gmx::ArrayRefWithPadding<const gmx::RVec>& f,
                        const t_fcdata&                                  fcdata,
@@ -138,21 +144,28 @@ public:
                        gmx_wallcycle_t   wcycle,
                        bool              haveConstraints);
 
-    void update_sd_second_half(const t_inputrec& inputRecord,
-                               int64_t           step,
-                               real*             dvdlambda,
-                               const t_mdatoms*  md,
-                               t_state*          state,
-                               const t_commrec*  cr,
-                               t_nrnb*           nrnb,
-                               gmx_wallcycle_t   wcycle,
-                               gmx::Constraints* constr,
-                               bool              do_log,
-                               bool              do_ene);
-
-    void update_for_constraint_virial(const t_inputrec&                                inputRecord,
-                                      const t_mdatoms&                                 md,
-                                      const t_state&                                   state,
+    void update_sd_second_half(const t_inputrec&                   inputRecord,
+                               int64_t                             step,
+                               real*                               dvdlambda,
+                               int                                 homenr,
+                               gmx::ArrayRef<const ParticleType>   ptype,
+                               gmx::ArrayRef<const unsigned short> cFREEZE,
+                               gmx::ArrayRef<const unsigned short> cTC,
+                               gmx::ArrayRef<const real>           invMass,
+                               t_state*                            state,
+                               const t_commrec*                    cr,
+                               t_nrnb*                             nrnb,
+                               gmx_wallcycle_t                     wcycle,
+                               gmx::Constraints*                   constr,
+                               bool                                do_log,
+                               bool                                do_ene);
+
+    void update_for_constraint_virial(const t_inputrec&   inputRecord,
+                                      int                 homenr,
+                                      bool                havePartiallyFrozenAtoms,
+                                      gmx::ArrayRef<real> invmass,
+                                      gmx::ArrayRef<rvec> invMassPerDim,
+                                      const t_state&      state,
                                       const gmx::ArrayRefWithPadding<const gmx::RVec>& f,
                                       const gmx_ekindata_t&                            ekind);
 
@@ -212,8 +225,25 @@ void Update::update_coords(const t_inputrec&                                inpu
                            const t_commrec*                                 cr,
                            const bool                                       haveConstraints)
 {
-    return impl_->update_coords(
-            inputRecord, step, md, state, f, fcdata, ekind, M, updatePart, cr, haveConstraints);
+    return impl_->update_coords(inputRecord,
+                                step,
+                                md->homenr,
+                                md->havePartiallyFrozenAtoms,
+                                gmx::arrayRefFromArray(md->ptype, md->nr),
+                                md->cFREEZE ? gmx::arrayRefFromArray(md->cFREEZE, md->nr)
+                                            : gmx::ArrayRef<const unsigned short>(),
+                                md->cTC ? gmx::arrayRefFromArray(md->cTC, md->nr)
+                                        : gmx::ArrayRef<const unsigned short>(),
+                                gmx::arrayRefFromArray(md->invmass, md->nr),
+                                gmx::arrayRefFromArray(md->invMassPerDim, md->nr),
+                                state,
+                                f,
+                                fcdata,
+                                ekind,
+                                M,
+                                updatePart,
+                                cr,
+                                haveConstraints);
 }
 
 void Update::finish_update(const t_inputrec& inputRecord,
@@ -237,8 +267,23 @@ void Update::update_sd_second_half(const t_inputrec& inputRecord,
                                    bool              do_log,
                                    bool              do_ene)
 {
-    return impl_->update_sd_second_half(
-            inputRecord, step, dvdlambda, md, state, cr, nrnb, wcycle, constr, do_log, do_ene);
+    return impl_->update_sd_second_half(inputRecord,
+                                        step,
+                                        dvdlambda,
+                                        md->homenr,
+                                        gmx::arrayRefFromArray(md->ptype, md->nr),
+                                        md->cFREEZE ? gmx::arrayRefFromArray(md->cFREEZE, md->nr)
+                                                    : gmx::ArrayRef<const unsigned short>(),
+                                        md->cTC ? gmx::arrayRefFromArray(md->cTC, md->nr)
+                                                : gmx::ArrayRef<const unsigned short>(),
+                                        gmx::arrayRefFromArray(md->invmass, md->nr),
+                                        state,
+                                        cr,
+                                        nrnb,
+                                        wcycle,
+                                        constr,
+                                        do_log,
+                                        do_ene);
 }
 
 void Update::update_for_constraint_virial(const t_inputrec& inputRecord,
@@ -247,7 +292,14 @@ void Update::update_for_constraint_virial(const t_inputrec& inputRecord,
                                           const gmx::ArrayRefWithPadding<const gmx::RVec>& f,
                                           const gmx_ekindata_t&                            ekind)
 {
-    return impl_->update_for_constraint_virial(inputRecord, md, state, f, ekind);
+    return impl_->update_for_constraint_virial(inputRecord,
+                                               md.homenr,
+                                               md.havePartiallyFrozenAtoms,
+                                               gmx::arrayRefFromArray(md.invmass, md.nr),
+                                               gmx::arrayRefFromArray(md.invMassPerDim, md.nr),
+                                               state,
+                                               f,
+                                               ekind);
 }
 
 void Update::update_temperature_constants(const t_inputrec& inputRecord)
@@ -306,14 +358,14 @@ enum class ApplyParrinelloRahmanVScaling
  */
 template<StoreUpdatedVelocities storeUpdatedVelocities, NumTempScaleValues numTempScaleValues, ApplyParrinelloRahmanVScaling applyPRVScaling, typename VelocityType>
 static std::enable_if_t<std::is_same<VelocityType, rvec*>::value || std::is_same<VelocityType, const rvec*>::value, void>
-updateMDLeapfrogSimple(int         start,
-                       int         nrend,
-                       real        dt,
-                       real        dtPressureCouple,
-                       const rvec* gmx_restrict          invMassPerDim,
-                       gmx::ArrayRef<const t_grp_tcstat> tcstat,
-                       const unsigned short*             cTC,
-                       const rvec                        pRVScaleMatrixDiagonal,
+updateMDLeapfrogSimple(int                                 start,
+                       int                                 nrend,
+                       real                                dt,
+                       real                                dtPressureCouple,
+                       gmx::ArrayRef<const rvec>           invMassPerDim,
+                       gmx::ArrayRef<const t_grp_tcstat>   tcstat,
+                       gmx::ArrayRef<const unsigned short> cTC,
+                       const rvec                          pRVScaleMatrixDiagonal,
                        const rvec* gmx_restrict x,
                        rvec* gmx_restrict xprime,
                        VelocityType gmx_restrict v,
@@ -428,10 +480,10 @@ static inline void simdStoreRvecs(rvec* r, int index, SimdReal r0, SimdReal r1,
  */
 template<StoreUpdatedVelocities storeUpdatedVelocities, typename VelocityType>
 static std::enable_if_t<std::is_same<VelocityType, rvec*>::value || std::is_same<VelocityType, const rvec*>::value, void>
-updateMDLeapfrogSimpleSimd(int         start,
-                           int         nrend,
-                           real        dt,
-                           const real* gmx_restrict          invMass,
+updateMDLeapfrogSimpleSimd(int                               start,
+                           int                               nrend,
+                           real                              dt,
+                           gmx::ArrayRef<const real>         invMass,
                            gmx::ArrayRef<const t_grp_tcstat> tcstat,
                            const rvec* gmx_restrict x,
                            rvec* gmx_restrict xprime,
@@ -444,12 +496,12 @@ updateMDLeapfrogSimpleSimd(int         start,
     /* We declare variables here, since code is often slower when declaring them inside the loop */
 
     /* Note: We should implement a proper PaddedVector, so we don't need this check */
-    GMX_ASSERT(isSimdAligned(invMass), "invMass should be aligned");
+    GMX_ASSERT(isSimdAligned(invMass.data()), "invMass should be aligned");
 
     for (int a = start; a < nrend; a += GMX_SIMD_REAL_WIDTH)
     {
         SimdReal invMass0, invMass1, invMass2;
-        expandScalarsToTriplets(simdLoad(invMass + a), &invMass0, &invMass1, &invMass2);
+        expandScalarsToTriplets(simdLoad(invMass.data() + a), &invMass0, &invMass1, &invMass2);
 
         SimdReal v0, v1, v2;
         SimdReal f0, f1, f2;
@@ -507,14 +559,15 @@ enum class AccelerationType
  * \param[in]     M                 Parrinello-Rahman scaling matrix.
  */
 template<AccelerationType accelerationType>
-static void updateMDLeapfrogGeneral(int                   start,
-                                    int                   nrend,
-                                    bool                  doNoseHoover,
-                                    real                  dt,
-                                    real                  dtPressureCouple,
-                                    const t_mdatoms*      md,
-                                    const gmx_ekindata_t* ekind,
-                                    const matrix          box,
+static void updateMDLeapfrogGeneral(int                                 start,
+                                    int                                 nrend,
+                                    bool                                doNoseHoover,
+                                    real                                dt,
+                                    real                                dtPressureCouple,
+                                    gmx::ArrayRef<const unsigned short> cTC,
+                                    gmx::ArrayRef<const rvec>           invMassPerDim,
+                                    const gmx_ekindata_t*               ekind,
+                                    const matrix                        box,
                                     const rvec* gmx_restrict x,
                                     rvec* gmx_restrict xprime,
                                     rvec* gmx_restrict v,
@@ -530,9 +583,6 @@ static void updateMDLeapfrogGeneral(int                   start,
      */
 
     gmx::ArrayRef<const t_grp_tcstat> tcstat = ekind->tcstat;
-    const unsigned short*             cTC    = md->cTC;
-
-    const rvec* gmx_restrict invMassPerDim = md->invMassPerDim;
 
     /* Initialize group values, changed later when multiple groups are used */
     int gt = 0;
@@ -541,7 +591,7 @@ static void updateMDLeapfrogGeneral(int                   start,
 
     for (int n = start; n < nrend; n++)
     {
-        if (cTC)
+        if (!cTC.empty())
         {
             gt = cTC[n];
         }
@@ -602,16 +652,19 @@ static void do_update_md(int         start,
                          const rvec* gmx_restrict x,
                          rvec* gmx_restrict xprime,
                          rvec* gmx_restrict v,
-                         const rvec* gmx_restrict  f,
-                         const TemperatureCoupling etc,
-                         const PressureCoupling    epc,
-                         const int                 nsttcouple,
-                         const int                 nstpcouple,
-                         const t_mdatoms*          md,
-                         const gmx_ekindata_t*     ekind,
-                         const matrix              box,
+                         const rvec* gmx_restrict            f,
+                         const TemperatureCoupling           etc,
+                         const PressureCoupling              epc,
+                         const int                           nsttcouple,
+                         const int                           nstpcouple,
+                         gmx::ArrayRef<const unsigned short> cTC,
+                         gmx::ArrayRef<const real> gmx_unused invmass,
+                         gmx::ArrayRef<const rvec>            invMassPerDim,
+                         const gmx_ekindata_t*                ekind,
+                         const matrix                         box,
                          const double* gmx_restrict nh_vxi,
-                         const matrix               M)
+                         const matrix               M,
+                         bool gmx_unused havePartiallyFrozenAtoms)
 {
     GMX_ASSERT(nrend == start || xprime != x,
                "For SIMD optimization certain compilers need to have xprime != x");
@@ -644,13 +697,41 @@ static void do_update_md(int         start,
 
         if (!doAcceleration)
         {
-            updateMDLeapfrogGeneral<AccelerationType::none>(
-                    start, nrend, doNoseHoover, dt, dtPressureCouple, md, ekind, box, x, xprime, v, f, nh_vxi, nsttcouple, stepM);
+            updateMDLeapfrogGeneral<AccelerationType::none>(start,
+                                                            nrend,
+                                                            doNoseHoover,
+                                                            dt,
+                                                            dtPressureCouple,
+                                                            cTC,
+                                                            invMassPerDim,
+                                                            ekind,
+                                                            box,
+                                                            x,
+                                                            xprime,
+                                                            v,
+                                                            f,
+                                                            nh_vxi,
+                                                            nsttcouple,
+                                                            stepM);
         }
         else
         {
-            updateMDLeapfrogGeneral<AccelerationType::cosine>(
-                    start, nrend, doNoseHoover, dt, dtPressureCouple, md, ekind, box, x, xprime, v, f, nh_vxi, nsttcouple, stepM);
+            updateMDLeapfrogGeneral<AccelerationType::cosine>(start,
+                                                              nrend,
+                                                              doNoseHoover,
+                                                              dt,
+                                                              dtPressureCouple,
+                                                              cTC,
+                                                              invMassPerDim,
+                                                              ekind,
+                                                              box,
+                                                              x,
+                                                              xprime,
+                                                              v,
+                                                              f,
+                                                              nh_vxi,
+                                                              nsttcouple,
+                                                              stepM);
         }
     }
     else
@@ -663,9 +744,7 @@ static void do_update_md(int         start,
         bool haveSingleTempScaleValue = (!doTempCouple || ekind->ngtc == 1);
 
         /* Extract some pointers needed by all cases */
-        const unsigned short*             cTC           = md->cTC;
-        gmx::ArrayRef<const t_grp_tcstat> tcstat        = ekind->tcstat;
-        const rvec*                       invMassPerDim = md->invMassPerDim;
+        gmx::ArrayRef<const t_grp_tcstat> tcstat = ekind->tcstat;
 
         if (doParrinelloRahman)
         {
@@ -701,10 +780,10 @@ static void do_update_md(int         start,
                  */
 #if GMX_HAVE_SIMD_UPDATE
                 /* Check if we can use invmass instead of invMassPerDim */
-                if (!md->havePartiallyFrozenAtoms)
+                if (!havePartiallyFrozenAtoms)
                 {
                     updateMDLeapfrogSimpleSimd<StoreUpdatedVelocities::yes>(
-                            start, nrend, dt, md->invmass, tcstat, x, xprime, v, f);
+                            start, nrend, dt, invmass, tcstat, x, xprime, v, f);
                 }
                 else
 #endif
@@ -729,8 +808,10 @@ static void doUpdateMDDoNotUpdateVelocities(int         start,
                                             rvec* gmx_restrict xprime,
                                             const rvec* gmx_restrict v,
                                             const rvec* gmx_restrict f,
-                                            const t_mdatoms&         md,
-                                            const gmx_ekindata_t&    ekind)
+                                            bool gmx_unused     havePartiallyFrozenAtoms,
+                                            gmx::ArrayRef<real> gmx_unused invmass,
+                                            gmx::ArrayRef<rvec>            invMassPerDim,
+                                            const gmx_ekindata_t&          ekind)
 {
     GMX_ASSERT(nrend == start || xprime != x,
                "For SIMD optimization certain compilers need to have xprime != x");
@@ -739,31 +820,31 @@ static void doUpdateMDDoNotUpdateVelocities(int         start,
 
     /* Check if we can use invmass instead of invMassPerDim */
 #if GMX_HAVE_SIMD_UPDATE
-    if (!md.havePartiallyFrozenAtoms)
+    if (!havePartiallyFrozenAtoms)
     {
         updateMDLeapfrogSimpleSimd<StoreUpdatedVelocities::no>(
-                start, nrend, dt, md.invmass, tcstat, x, xprime, v, f);
+                start, nrend, dt, invmass, tcstat, x, xprime, v, f);
     }
     else
 #endif
     {
         updateMDLeapfrogSimple<StoreUpdatedVelocities::no, NumTempScaleValues::single, ApplyParrinelloRahmanVScaling::no>(
-                start, nrend, dt, dt, md.invMassPerDim, tcstat, nullptr, nullptr, x, xprime, v, f);
+                start, nrend, dt, dt, invMassPerDim, tcstat, gmx::ArrayRef<const unsigned short>(), nullptr, x, xprime, v, f);
     }
 }
 
-static void do_update_vv_vel(int                  start,
-                             int                  nrend,
-                             real                 dt,
-                             const ivec           nFreeze[],
-                             const real           invmass[],
-                             const ParticleType   ptype[],
-                             const unsigned short cFREEZE[],
-                             rvec                 v[],
-                             const rvec           f[],
-                             gmx_bool             bExtended,
-                             real                 veta,
-                             real                 alpha)
+static void do_update_vv_vel(int                                 start,
+                             int                                 nrend,
+                             real                                dt,
+                             gmx::ArrayRef<const ivec>           nFreeze,
+                             gmx::ArrayRef<const real>           invmass,
+                             gmx::ArrayRef<const ParticleType>   ptype,
+                             gmx::ArrayRef<const unsigned short> cFREEZE,
+                             rvec                                v[],
+                             const rvec                          f[],
+                             gmx_bool                            bExtended,
+                             real                                veta,
+                             real                                alpha)
 {
     int  gf = 0;
     int  n, d;
@@ -783,7 +864,7 @@ static void do_update_vv_vel(int                  start,
     for (n = start; n < nrend; n++)
     {
         real w_dt = invmass[n] * dt;
-        if (cFREEZE)
+        if (!cFREEZE.empty())
         {
             gf = cFREEZE[n];
         }
@@ -802,17 +883,17 @@ static void do_update_vv_vel(int                  start,
     }
 } /* do_update_vv_vel */
 
-static void do_update_vv_pos(int                  start,
-                             int                  nrend,
-                             real                 dt,
-                             const ivec           nFreeze[],
-                             const ParticleType   ptype[],
-                             const unsigned short cFREEZE[],
-                             const rvec           x[],
-                             rvec                 xprime[],
-                             const rvec           v[],
-                             gmx_bool             bExtended,
-                             real                 veta)
+static void do_update_vv_pos(int                                 start,
+                             int                                 nrend,
+                             real                                dt,
+                             gmx::ArrayRef<const ivec>           nFreeze,
+                             gmx::ArrayRef<const ParticleType>   ptype,
+                             gmx::ArrayRef<const unsigned short> cFREEZE,
+                             const rvec                          x[],
+                             rvec                                xprime[],
+                             const rvec                          v[],
+                             gmx_bool                            bExtended,
+                             real                                veta)
 {
     int  gf = 0;
     int  n, d;
@@ -834,7 +915,7 @@ static void do_update_vv_pos(int                  start,
     for (n = start; n < nrend; n++)
     {
 
-        if (cFREEZE)
+        if (!cFREEZE.empty())
         {
             gf = cFREEZE[n];
         }
@@ -972,22 +1053,22 @@ enum class SDUpdate : int
  * Thus three instantiations of this templated function will be made,
  * two with only one contribution, and one with both contributions. */
 template<SDUpdate updateType>
-static void doSDUpdateGeneral(const gmx_stochd_t&  sd,
-                              int                  start,
-                              int                  nrend,
-                              real                 dt,
-                              const ivec           nFreeze[],
-                              const real           invmass[],
-                              const ParticleType   ptype[],
-                              const unsigned short cFREEZE[],
-                              const unsigned short cTC[],
-                              const rvec           x[],
-                              rvec                 xprime[],
-                              rvec                 v[],
-                              const rvec           f[],
-                              int64_t              step,
-                              int                  seed,
-                              const int*           gatindex)
+static void doSDUpdateGeneral(const gmx_stochd_t&                 sd,
+                              int                                 start,
+                              int                                 nrend,
+                              real                                dt,
+                              gmx::ArrayRef<const ivec>           nFreeze,
+                              gmx::ArrayRef<const real>           invmass,
+                              gmx::ArrayRef<const ParticleType>   ptype,
+                              gmx::ArrayRef<const unsigned short> cFREEZE,
+                              gmx::ArrayRef<const unsigned short> cTC,
+                              const rvec                          x[],
+                              rvec                                xprime[],
+                              rvec                                v[],
+                              const rvec                          f[],
+                              int64_t                             step,
+                              int                                 seed,
+                              const int*                          gatindex)
 {
     // cTC and cFREEZE can be nullptr any time, but various
     // instantiations do not make sense with particular pointer
@@ -995,7 +1076,7 @@ static void doSDUpdateGeneral(const gmx_stochd_t&  sd,
     if (updateType == SDUpdate::ForcesOnly)
     {
         GMX_ASSERT(f != nullptr, "SD update with only forces requires forces");
-        GMX_ASSERT(cTC == nullptr, "SD update with only forces cannot handle temperature groups");
+        GMX_ASSERT(cTC.empty(), "SD update with only forces cannot handle temperature groups");
     }
     if (updateType == SDUpdate::FrictionAndNoiseOnly)
     {
@@ -1019,8 +1100,8 @@ static void doSDUpdateGeneral(const gmx_stochd_t&  sd,
         real inverseMass = invmass[n];
         real invsqrtMass = std::sqrt(inverseMass);
 
-        int freezeGroup      = cFREEZE ? cFREEZE[n] : 0;
-        int temperatureGroup = cTC ? cTC[n] : 0;
+        int freezeGroup      = !cFREEZE.empty() ? cFREEZE[n] : 0;
+        int temperatureGroup = !cTC.empty() ? cTC[n] : 0;
 
         for (int d = 0; d < DIM; d++)
         {
@@ -1076,22 +1157,36 @@ static void do_update_sd(int         start,
                          const rvec* gmx_restrict x,
                          rvec* gmx_restrict xprime,
                          rvec* gmx_restrict v,
-                         const rvec* gmx_restrict f,
-                         const ivec               nFreeze[],
-                         const real               invmass[],
-                         const ParticleType       ptype[],
-                         const unsigned short     cFREEZE[],
-                         const unsigned short     cTC[],
-                         int                      seed,
-                         const t_commrec*         cr,
-                         const gmx_stochd_t&      sd,
-                         bool                     haveConstraints)
+                         const rvec* gmx_restrict            f,
+                         gmx::ArrayRef<const ivec>           nFreeze,
+                         gmx::ArrayRef<const real>           invmass,
+                         gmx::ArrayRef<const ParticleType>   ptype,
+                         gmx::ArrayRef<const unsigned short> cFREEZE,
+                         gmx::ArrayRef<const unsigned short> cTC,
+                         int                                 seed,
+                         const t_commrec*                    cr,
+                         const gmx_stochd_t&                 sd,
+                         bool                                haveConstraints)
 {
     if (haveConstraints)
     {
         // With constraints, the SD update is done in 2 parts
-        doSDUpdateGeneral<SDUpdate::ForcesOnly>(
-                sd, start, nrend, dt, nFreeze, invmass, ptype, cFREEZE, nullptr, x, xprime, v, f, step, seed, nullptr);
+        doSDUpdateGeneral<SDUpdate::ForcesOnly>(sd,
+                                                start,
+                                                nrend,
+                                                dt,
+                                                nFreeze,
+                                                invmass,
+                                                ptype,
+                                                cFREEZE,
+                                                gmx::ArrayRef<const unsigned short>(),
+                                                x,
+                                                xprime,
+                                                v,
+                                                f,
+                                                step,
+                                                seed,
+                                                nullptr);
     }
     else
     {
@@ -1121,16 +1216,16 @@ static void do_update_bd(int         start,
                          const rvec* gmx_restrict x,
                          rvec* gmx_restrict xprime,
                          rvec* gmx_restrict v,
-                         const rvec* gmx_restrict f,
-                         const ivec               nFreeze[],
-                         const real               invmass[],
-                         const ParticleType       ptype[],
-                         const unsigned short     cFREEZE[],
-                         const unsigned short     cTC[],
-                         real                     friction_coefficient,
-                         const real*              rf,
-                         int                      seed,
-                         const int*               gatindex)
+                         const rvec* gmx_restrict            f,
+                         gmx::ArrayRef<const ivec>           nFreeze,
+                         gmx::ArrayRef<const real>           invmass,
+                         gmx::ArrayRef<const ParticleType>   ptype,
+                         gmx::ArrayRef<const unsigned short> cFREEZE,
+                         gmx::ArrayRef<const unsigned short> cTC,
+                         real                                friction_coefficient,
+                         const real*                         rf,
+                         int                                 seed,
+                         const int*                          gatindex)
 {
     /* note -- these appear to be full step velocities . . .  */
     int  gf = 0, gt = 0;
@@ -1154,11 +1249,11 @@ static void do_update_bd(int         start,
         rng.restart(step, ng);
         dist.reset();
 
-        if (cFREEZE)
+        if (!cFREEZE.empty())
         {
             gf = cFREEZE[n];
         }
-        if (cTC)
+        if (!cTC.empty())
         {
             gt = cTC[n];
         }
@@ -1292,17 +1387,21 @@ void getThreadAtomRange(int numThreads, int threadIndex, int numAtoms, int* star
     }
 }
 
-void Update::Impl::update_sd_second_half(const t_inputrec& inputRecord,
-                                         int64_t           step,
-                                         real*             dvdlambda,
-                                         const t_mdatoms*  md,
-                                         t_state*          state,
-                                         const t_commrec*  cr,
-                                         t_nrnb*           nrnb,
-                                         gmx_wallcycle_t   wcycle,
-                                         gmx::Constraints* constr,
-                                         bool              do_log,
-                                         bool              do_ene)
+void Update::Impl::update_sd_second_half(const t_inputrec&                   inputRecord,
+                                         int64_t                             step,
+                                         real*                               dvdlambda,
+                                         int                                 homenr,
+                                         gmx::ArrayRef<const ParticleType>   ptype,
+                                         gmx::ArrayRef<const unsigned short> cFREEZE,
+                                         gmx::ArrayRef<const unsigned short> cTC,
+                                         gmx::ArrayRef<const real>           invMass,
+                                         t_state*                            state,
+                                         const t_commrec*                    cr,
+                                         t_nrnb*                             nrnb,
+                                         gmx_wallcycle_t                     wcycle,
+                                         gmx::Constraints*                   constr,
+                                         bool                                do_log,
+                                         bool                                do_ene)
 {
     if (!constr)
     {
@@ -1310,7 +1409,6 @@ void Update::Impl::update_sd_second_half(const t_inputrec& inputRecord,
     }
     if (inputRecord.eI == IntegrationAlgorithm::SD1)
     {
-        int homenr = md->homenr;
 
         /* Cast delta_t from double to real to make the integrators faster.
          * The only reason for having delta_t double is to get accurate values
@@ -1338,11 +1436,11 @@ void Update::Impl::update_sd_second_half(const t_inputrec& inputRecord,
                         start_th,
                         end_th,
                         dt,
-                        inputRecord.opts.nFreeze,
-                        md->invmass,
-                        md->ptype,
-                        md->cFREEZE,
-                        md->cTC,
+                        gmx::arrayRefFromArray(inputRecord.opts.nFreeze, inputRecord.opts.ngfrz),
+                        invMass,
+                        ptype,
+                        cFREEZE,
+                        cTC,
                         state->x.rvec_array(),
                         xp_.rvec_array(),
                         state->v.rvec_array(),
@@ -1430,10 +1528,16 @@ void Update::Impl::finish_update(const t_inputrec& inputRecord,
     wallcycle_stop(wcycle, ewcUPDATE);
 }
 
-void Update::Impl::update_coords(const t_inputrec&                                inputRecord,
-                                 int64_t                                          step,
-                                 const t_mdatoms*                                 md,
-                                 t_state*                                         state,
+void Update::Impl::update_coords(const t_inputrec&                   inputRecord,
+                                 int64_t                             step,
+                                 int                                 homenr,
+                                 bool                                havePartiallyFrozenAtoms,
+                                 gmx::ArrayRef<const ParticleType>   ptype,
+                                 gmx::ArrayRef<const unsigned short> cFREEZE,
+                                 gmx::ArrayRef<const unsigned short> cTC,
+                                 gmx::ArrayRef<const real>           invMass,
+                                 gmx::ArrayRef<rvec>                 invMassPerDim,
+                                 t_state*                            state,
                                  const gmx::ArrayRefWithPadding<const gmx::RVec>& f,
                                  const t_fcdata&                                  fcdata,
                                  const gmx_ekindata_t*                            ekind,
@@ -1448,8 +1552,6 @@ void Update::Impl::update_coords(const t_inputrec&
         gmx_incons("update_coords called for velocity without VV integrator");
     }
 
-    int homenr = md->homenr;
-
     /* Cast to real for faster code, no loss in precision (see comment above) */
     real dt = inputRecord.delta_t;
 
@@ -1494,11 +1596,14 @@ void Update::Impl::update_coords(const t_inputrec&
                                  inputRecord.epc,
                                  inputRecord.nsttcouple,
                                  inputRecord.nstpcouple,
-                                 md,
+                                 cTC,
+                                 invMass,
+                                 invMassPerDim,
                                  ekind,
                                  state->box,
                                  state->nosehoover_vxi.data(),
-                                 M);
+                                 M,
+                                 havePartiallyFrozenAtoms);
                     break;
                 case (IntegrationAlgorithm::SD1):
                     do_update_sd(start_th,
@@ -1509,11 +1614,11 @@ void Update::Impl::update_coords(const t_inputrec&
                                  xp_rvec,
                                  v_rvec,
                                  f_rvec,
-                                 inputRecord.opts.nFreeze,
-                                 md->invmass,
-                                 md->ptype,
-                                 md->cFREEZE,
-                                 md->cTC,
+                                 gmx::arrayRefFromArray(inputRecord.opts.nFreeze, inputRecord.opts.ngfrz),
+                                 invMass,
+                                 ptype,
+                                 cFREEZE,
+                                 cTC,
                                  inputRecord.ld_seed,
                                  cr,
                                  sd_,
@@ -1528,11 +1633,11 @@ void Update::Impl::update_coords(const t_inputrec&
                                  xp_rvec,
                                  v_rvec,
                                  f_rvec,
-                                 inputRecord.opts.nFreeze,
-                                 md->invmass,
-                                 md->ptype,
-                                 md->cFREEZE,
-                                 md->cTC,
+                                 gmx::arrayRefFromArray(inputRecord.opts.nFreeze, inputRecord.opts.ngfrz),
+                                 invMass,
+                                 ptype,
+                                 cFREEZE,
+                                 cTC,
                                  inputRecord.bd_fric,
                                  sd_.bd_rf.data(),
                                  inputRecord.ld_seed,
@@ -1554,10 +1659,11 @@ void Update::Impl::update_coords(const t_inputrec&
                             do_update_vv_vel(start_th,
                                              end_th,
                                              dt,
-                                             inputRecord.opts.nFreeze,
-                                             md->invmass,
-                                             md->ptype,
-                                             md->cFREEZE,
+                                             gmx::arrayRefFromArray(inputRecord.opts.nFreeze,
+                                                                    inputRecord.opts.ngfrz),
+                                             invMass,
+                                             ptype,
+                                             cFREEZE,
                                              v_rvec,
                                              f_rvec,
                                              bExtended,
@@ -1568,9 +1674,10 @@ void Update::Impl::update_coords(const t_inputrec&
                             do_update_vv_pos(start_th,
                                              end_th,
                                              dt,
-                                             inputRecord.opts.nFreeze,
-                                             md->ptype,
-                                             md->cFREEZE,
+                                             gmx::arrayRefFromArray(inputRecord.opts.nFreeze,
+                                                                    inputRecord.opts.ngfrz),
+                                             ptype,
+                                             cFREEZE,
                                              x_rvec,
                                              xp_rvec,
                                              v_rvec,
@@ -1587,9 +1694,12 @@ void Update::Impl::update_coords(const t_inputrec&
     }
 }
 
-void Update::Impl::update_for_constraint_virial(const t_inputrec& inputRecord,
-                                                const t_mdatoms&  md,
-                                                const t_state&    state,
+void Update::Impl::update_for_constraint_virial(const t_inputrec&   inputRecord,
+                                                int                 homenr,
+                                                bool                havePartiallyFrozenAtoms,
+                                                gmx::ArrayRef<real> invmass,
+                                                gmx::ArrayRef<rvec> invMassPerDim,
+                                                const t_state&      state,
                                                 const gmx::ArrayRefWithPadding<const gmx::RVec>& f,
                                                 const gmx_ekindata_t& ekind)
 {
@@ -1607,7 +1717,7 @@ void Update::Impl::update_for_constraint_virial(const t_inputrec& inputRecord,
         try
         {
             int start_th, end_th;
-            getThreadAtomRange(nth, th, md.homenr, &start_th, &end_th);
+            getThreadAtomRange(nth, th, homenr, &start_th, &end_th);
 
             const rvec* x_rvec  = state.x.rvec_array();
             rvec*       xp_rvec = xp_.rvec_array();
@@ -1615,7 +1725,7 @@ void Update::Impl::update_for_constraint_virial(const t_inputrec& inputRecord,
             const rvec* f_rvec  = as_rvec_array(f.unpaddedConstArrayRef().data());
 
             doUpdateMDDoNotUpdateVelocities(
-                    start_th, end_th, dt, x_rvec, xp_rvec, v_rvec, f_rvec, md, ekind);
+                    start_th, end_th, dt, x_rvec, xp_rvec, v_rvec, f_rvec, havePartiallyFrozenAtoms, invmass, invMassPerDim, ekind);
         }
         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
     }