Made pme_work_t opaque
authorMark Abraham <mark.j.abraham@gmail.com>
Wed, 25 Feb 2015 00:24:09 +0000 (01:24 +0100)
committerMark Abraham <mark.j.abraham@gmail.com>
Wed, 8 Apr 2015 11:25:48 +0000 (13:25 +0200)
Introduced pme_init/free_all_work to act on the newly opaque data,
and moved the helper functions that they call to pme-solve.c.

Renamed data type to pme_solve_work_t and field in gmx_pme_t to
solve_work, to help differentiate between the different "work"
structs.

Moved declaration of PME_SIMD_SOLVE to pme-solve.c, since it
is used nowhere else.

The PME work arrays are now allocated by their respective thread.

Change-Id: I02467ac2a4c2e8e6a9c45731ccec248b766609ff

src/gromacs/ewald/pme-internal.h
src/gromacs/ewald/pme-simd.h
src/gromacs/ewald/pme-solve.c
src/gromacs/ewald/pme-solve.h
src/gromacs/ewald/pme-spline-work.c
src/gromacs/ewald/pme.c

index af968b0e562d2aebcf1d10e009efc77aa3540ffb..f9e50aa5ff875d43a229a6abd986395ca433670a 100644 (file)
@@ -193,25 +193,7 @@ typedef struct {
 
 struct pme_spline_work;
 
-typedef struct {
-    /* work data for solve_pme */
-    int      nalloc;
-    real *   mhx;
-    real *   mhy;
-    real *   mhz;
-    real *   m2;
-    real *   denom;
-    real *   tmp1_alloc;
-    real *   tmp1;
-    real *   tmp2;
-    real *   eterm;
-    real *   m2inv;
-
-    real     energy_q;
-    matrix   vir_q;
-    real     energy_lj;
-    matrix   vir_lj;
-} pme_work_t;
+struct pme_solve_work_t;
 
 typedef struct gmx_pme_t {
     int           ndecompdim; /* The number of decomposition dimensions */
@@ -298,7 +280,7 @@ typedef struct gmx_pme_t {
     int                   buf_nalloc;    /* The communication buffer size */
 
     /* thread local work data for solve_pme */
-    pme_work_t *work;
+    struct pme_solve_work_t *solve_work;
 
     /* Work data for sum_qgrid */
     real *   sum_qgrid_tmp;
index 05ac9518ea993b70d741ada231b1d28f48ab8e4a..2b65de6d92a22dc06b2cfddadb646cf08e31211a 100644 (file)
 
 /* Include the SIMD macro file and then check for support */
 #include "gromacs/simd/simd.h"
-#include "gromacs/simd/simd_math.h"
-#ifdef GMX_SIMD_HAVE_REAL
-/* Turn on arbitrary width SIMD intrinsics for PME solve */
-#    define PME_SIMD_SOLVE
-#endif
 
 /* Check if we have 4-wide SIMD macro support */
 #if (defined GMX_SIMD4_HAVE_REAL)
index 1f8959d1d50bd17fa7c27f94203ae679c67f8f08..3f713c3e7ce3795022b765e0730b60944e962237 100644 (file)
 #include "gromacs/ewald/pme-simd.h"
 #include "gromacs/fft/parallel_3dfft.h"
 #include "gromacs/math/vec.h"
+#include "gromacs/simd/simd_math.h"
+#include "gromacs/utility/smalloc.h"
 
-void get_pme_ener_vir_q(const struct gmx_pme_t *pme, int nthread,
+#ifdef GMX_SIMD_HAVE_REAL
+/* Turn on arbitrary width SIMD intrinsics for PME solve */
+#    define PME_SIMD_SOLVE
+#endif
+
+struct pme_solve_work_t
+{
+    /* work data for solve_pme */
+    int      nalloc;
+    real *   mhx;
+    real *   mhy;
+    real *   mhz;
+    real *   m2;
+    real *   denom;
+    real *   tmp1_alloc;
+    real *   tmp1;
+    real *   tmp2;
+    real *   eterm;
+    real *   m2inv;
+
+    real     energy_q;
+    matrix   vir_q;
+    real     energy_lj;
+    matrix   vir_lj;
+};
+
+static void realloc_work(struct pme_solve_work_t *work, int nkx)
+{
+    if (nkx > work->nalloc)
+    {
+        int simd_width, i;
+
+        work->nalloc = nkx;
+        srenew(work->mhx, work->nalloc);
+        srenew(work->mhy, work->nalloc);
+        srenew(work->mhz, work->nalloc);
+        srenew(work->m2, work->nalloc);
+        /* Allocate an aligned pointer for SIMD operations, including extra
+         * elements at the end for padding.
+         */
+#ifdef PME_SIMD_SOLVE
+        simd_width = GMX_SIMD_REAL_WIDTH;
+#else
+        /* We can use any alignment, apart from 0, so we use 4 */
+        simd_width = 4;
+#endif
+        sfree_aligned(work->denom);
+        sfree_aligned(work->tmp1);
+        sfree_aligned(work->tmp2);
+        sfree_aligned(work->eterm);
+        snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
+        snew_aligned(work->tmp1,  work->nalloc+simd_width, simd_width*sizeof(real));
+        snew_aligned(work->tmp2,  work->nalloc+simd_width, simd_width*sizeof(real));
+        snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
+        srenew(work->m2inv, work->nalloc);
+
+        /* Init all allocated elements of denom to 1 to avoid 1/0 exceptions
+         * of simd padded elements.
+         */
+        for (i = 0; i < work->nalloc+simd_width; i++)
+        {
+            work->denom[i] = 1;
+        }
+    }
+}
+
+void pme_init_all_work(struct pme_solve_work_t **work, int nthread, int nkx)
+{
+    int thread;
+    /* Use fft5d, order after FFT is y major, z, x minor */
+
+    snew(*work, nthread);
+    /* Allocate the work arrays thread local to optimize memory access */
+#pragma omp parallel for num_threads(nthread) schedule(static)
+    for (thread = 0; thread < nthread; thread++)
+    {
+        realloc_work(&((*work)[thread]), nkx);
+    }
+}
+
+static void free_work(struct pme_solve_work_t *work)
+{
+    sfree(work->mhx);
+    sfree(work->mhy);
+    sfree(work->mhz);
+    sfree(work->m2);
+    sfree_aligned(work->denom);
+    sfree_aligned(work->tmp1);
+    sfree_aligned(work->tmp2);
+    sfree_aligned(work->eterm);
+    sfree(work->m2inv);
+}
+
+void pme_free_all_work(struct pme_solve_work_t **work, int nthread)
+{
+    int thread;
+
+    for (thread = 0; thread < nthread; thread++)
+    {
+        free_work(&(*work)[thread]);
+    }
+    sfree(work);
+    *work = NULL;
+}
+
+void get_pme_ener_vir_q(struct pme_solve_work_t *work, int nthread,
                         real *mesh_energy, matrix vir)
 {
     /* This function sums output over threads and should therefore
@@ -54,17 +161,17 @@ void get_pme_ener_vir_q(const struct gmx_pme_t *pme, int nthread,
      */
     int thread;
 
-    *mesh_energy = pme->work[0].energy_q;
-    copy_mat(pme->work[0].vir_q, vir);
+    *mesh_energy = work[0].energy_q;
+    copy_mat(work[0].vir_q, vir);
 
     for (thread = 1; thread < nthread; thread++)
     {
-        *mesh_energy += pme->work[thread].energy_q;
-        m_add(vir, pme->work[thread].vir_q, vir);
+        *mesh_energy += work[thread].energy_q;
+        m_add(vir, work[thread].vir_q, vir);
     }
 }
 
-void get_pme_ener_vir_lj(const struct gmx_pme_t *pme, int nthread,
+void get_pme_ener_vir_lj(struct pme_solve_work_t *work, int nthread,
                          real *mesh_energy, matrix vir)
 {
     /* This function sums output over threads and should therefore
@@ -72,13 +179,13 @@ void get_pme_ener_vir_lj(const struct gmx_pme_t *pme, int nthread,
      */
     int thread;
 
-    *mesh_energy = pme->work[0].energy_lj;
-    copy_mat(pme->work[0].vir_lj, vir);
+    *mesh_energy = work[0].energy_lj;
+    copy_mat(work[0].vir_lj, vir);
 
     for (thread = 1; thread < nthread; thread++)
     {
-        *mesh_energy += pme->work[thread].energy_lj;
-        m_add(vir, pme->work[thread].vir_lj, vir);
+        *mesh_energy += work[thread].energy_lj;
+        m_add(vir, work[thread].vir_lj, vir);
     }
 }
 
@@ -180,23 +287,23 @@ int solve_pme_yzx(struct gmx_pme_t *pme, t_complex *grid,
 {
     /* do recip sum over local cells in grid */
     /* y major, z middle, x minor or continuous */
-    t_complex  *p0;
-    int         kx, ky, kz, maxkx, maxky, maxkz;
-    int         nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
-    real        mx, my, mz;
-    real        factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
-    real        ets2, struct2, vfactor, ets2vf;
-    real        d1, d2, energy = 0;
-    real        by, bz;
-    real        virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
-    real        rxx, ryx, ryy, rzx, rzy, rzz;
-    pme_work_t *work;
-    real       *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
-    real        mhxk, mhyk, mhzk, m2k;
-    real        corner_fac;
-    ivec        complex_order;
-    ivec        local_ndata, local_offset, local_size;
-    real        elfac;
+    t_complex               *p0;
+    int                      kx, ky, kz, maxkx, maxky, maxkz;
+    int                      nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
+    real                     mx, my, mz;
+    real                     factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
+    real                     ets2, struct2, vfactor, ets2vf;
+    real                     d1, d2, energy = 0;
+    real                     by, bz;
+    real                     virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
+    real                     rxx, ryx, ryy, rzx, rzy, rzz;
+    struct pme_solve_work_t *work;
+    real                    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
+    real                     mhxk, mhyk, mhzk, m2k;
+    real                     corner_fac;
+    ivec                     complex_order;
+    ivec                     local_ndata, local_offset, local_size;
+    real                     elfac;
 
     elfac = ONE_4PI_EPS0/pme->epsilon_r;
 
@@ -222,7 +329,7 @@ int solve_pme_yzx(struct gmx_pme_t *pme, t_complex *grid,
     maxky = (ny+1)/2;
     maxkz = nz/2+1;
 
-    work  = &pme->work[thread];
+    work  = &pme->solve_work[thread];
     mhx   = work->mhx;
     mhy   = work->mhy;
     mhz   = work->mhz;
@@ -431,23 +538,23 @@ int solve_pme_lj_yzx(struct gmx_pme_t *pme, t_complex **grid, gmx_bool bLB,
 {
     /* do recip sum over local cells in grid */
     /* y major, z middle, x minor or continuous */
-    int         ig, gcount;
-    int         kx, ky, kz, maxkx, maxky, maxkz;
-    int         nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
-    real        mx, my, mz;
-    real        factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
-    real        ets2, ets2vf;
-    real        eterm, vterm, d1, d2, energy = 0;
-    real        by, bz;
-    real        virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
-    real        rxx, ryx, ryy, rzx, rzy, rzz;
-    real       *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
-    real        mhxk, mhyk, mhzk, m2k;
-    real        mk;
-    pme_work_t *work;
-    real        corner_fac;
-    ivec        complex_order;
-    ivec        local_ndata, local_offset, local_size;
+    int                      ig, gcount;
+    int                      kx, ky, kz, maxkx, maxky, maxkz;
+    int                      nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
+    real                     mx, my, mz;
+    real                     factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
+    real                     ets2, ets2vf;
+    real                     eterm, vterm, d1, d2, energy = 0;
+    real                     by, bz;
+    real                     virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
+    real                     rxx, ryx, ryy, rzx, rzy, rzz;
+    real                    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
+    real                     mhxk, mhyk, mhzk, m2k;
+    real                     mk;
+    struct pme_solve_work_t *work;
+    real                     corner_fac;
+    ivec                     complex_order;
+    ivec                     local_ndata, local_offset, local_size;
     nx = pme->nkx;
     ny = pme->nky;
     nz = pme->nkz;
@@ -469,7 +576,7 @@ int solve_pme_lj_yzx(struct gmx_pme_t *pme, t_complex **grid, gmx_bool bLB,
     maxky = (ny+1)/2;
     maxkz = nz/2+1;
 
-    work  = &pme->work[thread];
+    work  = &pme->solve_work[thread];
     mhx   = work->mhx;
     mhy   = work->mhy;
     mhz   = work->mhz;
index fd755cc4758009060414532181c383b2fdcf92d2..e6f3d3f33fb7b7cf411392aa0c58a5e23699cce1 100644 (file)
 extern "C" {
 #endif
 
+struct pme_solve_work_t;
 struct gmx_pme_t;
 
-void get_pme_ener_vir_q(const struct gmx_pme_t *pme, int nthread,
+/*! \brief Allocates array of work structures
+ *
+ * Note that work is the address of a pointer allocated by
+ * this function. Upon return it will point at
+ * an array of work structures.
+ */
+void pme_init_all_work(struct pme_solve_work_t **work, int nthread, int nkx);
+
+/*! \brief Frees array of work structures
+ *
+ * Frees work and sets it to NULL. */
+void pme_free_all_work(struct pme_solve_work_t **work, int nthread);
+
+/*! \brief Get energy and virial for electrostatics
+ *
+ * Note that work is an array of work structures
+ */
+void get_pme_ener_vir_q(struct pme_solve_work_t *work, int nthread,
                         real *mesh_energy, matrix vir);
 
-void get_pme_ener_vir_lj(const struct gmx_pme_t *pme, int nthread,
+/*! \brief Get energy and virial for L-J
+ *
+ * Note that work is an array of work structures
+ */
+void get_pme_ener_vir_lj(struct pme_solve_work_t *work, int nthread,
                          real *mesh_energy, matrix vir);
 
 int solve_pme_yzx(struct gmx_pme_t *pme, t_complex *grid,
index 137fc0a5d0211e8d069ecd51de6f3ea08e319971..af09e7f78b4deb55eaba30b3bdf0fcd11917b0ef 100644 (file)
@@ -40,6 +40,7 @@
 #include "pme-spline-work.h"
 
 #include "gromacs/ewald/pme-simd.h"
+#include "gromacs/utility/real.h"
 #include "gromacs/utility/smalloc.h"
 
 struct pme_spline_work *make_pme_spline_work(int gmx_unused order)
index 1d7b33dce6152a4aa9eca798f29711fdfb9a3de6..91cff2c4977a6fc5d30c412aeb07cf505372012d 100644 (file)
@@ -73,7 +73,6 @@
 #include "gromacs/ewald/pme-grid.h"
 #include "gromacs/ewald/pme-internal.h"
 #include "gromacs/ewald/pme-redistribute.h"
-#include "gromacs/ewald/pme-simd.h"
 #include "gromacs/ewald/pme-solve.h"
 #include "gromacs/ewald/pme-spline-work.h"
 #include "gromacs/ewald/pme-spread.h"
 #define GMX_CACHE_SEP 64
 
 
-static void realloc_work(pme_work_t *work, int nkx)
-{
-    int simd_width, i;
-
-    if (nkx > work->nalloc)
-    {
-        work->nalloc = nkx;
-        srenew(work->mhx, work->nalloc);
-        srenew(work->mhy, work->nalloc);
-        srenew(work->mhz, work->nalloc);
-        srenew(work->m2, work->nalloc);
-        /* Allocate an aligned pointer for SIMD operations, including extra
-         * elements at the end for padding.
-         */
-#ifdef PME_SIMD_SOLVE
-        simd_width = GMX_SIMD_REAL_WIDTH;
-#else
-        /* We can use any alignment, apart from 0, so we use 4 */
-        simd_width = 4;
-#endif
-        sfree_aligned(work->denom);
-        sfree_aligned(work->tmp1);
-        sfree_aligned(work->tmp2);
-        sfree_aligned(work->eterm);
-        snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
-        snew_aligned(work->tmp1,  work->nalloc+simd_width, simd_width*sizeof(real));
-        snew_aligned(work->tmp2,  work->nalloc+simd_width, simd_width*sizeof(real));
-        snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
-        srenew(work->m2inv, work->nalloc);
-#ifndef NDEBUG
-        for (i = 0; i < work->nalloc+simd_width; i++)
-        {
-            work->denom[i] = 1; /* init to 1 to avoid 1/0 exceptions of simd padded elements */
-        }
-#endif
-    }
-}
-
-
-static void free_work(pme_work_t *work)
-{
-    sfree(work->mhx);
-    sfree(work->mhy);
-    sfree(work->mhz);
-    sfree(work->m2);
-    sfree_aligned(work->denom);
-    sfree_aligned(work->tmp1);
-    sfree_aligned(work->tmp2);
-    sfree_aligned(work->eterm);
-    sfree(work->m2inv);
-}
-
 static void setup_coordinate_communication(pme_atomcomm_t *atc)
 {
     int nslab, n, i;
@@ -215,11 +162,7 @@ int gmx_pme_destroy(FILE *log, struct gmx_pme_t **pmedata)
     sfree((*pmedata)->lb_buf1);
     sfree((*pmedata)->lb_buf2);
 
-    for (thread = 0; thread < (*pmedata)->nthread; thread++)
-    {
-        free_work(&(*pmedata)->work[thread]);
-    }
-    sfree((*pmedata)->work);
+    pme_free_all_work(&(*pmedata)->solve_work, (*pmedata)->nthread);
 
     sfree(*pmedata);
     *pmedata = NULL;
@@ -827,17 +770,7 @@ int gmx_pme_init(struct gmx_pme_t **pmedata,
     pme->lb_buf2       = NULL;
     pme->lb_buf_nalloc = 0;
 
-    {
-        int thread;
-
-        /* Use fft5d, order after FFT is y major, z, x minor */
-
-        snew(pme->work, pme->nthread);
-        for (thread = 0; thread < pme->nthread; thread++)
-        {
-            realloc_work(&pme->work[thread], pme->nkx);
-        }
-    }
+    pme_init_all_work(&pme->solve_work, pme->nthread, pme->nkx);
 
     *pmedata = pme;
 
@@ -1274,11 +1207,11 @@ int gmx_pme_do(struct gmx_pme_t *pme,
              */
             if (grid_index < 2)
             {
-                get_pme_ener_vir_q(pme, pme->nthread, &energy_AB[grid_index], vir_AB[grid_index]);
+                get_pme_ener_vir_q(pme->solve_work, pme->nthread, &energy_AB[grid_index], vir_AB[grid_index]);
             }
             else
             {
-                get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[grid_index], vir_AB[grid_index]);
+                get_pme_ener_vir_lj(pme->solve_work, pme->nthread, &energy_AB[grid_index], vir_AB[grid_index]);
             }
         }
         bFirst = FALSE;
@@ -1453,7 +1386,7 @@ int gmx_pme_do(struct gmx_pme_t *pme,
                 /* This should only be called on the master thread and
                  * after the threads have synchronized.
                  */
-                get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[2+fep_state], vir_AB[2+fep_state]);
+                get_pme_ener_vir_lj(pme->solve_work, pme->nthread, &energy_AB[2+fep_state], vir_AB[2+fep_state]);
             }
 
             if (bCalcF)