Turn t_forcerec.shift_vec into an std::vector of gmx::RVec
[alexxy/gromacs.git] / src / gromacs / nbnxm / kernels_reference / kernel_gpu_ref.cpp
index 8fa8d6797a9fdfd52d7f2395b807227e601ee223..6002eb5be7ebe912d50b28d0a90548ead7e88e03 100644 (file)
 
 static constexpr int c_clSize = c_nbnxnGpuClusterSize;
 
-void nbnxn_kernel_gpu_ref(const NbnxnPairlistGpu*    nbl,
-                          const nbnxn_atomdata_t*    nbat,
-                          const interaction_const_t* iconst,
-                          rvec*                      shift_vec,
-                          const gmx::StepWorkload&   stepWork,
-                          int                        clearF,
-                          gmx::ArrayRef<real>        f,
-                          real*                      fshift,
-                          real*                      Vc,
-                          real*                      Vvdw)
+void nbnxn_kernel_gpu_ref(const NbnxnPairlistGpu*        nbl,
+                          const nbnxn_atomdata_t*        nbat,
+                          const interaction_const_t*     iconst,
+                          gmx::ArrayRef<const gmx::RVec> shiftvec,
+                          const gmx::StepWorkload&       stepWork,
+                          int                            clearF,
+                          gmx::ArrayRef<real>            f,
+                          real*                          fshift,
+                          real*                          Vc,
+                          real*                          Vvdw)
 {
     real                fscal = NAN;
     real                vcoul = 0;
@@ -97,7 +97,6 @@ void nbnxn_kernel_gpu_ref(const NbnxnPairlistGpu*    nbl,
 
     const int*  type     = nbat->params().type.data();
     const real  facel    = iconst->epsfac;
-    const real* shiftvec = shift_vec[0];
     const real* vdwparam = nbat->params().nbfp.data();
     const int   ntype    = nbat->params().numTypes;
 
@@ -109,10 +108,11 @@ void nbnxn_kernel_gpu_ref(const NbnxnPairlistGpu*    nbl,
 
     for (const nbnxn_sci_t& nbln : nbl->sci)
     {
-        const int  ish3     = 3 * nbln.shift;
-        const real shX      = shiftvec[ish3];
-        const real shY      = shiftvec[ish3 + 1];
-        const real shZ      = shiftvec[ish3 + 2];
+        const int  ish      = nbln.shift;
+        const int  ish3     = DIM * ish;
+        const real shX      = shiftvec[ish][XX];
+        const real shY      = shiftvec[ish][YY];
+        const real shZ      = shiftvec[ish][ZZ];
         const int  cj4_ind0 = nbln.cj4_ind_start;
         const int  cj4_ind1 = nbln.cj4_ind_end;
         const int  sci      = nbln.sci;