Apply clang-format to source tree
[alexxy/gromacs.git] / src / gromacs / ewald / pme_solve.cpp
index 23d181367e0eac9ba05916eec9927ede11ede298..406e701e2d9976bf0664b5064bc93305f6558923 100644 (file)
@@ -63,22 +63,22 @@ using namespace gmx; // TODO: Remove when this file is moved into gmx namespace
 struct pme_solve_work_t
 {
     /* work data for solve_pme */
-    int      nalloc;
-    real *   mhx;
-    real *   mhy;
-    real *   mhz;
-    real *   m2;
-    real *   denom;
-    real *   tmp1_alloc;
-    real *   tmp1;
-    real *   tmp2;
-    real *   eterm;
-    real *   m2inv;
-
-    real     energy_q;
-    matrix   vir_q;
-    real     energy_lj;
-    matrix   vir_lj;
+    int   nalloc;
+    real* mhx;
+    real* mhy;
+    real* mhz;
+    real* m2;
+    real* denom;
+    real* tmp1_alloc;
+    real* tmp1;
+    real* tmp2;
+    real* eterm;
+    real* m2inv;
+
+    real   energy_q;
+    matrix vir_q;
+    real   energy_lj;
+    matrix vir_lj;
 };
 
 #ifdef PME_SIMD_SOLVE
@@ -89,10 +89,11 @@ constexpr int c_simdWidth = 4;
 #endif
 
 /* Returns the smallest number >= \p that is a multiple of \p factor, \p factor must be a power of 2 */
-template <unsigned int factor>
+template<unsigned int factor>
 static size_t roundUpToMultipleOfFactor(size_t number)
 {
-    static_assert(factor > 0 && (factor & (factor - 1)) == 0, "factor should be >0 and a power of 2");
+    static_assert(factor > 0 && (factor & (factor - 1)) == 0,
+                  "factor should be >0 and a power of 2");
 
     /* We need to add a most factor-1 and because factor is a power of 2,
      * we get the result by masking out the bits corresponding to factor-1.
@@ -104,13 +105,14 @@ static size_t roundUpToMultipleOfFactor(size_t number)
  * at the end for padding.
  */
 /* TODO: Replace this SIMD reallocator with a general, C++ solution */
-static void reallocSimdAlignedAndPadded(real **ptr, int unpaddedNumElements)
+static void reallocSimdAlignedAndPadded(real** ptr, int unpaddedNumElements)
 {
     sfree_aligned(*ptr);
-    snew_aligned(*ptr, roundUpToMultipleOfFactor<c_simdWidth>(unpaddedNumElements), c_simdWidth*sizeof(real));
+    snew_aligned(*ptr, roundUpToMultipleOfFactor<c_simdWidth>(unpaddedNumElements),
+                 c_simdWidth * sizeof(real));
 }
 
-static void realloc_work(struct pme_solve_work_t *work, int nkx)
+static void realloc_work(struct pme_solve_work_twork, int nkx)
 {
     if (nkx > work->nalloc)
     {
@@ -128,14 +130,14 @@ static void realloc_work(struct pme_solve_work_t *work, int nkx)
         /* Init all allocated elements of denom to 1 to avoid 1/0 exceptions
          * of simd padded elements.
          */
-        for (size_t i = 0; i < roundUpToMultipleOfFactor<c_simdWidth>(work->nalloc ); i++)
+        for (size_t i = 0; i < roundUpToMultipleOfFactor<c_simdWidth>(work->nalloc); i++)
         {
             work->denom[i] = 1;
         }
     }
 }
 
-void pme_init_all_work(struct pme_solve_work_t **work, int nthread, int nkx)
+void pme_init_all_work(struct pme_solve_work_t** work, int nthread, int nkx)
 {
     /* Use fft5d, order after FFT is y major, z, x minor */
 
@@ -148,11 +150,11 @@ void pme_init_all_work(struct pme_solve_work_t **work, int nthread, int nkx)
         {
             realloc_work(&((*work)[thread]), nkx);
         }
-        GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
+        GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
     }
 }
 
-static void free_work(struct pme_solve_work_t *work)
+static void free_work(struct pme_solve_work_twork)
 {
     if (work)
     {
@@ -168,7 +170,7 @@ static void free_work(struct pme_solve_work_t *work)
     }
 }
 
-void pme_free_all_work(struct pme_solve_work_t **work, int nthread)
+void pme_free_all_work(struct pme_solve_work_t** work, int nthread)
 {
     if (*work)
     {
@@ -181,7 +183,7 @@ void pme_free_all_work(struct pme_solve_work_t **work, int nthread)
     *work = nullptr;
 }
 
-void get_pme_ener_vir_q(pme_solve_work_t *work, int nthread, PmeOutput *output)
+void get_pme_ener_vir_q(pme_solve_work_t* work, int nthread, PmeOutput* output)
 {
     GMX_ASSERT(output != nullptr, "Need valid output buffer");
     /* This function sums output over threads and should therefore
@@ -197,7 +199,7 @@ void get_pme_ener_vir_q(pme_solve_work_t *work, int nthread, PmeOutput *output)
     }
 }
 
-void get_pme_ener_vir_lj(pme_solve_work_t *work, int nthread, PmeOutput *output)
+void get_pme_ener_vir_lj(pme_solve_work_t* work, int nthread, PmeOutput* output)
 {
     GMX_ASSERT(output != nullptr, "Need valid output buffer");
     /* This function sums output over threads and should therefore
@@ -215,11 +217,16 @@ void get_pme_ener_vir_lj(pme_solve_work_t *work, int nthread, PmeOutput *output)
 
 #if defined PME_SIMD_SOLVE
 /* Calculate exponentials through SIMD */
-inline static void calc_exponentials_q(int /*unused*/, int /*unused*/, real f, ArrayRef<const SimdReal> d_aligned, ArrayRef<const SimdReal> r_aligned, ArrayRef<SimdReal> e_aligned)
+inline static void calc_exponentials_q(int /*unused*/,
+                                       int /*unused*/,
+                                       real                     f,
+                                       ArrayRef<const SimdReal> d_aligned,
+                                       ArrayRef<const SimdReal> r_aligned,
+                                       ArrayRef<SimdReal>       e_aligned)
 {
     {
-        SimdReal              f_simd(f);
-        SimdReal              tmp_d1, tmp_r, tmp_e;
+        SimdReal f_simd(f);
+        SimdReal tmp_d1, tmp_r, tmp_e;
 
         /* We only need to calculate from start. But since start is 0 or 1
          * and we want to use aligned loads/stores, we always start from 0.
@@ -238,14 +245,15 @@ inline static void calc_exponentials_q(int /*unused*/, int /*unused*/, real f, A
     }
 }
 #else
-inline static void calc_exponentials_q(int start, int end, real f, ArrayRef<real> d, ArrayRef<real> r, ArrayRef<real> e)
+inline static void
+calc_exponentials_q(int start, int end, real f, ArrayRef<real> d, ArrayRef<real> r, ArrayRef<real> e)
 {
     GMX_ASSERT(d.size() == r.size(), "d and r must have same size");
     GMX_ASSERT(d.size() == e.size(), "d and e must have same size");
     int kx;
     for (kx = start; kx < end; kx++)
     {
-        d[kx] = 1.0/d[kx];
+        d[kx] = 1.0 / d[kx];
     }
     for (kx = start; kx < end; kx++)
     {
@@ -253,17 +261,21 @@ inline static void calc_exponentials_q(int start, int end, real f, ArrayRef<real
     }
     for (kx = start; kx < end; kx++)
     {
-        e[kx] = f*r[kx]*d[kx];
+        e[kx] = f * r[kx] * d[kx];
     }
 }
 #endif
 
 #if defined PME_SIMD_SOLVE
 /* Calculate exponentials through SIMD */
-inline static void calc_exponentials_lj(int /*unused*/, int /*unused*/, ArrayRef<SimdReal> r_aligned, ArrayRef<SimdReal> factor_aligned, ArrayRef<SimdReal> d_aligned)
+inline static void calc_exponentials_lj(int /*unused*/,
+                                        int /*unused*/,
+                                        ArrayRef<SimdReal> r_aligned,
+                                        ArrayRef<SimdReal> factor_aligned,
+                                        ArrayRef<SimdReal> d_aligned)
 {
-    SimdReal              tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
-    const SimdReal        sqr_PI = sqrt(SimdReal(M_PI));
+    SimdReal       tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
+    const SimdReal sqr_PI = sqrt(SimdReal(M_PI));
 
     GMX_ASSERT(d_aligned.size() == r_aligned.size(), "d and r must have same size");
     GMX_ASSERT(d_aligned.size() == factor_aligned.size(), "d and factor must have same size");
@@ -284,7 +296,8 @@ inline static void calc_exponentials_lj(int /*unused*/, int /*unused*/, ArrayRef
     }
 }
 #else
-inline static void calc_exponentials_lj(int start, int end, ArrayRef<real> r, ArrayRef<real> tmp2, ArrayRef<real> d)
+inline static void
+calc_exponentials_lj(int start, int end, ArrayRef<real> r, ArrayRef<real> tmp2, ArrayRef<real> d)
 {
     int  kx;
     real mk;
@@ -292,7 +305,7 @@ inline static void calc_exponentials_lj(int start, int end, ArrayRef<real> r, Ar
     GMX_ASSERT(d.size() == tmp2.size(), "d and tmp2 must have same size");
     for (kx = start; kx < end; kx++)
     {
-        d[kx] = 1.0/d[kx];
+        d[kx] = 1.0 / d[kx];
     }
 
     for (kx = start; kx < end; kx++)
@@ -303,7 +316,7 @@ inline static void calc_exponentials_lj(int start, int end, ArrayRef<real> r, Ar
     for (kx = start; kx < end; kx++)
     {
         mk       = tmp2[kx];
-        tmp2[kx] = sqrt(M_PI)*mk*std::erfc(mk);
+        tmp2[kx] = sqrt(M_PI) * mk * std::erfc(mk);
     }
 }
 #endif
@@ -314,43 +327,38 @@ using PME_T = SimdReal;
 using PME_T = real;
 #endif
 
-int solve_pme_yzx(const gmx_pme_t *pme, t_complex *grid, real vol,
-                  gmx_bool bEnerVir,
-                  int nthread, int thread)
+int solve_pme_yzx(const gmx_pme_t* pme, t_complex* grid, real vol, gmx_bool bEnerVir, int nthread, int thread)
 {
     /* do recip sum over local cells in grid */
     /* y major, z middle, x minor or continuous */
-    t_complex               *p0;
+    t_complex*               p0;
     int                      kx, ky, kz, maxkx, maxky;
     int                      nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
     real                     mx, my, mz;
     real                     ewaldcoeff = pme->ewaldcoeff_q;
-    real                     factor     = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
+    real                     factor     = M_PI * M_PI / (ewaldcoeff * ewaldcoeff);
     real                     ets2, struct2, vfactor, ets2vf;
     real                     d1, d2, energy = 0;
     real                     by, bz;
     real                     virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
     real                     rxx, ryx, ryy, rzx, rzy, rzz;
-    struct pme_solve_work_t *work;
-    real                    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
+    struct pme_solve_work_twork;
+    real *                   mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
     real                     mhxk, mhyk, mhzk, m2k;
     real                     corner_fac;
     ivec                     complex_order;
     ivec                     local_ndata, local_offset, local_size;
     real                     elfac;
 
-    elfac = ONE_4PI_EPS0/pme->epsilon_r;
+    elfac = ONE_4PI_EPS0 / pme->epsilon_r;
 
     nx = pme->nkx;
     ny = pme->nky;
     nz = pme->nkz;
 
     /* Dimensions should be identical for A/B grid, so we just use A here */
-    gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_QA],
-                                      complex_order,
-                                      local_ndata,
-                                      local_offset,
-                                      local_size);
+    gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_QA], complex_order, local_ndata,
+                                      local_offset, local_size);
 
     rxx = pme->recipbox[XX][XX];
     ryx = pme->recipbox[YY][XX];
@@ -361,8 +369,8 @@ int solve_pme_yzx(const gmx_pme_t *pme, t_complex *grid, real vol,
 
     GMX_ASSERT(rxx != 0.0, "Someone broke the reciprocal box again");
 
-    maxkx = (nx+1)/2;
-    maxky = (ny+1)/2;
+    maxkx = (nx + 1) / 2;
+    maxky = (ny + 1) / 2;
 
     work  = &pme->solve_work[thread];
     mhx   = work->mhx;
@@ -374,13 +382,13 @@ int solve_pme_yzx(const gmx_pme_t *pme, t_complex *grid, real vol,
     eterm = work->eterm;
     m2inv = work->m2inv;
 
-    iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
-    iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
+    iyz0 = local_ndata[YY] * local_ndata[ZZ] * thread / nthread;
+    iyz1 = local_ndata[YY] * local_ndata[ZZ] * (thread + 1) / nthread;
 
     for (iyz = iyz0; iyz < iyz1; iyz++)
     {
-        iy = iyz/local_ndata[ZZ];
-        iz = iyz - iy*local_ndata[ZZ];
+        iy = iyz / local_ndata[ZZ];
+        iz = iyz - iy * local_ndata[ZZ];
 
         ky = iy + local_offset[YY];
 
@@ -393,7 +401,7 @@ int solve_pme_yzx(const gmx_pme_t *pme, t_complex *grid, real vol,
             my = (ky - ny);
         }
 
-        by = M_PI*vol*pme->bsp_mod[YY][ky];
+        by = M_PI * vol * pme->bsp_mod[YY][ky];
 
         kz = iz + local_offset[ZZ];
 
@@ -403,12 +411,12 @@ int solve_pme_yzx(const gmx_pme_t *pme, t_complex *grid, real vol,
 
         /* 0.5 correction for corner points */
         corner_fac = 1;
-        if (kz == 0 || kz == (nz+1)/2)
+        if (kz == 0 || kz == (nz + 1) / 2)
         {
             corner_fac = 0.5;
         }
 
-        p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
+        p0 = grid + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
 
         /* We should skip the k-space point (0,0,0) */
         /* Note that since here x is the minor index, local_offset[XX]=0 */
@@ -439,13 +447,13 @@ int solve_pme_yzx(const gmx_pme_t *pme, t_complex *grid, real vol,
                 mhxk      = mx * rxx;
                 mhyk      = mx * ryx + my * ryy;
                 mhzk      = mx * rzx + my * rzy + mz * rzz;
-                m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
+                m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
                 mhx[kx]   = mhxk;
                 mhy[kx]   = mhyk;
                 mhz[kx]   = mhzk;
                 m2[kx]    = m2k;
-                denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
-                tmp1[kx]  = -factor*m2k;
+                denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
+                tmp1[kx]  = -factor * m2k;
             }
 
             for (kx = maxkx; kx < kxend; kx++)
@@ -455,51 +463,52 @@ int solve_pme_yzx(const gmx_pme_t *pme, t_complex *grid, real vol,
                 mhxk      = mx * rxx;
                 mhyk      = mx * ryx + my * ryy;
                 mhzk      = mx * rzx + my * rzy + mz * rzz;
-                m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
+                m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
                 mhx[kx]   = mhxk;
                 mhy[kx]   = mhyk;
                 mhz[kx]   = mhzk;
                 m2[kx]    = m2k;
-                denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
-                tmp1[kx]  = -factor*m2k;
+                denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
+                tmp1[kx]  = -factor * m2k;
             }
 
             for (kx = kxstart; kx < kxend; kx++)
             {
-                m2inv[kx] = 1.0/m2[kx];
+                m2inv[kx] = 1.0 / m2[kx];
             }
 
-            calc_exponentials_q(kxstart, kxend, elfac,
-                                ArrayRef<PME_T>(denom, denom+roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
-                                ArrayRef<PME_T>(tmp1, tmp1+roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
-                                ArrayRef<PME_T>(eterm, eterm+roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
+            calc_exponentials_q(
+                    kxstart, kxend, elfac,
+                    ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
+                    ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
+                    ArrayRef<PME_T>(eterm, eterm + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
 
             for (kx = kxstart; kx < kxend; kx++, p0++)
             {
-                d1      = p0->re;
-                d2      = p0->im;
+                d1 = p0->re;
+                d2 = p0->im;
 
-                p0->re  = d1*eterm[kx];
-                p0->im  = d2*eterm[kx];
+                p0->re = d1 * eterm[kx];
+                p0->im = d2 * eterm[kx];
 
-                struct2 = 2.0*(d1*d1+d2*d2);
+                struct2 = 2.0 * (d1 * d1 + d2 * d2);
 
-                tmp1[kx] = eterm[kx]*struct2;
+                tmp1[kx] = eterm[kx] * struct2;
             }
 
             for (kx = kxstart; kx < kxend; kx++)
             {
-                ets2     = corner_fac*tmp1[kx];
-                vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
-                energy  += ets2;
-
-                ets2vf   = ets2*vfactor;
-                virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
-                virxy   += ets2vf*mhx[kx]*mhy[kx];
-                virxz   += ets2vf*mhx[kx]*mhz[kx];
-                viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
-                viryz   += ets2vf*mhy[kx]*mhz[kx];
-                virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
+                ets2    = corner_fac * tmp1[kx];
+                vfactor = (factor * m2[kx] + 1.0) * 2.0 * m2inv[kx];
+                energy += ets2;
+
+                ets2vf = ets2 * vfactor;
+                virxx += ets2vf * mhx[kx] * mhx[kx] - ets2;
+                virxy += ets2vf * mhx[kx] * mhy[kx];
+                virxz += ets2vf * mhx[kx] * mhz[kx];
+                viryy += ets2vf * mhy[kx] * mhy[kx] - ets2;
+                viryz += ets2vf * mhy[kx] * mhz[kx];
+                virzz += ets2vf * mhz[kx] * mhz[kx] - ets2;
             }
         }
         else
@@ -517,9 +526,9 @@ int solve_pme_yzx(const gmx_pme_t *pme, t_complex *grid, real vol,
                 mhxk      = mx * rxx;
                 mhyk      = mx * ryx + my * ryy;
                 mhzk      = mx * rzx + my * rzy + mz * rzz;
-                m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
-                denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
-                tmp1[kx]  = -factor*m2k;
+                m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
+                denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
+                tmp1[kx]  = -factor * m2k;
             }
 
             for (kx = maxkx; kx < kxend; kx++)
@@ -529,24 +538,25 @@ int solve_pme_yzx(const gmx_pme_t *pme, t_complex *grid, real vol,
                 mhxk      = mx * rxx;
                 mhyk      = mx * ryx + my * ryy;
                 mhzk      = mx * rzx + my * rzy + mz * rzz;
-                m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
-                denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
-                tmp1[kx]  = -factor*m2k;
+                m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
+                denom[kx] = m2k * bz * by * pme->bsp_mod[XX][kx];
+                tmp1[kx]  = -factor * m2k;
             }
 
-            calc_exponentials_q(kxstart, kxend, elfac,
-                                ArrayRef<PME_T>(denom, denom+roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
-                                ArrayRef<PME_T>(tmp1, tmp1+roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
-                                ArrayRef<PME_T>(eterm, eterm+roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
+            calc_exponentials_q(
+                    kxstart, kxend, elfac,
+                    ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
+                    ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
+                    ArrayRef<PME_T>(eterm, eterm + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
 
 
             for (kx = kxstart; kx < kxend; kx++, p0++)
             {
-                d1      = p0->re;
-                d2      = p0->im;
+                d1 = p0->re;
+                d2 = p0->im;
 
-                p0->re  = d1*eterm[kx];
-                p0->im  = d2*eterm[kx];
+                p0->re = d1 * eterm[kx];
+                p0->im = d2 * eterm[kx];
             }
         }
     }
@@ -559,23 +569,22 @@ int solve_pme_yzx(const gmx_pme_t *pme, t_complex *grid, real vol,
          * experiencing problems on semiisotropic membranes.
          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
          */
-        work->vir_q[XX][XX] = 0.25*virxx;
-        work->vir_q[YY][YY] = 0.25*viryy;
-        work->vir_q[ZZ][ZZ] = 0.25*virzz;
-        work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25*virxy;
-        work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25*virxz;
-        work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25*viryz;
+        work->vir_q[XX][XX] = 0.25 * virxx;
+        work->vir_q[YY][YY] = 0.25 * viryy;
+        work->vir_q[ZZ][ZZ] = 0.25 * virzz;
+        work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25 * virxy;
+        work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25 * virxz;
+        work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25 * viryz;
 
         /* This energy should be corrected for a charged system */
-        work->energy_q = 0.5*energy;
+        work->energy_q = 0.5 * energy;
     }
 
     /* Return the loop count */
-    return local_ndata[YY]*local_ndata[XX];
+    return local_ndata[YY] * local_ndata[XX];
 }
 
-int solve_pme_lj_yzx(const gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real vol,
-                     gmx_bool bEnerVir, int nthread, int thread)
+int solve_pme_lj_yzx(const gmx_pme_t* pme, t_complex** grid, gmx_bool bLB, real vol, gmx_bool bEnerVir, int nthread, int thread)
 {
     /* do recip sum over local cells in grid */
     /* y major, z middle, x minor or continuous */
@@ -584,15 +593,15 @@ int solve_pme_lj_yzx(const gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real
     int                      nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
     real                     mx, my, mz;
     real                     ewaldcoeff = pme->ewaldcoeff_lj;
-    real                     factor     = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
+    real                     factor     = M_PI * M_PI / (ewaldcoeff * ewaldcoeff);
     real                     ets2, ets2vf;
     real                     eterm, vterm, d1, d2, energy = 0;
     real                     by, bz;
     real                     virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
     real                     rxx, ryx, ryy, rzx, rzy, rzz;
-    real                    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
+    real *                   mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
     real                     mhxk, mhyk, mhzk, m2k;
-    struct pme_solve_work_t *work;
+    struct pme_solve_work_twork;
     real                     corner_fac;
     ivec                     complex_order;
     ivec                     local_ndata, local_offset, local_size;
@@ -601,11 +610,8 @@ int solve_pme_lj_yzx(const gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real
     nz = pme->nkz;
 
     /* Dimensions should be identical for A/B grid, so we just use A here */
-    gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_C6A],
-                                      complex_order,
-                                      local_ndata,
-                                      local_offset,
-                                      local_size);
+    gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_C6A], complex_order, local_ndata,
+                                      local_offset, local_size);
     rxx = pme->recipbox[XX][XX];
     ryx = pme->recipbox[YY][XX];
     ryy = pme->recipbox[YY][YY];
@@ -613,8 +619,8 @@ int solve_pme_lj_yzx(const gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real
     rzy = pme->recipbox[ZZ][YY];
     rzz = pme->recipbox[ZZ][ZZ];
 
-    maxkx = (nx+1)/2;
-    maxky = (ny+1)/2;
+    maxkx = (nx + 1) / 2;
+    maxky = (ny + 1) / 2;
 
     work  = &pme->solve_work[thread];
     mhx   = work->mhx;
@@ -625,13 +631,13 @@ int solve_pme_lj_yzx(const gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real
     tmp1  = work->tmp1;
     tmp2  = work->tmp2;
 
-    iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
-    iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
+    iyz0 = local_ndata[YY] * local_ndata[ZZ] * thread / nthread;
+    iyz1 = local_ndata[YY] * local_ndata[ZZ] * (thread + 1) / nthread;
 
     for (iyz = iyz0; iyz < iyz1; iyz++)
     {
-        iy = iyz/local_ndata[ZZ];
-        iz = iyz - iy*local_ndata[ZZ];
+        iy = iyz / local_ndata[ZZ];
+        iz = iyz - iy * local_ndata[ZZ];
 
         ky = iy + local_offset[YY];
 
@@ -644,8 +650,7 @@ int solve_pme_lj_yzx(const gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real
             my = (ky - ny);
         }
 
-        by = 3.0*vol*pme->bsp_mod[YY][ky]
-            / (M_PI*sqrt(M_PI)*ewaldcoeff*ewaldcoeff*ewaldcoeff);
+        by = 3.0 * vol * pme->bsp_mod[YY][ky] / (M_PI * sqrt(M_PI) * ewaldcoeff * ewaldcoeff * ewaldcoeff);
 
         kz = iz + local_offset[ZZ];
 
@@ -655,7 +660,7 @@ int solve_pme_lj_yzx(const gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real
 
         /* 0.5 correction for corner points */
         corner_fac = 1;
-        if (kz == 0 || kz == (nz+1)/2)
+        if (kz == 0 || kz == (nz + 1) / 2)
         {
             corner_fac = 0.5;
         }
@@ -678,14 +683,14 @@ int solve_pme_lj_yzx(const gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real
                 mhxk      = mx * rxx;
                 mhyk      = mx * ryx + my * ryy;
                 mhzk      = mx * rzx + my * rzy + mz * rzz;
-                m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
+                m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
                 mhx[kx]   = mhxk;
                 mhy[kx]   = mhyk;
                 mhz[kx]   = mhzk;
                 m2[kx]    = m2k;
-                denom[kx] = bz*by*pme->bsp_mod[XX][kx];
-                tmp1[kx]  = -factor*m2k;
-                tmp2[kx]  = sqrt(factor*m2k);
+                denom[kx] = bz * by * pme->bsp_mod[XX][kx];
+                tmp1[kx]  = -factor * m2k;
+                tmp2[kx]  = sqrt(factor * m2k);
             }
 
             for (kx = maxkx; kx < kxend; kx++)
@@ -695,14 +700,14 @@ int solve_pme_lj_yzx(const gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real
                 mhxk      = mx * rxx;
                 mhyk      = mx * ryx + my * ryy;
                 mhzk      = mx * rzx + my * rzy + mz * rzz;
-                m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
+                m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
                 mhx[kx]   = mhxk;
                 mhy[kx]   = mhyk;
                 mhz[kx]   = mhzk;
                 m2[kx]    = m2k;
-                denom[kx] = bz*by*pme->bsp_mod[XX][kx];
-                tmp1[kx]  = -factor*m2k;
-                tmp2[kx]  = sqrt(factor*m2k);
+                denom[kx] = bz * by * pme->bsp_mod[XX][kx];
+                tmp1[kx]  = -factor * m2k;
+                tmp2[kx]  = sqrt(factor * m2k);
             }
             /* Clear padding elements to avoid (harmless) fp exceptions */
             const int kxendSimd = roundUpToMultipleOfFactor<c_simdWidth>(kxend);
@@ -712,46 +717,46 @@ int solve_pme_lj_yzx(const gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real
                 tmp2[kx] = 0;
             }
 
-            calc_exponentials_lj(kxstart, kxend,
-                                 ArrayRef<PME_T>(tmp1, tmp1+roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
-                                 ArrayRef<PME_T>(tmp2, tmp2+roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
-                                 ArrayRef<PME_T>(denom, denom+roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
+            calc_exponentials_lj(
+                    kxstart, kxend,
+                    ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
+                    ArrayRef<PME_T>(tmp2, tmp2 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
+                    ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
 
             for (kx = kxstart; kx < kxend; kx++)
             {
-                m2k   = factor*m2[kx];
-                eterm = -((1.0 - 2.0*m2k)*tmp1[kx]
-                          + 2.0*m2k*tmp2[kx]);
-                vterm    = 3.0*(-tmp1[kx] + tmp2[kx]);
-                tmp1[kx] = eterm*denom[kx];
-                tmp2[kx] = vterm*denom[kx];
+                m2k      = factor * m2[kx];
+                eterm    = -((1.0 - 2.0 * m2k) * tmp1[kx] + 2.0 * m2k * tmp2[kx]);
+                vterm    = 3.0 * (-tmp1[kx] + tmp2[kx]);
+                tmp1[kx] = eterm * denom[kx];
+                tmp2[kx] = vterm * denom[kx];
             }
 
             if (!bLB)
             {
-                t_complex *p0;
+                t_complexp0;
                 real       struct2;
 
-                p0 = grid[0] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
+                p0 = grid[0] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
                 for (kx = kxstart; kx < kxend; kx++, p0++)
                 {
-                    d1      = p0->re;
-                    d2      = p0->im;
+                    d1 = p0->re;
+                    d2 = p0->im;
 
-                    eterm   = tmp1[kx];
-                    vterm   = tmp2[kx];
-                    p0->re  = d1*eterm;
-                    p0->im  = d2*eterm;
+                    eterm  = tmp1[kx];
+                    vterm  = tmp2[kx];
+                    p0->re = d1 * eterm;
+                    p0->im = d2 * eterm;
 
-                    struct2 = 2.0*(d1*d1+d2*d2);
+                    struct2 = 2.0 * (d1 * d1 + d2 * d2);
 
-                    tmp1[kx] = eterm*struct2;
-                    tmp2[kx] = vterm*struct2;
+                    tmp1[kx] = eterm * struct2;
+                    tmp2[kx] = vterm * struct2;
                 }
             }
             else
             {
-                real *struct2 = denom;
+                realstruct2 = denom;
                 real  str2;
 
                 for (kx = kxstart; kx < kxend; kx++)
@@ -764,28 +769,27 @@ int solve_pme_lj_yzx(const gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real
                     t_complex *p0, *p1;
                     real       scale;
 
-                    p0    = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
-                    p1    = grid[6-ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
-                    scale = 2.0*lb_scale_factor_symm[ig];
+                    p0 = grid[ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
+                    p1 = grid[6 - ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
+                    scale = 2.0 * lb_scale_factor_symm[ig];
                     for (kx = kxstart; kx < kxend; ++kx, ++p0, ++p1)
                     {
-                        struct2[kx] += scale*(p0->re*p1->re + p0->im*p1->im);
+                        struct2[kx] += scale * (p0->re * p1->re + p0->im * p1->im);
                     }
-
                 }
                 for (ig = 0; ig <= 6; ++ig)
                 {
-                    t_complex *p0;
+                    t_complexp0;
 
-                    p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
+                    p0 = grid[ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
                     for (kx = kxstart; kx < kxend; kx++, p0++)
                     {
-                        d1     = p0->re;
-                        d2     = p0->im;
+                        d1 = p0->re;
+                        d2 = p0->im;
 
                         eterm  = tmp1[kx];
-                        p0->re = d1*eterm;
-                        p0->im = d2*eterm;
+                        p0->re = d1 * eterm;
+                        p0->im = d2 * eterm;
                     }
                 }
                 for (kx = kxstart; kx < kxend; kx++)
@@ -793,23 +797,23 @@ int solve_pme_lj_yzx(const gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real
                     eterm    = tmp1[kx];
                     vterm    = tmp2[kx];
                     str2     = struct2[kx];
-                    tmp1[kx] = eterm*str2;
-                    tmp2[kx] = vterm*str2;
+                    tmp1[kx] = eterm * str2;
+                    tmp2[kx] = vterm * str2;
                 }
             }
 
             for (kx = kxstart; kx < kxend; kx++)
             {
-                ets2     = corner_fac*tmp1[kx];
-                vterm    = 2.0*factor*tmp2[kx];
-                energy  += ets2;
-                ets2vf   = corner_fac*vterm;
-                virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
-                virxy   += ets2vf*mhx[kx]*mhy[kx];
-                virxz   += ets2vf*mhx[kx]*mhz[kx];
-                viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
-                viryz   += ets2vf*mhy[kx]*mhz[kx];
-                virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
+                ets2  = corner_fac * tmp1[kx];
+                vterm = 2.0 * factor * tmp2[kx];
+                energy += ets2;
+                ets2vf = corner_fac * vterm;
+                virxx += ets2vf * mhx[kx] * mhx[kx] - ets2;
+                virxy += ets2vf * mhx[kx] * mhy[kx];
+                virxz += ets2vf * mhx[kx] * mhz[kx];
+                viryy += ets2vf * mhy[kx] * mhy[kx] - ets2;
+                viryz += ets2vf * mhy[kx] * mhz[kx];
+                virzz += ets2vf * mhz[kx] * mhz[kx] - ets2;
             }
         }
         else
@@ -827,11 +831,11 @@ int solve_pme_lj_yzx(const gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real
                 mhxk      = mx * rxx;
                 mhyk      = mx * ryx + my * ryy;
                 mhzk      = mx * rzx + my * rzy + mz * rzz;
-                m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
+                m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
                 m2[kx]    = m2k;
-                denom[kx] = bz*by*pme->bsp_mod[XX][kx];
-                tmp1[kx]  = -factor*m2k;
-                tmp2[kx]  = sqrt(factor*m2k);
+                denom[kx] = bz * by * pme->bsp_mod[XX][kx];
+                tmp1[kx]  = -factor * m2k;
+                tmp2[kx]  = sqrt(factor * m2k);
             }
 
             for (kx = maxkx; kx < kxend; kx++)
@@ -841,11 +845,11 @@ int solve_pme_lj_yzx(const gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real
                 mhxk      = mx * rxx;
                 mhyk      = mx * ryx + my * ryy;
                 mhzk      = mx * rzx + my * rzy + mz * rzz;
-                m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
+                m2k       = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
                 m2[kx]    = m2k;
-                denom[kx] = bz*by*pme->bsp_mod[XX][kx];
-                tmp1[kx]  = -factor*m2k;
-                tmp2[kx]  = sqrt(factor*m2k);
+                denom[kx] = bz * by * pme->bsp_mod[XX][kx];
+                tmp1[kx]  = -factor * m2k;
+                tmp2[kx]  = sqrt(factor * m2k);
             }
             /* Clear padding elements to avoid (harmless) fp exceptions */
             const int kxendSimd = roundUpToMultipleOfFactor<c_simdWidth>(kxend);
@@ -855,49 +859,49 @@ int solve_pme_lj_yzx(const gmx_pme_t *pme, t_complex **grid, gmx_bool bLB, real
                 tmp2[kx] = 0;
             }
 
-            calc_exponentials_lj(kxstart, kxend,
-                                 ArrayRef<PME_T>(tmp1, tmp1+roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
-                                 ArrayRef<PME_T>(tmp2, tmp2+roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
-                                 ArrayRef<PME_T>(denom, denom+roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
+            calc_exponentials_lj(
+                    kxstart, kxend,
+                    ArrayRef<PME_T>(tmp1, tmp1 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
+                    ArrayRef<PME_T>(tmp2, tmp2 + roundUpToMultipleOfFactor<c_simdWidth>(kxend)),
+                    ArrayRef<PME_T>(denom, denom + roundUpToMultipleOfFactor<c_simdWidth>(kxend)));
 
             for (kx = kxstart; kx < kxend; kx++)
             {
-                m2k    = factor*m2[kx];
-                eterm  = -((1.0 - 2.0*m2k)*tmp1[kx]
-                           + 2.0*m2k*tmp2[kx]);
-                tmp1[kx] = eterm*denom[kx];
+                m2k      = factor * m2[kx];
+                eterm    = -((1.0 - 2.0 * m2k) * tmp1[kx] + 2.0 * m2k * tmp2[kx]);
+                tmp1[kx] = eterm * denom[kx];
             }
             gcount = (bLB ? 7 : 1);
             for (ig = 0; ig < gcount; ++ig)
             {
-                t_complex *p0;
+                t_complexp0;
 
-                p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
+                p0 = grid[ig] + iy * local_size[ZZ] * local_size[XX] + iz * local_size[XX];
                 for (kx = kxstart; kx < kxend; kx++, p0++)
                 {
-                    d1      = p0->re;
-                    d2      = p0->im;
+                    d1 = p0->re;
+                    d2 = p0->im;
 
-                    eterm   = tmp1[kx];
+                    eterm = tmp1[kx];
 
-                    p0->re  = d1*eterm;
-                    p0->im  = d2*eterm;
+                    p0->re = d1 * eterm;
+                    p0->im = d2 * eterm;
                 }
             }
         }
     }
     if (bEnerVir)
     {
-        work->vir_lj[XX][XX] = 0.25*virxx;
-        work->vir_lj[YY][YY] = 0.25*viryy;
-        work->vir_lj[ZZ][ZZ] = 0.25*virzz;
-        work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25*virxy;
-        work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25*virxz;
-        work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25*viryz;
+        work->vir_lj[XX][XX] = 0.25 * virxx;
+        work->vir_lj[YY][YY] = 0.25 * viryy;
+        work->vir_lj[ZZ][ZZ] = 0.25 * virzz;
+        work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25 * virxy;
+        work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25 * virxz;
+        work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25 * viryz;
 
         /* This energy should be corrected for a charged system */
-        work->energy_lj = 0.5*energy;
+        work->energy_lj = 0.5 * energy;
     }
     /* Return the loop count */
-    return local_ndata[YY]*local_ndata[XX];
+    return local_ndata[YY] * local_ndata[XX];
 }