Add gmx::isPowerOfTwo function
[alexxy/gromacs.git] / src / gromacs / ewald / pme_solve.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017 by the GROMACS development team.
7  * Copyright (c) 2018,2019,2020,2021, by the GROMACS development team, led by
8  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
9  * and including many others, as listed in the AUTHORS file in the
10  * top-level source directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38
39 #include "gmxpre.h"
40
41 #include "pme_solve.h"
42
43 #include <cmath>
44
45 #include "gromacs/fft/parallel_3dfft.h"
46 #include "gromacs/math/units.h"
47 #include "gromacs/math/utilities.h"
48 #include "gromacs/math/vec.h"
49 #include "gromacs/simd/simd.h"
50 #include "gromacs/simd/simd_math.h"
51 #include "gromacs/utility/arrayref.h"
52 #include "gromacs/utility/exceptions.h"
53 #include "gromacs/utility/smalloc.h"
54
55 #include "pme_internal.h"
56 #include "pme_output.h"
57
58 #if GMX_SIMD_HAVE_REAL
59 /* Turn on arbitrary width SIMD intrinsics for PME solve */
60 #    define PME_SIMD_SOLVE
61 #endif
62
63 using namespace gmx; // TODO: Remove when this file is moved into gmx namespace
64
65 struct pme_solve_work_t
66 {
67     /* work data for solve_pme */
68     int   nalloc;
69     real* mhx;
70     real* mhy;
71     real* mhz;
72     real* m2;
73     real* denom;
74     real* tmp1_alloc;
75     real* tmp1;
76     real* tmp2;
77     real* eterm;
78     real* m2inv;
79
80     real   energy_q;
81     matrix vir_q;
82     real   energy_lj;
83     matrix vir_lj;
84 };
85
86 #ifdef PME_SIMD_SOLVE
87 constexpr int c_simdWidth = GMX_SIMD_REAL_WIDTH;
88 #else
89 /* We can use any alignment > 0, so we use 4 */
90 constexpr int c_simdWidth = 4;
91 #endif
92
93 /* Returns the smallest number >= \p that is a multiple of \p factor, \p factor must be a power of 2 */
94 template<unsigned int factor>
95 static size_t roundUpToMultipleOfFactor(size_t number)
96 {
97     static_assert(gmx::isPowerOfTwo(factor));
98
99     /* We need to add a most factor-1 and because factor is a power of 2,
100      * we get the result by masking out the bits corresponding to factor-1.
101      */
102     return (number + factor - 1) & ~(factor - 1);
103 }
104
105 /* Allocate an aligned pointer for SIMD operations, including extra elements
106  * at the end for padding.
107  */
108 /* TODO: Replace this SIMD reallocator with a general, C++ solution */
109 static void reallocSimdAlignedAndPadded(real** ptr, int unpaddedNumElements)
110 {
111     sfree_aligned(*ptr);
112     snew_aligned(*ptr,
113                  roundUpToMultipleOfFactor<c_simdWidth>(unpaddedNumElements),
114                  c_simdWidth * sizeof(real));
115 }
116
117 static void realloc_work(struct pme_solve_work_t* work, int nkx)
118 {
119     if (nkx > work->nalloc)
120     {
121         work->nalloc = nkx;
122         srenew(work->mhx, work->nalloc);
123         srenew(work->mhy, work->nalloc);
124         srenew(work->mhz, work->nalloc);
125         srenew(work->m2, work->nalloc);
126         reallocSimdAlignedAndPadded(&work->denom, work->nalloc);
127         reallocSimdAlignedAndPadded(&work->tmp1, work->nalloc);
128         reallocSimdAlignedAndPadded(&work->tmp2, work->nalloc);
129         reallocSimdAlignedAndPadded(&work->eterm, work->nalloc);
130         srenew(work->m2inv, work->nalloc);
131
132         /* Init all allocated elements of denom to 1 to avoid 1/0 exceptions
133          * of simd padded elements.
134          */
135         for (size_t i = 0; i < roundUpToMultipleOfFactor<c_simdWidth>(work->nalloc); i++)
136         {
137             work->denom[i] = 1;
138         }
139     }
140 }
141
142 void pme_init_all_work(struct pme_solve_work_t** work, int nthread, int nkx)
143 {
144     /* Use fft5d, order after FFT is y major, z, x minor */
145
146     snew(*work, nthread);
147     /* Allocate the work arrays thread local to optimize memory access */
148 #pragma omp parallel for num_threads(nthread) schedule(static)
149     for (int thread = 0; thread < nthread; thread++)
150     {
151         try
152         {
153             realloc_work(&((*work)[thread]), nkx);
154         }
155         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
156     }
157 }
158
159 static void free_work(struct pme_solve_work_t* work)
160 {
161     if (work)
162     {
163         sfree(work->mhx);
164         sfree(work->mhy);
165         sfree(work->mhz);
166         sfree(work->m2);
167         sfree_aligned(work->denom);
168         sfree_aligned(work->tmp1);
169         sfree_aligned(work->tmp2);
170         sfree_aligned(work->eterm);
171         sfree(work->m2inv);
172     }
173 }
174
175 void pme_free_all_work(struct pme_solve_work_t** work, int nthread)
176 {
177     if (*work)
178     {
179         for (int thread = 0; thread < nthread; thread++)
180         {
181             free_work(&(*work)[thread]);
182         }
183     }
184     sfree(*work);
185     *work = nullptr;
186 }
187
188 void get_pme_ener_vir_q(pme_solve_work_t* work, int nthread, PmeOutput* output)
189 {
190     GMX_ASSERT(output != nullptr, "Need valid output buffer");
191     /* This function sums output over threads and should therefore
192      * only be called after thread synchronization.
193      */
194     output->coulombEnergy_ = work[0].energy_q;
195     copy_mat(work[0].vir_q, output->coulombVirial_);
196
197     for (int thread = 1; thread < nthread; thread++)
198     {
199         output->coulombEnergy_ += work[thread].energy_q;
200         m_add(output->coulombVirial_, work[thread].vir_q, output->coulombVirial_);
201     }
202 }
203
204 void get_pme_ener_vir_lj(pme_solve_work_t* work, int nthread, PmeOutput* output)
205 {
206     GMX_ASSERT(output != nullptr, "Need valid output buffer");
207     /* This function sums output over threads and should therefore
208      * only be called after thread synchronization.
209      */
210     output->lennardJonesEnergy_ = work[0].energy_lj;
211     copy_mat(work[0].vir_lj, output->lennardJonesVirial_);
212
213     for (int thread = 1; thread < nthread; thread++)
214     {
215         output->lennardJonesEnergy_ += work[thread].energy_lj;
216         m_add(output->lennardJonesVirial_, work[thread].vir_lj, output->lennardJonesVirial_);
217     }
218 }
219
220 #if defined PME_SIMD_SOLVE
221 /* Calculate exponentials through SIMD */
222 inline static void calc_exponentials_q(int /*unused*/,
223                                        int /*unused*/,
224                                        real                     f,
225                                        ArrayRef<const SimdReal> d_aligned,
226                                        ArrayRef<const SimdReal> r_aligned,
227                                        ArrayRef<SimdReal>       e_aligned)
228 {
229     {
230         SimdReal f_simd(f);
231         SimdReal tmp_d1, tmp_r, tmp_e;
232
233         /* We only need to calculate from start. But since start is 0 or 1
234          * and we want to use aligned loads/stores, we always start from 0.
235          */
236         GMX_ASSERT(d_aligned.size() == r_aligned.size(), "d and r must have same size");
237         GMX_ASSERT(d_aligned.size() == e_aligned.size(), "d and e must have same size");
238         for (size_t kx = 0; kx != d_aligned.size(); ++kx)
239         {
240             tmp_d1        = d_aligned[kx];
241             tmp_r         = r_aligned[kx];
242             tmp_r         = gmx::exp(tmp_r);
243             tmp_e         = f_simd / tmp_d1;
244             tmp_e         = tmp_e * tmp_r;
245             e_aligned[kx] = tmp_e;
246         }
247     }
248 }
249 #else
250 inline static void
251 calc_exponentials_q(int start, int end, real f, ArrayRef<real> d, ArrayRef<real> r, ArrayRef<real> e)
252 {
253     GMX_ASSERT(d.size() == r.size(), "d and r must have same size");
254     GMX_ASSERT(d.size() == e.size(), "d and e must have same size");
255     int kx;
256     for (kx = start; kx < end; kx++)
257     {
258         d[kx] = 1.0 / d[kx];
259     }
260     for (kx = start; kx < end; kx++)
261     {
262         r[kx] = std::exp(r[kx]);
263     }
264     for (kx = start; kx < end; kx++)
265     {
266         e[kx] = f * r[kx] * d[kx];
267     }
268 }
269 #endif
270
271 #if defined PME_SIMD_SOLVE
272 /* Calculate exponentials through SIMD */
273 inline static void calc_exponentials_lj(int /*unused*/,
274                                         int /*unused*/,
275                                         ArrayRef<SimdReal> r_aligned,
276                                         ArrayRef<SimdReal> factor_aligned,
277                                         ArrayRef<SimdReal> d_aligned)
278 {
279     SimdReal       tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
280     const SimdReal sqr_PI = sqrt(SimdReal(M_PI));
281
282     GMX_ASSERT(d_aligned.size() == r_aligned.size(), "d and r must have same size");
283     GMX_ASSERT(d_aligned.size() == factor_aligned.size(), "d and factor must have same size");
284     for (size_t kx = 0; kx != d_aligned.size(); ++kx)
285     {
286         /* We only need to calculate from start. But since start is 0 or 1
287          * and we want to use aligned loads/stores, we always start from 0.
288          */
289         tmp_d              = d_aligned[kx];
290         d_inv              = SimdReal(1.0) / tmp_d;
291         d_aligned[kx]      = d_inv;
292         tmp_r              = r_aligned[kx];
293         tmp_r              = gmx::exp(tmp_r);
294         r_aligned[kx]      = tmp_r;
295         tmp_mk             = factor_aligned[kx];
296         tmp_fac            = sqr_PI * tmp_mk * erfc(tmp_mk);
297         factor_aligned[kx] = tmp_fac;
298     }
299 }
300 #else
301 inline static void
302 calc_exponentials_lj(int start, int end, ArrayRef<real> r, ArrayRef<real> tmp2, ArrayRef<real> d)
303 {
304     int  kx;
305     real mk;
306     GMX_ASSERT(d.size() == r.size(), "d and r must have same size");
307     GMX_ASSERT(d.size() == tmp2.size(), "d and tmp2 must have same size");
308     for (kx = start; kx < end; kx++)
309     {
310         d[kx] = 1.0 / d[kx];
311     }
312
313     for (kx = start; kx < end; kx++)
314     {
315         r[kx] = std::exp(r[kx]);
316     }
317
318     for (kx = start; kx < end; kx++)
319     {
320         mk       = tmp2[kx];
321         tmp2[kx] = sqrt(M_PI) * mk * std::erfc(mk);
322     }
323 }
324 #endif
325
326 #if defined PME_SIMD_SOLVE
327 using PME_T = SimdReal;
328 #else
329 using PME_T = real;
330 #endif
331
332 int solve_pme_yzx(const gmx_pme_t* pme, t_complex* grid, real vol, bool computeEnergyAndVirial, int nthread, int thread)
333 {
334     /* do recip sum over local cells in grid */
335     /* y major, z middle, x minor or continuous */
336     t_complex*               p0;
337     int                      kx, ky, kz, maxkx, maxky;
338     int                      nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
339     real                     mx, my, mz;
340     real                     ewaldcoeff = pme->ewaldcoeff_q;
341     real                     factor     = M_PI * M_PI / (ewaldcoeff * ewaldcoeff);
342     real                     ets2, struct2, vfactor, ets2vf;
343     real                     d1, d2, energy = 0;
344     real                     by, bz;
345     real                     virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
346     real                     rxx, ryx, ryy, rzx, rzy, rzz;
347     struct pme_solve_work_t* work;
348     real *                   mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
349     real                     mhxk, mhyk, mhzk, m2k;
350     real                     corner_fac;
351     ivec                     complex_order;
352     ivec                     local_ndata, local_offset, local_size;
353     real                     elfac;
354
355     elfac = ONE_4PI_EPS0 / pme->epsilon_r;
356
357     nx = pme->nkx;
358     ny = pme->nky;
359     nz = pme->nkz;
360
361     /* Dimensions should be identical for A/B grid, so we just use A here */
362     gmx_parallel_3dfft_complex_limits(
363             pme->pfft_setup[PME_GRID_QA], complex_order, local_ndata, local_offset, local_size);
364
365     rxx = pme->recipbox[XX][XX];
366     ryx = pme->recipbox[YY][XX];
367     ryy = pme->recipbox[YY][YY];
368     rzx = pme->recipbox[ZZ][XX];
369     rzy = pme->recipbox[ZZ][YY];
370     rzz = pme->recipbox[ZZ][ZZ];
371
372     GMX_ASSERT(rxx != 0.0, "Someone broke the reciprocal box again");
373
374     maxkx = (nx + 1) / 2;
375     maxky = (ny + 1) / 2;
376
377     work  = &pme->solve_work[thread];
378     mhx   = work->mhx;
379     mhy   = work->mhy;
380     mhz   = work->mhz;
381     m2    = work->m2;
382     denom = work->denom;
383     tmp1  = work->tmp1;
384     eterm = work->eterm;
385     m2inv = work->m2inv;
386
387     iyz0 = local_ndata[YY] * local_ndata[ZZ] * thread / nthread;
388     iyz1 = local_ndata[YY] * local_ndata[ZZ] * (thread + 1) / nthread;
389
390     for (iyz = iyz0; iyz < iyz1; iyz++)
391     {
392         iy = iyz / local_ndata[ZZ];
393         iz = iyz - iy * local_ndata[ZZ];
394
395         ky = iy + local_offset[YY];
396
397         if (ky < maxky)
398         {
399             my = ky;
400         }
401         else
402         {
403             my = (ky - ny);
404         }
405
406         by = M_PI * vol * pme->bsp_mod[YY][ky];
407
408         kz = iz + local_offset[ZZ];
409
410         mz = kz;
411
412         bz = pme->bsp_mod[ZZ][kz];
413
414         /* 0.5 correction for corner points */
415         corner_fac = 1;
416         if (kz == 0 || kz == (nz + 1) / 2)
417         {
418             corner_fac = 0.5;
419         }
420
421         p0 = grid + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
422
423         /* We should skip the k-space point (0,0,0) */
424         /* Note that since here x is the minor index, local_offset[XX]=0 */
425         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
426         {
427             kxstart = local_offset[XX];
428         }
429         else
430         {
431             kxstart = local_offset[XX] + 1;
432             p0++;
433         }
434         kxend = local_offset[XX] + local_ndata[XX];
435
436         if (computeEnergyAndVirial)
437         {
438             /* More expensive inner loop, especially because of the storage
439              * of the mh elements in array's.
440              * Because x is the minor grid index, all mh elements
441              * depend on kx for triclinic unit cells.
442              */
443
444             /* Two explicit loops to avoid a conditional inside the loop */
445             for (kx = kxstart; kx < maxkx; kx++)
446             {
447                 mx = kx;
448
449                 mhxk      = mx * rxx;
450                 mhyk      = mx * ryx + my * ryy;
451                 mhzk      = mx * rzx + my * rzy + mz * rzz;
452                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
453                 mhx[kx]   = mhxk;
454                 mhy[kx]   = mhyk;
455                 mhz[kx]   = mhzk;
456                 m2[kx]    = m2k;
457                 denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
458                 tmp1[kx]  = -factor * m2k;
459             }
460
461             for (kx = maxkx; kx < kxend; kx++)
462             {
463                 mx = (kx - nx);
464
465                 mhxk      = mx * rxx;
466                 mhyk      = mx * ryx + my * ryy;
467                 mhzk      = mx * rzx + my * rzy + mz * rzz;
468                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
469                 mhx[kx]   = mhxk;
470                 mhy[kx]   = mhyk;
471                 mhz[kx]   = mhzk;
472                 m2[kx]    = m2k;
473                 denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
474                 tmp1[kx]  = -factor * m2k;
475             }
476
477             for (kx = kxstart; kx < kxend; kx++)
478             {
479                 m2inv[kx] = 1.0 / m2[kx];
480             }
481
482             calc_exponentials_q(
483                     kxstart,
484                     kxend,
485                     elfac,
486                     ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
487                     ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
488                     ArrayRef<PME_T>(eterm, eterm + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
489
490             for (kx = kxstart; kx < kxend; kx++, p0++)
491             {
492                 d1 = p0->re;
493                 d2 = p0->im;
494
495                 p0->re = d1 * eterm[kx];
496                 p0->im = d2 * eterm[kx];
497
498                 struct2 = 2.0 * (d1 * d1 + d2 * d2);
499
500                 tmp1[kx] = eterm[kx] * struct2;
501             }
502
503             for (kx = kxstart; kx < kxend; kx++)
504             {
505                 ets2    = corner_fac * tmp1[kx];
506                 vfactor = (factor * m2[kx] + 1.0) * 2.0 * m2inv[kx];
507                 energy += ets2;
508
509                 ets2vf = ets2 * vfactor;
510                 virxx += ets2vf * mhx[kx] * mhx[kx] - ets2;
511                 virxy += ets2vf * mhx[kx] * mhy[kx];
512                 virxz += ets2vf * mhx[kx] * mhz[kx];
513                 viryy += ets2vf * mhy[kx] * mhy[kx] - ets2;
514                 viryz += ets2vf * mhy[kx] * mhz[kx];
515                 virzz += ets2vf * mhz[kx] * mhz[kx] - ets2;
516             }
517         }
518         else
519         {
520             /* We don't need to calculate the energy and the virial.
521              * In this case the triclinic overhead is small.
522              */
523
524             /* Two explicit loops to avoid a conditional inside the loop */
525
526             for (kx = kxstart; kx < maxkx; kx++)
527             {
528                 mx = kx;
529
530                 mhxk      = mx * rxx;
531                 mhyk      = mx * ryx + my * ryy;
532                 mhzk      = mx * rzx + my * rzy + mz * rzz;
533                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
534                 denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
535                 tmp1[kx]  = -factor * m2k;
536             }
537
538             for (kx = maxkx; kx < kxend; kx++)
539             {
540                 mx = (kx - nx);
541
542                 mhxk      = mx * rxx;
543                 mhyk      = mx * ryx + my * ryy;
544                 mhzk      = mx * rzx + my * rzy + mz * rzz;
545                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
546                 denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
547                 tmp1[kx]  = -factor * m2k;
548             }
549
550             calc_exponentials_q(
551                     kxstart,
552                     kxend,
553                     elfac,
554                     ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
555                     ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
556                     ArrayRef<PME_T>(eterm, eterm + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
557
558
559             for (kx = kxstart; kx < kxend; kx++, p0++)
560             {
561                 d1 = p0->re;
562                 d2 = p0->im;
563
564                 p0->re = d1 * eterm[kx];
565                 p0->im = d2 * eterm[kx];
566             }
567         }
568     }
569
570     if (computeEnergyAndVirial)
571     {
572         /* Update virial with local values.
573          * The virial is symmetric by definition.
574          * this virial seems ok for isotropic scaling, but I'm
575          * experiencing problems on semiisotropic membranes.
576          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
577          */
578         work->vir_q[XX][XX] = 0.25 * virxx;
579         work->vir_q[YY][YY] = 0.25 * viryy;
580         work->vir_q[ZZ][ZZ] = 0.25 * virzz;
581         work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25 * virxy;
582         work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25 * virxz;
583         work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25 * viryz;
584
585         /* This energy should be corrected for a charged system */
586         work->energy_q = 0.5 * energy;
587     }
588
589     /* Return the loop count */
590     return local_ndata[YY] * local_ndata[XX];
591 }
592
593 int solve_pme_lj_yzx(const gmx_pme_t* pme,
594                      t_complex**      grid,
595                      gmx_bool         bLB,
596                      real             vol,
597                      bool             computeEnergyAndVirial,
598                      int              nthread,
599                      int              thread)
600 {
601     /* do recip sum over local cells in grid */
602     /* y major, z middle, x minor or continuous */
603     int                      ig, gcount;
604     int                      kx, ky, kz, maxkx, maxky;
605     int                      nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
606     real                     mx, my, mz;
607     real                     ewaldcoeff = pme->ewaldcoeff_lj;
608     real                     factor     = M_PI * M_PI / (ewaldcoeff * ewaldcoeff);
609     real                     ets2, ets2vf;
610     real                     eterm, vterm, d1, d2, energy = 0;
611     real                     by, bz;
612     real                     virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
613     real                     rxx, ryx, ryy, rzx, rzy, rzz;
614     real *                   mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
615     real                     mhxk, mhyk, mhzk, m2k;
616     struct pme_solve_work_t* work;
617     real                     corner_fac;
618     ivec                     complex_order;
619     ivec                     local_ndata, local_offset, local_size;
620     nx = pme->nkx;
621     ny = pme->nky;
622     nz = pme->nkz;
623
624     /* Dimensions should be identical for A/B grid, so we just use A here */
625     gmx_parallel_3dfft_complex_limits(
626             pme->pfft_setup[PME_GRID_C6A], complex_order, local_ndata, local_offset, local_size);
627     rxx = pme->recipbox[XX][XX];
628     ryx = pme->recipbox[YY][XX];
629     ryy = pme->recipbox[YY][YY];
630     rzx = pme->recipbox[ZZ][XX];
631     rzy = pme->recipbox[ZZ][YY];
632     rzz = pme->recipbox[ZZ][ZZ];
633
634     maxkx = (nx + 1) / 2;
635     maxky = (ny + 1) / 2;
636
637     work  = &pme->solve_work[thread];
638     mhx   = work->mhx;
639     mhy   = work->mhy;
640     mhz   = work->mhz;
641     m2    = work->m2;
642     denom = work->denom;
643     tmp1  = work->tmp1;
644     tmp2  = work->tmp2;
645
646     iyz0 = local_ndata[YY] * local_ndata[ZZ] * thread / nthread;
647     iyz1 = local_ndata[YY] * local_ndata[ZZ] * (thread + 1) / nthread;
648
649     for (iyz = iyz0; iyz < iyz1; iyz++)
650     {
651         iy = iyz / local_ndata[ZZ];
652         iz = iyz - iy * local_ndata[ZZ];
653
654         ky = iy + local_offset[YY];
655
656         if (ky < maxky)
657         {
658             my = ky;
659         }
660         else
661         {
662             my = (ky - ny);
663         }
664
665         by = 3.0 * vol * pme->bsp_mod[YY][ky] / (M_PI * sqrt(M_PI) * ewaldcoeff * ewaldcoeff * ewaldcoeff);
666
667         kz = iz + local_offset[ZZ];
668
669         mz = kz;
670
671         bz = pme->bsp_mod[ZZ][kz];
672
673         /* 0.5 correction for corner points */
674         corner_fac = 1;
675         if (kz == 0 || kz == (nz + 1) / 2)
676         {
677             corner_fac = 0.5;
678         }
679
680         kxstart = local_offset[XX];
681         kxend   = local_offset[XX] + local_ndata[XX];
682         if (computeEnergyAndVirial)
683         {
684             /* More expensive inner loop, especially because of the
685              * storage of the mh elements in array's.  Because x is the
686              * minor grid index, all mh elements depend on kx for
687              * triclinic unit cells.
688              */
689
690             /* Two explicit loops to avoid a conditional inside the loop */
691             for (kx = kxstart; kx < maxkx; kx++)
692             {
693                 mx = kx;
694
695                 mhxk      = mx * rxx;
696                 mhyk      = mx * ryx + my * ryy;
697                 mhzk      = mx * rzx + my * rzy + mz * rzz;
698                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
699                 mhx[kx]   = mhxk;
700                 mhy[kx]   = mhyk;
701                 mhz[kx]   = mhzk;
702                 m2[kx]    = m2k;
703                 denom[kx] = bz * by * pme->bsp_mod[XX][kx];
704                 tmp1[kx]  = -factor * m2k;
705                 tmp2[kx]  = sqrt(factor * m2k);
706             }
707
708             for (kx = maxkx; kx < kxend; kx++)
709             {
710                 mx = (kx - nx);
711
712                 mhxk      = mx * rxx;
713                 mhyk      = mx * ryx + my * ryy;
714                 mhzk      = mx * rzx + my * rzy + mz * rzz;
715                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
716                 mhx[kx]   = mhxk;
717                 mhy[kx]   = mhyk;
718                 mhz[kx]   = mhzk;
719                 m2[kx]    = m2k;
720                 denom[kx] = bz * by * pme->bsp_mod[XX][kx];
721                 tmp1[kx]  = -factor * m2k;
722                 tmp2[kx]  = sqrt(factor * m2k);
723             }
724             /* Clear padding elements to avoid (harmless) fp exceptions */
725             const int kxendSimd = roundUpToMultipleOfFactor<c_simdWidth>(kxend);
726             for (; kx < kxendSimd; kx++)
727             {
728                 tmp1[kx] = 0;
729                 tmp2[kx] = 0;
730             }
731
732             calc_exponentials_lj(
733                     kxstart,
734                     kxend,
735                     ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
736                     ArrayRef<PME_T>(tmp2, tmp2 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
737                     ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
738
739             for (kx = kxstart; kx < kxend; kx++)
740             {
741                 m2k      = factor * m2[kx];
742                 eterm    = -((1.0 - 2.0 * m2k) * tmp1[kx] + 2.0 * m2k * tmp2[kx]);
743                 vterm    = 3.0 * (-tmp1[kx] + tmp2[kx]);
744                 tmp1[kx] = eterm * denom[kx];
745                 tmp2[kx] = vterm * denom[kx];
746             }
747
748             if (!bLB)
749             {
750                 t_complex* p0;
751                 real       struct2;
752
753                 p0 = grid[0] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
754                 for (kx = kxstart; kx < kxend; kx++, p0++)
755                 {
756                     d1 = p0->re;
757                     d2 = p0->im;
758
759                     eterm  = tmp1[kx];
760                     vterm  = tmp2[kx];
761                     p0->re = d1 * eterm;
762                     p0->im = d2 * eterm;
763
764                     struct2 = 2.0 * (d1 * d1 + d2 * d2);
765
766                     tmp1[kx] = eterm * struct2;
767                     tmp2[kx] = vterm * struct2;
768                 }
769             }
770             else
771             {
772                 real* struct2 = denom;
773                 real  str2;
774
775                 for (kx = kxstart; kx < kxend; kx++)
776                 {
777                     struct2[kx] = 0.0;
778                 }
779                 /* Due to symmetry we only need to calculate 4 of the 7 terms */
780                 for (ig = 0; ig <= 3; ++ig)
781                 {
782                     t_complex *p0, *p1;
783                     real       scale;
784
785                     p0 = grid[ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
786                     p1 = grid[6 - ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
787                     scale = 2.0 * lb_scale_factor_symm[ig];
788                     for (kx = kxstart; kx < kxend; ++kx, ++p0, ++p1)
789                     {
790                         struct2[kx] += scale * (p0->re * p1->re + p0->im * p1->im);
791                     }
792                 }
793                 for (ig = 0; ig <= 6; ++ig)
794                 {
795                     t_complex* p0;
796
797                     p0 = grid[ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
798                     for (kx = kxstart; kx < kxend; kx++, p0++)
799                     {
800                         d1 = p0->re;
801                         d2 = p0->im;
802
803                         eterm  = tmp1[kx];
804                         p0->re = d1 * eterm;
805                         p0->im = d2 * eterm;
806                     }
807                 }
808                 for (kx = kxstart; kx < kxend; kx++)
809                 {
810                     eterm    = tmp1[kx];
811                     vterm    = tmp2[kx];
812                     str2     = struct2[kx];
813                     tmp1[kx] = eterm * str2;
814                     tmp2[kx] = vterm * str2;
815                 }
816             }
817
818             for (kx = kxstart; kx < kxend; kx++)
819             {
820                 ets2  = corner_fac * tmp1[kx];
821                 vterm = 2.0 * factor * tmp2[kx];
822                 energy += ets2;
823                 ets2vf = corner_fac * vterm;
824                 virxx += ets2vf * mhx[kx] * mhx[kx] - ets2;
825                 virxy += ets2vf * mhx[kx] * mhy[kx];
826                 virxz += ets2vf * mhx[kx] * mhz[kx];
827                 viryy += ets2vf * mhy[kx] * mhy[kx] - ets2;
828                 viryz += ets2vf * mhy[kx] * mhz[kx];
829                 virzz += ets2vf * mhz[kx] * mhz[kx] - ets2;
830             }
831         }
832         else
833         {
834             /* We don't need to calculate the energy and the virial.
835              *  In this case the triclinic overhead is small.
836              */
837
838             /* Two explicit loops to avoid a conditional inside the loop */
839
840             for (kx = kxstart; kx < maxkx; kx++)
841             {
842                 mx = kx;
843
844                 mhxk      = mx * rxx;
845                 mhyk      = mx * ryx + my * ryy;
846                 mhzk      = mx * rzx + my * rzy + mz * rzz;
847                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
848                 m2[kx]    = m2k;
849                 denom[kx] = bz * by * pme->bsp_mod[XX][kx];
850                 tmp1[kx]  = -factor * m2k;
851                 tmp2[kx]  = sqrt(factor * m2k);
852             }
853
854             for (kx = maxkx; kx < kxend; kx++)
855             {
856                 mx = (kx - nx);
857
858                 mhxk      = mx * rxx;
859                 mhyk      = mx * ryx + my * ryy;
860                 mhzk      = mx * rzx + my * rzy + mz * rzz;
861                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
862                 m2[kx]    = m2k;
863                 denom[kx] = bz * by * pme->bsp_mod[XX][kx];
864                 tmp1[kx]  = -factor * m2k;
865                 tmp2[kx]  = sqrt(factor * m2k);
866             }
867             /* Clear padding elements to avoid (harmless) fp exceptions */
868             const int kxendSimd = roundUpToMultipleOfFactor<c_simdWidth>(kxend);
869             for (; kx < kxendSimd; kx++)
870             {
871                 tmp1[kx] = 0;
872                 tmp2[kx] = 0;
873             }
874
875             calc_exponentials_lj(
876                     kxstart,
877                     kxend,
878                     ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
879                     ArrayRef<PME_T>(tmp2, tmp2 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
880                     ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
881
882             for (kx = kxstart; kx < kxend; kx++)
883             {
884                 m2k      = factor * m2[kx];
885                 eterm    = -((1.0 - 2.0 * m2k) * tmp1[kx] + 2.0 * m2k * tmp2[kx]);
886                 tmp1[kx] = eterm * denom[kx];
887             }
888             gcount = (bLB ? 7 : 1);
889             for (ig = 0; ig < gcount; ++ig)
890             {
891                 t_complex* p0;
892
893                 p0 = grid[ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
894                 for (kx = kxstart; kx < kxend; kx++, p0++)
895                 {
896                     d1 = p0->re;
897                     d2 = p0->im;
898
899                     eterm = tmp1[kx];
900
901                     p0->re = d1 * eterm;
902                     p0->im = d2 * eterm;
903                 }
904             }
905         }
906     }
907     if (computeEnergyAndVirial)
908     {
909         work->vir_lj[XX][XX] = 0.25 * virxx;
910         work->vir_lj[YY][YY] = 0.25 * viryy;
911         work->vir_lj[ZZ][ZZ] = 0.25 * virzz;
912         work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25 * virxy;
913         work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25 * virxz;
914         work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25 * viryz;
915
916         /* This energy should be corrected for a charged system */
917         work->energy_lj = 0.5 * energy;
918     }
919     /* Return the loop count */
920     return local_ndata[YY] * local_ndata[XX];
921 }