99ffa7f12038ed6b041fdafc8f27119c59f752d4
[alexxy/gromacs.git] / src / gromacs / ewald / pme_solve.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017 by the GROMACS development team.
7  * Copyright (c) 2018,2019,2020, by the GROMACS development team, led by
8  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
9  * and including many others, as listed in the AUTHORS file in the
10  * top-level source directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38
39 #include "gmxpre.h"
40
41 #include "pme_solve.h"
42
43 #include <cmath>
44
45 #include "gromacs/fft/parallel_3dfft.h"
46 #include "gromacs/math/units.h"
47 #include "gromacs/math/utilities.h"
48 #include "gromacs/math/vec.h"
49 #include "gromacs/simd/simd.h"
50 #include "gromacs/simd/simd_math.h"
51 #include "gromacs/utility/arrayref.h"
52 #include "gromacs/utility/exceptions.h"
53 #include "gromacs/utility/smalloc.h"
54
55 #include "pme_internal.h"
56 #include "pme_output.h"
57
58 #if GMX_SIMD_HAVE_REAL
59 /* Turn on arbitrary width SIMD intrinsics for PME solve */
60 #    define PME_SIMD_SOLVE
61 #endif
62
63 using namespace gmx; // TODO: Remove when this file is moved into gmx namespace
64
65 struct pme_solve_work_t
66 {
67     /* work data for solve_pme */
68     int   nalloc;
69     real* mhx;
70     real* mhy;
71     real* mhz;
72     real* m2;
73     real* denom;
74     real* tmp1_alloc;
75     real* tmp1;
76     real* tmp2;
77     real* eterm;
78     real* m2inv;
79
80     real   energy_q;
81     matrix vir_q;
82     real   energy_lj;
83     matrix vir_lj;
84 };
85
86 #ifdef PME_SIMD_SOLVE
87 constexpr int c_simdWidth = GMX_SIMD_REAL_WIDTH;
88 #else
89 /* We can use any alignment > 0, so we use 4 */
90 constexpr int c_simdWidth = 4;
91 #endif
92
93 /* Returns the smallest number >= \p that is a multiple of \p factor, \p factor must be a power of 2 */
94 template<unsigned int factor>
95 static size_t roundUpToMultipleOfFactor(size_t number)
96 {
97     static_assert(factor > 0 && (factor & (factor - 1)) == 0,
98                   "factor should be >0 and a power of 2");
99
100     /* We need to add a most factor-1 and because factor is a power of 2,
101      * we get the result by masking out the bits corresponding to factor-1.
102      */
103     return (number + factor - 1) & ~(factor - 1);
104 }
105
106 /* Allocate an aligned pointer for SIMD operations, including extra elements
107  * at the end for padding.
108  */
109 /* TODO: Replace this SIMD reallocator with a general, C++ solution */
110 static void reallocSimdAlignedAndPadded(real** ptr, int unpaddedNumElements)
111 {
112     sfree_aligned(*ptr);
113     snew_aligned(*ptr, roundUpToMultipleOfFactor<c_simdWidth>(unpaddedNumElements),
114                  c_simdWidth * sizeof(real));
115 }
116
117 static void realloc_work(struct pme_solve_work_t* work, int nkx)
118 {
119     if (nkx > work->nalloc)
120     {
121         work->nalloc = nkx;
122         srenew(work->mhx, work->nalloc);
123         srenew(work->mhy, work->nalloc);
124         srenew(work->mhz, work->nalloc);
125         srenew(work->m2, work->nalloc);
126         reallocSimdAlignedAndPadded(&work->denom, work->nalloc);
127         reallocSimdAlignedAndPadded(&work->tmp1, work->nalloc);
128         reallocSimdAlignedAndPadded(&work->tmp2, work->nalloc);
129         reallocSimdAlignedAndPadded(&work->eterm, work->nalloc);
130         srenew(work->m2inv, work->nalloc);
131
132         /* Init all allocated elements of denom to 1 to avoid 1/0 exceptions
133          * of simd padded elements.
134          */
135         for (size_t i = 0; i < roundUpToMultipleOfFactor<c_simdWidth>(work->nalloc); i++)
136         {
137             work->denom[i] = 1;
138         }
139     }
140 }
141
142 void pme_init_all_work(struct pme_solve_work_t** work, int nthread, int nkx)
143 {
144     /* Use fft5d, order after FFT is y major, z, x minor */
145
146     snew(*work, nthread);
147     /* Allocate the work arrays thread local to optimize memory access */
148 #pragma omp parallel for num_threads(nthread) schedule(static)
149     for (int thread = 0; thread < nthread; thread++)
150     {
151         try
152         {
153             realloc_work(&((*work)[thread]), nkx);
154         }
155         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
156     }
157 }
158
159 static void free_work(struct pme_solve_work_t* work)
160 {
161     if (work)
162     {
163         sfree(work->mhx);
164         sfree(work->mhy);
165         sfree(work->mhz);
166         sfree(work->m2);
167         sfree_aligned(work->denom);
168         sfree_aligned(work->tmp1);
169         sfree_aligned(work->tmp2);
170         sfree_aligned(work->eterm);
171         sfree(work->m2inv);
172     }
173 }
174
175 void pme_free_all_work(struct pme_solve_work_t** work, int nthread)
176 {
177     if (*work)
178     {
179         for (int thread = 0; thread < nthread; thread++)
180         {
181             free_work(&(*work)[thread]);
182         }
183     }
184     sfree(*work);
185     *work = nullptr;
186 }
187
188 void get_pme_ener_vir_q(pme_solve_work_t* work, int nthread, PmeOutput* output)
189 {
190     GMX_ASSERT(output != nullptr, "Need valid output buffer");
191     /* This function sums output over threads and should therefore
192      * only be called after thread synchronization.
193      */
194     output->coulombEnergy_ = work[0].energy_q;
195     copy_mat(work[0].vir_q, output->coulombVirial_);
196
197     for (int thread = 1; thread < nthread; thread++)
198     {
199         output->coulombEnergy_ += work[thread].energy_q;
200         m_add(output->coulombVirial_, work[thread].vir_q, output->coulombVirial_);
201     }
202 }
203
204 void get_pme_ener_vir_lj(pme_solve_work_t* work, int nthread, PmeOutput* output)
205 {
206     GMX_ASSERT(output != nullptr, "Need valid output buffer");
207     /* This function sums output over threads and should therefore
208      * only be called after thread synchronization.
209      */
210     output->lennardJonesEnergy_ = work[0].energy_lj;
211     copy_mat(work[0].vir_lj, output->lennardJonesVirial_);
212
213     for (int thread = 1; thread < nthread; thread++)
214     {
215         output->lennardJonesEnergy_ += work[thread].energy_lj;
216         m_add(output->lennardJonesVirial_, work[thread].vir_lj, output->lennardJonesVirial_);
217     }
218 }
219
220 #if defined PME_SIMD_SOLVE
221 /* Calculate exponentials through SIMD */
222 inline static void calc_exponentials_q(int /*unused*/,
223                                        int /*unused*/,
224                                        real                     f,
225                                        ArrayRef<const SimdReal> d_aligned,
226                                        ArrayRef<const SimdReal> r_aligned,
227                                        ArrayRef<SimdReal>       e_aligned)
228 {
229     {
230         SimdReal f_simd(f);
231         SimdReal tmp_d1, tmp_r, tmp_e;
232
233         /* We only need to calculate from start. But since start is 0 or 1
234          * and we want to use aligned loads/stores, we always start from 0.
235          */
236         GMX_ASSERT(d_aligned.size() == r_aligned.size(), "d and r must have same size");
237         GMX_ASSERT(d_aligned.size() == e_aligned.size(), "d and e must have same size");
238         for (size_t kx = 0; kx != d_aligned.size(); ++kx)
239         {
240             tmp_d1        = d_aligned[kx];
241             tmp_r         = r_aligned[kx];
242             tmp_r         = gmx::exp(tmp_r);
243             tmp_e         = f_simd / tmp_d1;
244             tmp_e         = tmp_e * tmp_r;
245             e_aligned[kx] = tmp_e;
246         }
247     }
248 }
249 #else
250 inline static void
251 calc_exponentials_q(int start, int end, real f, ArrayRef<real> d, ArrayRef<real> r, ArrayRef<real> e)
252 {
253     GMX_ASSERT(d.size() == r.size(), "d and r must have same size");
254     GMX_ASSERT(d.size() == e.size(), "d and e must have same size");
255     int kx;
256     for (kx = start; kx < end; kx++)
257     {
258         d[kx] = 1.0 / d[kx];
259     }
260     for (kx = start; kx < end; kx++)
261     {
262         r[kx] = std::exp(r[kx]);
263     }
264     for (kx = start; kx < end; kx++)
265     {
266         e[kx] = f * r[kx] * d[kx];
267     }
268 }
269 #endif
270
271 #if defined PME_SIMD_SOLVE
272 /* Calculate exponentials through SIMD */
273 inline static void calc_exponentials_lj(int /*unused*/,
274                                         int /*unused*/,
275                                         ArrayRef<SimdReal> r_aligned,
276                                         ArrayRef<SimdReal> factor_aligned,
277                                         ArrayRef<SimdReal> d_aligned)
278 {
279     SimdReal       tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
280     const SimdReal sqr_PI = sqrt(SimdReal(M_PI));
281
282     GMX_ASSERT(d_aligned.size() == r_aligned.size(), "d and r must have same size");
283     GMX_ASSERT(d_aligned.size() == factor_aligned.size(), "d and factor must have same size");
284     for (size_t kx = 0; kx != d_aligned.size(); ++kx)
285     {
286         /* We only need to calculate from start. But since start is 0 or 1
287          * and we want to use aligned loads/stores, we always start from 0.
288          */
289         tmp_d              = d_aligned[kx];
290         d_inv              = SimdReal(1.0) / tmp_d;
291         d_aligned[kx]      = d_inv;
292         tmp_r              = r_aligned[kx];
293         tmp_r              = gmx::exp(tmp_r);
294         r_aligned[kx]      = tmp_r;
295         tmp_mk             = factor_aligned[kx];
296         tmp_fac            = sqr_PI * tmp_mk * erfc(tmp_mk);
297         factor_aligned[kx] = tmp_fac;
298     }
299 }
300 #else
301 inline static void
302 calc_exponentials_lj(int start, int end, ArrayRef<real> r, ArrayRef<real> tmp2, ArrayRef<real> d)
303 {
304     int  kx;
305     real mk;
306     GMX_ASSERT(d.size() == r.size(), "d and r must have same size");
307     GMX_ASSERT(d.size() == tmp2.size(), "d and tmp2 must have same size");
308     for (kx = start; kx < end; kx++)
309     {
310         d[kx] = 1.0 / d[kx];
311     }
312
313     for (kx = start; kx < end; kx++)
314     {
315         r[kx] = std::exp(r[kx]);
316     }
317
318     for (kx = start; kx < end; kx++)
319     {
320         mk       = tmp2[kx];
321         tmp2[kx] = sqrt(M_PI) * mk * std::erfc(mk);
322     }
323 }
324 #endif
325
326 #if defined PME_SIMD_SOLVE
327 using PME_T = SimdReal;
328 #else
329 using PME_T = real;
330 #endif
331
332 int solve_pme_yzx(const gmx_pme_t* pme, t_complex* grid, real vol, bool computeEnergyAndVirial, int nthread, int thread)
333 {
334     /* do recip sum over local cells in grid */
335     /* y major, z middle, x minor or continuous */
336     t_complex*               p0;
337     int                      kx, ky, kz, maxkx, maxky;
338     int                      nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
339     real                     mx, my, mz;
340     real                     ewaldcoeff = pme->ewaldcoeff_q;
341     real                     factor     = M_PI * M_PI / (ewaldcoeff * ewaldcoeff);
342     real                     ets2, struct2, vfactor, ets2vf;
343     real                     d1, d2, energy = 0;
344     real                     by, bz;
345     real                     virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
346     real                     rxx, ryx, ryy, rzx, rzy, rzz;
347     struct pme_solve_work_t* work;
348     real *                   mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
349     real                     mhxk, mhyk, mhzk, m2k;
350     real                     corner_fac;
351     ivec                     complex_order;
352     ivec                     local_ndata, local_offset, local_size;
353     real                     elfac;
354
355     elfac = ONE_4PI_EPS0 / pme->epsilon_r;
356
357     nx = pme->nkx;
358     ny = pme->nky;
359     nz = pme->nkz;
360
361     /* Dimensions should be identical for A/B grid, so we just use A here */
362     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_QA], complex_order, local_ndata,
363                                       local_offset, local_size);
364
365     rxx = pme->recipbox[XX][XX];
366     ryx = pme->recipbox[YY][XX];
367     ryy = pme->recipbox[YY][YY];
368     rzx = pme->recipbox[ZZ][XX];
369     rzy = pme->recipbox[ZZ][YY];
370     rzz = pme->recipbox[ZZ][ZZ];
371
372     GMX_ASSERT(rxx != 0.0, "Someone broke the reciprocal box again");
373
374     maxkx = (nx + 1) / 2;
375     maxky = (ny + 1) / 2;
376
377     work  = &pme->solve_work[thread];
378     mhx   = work->mhx;
379     mhy   = work->mhy;
380     mhz   = work->mhz;
381     m2    = work->m2;
382     denom = work->denom;
383     tmp1  = work->tmp1;
384     eterm = work->eterm;
385     m2inv = work->m2inv;
386
387     iyz0 = local_ndata[YY] * local_ndata[ZZ] * thread / nthread;
388     iyz1 = local_ndata[YY] * local_ndata[ZZ] * (thread + 1) / nthread;
389
390     for (iyz = iyz0; iyz < iyz1; iyz++)
391     {
392         iy = iyz / local_ndata[ZZ];
393         iz = iyz - iy * local_ndata[ZZ];
394
395         ky = iy + local_offset[YY];
396
397         if (ky < maxky)
398         {
399             my = ky;
400         }
401         else
402         {
403             my = (ky - ny);
404         }
405
406         by = M_PI * vol * pme->bsp_mod[YY][ky];
407
408         kz = iz + local_offset[ZZ];
409
410         mz = kz;
411
412         bz = pme->bsp_mod[ZZ][kz];
413
414         /* 0.5 correction for corner points */
415         corner_fac = 1;
416         if (kz == 0 || kz == (nz + 1) / 2)
417         {
418             corner_fac = 0.5;
419         }
420
421         p0 = grid + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
422
423         /* We should skip the k-space point (0,0,0) */
424         /* Note that since here x is the minor index, local_offset[XX]=0 */
425         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
426         {
427             kxstart = local_offset[XX];
428         }
429         else
430         {
431             kxstart = local_offset[XX] + 1;
432             p0++;
433         }
434         kxend = local_offset[XX] + local_ndata[XX];
435
436         if (computeEnergyAndVirial)
437         {
438             /* More expensive inner loop, especially because of the storage
439              * of the mh elements in array's.
440              * Because x is the minor grid index, all mh elements
441              * depend on kx for triclinic unit cells.
442              */
443
444             /* Two explicit loops to avoid a conditional inside the loop */
445             for (kx = kxstart; kx < maxkx; kx++)
446             {
447                 mx = kx;
448
449                 mhxk      = mx * rxx;
450                 mhyk      = mx * ryx + my * ryy;
451                 mhzk      = mx * rzx + my * rzy + mz * rzz;
452                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
453                 mhx[kx]   = mhxk;
454                 mhy[kx]   = mhyk;
455                 mhz[kx]   = mhzk;
456                 m2[kx]    = m2k;
457                 denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
458                 tmp1[kx]  = -factor * m2k;
459             }
460
461             for (kx = maxkx; kx < kxend; kx++)
462             {
463                 mx = (kx - nx);
464
465                 mhxk      = mx * rxx;
466                 mhyk      = mx * ryx + my * ryy;
467                 mhzk      = mx * rzx + my * rzy + mz * rzz;
468                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
469                 mhx[kx]   = mhxk;
470                 mhy[kx]   = mhyk;
471                 mhz[kx]   = mhzk;
472                 m2[kx]    = m2k;
473                 denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
474                 tmp1[kx]  = -factor * m2k;
475             }
476
477             for (kx = kxstart; kx < kxend; kx++)
478             {
479                 m2inv[kx] = 1.0 / m2[kx];
480             }
481
482             calc_exponentials_q(
483                     kxstart, kxend, elfac,
484                     ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
485                     ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
486                     ArrayRef<PME_T>(eterm, eterm + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
487
488             for (kx = kxstart; kx < kxend; kx++, p0++)
489             {
490                 d1 = p0->re;
491                 d2 = p0->im;
492
493                 p0->re = d1 * eterm[kx];
494                 p0->im = d2 * eterm[kx];
495
496                 struct2 = 2.0 * (d1 * d1 + d2 * d2);
497
498                 tmp1[kx] = eterm[kx] * struct2;
499             }
500
501             for (kx = kxstart; kx < kxend; kx++)
502             {
503                 ets2    = corner_fac * tmp1[kx];
504                 vfactor = (factor * m2[kx] + 1.0) * 2.0 * m2inv[kx];
505                 energy += ets2;
506
507                 ets2vf = ets2 * vfactor;
508                 virxx += ets2vf * mhx[kx] * mhx[kx] - ets2;
509                 virxy += ets2vf * mhx[kx] * mhy[kx];
510                 virxz += ets2vf * mhx[kx] * mhz[kx];
511                 viryy += ets2vf * mhy[kx] * mhy[kx] - ets2;
512                 viryz += ets2vf * mhy[kx] * mhz[kx];
513                 virzz += ets2vf * mhz[kx] * mhz[kx] - ets2;
514             }
515         }
516         else
517         {
518             /* We don't need to calculate the energy and the virial.
519              * In this case the triclinic overhead is small.
520              */
521
522             /* Two explicit loops to avoid a conditional inside the loop */
523
524             for (kx = kxstart; kx < maxkx; kx++)
525             {
526                 mx = kx;
527
528                 mhxk      = mx * rxx;
529                 mhyk      = mx * ryx + my * ryy;
530                 mhzk      = mx * rzx + my * rzy + mz * rzz;
531                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
532                 denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
533                 tmp1[kx]  = -factor * m2k;
534             }
535
536             for (kx = maxkx; kx < kxend; kx++)
537             {
538                 mx = (kx - nx);
539
540                 mhxk      = mx * rxx;
541                 mhyk      = mx * ryx + my * ryy;
542                 mhzk      = mx * rzx + my * rzy + mz * rzz;
543                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
544                 denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
545                 tmp1[kx]  = -factor * m2k;
546             }
547
548             calc_exponentials_q(
549                     kxstart, kxend, elfac,
550                     ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
551                     ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
552                     ArrayRef<PME_T>(eterm, eterm + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
553
554
555             for (kx = kxstart; kx < kxend; kx++, p0++)
556             {
557                 d1 = p0->re;
558                 d2 = p0->im;
559
560                 p0->re = d1 * eterm[kx];
561                 p0->im = d2 * eterm[kx];
562             }
563         }
564     }
565
566     if (computeEnergyAndVirial)
567     {
568         /* Update virial with local values.
569          * The virial is symmetric by definition.
570          * this virial seems ok for isotropic scaling, but I'm
571          * experiencing problems on semiisotropic membranes.
572          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
573          */
574         work->vir_q[XX][XX] = 0.25 * virxx;
575         work->vir_q[YY][YY] = 0.25 * viryy;
576         work->vir_q[ZZ][ZZ] = 0.25 * virzz;
577         work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25 * virxy;
578         work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25 * virxz;
579         work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25 * viryz;
580
581         /* This energy should be corrected for a charged system */
582         work->energy_q = 0.5 * energy;
583     }
584
585     /* Return the loop count */
586     return local_ndata[YY] * local_ndata[XX];
587 }
588
589 int solve_pme_lj_yzx(const gmx_pme_t* pme,
590                      t_complex**      grid,
591                      gmx_bool         bLB,
592                      real             vol,
593                      bool             computeEnergyAndVirial,
594                      int              nthread,
595                      int              thread)
596 {
597     /* do recip sum over local cells in grid */
598     /* y major, z middle, x minor or continuous */
599     int                      ig, gcount;
600     int                      kx, ky, kz, maxkx, maxky;
601     int                      nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
602     real                     mx, my, mz;
603     real                     ewaldcoeff = pme->ewaldcoeff_lj;
604     real                     factor     = M_PI * M_PI / (ewaldcoeff * ewaldcoeff);
605     real                     ets2, ets2vf;
606     real                     eterm, vterm, d1, d2, energy = 0;
607     real                     by, bz;
608     real                     virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
609     real                     rxx, ryx, ryy, rzx, rzy, rzz;
610     real *                   mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
611     real                     mhxk, mhyk, mhzk, m2k;
612     struct pme_solve_work_t* work;
613     real                     corner_fac;
614     ivec                     complex_order;
615     ivec                     local_ndata, local_offset, local_size;
616     nx = pme->nkx;
617     ny = pme->nky;
618     nz = pme->nkz;
619
620     /* Dimensions should be identical for A/B grid, so we just use A here */
621     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_C6A], complex_order, local_ndata,
622                                       local_offset, local_size);
623     rxx = pme->recipbox[XX][XX];
624     ryx = pme->recipbox[YY][XX];
625     ryy = pme->recipbox[YY][YY];
626     rzx = pme->recipbox[ZZ][XX];
627     rzy = pme->recipbox[ZZ][YY];
628     rzz = pme->recipbox[ZZ][ZZ];
629
630     maxkx = (nx + 1) / 2;
631     maxky = (ny + 1) / 2;
632
633     work  = &pme->solve_work[thread];
634     mhx   = work->mhx;
635     mhy   = work->mhy;
636     mhz   = work->mhz;
637     m2    = work->m2;
638     denom = work->denom;
639     tmp1  = work->tmp1;
640     tmp2  = work->tmp2;
641
642     iyz0 = local_ndata[YY] * local_ndata[ZZ] * thread / nthread;
643     iyz1 = local_ndata[YY] * local_ndata[ZZ] * (thread + 1) / nthread;
644
645     for (iyz = iyz0; iyz < iyz1; iyz++)
646     {
647         iy = iyz / local_ndata[ZZ];
648         iz = iyz - iy * local_ndata[ZZ];
649
650         ky = iy + local_offset[YY];
651
652         if (ky < maxky)
653         {
654             my = ky;
655         }
656         else
657         {
658             my = (ky - ny);
659         }
660
661         by = 3.0 * vol * pme->bsp_mod[YY][ky] / (M_PI * sqrt(M_PI) * ewaldcoeff * ewaldcoeff * ewaldcoeff);
662
663         kz = iz + local_offset[ZZ];
664
665         mz = kz;
666
667         bz = pme->bsp_mod[ZZ][kz];
668
669         /* 0.5 correction for corner points */
670         corner_fac = 1;
671         if (kz == 0 || kz == (nz + 1) / 2)
672         {
673             corner_fac = 0.5;
674         }
675
676         kxstart = local_offset[XX];
677         kxend   = local_offset[XX] + local_ndata[XX];
678         if (computeEnergyAndVirial)
679         {
680             /* More expensive inner loop, especially because of the
681              * storage of the mh elements in array's.  Because x is the
682              * minor grid index, all mh elements depend on kx for
683              * triclinic unit cells.
684              */
685
686             /* Two explicit loops to avoid a conditional inside the loop */
687             for (kx = kxstart; kx < maxkx; kx++)
688             {
689                 mx = kx;
690
691                 mhxk      = mx * rxx;
692                 mhyk      = mx * ryx + my * ryy;
693                 mhzk      = mx * rzx + my * rzy + mz * rzz;
694                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
695                 mhx[kx]   = mhxk;
696                 mhy[kx]   = mhyk;
697                 mhz[kx]   = mhzk;
698                 m2[kx]    = m2k;
699                 denom[kx] = bz * by * pme->bsp_mod[XX][kx];
700                 tmp1[kx]  = -factor * m2k;
701                 tmp2[kx]  = sqrt(factor * m2k);
702             }
703
704             for (kx = maxkx; kx < kxend; kx++)
705             {
706                 mx = (kx - nx);
707
708                 mhxk      = mx * rxx;
709                 mhyk      = mx * ryx + my * ryy;
710                 mhzk      = mx * rzx + my * rzy + mz * rzz;
711                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
712                 mhx[kx]   = mhxk;
713                 mhy[kx]   = mhyk;
714                 mhz[kx]   = mhzk;
715                 m2[kx]    = m2k;
716                 denom[kx] = bz * by * pme->bsp_mod[XX][kx];
717                 tmp1[kx]  = -factor * m2k;
718                 tmp2[kx]  = sqrt(factor * m2k);
719             }
720             /* Clear padding elements to avoid (harmless) fp exceptions */
721             const int kxendSimd = roundUpToMultipleOfFactor<c_simdWidth>(kxend);
722             for (; kx < kxendSimd; kx++)
723             {
724                 tmp1[kx] = 0;
725                 tmp2[kx] = 0;
726             }
727
728             calc_exponentials_lj(
729                     kxstart, kxend,
730                     ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
731                     ArrayRef<PME_T>(tmp2, tmp2 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
732                     ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
733
734             for (kx = kxstart; kx < kxend; kx++)
735             {
736                 m2k      = factor * m2[kx];
737                 eterm    = -((1.0 - 2.0 * m2k) * tmp1[kx] + 2.0 * m2k * tmp2[kx]);
738                 vterm    = 3.0 * (-tmp1[kx] + tmp2[kx]);
739                 tmp1[kx] = eterm * denom[kx];
740                 tmp2[kx] = vterm * denom[kx];
741             }
742
743             if (!bLB)
744             {
745                 t_complex* p0;
746                 real       struct2;
747
748                 p0 = grid[0] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
749                 for (kx = kxstart; kx < kxend; kx++, p0++)
750                 {
751                     d1 = p0->re;
752                     d2 = p0->im;
753
754                     eterm  = tmp1[kx];
755                     vterm  = tmp2[kx];
756                     p0->re = d1 * eterm;
757                     p0->im = d2 * eterm;
758
759                     struct2 = 2.0 * (d1 * d1 + d2 * d2);
760
761                     tmp1[kx] = eterm * struct2;
762                     tmp2[kx] = vterm * struct2;
763                 }
764             }
765             else
766             {
767                 real* struct2 = denom;
768                 real  str2;
769
770                 for (kx = kxstart; kx < kxend; kx++)
771                 {
772                     struct2[kx] = 0.0;
773                 }
774                 /* Due to symmetry we only need to calculate 4 of the 7 terms */
775                 for (ig = 0; ig <= 3; ++ig)
776                 {
777                     t_complex *p0, *p1;
778                     real       scale;
779
780                     p0 = grid[ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
781                     p1 = grid[6 - ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
782                     scale = 2.0 * lb_scale_factor_symm[ig];
783                     for (kx = kxstart; kx < kxend; ++kx, ++p0, ++p1)
784                     {
785                         struct2[kx] += scale * (p0->re * p1->re + p0->im * p1->im);
786                     }
787                 }
788                 for (ig = 0; ig <= 6; ++ig)
789                 {
790                     t_complex* p0;
791
792                     p0 = grid[ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
793                     for (kx = kxstart; kx < kxend; kx++, p0++)
794                     {
795                         d1 = p0->re;
796                         d2 = p0->im;
797
798                         eterm  = tmp1[kx];
799                         p0->re = d1 * eterm;
800                         p0->im = d2 * eterm;
801                     }
802                 }
803                 for (kx = kxstart; kx < kxend; kx++)
804                 {
805                     eterm    = tmp1[kx];
806                     vterm    = tmp2[kx];
807                     str2     = struct2[kx];
808                     tmp1[kx] = eterm * str2;
809                     tmp2[kx] = vterm * str2;
810                 }
811             }
812
813             for (kx = kxstart; kx < kxend; kx++)
814             {
815                 ets2  = corner_fac * tmp1[kx];
816                 vterm = 2.0 * factor * tmp2[kx];
817                 energy += ets2;
818                 ets2vf = corner_fac * vterm;
819                 virxx += ets2vf * mhx[kx] * mhx[kx] - ets2;
820                 virxy += ets2vf * mhx[kx] * mhy[kx];
821                 virxz += ets2vf * mhx[kx] * mhz[kx];
822                 viryy += ets2vf * mhy[kx] * mhy[kx] - ets2;
823                 viryz += ets2vf * mhy[kx] * mhz[kx];
824                 virzz += ets2vf * mhz[kx] * mhz[kx] - ets2;
825             }
826         }
827         else
828         {
829             /* We don't need to calculate the energy and the virial.
830              *  In this case the triclinic overhead is small.
831              */
832
833             /* Two explicit loops to avoid a conditional inside the loop */
834
835             for (kx = kxstart; kx < maxkx; kx++)
836             {
837                 mx = kx;
838
839                 mhxk      = mx * rxx;
840                 mhyk      = mx * ryx + my * ryy;
841                 mhzk      = mx * rzx + my * rzy + mz * rzz;
842                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
843                 m2[kx]    = m2k;
844                 denom[kx] = bz * by * pme->bsp_mod[XX][kx];
845                 tmp1[kx]  = -factor * m2k;
846                 tmp2[kx]  = sqrt(factor * m2k);
847             }
848
849             for (kx = maxkx; kx < kxend; kx++)
850             {
851                 mx = (kx - nx);
852
853                 mhxk      = mx * rxx;
854                 mhyk      = mx * ryx + my * ryy;
855                 mhzk      = mx * rzx + my * rzy + mz * rzz;
856                 m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
857                 m2[kx]    = m2k;
858                 denom[kx] = bz * by * pme->bsp_mod[XX][kx];
859                 tmp1[kx]  = -factor * m2k;
860                 tmp2[kx]  = sqrt(factor * m2k);
861             }
862             /* Clear padding elements to avoid (harmless) fp exceptions */
863             const int kxendSimd = roundUpToMultipleOfFactor<c_simdWidth>(kxend);
864             for (; kx < kxendSimd; kx++)
865             {
866                 tmp1[kx] = 0;
867                 tmp2[kx] = 0;
868             }
869
870             calc_exponentials_lj(
871                     kxstart, kxend,
872                     ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
873                     ArrayRef<PME_T>(tmp2, tmp2 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
874                     ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
875
876             for (kx = kxstart; kx < kxend; kx++)
877             {
878                 m2k      = factor * m2[kx];
879                 eterm    = -((1.0 - 2.0 * m2k) * tmp1[kx] + 2.0 * m2k * tmp2[kx]);
880                 tmp1[kx] = eterm * denom[kx];
881             }
882             gcount = (bLB ? 7 : 1);
883             for (ig = 0; ig < gcount; ++ig)
884             {
885                 t_complex* p0;
886
887                 p0 = grid[ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
888                 for (kx = kxstart; kx < kxend; kx++, p0++)
889                 {
890                     d1 = p0->re;
891                     d2 = p0->im;
892
893                     eterm = tmp1[kx];
894
895                     p0->re = d1 * eterm;
896                     p0->im = d2 * eterm;
897                 }
898             }
899         }
900     }
901     if (computeEnergyAndVirial)
902     {
903         work->vir_lj[XX][XX] = 0.25 * virxx;
904         work->vir_lj[YY][YY] = 0.25 * viryy;
905         work->vir_lj[ZZ][ZZ] = 0.25 * virzz;
906         work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25 * virxy;
907         work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25 * virxz;
908         work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25 * viryz;
909
910         /* This energy should be corrected for a charged system */
911         work->energy_lj = 0.5 * energy;
912     }
913     /* Return the loop count */
914     return local_ndata[YY] * local_ndata[XX];
915 }