Fix cmake policy warning
[alexxy/gromacs.git] / src / gromacs / ewald / pme_solve.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017,2018,2019, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37
38 #include "gmxpre.h"
39
40 #include "pme_solve.h"
41
42 #include <cmath>
43
44 #include "gromacs/fft/parallel_3dfft.h"
45 #include "gromacs/math/units.h"
46 #include "gromacs/math/utilities.h"
47 #include "gromacs/math/vec.h"
48 #include "gromacs/simd/simd.h"
49 #include "gromacs/simd/simd_math.h"
50 #include "gromacs/utility/arrayref.h"
51 #include "gromacs/utility/exceptions.h"
52 #include "gromacs/utility/smalloc.h"
53
54 #include "pme_internal.h"
55
56 #if GMX_SIMD_HAVE_REAL
57 /* Turn on arbitrary width SIMD intrinsics for PME solve */
58 #    define PME_SIMD_SOLVE
59 #endif
60
61 using namespace gmx; // TODO: Remove when this file is moved into gmx namespace
62
63 struct pme_solve_work_t
64 {
65     /* work data for solve_pme */
66     int      nalloc;
67     real *   mhx;
68     real *   mhy;
69     real *   mhz;
70     real *   m2;
71     real *   denom;
72     real *   tmp1_alloc;
73     real *   tmp1;
74     real *   tmp2;
75     real *   eterm;
76     real *   m2inv;
77
78     real     energy_q;
79     matrix   vir_q;
80     real     energy_lj;
81     matrix   vir_lj;
82 };
83
84 #ifdef PME_SIMD_SOLVE
85 constexpr int c_simdWidth = GMX_SIMD_REAL_WIDTH;
86 #else
87 /* We can use any alignment > 0, so we use 4 */
88 constexpr int c_simdWidth = 4;
89 #endif
90
91 /* Returns the smallest number >= \p that is a multiple of \p factor, \p factor must be a power of 2 */
92 template <unsigned int factor>
93 static size_t roundUpToMultipleOfFactor(size_t number)
94 {
95     static_assert(factor > 0 && (factor & (factor - 1)) == 0, "factor should be >0 and a power of 2");
96
97     /* We need to add a most factor-1 and because factor is a power of 2,
98      * we get the result by masking out the bits corresponding to factor-1.
99      */
100     return (number + factor - 1) & ~(factor - 1);
101 }
102
103 /* Allocate an aligned pointer for SIMD operations, including extra elements
104  * at the end for padding.
105  */
106 /* TODO: Replace this SIMD reallocator with a general, C++ solution */
107 static void reallocSimdAlignedAndPadded(real **ptr, int unpaddedNumElements)
108 {
109     sfree_aligned(*ptr);
110     snew_aligned(*ptr, roundUpToMultipleOfFactor<c_simdWidth>(unpaddedNumElements), c_simdWidth*sizeof(real));
111 }
112
113 static void realloc_work(struct pme_solve_work_t *work, int nkx)
114 {
115     if (nkx > work->nalloc)
116     {
117         work->nalloc = nkx;
118         srenew(work->mhx, work->nalloc);
119         srenew(work->mhy, work->nalloc);
120         srenew(work->mhz, work->nalloc);
121         srenew(work->m2, work->nalloc);
122         reallocSimdAlignedAndPadded(&work->denom, work->nalloc);
123         reallocSimdAlignedAndPadded(&work->tmp1, work->nalloc);
124         reallocSimdAlignedAndPadded(&work->tmp2, work->nalloc);
125         reallocSimdAlignedAndPadded(&work->eterm, work->nalloc);
126         srenew(work->m2inv, work->nalloc);
127
128         /* Init all allocated elements of denom to 1 to avoid 1/0 exceptions
129          * of simd padded elements.
130          */
131         for (size_t i = 0; i < roundUpToMultipleOfFactor<c_simdWidth>(work->nalloc ); i++)
132         {
133             work->denom[i] = 1;
134         }
135     }
136 }
137
138 void pme_init_all_work(struct pme_solve_work_t **work, int nthread, int nkx)
139 {
140     int thread;
141     /* Use fft5d, order after FFT is y major, z, x minor */
142
143     snew(*work, nthread);
144     /* Allocate the work arrays thread local to optimize memory access */
145 #pragma omp parallel for num_threads(nthread) schedule(static)
146     for (thread = 0; thread < nthread; thread++)
147     {
148         try
149         {
150             realloc_work(&((*work)[thread]), nkx);
151         }
152         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
153     }
154 }
155
156 static void free_work(struct pme_solve_work_t *work)
157 {
158     if (work)
159     {
160         sfree(work->mhx);
161         sfree(work->mhy);
162         sfree(work->mhz);
163         sfree(work->m2);
164         sfree_aligned(work->denom);
165         sfree_aligned(work->tmp1);
166         sfree_aligned(work->tmp2);
167         sfree_aligned(work->eterm);
168         sfree(work->m2inv);
169     }
170 }
171
172 void pme_free_all_work(struct pme_solve_work_t **work, int nthread)
173 {
174     if (*work)
175     {
176         for (int thread = 0; thread < nthread; thread++)
177         {
178             free_work(&(*work)[thread]);
179         }
180     }
181     sfree(*work);
182     *work = nullptr;
183 }
184
185 void get_pme_ener_vir_q(pme_solve_work_t *work, int nthread, PmeOutput *output)
186 {
187     GMX_ASSERT(output != nullptr, "Need valid output buffer");
188     /* This function sums output over threads and should therefore
189      * only be called after thread synchronization.
190      */
191     output->coulombEnergy_ = work[0].energy_q;
192     copy_mat(work[0].vir_q, output->coulombVirial_);
193
194     for (int thread = 1; thread < nthread; thread++)
195     {
196         output->coulombEnergy_ += work[thread].energy_q;
197         m_add(output->coulombVirial_, work[thread].vir_q, output->coulombVirial_);
198     }
199 }
200
201 void get_pme_ener_vir_lj(pme_solve_work_t *work, int nthread, PmeOutput *output)
202 {
203     GMX_ASSERT(output != nullptr, "Need valid output buffer");
204     /* This function sums output over threads and should therefore
205      * only be called after thread synchronization.
206      */
207     output->lennardJonesEnergy_ = work[0].energy_lj;
208     copy_mat(work[0].vir_lj, output->lennardJonesVirial_);
209
210     for (int thread = 1; thread < nthread; thread++)
211     {
212         output->lennardJonesEnergy_ += work[thread].energy_lj;
213         m_add(output->lennardJonesVirial_, work[thread].vir_lj, output->lennardJonesVirial_);
214     }
215 }
216
217 #if defined PME_SIMD_SOLVE
218 /* Calculate exponentials through SIMD */
219 inline static void calc_exponentials_q(int /*unused*/, int /*unused*/, real f, ArrayRef<const SimdReal> d_aligned, ArrayRef<const SimdReal> r_aligned, ArrayRef<SimdReal> e_aligned)
220 {
221     {
222         SimdReal              f_simd(f);
223         SimdReal              tmp_d1, tmp_r, tmp_e;
224
225         /* We only need to calculate from start. But since start is 0 or 1
226          * and we want to use aligned loads/stores, we always start from 0.
227          */
228         GMX_ASSERT(d_aligned.size() == r_aligned.size(), "d and r must have same size");
229         GMX_ASSERT(d_aligned.size() == e_aligned.size(), "d and e must have same size");
230         for (size_t kx = 0; kx != d_aligned.size(); ++kx)
231         {
232             tmp_d1        = d_aligned[kx];
233             tmp_r         = r_aligned[kx];
234             tmp_r         = gmx::exp(tmp_r);
235             tmp_e         = f_simd / tmp_d1;
236             tmp_e         = tmp_e * tmp_r;
237             e_aligned[kx] = tmp_e;
238         }
239     }
240 }
241 #else
242 inline static void calc_exponentials_q(int start, int end, real f, ArrayRef<real> d, ArrayRef<real> r, ArrayRef<real> e)
243 {
244     GMX_ASSERT(d.size() == r.size(), "d and r must have same size");
245     GMX_ASSERT(d.size() == e.size(), "d and e must have same size");
246     int kx;
247     for (kx = start; kx < end; kx++)
248     {
249         d[kx] = 1.0/d[kx];
250     }
251     for (kx = start; kx < end; kx++)
252     {
253         r[kx] = std::exp(r[kx]);
254     }
255     for (kx = start; kx < end; kx++)
256     {
257         e[kx] = f*r[kx]*d[kx];
258     }
259 }
260 #endif
261
262 #if defined PME_SIMD_SOLVE
263 /* Calculate exponentials through SIMD */
264 inline static void calc_exponentials_lj(int /*unused*/, int /*unused*/, ArrayRef<SimdReal> r_aligned, ArrayRef<SimdReal> factor_aligned, ArrayRef<SimdReal> d_aligned)
265 {
266     SimdReal              tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
267     const SimdReal        sqr_PI = sqrt(SimdReal(M_PI));
268
269     GMX_ASSERT(d_aligned.size() == r_aligned.size(), "d and r must have same size");
270     GMX_ASSERT(d_aligned.size() == factor_aligned.size(), "d and factor must have same size");
271     for (size_t kx = 0; kx != d_aligned.size(); ++kx)
272     {
273         /* We only need to calculate from start. But since start is 0 or 1
274          * and we want to use aligned loads/stores, we always start from 0.
275          */
276         tmp_d              = d_aligned[kx];
277         d_inv              = SimdReal(1.0) / tmp_d;
278         d_aligned[kx]      = d_inv;
279         tmp_r              = r_aligned[kx];
280         tmp_r              = gmx::exp(tmp_r);
281         r_aligned[kx]      = tmp_r;
282         tmp_mk             = factor_aligned[kx];
283         tmp_fac            = sqr_PI * tmp_mk * erfc(tmp_mk);
284         factor_aligned[kx] = tmp_fac;
285     }
286 }
287 #else
288 inline static void calc_exponentials_lj(int start, int end, ArrayRef<real> r, ArrayRef<real> tmp2, ArrayRef<real> d)
289 {
290     int  kx;
291     real mk;
292     GMX_ASSERT(d.size() == r.size(), "d and r must have same size");
293     GMX_ASSERT(d.size() == tmp2.size(), "d and tmp2 must have same size");
294     for (kx = start; kx < end; kx++)
295     {
296         d[kx] = 1.0/d[kx];
297     }
298
299     for (kx = start; kx < end; kx++)
300     {
301         r[kx] = std::exp(r[kx]);
302     }
303
304     for (kx = start; kx < end; kx++)
305     {
306         mk       = tmp2[kx];
307         tmp2[kx] = sqrt(M_PI)*mk*std::erfc(mk);
308     }
309 }
310 #endif
311
312 #if defined PME_SIMD_SOLVE
313 using PME_T = SimdReal;
314 #else
315 using PME_T = real;
316 #endif
317
318 int solve_pme_yzx(const gmx_pme_t *pme, t_complex *grid, real vol,
319                   gmx_bool bEnerVir,
320                   int nthread, int thread)
321 {
322     /* do recip sum over local cells in grid */
323     /* y major, z middle, x minor or continuous */
324     t_complex               *p0;
325     int                      kx, ky, kz, maxkx, maxky;
326     int                      nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
327     real                     mx, my, mz;
328     real                     ewaldcoeff = pme->ewaldcoeff_q;
329     real                     factor     = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
330     real                     ets2, struct2, vfactor, ets2vf;
331     real                     d1, d2, energy = 0;
332     real                     by, bz;
333     real                     virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
334     real                     rxx, ryx, ryy, rzx, rzy, rzz;
335     struct pme_solve_work_t *work;
336     real                    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
337     real                     mhxk, mhyk, mhzk, m2k;
338     real                     corner_fac;
339     ivec                     complex_order;
340     ivec                     local_ndata, local_offset, local_size;
341     real                     elfac;
342
343     elfac = ONE_4PI_EPS0/pme->epsilon_r;
344
345     nx = pme->nkx;
346     ny = pme->nky;
347     nz = pme->nkz;
348
349     /* Dimensions should be identical for A/B grid, so we just use A here */
350     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_QA],
351                                       complex_order,
352                                       local_ndata,
353                                       local_offset,
354                                       local_size);
355
356     rxx = pme->recipbox[XX][XX];
357     ryx = pme->recipbox[YY][XX];
358     ryy = pme->recipbox[YY][YY];
359     rzx = pme->recipbox[ZZ][XX];
360     rzy = pme->recipbox[ZZ][YY];
361     rzz = pme->recipbox[ZZ][ZZ];
362
363     GMX_ASSERT(rxx != 0.0, "Someone broke the reciprocal box again");
364
365     maxkx = (nx+1)/2;
366     maxky = (ny+1)/2;
367
368     work  = &pme->solve_work[thread];
369     mhx   = work->mhx;
370     mhy   = work->mhy;
371     mhz   = work->mhz;
372     m2    = work->m2;
373     denom = work->denom;
374     tmp1  = work->tmp1;
375     eterm = work->eterm;
376     m2inv = work->m2inv;
377
378     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
379     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
380
381     for (iyz = iyz0; iyz < iyz1; iyz++)
382     {
383         iy = iyz/local_ndata[ZZ];
384         iz = iyz - iy*local_ndata[ZZ];
385
386         ky = iy + local_offset[YY];
387
388         if (ky < maxky)
389         {
390             my = ky;
391         }
392         else
393         {
394             my = (ky - ny);
395         }
396
397         by = M_PI*vol*pme->bsp_mod[YY][ky];
398
399         kz = iz + local_offset[ZZ];
400
401         mz = kz;
402
403         bz = pme->bsp_mod[ZZ][kz];
404
405         /* 0.5 correction for corner points */
406         corner_fac = 1;
407         if (kz == 0 || kz == (nz+1)/2)
408         {
409             corner_fac = 0.5;
410         }
411
412         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
413
414         /* We should skip the k-space point (0,0,0) */
415         /* Note that since here x is the minor index, local_offset[XX]=0 */
416         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
417         {
418             kxstart = local_offset[XX];
419         }
420         else
421         {
422             kxstart = local_offset[XX] + 1;
423             p0++;
424         }
425         kxend = local_offset[XX] + local_ndata[XX];
426
427         if (bEnerVir)
428         {
429             /* More expensive inner loop, especially because of the storage
430              * of the mh elements in array's.
431              * Because x is the minor grid index, all mh elements
432              * depend on kx for triclinic unit cells.
433              */
434
435             /* Two explicit loops to avoid a conditional inside the loop */
436             for (kx = kxstart; kx < maxkx; kx++)
437             {
438                 mx = kx;
439
440                 mhxk      = mx * rxx;
441                 mhyk      = mx * ryx + my * ryy;
442                 mhzk      = mx * rzx + my * rzy + mz * rzz;
443                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
444                 mhx[kx]   = mhxk;
445                 mhy[kx]   = mhyk;
446                 mhz[kx]   = mhzk;
447                 m2[kx]    = m2k;
448                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
449                 tmp1[kx]  = -factor*m2k;
450             }
451
452             for (kx = maxkx; kx < kxend; kx++)
453             {
454                 mx = (kx - nx);
455
456                 mhxk      = mx * rxx;
457                 mhyk      = mx * ryx + my * ryy;
458                 mhzk      = mx * rzx + my * rzy + mz * rzz;
459                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
460                 mhx[kx]   = mhxk;
461                 mhy[kx]   = mhyk;
462                 mhz[kx]   = mhzk;
463                 m2[kx]    = m2k;
464                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
465                 tmp1[kx]  = -factor*m2k;
466             }
467
468             for (kx = kxstart; kx < kxend; kx++)
469             {
470                 m2inv[kx] = 1.0/m2[kx];
471             }
472
473             calc_exponentials_q(kxstart, kxend, elfac,
474                                 ArrayRef<PME_T>(denom, denom+roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
475                                 ArrayRef<PME_T>(tmp1, tmp1+roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
476                                 ArrayRef<PME_T>(eterm, eterm+roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
477
478             for (kx = kxstart; kx < kxend; kx++, p0++)
479             {
480                 d1      = p0->re;
481                 d2      = p0->im;
482
483                 p0->re  = d1*eterm[kx];
484                 p0->im  = d2*eterm[kx];
485
486                 struct2 = 2.0*(d1*d1+d2*d2);
487
488                 tmp1[kx] = eterm[kx]*struct2;
489             }
490
491             for (kx = kxstart; kx < kxend; kx++)
492             {
493                 ets2     = corner_fac*tmp1[kx];
494                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
495                 energy  += ets2;
496
497                 ets2vf   = ets2*vfactor;
498                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
499                 virxy   += ets2vf*mhx[kx]*mhy[kx];
500                 virxz   += ets2vf*mhx[kx]*mhz[kx];
501                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
502                 viryz   += ets2vf*mhy[kx]*mhz[kx];
503                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
504             }
505         }
506         else
507         {
508             /* We don't need to calculate the energy and the virial.
509              * In this case the triclinic overhead is small.
510              */
511
512             /* Two explicit loops to avoid a conditional inside the loop */
513
514             for (kx = kxstart; kx < maxkx; kx++)
515             {
516                 mx = kx;
517
518                 mhxk      = mx * rxx;
519                 mhyk      = mx * ryx + my * ryy;
520                 mhzk      = mx * rzx + my * rzy + mz * rzz;
521                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
522                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
523                 tmp1[kx]  = -factor*m2k;
524             }
525
526             for (kx = maxkx; kx < kxend; kx++)
527             {
528                 mx = (kx - nx);
529
530                 mhxk      = mx * rxx;
531                 mhyk      = mx * ryx + my * ryy;
532                 mhzk      = mx * rzx + my * rzy + mz * rzz;
533                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
534                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
535                 tmp1[kx]  = -factor*m2k;
536             }
537
538             calc_exponentials_q(kxstart, kxend, elfac,
539                                 ArrayRef<PME_T>(denom, denom+roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
540                                 ArrayRef<PME_T>(tmp1, tmp1+roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
541                                 ArrayRef<PME_T>(eterm, eterm+roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
542
543
544             for (kx = kxstart; kx < kxend; kx++, p0++)
545             {
546                 d1      = p0->re;
547                 d2      = p0->im;
548
549                 p0->re  = d1*eterm[kx];
550                 p0->im  = d2*eterm[kx];
551             }
552         }
553     }
554
555     if (bEnerVir)
556     {
557         /* Update virial with local values.
558          * The virial is symmetric by definition.
559          * this virial seems ok for isotropic scaling, but I'm
560          * experiencing problems on semiisotropic membranes.
561          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
562          */
563         work->vir_q[XX][XX] = 0.25*virxx;
564         work->vir_q[YY][YY] = 0.25*viryy;
565         work->vir_q[ZZ][ZZ] = 0.25*virzz;
566         work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25*virxy;
567         work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25*virxz;
568         work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25*viryz;
569
570         /* This energy should be corrected for a charged system */
571         work->energy_q = 0.5*energy;
572     }
573
574     /* Return the loop count */
575     return local_ndata[YY]*local_ndata[XX];
576 }
577
578 int solve_pme_lj_yzx(const gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real vol,
579                      gmx_bool bEnerVir, int nthread, int thread)
580 {
581     /* do recip sum over local cells in grid */
582     /* y major, z middle, x minor or continuous */
583     int                      ig, gcount;
584     int                      kx, ky, kz, maxkx, maxky;
585     int                      nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
586     real                     mx, my, mz;
587     real                     ewaldcoeff = pme->ewaldcoeff_lj;
588     real                     factor     = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
589     real                     ets2, ets2vf;
590     real                     eterm, vterm, d1, d2, energy = 0;
591     real                     by, bz;
592     real                     virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
593     real                     rxx, ryx, ryy, rzx, rzy, rzz;
594     real                    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
595     real                     mhxk, mhyk, mhzk, m2k;
596     struct pme_solve_work_t *work;
597     real                     corner_fac;
598     ivec                     complex_order;
599     ivec                     local_ndata, local_offset, local_size;
600     nx = pme->nkx;
601     ny = pme->nky;
602     nz = pme->nkz;
603
604     /* Dimensions should be identical for A/B grid, so we just use A here */
605     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_C6A],
606                                       complex_order,
607                                       local_ndata,
608                                       local_offset,
609                                       local_size);
610     rxx = pme->recipbox[XX][XX];
611     ryx = pme->recipbox[YY][XX];
612     ryy = pme->recipbox[YY][YY];
613     rzx = pme->recipbox[ZZ][XX];
614     rzy = pme->recipbox[ZZ][YY];
615     rzz = pme->recipbox[ZZ][ZZ];
616
617     maxkx = (nx+1)/2;
618     maxky = (ny+1)/2;
619
620     work  = &pme->solve_work[thread];
621     mhx   = work->mhx;
622     mhy   = work->mhy;
623     mhz   = work->mhz;
624     m2    = work->m2;
625     denom = work->denom;
626     tmp1  = work->tmp1;
627     tmp2  = work->tmp2;
628
629     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
630     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
631
632     for (iyz = iyz0; iyz < iyz1; iyz++)
633     {
634         iy = iyz/local_ndata[ZZ];
635         iz = iyz - iy*local_ndata[ZZ];
636
637         ky = iy + local_offset[YY];
638
639         if (ky < maxky)
640         {
641             my = ky;
642         }
643         else
644         {
645             my = (ky - ny);
646         }
647
648         by = 3.0*vol*pme->bsp_mod[YY][ky]
649             / (M_PI*sqrt(M_PI)*ewaldcoeff*ewaldcoeff*ewaldcoeff);
650
651         kz = iz + local_offset[ZZ];
652
653         mz = kz;
654
655         bz = pme->bsp_mod[ZZ][kz];
656
657         /* 0.5 correction for corner points */
658         corner_fac = 1;
659         if (kz == 0 || kz == (nz+1)/2)
660         {
661             corner_fac = 0.5;
662         }
663
664         kxstart = local_offset[XX];
665         kxend   = local_offset[XX] + local_ndata[XX];
666         if (bEnerVir)
667         {
668             /* More expensive inner loop, especially because of the
669              * storage of the mh elements in array's.  Because x is the
670              * minor grid index, all mh elements depend on kx for
671              * triclinic unit cells.
672              */
673
674             /* Two explicit loops to avoid a conditional inside the loop */
675             for (kx = kxstart; kx < maxkx; kx++)
676             {
677                 mx = kx;
678
679                 mhxk      = mx * rxx;
680                 mhyk      = mx * ryx + my * ryy;
681                 mhzk      = mx * rzx + my * rzy + mz * rzz;
682                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
683                 mhx[kx]   = mhxk;
684                 mhy[kx]   = mhyk;
685                 mhz[kx]   = mhzk;
686                 m2[kx]    = m2k;
687                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
688                 tmp1[kx]  = -factor*m2k;
689                 tmp2[kx]  = sqrt(factor*m2k);
690             }
691
692             for (kx = maxkx; kx < kxend; kx++)
693             {
694                 mx = (kx - nx);
695
696                 mhxk      = mx * rxx;
697                 mhyk      = mx * ryx + my * ryy;
698                 mhzk      = mx * rzx + my * rzy + mz * rzz;
699                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
700                 mhx[kx]   = mhxk;
701                 mhy[kx]   = mhyk;
702                 mhz[kx]   = mhzk;
703                 m2[kx]    = m2k;
704                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
705                 tmp1[kx]  = -factor*m2k;
706                 tmp2[kx]  = sqrt(factor*m2k);
707             }
708             /* Clear padding elements to avoid (harmless) fp exceptions */
709             const int kxendSimd = roundUpToMultipleOfFactor<c_simdWidth>(kxend);
710             for (; kx < kxendSimd; kx++)
711             {
712                 tmp1[kx] = 0;
713                 tmp2[kx] = 0;
714             }
715
716             calc_exponentials_lj(kxstart, kxend,
717                                  ArrayRef<PME_T>(tmp1, tmp1+roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
718                                  ArrayRef<PME_T>(tmp2, tmp2+roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
719                                  ArrayRef<PME_T>(denom, denom+roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
720
721             for (kx = kxstart; kx < kxend; kx++)
722             {
723                 m2k   = factor*m2[kx];
724                 eterm = -((1.0 - 2.0*m2k)*tmp1[kx]
725                           + 2.0*m2k*tmp2[kx]);
726                 vterm    = 3.0*(-tmp1[kx] + tmp2[kx]);
727                 tmp1[kx] = eterm*denom[kx];
728                 tmp2[kx] = vterm*denom[kx];
729             }
730
731             if (!bLB)
732             {
733                 t_complex *p0;
734                 real       struct2;
735
736                 p0 = grid[0] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
737                 for (kx = kxstart; kx < kxend; kx++, p0++)
738                 {
739                     d1      = p0->re;
740                     d2      = p0->im;
741
742                     eterm   = tmp1[kx];
743                     vterm   = tmp2[kx];
744                     p0->re  = d1*eterm;
745                     p0->im  = d2*eterm;
746
747                     struct2 = 2.0*(d1*d1+d2*d2);
748
749                     tmp1[kx] = eterm*struct2;
750                     tmp2[kx] = vterm*struct2;
751                 }
752             }
753             else
754             {
755                 real *struct2 = denom;
756                 real  str2;
757
758                 for (kx = kxstart; kx < kxend; kx++)
759                 {
760                     struct2[kx] = 0.0;
761                 }
762                 /* Due to symmetry we only need to calculate 4 of the 7 terms */
763                 for (ig = 0; ig <= 3; ++ig)
764                 {
765                     t_complex *p0, *p1;
766                     real       scale;
767
768                     p0    = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
769                     p1    = grid[6-ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
770                     scale = 2.0*lb_scale_factor_symm[ig];
771                     for (kx = kxstart; kx < kxend; ++kx, ++p0, ++p1)
772                     {
773                         struct2[kx] += scale*(p0->re*p1->re + p0->im*p1->im);
774                     }
775
776                 }
777                 for (ig = 0; ig <= 6; ++ig)
778                 {
779                     t_complex *p0;
780
781                     p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
782                     for (kx = kxstart; kx < kxend; kx++, p0++)
783                     {
784                         d1     = p0->re;
785                         d2     = p0->im;
786
787                         eterm  = tmp1[kx];
788                         p0->re = d1*eterm;
789                         p0->im = d2*eterm;
790                     }
791                 }
792                 for (kx = kxstart; kx < kxend; kx++)
793                 {
794                     eterm    = tmp1[kx];
795                     vterm    = tmp2[kx];
796                     str2     = struct2[kx];
797                     tmp1[kx] = eterm*str2;
798                     tmp2[kx] = vterm*str2;
799                 }
800             }
801
802             for (kx = kxstart; kx < kxend; kx++)
803             {
804                 ets2     = corner_fac*tmp1[kx];
805                 vterm    = 2.0*factor*tmp2[kx];
806                 energy  += ets2;
807                 ets2vf   = corner_fac*vterm;
808                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
809                 virxy   += ets2vf*mhx[kx]*mhy[kx];
810                 virxz   += ets2vf*mhx[kx]*mhz[kx];
811                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
812                 viryz   += ets2vf*mhy[kx]*mhz[kx];
813                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
814             }
815         }
816         else
817         {
818             /* We don't need to calculate the energy and the virial.
819              *  In this case the triclinic overhead is small.
820              */
821
822             /* Two explicit loops to avoid a conditional inside the loop */
823
824             for (kx = kxstart; kx < maxkx; kx++)
825             {
826                 mx = kx;
827
828                 mhxk      = mx * rxx;
829                 mhyk      = mx * ryx + my * ryy;
830                 mhzk      = mx * rzx + my * rzy + mz * rzz;
831                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
832                 m2[kx]    = m2k;
833                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
834                 tmp1[kx]  = -factor*m2k;
835                 tmp2[kx]  = sqrt(factor*m2k);
836             }
837
838             for (kx = maxkx; kx < kxend; kx++)
839             {
840                 mx = (kx - nx);
841
842                 mhxk      = mx * rxx;
843                 mhyk      = mx * ryx + my * ryy;
844                 mhzk      = mx * rzx + my * rzy + mz * rzz;
845                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
846                 m2[kx]    = m2k;
847                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
848                 tmp1[kx]  = -factor*m2k;
849                 tmp2[kx]  = sqrt(factor*m2k);
850             }
851             /* Clear padding elements to avoid (harmless) fp exceptions */
852             const int kxendSimd = roundUpToMultipleOfFactor<c_simdWidth>(kxend);
853             for (; kx < kxendSimd; kx++)
854             {
855                 tmp1[kx] = 0;
856                 tmp2[kx] = 0;
857             }
858
859             calc_exponentials_lj(kxstart, kxend,
860                                  ArrayRef<PME_T>(tmp1, tmp1+roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
861                                  ArrayRef<PME_T>(tmp2, tmp2+roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
862                                  ArrayRef<PME_T>(denom, denom+roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
863
864             for (kx = kxstart; kx < kxend; kx++)
865             {
866                 m2k    = factor*m2[kx];
867                 eterm  = -((1.0 - 2.0*m2k)*tmp1[kx]
868                            + 2.0*m2k*tmp2[kx]);
869                 tmp1[kx] = eterm*denom[kx];
870             }
871             gcount = (bLB ? 7 : 1);
872             for (ig = 0; ig < gcount; ++ig)
873             {
874                 t_complex *p0;
875
876                 p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
877                 for (kx = kxstart; kx < kxend; kx++, p0++)
878                 {
879                     d1      = p0->re;
880                     d2      = p0->im;
881
882                     eterm   = tmp1[kx];
883
884                     p0->re  = d1*eterm;
885                     p0->im  = d2*eterm;
886                 }
887             }
888         }
889     }
890     if (bEnerVir)
891     {
892         work->vir_lj[XX][XX] = 0.25*virxx;
893         work->vir_lj[YY][YY] = 0.25*viryy;
894         work->vir_lj[ZZ][ZZ] = 0.25*virzz;
895         work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25*virxy;
896         work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25*virxz;
897         work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25*viryz;
898
899         /* This energy should be corrected for a charged system */
900         work->energy_lj = 0.5*energy;
901     }
902     /* Return the loop count */
903     return local_ndata[YY]*local_ndata[XX];
904 }