Update Awh initialization and lifetime management
authorMark Abraham <mark.j.abraham@gmail.com>
Sun, 3 Jun 2018 21:49:54 +0000 (23:49 +0200)
committerMark Abraham <mark.j.abraham@gmail.com>
Mon, 4 Jun 2018 20:38:57 +0000 (22:38 +0200)
inputrec no longer stores the main module, just the user-specified
parameters

RAII and make_unique is used for resource management.

The new factory function provides a good place to run cross-module
checks and checkpoint handling preparation.

Change-Id: I63218e201f965b838106fc8918a3f9fabdba88cb

src/gromacs/awh/awh.cpp
src/gromacs/awh/awh.h
src/gromacs/mdlib/force.h
src/gromacs/mdlib/shellfc.cpp
src/gromacs/mdlib/sim_util.cpp
src/gromacs/mdrun/md.cpp
src/gromacs/mdrun/minimize.cpp
src/gromacs/mdrun/tpi.cpp
src/gromacs/mdtypes/inputrec.h

index 8531347a4d95b3a87ecc12f3f642673f40ca129a..0f573e96c17f713a9e7749c7f74662e660cf4bec 100644 (file)
@@ -401,4 +401,40 @@ void Awh::writeToEnergyFrame(gmx_int64_t  step,
     }
 }
 
+std::unique_ptr<Awh>
+prepareAwhModule(FILE                 *fplog,
+                 const t_inputrec     &inputRecord,
+                 t_state              *stateGlobal,
+                 const t_commrec      *commRecord,
+                 const gmx_multisim_t *multiSimRecord,
+                 const bool            startingFromCheckpoint,
+                 const bool            usingShellParticles,
+                 const std::string    &biasInitFilename,
+                 pull_t               *pull_work)
+{
+    if (!inputRecord.bDoAwh)
+    {
+        return nullptr;
+    }
+    if (usingShellParticles)
+    {
+        GMX_THROW(InvalidInputError("AWH biasing does not support shell particles."));
+    }
+
+    auto awh = compat::make_unique<Awh>(fplog, inputRecord, commRecord, multiSimRecord, *inputRecord.awhParams,
+                                        biasInitFilename, pull_work);
+
+    if (startingFromCheckpoint)
+    {
+        // Restore the AWH history read from checkpoint
+        awh->restoreStateFromHistory(MASTER(commRecord) ? stateGlobal->awhHistory.get() : nullptr);
+    }
+    else if (MASTER(commRecord))
+    {
+        // Initialize the AWH history here
+        stateGlobal->awhHistory = awh->initHistoryFromState();
+    }
+    return awh;
+}
+
 } // namespace gmx
index 26d4efe5bd685fb93772843aaf8ed96774067fbc..858da6b7566d989db61bff50c39f931d3339f802 100644 (file)
@@ -249,6 +249,35 @@ class Awh
         double                           potentialOffset_;     /**< The offset of the bias potential which changes due to bias updates. */
 };
 
+/*! \brief Makes an Awh and prepares to use it if the user input
+ * requests that
+ *
+ * Restores state from history in checkpoint if needed.
+ *
+ * \param[in,out] fplog                   General output file, normally md.log, can be nullptr.
+ * \param[in]     inputRecord             General input parameters (as set up by grompp).
+ * \param[in]     stateGlobal             A pointer to the global state structure.
+ * \param[in]     commRecord              Struct for communication, can be nullptr.
+ * \param[in]     multiSimRecord          Multi-sim handler
+ * \param[in]     startingFromCheckpoint  Whether the simulation is starting from a checkpoint
+ * \param[in]     usingShellParticles     Whether the user requested shell particles (which is unsupported)
+ * \param[in]     biasInitFilename        Name of file to read PMF and target from.
+ * \param[in,out] pull_work               Pointer to a pull struct which AWH will couple to, has to be initialized,
+ *                                        is assumed not to change during the lifetime of the Awh object.
+ * \returns       An initialized Awh module, or nullptr if none was requested.
+ * \throws        InvalidInputError       If another active module is not supported.
+ */
+std::unique_ptr<Awh>
+prepareAwhModule(FILE                 *fplog,
+                 const t_inputrec     &inputRecord,
+                 t_state              *stateGlobal,
+                 const t_commrec      *commRecord,
+                 const gmx_multisim_t *multiSimRecord,
+                 const bool            startingFromCheckpoint,
+                 const bool            usingShellParticles,
+                 const std::string    &biasInitFilename,
+                 pull_t               *pull_work);
+
 }      // namespace gmx
 
 #endif /* GMX_AWH_H */
index 60c26a6365330dc963a3b2e74531b290f870122f..a0c0d88cde5ca924f6074181a00feb1de7a0da17 100644 (file)
@@ -64,6 +64,7 @@ struct t_nrnb;
 
 namespace gmx
 {
+class Awh;
 class ForceWithVirial;
 class MDLogger;
 }
@@ -90,6 +91,7 @@ void do_force(FILE                                     *log,
               const t_commrec                          *cr,
               const gmx_multisim_t                     *ms,
               const t_inputrec                         *inputrec,
+              gmx::Awh                                 *awh,
               gmx_int64_t                               step,
               t_nrnb                                   *nrnb,
               gmx_wallcycle                            *wcycle,
index 3958321029479c1b0a9410299f78749382761cb4..2359eac458673e44bfb4412c9b967853b2848798 100644 (file)
@@ -1137,7 +1137,8 @@ void relax_shell_flexcon(FILE                                     *fplog,
     {
         pr_rvecs(debug, 0, "x b4 do_force", as_rvec_array(state->x.data()), homenr);
     }
-    do_force(fplog, cr, ms, inputrec, mdstep, nrnb, wcycle, top, groups,
+    do_force(fplog, cr, ms, inputrec, nullptr,
+             mdstep, nrnb, wcycle, top, groups,
              state->box, state->x, &state->hist,
              force[Min], force_vir, md, enerd, fcd,
              state->lambda, graph,
@@ -1240,7 +1241,7 @@ void relax_shell_flexcon(FILE                                     *fplog,
             pr_rvecs(debug, 0, "RELAX: pos[Try]  ", as_rvec_array(pos[Try].data()), homenr);
         }
         /* Try the new positions */
-        do_force(fplog, cr, ms, inputrec, 1, nrnb, wcycle,
+        do_force(fplog, cr, ms, inputrec, nullptr, 1, nrnb, wcycle,
                  top, groups, state->box, pos[Try], &state->hist,
                  force[Try], force_vir,
                  md, enerd, fcd, state->lambda, graph,
index a6e94301e1c50c6227da8d2465aa27492cf2ca44..025d5339eca4b178d699953e61c1cf89da901017 100644 (file)
@@ -805,6 +805,7 @@ static void checkPotentialEnergyValidity(gmx_int64_t           step,
  * \param[in]     fplog            The log file
  * \param[in]     cr               The communication record
  * \param[in]     inputrec         The input record
+ * \param[in]     awh              The Awh module (nullptr if none in use).
  * \param[in]     step             The current MD step
  * \param[in]     t                The current time
  * \param[in,out] wcycle           Wallcycle accounting struct
@@ -826,6 +827,7 @@ static void
 computeSpecialForces(FILE                *fplog,
                      const t_commrec     *cr,
                      const t_inputrec    *inputrec,
+                     gmx::Awh            *awh,
                      gmx_int64_t          step,
                      double               t,
                      gmx_wallcycle_t      wcycle,
@@ -861,13 +863,12 @@ computeSpecialForces(FILE                *fplog,
                                mdatoms, enerd, lambda, t,
                                wcycle);
 
-        if (inputrec->bDoAwh)
+        if (awh)
         {
-            Awh &awh = *inputrec->awh;
             enerd->term[F_COM_PULL] +=
-                awh.applyBiasForcesAndUpdateBias(inputrec->ePBC, *mdatoms, box,
-                                                 forceWithVirial,
-                                                 t, step, wcycle, fplog);
+                awh->applyBiasForcesAndUpdateBias(inputrec->ePBC, *mdatoms, box,
+                                                  forceWithVirial,
+                                                  t, step, wcycle, fplog);
         }
     }
 
@@ -1050,6 +1051,7 @@ static void do_force_cutsVERLET(FILE *fplog,
                                 const t_commrec *cr,
                                 const gmx_multisim_t *ms,
                                 const t_inputrec *inputrec,
+                                gmx::Awh *awh,
                                 gmx_int64_t step,
                                 t_nrnb *nrnb,
                                 gmx_wallcycle_t wcycle,
@@ -1545,7 +1547,7 @@ static void do_force_cutsVERLET(FILE *fplog,
 
     wallcycle_stop(wcycle, ewcFORCE);
 
-    computeSpecialForces(fplog, cr, inputrec, step, t, wcycle,
+    computeSpecialForces(fplog, cr, inputrec, awh, step, t, wcycle,
                          fr->forceProviders, box, x, mdatoms, lambda,
                          flags, &forceWithVirial, enerd,
                          ed, bNS);
@@ -1745,6 +1747,7 @@ static void do_force_cutsGROUP(FILE *fplog,
                                const t_commrec *cr,
                                const gmx_multisim_t *ms,
                                const t_inputrec *inputrec,
+                               gmx::Awh *awh,
                                gmx_int64_t step,
                                t_nrnb *nrnb,
                                gmx_wallcycle_t wcycle,
@@ -2002,7 +2005,7 @@ static void do_force_cutsGROUP(FILE *fplog,
         }
     }
 
-    computeSpecialForces(fplog, cr, inputrec, step, t, wcycle,
+    computeSpecialForces(fplog, cr, inputrec, awh, step, t, wcycle,
                          fr->forceProviders, box, x, mdatoms, lambda,
                          flags, &forceWithVirial, enerd,
                          ed, bNS);
@@ -2077,6 +2080,7 @@ void do_force(FILE                                     *fplog,
               const t_commrec                          *cr,
               const gmx_multisim_t                     *ms,
               const t_inputrec                         *inputrec,
+              gmx::Awh                                 *awh,
               gmx_int64_t                               step,
               t_nrnb                                   *nrnb,
               gmx_wallcycle_t                           wcycle,
@@ -2114,7 +2118,7 @@ void do_force(FILE                                     *fplog,
     {
         case ecutsVERLET:
             do_force_cutsVERLET(fplog, cr, ms, inputrec,
-                                step, nrnb, wcycle,
+                                awh, step, nrnb, wcycle,
                                 top,
                                 groups,
                                 box, x, hist,
@@ -2131,7 +2135,7 @@ void do_force(FILE                                     *fplog,
             break;
         case ecutsGROUP:
             do_force_cutsGROUP(fplog, cr, ms, inputrec,
-                               step, nrnb, wcycle,
+                               awh, step, nrnb, wcycle,
                                top,
                                groups,
                                box, x, hist,
index 02f29de5c70220bfa25d6b8a5f4853f8c7c0cf58..2c1ff9962ebd2585c11be16a88532e6436553c14 100644 (file)
@@ -442,11 +442,6 @@ void gmx::Integrator::do_md()
                                  top_global, n_flexible_constraints(constr),
                                  ir->nstcalcenergy, DOMAINDECOMP(cr));
 
-    if (shellfc && ir->bDoAwh)
-    {
-        gmx_fatal(FARGS, "AWH biasing does not support shell particles.");
-    }
-
     if (inputrecDeform(ir))
     {
         tMPI_Thread_mutex_lock(&deform_init_box_mutex);
@@ -561,22 +556,10 @@ void gmx::Integrator::do_md()
         set_constraints(constr, top, ir, mdatoms, cr);
     }
 
-    /* Initialize AWH and restore state from history in checkpoint if needed. */
-    if (ir->bDoAwh)
-    {
-        ir->awh = new gmx::Awh(fplog, *ir, cr, ms, *ir->awhParams, opt2fn("-awh", nfile, fnm), ir->pull_work);
-
-        if (startingFromCheckpoint)
-        {
-            /* Restore the AWH history read from checkpoint */
-            ir->awh->restoreStateFromHistory(MASTER(cr) ? state_global->awhHistory.get() : nullptr);
-        }
-        else if (MASTER(cr))
-        {
-            /* Initialize the AWH history here */
-            state_global->awhHistory = ir->awh->initHistoryFromState();
-        }
-    }
+    // TODO: Remove this by converting AWH into a ForceProvider
+    auto awh = prepareAwhModule(fplog, *ir, state_global, cr, ms, startingFromCheckpoint,
+                                shellfc != nullptr,
+                                opt2fn("-awh", nfile, fnm), ir->pull_work);
 
     const bool useReplicaExchange = (replExParams.exchangeInterval > 0);
     if (useReplicaExchange && MASTER(cr))
@@ -1173,9 +1156,9 @@ void gmx::Integrator::do_md()
                do_md_trajectory_writing (then containing update_awh_history).
                The checkpointing will in the future probably moved to the start of the md loop which will
                rid of this issue. */
-            if (ir->bDoAwh && bCPT && MASTER(cr))
+            if (awh && bCPT && MASTER(cr))
             {
-                ir->awh->updateHistory(state_global->awhHistory.get());
+                awh->updateHistory(state_global->awhHistory.get());
             }
 
             /* The coordinates (x) are shifted (to get whole molecules)
@@ -1183,7 +1166,8 @@ void gmx::Integrator::do_md()
              * This is parallellized as well, and does communication too.
              * Check comments in sim_util.c
              */
-            do_force(fplog, cr, ms, ir, step, nrnb, wcycle, top, groups,
+            do_force(fplog, cr, ms, ir, awh.get(),
+                     step, nrnb, wcycle, top, groups,
                      state->box, state->x, &state->hist,
                      f, force_vir, mdatoms, enerd, fcd,
                      state->lambda, graph,
@@ -1753,7 +1737,7 @@ void gmx::Integrator::do_md()
 
             print_ebin(mdoutf_get_fp_ene(outf), do_ene, do_dr, do_or, do_log ? fplog : nullptr,
                        step, t,
-                       eprNORMAL, mdebin, fcd, groups, &(ir->opts), ir->awh);
+                       eprNORMAL, mdebin, fcd, groups, &(ir->opts), awh.get());
 
             if (ir->bPull)
             {
@@ -1944,7 +1928,7 @@ void gmx::Integrator::do_md()
         if (ir->nstcalcenergy > 0 && !bRerunMD)
         {
             print_ebin(mdoutf_get_fp_ene(outf), FALSE, FALSE, FALSE, fplog, step, t,
-                       eprAVER, mdebin, fcd, groups, &(ir->opts), ir->awh);
+                       eprAVER, mdebin, fcd, groups, &(ir->opts), awh.get());
         }
     }
     done_mdebin(mdebin);
@@ -1962,11 +1946,6 @@ void gmx::Integrator::do_md()
         print_replica_exchange_statistics(fplog, repl_ex);
     }
 
-    if (ir->bDoAwh)
-    {
-        delete ir->awh;
-    }
-
     // Clean up swapcoords
     if (ir->eSwapCoords != eswapNO)
     {
index 33ee89e395c23216199d0e6425aa9507a1f0585b..ea1ada43d5b98d5af4df57fe4db631ab84d15efa 100644 (file)
@@ -831,7 +831,7 @@ EnergyEvaluator::run(em_state_t *ems, rvec mu_tot,
     /* do_force always puts the charge groups in the box and shifts again
      * We do not unshift, so molecules are always whole in congrad.c
      */
-    do_force(fplog, cr, ms, inputrec,
+    do_force(fplog, cr, ms, inputrec, nullptr,
              count, nrnb, wcycle, top, &top_global->groups,
              ems->s.box, ems->s.x, &ems->s.hist,
              ems->f, force_vir, mdAtoms->mdatoms(), enerd, fcd,
index 5c8e3aff078b95064f684d79decd5b1e128e5080..b3d28a63daf1a0b2195795e61132005928ed4d91 100644 (file)
@@ -629,7 +629,7 @@ Integrator::do_tpi()
              * out of the box. */
             /* Make do_force do a single node force calculation */
             cr->nnodes = 1;
-            do_force(fplog, cr, ms, inputrec,
+            do_force(fplog, cr, ms, inputrec, nullptr,
                      step, nrnb, wcycle, top, &top_global->groups,
                      state_global->box, state_global->x, &state_global->hist,
                      f, force_vir, mdatoms, enerd, fcd,
index 9664dc5160c621502b8dc46d4eba6615aae88b98..3eff130d5531577d69ed86e5fa296a23a7c0eb07 100644 (file)
@@ -353,8 +353,6 @@ struct t_inputrec
     /* AWH bias data */
     gmx_bool                 bDoAwh;    /* Use awh biasing for PMF calculations?        */
     gmx::AwhParams          *awhParams; /* AWH biasing parameters                       */
-    // TODO: Remove this by converting AWH into a ForceProvider
-    gmx::Awh                *awh;       /* AWH work object */
 
     /* Enforced rotation data */
     gmx_bool                 bRot;           /* Calculate enforced rotation potential(s)?    */