#include "gromacs/ewald/pme-simd.h"
#include "gromacs/fft/parallel_3dfft.h"
#include "gromacs/math/vec.h"
+#include "gromacs/simd/simd_math.h"
+#include "gromacs/utility/smalloc.h"
-void get_pme_ener_vir_q(const struct gmx_pme_t *pme, int nthread,
+#ifdef GMX_SIMD_HAVE_REAL
+/* Turn on arbitrary width SIMD intrinsics for PME solve */
+# define PME_SIMD_SOLVE
+#endif
+
+struct pme_solve_work_t
+{
+ /* work data for solve_pme */
+ int nalloc;
+ real * mhx;
+ real * mhy;
+ real * mhz;
+ real * m2;
+ real * denom;
+ real * tmp1_alloc;
+ real * tmp1;
+ real * tmp2;
+ real * eterm;
+ real * m2inv;
+
+ real energy_q;
+ matrix vir_q;
+ real energy_lj;
+ matrix vir_lj;
+};
+
+static void realloc_work(struct pme_solve_work_t *work, int nkx)
+{
+ if (nkx > work->nalloc)
+ {
+ int simd_width, i;
+
+ work->nalloc = nkx;
+ srenew(work->mhx, work->nalloc);
+ srenew(work->mhy, work->nalloc);
+ srenew(work->mhz, work->nalloc);
+ srenew(work->m2, work->nalloc);
+ /* Allocate an aligned pointer for SIMD operations, including extra
+ * elements at the end for padding.
+ */
+#ifdef PME_SIMD_SOLVE
+ simd_width = GMX_SIMD_REAL_WIDTH;
+#else
+ /* We can use any alignment, apart from 0, so we use 4 */
+ simd_width = 4;
+#endif
+ sfree_aligned(work->denom);
+ sfree_aligned(work->tmp1);
+ sfree_aligned(work->tmp2);
+ sfree_aligned(work->eterm);
+ snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
+ snew_aligned(work->tmp1, work->nalloc+simd_width, simd_width*sizeof(real));
+ snew_aligned(work->tmp2, work->nalloc+simd_width, simd_width*sizeof(real));
+ snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
+ srenew(work->m2inv, work->nalloc);
+
+ /* Init all allocated elements of denom to 1 to avoid 1/0 exceptions
+ * of simd padded elements.
+ */
+ for (i = 0; i < work->nalloc+simd_width; i++)
+ {
+ work->denom[i] = 1;
+ }
+ }
+}
+
+void pme_init_all_work(struct pme_solve_work_t **work, int nthread, int nkx)
+{
+ int thread;
+ /* Use fft5d, order after FFT is y major, z, x minor */
+
+ snew(*work, nthread);
+ /* Allocate the work arrays thread local to optimize memory access */
+#pragma omp parallel for num_threads(nthread) schedule(static)
+ for (thread = 0; thread < nthread; thread++)
+ {
+ realloc_work(&((*work)[thread]), nkx);
+ }
+}
+
+static void free_work(struct pme_solve_work_t *work)
+{
+ sfree(work->mhx);
+ sfree(work->mhy);
+ sfree(work->mhz);
+ sfree(work->m2);
+ sfree_aligned(work->denom);
+ sfree_aligned(work->tmp1);
+ sfree_aligned(work->tmp2);
+ sfree_aligned(work->eterm);
+ sfree(work->m2inv);
+}
+
+void pme_free_all_work(struct pme_solve_work_t **work, int nthread)
+{
+ int thread;
+
+ for (thread = 0; thread < nthread; thread++)
+ {
+ free_work(&(*work)[thread]);
+ }
+ sfree(work);
+ *work = NULL;
+}
+
+void get_pme_ener_vir_q(struct pme_solve_work_t *work, int nthread,
real *mesh_energy, matrix vir)
{
/* This function sums output over threads and should therefore
*/
int thread;
- *mesh_energy = pme->work[0].energy_q;
- copy_mat(pme->work[0].vir_q, vir);
+ *mesh_energy = work[0].energy_q;
+ copy_mat(work[0].vir_q, vir);
for (thread = 1; thread < nthread; thread++)
{
- *mesh_energy += pme->work[thread].energy_q;
- m_add(vir, pme->work[thread].vir_q, vir);
+ *mesh_energy += work[thread].energy_q;
+ m_add(vir, work[thread].vir_q, vir);
}
}
-void get_pme_ener_vir_lj(const struct gmx_pme_t *pme, int nthread,
+void get_pme_ener_vir_lj(struct pme_solve_work_t *work, int nthread,
real *mesh_energy, matrix vir)
{
/* This function sums output over threads and should therefore
*/
int thread;
- *mesh_energy = pme->work[0].energy_lj;
- copy_mat(pme->work[0].vir_lj, vir);
+ *mesh_energy = work[0].energy_lj;
+ copy_mat(work[0].vir_lj, vir);
for (thread = 1; thread < nthread; thread++)
{
- *mesh_energy += pme->work[thread].energy_lj;
- m_add(vir, pme->work[thread].vir_lj, vir);
+ *mesh_energy += work[thread].energy_lj;
+ m_add(vir, work[thread].vir_lj, vir);
}
}
{
/* do recip sum over local cells in grid */
/* y major, z middle, x minor or continuous */
- t_complex *p0;
- int kx, ky, kz, maxkx, maxky, maxkz;
- int nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
- real mx, my, mz;
- real factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
- real ets2, struct2, vfactor, ets2vf;
- real d1, d2, energy = 0;
- real by, bz;
- real virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
- real rxx, ryx, ryy, rzx, rzy, rzz;
- pme_work_t *work;
- real *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
- real mhxk, mhyk, mhzk, m2k;
- real corner_fac;
- ivec complex_order;
- ivec local_ndata, local_offset, local_size;
- real elfac;
+ t_complex *p0;
+ int kx, ky, kz, maxkx, maxky, maxkz;
+ int nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
+ real mx, my, mz;
+ real factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
+ real ets2, struct2, vfactor, ets2vf;
+ real d1, d2, energy = 0;
+ real by, bz;
+ real virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
+ real rxx, ryx, ryy, rzx, rzy, rzz;
+ struct pme_solve_work_t *work;
+ real *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
+ real mhxk, mhyk, mhzk, m2k;
+ real corner_fac;
+ ivec complex_order;
+ ivec local_ndata, local_offset, local_size;
+ real elfac;
elfac = ONE_4PI_EPS0/pme->epsilon_r;
maxky = (ny+1)/2;
maxkz = nz/2+1;
- work = &pme->work[thread];
+ work = &pme->solve_work[thread];
mhx = work->mhx;
mhy = work->mhy;
mhz = work->mhz;
{
/* do recip sum over local cells in grid */
/* y major, z middle, x minor or continuous */
- int ig, gcount;
- int kx, ky, kz, maxkx, maxky, maxkz;
- int nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
- real mx, my, mz;
- real factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
- real ets2, ets2vf;
- real eterm, vterm, d1, d2, energy = 0;
- real by, bz;
- real virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
- real rxx, ryx, ryy, rzx, rzy, rzz;
- real *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
- real mhxk, mhyk, mhzk, m2k;
- real mk;
- pme_work_t *work;
- real corner_fac;
- ivec complex_order;
- ivec local_ndata, local_offset, local_size;
+ int ig, gcount;
+ int kx, ky, kz, maxkx, maxky, maxkz;
+ int nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
+ real mx, my, mz;
+ real factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
+ real ets2, ets2vf;
+ real eterm, vterm, d1, d2, energy = 0;
+ real by, bz;
+ real virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
+ real rxx, ryx, ryy, rzx, rzy, rzz;
+ real *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
+ real mhxk, mhyk, mhzk, m2k;
+ real mk;
+ struct pme_solve_work_t *work;
+ real corner_fac;
+ ivec complex_order;
+ ivec local_ndata, local_offset, local_size;
nx = pme->nkx;
ny = pme->nky;
nz = pme->nkz;
maxky = (ny+1)/2;
maxkz = nz/2+1;
- work = &pme->work[thread];
+ work = &pme->solve_work[thread];
mhx = work->mhx;
mhy = work->mhy;
mhz = work->mhz;
#include "gromacs/ewald/pme-grid.h"
#include "gromacs/ewald/pme-internal.h"
#include "gromacs/ewald/pme-redistribute.h"
-#include "gromacs/ewald/pme-simd.h"
#include "gromacs/ewald/pme-solve.h"
#include "gromacs/ewald/pme-spline-work.h"
#include "gromacs/ewald/pme-spread.h"
#define GMX_CACHE_SEP 64
-static void realloc_work(pme_work_t *work, int nkx)
-{
- int simd_width, i;
-
- if (nkx > work->nalloc)
- {
- work->nalloc = nkx;
- srenew(work->mhx, work->nalloc);
- srenew(work->mhy, work->nalloc);
- srenew(work->mhz, work->nalloc);
- srenew(work->m2, work->nalloc);
- /* Allocate an aligned pointer for SIMD operations, including extra
- * elements at the end for padding.
- */
-#ifdef PME_SIMD_SOLVE
- simd_width = GMX_SIMD_REAL_WIDTH;
-#else
- /* We can use any alignment, apart from 0, so we use 4 */
- simd_width = 4;
-#endif
- sfree_aligned(work->denom);
- sfree_aligned(work->tmp1);
- sfree_aligned(work->tmp2);
- sfree_aligned(work->eterm);
- snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
- snew_aligned(work->tmp1, work->nalloc+simd_width, simd_width*sizeof(real));
- snew_aligned(work->tmp2, work->nalloc+simd_width, simd_width*sizeof(real));
- snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
- srenew(work->m2inv, work->nalloc);
-#ifndef NDEBUG
- for (i = 0; i < work->nalloc+simd_width; i++)
- {
- work->denom[i] = 1; /* init to 1 to avoid 1/0 exceptions of simd padded elements */
- }
-#endif
- }
-}
-
-
-static void free_work(pme_work_t *work)
-{
- sfree(work->mhx);
- sfree(work->mhy);
- sfree(work->mhz);
- sfree(work->m2);
- sfree_aligned(work->denom);
- sfree_aligned(work->tmp1);
- sfree_aligned(work->tmp2);
- sfree_aligned(work->eterm);
- sfree(work->m2inv);
-}
-
static void setup_coordinate_communication(pme_atomcomm_t *atc)
{
int nslab, n, i;
sfree((*pmedata)->lb_buf1);
sfree((*pmedata)->lb_buf2);
- for (thread = 0; thread < (*pmedata)->nthread; thread++)
- {
- free_work(&(*pmedata)->work[thread]);
- }
- sfree((*pmedata)->work);
+ pme_free_all_work(&(*pmedata)->solve_work, (*pmedata)->nthread);
sfree(*pmedata);
*pmedata = NULL;
pme->lb_buf2 = NULL;
pme->lb_buf_nalloc = 0;
- {
- int thread;
-
- /* Use fft5d, order after FFT is y major, z, x minor */
-
- snew(pme->work, pme->nthread);
- for (thread = 0; thread < pme->nthread; thread++)
- {
- realloc_work(&pme->work[thread], pme->nkx);
- }
- }
+ pme_init_all_work(&pme->solve_work, pme->nthread, pme->nkx);
*pmedata = pme;
*/
if (grid_index < 2)
{
- get_pme_ener_vir_q(pme, pme->nthread, &energy_AB[grid_index], vir_AB[grid_index]);
+ get_pme_ener_vir_q(pme->solve_work, pme->nthread, &energy_AB[grid_index], vir_AB[grid_index]);
}
else
{
- get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[grid_index], vir_AB[grid_index]);
+ get_pme_ener_vir_lj(pme->solve_work, pme->nthread, &energy_AB[grid_index], vir_AB[grid_index]);
}
}
bFirst = FALSE;
/* This should only be called on the master thread and
* after the threads have synchronized.
*/
- get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[2+fep_state], vir_AB[2+fep_state]);
+ get_pme_ener_vir_lj(pme->solve_work, pme->nthread, &energy_AB[2+fep_state], vir_AB[2+fep_state]);
}
if (bCalcF)