406e701e2d9976bf0664b5064bc93305f6558923
[alexxy/gromacs.git] / src / gromacs / ewald / pme_solve.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017,2018,2019, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37
38 #include "gmxpre.h"
39
40 #include "pme_solve.h"
41
42 #include <cmath>
43
44 #include "gromacs/fft/parallel_3dfft.h"
45 #include "gromacs/math/units.h"
46 #include "gromacs/math/utilities.h"
47 #include "gromacs/math/vec.h"
48 #include "gromacs/simd/simd.h"
49 #include "gromacs/simd/simd_math.h"
50 #include "gromacs/utility/arrayref.h"
51 #include "gromacs/utility/exceptions.h"
52 #include "gromacs/utility/smalloc.h"
53
54 #include "pme_internal.h"
55
56 #if GMX_SIMD_HAVE_REAL
57 /* Turn on arbitrary width SIMD intrinsics for PME solve */
58 #    define PME_SIMD_SOLVE
59 #endif
60
61 using namespace gmx; // TODO: Remove when this file is moved into gmx namespace
62
63 struct pme_solve_work_t
64 {
65     /* work data for solve_pme */
66     int   nalloc;
67     real* mhx;
68     real* mhy;
69     real* mhz;
70     real* m2;
71     real* denom;
72     real* tmp1_alloc;
73     real* tmp1;
74     real* tmp2;
75     real* eterm;
76     real* m2inv;
77
78     real   energy_q;
79     matrix vir_q;
80     real   energy_lj;
81     matrix vir_lj;
82 };
83
84 #ifdef PME_SIMD_SOLVE
85 constexpr int c_simdWidth = GMX_SIMD_REAL_WIDTH;
86 #else
87 /* We can use any alignment > 0, so we use 4 */
88 constexpr int c_simdWidth = 4;
89 #endif
90
91 /* Returns the smallest number >= \p that is a multiple of \p factor, \p factor must be a power of 2 */
92 template<unsigned int factor>
93 static size_t roundUpToMultipleOfFactor(size_t number)
94 {
95     static_assert(factor > 0 && (factor & (factor - 1)) == 0,
96                   "factor should be >0 and a power of 2");
97
98     /* We need to add a most factor-1 and because factor is a power of 2,
99      * we get the result by masking out the bits corresponding to factor-1.
100      */
101     return (number + factor - 1) & ~(factor - 1);
102 }
103
104 /* Allocate an aligned pointer for SIMD operations, including extra elements
105  * at the end for padding.
106  */
107 /* TODO: Replace this SIMD reallocator with a general, C++ solution */
108 static void reallocSimdAlignedAndPadded(real** ptr, int unpaddedNumElements)
109 {
110     sfree_aligned(*ptr);
111     snew_aligned(*ptr, roundUpToMultipleOfFactor<c_simdWidth>(unpaddedNumElements),
112                  c_simdWidth * sizeof(real));
113 }
114
115 static void realloc_work(struct pme_solve_work_t* work, int nkx)
116 {
117     if (nkx > work->nalloc)
118     {
119         work->nalloc = nkx;
120         srenew(work->mhx, work->nalloc);
121         srenew(work->mhy, work->nalloc);
122         srenew(work->mhz, work->nalloc);
123         srenew(work->m2, work->nalloc);
124         reallocSimdAlignedAndPadded(&work->denom, work->nalloc);
125         reallocSimdAlignedAndPadded(&work->tmp1, work->nalloc);
126         reallocSimdAlignedAndPadded(&work->tmp2, work->nalloc);
127         reallocSimdAlignedAndPadded(&work->eterm, work->nalloc);
128         srenew(work->m2inv, work->nalloc);
129
130         /* Init all allocated elements of denom to 1 to avoid 1/0 exceptions
131          * of simd padded elements.
132          */
133         for (size_t i = 0; i < roundUpToMultipleOfFactor<c_simdWidth>(work->nalloc); i++)
134         {
135             work->denom[i] = 1;
136         }
137     }
138 }
139
140 void pme_init_all_work(struct pme_solve_work_t** work, int nthread, int nkx)
141 {
142     /* Use fft5d, order after FFT is y major, z, x minor */
143
144     snew(*work, nthread);
145     /* Allocate the work arrays thread local to optimize memory access */
146 #pragma omp parallel for num_threads(nthread) schedule(static)
147     for (int thread = 0; thread < nthread; thread++)
148     {
149         try
150         {
151             realloc_work(&((*work)[thread]), nkx);
152         }
153         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
154     }
155 }
156
157 static void free_work(struct pme_solve_work_t* work)
158 {
159     if (work)
160     {
161         sfree(work->mhx);
162         sfree(work->mhy);
163         sfree(work->mhz);
164         sfree(work->m2);
165         sfree_aligned(work->denom);
166         sfree_aligned(work->tmp1);
167         sfree_aligned(work->tmp2);
168         sfree_aligned(work->eterm);
169         sfree(work->m2inv);
170     }
171 }
172
173 void pme_free_all_work(struct pme_solve_work_t** work, int nthread)
174 {
175     if (*work)
176     {
177         for (int thread = 0; thread < nthread; thread++)
178         {
179             free_work(&(*work)[thread]);
180         }
181     }
182     sfree(*work);
183     *work = nullptr;
184 }
185
186 void get_pme_ener_vir_q(pme_solve_work_t* work, int nthread, PmeOutput* output)
187 {
188     GMX_ASSERT(output != nullptr, "Need valid output buffer");
189     /* This function sums output over threads and should therefore
190      * only be called after thread synchronization.
191      */
192     output->coulombEnergy_ = work[0].energy_q;
193     copy_mat(work[0].vir_q, output->coulombVirial_);
194
195     for (int thread = 1; thread < nthread; thread++)
196     {
197         output->coulombEnergy_ += work[thread].energy_q;
198         m_add(output->coulombVirial_, work[thread].vir_q, output->coulombVirial_);
199     }
200 }
201
202 void get_pme_ener_vir_lj(pme_solve_work_t* work, int nthread, PmeOutput* output)
203 {
204     GMX_ASSERT(output != nullptr, "Need valid output buffer");
205     /* This function sums output over threads and should therefore
206      * only be called after thread synchronization.
207      */
208     output->lennardJonesEnergy_ = work[0].energy_lj;
209     copy_mat(work[0].vir_lj, output->lennardJonesVirial_);
210
211     for (int thread = 1; thread < nthread; thread++)
212     {
213         output->lennardJonesEnergy_ += work[thread].energy_lj;
214         m_add(output->lennardJonesVirial_, work[thread].vir_lj, output->lennardJonesVirial_);
215     }
216 }
217
218 #if defined PME_SIMD_SOLVE
219 /* Calculate exponentials through SIMD */
220 inline static void calc_exponentials_q(int /*unused*/,
221                                        int /*unused*/,
222                                        real                     f,
223                                        ArrayRef<const SimdReal> d_aligned,
224                                        ArrayRef<const SimdReal> r_aligned,
225                                        ArrayRef<SimdReal>       e_aligned)
226 {
227     {
228         SimdReal f_simd(f);
229         SimdReal tmp_d1, tmp_r, tmp_e;
230
231         /* We only need to calculate from start. But since start is 0 or 1
232          * and we want to use aligned loads/stores, we always start from 0.
233          */
234         GMX_ASSERT(d_aligned.size() == r_aligned.size(), "d and r must have same size");
235         GMX_ASSERT(d_aligned.size() == e_aligned.size(), "d and e must have same size");
236         for (size_t kx = 0; kx != d_aligned.size(); ++kx)
237         {
238             tmp_d1        = d_aligned[kx];
239             tmp_r         = r_aligned[kx];
240             tmp_r         = gmx::exp(tmp_r);
241             tmp_e         = f_simd / tmp_d1;
242             tmp_e         = tmp_e * tmp_r;
243             e_aligned[kx] = tmp_e;
244         }
245     }
246 }
247 #else
248 inline static void
249 calc_exponentials_q(int start, int end, real f, ArrayRef<real> d, ArrayRef<real> r, ArrayRef<real> e)
250 {
251     GMX_ASSERT(d.size() == r.size(), "d and r must have same size");
252     GMX_ASSERT(d.size() == e.size(), "d and e must have same size");
253     int kx;
254     for (kx = start; kx < end; kx++)
255     {
256         d[kx] = 1.0 / d[kx];
257     }
258     for (kx = start; kx < end; kx++)
259     {
260         r[kx] = std::exp(r[kx]);
261     }
262     for (kx = start; kx < end; kx++)
263     {
264         e[kx] = f * r[kx] * d[kx];
265     }
266 }
267 #endif
268
269 #if defined PME_SIMD_SOLVE
270 /* Calculate exponentials through SIMD */
271 inline static void calc_exponentials_lj(int /*unused*/,
272                                         int /*unused*/,
273                                         ArrayRef<SimdReal> r_aligned,
274                                         ArrayRef<SimdReal> factor_aligned,
275                                         ArrayRef<SimdReal> d_aligned)
276 {
277     SimdReal       tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
278     const SimdReal sqr_PI = sqrt(SimdReal(M_PI));
279
280     GMX_ASSERT(d_aligned.size() == r_aligned.size(), "d and r must have same size");
281     GMX_ASSERT(d_aligned.size() == factor_aligned.size(), "d and factor must have same size");
282     for (size_t kx = 0; kx != d_aligned.size(); ++kx)
283     {
284         /* We only need to calculate from start. But since start is 0 or 1
285          * and we want to use aligned loads/stores, we always start from 0.
286          */
287         tmp_d              = d_aligned[kx];
288         d_inv              = SimdReal(1.0) / tmp_d;
289         d_aligned[kx]      = d_inv;
290         tmp_r              = r_aligned[kx];
291         tmp_r              = gmx::exp(tmp_r);
292         r_aligned[kx]      = tmp_r;
293         tmp_mk             = factor_aligned[kx];
294         tmp_fac            = sqr_PI * tmp_mk * erfc(tmp_mk);
295         factor_aligned[kx] = tmp_fac;
296     }
297 }
298 #else
299 inline static void
300 calc_exponentials_lj(int start, int end, ArrayRef<real> r, ArrayRef<real> tmp2, ArrayRef<real> d)
301 {
302     int  kx;
303     real mk;
304     GMX_ASSERT(d.size() == r.size(), "d and r must have same size");
305     GMX_ASSERT(d.size() == tmp2.size(), "d and tmp2 must have same size");
306     for (kx = start; kx < end; kx++)
307     {
308         d[kx] = 1.0 / d[kx];
309     }
310
311     for (kx = start; kx < end; kx++)
312     {
313         r[kx] = std::exp(r[kx]);
314     }
315
316     for (kx = start; kx < end; kx++)
317     {
318         mk       = tmp2[kx];
319         tmp2[kx] = sqrt(M_PI) * mk * std::erfc(mk);
320     }
321 }
322 #endif
323
324 #if defined PME_SIMD_SOLVE
325 using PME_T = SimdReal;
326 #else
327 using PME_T = real;
328 #endif
329
330 int solve_pme_yzx(const gmx_pme_t* pme, t_complex* grid, real vol, gmx_bool bEnerVir, int nthread, int thread)
331 {
332     /* do recip sum over local cells in grid */
333     /* y major, z middle, x minor or continuous */
334     t_complex*               p0;
335     int                      kx, ky, kz, maxkx, maxky;
336     int                      nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
337     real                     mx, my, mz;
338     real                     ewaldcoeff = pme->ewaldcoeff_q;
339     real                     factor     = M_PI * M_PI / (ewaldcoeff * ewaldcoeff);
340     real                     ets2, struct2, vfactor, ets2vf;
341     real                     d1, d2, energy = 0;
342     real                     by, bz;
343     real                     virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
344     real                     rxx, ryx, ryy, rzx, rzy, rzz;
345     struct pme_solve_work_t* work;
346     real *                   mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
347     real                     mhxk, mhyk, mhzk, m2k;
348     real                     corner_fac;
349     ivec                     complex_order;
350     ivec                     local_ndata, local_offset, local_size;
351     real                     elfac;
352
353     elfac = ONE_4PI_EPS0 / pme->epsilon_r;
354
355     nx = pme->nkx;
356     ny = pme->nky;
357     nz = pme->nkz;
358
359     /* Dimensions should be identical for A/B grid, so we just use A here */
360     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_QA], complex_order, local_ndata,
361                                       local_offset, local_size);
362
363     rxx = pme->recipbox[XX][XX];
364     ryx = pme->recipbox[YY][XX];
365     ryy = pme->recipbox[YY][YY];
366     rzx = pme->recipbox[ZZ][XX];
367     rzy = pme->recipbox[ZZ][YY];
368     rzz = pme->recipbox[ZZ][ZZ];
369
370     GMX_ASSERT(rxx != 0.0, "Someone broke the reciprocal box again");
371
372     maxkx = (nx + 1) / 2;
373     maxky = (ny + 1) / 2;
374
375     work  = &pme->solve_work[thread];
376     mhx   = work->mhx;
377     mhy   = work->mhy;
378     mhz   = work->mhz;
379     m2    = work->m2;
380     denom = work->denom;
381     tmp1  = work->tmp1;
382     eterm = work->eterm;
383     m2inv = work->m2inv;
384
385     iyz0 = local_ndata[YY] * local_ndata[ZZ] * thread / nthread;
386     iyz1 = local_ndata[YY] * local_ndata[ZZ] * (thread + 1) / nthread;
387
388     for (iyz = iyz0; iyz < iyz1; iyz++)
389     {
390         iy = iyz / local_ndata[ZZ];
391         iz = iyz - iy * local_ndata[ZZ];
392
393         ky = iy + local_offset[YY];
394
395         if (ky < maxky)
396         {
397             my = ky;
398         }
399         else
400         {
401             my = (ky - ny);
402         }
403
404         by = M_PI * vol * pme->bsp_mod[YY][ky];
405
406         kz = iz + local_offset[ZZ];
407
408         mz = kz;
409
410         bz = pme->bsp_mod[ZZ][kz];
411
412         /* 0.5 correction for corner points */
413         corner_fac = 1;
414         if (kz == 0 || kz == (nz + 1) / 2)
415         {
416             corner_fac = 0.5;
417         }
418
419         p0 = grid + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
420
421         /* We should skip the k-space point (0,0,0) */
422         /* Note that since here x is the minor index, local_offset[XX]=0 */
423         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
424         {
425             kxstart = local_offset[XX];
426         }
427         else
428         {
429             kxstart = local_offset[XX] + 1;
430             p0++;
431         }
432         kxend = local_offset[XX] + local_ndata[XX];
433
434         if (bEnerVir)
435         {
436             /* More expensive inner loop, especially because of the storage
437              * of the mh elements in array's.
438              * Because x is the minor grid index, all mh elements
439              * depend on kx for triclinic unit cells.
440              */
441
442             /* Two explicit loops to avoid a conditional inside the loop */
443             for (kx = kxstart; kx < maxkx; kx++)
444             {
445                 mx = kx;
446
447                 mhxk      = mx * rxx;
448                 mhyk      = mx * ryx + my * ryy;
449                 mhzk      = mx * rzx + my * rzy + mz * rzz;
450                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
451                 mhx[kx]   = mhxk;
452                 mhy[kx]   = mhyk;
453                 mhz[kx]   = mhzk;
454                 m2[kx]    = m2k;
455                 denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
456                 tmp1[kx]  = -factor * m2k;
457             }
458
459             for (kx = maxkx; kx < kxend; kx++)
460             {
461                 mx = (kx - nx);
462
463                 mhxk      = mx * rxx;
464                 mhyk      = mx * ryx + my * ryy;
465                 mhzk      = mx * rzx + my * rzy + mz * rzz;
466                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
467                 mhx[kx]   = mhxk;
468                 mhy[kx]   = mhyk;
469                 mhz[kx]   = mhzk;
470                 m2[kx]    = m2k;
471                 denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
472                 tmp1[kx]  = -factor * m2k;
473             }
474
475             for (kx = kxstart; kx < kxend; kx++)
476             {
477                 m2inv[kx] = 1.0 / m2[kx];
478             }
479
480             calc_exponentials_q(
481                     kxstart, kxend, elfac,
482                     ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
483                     ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
484                     ArrayRef<PME_T>(eterm, eterm + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
485
486             for (kx = kxstart; kx < kxend; kx++, p0++)
487             {
488                 d1 = p0->re;
489                 d2 = p0->im;
490
491                 p0->re = d1 * eterm[kx];
492                 p0->im = d2 * eterm[kx];
493
494                 struct2 = 2.0 * (d1 * d1 + d2 * d2);
495
496                 tmp1[kx] = eterm[kx] * struct2;
497             }
498
499             for (kx = kxstart; kx < kxend; kx++)
500             {
501                 ets2    = corner_fac * tmp1[kx];
502                 vfactor = (factor * m2[kx] + 1.0) * 2.0 * m2inv[kx];
503                 energy += ets2;
504
505                 ets2vf = ets2 * vfactor;
506                 virxx += ets2vf * mhx[kx] * mhx[kx] - ets2;
507                 virxy += ets2vf * mhx[kx] * mhy[kx];
508                 virxz += ets2vf * mhx[kx] * mhz[kx];
509                 viryy += ets2vf * mhy[kx] * mhy[kx] - ets2;
510                 viryz += ets2vf * mhy[kx] * mhz[kx];
511                 virzz += ets2vf * mhz[kx] * mhz[kx] - ets2;
512             }
513         }
514         else
515         {
516             /* We don't need to calculate the energy and the virial.
517              * In this case the triclinic overhead is small.
518              */
519
520             /* Two explicit loops to avoid a conditional inside the loop */
521
522             for (kx = kxstart; kx < maxkx; kx++)
523             {
524                 mx = kx;
525
526                 mhxk      = mx * rxx;
527                 mhyk      = mx * ryx + my * ryy;
528                 mhzk      = mx * rzx + my * rzy + mz * rzz;
529                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
530                 denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
531                 tmp1[kx]  = -factor * m2k;
532             }
533
534             for (kx = maxkx; kx < kxend; kx++)
535             {
536                 mx = (kx - nx);
537
538                 mhxk      = mx * rxx;
539                 mhyk      = mx * ryx + my * ryy;
540                 mhzk      = mx * rzx + my * rzy + mz * rzz;
541                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
542                 denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
543                 tmp1[kx]  = -factor * m2k;
544             }
545
546             calc_exponentials_q(
547                     kxstart, kxend, elfac,
548                     ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
549                     ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
550                     ArrayRef<PME_T>(eterm, eterm + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
551
552
553             for (kx = kxstart; kx < kxend; kx++, p0++)
554             {
555                 d1 = p0->re;
556                 d2 = p0->im;
557
558                 p0->re = d1 * eterm[kx];
559                 p0->im = d2 * eterm[kx];
560             }
561         }
562     }
563
564     if (bEnerVir)
565     {
566         /* Update virial with local values.
567          * The virial is symmetric by definition.
568          * this virial seems ok for isotropic scaling, but I'm
569          * experiencing problems on semiisotropic membranes.
570          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
571          */
572         work->vir_q[XX][XX] = 0.25 * virxx;
573         work->vir_q[YY][YY] = 0.25 * viryy;
574         work->vir_q[ZZ][ZZ] = 0.25 * virzz;
575         work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25 * virxy;
576         work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25 * virxz;
577         work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25 * viryz;
578
579         /* This energy should be corrected for a charged system */
580         work->energy_q = 0.5 * energy;
581     }
582
583     /* Return the loop count */
584     return local_ndata[YY] * local_ndata[XX];
585 }
586
587 int solve_pme_lj_yzx(const gmx_pme_t* pme, t_complex** grid, gmx_bool bLB, real vol, gmx_bool bEnerVir, int nthread, int thread)
588 {
589     /* do recip sum over local cells in grid */
590     /* y major, z middle, x minor or continuous */
591     int                      ig, gcount;
592     int                      kx, ky, kz, maxkx, maxky;
593     int                      nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
594     real                     mx, my, mz;
595     real                     ewaldcoeff = pme->ewaldcoeff_lj;
596     real                     factor     = M_PI * M_PI / (ewaldcoeff * ewaldcoeff);
597     real                     ets2, ets2vf;
598     real                     eterm, vterm, d1, d2, energy = 0;
599     real                     by, bz;
600     real                     virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
601     real                     rxx, ryx, ryy, rzx, rzy, rzz;
602     real *                   mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
603     real                     mhxk, mhyk, mhzk, m2k;
604     struct pme_solve_work_t* work;
605     real                     corner_fac;
606     ivec                     complex_order;
607     ivec                     local_ndata, local_offset, local_size;
608     nx = pme->nkx;
609     ny = pme->nky;
610     nz = pme->nkz;
611
612     /* Dimensions should be identical for A/B grid, so we just use A here */
613     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_C6A], complex_order, local_ndata,
614                                       local_offset, local_size);
615     rxx = pme->recipbox[XX][XX];
616     ryx = pme->recipbox[YY][XX];
617     ryy = pme->recipbox[YY][YY];
618     rzx = pme->recipbox[ZZ][XX];
619     rzy = pme->recipbox[ZZ][YY];
620     rzz = pme->recipbox[ZZ][ZZ];
621
622     maxkx = (nx + 1) / 2;
623     maxky = (ny + 1) / 2;
624
625     work  = &pme->solve_work[thread];
626     mhx   = work->mhx;
627     mhy   = work->mhy;
628     mhz   = work->mhz;
629     m2    = work->m2;
630     denom = work->denom;
631     tmp1  = work->tmp1;
632     tmp2  = work->tmp2;
633
634     iyz0 = local_ndata[YY] * local_ndata[ZZ] * thread / nthread;
635     iyz1 = local_ndata[YY] * local_ndata[ZZ] * (thread + 1) / nthread;
636
637     for (iyz = iyz0; iyz < iyz1; iyz++)
638     {
639         iy = iyz / local_ndata[ZZ];
640         iz = iyz - iy * local_ndata[ZZ];
641
642         ky = iy + local_offset[YY];
643
644         if (ky < maxky)
645         {
646             my = ky;
647         }
648         else
649         {
650             my = (ky - ny);
651         }
652
653         by = 3.0 * vol * pme->bsp_mod[YY][ky] / (M_PI * sqrt(M_PI) * ewaldcoeff * ewaldcoeff * ewaldcoeff);
654
655         kz = iz + local_offset[ZZ];
656
657         mz = kz;
658
659         bz = pme->bsp_mod[ZZ][kz];
660
661         /* 0.5 correction for corner points */
662         corner_fac = 1;
663         if (kz == 0 || kz == (nz + 1) / 2)
664         {
665             corner_fac = 0.5;
666         }
667
668         kxstart = local_offset[XX];
669         kxend   = local_offset[XX] + local_ndata[XX];
670         if (bEnerVir)
671         {
672             /* More expensive inner loop, especially because of the
673              * storage of the mh elements in array's.  Because x is the
674              * minor grid index, all mh elements depend on kx for
675              * triclinic unit cells.
676              */
677
678             /* Two explicit loops to avoid a conditional inside the loop */
679             for (kx = kxstart; kx < maxkx; kx++)
680             {
681                 mx = kx;
682
683                 mhxk      = mx * rxx;
684                 mhyk      = mx * ryx + my * ryy;
685                 mhzk      = mx * rzx + my * rzy + mz * rzz;
686                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
687                 mhx[kx]   = mhxk;
688                 mhy[kx]   = mhyk;
689                 mhz[kx]   = mhzk;
690                 m2[kx]    = m2k;
691                 denom[kx] = bz * by * pme->bsp_mod[XX][kx];
692                 tmp1[kx]  = -factor * m2k;
693                 tmp2[kx]  = sqrt(factor * m2k);
694             }
695
696             for (kx = maxkx; kx < kxend; kx++)
697             {
698                 mx = (kx - nx);
699
700                 mhxk      = mx * rxx;
701                 mhyk      = mx * ryx + my * ryy;
702                 mhzk      = mx * rzx + my * rzy + mz * rzz;
703                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
704                 mhx[kx]   = mhxk;
705                 mhy[kx]   = mhyk;
706                 mhz[kx]   = mhzk;
707                 m2[kx]    = m2k;
708                 denom[kx] = bz * by * pme->bsp_mod[XX][kx];
709                 tmp1[kx]  = -factor * m2k;
710                 tmp2[kx]  = sqrt(factor * m2k);
711             }
712             /* Clear padding elements to avoid (harmless) fp exceptions */
713             const int kxendSimd = roundUpToMultipleOfFactor<c_simdWidth>(kxend);
714             for (; kx < kxendSimd; kx++)
715             {
716                 tmp1[kx] = 0;
717                 tmp2[kx] = 0;
718             }
719
720             calc_exponentials_lj(
721                     kxstart, kxend,
722                     ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
723                     ArrayRef<PME_T>(tmp2, tmp2 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
724                     ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
725
726             for (kx = kxstart; kx < kxend; kx++)
727             {
728                 m2k      = factor * m2[kx];
729                 eterm    = -((1.0 - 2.0 * m2k) * tmp1[kx] + 2.0 * m2k * tmp2[kx]);
730                 vterm    = 3.0 * (-tmp1[kx] + tmp2[kx]);
731                 tmp1[kx] = eterm * denom[kx];
732                 tmp2[kx] = vterm * denom[kx];
733             }
734
735             if (!bLB)
736             {
737                 t_complex* p0;
738                 real       struct2;
739
740                 p0 = grid[0] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
741                 for (kx = kxstart; kx < kxend; kx++, p0++)
742                 {
743                     d1 = p0->re;
744                     d2 = p0->im;
745
746                     eterm  = tmp1[kx];
747                     vterm  = tmp2[kx];
748                     p0->re = d1 * eterm;
749                     p0->im = d2 * eterm;
750
751                     struct2 = 2.0 * (d1 * d1 + d2 * d2);
752
753                     tmp1[kx] = eterm * struct2;
754                     tmp2[kx] = vterm * struct2;
755                 }
756             }
757             else
758             {
759                 real* struct2 = denom;
760                 real  str2;
761
762                 for (kx = kxstart; kx < kxend; kx++)
763                 {
764                     struct2[kx] = 0.0;
765                 }
766                 /* Due to symmetry we only need to calculate 4 of the 7 terms */
767                 for (ig = 0; ig <= 3; ++ig)
768                 {
769                     t_complex *p0, *p1;
770                     real       scale;
771
772                     p0 = grid[ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
773                     p1 = grid[6 - ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
774                     scale = 2.0 * lb_scale_factor_symm[ig];
775                     for (kx = kxstart; kx < kxend; ++kx, ++p0, ++p1)
776                     {
777                         struct2[kx] += scale * (p0->re * p1->re + p0->im * p1->im);
778                     }
779                 }
780                 for (ig = 0; ig <= 6; ++ig)
781                 {
782                     t_complex* p0;
783
784                     p0 = grid[ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
785                     for (kx = kxstart; kx < kxend; kx++, p0++)
786                     {
787                         d1 = p0->re;
788                         d2 = p0->im;
789
790                         eterm  = tmp1[kx];
791                         p0->re = d1 * eterm;
792                         p0->im = d2 * eterm;
793                     }
794                 }
795                 for (kx = kxstart; kx < kxend; kx++)
796                 {
797                     eterm    = tmp1[kx];
798                     vterm    = tmp2[kx];
799                     str2     = struct2[kx];
800                     tmp1[kx] = eterm * str2;
801                     tmp2[kx] = vterm * str2;
802                 }
803             }
804
805             for (kx = kxstart; kx < kxend; kx++)
806             {
807                 ets2  = corner_fac * tmp1[kx];
808                 vterm = 2.0 * factor * tmp2[kx];
809                 energy += ets2;
810                 ets2vf = corner_fac * vterm;
811                 virxx += ets2vf * mhx[kx] * mhx[kx] - ets2;
812                 virxy += ets2vf * mhx[kx] * mhy[kx];
813                 virxz += ets2vf * mhx[kx] * mhz[kx];
814                 viryy += ets2vf * mhy[kx] * mhy[kx] - ets2;
815                 viryz += ets2vf * mhy[kx] * mhz[kx];
816                 virzz += ets2vf * mhz[kx] * mhz[kx] - ets2;
817             }
818         }
819         else
820         {
821             /* We don't need to calculate the energy and the virial.
822              *  In this case the triclinic overhead is small.
823              */
824
825             /* Two explicit loops to avoid a conditional inside the loop */
826
827             for (kx = kxstart; kx < maxkx; kx++)
828             {
829                 mx = kx;
830
831                 mhxk      = mx * rxx;
832                 mhyk      = mx * ryx + my * ryy;
833                 mhzk      = mx * rzx + my * rzy + mz * rzz;
834                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
835                 m2[kx]    = m2k;
836                 denom[kx] = bz * by * pme->bsp_mod[XX][kx];
837                 tmp1[kx]  = -factor * m2k;
838                 tmp2[kx]  = sqrt(factor * m2k);
839             }
840
841             for (kx = maxkx; kx < kxend; kx++)
842             {
843                 mx = (kx - nx);
844
845                 mhxk      = mx * rxx;
846                 mhyk      = mx * ryx + my * ryy;
847                 mhzk      = mx * rzx + my * rzy + mz * rzz;
848                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
849                 m2[kx]    = m2k;
850                 denom[kx] = bz * by * pme->bsp_mod[XX][kx];
851                 tmp1[kx]  = -factor * m2k;
852                 tmp2[kx]  = sqrt(factor * m2k);
853             }
854             /* Clear padding elements to avoid (harmless) fp exceptions */
855             const int kxendSimd = roundUpToMultipleOfFactor<c_simdWidth>(kxend);
856             for (; kx < kxendSimd; kx++)
857             {
858                 tmp1[kx] = 0;
859                 tmp2[kx] = 0;
860             }
861
862             calc_exponentials_lj(
863                     kxstart, kxend,
864                     ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
865                     ArrayRef<PME_T>(tmp2, tmp2 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
866                     ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
867
868             for (kx = kxstart; kx < kxend; kx++)
869             {
870                 m2k      = factor * m2[kx];
871                 eterm    = -((1.0 - 2.0 * m2k) * tmp1[kx] + 2.0 * m2k * tmp2[kx]);
872                 tmp1[kx] = eterm * denom[kx];
873             }
874             gcount = (bLB ? 7 : 1);
875             for (ig = 0; ig < gcount; ++ig)
876             {
877                 t_complex* p0;
878
879                 p0 = grid[ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
880                 for (kx = kxstart; kx < kxend; kx++, p0++)
881                 {
882                     d1 = p0->re;
883                     d2 = p0->im;
884
885                     eterm = tmp1[kx];
886
887                     p0->re = d1 * eterm;
888                     p0->im = d2 * eterm;
889                 }
890             }
891         }
892     }
893     if (bEnerVir)
894     {
895         work->vir_lj[XX][XX] = 0.25 * virxx;
896         work->vir_lj[YY][YY] = 0.25 * viryy;
897         work->vir_lj[ZZ][ZZ] = 0.25 * virzz;
898         work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25 * virxy;
899         work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25 * virxz;
900         work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25 * viryz;
901
902         /* This energy should be corrected for a charged system */
903         work->energy_lj = 0.5 * energy;
904     }
905     /* Return the loop count */
906     return local_ndata[YY] * local_ndata[XX];
907 }