Convert gmx_update_t to C++
[alexxy/gromacs.git] / src / gromacs / mdlib / update.cpp
index c34571000f2094eadb316703d0f86844166d3785..e6f39df9bc17b0f77db5e4b1dc377f60747cd98c 100644 (file)
 
 using namespace gmx; // TODO: Remove when this file is moved into gmx namespace
 
-typedef struct {
+struct gmx_sd_const_t {
     double em;
-} gmx_sd_const_t;
+};
 
-typedef struct {
+struct gmx_sd_sigma_t {
     real V;
-} gmx_sd_sigma_t;
+};
 
 struct gmx_stochd_t
 {
@@ -104,22 +104,46 @@ struct gmx_stochd_t
     std::vector<bool>           randomize_group;
     std::vector<real>           boltzfac;
 
-    gmx_stochd_t(const t_inputrec *ir);
+    explicit gmx_stochd_t(const t_inputrec *ir);
 };
 
-struct gmx_update_t
+//! pImpled implementation for Update
+class Update::Impl
 {
-    std::unique_ptr<gmx_stochd_t> sd;
-    /* xprime for constraint algorithms */
-    PaddedVector<gmx::RVec>       xp;
+    public:
+        //! Constructor
+        Impl(const t_inputrec    *ir, BoxDeformation *boxDeformation);
+        //! Destructor
+        ~Impl() = default;
+        //! stochastic dynamics struct
+        std::unique_ptr<gmx_stochd_t> sd;
+        //! xprime for constraint algorithms
+        PaddedVector<RVec>            xp;
+        //! Box deformation handler (or nullptr if inactive).
+        BoxDeformation               *deform = nullptr;
+};
 
-    /* Variables for the deform algorithm */
-    int64_t           deformref_step;
-    matrix            deformref_box;
+Update::Update(const t_inputrec    *ir, BoxDeformation *boxDeformation)
+    : impl_(new Impl(ir, boxDeformation))
+{};
 
-    //! Box deformation handler (or nullptr if inactive).
-    gmx::BoxDeformation *deform;
-};
+Update::~Update()
+{};
+
+gmx_stochd_t* Update::sd() const
+{
+    return impl_->sd.get();
+}
+
+PaddedVector<RVec> * Update::xp()
+{
+    return &impl_->xp;
+}
+
+BoxDeformation * Update::deform() const
+{
+    return impl_->deform;
+}
 
 static bool isTemperatureCouplingStep(int64_t step, const t_inputrec *ir)
 {
@@ -819,7 +843,7 @@ gmx_stochd_t::gmx_stochd_t(const t_inputrec *ir)
     }
 }
 
-void update_temperature_constants(gmx_update_t *upd, const t_inputrec *ir)
+void update_temperature_constants(gmx_stochd_t *sd, const t_inputrec *ir)
 {
     if (ir->eI == eiBD)
     {
@@ -827,14 +851,14 @@ void update_temperature_constants(gmx_update_t *upd, const t_inputrec *ir)
         {
             for (int gt = 0; gt < ir->opts.ngtc; gt++)
             {
-                upd->sd->bd_rf[gt] = std::sqrt(2.0*BOLTZ*ir->opts.ref_t[gt]/(ir->bd_fric*ir->delta_t));
+                sd->bd_rf[gt] = std::sqrt(2.0*BOLTZ*ir->opts.ref_t[gt]/(ir->bd_fric*ir->delta_t));
             }
         }
         else
         {
             for (int gt = 0; gt < ir->opts.ngtc; gt++)
             {
-                upd->sd->bd_rf[gt] = std::sqrt(2.0*BOLTZ*ir->opts.ref_t[gt]);
+                sd->bd_rf[gt] = std::sqrt(2.0*BOLTZ*ir->opts.ref_t[gt]);
             }
         }
     }
@@ -844,32 +868,23 @@ void update_temperature_constants(gmx_update_t *upd, const t_inputrec *ir)
         {
             real kT = BOLTZ*ir->opts.ref_t[gt];
             /* The mass is accounted for later, since this differs per atom */
-            upd->sd->sdsig[gt].V  = std::sqrt(kT*(1 - upd->sd->sdc[gt].em*upd->sd->sdc[gt].em));
+            sd->sdsig[gt].V  = std::sqrt(kT*(1 - sd->sdc[gt].em * sd->sdc[gt].em));
         }
     }
 }
 
-gmx_update_t *init_update(const t_inputrec    *ir,
-                          gmx::BoxDeformation *deform)
+Update::Impl::Impl(const t_inputrec    *ir, BoxDeformation *boxDeformation)
 {
-    gmx_update_t *upd = new(gmx_update_t);
-
-    upd->sd    = gmx::compat::make_unique<gmx_stochd_t>(ir);
-
-    update_temperature_constants(upd, ir);
-
-    upd->xp.resizeWithPadding(0);
-
-    upd->deform = deform;
-
-    return upd;
+    sd = gmx::compat::make_unique<gmx_stochd_t>(ir);
+    update_temperature_constants(sd.get(), ir);
+    xp.resizeWithPadding(0);
+    deform = boxDeformation;
 }
 
-void update_realloc(gmx_update_t *upd, int natoms)
+void Update::setNumAtoms(int nAtoms)
 {
-    GMX_ASSERT(upd, "upd must be allocated before its fields can be reallocated");
 
-    upd->xp.resizeWithPadding(natoms);
+    impl_->xp.resizeWithPadding(nAtoms);
 }
 
 /*! \brief Sets the SD update type */
@@ -1478,15 +1493,15 @@ void constrain_velocities(int64_t                        step,
     }
 }
 
-void constrain_coordinates(int64_t                        step,
-                           real                          *dvdlambda, /* the contribution to be added to the bonded interactions */
-                           t_state                       *state,
-                           tensor                         vir_part,
-                           gmx_update_t                  *upd,
-                           gmx::Constraints              *constr,
-                           gmx_bool                       bCalcVir,
-                           bool                           do_log,
-                           bool                           do_ene)
+void constrain_coordinates(int64_t           step,
+                           real             *dvdlambda, /* the contribution to be added to the bonded interactions */
+                           t_state          *state,
+                           tensor            vir_part,
+                           Update           *upd,
+                           gmx::Constraints *constr,
+                           gmx_bool          bCalcVir,
+                           bool              do_log,
+                           bool              do_ene)
 {
     if (!constr)
     {
@@ -1502,7 +1517,7 @@ void constrain_coordinates(int64_t                        step,
         /* Constrain the coordinates upd->xp */
         constr->apply(do_log, do_ene,
                       step, 1, 1.0,
-                      state->x.rvec_array(), upd->xp.rvec_array(), nullptr,
+                      state->x.rvec_array(), upd->xp()->rvec_array(), nullptr,
                       state->box,
                       state->lambda[efptBONDED], dvdlambda,
                       as_rvec_array(state->v.data()), bCalcVir ? &vir_con : nullptr, ConstraintVariable::Positions);
@@ -1515,18 +1530,18 @@ void constrain_coordinates(int64_t                        step,
 }
 
 void
-update_sd_second_half(int64_t                        step,
-                      real                          *dvdlambda,   /* the contribution to be added to the bonded interactions */
-                      const t_inputrec              *inputrec,    /* input record and box stuff        */
-                      const t_mdatoms               *md,
-                      t_state                       *state,
-                      const t_commrec               *cr,
-                      t_nrnb                        *nrnb,
-                      gmx_wallcycle_t                wcycle,
-                      gmx_update_t                  *upd,
-                      gmx::Constraints              *constr,
-                      bool                           do_log,
-                      bool                           do_ene)
+update_sd_second_half(int64_t           step,
+                      real             *dvdlambda, /* the contribution to be added to the bonded interactions */
+                      const t_inputrec *inputrec,  /* input record and box stuff       */
+                      const t_mdatoms  *md,
+                      t_state          *state,
+                      const t_commrec  *cr,
+                      t_nrnb           *nrnb,
+                      gmx_wallcycle_t   wcycle,
+                      Update           *upd,
+                      gmx::Constraints *constr,
+                      bool              do_log,
+                      bool              do_ene)
 {
     if (!constr)
     {
@@ -1559,12 +1574,12 @@ update_sd_second_half(int64_t                        step,
                 getThreadAtomRange(nth, th, homenr, &start_th, &end_th);
 
                 doSDUpdateGeneral<SDUpdate::FrictionAndNoiseOnly>
-                    (*upd->sd,
+                    (*upd->sd(),
                     start_th, end_th, dt,
                     inputrec->opts.acc, inputrec->opts.nFreeze,
                     md->invmass, md->ptype,
                     md->cFREEZE, nullptr, md->cTC,
-                    state->x.rvec_array(), upd->xp.rvec_array(),
+                    state->x.rvec_array(), upd->xp()->rvec_array(),
                     state->v.rvec_array(), nullptr,
                     step, inputrec->ld_seed,
                     DOMAINDECOMP(cr) ? cr->dd->globalAtomIndices.data() : nullptr);
@@ -1577,21 +1592,21 @@ update_sd_second_half(int64_t                        step,
         /* Constrain the coordinates upd->xp for half a time step */
         constr->apply(do_log, do_ene,
                       step, 1, 0.5,
-                      state->x.rvec_array(), upd->xp.rvec_array(), nullptr,
+                      state->x.rvec_array(), upd->xp()->rvec_array(), nullptr,
                       state->box,
                       state->lambda[efptBONDED], dvdlambda,
                       as_rvec_array(state->v.data()), nullptr, ConstraintVariable::Positions);
     }
 }
 
-void finish_update(const t_inputrec              *inputrec,  /* input record and box stuff     */
-                   const t_mdatoms               *md,
-                   t_state                       *state,
-                   const t_graph                 *graph,
-                   t_nrnb                        *nrnb,
-                   gmx_wallcycle_t                wcycle,
-                   gmx_update_t                  *upd,
-                   const gmx::Constraints        *constr)
+void finish_update(const t_inputrec       *inputrec, /* input record and box stuff     */
+                   const t_mdatoms        *md,
+                   t_state                *state,
+                   const t_graph          *graph,
+                   t_nrnb                 *nrnb,
+                   gmx_wallcycle_t         wcycle,
+                   Update                 *upd,
+                   const gmx::Constraints *constr)
 {
     int homenr = md->homenr;
 
@@ -1624,7 +1639,7 @@ void finish_update(const t_inputrec              *inputrec,  /* input record and
             }
             if (partialFreezeAndConstraints)
             {
-                auto xp = makeArrayRef(upd->xp).subArray(0, homenr);
+                auto xp = makeArrayRef(*upd->xp()).subArray(0, homenr);
                 auto x  = makeConstArrayRef(state->x).subArray(0, homenr);
                 for (int i = 0; i < homenr; i++)
                 {
@@ -1643,7 +1658,7 @@ void finish_update(const t_inputrec              *inputrec,  /* input record and
 
         if (graph && (graph->nnodes > 0))
         {
-            unshift_x(graph, state->box, state->x.rvec_array(), upd->xp.rvec_array());
+            unshift_x(graph, state->box, state->x.rvec_array(), upd->xp()->rvec_array());
             if (TRICLINIC(state->box))
             {
                 inc_nrnb(nrnb, eNR_SHIFTX, 2*graph->nnodes);
@@ -1655,7 +1670,7 @@ void finish_update(const t_inputrec              *inputrec,  /* input record and
         }
         else
         {
-            auto           xp = makeConstArrayRef(upd->xp).subArray(0, homenr);
+            auto           xp = makeConstArrayRef(*upd->xp()).subArray(0, homenr);
             auto           x  = makeArrayRef(state->x).subArray(0, homenr);
 #ifndef __clang_analyzer__
             int gmx_unused nth = gmx_omp_nthreads_get(emntUpdate);
@@ -1682,7 +1697,7 @@ void update_pcouple_after_coordinates(FILE             *fplog,
                                       const matrix      parrinellorahmanMu,
                                       t_state          *state,
                                       t_nrnb           *nrnb,
-                                      gmx_update_t     *upd)
+                                      Update           *upd)
 {
     int  start  = 0;
     int  homenr = md->homenr;
@@ -1764,10 +1779,10 @@ void update_pcouple_after_coordinates(FILE             *fplog,
             break;
     }
 
-    if (upd->deform)
+    if (upd->deform())
     {
         auto localX = makeArrayRef(state->x).subArray(start, homenr);
-        upd->deform->apply(localX, state->box, step);
+        upd->deform()->apply(localX, state->box, step);
     }
 }
 
@@ -1779,7 +1794,7 @@ void update_coords(int64_t                             step,
                    const t_fcdata                     *fcd,
                    const gmx_ekindata_t               *ekind,
                    const matrix                        M,
-                   gmx_update_t                       *upd,
+                   Update                             *upd,
                    int                                 UpdatePart,
                    const t_commrec                    *cr, /* these shouldn't be here -- need to think about it */
                    const gmx::Constraints             *constr)
@@ -1820,7 +1835,7 @@ void update_coords(int64_t                             step,
             getThreadAtomRange(nth, th, homenr, &start_th, &end_th);
 
             const rvec *x_rvec  = state->x.rvec_array();
-            rvec       *xp_rvec = upd->xp.rvec_array();
+            rvec       *xp_rvec = upd->xp()->rvec_array();
             rvec       *v_rvec  = state->v.rvec_array();
             const rvec *f_rvec  = as_rvec_array(f.unpaddedArrayRef().data());
 
@@ -1837,7 +1852,7 @@ void update_coords(int64_t                             step,
                     {
                         // With constraints, the SD update is done in 2 parts
                         doSDUpdateGeneral<SDUpdate::ForcesOnly>
-                            (*upd->sd,
+                            (*upd->sd(),
                             start_th, end_th, dt,
                             inputrec->opts.acc, inputrec->opts.nFreeze,
                             md->invmass, md->ptype,
@@ -1848,7 +1863,7 @@ void update_coords(int64_t                             step,
                     else
                     {
                         doSDUpdateGeneral<SDUpdate::Combined>
-                            (*upd->sd,
+                            (*upd->sd(),
                             start_th, end_th, dt,
                             inputrec->opts.acc, inputrec->opts.nFreeze,
                             md->invmass, md->ptype,
@@ -1864,7 +1879,7 @@ void update_coords(int64_t                             step,
                                  md->cFREEZE, md->cTC,
                                  x_rvec, xp_rvec, v_rvec, f_rvec,
                                  inputrec->bd_fric,
-                                 upd->sd->bd_rf.data(),
+                                 upd->sd()->bd_rf.data(),
                                  step, inputrec->ld_seed, DOMAINDECOMP(cr) ? cr->dd->globalAtomIndices.data() : nullptr);
                     break;
                 case (eiVV):
@@ -1909,7 +1924,7 @@ void update_coords(int64_t                             step,
 extern gmx_bool update_randomize_velocities(const t_inputrec *ir, int64_t step, const t_commrec *cr,
                                             const t_mdatoms *md,
                                             gmx::ArrayRef<gmx::RVec> v,
-                                            const gmx_update_t *upd,
+                                            const Update *upd,
                                             const gmx::Constraints *constr)
 {
 
@@ -1931,8 +1946,8 @@ extern gmx_bool update_randomize_velocities(const t_inputrec *ir, int64_t step,
     if ((ir->etc == etcANDERSEN) || do_per_step(step, roundToInt(1.0/rate)))
     {
         andersen_tcoupl(ir, step, cr, md, v, rate,
-                        upd->sd->randomize_group,
-                        upd->sd->boltzfac);
+                        upd->sd()->randomize_group,
+                        upd->sd()->boltzfac);
         return TRUE;
     }
     return FALSE;