Unify NB atoms and staging data structures in OpenCL, CUDA and SYCL
[alexxy/gromacs.git] / src / gromacs / nbnxm / opencl / nbnxm_ocl_types.h
index 751e3529624533bb33f025dd175ff58df4cb61ff..474d90700d95d3d6aba8600a571f3f235a43884a 100644 (file)
@@ -74,62 +74,6 @@ enum ePruneKind
     ePruneNR
 };
 
-/*! \internal
- * \brief Staging area for temporary data downloaded from the GPU.
- *
- *  The energies/shift forces get downloaded here first, before getting added
- *  to the CPU-side aggregate values.
- */
-struct nb_staging_t
-{
-    //! LJ energy
-    float* e_lj = nullptr;
-    //! electrostatic energy
-    float* e_el = nullptr;
-    //! float3 buffer with shift forces
-    Float3* fshift = nullptr;
-};
-
-/*! \internal
- * \brief Nonbonded atom data - both inputs and outputs.
- */
-typedef struct cl_atomdata
-{
-    //! number of atoms
-    int natoms;
-    //! number of local atoms
-    int natoms_local;
-    //! allocation size for the atom data (xq, f)
-    int nalloc;
-
-    //! float4 buffer with atom coordinates + charges, size natoms
-    DeviceBuffer<Float4> xq;
-
-    //! float3 buffer with force output array, size natoms
-    DeviceBuffer<Float3> f;
-
-    //! LJ energy output, size 1
-    DeviceBuffer<float> e_lj;
-    //! Electrostatics energy input, size 1
-    DeviceBuffer<float> e_el;
-
-    //! float3 buffer with shift forces
-    DeviceBuffer<Float3> fshift;
-
-    //! number of atom types
-    int ntypes;
-    //! int buffer with atom type indices, size natoms
-    DeviceBuffer<int> atom_types;
-    //! float2 buffer with sqrt(c6),sqrt(c12), size natoms
-    DeviceBuffer<Float2> lj_comb;
-
-    //! float3 buffer with shifts values
-    DeviceBuffer<Float3> shift_vec;
-
-    //! true if the shift vector has been uploaded
-    bool bShiftVecUploaded;
-} cl_atomdata_t;
-
 /*! \internal
  * \brief Data structure shared between the OpenCL device code and OpenCL host code
  *
@@ -229,13 +173,13 @@ struct NbnxmGpu
     bool bNonLocalStreamDoneMarked = false;
 
     //! atom data
-    cl_atomdata_t* atdat = nullptr;
+    NBAtomData* atdat = nullptr;
     //! parameters required for the non-bonded calc.
     NBParamGpu* nbparam = nullptr;
     //! pair-list data structures (local and non-local)
     gmx::EnumerationArray<Nbnxm::InteractionLocality, Nbnxm::gpu_plist*> plist = { nullptr };
     //! staging area where fshift/energies get downloaded
-    nb_staging_t nbst;
+    NBStagingData nbst;
 
     //! local and non-local GPU queues
     gmx::EnumerationArray<Nbnxm::InteractionLocality, const DeviceStream*> deviceStreams;