Remove const cast in update code
authorBerk Hess <hess@kth.se>
Fri, 19 Mar 2021 12:42:49 +0000 (12:42 +0000)
committerPaul Bauer <paul.bauer.q@gmail.com>
Fri, 19 Mar 2021 12:42:49 +0000 (12:42 +0000)
src/gromacs/mdlib/update.cpp

index 16c722e915ed39bc763df6b5e759ebc3a5397826..4def5f35344359bad997ce5528da61f08bcf7f3c 100644 (file)
@@ -296,7 +296,7 @@ enum class ApplyParrinelloRahmanVScaling
  * \param[in]    pRVScaleMatrixDiagonal Parrinello-Rahman v-scale matrix diagonal
  * \param[in]    x                      Input coordinates
  * \param[out]   xprime                 Updated coordinates
- * \param[inout] v                      Velocities
+ * \param[inout] v                      Velocities, type either rvec* or const rvec*
  * \param[in]    f                      Forces
  *
  * We expect this template to get good SIMD acceleration by most compilers,
@@ -304,19 +304,20 @@ enum class ApplyParrinelloRahmanVScaling
  * Note that we might get even better SIMD acceleration when we introduce
  * aligned (and padded) memory, possibly with some hints for the compilers.
  */
-template<StoreUpdatedVelocities storeUpdatedVelocities, NumTempScaleValues numTempScaleValues, ApplyParrinelloRahmanVScaling applyPRVScaling>
-static void updateMDLeapfrogSimple(int         start,
-                                   int         nrend,
-                                   real        dt,
-                                   real        dtPressureCouple,
-                                   const rvec* gmx_restrict          invMassPerDim,
-                                   gmx::ArrayRef<const t_grp_tcstat> tcstat,
-                                   const unsigned short*             cTC,
-                                   const rvec                        pRVScaleMatrixDiagonal,
-                                   const rvec* gmx_restrict x,
-                                   rvec* gmx_restrict xprime,
-                                   rvec* gmx_restrict v,
-                                   const rvec* gmx_restrict f)
+template<StoreUpdatedVelocities storeUpdatedVelocities, NumTempScaleValues numTempScaleValues, ApplyParrinelloRahmanVScaling applyPRVScaling, typename VelocityType>
+static std::enable_if_t<std::is_same<VelocityType, rvec*>::value || std::is_same<VelocityType, const rvec*>::value, void>
+updateMDLeapfrogSimple(int         start,
+                       int         nrend,
+                       real        dt,
+                       real        dtPressureCouple,
+                       const rvec* gmx_restrict          invMassPerDim,
+                       gmx::ArrayRef<const t_grp_tcstat> tcstat,
+                       const unsigned short*             cTC,
+                       const rvec                        pRVScaleMatrixDiagonal,
+                       const rvec* gmx_restrict x,
+                       rvec* gmx_restrict xprime,
+                       VelocityType gmx_restrict v,
+                       const rvec* gmx_restrict f)
 {
     real lambdaGroup;
 
@@ -346,7 +347,7 @@ static void updateMDLeapfrogSimple(int         start,
             {
                 vNew -= dtPressureCouple * pRVScaleMatrixDiagonal[d] * v[a][d];
             }
-            if (storeUpdatedVelocities == StoreUpdatedVelocities::yes)
+            if constexpr (storeUpdatedVelocities == StoreUpdatedVelocities::yes) // NOLINT // NOLINTNEXTLINE
             {
                 v[a][d] = vNew;
             }
@@ -421,19 +422,20 @@ static inline void simdStoreRvecs(rvec* r, int index, SimdReal r0, SimdReal r1,
  * \param[in]    tcstat                 Temperature coupling information
  * \param[in]    x                      Input coordinates
  * \param[out]   xprime                 Updated coordinates
- * \param[inout] v                      Velocities
+ * \param[inout] v                      Velocities, type either rvec* or const rvec*
  * \param[in]    f                      Forces
  */
-template<StoreUpdatedVelocities storeUpdatedVelocities>
-static void updateMDLeapfrogSimpleSimd(int         start,
-                                       int         nrend,
-                                       real        dt,
-                                       const real* gmx_restrict          invMass,
-                                       gmx::ArrayRef<const t_grp_tcstat> tcstat,
-                                       const rvec* gmx_restrict x,
-                                       rvec* gmx_restrict xprime,
-                                       rvec* gmx_restrict v,
-                                       const rvec* gmx_restrict f)
+template<StoreUpdatedVelocities storeUpdatedVelocities, typename VelocityType>
+static std::enable_if_t<std::is_same<VelocityType, rvec*>::value || std::is_same<VelocityType, const rvec*>::value, void>
+updateMDLeapfrogSimpleSimd(int         start,
+                           int         nrend,
+                           real        dt,
+                           const real* gmx_restrict          invMass,
+                           gmx::ArrayRef<const t_grp_tcstat> tcstat,
+                           const rvec* gmx_restrict x,
+                           rvec* gmx_restrict xprime,
+                           VelocityType gmx_restrict v,
+                           const rvec* gmx_restrict f)
 {
     SimdReal timestep(dt);
     SimdReal lambdaSystem(tcstat[0].lambda);
@@ -457,7 +459,7 @@ static void updateMDLeapfrogSimpleSimd(int         start,
         v1 = fma(f1 * invMass1, timestep, lambdaSystem * v1);
         v2 = fma(f2 * invMass2, timestep, lambdaSystem * v2);
 
-        if (storeUpdatedVelocities == StoreUpdatedVelocities::yes)
+        if constexpr (storeUpdatedVelocities == StoreUpdatedVelocities::yes) // NOLINT // NOLINTNEXTLINE
         {
             simdStoreRvecs(v, a, v0, v1, v2);
         }
@@ -723,7 +725,7 @@ static void doUpdateMDDoNotUpdateVelocities(int         start,
                                             real        dt,
                                             const rvec* gmx_restrict x,
                                             rvec* gmx_restrict xprime,
-                                            rvec* gmx_restrict v,
+                                            const rvec* gmx_restrict v,
                                             const rvec* gmx_restrict f,
                                             const t_mdatoms&         md,
                                             const gmx_ekindata_t&    ekind)
@@ -1607,7 +1609,7 @@ void Update::Impl::update_for_constraint_virial(const t_inputrec& inputRecord,
 
             const rvec* x_rvec  = state.x.rvec_array();
             rvec*       xp_rvec = xp_.rvec_array();
-            rvec*       v_rvec  = const_cast<rvec*>(state.v.rvec_array());
+            const rvec* v_rvec  = state.v.rvec_array();
             const rvec* f_rvec  = as_rvec_array(f.unpaddedConstArrayRef().data());
 
             doUpdateMDDoNotUpdateVelocities(