Tidy: modernize-use-nullptr
[alexxy/gromacs.git] / src / gromacs / ewald / pme-solve.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37
38 #include "gmxpre.h"
39
40 #include "pme-solve.h"
41
42 #include <cmath>
43
44 #include "gromacs/fft/parallel_3dfft.h"
45 #include "gromacs/math/units.h"
46 #include "gromacs/math/utilities.h"
47 #include "gromacs/math/vec.h"
48 #include "gromacs/simd/simd.h"
49 #include "gromacs/simd/simd_math.h"
50 #include "gromacs/utility/exceptions.h"
51 #include "gromacs/utility/smalloc.h"
52
53 #include "pme-internal.h"
54
55 #if GMX_SIMD_HAVE_REAL
56 /* Turn on arbitrary width SIMD intrinsics for PME solve */
57 #    define PME_SIMD_SOLVE
58 #endif
59
60 using namespace gmx; // TODO: Remove when this file is moved into gmx namespace
61
62 struct pme_solve_work_t
63 {
64     /* work data for solve_pme */
65     int      nalloc;
66     real *   mhx;
67     real *   mhy;
68     real *   mhz;
69     real *   m2;
70     real *   denom;
71     real *   tmp1_alloc;
72     real *   tmp1;
73     real *   tmp2;
74     real *   eterm;
75     real *   m2inv;
76
77     real     energy_q;
78     matrix   vir_q;
79     real     energy_lj;
80     matrix   vir_lj;
81 };
82
83 static void realloc_work(struct pme_solve_work_t *work, int nkx)
84 {
85     if (nkx > work->nalloc)
86     {
87         int simd_width, i;
88
89         work->nalloc = nkx;
90         srenew(work->mhx, work->nalloc);
91         srenew(work->mhy, work->nalloc);
92         srenew(work->mhz, work->nalloc);
93         srenew(work->m2, work->nalloc);
94         /* Allocate an aligned pointer for SIMD operations, including extra
95          * elements at the end for padding.
96          */
97 #ifdef PME_SIMD_SOLVE
98         simd_width = GMX_SIMD_REAL_WIDTH;
99 #else
100         /* We can use any alignment, apart from 0, so we use 4 */
101         simd_width = 4;
102 #endif
103         sfree_aligned(work->denom);
104         sfree_aligned(work->tmp1);
105         sfree_aligned(work->tmp2);
106         sfree_aligned(work->eterm);
107         snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
108         snew_aligned(work->tmp1,  work->nalloc+simd_width, simd_width*sizeof(real));
109         snew_aligned(work->tmp2,  work->nalloc+simd_width, simd_width*sizeof(real));
110         snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
111         srenew(work->m2inv, work->nalloc);
112
113         /* Init all allocated elements of denom to 1 to avoid 1/0 exceptions
114          * of simd padded elements.
115          */
116         for (i = 0; i < work->nalloc+simd_width; i++)
117         {
118             work->denom[i] = 1;
119         }
120     }
121 }
122
123 void pme_init_all_work(struct pme_solve_work_t **work, int nthread, int nkx)
124 {
125     int thread;
126     /* Use fft5d, order after FFT is y major, z, x minor */
127
128     snew(*work, nthread);
129     /* Allocate the work arrays thread local to optimize memory access */
130 #pragma omp parallel for num_threads(nthread) schedule(static)
131     for (thread = 0; thread < nthread; thread++)
132     {
133         try
134         {
135             realloc_work(&((*work)[thread]), nkx);
136         }
137         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
138     }
139 }
140
141 static void free_work(struct pme_solve_work_t *work)
142 {
143     sfree(work->mhx);
144     sfree(work->mhy);
145     sfree(work->mhz);
146     sfree(work->m2);
147     sfree_aligned(work->denom);
148     sfree_aligned(work->tmp1);
149     sfree_aligned(work->tmp2);
150     sfree_aligned(work->eterm);
151     sfree(work->m2inv);
152 }
153
154 void pme_free_all_work(struct pme_solve_work_t **work, int nthread)
155 {
156     int thread;
157
158     for (thread = 0; thread < nthread; thread++)
159     {
160         free_work(&(*work)[thread]);
161     }
162     sfree(*work);
163     *work = nullptr;
164 }
165
166 void get_pme_ener_vir_q(struct pme_solve_work_t *work, int nthread,
167                         real *mesh_energy, matrix vir)
168 {
169     /* This function sums output over threads and should therefore
170      * only be called after thread synchronization.
171      */
172     int thread;
173
174     *mesh_energy = work[0].energy_q;
175     copy_mat(work[0].vir_q, vir);
176
177     for (thread = 1; thread < nthread; thread++)
178     {
179         *mesh_energy += work[thread].energy_q;
180         m_add(vir, work[thread].vir_q, vir);
181     }
182 }
183
184 void get_pme_ener_vir_lj(struct pme_solve_work_t *work, int nthread,
185                          real *mesh_energy, matrix vir)
186 {
187     /* This function sums output over threads and should therefore
188      * only be called after thread synchronization.
189      */
190     int thread;
191
192     *mesh_energy = work[0].energy_lj;
193     copy_mat(work[0].vir_lj, vir);
194
195     for (thread = 1; thread < nthread; thread++)
196     {
197         *mesh_energy += work[thread].energy_lj;
198         m_add(vir, work[thread].vir_lj, vir);
199     }
200 }
201
202 #if defined PME_SIMD_SOLVE
203 /* Calculate exponentials through SIMD */
204 gmx_inline static void calc_exponentials_q(int gmx_unused start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
205 {
206     {
207         SimdReal              f_simd(f);
208         SimdReal              tmp_d1, tmp_r, tmp_e;
209         int                   kx;
210
211         /* We only need to calculate from start. But since start is 0 or 1
212          * and we want to use aligned loads/stores, we always start from 0.
213          */
214         for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
215         {
216             tmp_d1   = load(d_aligned+kx);
217             tmp_r    = load(r_aligned+kx);
218             tmp_r    = gmx::exp(tmp_r);
219             tmp_e    = f_simd / tmp_d1;
220             tmp_e    = tmp_e * tmp_r;
221             store(e_aligned+kx, tmp_e);
222         }
223     }
224 }
225 #else
226 gmx_inline static void calc_exponentials_q(int start, int end, real f, real *d, real *r, real *e)
227 {
228     int kx;
229     for (kx = start; kx < end; kx++)
230     {
231         d[kx] = 1.0/d[kx];
232     }
233     for (kx = start; kx < end; kx++)
234     {
235         r[kx] = std::exp(r[kx]);
236     }
237     for (kx = start; kx < end; kx++)
238     {
239         e[kx] = f*r[kx]*d[kx];
240     }
241 }
242 #endif
243
244 #if defined PME_SIMD_SOLVE
245 /* Calculate exponentials through SIMD */
246 gmx_inline static void calc_exponentials_lj(int gmx_unused start, int end, real *r_aligned, real *factor_aligned, real *d_aligned)
247 {
248     SimdReal              tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
249     const SimdReal        sqr_PI = sqrt(SimdReal(M_PI));
250     int                   kx;
251     for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
252     {
253         /* We only need to calculate from start. But since start is 0 or 1
254          * and we want to use aligned loads/stores, we always start from 0.
255          */
256         tmp_d = load(d_aligned+kx);
257         d_inv = SimdReal(1.0) / tmp_d;
258         store(d_aligned+kx, d_inv);
259         tmp_r = load(r_aligned+kx);
260         tmp_r = gmx::exp(tmp_r);
261         store(r_aligned+kx, tmp_r);
262         tmp_mk  = load(factor_aligned+kx);
263         tmp_fac = sqr_PI * tmp_mk * erfc(tmp_mk);
264         store(factor_aligned+kx, tmp_fac);
265     }
266 }
267 #else
268 gmx_inline static void calc_exponentials_lj(int start, int end, real *r, real *tmp2, real *d)
269 {
270     int  kx;
271     real mk;
272     for (kx = start; kx < end; kx++)
273     {
274         d[kx] = 1.0/d[kx];
275     }
276
277     for (kx = start; kx < end; kx++)
278     {
279         r[kx] = std::exp(r[kx]);
280     }
281
282     for (kx = start; kx < end; kx++)
283     {
284         mk       = tmp2[kx];
285         tmp2[kx] = sqrt(M_PI)*mk*std::erfc(mk);
286     }
287 }
288 #endif
289
290 int solve_pme_yzx(struct gmx_pme_t *pme, t_complex *grid, real vol,
291                   gmx_bool bEnerVir,
292                   int nthread, int thread)
293 {
294     /* do recip sum over local cells in grid */
295     /* y major, z middle, x minor or continuous */
296     t_complex               *p0;
297     int                      kx, ky, kz, maxkx, maxky;
298     int                      nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
299     real                     mx, my, mz;
300     real                     ewaldcoeff = pme->ewaldcoeff_q;
301     real                     factor     = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
302     real                     ets2, struct2, vfactor, ets2vf;
303     real                     d1, d2, energy = 0;
304     real                     by, bz;
305     real                     virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
306     real                     rxx, ryx, ryy, rzx, rzy, rzz;
307     struct pme_solve_work_t *work;
308     real                    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
309     real                     mhxk, mhyk, mhzk, m2k;
310     real                     corner_fac;
311     ivec                     complex_order;
312     ivec                     local_ndata, local_offset, local_size;
313     real                     elfac;
314
315     elfac = ONE_4PI_EPS0/pme->epsilon_r;
316
317     nx = pme->nkx;
318     ny = pme->nky;
319     nz = pme->nkz;
320
321     /* Dimensions should be identical for A/B grid, so we just use A here */
322     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_QA],
323                                       complex_order,
324                                       local_ndata,
325                                       local_offset,
326                                       local_size);
327
328     rxx = pme->recipbox[XX][XX];
329     ryx = pme->recipbox[YY][XX];
330     ryy = pme->recipbox[YY][YY];
331     rzx = pme->recipbox[ZZ][XX];
332     rzy = pme->recipbox[ZZ][YY];
333     rzz = pme->recipbox[ZZ][ZZ];
334
335     maxkx = (nx+1)/2;
336     maxky = (ny+1)/2;
337
338     work  = &pme->solve_work[thread];
339     mhx   = work->mhx;
340     mhy   = work->mhy;
341     mhz   = work->mhz;
342     m2    = work->m2;
343     denom = work->denom;
344     tmp1  = work->tmp1;
345     eterm = work->eterm;
346     m2inv = work->m2inv;
347
348     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
349     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
350
351     for (iyz = iyz0; iyz < iyz1; iyz++)
352     {
353         iy = iyz/local_ndata[ZZ];
354         iz = iyz - iy*local_ndata[ZZ];
355
356         ky = iy + local_offset[YY];
357
358         if (ky < maxky)
359         {
360             my = ky;
361         }
362         else
363         {
364             my = (ky - ny);
365         }
366
367         by = M_PI*vol*pme->bsp_mod[YY][ky];
368
369         kz = iz + local_offset[ZZ];
370
371         mz = kz;
372
373         bz = pme->bsp_mod[ZZ][kz];
374
375         /* 0.5 correction for corner points */
376         corner_fac = 1;
377         if (kz == 0 || kz == (nz+1)/2)
378         {
379             corner_fac = 0.5;
380         }
381
382         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
383
384         /* We should skip the k-space point (0,0,0) */
385         /* Note that since here x is the minor index, local_offset[XX]=0 */
386         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
387         {
388             kxstart = local_offset[XX];
389         }
390         else
391         {
392             kxstart = local_offset[XX] + 1;
393             p0++;
394         }
395         kxend = local_offset[XX] + local_ndata[XX];
396
397         if (bEnerVir)
398         {
399             /* More expensive inner loop, especially because of the storage
400              * of the mh elements in array's.
401              * Because x is the minor grid index, all mh elements
402              * depend on kx for triclinic unit cells.
403              */
404
405             /* Two explicit loops to avoid a conditional inside the loop */
406             for (kx = kxstart; kx < maxkx; kx++)
407             {
408                 mx = kx;
409
410                 mhxk      = mx * rxx;
411                 mhyk      = mx * ryx + my * ryy;
412                 mhzk      = mx * rzx + my * rzy + mz * rzz;
413                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
414                 mhx[kx]   = mhxk;
415                 mhy[kx]   = mhyk;
416                 mhz[kx]   = mhzk;
417                 m2[kx]    = m2k;
418                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
419                 tmp1[kx]  = -factor*m2k;
420             }
421
422             for (kx = maxkx; kx < kxend; kx++)
423             {
424                 mx = (kx - nx);
425
426                 mhxk      = mx * rxx;
427                 mhyk      = mx * ryx + my * ryy;
428                 mhzk      = mx * rzx + my * rzy + mz * rzz;
429                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
430                 mhx[kx]   = mhxk;
431                 mhy[kx]   = mhyk;
432                 mhz[kx]   = mhzk;
433                 m2[kx]    = m2k;
434                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
435                 tmp1[kx]  = -factor*m2k;
436             }
437
438             for (kx = kxstart; kx < kxend; kx++)
439             {
440                 m2inv[kx] = 1.0/m2[kx];
441             }
442
443             calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
444
445             for (kx = kxstart; kx < kxend; kx++, p0++)
446             {
447                 d1      = p0->re;
448                 d2      = p0->im;
449
450                 p0->re  = d1*eterm[kx];
451                 p0->im  = d2*eterm[kx];
452
453                 struct2 = 2.0*(d1*d1+d2*d2);
454
455                 tmp1[kx] = eterm[kx]*struct2;
456             }
457
458             for (kx = kxstart; kx < kxend; kx++)
459             {
460                 ets2     = corner_fac*tmp1[kx];
461                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
462                 energy  += ets2;
463
464                 ets2vf   = ets2*vfactor;
465                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
466                 virxy   += ets2vf*mhx[kx]*mhy[kx];
467                 virxz   += ets2vf*mhx[kx]*mhz[kx];
468                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
469                 viryz   += ets2vf*mhy[kx]*mhz[kx];
470                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
471             }
472         }
473         else
474         {
475             /* We don't need to calculate the energy and the virial.
476              * In this case the triclinic overhead is small.
477              */
478
479             /* Two explicit loops to avoid a conditional inside the loop */
480
481             for (kx = kxstart; kx < maxkx; kx++)
482             {
483                 mx = kx;
484
485                 mhxk      = mx * rxx;
486                 mhyk      = mx * ryx + my * ryy;
487                 mhzk      = mx * rzx + my * rzy + mz * rzz;
488                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
489                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
490                 tmp1[kx]  = -factor*m2k;
491             }
492
493             for (kx = maxkx; kx < kxend; kx++)
494             {
495                 mx = (kx - nx);
496
497                 mhxk      = mx * rxx;
498                 mhyk      = mx * ryx + my * ryy;
499                 mhzk      = mx * rzx + my * rzy + mz * rzz;
500                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
501                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
502                 tmp1[kx]  = -factor*m2k;
503             }
504
505             calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
506
507             for (kx = kxstart; kx < kxend; kx++, p0++)
508             {
509                 d1      = p0->re;
510                 d2      = p0->im;
511
512                 p0->re  = d1*eterm[kx];
513                 p0->im  = d2*eterm[kx];
514             }
515         }
516     }
517
518     if (bEnerVir)
519     {
520         /* Update virial with local values.
521          * The virial is symmetric by definition.
522          * this virial seems ok for isotropic scaling, but I'm
523          * experiencing problems on semiisotropic membranes.
524          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
525          */
526         work->vir_q[XX][XX] = 0.25*virxx;
527         work->vir_q[YY][YY] = 0.25*viryy;
528         work->vir_q[ZZ][ZZ] = 0.25*virzz;
529         work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25*virxy;
530         work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25*virxz;
531         work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25*viryz;
532
533         /* This energy should be corrected for a charged system */
534         work->energy_q = 0.5*energy;
535     }
536
537     /* Return the loop count */
538     return local_ndata[YY]*local_ndata[XX];
539 }
540
541 int solve_pme_lj_yzx(struct gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real vol,
542                      gmx_bool bEnerVir, int nthread, int thread)
543 {
544     /* do recip sum over local cells in grid */
545     /* y major, z middle, x minor or continuous */
546     int                      ig, gcount;
547     int                      kx, ky, kz, maxkx, maxky;
548     int                      nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
549     real                     mx, my, mz;
550     real                     ewaldcoeff = pme->ewaldcoeff_lj;
551     real                     factor     = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
552     real                     ets2, ets2vf;
553     real                     eterm, vterm, d1, d2, energy = 0;
554     real                     by, bz;
555     real                     virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
556     real                     rxx, ryx, ryy, rzx, rzy, rzz;
557     real                    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
558     real                     mhxk, mhyk, mhzk, m2k;
559     struct pme_solve_work_t *work;
560     real                     corner_fac;
561     ivec                     complex_order;
562     ivec                     local_ndata, local_offset, local_size;
563     nx = pme->nkx;
564     ny = pme->nky;
565     nz = pme->nkz;
566
567     /* Dimensions should be identical for A/B grid, so we just use A here */
568     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_C6A],
569                                       complex_order,
570                                       local_ndata,
571                                       local_offset,
572                                       local_size);
573     rxx = pme->recipbox[XX][XX];
574     ryx = pme->recipbox[YY][XX];
575     ryy = pme->recipbox[YY][YY];
576     rzx = pme->recipbox[ZZ][XX];
577     rzy = pme->recipbox[ZZ][YY];
578     rzz = pme->recipbox[ZZ][ZZ];
579
580     maxkx = (nx+1)/2;
581     maxky = (ny+1)/2;
582
583     work  = &pme->solve_work[thread];
584     mhx   = work->mhx;
585     mhy   = work->mhy;
586     mhz   = work->mhz;
587     m2    = work->m2;
588     denom = work->denom;
589     tmp1  = work->tmp1;
590     tmp2  = work->tmp2;
591
592     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
593     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
594
595     for (iyz = iyz0; iyz < iyz1; iyz++)
596     {
597         iy = iyz/local_ndata[ZZ];
598         iz = iyz - iy*local_ndata[ZZ];
599
600         ky = iy + local_offset[YY];
601
602         if (ky < maxky)
603         {
604             my = ky;
605         }
606         else
607         {
608             my = (ky - ny);
609         }
610
611         by = 3.0*vol*pme->bsp_mod[YY][ky]
612             / (M_PI*sqrt(M_PI)*ewaldcoeff*ewaldcoeff*ewaldcoeff);
613
614         kz = iz + local_offset[ZZ];
615
616         mz = kz;
617
618         bz = pme->bsp_mod[ZZ][kz];
619
620         /* 0.5 correction for corner points */
621         corner_fac = 1;
622         if (kz == 0 || kz == (nz+1)/2)
623         {
624             corner_fac = 0.5;
625         }
626
627         kxstart = local_offset[XX];
628         kxend   = local_offset[XX] + local_ndata[XX];
629         if (bEnerVir)
630         {
631             /* More expensive inner loop, especially because of the
632              * storage of the mh elements in array's.  Because x is the
633              * minor grid index, all mh elements depend on kx for
634              * triclinic unit cells.
635              */
636
637             /* Two explicit loops to avoid a conditional inside the loop */
638             for (kx = kxstart; kx < maxkx; kx++)
639             {
640                 mx = kx;
641
642                 mhxk      = mx * rxx;
643                 mhyk      = mx * ryx + my * ryy;
644                 mhzk      = mx * rzx + my * rzy + mz * rzz;
645                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
646                 mhx[kx]   = mhxk;
647                 mhy[kx]   = mhyk;
648                 mhz[kx]   = mhzk;
649                 m2[kx]    = m2k;
650                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
651                 tmp1[kx]  = -factor*m2k;
652                 tmp2[kx]  = sqrt(factor*m2k);
653             }
654
655             for (kx = maxkx; kx < kxend; kx++)
656             {
657                 mx = (kx - nx);
658
659                 mhxk      = mx * rxx;
660                 mhyk      = mx * ryx + my * ryy;
661                 mhzk      = mx * rzx + my * rzy + mz * rzz;
662                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
663                 mhx[kx]   = mhxk;
664                 mhy[kx]   = mhyk;
665                 mhz[kx]   = mhzk;
666                 m2[kx]    = m2k;
667                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
668                 tmp1[kx]  = -factor*m2k;
669                 tmp2[kx]  = sqrt(factor*m2k);
670             }
671
672             calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
673
674             for (kx = kxstart; kx < kxend; kx++)
675             {
676                 m2k   = factor*m2[kx];
677                 eterm = -((1.0 - 2.0*m2k)*tmp1[kx]
678                           + 2.0*m2k*tmp2[kx]);
679                 vterm    = 3.0*(-tmp1[kx] + tmp2[kx]);
680                 tmp1[kx] = eterm*denom[kx];
681                 tmp2[kx] = vterm*denom[kx];
682             }
683
684             if (!bLB)
685             {
686                 t_complex *p0;
687                 real       struct2;
688
689                 p0 = grid[0] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
690                 for (kx = kxstart; kx < kxend; kx++, p0++)
691                 {
692                     d1      = p0->re;
693                     d2      = p0->im;
694
695                     eterm   = tmp1[kx];
696                     vterm   = tmp2[kx];
697                     p0->re  = d1*eterm;
698                     p0->im  = d2*eterm;
699
700                     struct2 = 2.0*(d1*d1+d2*d2);
701
702                     tmp1[kx] = eterm*struct2;
703                     tmp2[kx] = vterm*struct2;
704                 }
705             }
706             else
707             {
708                 real *struct2 = denom;
709                 real  str2;
710
711                 for (kx = kxstart; kx < kxend; kx++)
712                 {
713                     struct2[kx] = 0.0;
714                 }
715                 /* Due to symmetry we only need to calculate 4 of the 7 terms */
716                 for (ig = 0; ig <= 3; ++ig)
717                 {
718                     t_complex *p0, *p1;
719                     real       scale;
720
721                     p0    = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
722                     p1    = grid[6-ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
723                     scale = 2.0*lb_scale_factor_symm[ig];
724                     for (kx = kxstart; kx < kxend; ++kx, ++p0, ++p1)
725                     {
726                         struct2[kx] += scale*(p0->re*p1->re + p0->im*p1->im);
727                     }
728
729                 }
730                 for (ig = 0; ig <= 6; ++ig)
731                 {
732                     t_complex *p0;
733
734                     p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
735                     for (kx = kxstart; kx < kxend; kx++, p0++)
736                     {
737                         d1     = p0->re;
738                         d2     = p0->im;
739
740                         eterm  = tmp1[kx];
741                         p0->re = d1*eterm;
742                         p0->im = d2*eterm;
743                     }
744                 }
745                 for (kx = kxstart; kx < kxend; kx++)
746                 {
747                     eterm    = tmp1[kx];
748                     vterm    = tmp2[kx];
749                     str2     = struct2[kx];
750                     tmp1[kx] = eterm*str2;
751                     tmp2[kx] = vterm*str2;
752                 }
753             }
754
755             for (kx = kxstart; kx < kxend; kx++)
756             {
757                 ets2     = corner_fac*tmp1[kx];
758                 vterm    = 2.0*factor*tmp2[kx];
759                 energy  += ets2;
760                 ets2vf   = corner_fac*vterm;
761                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
762                 virxy   += ets2vf*mhx[kx]*mhy[kx];
763                 virxz   += ets2vf*mhx[kx]*mhz[kx];
764                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
765                 viryz   += ets2vf*mhy[kx]*mhz[kx];
766                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
767             }
768         }
769         else
770         {
771             /* We don't need to calculate the energy and the virial.
772              *  In this case the triclinic overhead is small.
773              */
774
775             /* Two explicit loops to avoid a conditional inside the loop */
776
777             for (kx = kxstart; kx < maxkx; kx++)
778             {
779                 mx = kx;
780
781                 mhxk      = mx * rxx;
782                 mhyk      = mx * ryx + my * ryy;
783                 mhzk      = mx * rzx + my * rzy + mz * rzz;
784                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
785                 m2[kx]    = m2k;
786                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
787                 tmp1[kx]  = -factor*m2k;
788                 tmp2[kx]  = sqrt(factor*m2k);
789             }
790
791             for (kx = maxkx; kx < kxend; kx++)
792             {
793                 mx = (kx - nx);
794
795                 mhxk      = mx * rxx;
796                 mhyk      = mx * ryx + my * ryy;
797                 mhzk      = mx * rzx + my * rzy + mz * rzz;
798                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
799                 m2[kx]    = m2k;
800                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
801                 tmp1[kx]  = -factor*m2k;
802                 tmp2[kx]  = sqrt(factor*m2k);
803             }
804
805             calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
806
807             for (kx = kxstart; kx < kxend; kx++)
808             {
809                 m2k    = factor*m2[kx];
810                 eterm  = -((1.0 - 2.0*m2k)*tmp1[kx]
811                            + 2.0*m2k*tmp2[kx]);
812                 tmp1[kx] = eterm*denom[kx];
813             }
814             gcount = (bLB ? 7 : 1);
815             for (ig = 0; ig < gcount; ++ig)
816             {
817                 t_complex *p0;
818
819                 p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
820                 for (kx = kxstart; kx < kxend; kx++, p0++)
821                 {
822                     d1      = p0->re;
823                     d2      = p0->im;
824
825                     eterm   = tmp1[kx];
826
827                     p0->re  = d1*eterm;
828                     p0->im  = d2*eterm;
829                 }
830             }
831         }
832     }
833     if (bEnerVir)
834     {
835         work->vir_lj[XX][XX] = 0.25*virxx;
836         work->vir_lj[YY][YY] = 0.25*viryy;
837         work->vir_lj[ZZ][ZZ] = 0.25*virzz;
838         work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25*virxy;
839         work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25*virxz;
840         work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25*viryz;
841
842         /* This energy should be corrected for a charged system */
843         work->energy_lj = 0.5*energy;
844     }
845     /* Return the loop count */
846     return local_ndata[YY]*local_ndata[XX];
847 }