Move computeSlowForces into stepWork
[alexxy/gromacs.git] / src / gromacs / mdlib / update.cpp
index 0cc8ca8e5bd6fa4996cb74f5b87f127d45eb00e1..46f848d4e73a42a083f975935ab5780274a1a796 100644 (file)
@@ -150,6 +150,12 @@ public:
                                bool              do_log,
                                bool              do_ene);
 
+    void update_for_constraint_virial(const t_inputrec&                                inputRecord,
+                                      const t_mdatoms&                                 md,
+                                      const t_state&                                   state,
+                                      const gmx::ArrayRefWithPadding<const gmx::RVec>& f,
+                                      const gmx_ekindata_t&                            ekind);
+
     void update_temperature_constants(const t_inputrec& inputRecord);
 
     const std::vector<bool>& getAndersenRandomizeGroup() const { return sd_.randomize_group; }
@@ -235,6 +241,15 @@ void Update::update_sd_second_half(const t_inputrec& inputRecord,
                                         constr, do_log, do_ene);
 }
 
+void Update::update_for_constraint_virial(const t_inputrec& inputRecord,
+                                          const t_mdatoms&  md,
+                                          const t_state&    state,
+                                          const gmx::ArrayRefWithPadding<const gmx::RVec>& f,
+                                          const gmx_ekindata_t&                            ekind)
+{
+    return impl_->update_for_constraint_virial(inputRecord, md, state, f, ekind);
+}
+
 void Update::update_temperature_constants(const t_inputrec& inputRecord)
 {
     return impl_->update_temperature_constants(inputRecord);
@@ -252,6 +267,13 @@ static void clearVsiteVelocities(int start, int nrend, const unsigned short* par
     }
 }
 
+/*! \brief Sets whether we store the updated velocities */
+enum class StoreUpdatedVelocities
+{
+    yes, //!< Store the updated velocities
+    no   //!< Do not store the updated velocities
+};
+
 /*! \brief Sets the number of different temperature coupling values */
 enum class NumTempScaleValues
 {
@@ -273,6 +295,7 @@ enum class ApplyParrinelloRahmanVScaling
 
 /*! \brief Integrate using leap-frog with T-scaling and optionally diagonal Parrinello-Rahman p-coupling
  *
+ * \tparam       storeUpdatedVelocities Tells whether we should store the updated velocities
  * \tparam       numTempScaleValues     The number of different T-couple values
  * \tparam       applyPRVScaling        Apply Parrinello-Rahman velocity scaling
  * \param[in]    start                  Index of first atom to update
@@ -293,7 +316,7 @@ enum class ApplyParrinelloRahmanVScaling
  * Note that we might get even better SIMD acceleration when we introduce
  * aligned (and padded) memory, possibly with some hints for the compilers.
  */
-template<NumTempScaleValues numTempScaleValues, ApplyParrinelloRahmanVScaling applyPRVScaling>
+template<StoreUpdatedVelocities storeUpdatedVelocities, NumTempScaleValues numTempScaleValues, ApplyParrinelloRahmanVScaling applyPRVScaling>
 static void updateMDLeapfrogSimple(int         start,
                                    int         nrend,
                                    real        dt,
@@ -335,7 +358,10 @@ static void updateMDLeapfrogSimple(int         start,
             {
                 vNew -= dtPressureCouple * pRVScaleMatrixDiagonal[d] * v[a][d];
             }
-            v[a][d]      = vNew;
+            if (storeUpdatedVelocities == StoreUpdatedVelocities::yes)
+            {
+                v[a][d] = vNew;
+            }
             xprime[a][d] = x[a][d] + vNew * dt;
         }
     }
@@ -399,6 +425,7 @@ static inline void simdStoreRvecs(rvec* r, int index, SimdReal r0, SimdReal r1,
 
 /*! \brief Integrate using leap-frog with single group T-scaling and SIMD
  *
+ * \tparam       storeUpdatedVelocities Tells whether we should store the updated velocities
  * \param[in]    start                  Index of first atom to update
  * \param[in]    nrend                  Last atom to update: \p nrend - 1
  * \param[in]    dt                     The time step
@@ -409,6 +436,7 @@ static inline void simdStoreRvecs(rvec* r, int index, SimdReal r0, SimdReal r1,
  * \param[inout] v                      Velocities
  * \param[in]    f                      Forces
  */
+template<StoreUpdatedVelocities storeUpdatedVelocities>
 static void updateMDLeapfrogSimpleSimd(int         start,
                                        int         nrend,
                                        real        dt,
@@ -441,7 +469,10 @@ static void updateMDLeapfrogSimpleSimd(int         start,
         v1 = fma(f1 * invMass1, timestep, lambdaSystem * v1);
         v2 = fma(f2 * invMass2, timestep, lambdaSystem * v2);
 
-        simdStoreRvecs(v, a, v0, v1, v2);
+        if (storeUpdatedVelocities == StoreUpdatedVelocities::yes)
+        {
+            simdStoreRvecs(v, a, v0, v1, v2);
+        }
 
         SimdReal x0, x1, x2;
         simdLoadRvecs(x, a, &x0, &x1, &x2);
@@ -700,13 +731,15 @@ static void do_update_md(int         start,
 
             if (haveSingleTempScaleValue)
             {
-                updateMDLeapfrogSimple<NumTempScaleValues::single, ApplyParrinelloRahmanVScaling::diagonal>(
+                updateMDLeapfrogSimple<StoreUpdatedVelocities::yes, NumTempScaleValues::single,
+                                       ApplyParrinelloRahmanVScaling::diagonal>(
                         start, nrend, dt, dtPressureCouple, invMassPerDim, tcstat, cTC, diagM, x,
                         xprime, v, f);
             }
             else
             {
-                updateMDLeapfrogSimple<NumTempScaleValues::multiple, ApplyParrinelloRahmanVScaling::diagonal>(
+                updateMDLeapfrogSimple<StoreUpdatedVelocities::yes, NumTempScaleValues::multiple,
+                                       ApplyParrinelloRahmanVScaling::diagonal>(
                         start, nrend, dt, dtPressureCouple, invMassPerDim, tcstat, cTC, diagM, x,
                         xprime, v, f);
             }
@@ -724,25 +757,58 @@ static void do_update_md(int         start,
                 /* Check if we can use invmass instead of invMassPerDim */
                 if (!md->havePartiallyFrozenAtoms)
                 {
-                    updateMDLeapfrogSimpleSimd(start, nrend, dt, md->invmass, tcstat, x, xprime, v, f);
+                    updateMDLeapfrogSimpleSimd<StoreUpdatedVelocities::yes>(
+                            start, nrend, dt, md->invmass, tcstat, x, xprime, v, f);
                 }
                 else
 #endif
                 {
-                    updateMDLeapfrogSimple<NumTempScaleValues::single, ApplyParrinelloRahmanVScaling::no>(
+                    updateMDLeapfrogSimple<StoreUpdatedVelocities::yes, NumTempScaleValues::single,
+                                           ApplyParrinelloRahmanVScaling::no>(
                             start, nrend, dt, dtPressureCouple, invMassPerDim, tcstat, cTC, nullptr,
                             x, xprime, v, f);
                 }
             }
             else
             {
-                updateMDLeapfrogSimple<NumTempScaleValues::multiple, ApplyParrinelloRahmanVScaling::no>(
+                updateMDLeapfrogSimple<StoreUpdatedVelocities::yes, NumTempScaleValues::multiple,
+                                       ApplyParrinelloRahmanVScaling::no>(
                         start, nrend, dt, dtPressureCouple, invMassPerDim, tcstat, cTC, nullptr, x,
                         xprime, v, f);
             }
         }
     }
 }
+/*! \brief Handles the Leap-frog MD x and v integration */
+static void doUpdateMDDoNotUpdateVelocities(int         start,
+                                            int         nrend,
+                                            real        dt,
+                                            const rvec* gmx_restrict x,
+                                            rvec* gmx_restrict xprime,
+                                            rvec* gmx_restrict v,
+                                            const rvec* gmx_restrict f,
+                                            const t_mdatoms&         md,
+                                            const gmx_ekindata_t&    ekind)
+{
+    GMX_ASSERT(nrend == start || xprime != x,
+               "For SIMD optimization certain compilers need to have xprime != x");
+
+    gmx::ArrayRef<const t_grp_tcstat> tcstat = ekind.tcstat;
+
+    /* Check if we can use invmass instead of invMassPerDim */
+#if GMX_HAVE_SIMD_UPDATE
+    if (!md.havePartiallyFrozenAtoms)
+    {
+        updateMDLeapfrogSimpleSimd<StoreUpdatedVelocities::no>(start, nrend, dt, md.invmass, tcstat,
+                                                               x, xprime, v, f);
+    }
+    else
+#endif
+    {
+        updateMDLeapfrogSimple<StoreUpdatedVelocities::no, NumTempScaleValues::single, ApplyParrinelloRahmanVScaling::no>(
+                start, nrend, dt, dt, md.invMassPerDim, tcstat, nullptr, nullptr, x, xprime, v, f);
+    }
+}
 
 static void do_update_vv_vel(int                  start,
                              int                  nrend,
@@ -1494,3 +1560,37 @@ void Update::Impl::update_coords(const t_inputrec&
         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
     }
 }
+
+void Update::Impl::update_for_constraint_virial(const t_inputrec& inputRecord,
+                                                const t_mdatoms&  md,
+                                                const t_state&    state,
+                                                const gmx::ArrayRefWithPadding<const gmx::RVec>& f,
+                                                const gmx_ekindata_t& ekind)
+{
+    GMX_ASSERT(inputRecord.eI == eiMD || inputRecord.eI == eiSD1,
+               "Only leap-frog is supported here");
+
+    // Cast to real for faster code, no loss in precision
+    const real dt = inputRecord.delta_t;
+
+    const int nth = gmx_omp_nthreads_get(emntUpdate);
+
+#pragma omp parallel for num_threads(nth) schedule(static)
+    for (int th = 0; th < nth; th++)
+    {
+        try
+        {
+            int start_th, end_th;
+            getThreadAtomRange(nth, th, md.homenr, &start_th, &end_th);
+
+            const rvec* x_rvec  = state.x.rvec_array();
+            rvec*       xp_rvec = xp_.rvec_array();
+            rvec*       v_rvec  = const_cast<rvec*>(state.v.rvec_array());
+            const rvec* f_rvec  = as_rvec_array(f.unpaddedConstArrayRef().data());
+
+            doUpdateMDDoNotUpdateVelocities(start_th, end_th, dt, x_rvec, xp_rvec, v_rvec, f_rvec,
+                                            md, ekind);
+        }
+        GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
+    }
+}