Fix FEP lambda interpolation for reruns
authorChristian Blau <cblau.mail@gmail.com>
Fri, 2 Oct 2020 07:50:52 +0000 (07:50 +0000)
committerChristian Blau <cblau.mail@gmail.com>
Fri, 2 Oct 2020 07:50:52 +0000 (07:50 +0000)
FEP lambda interpolation fix is now also applied when re-running simulations.

Closes #3585

src/gromacs/mdlib/md_support.cpp

index 9d9f8a605bf2b2ad63d8f5d1df43492b6930af13..dff38c6576b029af40333e80a7fde0ad0733ca7f 100644 (file)
@@ -316,8 +316,8 @@ void compute_globals(gmx_global_stat*          gstat,
 void setCurrentLambdasRerun(int64_t           step,
                             const t_lambda*   fepvals,
                             const t_trxframe* rerun_fr,
-                            const double*     lam0,
-                            t_state*          globalState)
+                            const double* /*lam0*/,
+                            t_state* globalState)
 {
     GMX_RELEASE_ASSERT(globalState != nullptr,
                        "setCurrentLambdasGlobalRerun should be called with a valid state object");
@@ -331,16 +331,59 @@ void setCurrentLambdasRerun(int64_t           step,
         else
         {
             /* find out between which two value of lambda we should be */
-            real frac      = step * fepvals->delta_lambda;
-            int  fep_state = static_cast<int>(std::floor(frac * fepvals->n_lambda));
-            /* interpolate between this state and the next */
-            /* this assumes that the initial lambda corresponds to lambda==0, which is verified in grompp */
-            frac = frac * fepvals->n_lambda - fep_state;
+            const real fracSimulationLambda = step * fepvals->delta_lambda;
+
+            // Set initial lambda value for the simulation either from initialFEPStateIndex or,
+            // if not set, from the initial lambda.
+            double initialGlobalLambda = 0;
+            if (fepvals->init_fep_state > -1)
+            {
+                if (fepvals->n_lambda > 1)
+                {
+                    initialGlobalLambda =
+                            static_cast<double>(fepvals->init_fep_state) / (fepvals->n_lambda - 1);
+                }
+            }
+            else
+            {
+                if (fepvals->init_lambda > -1)
+                {
+                    initialGlobalLambda = fepvals->init_lambda;
+                }
+            }
+
+            const double globalLambda = initialGlobalLambda + fracSimulationLambda;
+
+            // when there is no lambda value array, set all lambdas to steps * deltaLambdaPerStep
+            if (fepvals->n_lambda <= 0)
+            {
+                std::fill(std::begin(globalState->lambda), std::end(globalState->lambda), globalLambda);
+                return;
+            }
+
+            GMX_ASSERT(
+                    globalLambda <= 1 || gmx_within_tol(globalLambda, 1, 1e-5),
+                    "Lambda may not be larger than one when interpolating an array of multi-lambda "
+                    "values.");
+            GMX_ASSERT(globalLambda >= 0 || gmx_within_tol(globalLambda, 0, 1e-5),
+                       "Lambda may not be negative when interpolating an array of multi-lambda "
+                       "values.");
+
+            // find out between which two value lambda array elements to interpolate
+            // at the boundary of the lambda array, return the boundary array values
+            const int fepStateLeft =
+                    std::max(0, static_cast<int>(std::floor(globalLambda * (fepvals->n_lambda - 1))));
+
+            const int fepStateRight = std::min(fepvals->n_lambda - 1, fepStateLeft + 1);
+
+            // interpolate between this state and the next
+            const double fracBetween = globalLambda * (fepvals->n_lambda - 1) - fepStateLeft;
             for (int i = 0; i < efptNR; i++)
             {
-                globalState->lambda[i] =
-                        lam0[i] + (fepvals->all_lambda[i][fep_state])
-                        + frac * (fepvals->all_lambda[i][fep_state + 1] - fepvals->all_lambda[i][fep_state]);
+                globalState->lambda[i] = fepvals->all_lambda[i][fepStateLeft]
+                                         + fracBetween
+                                                   * (fepvals->all_lambda[i][fepStateRight]
+                                                      - fepvals->all_lambda[i][fepStateLeft]);
             }
         }
     }