88e529b553e1a5601e8bfd33bb946b20f3e744fe
[alexxy/gromacs.git] / src / gromacs / ewald / pme_solve.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017 by the GROMACS development team.
7  * Copyright (c) 2018,2019,2020, by the GROMACS development team, led by
8  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
9  * and including many others, as listed in the AUTHORS file in the
10  * top-level source directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38
39 #include "gmxpre.h"
40
41 #include "pme_solve.h"
42
43 #include <cmath>
44
45 #include "gromacs/fft/parallel_3dfft.h"
46 #include "gromacs/math/units.h"
47 #include "gromacs/math/utilities.h"
48 #include "gromacs/math/vec.h"
49 #include "gromacs/simd/simd.h"
50 #include "gromacs/simd/simd_math.h"
51 #include "gromacs/utility/arrayref.h"
52 #include "gromacs/utility/exceptions.h"
53 #include "gromacs/utility/smalloc.h"
54
55 #include "pme_internal.h"
56 #include "pme_output.h"
57
58 #if GMX_SIMD_HAVE_REAL
59 /* Turn on arbitrary width SIMD intrinsics for PME solve */
60 #    define PME_SIMD_SOLVE
61 #endif
62
63 using namespace gmx; // TODO: Remove when this file is moved into gmx namespace
64
65 struct pme_solve_work_t
66 {
67     /* work data for solve_pme */
68     int   nalloc;
69     real* mhx;
70     real* mhy;
71     real* mhz;
72     real* m2;
73     real* denom;
74     real* tmp1_alloc;
75     real* tmp1;
76     real* tmp2;
77     real* eterm;
78     real* m2inv;
79
80     real   energy_q;
81     matrix vir_q;
82     real   energy_lj;
83     matrix vir_lj;
84 };
85
86 #ifdef PME_SIMD_SOLVE
87 constexpr int c_simdWidth = GMX_SIMD_REAL_WIDTH;
88 #else
89 /* We can use any alignment > 0, so we use 4 */
90 constexpr int c_simdWidth = 4;
91 #endif
92
93 /* Returns the smallest number >= \p that is a multiple of \p factor, \p factor must be a power of 2 */
94 template<unsigned int factor>
95 static size_t roundUpToMultipleOfFactor(size_t number)
96 {
97     static_assert(factor > 0 && (factor & (factor - 1)) == 0,
98                   "factor should be >0 and a power of 2");
99
100     /* We need to add a most factor-1 and because factor is a power of 2,
101      * we get the result by masking out the bits corresponding to factor-1.
102      */
103     return (number + factor - 1) & ~(factor - 1);
104 }
105
106 /* Allocate an aligned pointer for SIMD operations, including extra elements
107  * at the end for padding.
108  */
109 /* TODO: Replace this SIMD reallocator with a general, C++ solution */
110 static void reallocSimdAlignedAndPadded(real** ptr, int unpaddedNumElements)
111 {
112     sfree_aligned(*ptr);
113     snew_aligned(*ptr,
114                  roundUpToMultipleOfFactor<c_simdWidth>(unpaddedNumElements),
115                  c_simdWidth * sizeof(real));
116 }
117
118 static void realloc_work(struct pme_solve_work_t* work, int nkx)
119 {
120     if (nkx > work->nalloc)
121     {
122         work->nalloc = nkx;
123         srenew(work->mhx, work->nalloc);
124         srenew(work->mhy, work->nalloc);
125         srenew(work->mhz, work->nalloc);
126         srenew(work->m2, work->nalloc);
127         reallocSimdAlignedAndPadded(&work->denom, work->nalloc);
128         reallocSimdAlignedAndPadded(&work->tmp1, work->nalloc);
129         reallocSimdAlignedAndPadded(&work->tmp2, work->nalloc);
130         reallocSimdAlignedAndPadded(&work->eterm, work->nalloc);
131         srenew(work->m2inv, work->nalloc);
132
133         /* Init all allocated elements of denom to 1 to avoid 1/0 exceptions
134          * of simd padded elements.
135          */
136         for (size_t i = 0; i < roundUpToMultipleOfFactor<c_simdWidth>(work->nalloc); i++)
137         {
138             work->denom[i] = 1;
139         }
140     }
141 }
142
143 void pme_init_all_work(struct pme_solve_work_t** work, int nthread, int nkx)
144 {
145     /* Use fft5d, order after FFT is y major, z, x minor */
146
147     snew(*work, nthread);
148     /* Allocate the work arrays thread local to optimize memory access */
149 #pragma omp parallel for num_threads(nthread) schedule(static)
150     for (int thread = 0; thread < nthread; thread++)
151     {
152         try
153         {
154             realloc_work(&((*work)[thread]), nkx);
155         }
156         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
157     }
158 }
159
160 static void free_work(struct pme_solve_work_t* work)
161 {
162     if (work)
163     {
164         sfree(work->mhx);
165         sfree(work->mhy);
166         sfree(work->mhz);
167         sfree(work->m2);
168         sfree_aligned(work->denom);
169         sfree_aligned(work->tmp1);
170         sfree_aligned(work->tmp2);
171         sfree_aligned(work->eterm);
172         sfree(work->m2inv);
173     }
174 }
175
176 void pme_free_all_work(struct pme_solve_work_t** work, int nthread)
177 {
178     if (*work)
179     {
180         for (int thread = 0; thread < nthread; thread++)
181         {
182             free_work(&(*work)[thread]);
183         }
184     }
185     sfree(*work);
186     *work = nullptr;
187 }
188
189 void get_pme_ener_vir_q(pme_solve_work_t* work, int nthread, PmeOutput* output)
190 {
191     GMX_ASSERT(output != nullptr, "Need valid output buffer");
192     /* This function sums output over threads and should therefore
193      * only be called after thread synchronization.
194      */
195     output->coulombEnergy_ = work[0].energy_q;
196     copy_mat(work[0].vir_q, output->coulombVirial_);
197
198     for (int thread = 1; thread < nthread; thread++)
199     {
200         output->coulombEnergy_ += work[thread].energy_q;
201         m_add(output->coulombVirial_, work[thread].vir_q, output->coulombVirial_);
202     }
203 }
204
205 void get_pme_ener_vir_lj(pme_solve_work_t* work, int nthread, PmeOutput* output)
206 {
207     GMX_ASSERT(output != nullptr, "Need valid output buffer");
208     /* This function sums output over threads and should therefore
209      * only be called after thread synchronization.
210      */
211     output->lennardJonesEnergy_ = work[0].energy_lj;
212     copy_mat(work[0].vir_lj, output->lennardJonesVirial_);
213
214     for (int thread = 1; thread < nthread; thread++)
215     {
216         output->lennardJonesEnergy_ += work[thread].energy_lj;
217         m_add(output->lennardJonesVirial_, work[thread].vir_lj, output->lennardJonesVirial_);
218     }
219 }
220
221 #if defined PME_SIMD_SOLVE
222 /* Calculate exponentials through SIMD */
223 inline static void calc_exponentials_q(int /*unused*/,
224                                        int /*unused*/,
225                                        real                     f,
226                                        ArrayRef<const SimdReal> d_aligned,
227                                        ArrayRef<const SimdReal> r_aligned,
228                                        ArrayRef<SimdReal>       e_aligned)
229 {
230     {
231         SimdReal f_simd(f);
232         SimdReal tmp_d1, tmp_r, tmp_e;
233
234         /* We only need to calculate from start. But since start is 0 or 1
235          * and we want to use aligned loads/stores, we always start from 0.
236          */
237         GMX_ASSERT(d_aligned.size() == r_aligned.size(), "d and r must have same size");
238         GMX_ASSERT(d_aligned.size() == e_aligned.size(), "d and e must have same size");
239         for (size_t kx = 0; kx != d_aligned.size(); ++kx)
240         {
241             tmp_d1        = d_aligned[kx];
242             tmp_r         = r_aligned[kx];
243             tmp_r         = gmx::exp(tmp_r);
244             tmp_e         = f_simd / tmp_d1;
245             tmp_e         = tmp_e * tmp_r;
246             e_aligned[kx] = tmp_e;
247         }
248     }
249 }
250 #else
251 inline static void
252 calc_exponentials_q(int start, int end, real f, ArrayRef<real> d, ArrayRef<real> r, ArrayRef<real> e)
253 {
254     GMX_ASSERT(d.size() == r.size(), "d and r must have same size");
255     GMX_ASSERT(d.size() == e.size(), "d and e must have same size");
256     int kx;
257     for (kx = start; kx < end; kx++)
258     {
259         d[kx] = 1.0 / d[kx];
260     }
261     for (kx = start; kx < end; kx++)
262     {
263         r[kx] = std::exp(r[kx]);
264     }
265     for (kx = start; kx < end; kx++)
266     {
267         e[kx] = f * r[kx] * d[kx];
268     }
269 }
270 #endif
271
272 #if defined PME_SIMD_SOLVE
273 /* Calculate exponentials through SIMD */
274 inline static void calc_exponentials_lj(int /*unused*/,
275                                         int /*unused*/,
276                                         ArrayRef<SimdReal> r_aligned,
277                                         ArrayRef<SimdReal> factor_aligned,
278                                         ArrayRef<SimdReal> d_aligned)
279 {
280     SimdReal       tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
281     const SimdReal sqr_PI = sqrt(SimdReal(M_PI));
282
283     GMX_ASSERT(d_aligned.size() == r_aligned.size(), "d and r must have same size");
284     GMX_ASSERT(d_aligned.size() == factor_aligned.size(), "d and factor must have same size");
285     for (size_t kx = 0; kx != d_aligned.size(); ++kx)
286     {
287         /* We only need to calculate from start. But since start is 0 or 1
288          * and we want to use aligned loads/stores, we always start from 0.
289          */
290         tmp_d              = d_aligned[kx];
291         d_inv              = SimdReal(1.0) / tmp_d;
292         d_aligned[kx]      = d_inv;
293         tmp_r              = r_aligned[kx];
294         tmp_r              = gmx::exp(tmp_r);
295         r_aligned[kx]      = tmp_r;
296         tmp_mk             = factor_aligned[kx];
297         tmp_fac            = sqr_PI * tmp_mk * erfc(tmp_mk);
298         factor_aligned[kx] = tmp_fac;
299     }
300 }
301 #else
302 inline static void
303 calc_exponentials_lj(int start, int end, ArrayRef<real> r, ArrayRef<real> tmp2, ArrayRef<real> d)
304 {
305     int  kx;
306     real mk;
307     GMX_ASSERT(d.size() == r.size(), "d and r must have same size");
308     GMX_ASSERT(d.size() == tmp2.size(), "d and tmp2 must have same size");
309     for (kx = start; kx < end; kx++)
310     {
311         d[kx] = 1.0 / d[kx];
312     }
313
314     for (kx = start; kx < end; kx++)
315     {
316         r[kx] = std::exp(r[kx]);
317     }
318
319     for (kx = start; kx < end; kx++)
320     {
321         mk       = tmp2[kx];
322         tmp2[kx] = sqrt(M_PI) * mk * std::erfc(mk);
323     }
324 }
325 #endif
326
327 #if defined PME_SIMD_SOLVE
328 using PME_T = SimdReal;
329 #else
330 using PME_T = real;
331 #endif
332
333 int solve_pme_yzx(const gmx_pme_t* pme, t_complex* grid, real vol, bool computeEnergyAndVirial, int nthread, int thread)
334 {
335     /* do recip sum over local cells in grid */
336     /* y major, z middle, x minor or continuous */
337     t_complex*               p0;
338     int                      kx, ky, kz, maxkx, maxky;
339     int                      nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
340     real                     mx, my, mz;
341     real                     ewaldcoeff = pme->ewaldcoeff_q;
342     real                     factor     = M_PI * M_PI / (ewaldcoeff * ewaldcoeff);
343     real                     ets2, struct2, vfactor, ets2vf;
344     real                     d1, d2, energy = 0;
345     real                     by, bz;
346     real                     virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
347     real                     rxx, ryx, ryy, rzx, rzy, rzz;
348     struct pme_solve_work_t* work;
349     real *                   mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
350     real                     mhxk, mhyk, mhzk, m2k;
351     real                     corner_fac;
352     ivec                     complex_order;
353     ivec                     local_ndata, local_offset, local_size;
354     real                     elfac;
355
356     elfac = ONE_4PI_EPS0 / pme->epsilon_r;
357
358     nx = pme->nkx;
359     ny = pme->nky;
360     nz = pme->nkz;
361
362     /* Dimensions should be identical for A/B grid, so we just use A here */
363     gmx_parallel_3dfft_complex_limits(
364             pme->pfft_setup[PME_GRID_QA], complex_order, local_ndata, local_offset, local_size);
365
366     rxx = pme->recipbox[XX][XX];
367     ryx = pme->recipbox[YY][XX];
368     ryy = pme->recipbox[YY][YY];
369     rzx = pme->recipbox[ZZ][XX];
370     rzy = pme->recipbox[ZZ][YY];
371     rzz = pme->recipbox[ZZ][ZZ];
372
373     GMX_ASSERT(rxx != 0.0, "Someone broke the reciprocal box again");
374
375     maxkx = (nx + 1) / 2;
376     maxky = (ny + 1) / 2;
377
378     work  = &pme->solve_work[thread];
379     mhx   = work->mhx;
380     mhy   = work->mhy;
381     mhz   = work->mhz;
382     m2    = work->m2;
383     denom = work->denom;
384     tmp1  = work->tmp1;
385     eterm = work->eterm;
386     m2inv = work->m2inv;
387
388     iyz0 = local_ndata[YY] * local_ndata[ZZ] * thread / nthread;
389     iyz1 = local_ndata[YY] * local_ndata[ZZ] * (thread + 1) / nthread;
390
391     for (iyz = iyz0; iyz < iyz1; iyz++)
392     {
393         iy = iyz / local_ndata[ZZ];
394         iz = iyz - iy * local_ndata[ZZ];
395
396         ky = iy + local_offset[YY];
397
398         if (ky < maxky)
399         {
400             my = ky;
401         }
402         else
403         {
404             my = (ky - ny);
405         }
406
407         by = M_PI * vol * pme->bsp_mod[YY][ky];
408
409         kz = iz + local_offset[ZZ];
410
411         mz = kz;
412
413         bz = pme->bsp_mod[ZZ][kz];
414
415         /* 0.5 correction for corner points */
416         corner_fac = 1;
417         if (kz == 0 || kz == (nz + 1) / 2)
418         {
419             corner_fac = 0.5;
420         }
421
422         p0 = grid + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
423
424         /* We should skip the k-space point (0,0,0) */
425         /* Note that since here x is the minor index, local_offset[XX]=0 */
426         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
427         {
428             kxstart = local_offset[XX];
429         }
430         else
431         {
432             kxstart = local_offset[XX] + 1;
433             p0++;
434         }
435         kxend = local_offset[XX] + local_ndata[XX];
436
437         if (computeEnergyAndVirial)
438         {
439             /* More expensive inner loop, especially because of the storage
440              * of the mh elements in array's.
441              * Because x is the minor grid index, all mh elements
442              * depend on kx for triclinic unit cells.
443              */
444
445             /* Two explicit loops to avoid a conditional inside the loop */
446             for (kx = kxstart; kx < maxkx; kx++)
447             {
448                 mx = kx;
449
450                 mhxk      = mx * rxx;
451                 mhyk      = mx * ryx + my * ryy;
452                 mhzk      = mx * rzx + my * rzy + mz * rzz;
453                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
454                 mhx[kx]   = mhxk;
455                 mhy[kx]   = mhyk;
456                 mhz[kx]   = mhzk;
457                 m2[kx]    = m2k;
458                 denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
459                 tmp1[kx]  = -factor * m2k;
460             }
461
462             for (kx = maxkx; kx < kxend; kx++)
463             {
464                 mx = (kx - nx);
465
466                 mhxk      = mx * rxx;
467                 mhyk      = mx * ryx + my * ryy;
468                 mhzk      = mx * rzx + my * rzy + mz * rzz;
469                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
470                 mhx[kx]   = mhxk;
471                 mhy[kx]   = mhyk;
472                 mhz[kx]   = mhzk;
473                 m2[kx]    = m2k;
474                 denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
475                 tmp1[kx]  = -factor * m2k;
476             }
477
478             for (kx = kxstart; kx < kxend; kx++)
479             {
480                 m2inv[kx] = 1.0 / m2[kx];
481             }
482
483             calc_exponentials_q(
484                     kxstart,
485                     kxend,
486                     elfac,
487                     ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
488                     ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
489                     ArrayRef<PME_T>(eterm, eterm + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
490
491             for (kx = kxstart; kx < kxend; kx++, p0++)
492             {
493                 d1 = p0->re;
494                 d2 = p0->im;
495
496                 p0->re = d1 * eterm[kx];
497                 p0->im = d2 * eterm[kx];
498
499                 struct2 = 2.0 * (d1 * d1 + d2 * d2);
500
501                 tmp1[kx] = eterm[kx] * struct2;
502             }
503
504             for (kx = kxstart; kx < kxend; kx++)
505             {
506                 ets2    = corner_fac * tmp1[kx];
507                 vfactor = (factor * m2[kx] + 1.0) * 2.0 * m2inv[kx];
508                 energy += ets2;
509
510                 ets2vf = ets2 * vfactor;
511                 virxx += ets2vf * mhx[kx] * mhx[kx] - ets2;
512                 virxy += ets2vf * mhx[kx] * mhy[kx];
513                 virxz += ets2vf * mhx[kx] * mhz[kx];
514                 viryy += ets2vf * mhy[kx] * mhy[kx] - ets2;
515                 viryz += ets2vf * mhy[kx] * mhz[kx];
516                 virzz += ets2vf * mhz[kx] * mhz[kx] - ets2;
517             }
518         }
519         else
520         {
521             /* We don't need to calculate the energy and the virial.
522              * In this case the triclinic overhead is small.
523              */
524
525             /* Two explicit loops to avoid a conditional inside the loop */
526
527             for (kx = kxstart; kx < maxkx; kx++)
528             {
529                 mx = kx;
530
531                 mhxk      = mx * rxx;
532                 mhyk      = mx * ryx + my * ryy;
533                 mhzk      = mx * rzx + my * rzy + mz * rzz;
534                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
535                 denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
536                 tmp1[kx]  = -factor * m2k;
537             }
538
539             for (kx = maxkx; kx < kxend; kx++)
540             {
541                 mx = (kx - nx);
542
543                 mhxk      = mx * rxx;
544                 mhyk      = mx * ryx + my * ryy;
545                 mhzk      = mx * rzx + my * rzy + mz * rzz;
546                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
547                 denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
548                 tmp1[kx]  = -factor * m2k;
549             }
550
551             calc_exponentials_q(
552                     kxstart,
553                     kxend,
554                     elfac,
555                     ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
556                     ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
557                     ArrayRef<PME_T>(eterm, eterm + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
558
559
560             for (kx = kxstart; kx < kxend; kx++, p0++)
561             {
562                 d1 = p0->re;
563                 d2 = p0->im;
564
565                 p0->re = d1 * eterm[kx];
566                 p0->im = d2 * eterm[kx];
567             }
568         }
569     }
570
571     if (computeEnergyAndVirial)
572     {
573         /* Update virial with local values.
574          * The virial is symmetric by definition.
575          * this virial seems ok for isotropic scaling, but I'm
576          * experiencing problems on semiisotropic membranes.
577          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
578          */
579         work->vir_q[XX][XX] = 0.25 * virxx;
580         work->vir_q[YY][YY] = 0.25 * viryy;
581         work->vir_q[ZZ][ZZ] = 0.25 * virzz;
582         work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25 * virxy;
583         work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25 * virxz;
584         work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25 * viryz;
585
586         /* This energy should be corrected for a charged system */
587         work->energy_q = 0.5 * energy;
588     }
589
590     /* Return the loop count */
591     return local_ndata[YY] * local_ndata[XX];
592 }
593
594 int solve_pme_lj_yzx(const gmx_pme_t* pme,
595                      t_complex**      grid,
596                      gmx_bool         bLB,
597                      real             vol,
598                      bool             computeEnergyAndVirial,
599                      int              nthread,
600                      int              thread)
601 {
602     /* do recip sum over local cells in grid */
603     /* y major, z middle, x minor or continuous */
604     int                      ig, gcount;
605     int                      kx, ky, kz, maxkx, maxky;
606     int                      nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
607     real                     mx, my, mz;
608     real                     ewaldcoeff = pme->ewaldcoeff_lj;
609     real                     factor     = M_PI * M_PI / (ewaldcoeff * ewaldcoeff);
610     real                     ets2, ets2vf;
611     real                     eterm, vterm, d1, d2, energy = 0;
612     real                     by, bz;
613     real                     virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
614     real                     rxx, ryx, ryy, rzx, rzy, rzz;
615     real *                   mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
616     real                     mhxk, mhyk, mhzk, m2k;
617     struct pme_solve_work_t* work;
618     real                     corner_fac;
619     ivec                     complex_order;
620     ivec                     local_ndata, local_offset, local_size;
621     nx = pme->nkx;
622     ny = pme->nky;
623     nz = pme->nkz;
624
625     /* Dimensions should be identical for A/B grid, so we just use A here */
626     gmx_parallel_3dfft_complex_limits(
627             pme->pfft_setup[PME_GRID_C6A], complex_order, local_ndata, local_offset, local_size);
628     rxx = pme->recipbox[XX][XX];
629     ryx = pme->recipbox[YY][XX];
630     ryy = pme->recipbox[YY][YY];
631     rzx = pme->recipbox[ZZ][XX];
632     rzy = pme->recipbox[ZZ][YY];
633     rzz = pme->recipbox[ZZ][ZZ];
634
635     maxkx = (nx + 1) / 2;
636     maxky = (ny + 1) / 2;
637
638     work  = &pme->solve_work[thread];
639     mhx   = work->mhx;
640     mhy   = work->mhy;
641     mhz   = work->mhz;
642     m2    = work->m2;
643     denom = work->denom;
644     tmp1  = work->tmp1;
645     tmp2  = work->tmp2;
646
647     iyz0 = local_ndata[YY] * local_ndata[ZZ] * thread / nthread;
648     iyz1 = local_ndata[YY] * local_ndata[ZZ] * (thread + 1) / nthread;
649
650     for (iyz = iyz0; iyz < iyz1; iyz++)
651     {
652         iy = iyz / local_ndata[ZZ];
653         iz = iyz - iy * local_ndata[ZZ];
654
655         ky = iy + local_offset[YY];
656
657         if (ky < maxky)
658         {
659             my = ky;
660         }
661         else
662         {
663             my = (ky - ny);
664         }
665
666         by = 3.0 * vol * pme->bsp_mod[YY][ky] / (M_PI * sqrt(M_PI) * ewaldcoeff * ewaldcoeff * ewaldcoeff);
667
668         kz = iz + local_offset[ZZ];
669
670         mz = kz;
671
672         bz = pme->bsp_mod[ZZ][kz];
673
674         /* 0.5 correction for corner points */
675         corner_fac = 1;
676         if (kz == 0 || kz == (nz + 1) / 2)
677         {
678             corner_fac = 0.5;
679         }
680
681         kxstart = local_offset[XX];
682         kxend   = local_offset[XX] + local_ndata[XX];
683         if (computeEnergyAndVirial)
684         {
685             /* More expensive inner loop, especially because of the
686              * storage of the mh elements in array's.  Because x is the
687              * minor grid index, all mh elements depend on kx for
688              * triclinic unit cells.
689              */
690
691             /* Two explicit loops to avoid a conditional inside the loop */
692             for (kx = kxstart; kx < maxkx; kx++)
693             {
694                 mx = kx;
695
696                 mhxk      = mx * rxx;
697                 mhyk      = mx * ryx + my * ryy;
698                 mhzk      = mx * rzx + my * rzy + mz * rzz;
699                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
700                 mhx[kx]   = mhxk;
701                 mhy[kx]   = mhyk;
702                 mhz[kx]   = mhzk;
703                 m2[kx]    = m2k;
704                 denom[kx] = bz * by * pme->bsp_mod[XX][kx];
705                 tmp1[kx]  = -factor * m2k;
706                 tmp2[kx]  = sqrt(factor * m2k);
707             }
708
709             for (kx = maxkx; kx < kxend; kx++)
710             {
711                 mx = (kx - nx);
712
713                 mhxk      = mx * rxx;
714                 mhyk      = mx * ryx + my * ryy;
715                 mhzk      = mx * rzx + my * rzy + mz * rzz;
716                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
717                 mhx[kx]   = mhxk;
718                 mhy[kx]   = mhyk;
719                 mhz[kx]   = mhzk;
720                 m2[kx]    = m2k;
721                 denom[kx] = bz * by * pme->bsp_mod[XX][kx];
722                 tmp1[kx]  = -factor * m2k;
723                 tmp2[kx]  = sqrt(factor * m2k);
724             }
725             /* Clear padding elements to avoid (harmless) fp exceptions */
726             const int kxendSimd = roundUpToMultipleOfFactor<c_simdWidth>(kxend);
727             for (; kx < kxendSimd; kx++)
728             {
729                 tmp1[kx] = 0;
730                 tmp2[kx] = 0;
731             }
732
733             calc_exponentials_lj(
734                     kxstart,
735                     kxend,
736                     ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
737                     ArrayRef<PME_T>(tmp2, tmp2 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
738                     ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
739
740             for (kx = kxstart; kx < kxend; kx++)
741             {
742                 m2k      = factor * m2[kx];
743                 eterm    = -((1.0 - 2.0 * m2k) * tmp1[kx] + 2.0 * m2k * tmp2[kx]);
744                 vterm    = 3.0 * (-tmp1[kx] + tmp2[kx]);
745                 tmp1[kx] = eterm * denom[kx];
746                 tmp2[kx] = vterm * denom[kx];
747             }
748
749             if (!bLB)
750             {
751                 t_complex* p0;
752                 real       struct2;
753
754                 p0 = grid[0] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
755                 for (kx = kxstart; kx < kxend; kx++, p0++)
756                 {
757                     d1 = p0->re;
758                     d2 = p0->im;
759
760                     eterm  = tmp1[kx];
761                     vterm  = tmp2[kx];
762                     p0->re = d1 * eterm;
763                     p0->im = d2 * eterm;
764
765                     struct2 = 2.0 * (d1 * d1 + d2 * d2);
766
767                     tmp1[kx] = eterm * struct2;
768                     tmp2[kx] = vterm * struct2;
769                 }
770             }
771             else
772             {
773                 real* struct2 = denom;
774                 real  str2;
775
776                 for (kx = kxstart; kx < kxend; kx++)
777                 {
778                     struct2[kx] = 0.0;
779                 }
780                 /* Due to symmetry we only need to calculate 4 of the 7 terms */
781                 for (ig = 0; ig <= 3; ++ig)
782                 {
783                     t_complex *p0, *p1;
784                     real       scale;
785
786                     p0 = grid[ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
787                     p1 = grid[6 - ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
788                     scale = 2.0 * lb_scale_factor_symm[ig];
789                     for (kx = kxstart; kx < kxend; ++kx, ++p0, ++p1)
790                     {
791                         struct2[kx] += scale * (p0->re * p1->re + p0->im * p1->im);
792                     }
793                 }
794                 for (ig = 0; ig <= 6; ++ig)
795                 {
796                     t_complex* p0;
797
798                     p0 = grid[ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
799                     for (kx = kxstart; kx < kxend; kx++, p0++)
800                     {
801                         d1 = p0->re;
802                         d2 = p0->im;
803
804                         eterm  = tmp1[kx];
805                         p0->re = d1 * eterm;
806                         p0->im = d2 * eterm;
807                     }
808                 }
809                 for (kx = kxstart; kx < kxend; kx++)
810                 {
811                     eterm    = tmp1[kx];
812                     vterm    = tmp2[kx];
813                     str2     = struct2[kx];
814                     tmp1[kx] = eterm * str2;
815                     tmp2[kx] = vterm * str2;
816                 }
817             }
818
819             for (kx = kxstart; kx < kxend; kx++)
820             {
821                 ets2  = corner_fac * tmp1[kx];
822                 vterm = 2.0 * factor * tmp2[kx];
823                 energy += ets2;
824                 ets2vf = corner_fac * vterm;
825                 virxx += ets2vf * mhx[kx] * mhx[kx] - ets2;
826                 virxy += ets2vf * mhx[kx] * mhy[kx];
827                 virxz += ets2vf * mhx[kx] * mhz[kx];
828                 viryy += ets2vf * mhy[kx] * mhy[kx] - ets2;
829                 viryz += ets2vf * mhy[kx] * mhz[kx];
830                 virzz += ets2vf * mhz[kx] * mhz[kx] - ets2;
831             }
832         }
833         else
834         {
835             /* We don't need to calculate the energy and the virial.
836              *  In this case the triclinic overhead is small.
837              */
838
839             /* Two explicit loops to avoid a conditional inside the loop */
840
841             for (kx = kxstart; kx < maxkx; kx++)
842             {
843                 mx = kx;
844
845                 mhxk      = mx * rxx;
846                 mhyk      = mx * ryx + my * ryy;
847                 mhzk      = mx * rzx + my * rzy + mz * rzz;
848                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
849                 m2[kx]    = m2k;
850                 denom[kx] = bz * by * pme->bsp_mod[XX][kx];
851                 tmp1[kx]  = -factor * m2k;
852                 tmp2[kx]  = sqrt(factor * m2k);
853             }
854
855             for (kx = maxkx; kx < kxend; kx++)
856             {
857                 mx = (kx - nx);
858
859                 mhxk      = mx * rxx;
860                 mhyk      = mx * ryx + my * ryy;
861                 mhzk      = mx * rzx + my * rzy + mz * rzz;
862                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
863                 m2[kx]    = m2k;
864                 denom[kx] = bz * by * pme->bsp_mod[XX][kx];
865                 tmp1[kx]  = -factor * m2k;
866                 tmp2[kx]  = sqrt(factor * m2k);
867             }
868             /* Clear padding elements to avoid (harmless) fp exceptions */
869             const int kxendSimd = roundUpToMultipleOfFactor<c_simdWidth>(kxend);
870             for (; kx < kxendSimd; kx++)
871             {
872                 tmp1[kx] = 0;
873                 tmp2[kx] = 0;
874             }
875
876             calc_exponentials_lj(
877                     kxstart,
878                     kxend,
879                     ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
880                     ArrayRef<PME_T>(tmp2, tmp2 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
881                     ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
882
883             for (kx = kxstart; kx < kxend; kx++)
884             {
885                 m2k      = factor * m2[kx];
886                 eterm    = -((1.0 - 2.0 * m2k) * tmp1[kx] + 2.0 * m2k * tmp2[kx]);
887                 tmp1[kx] = eterm * denom[kx];
888             }
889             gcount = (bLB ? 7 : 1);
890             for (ig = 0; ig < gcount; ++ig)
891             {
892                 t_complex* p0;
893
894                 p0 = grid[ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
895                 for (kx = kxstart; kx < kxend; kx++, p0++)
896                 {
897                     d1 = p0->re;
898                     d2 = p0->im;
899
900                     eterm = tmp1[kx];
901
902                     p0->re = d1 * eterm;
903                     p0->im = d2 * eterm;
904                 }
905             }
906         }
907     }
908     if (computeEnergyAndVirial)
909     {
910         work->vir_lj[XX][XX] = 0.25 * virxx;
911         work->vir_lj[YY][YY] = 0.25 * viryy;
912         work->vir_lj[ZZ][ZZ] = 0.25 * virzz;
913         work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25 * virxy;
914         work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25 * virxz;
915         work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25 * viryz;
916
917         /* This energy should be corrected for a charged system */
918         work->energy_lj = 0.5 * energy;
919     }
920     /* Return the loop count */
921     return local_ndata[YY] * local_ndata[XX];
922 }