Pad RVec force buffer in ThreadForceBuffer
[alexxy/gromacs.git] / src / gromacs / gmxlib / nonbonded / nb_free_energy.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017 by the GROMACS development team.
7  * Copyright (c) 2018,2019,2020,2021, by the GROMACS development team, led by
8  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
9  * and including many others, as listed in the AUTHORS file in the
10  * top-level source directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 #include "gmxpre.h"
39
40 #include "nb_free_energy.h"
41
42 #include "config.h"
43
44 #include <cmath>
45 #include <set>
46
47 #include <algorithm>
48
49 #include "gromacs/gmxlib/nrnb.h"
50 #include "gromacs/gmxlib/nonbonded/nonbonded.h"
51 #include "gromacs/math/arrayrefwithpadding.h"
52 #include "gromacs/math/functions.h"
53 #include "gromacs/math/vec.h"
54 #include "gromacs/mdtypes/forceoutput.h"
55 #include "gromacs/mdtypes/forcerec.h"
56 #include "gromacs/mdtypes/interaction_const.h"
57 #include "gromacs/mdtypes/md_enums.h"
58 #include "gromacs/mdtypes/mdatom.h"
59 #include "gromacs/mdtypes/nblist.h"
60 #include "gromacs/pbcutil/ishift.h"
61 #include "gromacs/simd/simd.h"
62 #include "gromacs/simd/simd_math.h"
63 #include "gromacs/utility/fatalerror.h"
64 #include "gromacs/utility/arrayref.h"
65
66
67 //! Scalar (non-SIMD) data types.
68 struct ScalarDataTypes
69 {
70     using RealType = real; //!< The data type to use as real.
71     using IntType  = int;  //!< The data type to use as int.
72     using BoolType = bool; //!< The data type to use as bool for real value comparison.
73     static constexpr int simdRealWidth = 1; //!< The width of the RealType.
74     static constexpr int simdIntWidth  = 1; //!< The width of the IntType.
75 };
76
77 #if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS
78 //! SIMD data types.
79 struct SimdDataTypes
80 {
81     using RealType = gmx::SimdReal;  //!< The data type to use as real.
82     using IntType  = gmx::SimdInt32; //!< The data type to use as int.
83     using BoolType = gmx::SimdBool;  //!< The data type to use as bool for real value comparison.
84     static constexpr int simdRealWidth = GMX_SIMD_REAL_WIDTH; //!< The width of the RealType.
85 #    if GMX_SIMD_HAVE_DOUBLE && GMX_DOUBLE
86     static constexpr int simdIntWidth = GMX_SIMD_DINT32_WIDTH; //!< The width of the IntType.
87 #    else
88     static constexpr int simdIntWidth = GMX_SIMD_FINT32_WIDTH; //!< The width of the IntType.
89 #    endif
90 };
91 #endif
92
93 template<class RealType, class BoolType>
94 static inline void
95 pmeCoulombCorrectionVF(const RealType rSq, const real beta, RealType* pot, RealType* force, const BoolType mask)
96 {
97     const RealType brsq = gmx::selectByMask(rSq * beta * beta, mask);
98     *force              = -brsq * beta * gmx::pmeForceCorrection(brsq);
99     *pot                = beta * gmx::pmePotentialCorrection(brsq);
100 }
101
102 template<class RealType, class BoolType>
103 static inline void pmeLJCorrectionVF(const RealType rInv,
104                                      const RealType rSq,
105                                      const real     ewaldLJCoeffSq,
106                                      const real     ewaldLJCoeffSixDivSix,
107                                      RealType*      pot,
108                                      RealType*      force,
109                                      const BoolType mask,
110                                      const BoolType bIiEqJnr)
111 {
112     // We mask rInv to get zero force and potential for masked out pair interactions
113     const RealType rInvSq  = gmx::selectByMask(rInv * rInv, mask);
114     const RealType rInvSix = rInvSq * rInvSq * rInvSq;
115     // Mask rSq to avoid underflow in exp()
116     const RealType coeffSqRSq       = ewaldLJCoeffSq * gmx::selectByMask(rSq, mask);
117     const RealType expNegCoeffSqRSq = gmx::exp(-coeffSqRSq);
118     const RealType poly             = 1.0_real + coeffSqRSq + 0.5_real * coeffSqRSq * coeffSqRSq;
119     *force = rInvSix - expNegCoeffSqRSq * (rInvSix * poly + ewaldLJCoeffSixDivSix);
120     *force = *force * rInvSq;
121     // The self interaction is the limit for r -> 0 which we need to compute separately
122     *pot = gmx::blend(
123             rInvSix * (1.0_real - expNegCoeffSqRSq * poly), 0.5_real * ewaldLJCoeffSixDivSix, bIiEqJnr);
124 }
125
126 //! Computes r^(1/p) and 1/r^(1/p) for the standard p=6
127 template<class RealType, class BoolType>
128 static inline void pthRoot(const RealType r, RealType* pthRoot, RealType* invPthRoot, const BoolType mask)
129 {
130     RealType cbrtRes = gmx::cbrt(r);
131     *invPthRoot      = gmx::maskzInvsqrt(cbrtRes, mask);
132     *pthRoot         = gmx::maskzInv(*invPthRoot, mask);
133 }
134
135 template<class RealType>
136 static inline RealType calculateRinv6(const RealType rInvV)
137 {
138     RealType rInv6 = rInvV * rInvV;
139     return (rInv6 * rInv6 * rInv6);
140 }
141
142 template<class RealType>
143 static inline RealType calculateVdw6(const RealType c6, const RealType rInv6)
144 {
145     return (c6 * rInv6);
146 }
147
148 template<class RealType>
149 static inline RealType calculateVdw12(const RealType c12, const RealType rInv6)
150 {
151     return (c12 * rInv6 * rInv6);
152 }
153
154 /* reaction-field electrostatics */
155 template<class RealType>
156 static inline RealType reactionFieldScalarForce(const RealType qq,
157                                                 const RealType rInv,
158                                                 const RealType r,
159                                                 const real     krf,
160                                                 const real     two)
161 {
162     return (qq * (rInv - two * krf * r * r));
163 }
164 template<class RealType>
165 static inline RealType reactionFieldPotential(const RealType qq,
166                                               const RealType rInv,
167                                               const RealType r,
168                                               const real     krf,
169                                               const real     potentialShift)
170 {
171     return (qq * (rInv + krf * r * r - potentialShift));
172 }
173
174 /* Ewald electrostatics */
175 template<class RealType>
176 static inline RealType ewaldScalarForce(const RealType coulomb, const RealType rInv)
177 {
178     return (coulomb * rInv);
179 }
180 template<class RealType>
181 static inline RealType ewaldPotential(const RealType coulomb, const RealType rInv, const real potentialShift)
182 {
183     return (coulomb * (rInv - potentialShift));
184 }
185
186 /* cutoff LJ */
187 template<class RealType>
188 static inline RealType lennardJonesScalarForce(const RealType v6, const RealType v12)
189 {
190     return (v12 - v6);
191 }
192 template<class RealType>
193 static inline RealType lennardJonesPotential(const RealType v6,
194                                              const RealType v12,
195                                              const RealType c6,
196                                              const RealType c12,
197                                              const real     repulsionShift,
198                                              const real     dispersionShift,
199                                              const real     oneSixth,
200                                              const real     oneTwelfth)
201 {
202     return ((v12 + c12 * repulsionShift) * oneTwelfth - (v6 + c6 * dispersionShift) * oneSixth);
203 }
204
205 /* Ewald LJ */
206 template<class RealType>
207 static inline RealType ewaldLennardJonesGridSubtract(const RealType c6grid,
208                                                      const real     potentialShift,
209                                                      const real     oneSixth)
210 {
211     return (c6grid * potentialShift * oneSixth);
212 }
213
214 /* LJ Potential switch */
215 template<class RealType, class BoolType>
216 static inline RealType potSwitchScalarForceMod(const RealType fScalarInp,
217                                                const RealType potential,
218                                                const RealType sw,
219                                                const RealType r,
220                                                const RealType dsw,
221                                                const BoolType mask)
222 {
223     /* The mask should select on rV < rVdw */
224     return (gmx::selectByMask(fScalarInp * sw - r * potential * dsw, mask));
225 }
226 template<class RealType, class BoolType>
227 static inline RealType potSwitchPotentialMod(const RealType potentialInp, const RealType sw, const BoolType mask)
228 {
229     /* The mask should select on rV < rVdw */
230     return (gmx::selectByMask(potentialInp * sw, mask));
231 }
232
233
234 //! Templated free-energy non-bonded kernel
235 template<typename DataTypes, bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
236 static void nb_free_energy_kernel(const t_nblist&                                  nlist,
237                                   const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
238                                   const int                                        ntype,
239                                   const real                                       rlist,
240                                   const interaction_const_t&                       ic,
241                                   gmx::ArrayRef<const gmx::RVec>                   shiftvec,
242                                   gmx::ArrayRef<const real>                        nbfp,
243                                   gmx::ArrayRef<const real> gmx_unused             nbfp_grid,
244                                   gmx::ArrayRef<const real>                        chargeA,
245                                   gmx::ArrayRef<const real>                        chargeB,
246                                   gmx::ArrayRef<const int>                         typeA,
247                                   gmx::ArrayRef<const int>                         typeB,
248                                   int                                              flags,
249                                   gmx::ArrayRef<const real>                        lambda,
250                                   t_nrnb* gmx_restrict                             nrnb,
251                                   gmx::ArrayRefWithPadding<gmx::RVec> threadForceBuffer,
252                                   rvec*                               threadForceShiftBuffer,
253                                   gmx::ArrayRef<real>                 threadVc,
254                                   gmx::ArrayRef<real>                 threadVv,
255                                   gmx::ArrayRef<real>                 threadDvdl)
256 {
257 #define STATE_A 0
258 #define STATE_B 1
259 #define NSTATES 2
260
261     using RealType = typename DataTypes::RealType;
262     using IntType  = typename DataTypes::IntType;
263     using BoolType = typename DataTypes::BoolType;
264
265     constexpr real oneTwelfth = 1.0_real / 12.0_real;
266     constexpr real oneSixth   = 1.0_real / 6.0_real;
267     constexpr real zero       = 0.0_real;
268     constexpr real half       = 0.5_real;
269     constexpr real one        = 1.0_real;
270     constexpr real two        = 2.0_real;
271     constexpr real six        = 6.0_real;
272
273     // Extract pair list data
274     const int                nri    = nlist.nri;
275     gmx::ArrayRef<const int> iinr   = nlist.iinr;
276     gmx::ArrayRef<const int> jindex = nlist.jindex;
277     gmx::ArrayRef<const int> jjnr   = nlist.jjnr;
278     gmx::ArrayRef<const int> shift  = nlist.shift;
279     gmx::ArrayRef<const int> gid    = nlist.gid;
280
281     const real  lambda_coul = lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)];
282     const real  lambda_vdw  = lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)];
283     const auto& scParams    = *ic.softCoreParameters;
284     const real gmx_unused alpha_coul    = scParams.alphaCoulomb;
285     const real gmx_unused alpha_vdw     = scParams.alphaVdw;
286     const real            lam_power     = scParams.lambdaPower;
287     const real gmx_unused sigma6_def    = scParams.sigma6WithInvalidSigma;
288     const real gmx_unused sigma6_min    = scParams.sigma6Minimum;
289     const bool            doForces      = ((flags & GMX_NONBONDED_DO_FORCE) != 0);
290     const bool            doShiftForces = ((flags & GMX_NONBONDED_DO_SHIFTFORCE) != 0);
291     const bool            doPotential   = ((flags & GMX_NONBONDED_DO_POTENTIAL) != 0);
292
293     // Extract data from interaction_const_t
294     const real            facel           = ic.epsfac;
295     const real            rCoulomb        = ic.rcoulomb;
296     const real            krf             = ic.reactionFieldCoefficient;
297     const real            crf             = ic.reactionFieldShift;
298     const real gmx_unused shLjEwald       = ic.sh_lj_ewald;
299     const real            rVdw            = ic.rvdw;
300     const real            dispersionShift = ic.dispersion_shift.cpot;
301     const real            repulsionShift  = ic.repulsion_shift.cpot;
302     const real            ewaldBeta       = ic.ewaldcoeff_q;
303     real gmx_unused       ewaldLJCoeffSq;
304     real gmx_unused       ewaldLJCoeffSixDivSix;
305     if constexpr (vdwInteractionTypeIsEwald)
306     {
307         ewaldLJCoeffSq        = ic.ewaldcoeff_lj * ic.ewaldcoeff_lj;
308         ewaldLJCoeffSixDivSix = ewaldLJCoeffSq * ewaldLJCoeffSq * ewaldLJCoeffSq / six;
309     }
310
311     // Note that the nbnxm kernels do not support Coulomb potential switching at all
312     GMX_ASSERT(ic.coulomb_modifier != InteractionModifiers::PotSwitch,
313                "Potential switching is not supported for Coulomb with FEP");
314
315     const real      rVdwSwitch = ic.rvdw_switch;
316     real gmx_unused vdw_swV3, vdw_swV4, vdw_swV5, vdw_swF2, vdw_swF3, vdw_swF4;
317     if constexpr (vdwModifierIsPotSwitch)
318     {
319         const real d = rVdw - rVdwSwitch;
320         vdw_swV3     = -10.0_real / (d * d * d);
321         vdw_swV4     = 15.0_real / (d * d * d * d);
322         vdw_swV5     = -6.0_real / (d * d * d * d * d);
323         vdw_swF2     = -30.0_real / (d * d * d);
324         vdw_swF3     = 60.0_real / (d * d * d * d);
325         vdw_swF4     = -30.0_real / (d * d * d * d * d);
326     }
327     else
328     {
329         /* Avoid warnings from stupid compilers (looking at you, Clang!) */
330         vdw_swV3 = vdw_swV4 = vdw_swV5 = vdw_swF2 = vdw_swF3 = vdw_swF4 = zero;
331     }
332
333     NbkernelElecType icoul;
334     if (ic.eeltype == CoulombInteractionType::Cut || EEL_RF(ic.eeltype))
335     {
336         icoul = NbkernelElecType::ReactionField;
337     }
338     else
339     {
340         icoul = NbkernelElecType::None;
341     }
342
343     real rcutoff_max2 = std::max(ic.rcoulomb, ic.rvdw);
344     rcutoff_max2      = rcutoff_max2 * rcutoff_max2;
345
346     real gmx_unused sh_ewald = 0;
347     if constexpr (elecInteractionTypeIsEwald || vdwInteractionTypeIsEwald)
348     {
349         sh_ewald = ic.sh_ewald;
350     }
351
352     /* For Ewald/PME interactions we cannot easily apply the soft-core component to
353      * reciprocal space. When we use non-switched Ewald interactions, we
354      * assume the soft-coring does not significantly affect the grid contribution
355      * and apply the soft-core only to the full 1/r (- shift) pair contribution.
356      *
357      * However, we cannot use this approach for switch-modified since we would then
358      * effectively end up evaluating a significantly different interaction here compared to the
359      * normal (non-free-energy) kernels, either by applying a cutoff at a different
360      * position than what the user requested, or by switching different
361      * things (1/r rather than short-range Ewald). For these settings, we just
362      * use the traditional short-range Ewald interaction in that case.
363      */
364     GMX_RELEASE_ASSERT(!(vdwInteractionTypeIsEwald && vdwModifierIsPotSwitch),
365                        "Can not apply soft-core to switched Ewald potentials");
366
367     RealType dvdlCoul(zero);
368     RealType dvdlVdw(zero);
369
370     /* Lambda factor for state A, 1-lambda*/
371     real LFC[NSTATES], LFV[NSTATES];
372     LFC[STATE_A] = one - lambda_coul;
373     LFV[STATE_A] = one - lambda_vdw;
374
375     /* Lambda factor for state B, lambda*/
376     LFC[STATE_B] = lambda_coul;
377     LFV[STATE_B] = lambda_vdw;
378
379     /*derivative of the lambda factor for state A and B */
380     real DLF[NSTATES];
381     DLF[STATE_A] = -one;
382     DLF[STATE_B] = one;
383
384     real gmx_unused lFacCoul[NSTATES], dlFacCoul[NSTATES], lFacVdw[NSTATES], dlFacVdw[NSTATES];
385     constexpr real  sc_r_power = six;
386     for (int i = 0; i < NSTATES; i++)
387     {
388         lFacCoul[i]  = (lam_power == 2 ? (1 - LFC[i]) * (1 - LFC[i]) : (1 - LFC[i]));
389         dlFacCoul[i] = DLF[i] * lam_power / sc_r_power * (lam_power == 2 ? (1 - LFC[i]) : 1);
390         lFacVdw[i]   = (lam_power == 2 ? (1 - LFV[i]) * (1 - LFV[i]) : (1 - LFV[i]));
391         dlFacVdw[i]  = DLF[i] * lam_power / sc_r_power * (lam_power == 2 ? (1 - LFV[i]) : 1);
392     }
393
394     // We need pointers to real for SIMD access
395     const real* gmx_restrict x            = coords.paddedConstArrayRef().data()[0];
396     real* gmx_restrict       forceRealPtr = threadForceBuffer.paddedArrayRef().data()[0];
397
398     const real rlistSquared = gmx::square(rlist);
399
400     bool haveExcludedPairsBeyondRlist = false;
401
402     for (int n = 0; n < nri; n++)
403     {
404         bool havePairsWithinCutoff = false;
405
406         const int  is   = shift[n];
407         const real shX  = shiftvec[is][XX];
408         const real shY  = shiftvec[is][YY];
409         const real shZ  = shiftvec[is][ZZ];
410         const int  nj0  = jindex[n];
411         const int  nj1  = jindex[n + 1];
412         const int  ii   = iinr[n];
413         const int  ii3  = 3 * ii;
414         const real ix   = shX + x[ii3 + 0];
415         const real iy   = shY + x[ii3 + 1];
416         const real iz   = shZ + x[ii3 + 2];
417         const real iqA  = facel * chargeA[ii];
418         const real iqB  = facel * chargeB[ii];
419         const int  ntiA = ntype * typeA[ii];
420         const int  ntiB = ntype * typeB[ii];
421         RealType   vCTot(0);
422         RealType   vVTot(0);
423         RealType   fIX(0);
424         RealType   fIY(0);
425         RealType   fIZ(0);
426
427 #if GMX_SIMD_HAVE_REAL
428         alignas(GMX_SIMD_ALIGNMENT) int preloadIi[DataTypes::simdRealWidth];
429         alignas(GMX_SIMD_ALIGNMENT) int preloadIs[DataTypes::simdRealWidth];
430 #else
431         int preloadIi[DataTypes::simdRealWidth];
432         int preloadIs[DataTypes::simdRealWidth];
433 #endif
434         for (int s = 0; s < DataTypes::simdRealWidth; s++)
435         {
436             preloadIi[s] = ii;
437             preloadIs[s] = shift[n];
438         }
439         IntType ii_s = gmx::load<IntType>(preloadIi);
440
441         for (int k = nj0; k < nj1; k += DataTypes::simdRealWidth)
442         {
443             RealType r, rInv;
444
445 #if GMX_SIMD_HAVE_REAL
446             alignas(GMX_SIMD_ALIGNMENT) real    preloadPairIsValid[DataTypes::simdRealWidth];
447             alignas(GMX_SIMD_ALIGNMENT) real    preloadPairIncluded[DataTypes::simdRealWidth];
448             alignas(GMX_SIMD_ALIGNMENT) int32_t preloadJnr[DataTypes::simdRealWidth];
449             alignas(GMX_SIMD_ALIGNMENT) int32_t typeIndices[NSTATES][DataTypes::simdRealWidth];
450             alignas(GMX_SIMD_ALIGNMENT) real    preloadQq[NSTATES][DataTypes::simdRealWidth];
451             alignas(GMX_SIMD_ALIGNMENT) real gmx_unused preloadSigma6[NSTATES][DataTypes::simdRealWidth];
452             alignas(GMX_SIMD_ALIGNMENT) real gmx_unused preloadAlphaVdwEff[DataTypes::simdRealWidth];
453             alignas(GMX_SIMD_ALIGNMENT) real gmx_unused preloadAlphaCoulEff[DataTypes::simdRealWidth];
454             alignas(GMX_SIMD_ALIGNMENT) real preloadLjPmeC6Grid[NSTATES][DataTypes::simdRealWidth];
455 #else
456             real            preloadPairIsValid[DataTypes::simdRealWidth];
457             real            preloadPairIncluded[DataTypes::simdRealWidth];
458             int             preloadJnr[DataTypes::simdRealWidth];
459             int             typeIndices[NSTATES][DataTypes::simdRealWidth];
460             real            preloadQq[NSTATES][DataTypes::simdRealWidth];
461             real gmx_unused preloadSigma6[NSTATES][DataTypes::simdRealWidth];
462             real gmx_unused preloadAlphaVdwEff[DataTypes::simdRealWidth];
463             real gmx_unused preloadAlphaCoulEff[DataTypes::simdRealWidth];
464             real            preloadLjPmeC6Grid[NSTATES][DataTypes::simdRealWidth];
465 #endif
466             for (int s = 0; s < DataTypes::simdRealWidth; s++)
467             {
468                 if (k + s < nj1)
469                 {
470                     preloadPairIsValid[s] = true;
471                     /* Check if this pair on the exclusions list.*/
472                     preloadPairIncluded[s]  = (nlist.excl_fep.empty() || nlist.excl_fep[k + s]);
473                     const int jnr           = jjnr[k + s];
474                     preloadJnr[s]           = jnr;
475                     typeIndices[STATE_A][s] = ntiA + typeA[jnr];
476                     typeIndices[STATE_B][s] = ntiB + typeB[jnr];
477                     preloadQq[STATE_A][s]   = iqA * chargeA[jnr];
478                     preloadQq[STATE_B][s]   = iqB * chargeB[jnr];
479
480                     for (int i = 0; i < NSTATES; i++)
481                     {
482                         if constexpr (vdwInteractionTypeIsEwald)
483                         {
484                             preloadLjPmeC6Grid[i][s] = nbfp_grid[2 * typeIndices[i][s]];
485                         }
486                         else
487                         {
488                             preloadLjPmeC6Grid[i][s] = 0;
489                         }
490                         if constexpr (useSoftCore)
491                         {
492                             const real c6  = nbfp[2 * typeIndices[i][s]];
493                             const real c12 = nbfp[2 * typeIndices[i][s] + 1];
494                             if (c6 > 0 && c12 > 0)
495                             {
496                                 /* c12 is stored scaled with 12.0 and c6 is scaled with 6.0 - correct for this */
497                                 preloadSigma6[i][s] = 0.5_real * c12 / c6;
498                                 if (preloadSigma6[i][s]
499                                     < sigma6_min) /* for disappearing coul and vdw with soft core at the same time */
500                                 {
501                                     preloadSigma6[i][s] = sigma6_min;
502                                 }
503                             }
504                             else
505                             {
506                                 preloadSigma6[i][s] = sigma6_def;
507                             }
508                         }
509                     }
510                     if constexpr (useSoftCore)
511                     {
512                         /* only use softcore if one of the states has a zero endstate - softcore is for avoiding infinities!*/
513                         const real c12A = nbfp[2 * typeIndices[STATE_A][s] + 1];
514                         const real c12B = nbfp[2 * typeIndices[STATE_B][s] + 1];
515                         if (c12A > 0 && c12B > 0)
516                         {
517                             preloadAlphaVdwEff[s]  = 0;
518                             preloadAlphaCoulEff[s] = 0;
519                         }
520                         else
521                         {
522                             preloadAlphaVdwEff[s]  = alpha_vdw;
523                             preloadAlphaCoulEff[s] = alpha_coul;
524                         }
525                     }
526                 }
527                 else
528                 {
529                     preloadJnr[s]          = jjnr[k];
530                     preloadPairIsValid[s]  = false;
531                     preloadPairIncluded[s] = false;
532                     preloadAlphaVdwEff[s]  = 0;
533                     preloadAlphaCoulEff[s] = 0;
534
535                     for (int i = 0; i < NSTATES; i++)
536                     {
537                         typeIndices[STATE_A][s]  = ntiA + typeA[jjnr[k]];
538                         typeIndices[STATE_B][s]  = ntiB + typeB[jjnr[k]];
539                         preloadLjPmeC6Grid[i][s] = 0;
540                         preloadQq[i][s]          = 0;
541                         preloadSigma6[i][s]      = 0;
542                     }
543                 }
544             }
545
546             RealType jx, jy, jz;
547             gmx::gatherLoadUTranspose<3>(reinterpret_cast<const real*>(x), preloadJnr, &jx, &jy, &jz);
548
549             const RealType pairIsValid   = gmx::load<RealType>(preloadPairIsValid);
550             const RealType pairIncluded  = gmx::load<RealType>(preloadPairIncluded);
551             const BoolType bPairIncluded = (pairIncluded != zero);
552             const BoolType bPairExcluded = (pairIncluded == zero && pairIsValid != zero);
553
554             const RealType dX  = ix - jx;
555             const RealType dY  = iy - jy;
556             const RealType dZ  = iz - jz;
557             const RealType rSq = dX * dX + dY * dY + dZ * dZ;
558
559             BoolType withinCutoffMask = (rSq < rcutoff_max2);
560
561             if (!gmx::anyTrue(withinCutoffMask || bPairExcluded))
562             {
563                 /* We save significant time by skipping all code below.
564                  * Note that with soft-core interactions, the actual cut-off
565                  * check might be different. But since the soft-core distance
566                  * is always larger than r, checking on r here is safe.
567                  * Exclusions outside the cutoff can not be skipped as
568                  * when using Ewald: the reciprocal-space
569                  * Ewald component still needs to be subtracted.
570                  */
571                 continue;
572             }
573             else
574             {
575                 havePairsWithinCutoff = true;
576             }
577
578             if (gmx::anyTrue(rlistSquared < rSq && bPairExcluded))
579             {
580                 haveExcludedPairsBeyondRlist = true;
581             }
582
583             const IntType  jnr_s    = gmx::load<IntType>(preloadJnr);
584             const BoolType bIiEqJnr = gmx::cvtIB2B(ii_s == jnr_s);
585
586             RealType            c6[NSTATES];
587             RealType            c12[NSTATES];
588             RealType gmx_unused sigma6[NSTATES];
589             RealType            qq[NSTATES];
590             RealType gmx_unused ljPmeC6Grid[NSTATES];
591             RealType gmx_unused alphaVdwEff;
592             RealType gmx_unused alphaCoulEff;
593             for (int i = 0; i < NSTATES; i++)
594             {
595                 gmx::gatherLoadTranspose<2>(nbfp.data(), typeIndices[i], &c6[i], &c12[i]);
596                 qq[i]          = gmx::load<RealType>(preloadQq[i]);
597                 ljPmeC6Grid[i] = gmx::load<RealType>(preloadLjPmeC6Grid[i]);
598                 if constexpr (useSoftCore)
599                 {
600                     sigma6[i] = gmx::load<RealType>(preloadSigma6[i]);
601                 }
602             }
603             if constexpr (useSoftCore)
604             {
605                 alphaVdwEff  = gmx::load<RealType>(preloadAlphaVdwEff);
606                 alphaCoulEff = gmx::load<RealType>(preloadAlphaCoulEff);
607             }
608
609             BoolType rSqValid = (zero < rSq);
610
611             /* The force at r=0 is zero, because of symmetry.
612              * But note that the potential is in general non-zero,
613              * since the soft-cored r will be non-zero.
614              */
615             rInv = gmx::maskzInvsqrt(rSq, rSqValid);
616             r    = rSq * rInv;
617
618             RealType gmx_unused rp, rpm2;
619             if constexpr (useSoftCore)
620             {
621                 rpm2 = rSq * rSq;  /* r4 */
622                 rp   = rpm2 * rSq; /* r6 */
623             }
624             else
625             {
626                 /* The soft-core power p will not affect the results
627                  * with not using soft-core, so we use power of 0 which gives
628                  * the simplest math and cheapest code.
629                  */
630                 rpm2 = rInv * rInv;
631                 rp   = one;
632             }
633
634             RealType fScal(0);
635
636             /* The following block is masked to only calculate values having bPairIncluded. If
637              * bPairIncluded is true then withinCutoffMask must also be true. */
638             if (gmx::anyTrue(withinCutoffMask && bPairIncluded))
639             {
640                 RealType fScalC[NSTATES], fScalV[NSTATES];
641                 RealType vCoul[NSTATES], vVdw[NSTATES];
642                 for (int i = 0; i < NSTATES; i++)
643                 {
644                     fScalC[i] = zero;
645                     fScalV[i] = zero;
646                     vCoul[i]  = zero;
647                     vVdw[i]   = zero;
648
649                     RealType gmx_unused rInvC, rInvV, rC, rV, rPInvC, rPInvV;
650
651                     /* The following block is masked to require (qq[i] != 0 || c6[i] != 0 || c12[i]
652                      * != 0) in addition to bPairIncluded, which in turn requires withinCutoffMask. */
653                     BoolType nonZeroState = ((qq[i] != zero || c6[i] != zero || c12[i] != zero)
654                                              && bPairIncluded && withinCutoffMask);
655                     if (gmx::anyTrue(nonZeroState))
656                     {
657                         if constexpr (useSoftCore)
658                         {
659                             RealType divisor      = (alphaCoulEff * lFacCoul[i] * sigma6[i] + rp);
660                             BoolType validDivisor = (zero < divisor);
661                             rPInvC                = gmx::maskzInv(divisor, validDivisor);
662                             pthRoot(rPInvC, &rInvC, &rC, validDivisor);
663
664                             if constexpr (scLambdasOrAlphasDiffer)
665                             {
666                                 RealType divisor      = (alphaVdwEff * lFacVdw[i] * sigma6[i] + rp);
667                                 BoolType validDivisor = (zero < divisor);
668                                 rPInvV                = gmx::maskzInv(divisor, validDivisor);
669                                 pthRoot(rPInvV, &rInvV, &rV, validDivisor);
670                             }
671                             else
672                             {
673                                 /* We can avoid one expensive pow and one / operation */
674                                 rPInvV = rPInvC;
675                                 rInvV  = rInvC;
676                                 rV     = rC;
677                             }
678                         }
679                         else
680                         {
681                             rPInvC = one;
682                             rInvC  = rInv;
683                             rC     = r;
684
685                             rPInvV = one;
686                             rInvV  = rInv;
687                             rV     = r;
688                         }
689
690                         /* Only process the coulomb interactions if we either
691                          * include all entries in the list (no cutoff
692                          * used in the kernel), or if we are within the cutoff.
693                          */
694                         BoolType computeElecInteraction;
695                         if constexpr (elecInteractionTypeIsEwald)
696                         {
697                             computeElecInteraction = (r < rCoulomb && qq[i] != zero && bPairIncluded);
698                         }
699                         else
700                         {
701                             computeElecInteraction = (rC < rCoulomb && qq[i] != zero && bPairIncluded);
702                         }
703                         if (gmx::anyTrue(computeElecInteraction))
704                         {
705                             if constexpr (elecInteractionTypeIsEwald)
706                             {
707                                 vCoul[i]  = ewaldPotential(qq[i], rInvC, sh_ewald);
708                                 fScalC[i] = ewaldScalarForce(qq[i], rInvC);
709                             }
710                             else
711                             {
712                                 vCoul[i]  = reactionFieldPotential(qq[i], rInvC, rC, krf, crf);
713                                 fScalC[i] = reactionFieldScalarForce(qq[i], rInvC, rC, krf, two);
714                             }
715
716                             vCoul[i]  = gmx::selectByMask(vCoul[i], computeElecInteraction);
717                             fScalC[i] = gmx::selectByMask(fScalC[i], computeElecInteraction);
718                         }
719
720                         /* Only process the VDW interactions if we either
721                          * include all entries in the list (no cutoff used
722                          * in the kernel), or if we are within the cutoff.
723                          */
724                         BoolType computeVdwInteraction;
725                         if constexpr (vdwInteractionTypeIsEwald)
726                         {
727                             computeVdwInteraction =
728                                     (r < rVdw && (c6[i] != 0 || c12[i] != 0) && bPairIncluded);
729                         }
730                         else
731                         {
732                             computeVdwInteraction =
733                                     (rV < rVdw && (c6[i] != 0 || c12[i] != 0) && bPairIncluded);
734                         }
735                         if (gmx::anyTrue(computeVdwInteraction))
736                         {
737                             RealType rInv6;
738                             if constexpr (useSoftCore)
739                             {
740                                 rInv6 = rPInvV;
741                             }
742                             else
743                             {
744                                 rInv6 = calculateRinv6(rInvV);
745                             }
746                             RealType vVdw6  = calculateVdw6(c6[i], rInv6);
747                             RealType vVdw12 = calculateVdw12(c12[i], rInv6);
748
749                             vVdw[i] = lennardJonesPotential(
750                                     vVdw6, vVdw12, c6[i], c12[i], repulsionShift, dispersionShift, oneSixth, oneTwelfth);
751                             fScalV[i] = lennardJonesScalarForce(vVdw6, vVdw12);
752
753                             if constexpr (vdwInteractionTypeIsEwald)
754                             {
755                                 /* Subtract the grid potential at the cut-off */
756                                 vVdw[i] = vVdw[i]
757                                           + gmx::selectByMask(ewaldLennardJonesGridSubtract(
758                                                                       ljPmeC6Grid[i], shLjEwald, oneSixth),
759                                                               computeVdwInteraction);
760                             }
761
762                             if constexpr (vdwModifierIsPotSwitch)
763                             {
764                                 RealType d             = rV - rVdwSwitch;
765                                 BoolType zeroMask      = zero < d;
766                                 BoolType potSwitchMask = rV < rVdw;
767                                 d                      = gmx::selectByMask(d, zeroMask);
768                                 const RealType d2      = d * d;
769                                 const RealType sw =
770                                         one + d2 * d * (vdw_swV3 + d * (vdw_swV4 + d * vdw_swV5));
771                                 const RealType dsw = d2 * (vdw_swF2 + d * (vdw_swF3 + d * vdw_swF4));
772
773                                 fScalV[i] = potSwitchScalarForceMod(
774                                         fScalV[i], vVdw[i], sw, rV, dsw, potSwitchMask);
775                                 vVdw[i] = potSwitchPotentialMod(vVdw[i], sw, potSwitchMask);
776                             }
777
778                             vVdw[i]   = gmx::selectByMask(vVdw[i], computeVdwInteraction);
779                             fScalV[i] = gmx::selectByMask(fScalV[i], computeVdwInteraction);
780                         }
781
782                         /* fScalC (and fScalV) now contain: dV/drC * rC
783                          * Now we multiply by rC^-p, so it will be: dV/drC * rC^1-p
784                          * Further down we first multiply by r^p-2 and then by
785                          * the vector r, which in total gives: dV/drC * (r/rC)^1-p
786                          */
787                         fScalC[i] = fScalC[i] * rPInvC;
788                         fScalV[i] = fScalV[i] * rPInvV;
789                     } // end of block requiring nonZeroState
790                 }     // end for (int i = 0; i < NSTATES; i++)
791
792                 /* Assemble A and B states. */
793                 BoolType assembleStates = (bPairIncluded && withinCutoffMask);
794                 if (gmx::anyTrue(assembleStates))
795                 {
796                     for (int i = 0; i < NSTATES; i++)
797                     {
798                         vCTot = vCTot + LFC[i] * vCoul[i];
799                         vVTot = vVTot + LFV[i] * vVdw[i];
800
801                         fScal = fScal + LFC[i] * fScalC[i] * rpm2;
802                         fScal = fScal + LFV[i] * fScalV[i] * rpm2;
803
804                         if constexpr (useSoftCore)
805                         {
806                             dvdlCoul = dvdlCoul + vCoul[i] * DLF[i]
807                                        + LFC[i] * alphaCoulEff * dlFacCoul[i] * fScalC[i] * sigma6[i];
808                             dvdlVdw = dvdlVdw + vVdw[i] * DLF[i]
809                                       + LFV[i] * alphaVdwEff * dlFacVdw[i] * fScalV[i] * sigma6[i];
810                         }
811                         else
812                         {
813                             dvdlCoul = dvdlCoul + vCoul[i] * DLF[i];
814                             dvdlVdw  = dvdlVdw + vVdw[i] * DLF[i];
815                         }
816                     }
817                 }
818             } // end of block requiring bPairIncluded && withinCutoffMask
819             /* In the following block bPairIncluded should be false in the masks. */
820             if (icoul == NbkernelElecType::ReactionField)
821             {
822                 const BoolType computeReactionField = bPairExcluded;
823
824                 if (gmx::anyTrue(computeReactionField))
825                 {
826                     /* For excluded pairs we don't use soft-core.
827                      * As there is no singularity, there is no need for soft-core.
828                      */
829                     const RealType FF = -two * krf;
830                     RealType       VV = krf * rSq - crf;
831
832                     /* If ii == jnr the i particle (ii) has itself (jnr)
833                      * in its neighborlist. This corresponds to a self-interaction
834                      * that will occur twice. Scale it down by 50% to only include
835                      * it once.
836                      */
837                     VV = VV * gmx::blend(one, half, bIiEqJnr);
838
839                     for (int i = 0; i < NSTATES; i++)
840                     {
841                         vCTot = vCTot + gmx::selectByMask(LFC[i] * qq[i] * VV, computeReactionField);
842                         fScal = fScal + gmx::selectByMask(LFC[i] * qq[i] * FF, computeReactionField);
843                         dvdlCoul = dvdlCoul + gmx::selectByMask(DLF[i] * qq[i] * VV, computeReactionField);
844                     }
845                 }
846             }
847
848             const BoolType computeElecEwaldInteraction = (bPairExcluded || r < rCoulomb);
849             if (elecInteractionTypeIsEwald && gmx::anyTrue(computeElecEwaldInteraction))
850             {
851                 /* See comment in the preamble. When using Ewald interactions
852                  * (unless we use a switch modifier) we subtract the reciprocal-space
853                  * Ewald component here which made it possible to apply the free
854                  * energy interaction to 1/r (vanilla coulomb short-range part)
855                  * above. This gets us closer to the ideal case of applying
856                  * the softcore to the entire electrostatic interaction,
857                  * including the reciprocal-space component.
858                  */
859                 RealType v_lr, f_lr;
860
861                 pmeCoulombCorrectionVF(rSq, ewaldBeta, &v_lr, &f_lr, rSqValid);
862                 f_lr = f_lr * rInv * rInv;
863
864                 /* Note that any possible Ewald shift has already been applied in
865                  * the normal interaction part above.
866                  */
867
868                 /* If ii == jnr the i particle (ii) has itself (jnr)
869                  * in its neighborlist. This corresponds to a self-interaction
870                  * that will occur twice. Scale it down by 50% to only include
871                  * it once.
872                  */
873                 v_lr = v_lr * gmx::blend(one, half, bIiEqJnr);
874
875                 for (int i = 0; i < NSTATES; i++)
876                 {
877                     vCTot = vCTot - gmx::selectByMask(LFC[i] * qq[i] * v_lr, computeElecEwaldInteraction);
878                     fScal = fScal - gmx::selectByMask(LFC[i] * qq[i] * f_lr, computeElecEwaldInteraction);
879                     dvdlCoul = dvdlCoul
880                                - gmx::selectByMask(DLF[i] * qq[i] * v_lr, computeElecEwaldInteraction);
881                 }
882             }
883
884             const BoolType computeVdwEwaldInteraction = (bPairExcluded || r < rVdw);
885             if (vdwInteractionTypeIsEwald && gmx::anyTrue(computeVdwEwaldInteraction))
886             {
887                 /* See comment in the preamble. When using LJ-Ewald interactions
888                  * (unless we use a switch modifier) we subtract the reciprocal-space
889                  * Ewald component here which made it possible to apply the free
890                  * energy interaction to r^-6 (vanilla LJ6 short-range part)
891                  * above. This gets us closer to the ideal case of applying
892                  * the softcore to the entire VdW interaction,
893                  * including the reciprocal-space component.
894                  */
895
896                 RealType v_lr, f_lr;
897                 pmeLJCorrectionVF(
898                         rInv, rSq, ewaldLJCoeffSq, ewaldLJCoeffSixDivSix, &v_lr, &f_lr, computeVdwEwaldInteraction, bIiEqJnr);
899                 v_lr = v_lr * oneSixth;
900
901                 for (int i = 0; i < NSTATES; i++)
902                 {
903                     vVTot = vVTot + gmx::selectByMask(LFV[i] * ljPmeC6Grid[i] * v_lr, computeVdwEwaldInteraction);
904                     fScal = fScal + gmx::selectByMask(LFV[i] * ljPmeC6Grid[i] * f_lr, computeVdwEwaldInteraction);
905                     dvdlVdw = dvdlVdw + gmx::selectByMask(DLF[i] * ljPmeC6Grid[i] * v_lr, computeVdwEwaldInteraction);
906                 }
907             }
908
909             if (doForces && gmx::anyTrue(fScal != zero))
910             {
911                 const RealType tX = fScal * dX;
912                 const RealType tY = fScal * dY;
913                 const RealType tZ = fScal * dZ;
914                 fIX               = fIX + tX;
915                 fIY               = fIY + tY;
916                 fIZ               = fIZ + tZ;
917
918                 gmx::transposeScatterDecrU<3>(forceRealPtr, preloadJnr, tX, tY, tZ);
919             }
920         } // end for (int k = nj0; k < nj1; k += DataTypes::simdRealWidth)
921
922         if (havePairsWithinCutoff)
923         {
924             if (doForces)
925             {
926                 gmx::transposeScatterIncrU<3>(forceRealPtr, preloadIi, fIX, fIY, fIZ);
927             }
928             if (doShiftForces)
929             {
930                 gmx::transposeScatterIncrU<3>(
931                         reinterpret_cast<real*>(threadForceShiftBuffer), preloadIs, fIX, fIY, fIZ);
932             }
933             if (doPotential)
934             {
935                 int ggid = gid[n];
936                 threadVc[ggid] += gmx::reduce(vCTot);
937                 threadVv[ggid] += gmx::reduce(vVTot);
938             }
939         }
940     } // end for (int n = 0; n < nri; n++)
941
942     if (gmx::anyTrue(dvdlCoul != zero))
943     {
944         threadDvdl[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)] += gmx::reduce(dvdlCoul);
945     }
946     if (gmx::anyTrue(dvdlVdw != zero))
947     {
948         threadDvdl[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)] += gmx::reduce(dvdlVdw);
949     }
950
951     /* Estimate flops, average for free energy stuff:
952      * 12  flops per outer iteration
953      * 150 flops per inner iteration
954      * TODO: Update the number of flops and/or use different counts for different code paths.
955      */
956     atomicNrnbIncrement(nrnb, eNR_NBKERNEL_FREE_ENERGY, nlist.nri * 12 + nlist.jindex[nri] * 150);
957
958     if (haveExcludedPairsBeyondRlist > 0)
959     {
960         gmx_fatal(FARGS,
961                   "There are perturbed non-bonded pair interactions beyond the pair-list cutoff "
962                   "of %g nm, which is not supported. This can happen because the system is "
963                   "unstable or because intra-molecular interactions at long distances are "
964                   "excluded. If the "
965                   "latter is the case, you can try to increase nstlist or rlist to avoid this."
966                   "The error is likely triggered by the use of couple-intramol=no "
967                   "and the maximal distance in the decoupled molecule exceeding rlist.",
968                   rlist);
969     }
970 }
971
972 typedef void (*KernelFunction)(const t_nblist&                                  nlist,
973                                const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
974                                const int                                        ntype,
975                                const real                                       rlist,
976                                const interaction_const_t&                       ic,
977                                gmx::ArrayRef<const gmx::RVec>                   shiftvec,
978                                gmx::ArrayRef<const real>                        nbfp,
979                                gmx::ArrayRef<const real>                        nbfp_grid,
980                                gmx::ArrayRef<const real>                        chargeA,
981                                gmx::ArrayRef<const real>                        chargeB,
982                                gmx::ArrayRef<const int>                         typeA,
983                                gmx::ArrayRef<const int>                         typeB,
984                                int                                              flags,
985                                gmx::ArrayRef<const real>                        lambda,
986                                t_nrnb* gmx_restrict                             nrnb,
987                                gmx::ArrayRefWithPadding<gmx::RVec>              threadForceBuffer,
988                                rvec*               threadForceShiftBuffer,
989                                gmx::ArrayRef<real> threadVc,
990                                gmx::ArrayRef<real> threadVv,
991                                gmx::ArrayRef<real> threadDvdl);
992
993 template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
994 static KernelFunction dispatchKernelOnUseSimd(const bool useSimd)
995 {
996     if (useSimd)
997     {
998 #if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS && GMX_USE_SIMD_KERNELS
999         return (nb_free_energy_kernel<SimdDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
1000 #else
1001         return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
1002 #endif
1003     }
1004     else
1005     {
1006         return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
1007     }
1008 }
1009
1010 template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald>
1011 static KernelFunction dispatchKernelOnVdwModifier(const bool vdwModifierIsPotSwitch, const bool useSimd)
1012 {
1013     if (vdwModifierIsPotSwitch)
1014     {
1015         return (dispatchKernelOnUseSimd<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, true>(
1016                 useSimd));
1017     }
1018     else
1019     {
1020         return (dispatchKernelOnUseSimd<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, false>(
1021                 useSimd));
1022     }
1023 }
1024
1025 template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald>
1026 static KernelFunction dispatchKernelOnElecInteractionType(const bool elecInteractionTypeIsEwald,
1027                                                           const bool vdwModifierIsPotSwitch,
1028                                                           const bool useSimd)
1029 {
1030     if (elecInteractionTypeIsEwald)
1031     {
1032         return (dispatchKernelOnVdwModifier<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, true>(
1033                 vdwModifierIsPotSwitch, useSimd));
1034     }
1035     else
1036     {
1037         return (dispatchKernelOnVdwModifier<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, false>(
1038                 vdwModifierIsPotSwitch, useSimd));
1039     }
1040 }
1041
1042 template<bool useSoftCore, bool scLambdasOrAlphasDiffer>
1043 static KernelFunction dispatchKernelOnVdwInteractionType(const bool vdwInteractionTypeIsEwald,
1044                                                          const bool elecInteractionTypeIsEwald,
1045                                                          const bool vdwModifierIsPotSwitch,
1046                                                          const bool useSimd)
1047 {
1048     if (vdwInteractionTypeIsEwald)
1049     {
1050         return (dispatchKernelOnElecInteractionType<useSoftCore, scLambdasOrAlphasDiffer, true>(
1051                 elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
1052     }
1053     else
1054     {
1055         return (dispatchKernelOnElecInteractionType<useSoftCore, scLambdasOrAlphasDiffer, false>(
1056                 elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
1057     }
1058 }
1059
1060 template<bool useSoftCore>
1061 static KernelFunction dispatchKernelOnScLambdasOrAlphasDifference(const bool scLambdasOrAlphasDiffer,
1062                                                                   const bool vdwInteractionTypeIsEwald,
1063                                                                   const bool elecInteractionTypeIsEwald,
1064                                                                   const bool vdwModifierIsPotSwitch,
1065                                                                   const bool useSimd)
1066 {
1067     if (scLambdasOrAlphasDiffer)
1068     {
1069         return (dispatchKernelOnVdwInteractionType<useSoftCore, true>(
1070                 vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
1071     }
1072     else
1073     {
1074         return (dispatchKernelOnVdwInteractionType<useSoftCore, false>(
1075                 vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
1076     }
1077 }
1078
1079 static KernelFunction dispatchKernel(const bool                 scLambdasOrAlphasDiffer,
1080                                      const bool                 vdwInteractionTypeIsEwald,
1081                                      const bool                 elecInteractionTypeIsEwald,
1082                                      const bool                 vdwModifierIsPotSwitch,
1083                                      const bool                 useSimd,
1084                                      const interaction_const_t& ic)
1085 {
1086     if (ic.softCoreParameters->alphaCoulomb == 0 && ic.softCoreParameters->alphaVdw == 0)
1087     {
1088         return (dispatchKernelOnScLambdasOrAlphasDifference<false>(scLambdasOrAlphasDiffer,
1089                                                                    vdwInteractionTypeIsEwald,
1090                                                                    elecInteractionTypeIsEwald,
1091                                                                    vdwModifierIsPotSwitch,
1092                                                                    useSimd));
1093     }
1094     else
1095     {
1096         return (dispatchKernelOnScLambdasOrAlphasDifference<true>(scLambdasOrAlphasDiffer,
1097                                                                   vdwInteractionTypeIsEwald,
1098                                                                   elecInteractionTypeIsEwald,
1099                                                                   vdwModifierIsPotSwitch,
1100                                                                   useSimd));
1101     }
1102 }
1103
1104
1105 void gmx_nb_free_energy_kernel(const t_nblist&                                  nlist,
1106                                const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
1107                                const bool                                       useSimd,
1108                                const int                                        ntype,
1109                                const real                                       rlist,
1110                                const interaction_const_t&                       ic,
1111                                gmx::ArrayRef<const gmx::RVec>                   shiftvec,
1112                                gmx::ArrayRef<const real>                        nbfp,
1113                                gmx::ArrayRef<const real>                        nbfp_grid,
1114                                gmx::ArrayRef<const real>                        chargeA,
1115                                gmx::ArrayRef<const real>                        chargeB,
1116                                gmx::ArrayRef<const int>                         typeA,
1117                                gmx::ArrayRef<const int>                         typeB,
1118                                int                                              flags,
1119                                gmx::ArrayRef<const real>                        lambda,
1120                                t_nrnb*                                          nrnb,
1121                                gmx::ArrayRefWithPadding<gmx::RVec>              threadForceBuffer,
1122                                rvec*               threadForceShiftBuffer,
1123                                gmx::ArrayRef<real> threadVc,
1124                                gmx::ArrayRef<real> threadVv,
1125                                gmx::ArrayRef<real> threadDvdl)
1126 {
1127     GMX_ASSERT(EEL_PME_EWALD(ic.eeltype) || ic.eeltype == CoulombInteractionType::Cut || EEL_RF(ic.eeltype),
1128                "Unsupported eeltype with free energy");
1129     GMX_ASSERT(ic.softCoreParameters, "We need soft-core parameters");
1130
1131     // Not all SIMD implementations need padding, but we provide padding anyhow so we can assert
1132     GMX_ASSERT(!GMX_SIMD_HAVE_REAL || threadForceBuffer.empty()
1133                        || threadForceBuffer.size() > threadForceBuffer.unpaddedArrayRef().ssize(),
1134                "We need actual padding with at least one element for SIMD scatter operations");
1135
1136     const auto& scParams                   = *ic.softCoreParameters;
1137     const bool  vdwInteractionTypeIsEwald  = (EVDW_PME(ic.vdwtype));
1138     const bool  elecInteractionTypeIsEwald = (EEL_PME_EWALD(ic.eeltype));
1139     const bool  vdwModifierIsPotSwitch     = (ic.vdw_modifier == InteractionModifiers::PotSwitch);
1140     bool        scLambdasOrAlphasDiffer    = true;
1141
1142     if (scParams.alphaCoulomb == 0 && scParams.alphaVdw == 0)
1143     {
1144         scLambdasOrAlphasDiffer = false;
1145     }
1146     else
1147     {
1148         if (lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)]
1149                     == lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)]
1150             && scParams.alphaCoulomb == scParams.alphaVdw)
1151         {
1152             scLambdasOrAlphasDiffer = false;
1153         }
1154     }
1155
1156     KernelFunction kernelFunc;
1157     kernelFunc = dispatchKernel(scLambdasOrAlphasDiffer,
1158                                 vdwInteractionTypeIsEwald,
1159                                 elecInteractionTypeIsEwald,
1160                                 vdwModifierIsPotSwitch,
1161                                 useSimd,
1162                                 ic);
1163     kernelFunc(nlist,
1164                coords,
1165                ntype,
1166                rlist,
1167                ic,
1168                shiftvec,
1169                nbfp,
1170                nbfp_grid,
1171                chargeA,
1172                chargeB,
1173                typeA,
1174                typeB,
1175                flags,
1176                lambda,
1177                nrnb,
1178                threadForceBuffer,
1179                threadForceShiftBuffer,
1180                threadVc,
1181                threadVv,
1182                threadDvdl);
1183 }