Use ArrayRef(WithPadding) in constraint code
[alexxy/gromacs.git] / src / gromacs / mdlib / lincs.cpp
index 98fdc2e03c04bea63cfbd928990de56d222c079f..b425d12eafdd4429b045e3144ffcc2bc1b963038 100644 (file)
@@ -569,17 +569,17 @@ static void gmx_simdcall calc_dr_x_f_simd(int                           b0,
 #endif // GMX_SIMD_HAVE_REAL
 
 /*! \brief LINCS projection, works on derivatives of the coordinates. */
-static void do_lincsp(const rvec*        x,
-                      rvec*              f,
-                      rvec*              fp,
-                      t_pbc*             pbc,
-                      Lincs*             lincsd,
-                      int                th,
-                      real*              invmass,
-                      ConstraintVariable econq,
-                      bool               bCalcDHDL,
-                      bool               bCalcVir,
-                      tensor             rmdf)
+static void do_lincsp(ArrayRefWithPadding<const RVec> xPadded,
+                      ArrayRefWithPadding<RVec>       fPadded,
+                      ArrayRef<RVec>                  fp,
+                      t_pbc*                          pbc,
+                      Lincs*                          lincsd,
+                      int                             th,
+                      real*                           invmass,
+                      ConstraintVariable              econq,
+                      bool                            bCalcDHDL,
+                      bool                            bCalcVir,
+                      tensor                          rmdf)
 {
     const int b0 = lincsd->task[th].b0;
     const int b1 = lincsd->task[th].b1;
@@ -608,6 +608,9 @@ static void do_lincsp(const rvec*        x,
     gmx::ArrayRef<real> rhs2 = lincsd->tmp2;
     gmx::ArrayRef<real> sol  = lincsd->tmp3;
 
+    const rvec* x = as_rvec_array(xPadded.paddedArrayRef().data());
+    rvec*       f = as_rvec_array(fPadded.paddedArrayRef().data());
+
 #if GMX_SIMD_HAVE_REAL
     /* This SIMD code does the same as the plain-C code after the #else.
      * The only difference is that we always call pbc code, as with SIMD
@@ -706,8 +709,8 @@ static void do_lincsp(const rvec*        x,
     /* When constraining forces, we should not use mass weighting,
      * so we pass invmass=NULL, which results in the use of 1 for all atoms.
      */
-    lincs_update_atoms(lincsd, th, 1.0, sol, r,
-                       (econq != ConstraintVariable::Force) ? invmass : nullptr, fp);
+    lincs_update_atoms(lincsd, th, 1.0, sol, r, (econq != ConstraintVariable::Force) ? invmass : nullptr,
+                       as_rvec_array(fp.data()));
 
     if (bCalcDHDL)
     {
@@ -945,22 +948,26 @@ static void gmx_simdcall calc_dist_iter_simd(int                           b0,
 #endif // GMX_SIMD_HAVE_REAL
 
 //! Implements LINCS constraining.
-static void do_lincs(const rvec*      x,
-                     rvec*            xp,
-                     const matrix     box,
-                     t_pbc*           pbc,
-                     Lincs*           lincsd,
-                     int              th,
-                     const real*      invmass,
-                     const t_commrec* cr,
-                     bool             bCalcDHDL,
-                     real             wangle,
-                     bool*            bWarn,
-                     real             invdt,
-                     rvec* gmx_restrict v,
-                     bool               bCalcVir,
-                     tensor             vir_r_m_dr)
+static void do_lincs(ArrayRefWithPadding<const RVec> xPadded,
+                     ArrayRefWithPadding<RVec>       xpPadded,
+                     const matrix                    box,
+                     t_pbc*                          pbc,
+                     Lincs*                          lincsd,
+                     int                             th,
+                     const real*                     invmass,
+                     const t_commrec*                cr,
+                     bool                            bCalcDHDL,
+                     real                            wangle,
+                     bool*                           bWarn,
+                     real                            invdt,
+                     ArrayRef<RVec>                  vRef,
+                     bool                            bCalcVir,
+                     tensor                          vir_r_m_dr)
 {
+    const rvec* x        = as_rvec_array(xPadded.paddedArrayRef().data());
+    rvec*       xp       = as_rvec_array(xpPadded.paddedArrayRef().data());
+    rvec* gmx_restrict v = as_rvec_array(vRef.data());
+
     const int b0 = lincsd->task[th].b0;
     const int b1 = lincsd->task[th].b1;
 
@@ -1098,7 +1105,7 @@ static void do_lincs(const rvec*      x,
                 /* Communicate the corrected non-local coordinates */
                 if (DOMAINDECOMP(cr))
                 {
-                    dd_move_x_constraints(cr->dd, box, xp, nullptr, FALSE);
+                    dd_move_x_constraints(cr->dd, box, xpPadded.unpaddedArrayRef(), ArrayRef<RVec>(), FALSE);
                 }
             }
 #pragma omp barrier
@@ -2124,8 +2131,8 @@ void set_lincs(const t_idef& idef, const t_mdatoms& md, bool bDynamics, const t_
 
 //! Issues a warning when LINCS constraints cannot be satisfied.
 static void lincs_warning(gmx_domdec_t*                 dd,
-                          const rvec*                   x,
-                          rvec*                         xprime,
+                          ArrayRef<const RVec>          x,
+                          ArrayRef<const RVec>          xprime,
                           t_pbc*                        pbc,
                           int                           ncons,
                           gmx::ArrayRef<const AtomPair> atoms,
@@ -2192,7 +2199,7 @@ struct LincsDeviations
 };
 
 //! Determine how well the constraints have been satisfied.
-static LincsDeviations makeLincsDeviations(const Lincs& lincsd, const rvec* x, const t_pbc* pbc)
+static LincsDeviations makeLincsDeviations(const Lincs& lincsd, ArrayRef<const RVec> x, const t_pbc* pbc)
 {
     LincsDeviations                result;
     const ArrayRef<const AtomPair> atoms  = lincsd.atoms;
@@ -2241,28 +2248,28 @@ static LincsDeviations makeLincsDeviations(const Lincs& lincsd, const rvec* x, c
     return result;
 }
 
-bool constrain_lincs(bool                  computeRmsd,
-                     const t_inputrec&     ir,
-                     int64_t               step,
-                     Lincs*                lincsd,
-                     const t_mdatoms&      md,
-                     const t_commrec*      cr,
-                     const gmx_multisim_t* ms,
-                     const rvec*           x,
-                     rvec*                 xprime,
-                     rvec*                 min_proj,
-                     const matrix          box,
-                     t_pbc*                pbc,
-                     real                  lambda,
-                     real*                 dvdlambda,
-                     real                  invdt,
-                     rvec*                 v,
-                     bool                  bCalcVir,
-                     tensor                vir_r_m_dr,
-                     ConstraintVariable    econq,
-                     t_nrnb*               nrnb,
-                     int                   maxwarn,
-                     int*                  warncount)
+bool constrain_lincs(bool                            computeRmsd,
+                     const t_inputrec&               ir,
+                     int64_t                         step,
+                     Lincs*                          lincsd,
+                     const t_mdatoms&                md,
+                     const t_commrec*                cr,
+                     const gmx_multisim_t*           ms,
+                     ArrayRefWithPadding<const RVec> xPadded,
+                     ArrayRefWithPadding<RVec>       xprimePadded,
+                     ArrayRef<RVec>                  min_proj,
+                     const matrix                    box,
+                     t_pbc*                          pbc,
+                     real                            lambda,
+                     real*                           dvdlambda,
+                     real                            invdt,
+                     ArrayRef<RVec>                  v,
+                     bool                            bCalcVir,
+                     tensor                          vir_r_m_dr,
+                     ConstraintVariable              econq,
+                     t_nrnb*                         nrnb,
+                     int                             maxwarn,
+                     int*                            warncount)
 {
     bool bOK = TRUE;
 
@@ -2282,6 +2289,9 @@ bool constrain_lincs(bool                  computeRmsd,
         return bOK;
     }
 
+    ArrayRef<const RVec> x      = xPadded.unpaddedArrayRef();
+    ArrayRef<RVec>       xprime = xprimePadded.unpaddedArrayRef();
+
     if (econq == ConstraintVariable::Positions)
     {
         /* We can't use bCalcDHDL here, since NULL can be passed for dvdlambda
@@ -2355,8 +2365,9 @@ bool constrain_lincs(bool                  computeRmsd,
 
                 clear_mat(lincsd->task[th].vir_r_m_dr);
 
-                do_lincs(x, xprime, box, pbc, lincsd, th, md.invmass, cr, bCalcDHDL, ir.LincsWarnAngle,
-                         &bWarn, invdt, v, bCalcVir, th == 0 ? vir_r_m_dr : lincsd->task[th].vir_r_m_dr);
+                do_lincs(xPadded, xprimePadded, box, pbc, lincsd, th, md.invmass, cr, bCalcDHDL,
+                         ir.LincsWarnAngle, &bWarn, invdt, v, bCalcVir,
+                         th == 0 ? vir_r_m_dr : lincsd->task[th].vir_r_m_dr);
             }
             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
         }
@@ -2433,8 +2444,8 @@ bool constrain_lincs(bool                  computeRmsd,
             {
                 int th = gmx_omp_get_thread_num();
 
-                do_lincsp(x, xprime, min_proj, pbc, lincsd, th, md.invmass, econq, bCalcDHDL,
-                          bCalcVir, th == 0 ? vir_r_m_dr : lincsd->task[th].vir_r_m_dr);
+                do_lincsp(xPadded, xprimePadded, min_proj, pbc, lincsd, th, md.invmass, econq,
+                          bCalcDHDL, bCalcVir, th == 0 ? vir_r_m_dr : lincsd->task[th].vir_r_m_dr);
             }
             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
         }
@@ -2475,7 +2486,7 @@ bool constrain_lincs(bool                  computeRmsd,
     {
         inc_nrnb(nrnb, eNR_LINCSMAT, lincsd->nOrder * lincsd->ncc_triangle);
     }
-    if (v)
+    if (!v.empty())
     {
         inc_nrnb(nrnb, eNR_CONSTR_V, lincsd->nc_real * 2);
     }