Unify NB atoms and staging data structures in OpenCL, CUDA and SYCL
[alexxy/gromacs.git] / src / gromacs / nbnxm / cuda / nbnxm_cuda_kernel.cuh
index 688e094715ea6912efb5bb918637ffa153a052dd..344e971c845afeb7e5a25abb65e7af8e3dcf8b32 100644 (file)
@@ -159,7 +159,7 @@ __launch_bounds__(THREADS_PER_BLOCK)
         __global__ void NB_KERNEL_FUNC_NAME(nbnxn_kernel, _F_cuda)
 #    endif /* CALC_ENERGIES */
 #endif     /* PRUNE_NBL */
-                (const cu_atomdata_t atdat, const NBParamGpu nbparam, const Nbnxm::gpu_plist plist, bool bCalcFshift)
+                (const NBAtomData atdat, const NBParamGpu nbparam, const Nbnxm::gpu_plist plist, bool bCalcFshift)
 #ifdef FUNCTION_DECLARATION_ONLY
                         ; /* Only do function declaration, omit the function body. */
 #else
@@ -172,15 +172,15 @@ __launch_bounds__(THREADS_PER_BLOCK)
             nbnxn_cj4_t* pl_cj4      = plist.cj4;
     const nbnxn_excl_t*  excl        = plist.excl;
 #    ifndef LJ_COMB
-    const int*           atom_types  = atdat.atom_types;
-    int                  ntypes      = atdat.ntypes;
+    const int*           atom_types  = atdat.atomTypes;
+    int                  ntypes      = atdat.numTypes;
 #    else
-    const float2* lj_comb = atdat.lj_comb;
+    const float2* lj_comb = atdat.ljComb;
     float2        ljcp_i, ljcp_j;
 #    endif
     const float4*        xq          = atdat.xq;
     float3*              f           = asFloat3(atdat.f);
-    const float3*        shift_vec   = asFloat3(atdat.shift_vec);
+    const float3*        shift_vec   = asFloat3(atdat.shiftVec);
     float                rcoulomb_sq = nbparam.rcoulomb_sq;
 #    ifdef VDW_CUTOFF_CHECK
     float                rvdw_sq     = nbparam.rvdw_sq;
@@ -207,8 +207,8 @@ __launch_bounds__(THREADS_PER_BLOCK)
 #        else
     float reactionFieldShift = nbparam.c_rf;
 #        endif /* EL_EWALD_ANY */
-    float*               e_lj        = atdat.e_lj;
-    float*               e_el        = atdat.e_el;
+    float*               e_lj        = atdat.eLJ;
+    float*               e_el        = atdat.eElec;
 #    endif     /* CALC_ENERGIES */
 
     /* thread/block/warp id-s */
@@ -649,8 +649,8 @@ __launch_bounds__(THREADS_PER_BLOCK)
     /* add up local shift forces into global mem, tidxj indexes x,y,z */
     if (bCalcFshift && (tidxj & 3) < 3)
     {
-        float3* fshift = asFloat3(atdat.fshift);
-        atomicAdd(&(fshift[nb_sci.shift].x) + (tidxj & 3), fshift_buf);
+        float3* fShift = asFloat3(atdat.fShift);
+        atomicAdd(&(fShift[nb_sci.shift].x) + (tidxj & 3), fshift_buf);
     }
 
 #    ifdef CALC_ENERGIES