2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5 * Copyright (c) 2001-2004, The GROMACS development team.
6 * Copyright (c) 2013,2014,2015,2016,2017, by the GROMACS development team, led by
7 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8 * and including many others, as listed in the AUTHORS file in the
9 * top-level source directory and at http://www.gromacs.org.
11 * GROMACS is free software; you can redistribute it and/or
12 * modify it under the terms of the GNU Lesser General Public License
13 * as published by the Free Software Foundation; either version 2.1
14 * of the License, or (at your option) any later version.
16 * GROMACS is distributed in the hope that it will be useful,
17 * but WITHOUT ANY WARRANTY; without even the implied warranty of
18 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 * Lesser General Public License for more details.
21 * You should have received a copy of the GNU Lesser General Public
22 * License along with GROMACS; if not, see
23 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
26 * If you want to redistribute modifications to GROMACS, please
27 * consider that scientific software is very special. Version
28 * control is crucial - bugs must be traceable. We will be happy to
29 * consider code for inclusion in the official distribution, but
30 * derived work must not be called official GROMACS. Details are found
31 * in the README & COPYING files - if they are missing, get the
32 * official version at http://www.gromacs.org.
34 * To help us fund GROMACS development, we humbly ask that you cite
35 * the research papers on the package. Check out http://www.gromacs.org.
40 #include "pme-solve.h"
44 #include "gromacs/fft/parallel_3dfft.h"
45 #include "gromacs/math/units.h"
46 #include "gromacs/math/utilities.h"
47 #include "gromacs/math/vec.h"
48 #include "gromacs/simd/simd.h"
49 #include "gromacs/simd/simd_math.h"
50 #include "gromacs/utility/exceptions.h"
51 #include "gromacs/utility/smalloc.h"
53 #include "pme-internal.h"
55 #if GMX_SIMD_HAVE_REAL
56 /* Turn on arbitrary width SIMD intrinsics for PME solve */
57 # define PME_SIMD_SOLVE
60 using namespace gmx; // TODO: Remove when this file is moved into gmx namespace
62 struct pme_solve_work_t
64 /* work data for solve_pme */
83 static void realloc_work(struct pme_solve_work_t *work, int nkx)
85 if (nkx > work->nalloc)
90 srenew(work->mhx, work->nalloc);
91 srenew(work->mhy, work->nalloc);
92 srenew(work->mhz, work->nalloc);
93 srenew(work->m2, work->nalloc);
94 /* Allocate an aligned pointer for SIMD operations, including extra
95 * elements at the end for padding.
98 simd_width = GMX_SIMD_REAL_WIDTH;
100 /* We can use any alignment, apart from 0, so we use 4 */
103 sfree_aligned(work->denom);
104 sfree_aligned(work->tmp1);
105 sfree_aligned(work->tmp2);
106 sfree_aligned(work->eterm);
107 snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
108 snew_aligned(work->tmp1, work->nalloc+simd_width, simd_width*sizeof(real));
109 snew_aligned(work->tmp2, work->nalloc+simd_width, simd_width*sizeof(real));
110 snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
111 srenew(work->m2inv, work->nalloc);
113 /* Init all allocated elements of denom to 1 to avoid 1/0 exceptions
114 * of simd padded elements.
116 for (i = 0; i < work->nalloc+simd_width; i++)
123 void pme_init_all_work(struct pme_solve_work_t **work, int nthread, int nkx)
126 /* Use fft5d, order after FFT is y major, z, x minor */
128 snew(*work, nthread);
129 /* Allocate the work arrays thread local to optimize memory access */
130 #pragma omp parallel for num_threads(nthread) schedule(static)
131 for (thread = 0; thread < nthread; thread++)
135 realloc_work(&((*work)[thread]), nkx);
137 GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
141 static void free_work(struct pme_solve_work_t *work)
147 sfree_aligned(work->denom);
148 sfree_aligned(work->tmp1);
149 sfree_aligned(work->tmp2);
150 sfree_aligned(work->eterm);
154 void pme_free_all_work(struct pme_solve_work_t **work, int nthread)
158 for (thread = 0; thread < nthread; thread++)
160 free_work(&(*work)[thread]);
166 void get_pme_ener_vir_q(struct pme_solve_work_t *work, int nthread,
167 real *mesh_energy, matrix vir)
169 /* This function sums output over threads and should therefore
170 * only be called after thread synchronization.
174 *mesh_energy = work[0].energy_q;
175 copy_mat(work[0].vir_q, vir);
177 for (thread = 1; thread < nthread; thread++)
179 *mesh_energy += work[thread].energy_q;
180 m_add(vir, work[thread].vir_q, vir);
184 void get_pme_ener_vir_lj(struct pme_solve_work_t *work, int nthread,
185 real *mesh_energy, matrix vir)
187 /* This function sums output over threads and should therefore
188 * only be called after thread synchronization.
192 *mesh_energy = work[0].energy_lj;
193 copy_mat(work[0].vir_lj, vir);
195 for (thread = 1; thread < nthread; thread++)
197 *mesh_energy += work[thread].energy_lj;
198 m_add(vir, work[thread].vir_lj, vir);
202 #if defined PME_SIMD_SOLVE
203 /* Calculate exponentials through SIMD */
204 gmx_inline static void calc_exponentials_q(int gmx_unused start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
208 SimdReal tmp_d1, tmp_r, tmp_e;
211 /* We only need to calculate from start. But since start is 0 or 1
212 * and we want to use aligned loads/stores, we always start from 0.
214 for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
216 tmp_d1 = load(d_aligned+kx);
217 tmp_r = load(r_aligned+kx);
218 tmp_r = gmx::exp(tmp_r);
219 tmp_e = f_simd / tmp_d1;
220 tmp_e = tmp_e * tmp_r;
221 store(e_aligned+kx, tmp_e);
226 gmx_inline static void calc_exponentials_q(int start, int end, real f, real *d, real *r, real *e)
229 for (kx = start; kx < end; kx++)
233 for (kx = start; kx < end; kx++)
235 r[kx] = std::exp(r[kx]);
237 for (kx = start; kx < end; kx++)
239 e[kx] = f*r[kx]*d[kx];
244 #if defined PME_SIMD_SOLVE
245 /* Calculate exponentials through SIMD */
246 gmx_inline static void calc_exponentials_lj(int gmx_unused start, int end, real *r_aligned, real *factor_aligned, real *d_aligned)
248 SimdReal tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
249 const SimdReal sqr_PI = sqrt(SimdReal(M_PI));
251 for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
253 /* We only need to calculate from start. But since start is 0 or 1
254 * and we want to use aligned loads/stores, we always start from 0.
256 tmp_d = load(d_aligned+kx);
257 d_inv = SimdReal(1.0) / tmp_d;
258 store(d_aligned+kx, d_inv);
259 tmp_r = load(r_aligned+kx);
260 tmp_r = gmx::exp(tmp_r);
261 store(r_aligned+kx, tmp_r);
262 tmp_mk = load(factor_aligned+kx);
263 tmp_fac = sqr_PI * tmp_mk * erfc(tmp_mk);
264 store(factor_aligned+kx, tmp_fac);
268 gmx_inline static void calc_exponentials_lj(int start, int end, real *r, real *tmp2, real *d)
272 for (kx = start; kx < end; kx++)
277 for (kx = start; kx < end; kx++)
279 r[kx] = std::exp(r[kx]);
282 for (kx = start; kx < end; kx++)
285 tmp2[kx] = sqrt(M_PI)*mk*std::erfc(mk);
290 int solve_pme_yzx(struct gmx_pme_t *pme, t_complex *grid, real vol,
292 int nthread, int thread)
294 /* do recip sum over local cells in grid */
295 /* y major, z middle, x minor or continuous */
297 int kx, ky, kz, maxkx, maxky;
298 int nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
300 real ewaldcoeff = pme->ewaldcoeff_q;
301 real factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
302 real ets2, struct2, vfactor, ets2vf;
303 real d1, d2, energy = 0;
305 real virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
306 real rxx, ryx, ryy, rzx, rzy, rzz;
307 struct pme_solve_work_t *work;
308 real *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
309 real mhxk, mhyk, mhzk, m2k;
312 ivec local_ndata, local_offset, local_size;
315 elfac = ONE_4PI_EPS0/pme->epsilon_r;
321 /* Dimensions should be identical for A/B grid, so we just use A here */
322 gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_QA],
328 rxx = pme->recipbox[XX][XX];
329 ryx = pme->recipbox[YY][XX];
330 ryy = pme->recipbox[YY][YY];
331 rzx = pme->recipbox[ZZ][XX];
332 rzy = pme->recipbox[ZZ][YY];
333 rzz = pme->recipbox[ZZ][ZZ];
338 work = &pme->solve_work[thread];
348 iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread /nthread;
349 iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
351 for (iyz = iyz0; iyz < iyz1; iyz++)
353 iy = iyz/local_ndata[ZZ];
354 iz = iyz - iy*local_ndata[ZZ];
356 ky = iy + local_offset[YY];
367 by = M_PI*vol*pme->bsp_mod[YY][ky];
369 kz = iz + local_offset[ZZ];
373 bz = pme->bsp_mod[ZZ][kz];
375 /* 0.5 correction for corner points */
377 if (kz == 0 || kz == (nz+1)/2)
382 p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
384 /* We should skip the k-space point (0,0,0) */
385 /* Note that since here x is the minor index, local_offset[XX]=0 */
386 if (local_offset[XX] > 0 || ky > 0 || kz > 0)
388 kxstart = local_offset[XX];
392 kxstart = local_offset[XX] + 1;
395 kxend = local_offset[XX] + local_ndata[XX];
399 /* More expensive inner loop, especially because of the storage
400 * of the mh elements in array's.
401 * Because x is the minor grid index, all mh elements
402 * depend on kx for triclinic unit cells.
405 /* Two explicit loops to avoid a conditional inside the loop */
406 for (kx = kxstart; kx < maxkx; kx++)
411 mhyk = mx * ryx + my * ryy;
412 mhzk = mx * rzx + my * rzy + mz * rzz;
413 m2k = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
418 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
419 tmp1[kx] = -factor*m2k;
422 for (kx = maxkx; kx < kxend; kx++)
427 mhyk = mx * ryx + my * ryy;
428 mhzk = mx * rzx + my * rzy + mz * rzz;
429 m2k = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
434 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
435 tmp1[kx] = -factor*m2k;
438 for (kx = kxstart; kx < kxend; kx++)
440 m2inv[kx] = 1.0/m2[kx];
443 calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
445 for (kx = kxstart; kx < kxend; kx++, p0++)
450 p0->re = d1*eterm[kx];
451 p0->im = d2*eterm[kx];
453 struct2 = 2.0*(d1*d1+d2*d2);
455 tmp1[kx] = eterm[kx]*struct2;
458 for (kx = kxstart; kx < kxend; kx++)
460 ets2 = corner_fac*tmp1[kx];
461 vfactor = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
464 ets2vf = ets2*vfactor;
465 virxx += ets2vf*mhx[kx]*mhx[kx] - ets2;
466 virxy += ets2vf*mhx[kx]*mhy[kx];
467 virxz += ets2vf*mhx[kx]*mhz[kx];
468 viryy += ets2vf*mhy[kx]*mhy[kx] - ets2;
469 viryz += ets2vf*mhy[kx]*mhz[kx];
470 virzz += ets2vf*mhz[kx]*mhz[kx] - ets2;
475 /* We don't need to calculate the energy and the virial.
476 * In this case the triclinic overhead is small.
479 /* Two explicit loops to avoid a conditional inside the loop */
481 for (kx = kxstart; kx < maxkx; kx++)
486 mhyk = mx * ryx + my * ryy;
487 mhzk = mx * rzx + my * rzy + mz * rzz;
488 m2k = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
489 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
490 tmp1[kx] = -factor*m2k;
493 for (kx = maxkx; kx < kxend; kx++)
498 mhyk = mx * ryx + my * ryy;
499 mhzk = mx * rzx + my * rzy + mz * rzz;
500 m2k = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
501 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
502 tmp1[kx] = -factor*m2k;
505 calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
507 for (kx = kxstart; kx < kxend; kx++, p0++)
512 p0->re = d1*eterm[kx];
513 p0->im = d2*eterm[kx];
520 /* Update virial with local values.
521 * The virial is symmetric by definition.
522 * this virial seems ok for isotropic scaling, but I'm
523 * experiencing problems on semiisotropic membranes.
524 * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
526 work->vir_q[XX][XX] = 0.25*virxx;
527 work->vir_q[YY][YY] = 0.25*viryy;
528 work->vir_q[ZZ][ZZ] = 0.25*virzz;
529 work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25*virxy;
530 work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25*virxz;
531 work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25*viryz;
533 /* This energy should be corrected for a charged system */
534 work->energy_q = 0.5*energy;
537 /* Return the loop count */
538 return local_ndata[YY]*local_ndata[XX];
541 int solve_pme_lj_yzx(struct gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real vol,
542 gmx_bool bEnerVir, int nthread, int thread)
544 /* do recip sum over local cells in grid */
545 /* y major, z middle, x minor or continuous */
547 int kx, ky, kz, maxkx, maxky;
548 int nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
550 real ewaldcoeff = pme->ewaldcoeff_lj;
551 real factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
553 real eterm, vterm, d1, d2, energy = 0;
555 real virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
556 real rxx, ryx, ryy, rzx, rzy, rzz;
557 real *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
558 real mhxk, mhyk, mhzk, m2k;
559 struct pme_solve_work_t *work;
562 ivec local_ndata, local_offset, local_size;
567 /* Dimensions should be identical for A/B grid, so we just use A here */
568 gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_C6A],
573 rxx = pme->recipbox[XX][XX];
574 ryx = pme->recipbox[YY][XX];
575 ryy = pme->recipbox[YY][YY];
576 rzx = pme->recipbox[ZZ][XX];
577 rzy = pme->recipbox[ZZ][YY];
578 rzz = pme->recipbox[ZZ][ZZ];
583 work = &pme->solve_work[thread];
592 iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread /nthread;
593 iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
595 for (iyz = iyz0; iyz < iyz1; iyz++)
597 iy = iyz/local_ndata[ZZ];
598 iz = iyz - iy*local_ndata[ZZ];
600 ky = iy + local_offset[YY];
611 by = 3.0*vol*pme->bsp_mod[YY][ky]
612 / (M_PI*sqrt(M_PI)*ewaldcoeff*ewaldcoeff*ewaldcoeff);
614 kz = iz + local_offset[ZZ];
618 bz = pme->bsp_mod[ZZ][kz];
620 /* 0.5 correction for corner points */
622 if (kz == 0 || kz == (nz+1)/2)
627 kxstart = local_offset[XX];
628 kxend = local_offset[XX] + local_ndata[XX];
631 /* More expensive inner loop, especially because of the
632 * storage of the mh elements in array's. Because x is the
633 * minor grid index, all mh elements depend on kx for
634 * triclinic unit cells.
637 /* Two explicit loops to avoid a conditional inside the loop */
638 for (kx = kxstart; kx < maxkx; kx++)
643 mhyk = mx * ryx + my * ryy;
644 mhzk = mx * rzx + my * rzy + mz * rzz;
645 m2k = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
650 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
651 tmp1[kx] = -factor*m2k;
652 tmp2[kx] = sqrt(factor*m2k);
655 for (kx = maxkx; kx < kxend; kx++)
660 mhyk = mx * ryx + my * ryy;
661 mhzk = mx * rzx + my * rzy + mz * rzz;
662 m2k = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
667 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
668 tmp1[kx] = -factor*m2k;
669 tmp2[kx] = sqrt(factor*m2k);
672 calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
674 for (kx = kxstart; kx < kxend; kx++)
677 eterm = -((1.0 - 2.0*m2k)*tmp1[kx]
679 vterm = 3.0*(-tmp1[kx] + tmp2[kx]);
680 tmp1[kx] = eterm*denom[kx];
681 tmp2[kx] = vterm*denom[kx];
689 p0 = grid[0] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
690 for (kx = kxstart; kx < kxend; kx++, p0++)
700 struct2 = 2.0*(d1*d1+d2*d2);
702 tmp1[kx] = eterm*struct2;
703 tmp2[kx] = vterm*struct2;
708 real *struct2 = denom;
711 for (kx = kxstart; kx < kxend; kx++)
715 /* Due to symmetry we only need to calculate 4 of the 7 terms */
716 for (ig = 0; ig <= 3; ++ig)
721 p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
722 p1 = grid[6-ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
723 scale = 2.0*lb_scale_factor_symm[ig];
724 for (kx = kxstart; kx < kxend; ++kx, ++p0, ++p1)
726 struct2[kx] += scale*(p0->re*p1->re + p0->im*p1->im);
730 for (ig = 0; ig <= 6; ++ig)
734 p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
735 for (kx = kxstart; kx < kxend; kx++, p0++)
745 for (kx = kxstart; kx < kxend; kx++)
750 tmp1[kx] = eterm*str2;
751 tmp2[kx] = vterm*str2;
755 for (kx = kxstart; kx < kxend; kx++)
757 ets2 = corner_fac*tmp1[kx];
758 vterm = 2.0*factor*tmp2[kx];
760 ets2vf = corner_fac*vterm;
761 virxx += ets2vf*mhx[kx]*mhx[kx] - ets2;
762 virxy += ets2vf*mhx[kx]*mhy[kx];
763 virxz += ets2vf*mhx[kx]*mhz[kx];
764 viryy += ets2vf*mhy[kx]*mhy[kx] - ets2;
765 viryz += ets2vf*mhy[kx]*mhz[kx];
766 virzz += ets2vf*mhz[kx]*mhz[kx] - ets2;
771 /* We don't need to calculate the energy and the virial.
772 * In this case the triclinic overhead is small.
775 /* Two explicit loops to avoid a conditional inside the loop */
777 for (kx = kxstart; kx < maxkx; kx++)
782 mhyk = mx * ryx + my * ryy;
783 mhzk = mx * rzx + my * rzy + mz * rzz;
784 m2k = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
786 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
787 tmp1[kx] = -factor*m2k;
788 tmp2[kx] = sqrt(factor*m2k);
791 for (kx = maxkx; kx < kxend; kx++)
796 mhyk = mx * ryx + my * ryy;
797 mhzk = mx * rzx + my * rzy + mz * rzz;
798 m2k = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
800 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
801 tmp1[kx] = -factor*m2k;
802 tmp2[kx] = sqrt(factor*m2k);
805 calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
807 for (kx = kxstart; kx < kxend; kx++)
810 eterm = -((1.0 - 2.0*m2k)*tmp1[kx]
812 tmp1[kx] = eterm*denom[kx];
814 gcount = (bLB ? 7 : 1);
815 for (ig = 0; ig < gcount; ++ig)
819 p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
820 for (kx = kxstart; kx < kxend; kx++, p0++)
835 work->vir_lj[XX][XX] = 0.25*virxx;
836 work->vir_lj[YY][YY] = 0.25*viryy;
837 work->vir_lj[ZZ][ZZ] = 0.25*virzz;
838 work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25*virxy;
839 work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25*virxz;
840 work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25*viryz;
842 /* This energy should be corrected for a charged system */
843 work->energy_lj = 0.5*energy;
845 /* Return the loop count */
846 return local_ndata[YY]*local_ndata[XX];