SIMD support for nonbonded free-energy kernels
[alexxy/gromacs.git] / src / gromacs / gmxlib / nonbonded / nb_free_energy.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017 by the GROMACS development team.
7  * Copyright (c) 2018,2019,2020,2021, by the GROMACS development team, led by
8  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
9  * and including many others, as listed in the AUTHORS file in the
10  * top-level source directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 #include "gmxpre.h"
39
40 #include "nb_free_energy.h"
41
42 #include "config.h"
43
44 #include <cmath>
45 #include <set>
46
47 #include <algorithm>
48
49 #include "gromacs/gmxlib/nrnb.h"
50 #include "gromacs/gmxlib/nonbonded/nonbonded.h"
51 #include "gromacs/math/arrayrefwithpadding.h"
52 #include "gromacs/math/functions.h"
53 #include "gromacs/math/vec.h"
54 #include "gromacs/mdtypes/forceoutput.h"
55 #include "gromacs/mdtypes/forcerec.h"
56 #include "gromacs/mdtypes/interaction_const.h"
57 #include "gromacs/mdtypes/md_enums.h"
58 #include "gromacs/mdtypes/mdatom.h"
59 #include "gromacs/mdtypes/nblist.h"
60 #include "gromacs/pbcutil/ishift.h"
61 #include "gromacs/simd/simd.h"
62 #include "gromacs/simd/simd_math.h"
63 #include "gromacs/utility/fatalerror.h"
64 #include "gromacs/utility/arrayref.h"
65
66
67 //! Scalar (non-SIMD) data types.
68 struct ScalarDataTypes
69 {
70     using RealType = real; //!< The data type to use as real.
71     using IntType  = int;  //!< The data type to use as int.
72     using BoolType = bool; //!< The data type to use as bool for real value comparison.
73     static constexpr int simdRealWidth = 1; //!< The width of the RealType.
74     static constexpr int simdIntWidth  = 1; //!< The width of the IntType.
75 };
76
77 #if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS
78 //! SIMD data types.
79 struct SimdDataTypes
80 {
81     using RealType = gmx::SimdReal;  //!< The data type to use as real.
82     using IntType  = gmx::SimdInt32; //!< The data type to use as int.
83     using BoolType = gmx::SimdBool;  //!< The data type to use as bool for real value comparison.
84     static constexpr int simdRealWidth = GMX_SIMD_REAL_WIDTH; //!< The width of the RealType.
85 #    if GMX_SIMD_HAVE_DOUBLE && GMX_DOUBLE
86     static constexpr int simdIntWidth = GMX_SIMD_DINT32_WIDTH; //!< The width of the IntType.
87 #    else
88     static constexpr int simdIntWidth = GMX_SIMD_FINT32_WIDTH; //!< The width of the IntType.
89 #    endif
90 };
91 #endif
92
93 template<class RealType, class BoolType>
94 static inline void
95 pmeCoulombCorrectionVF(const RealType rSq, const real beta, RealType* pot, RealType* force, const BoolType mask)
96 {
97     const RealType brsq = gmx::selectByMask(rSq * beta * beta, mask);
98     *force              = -brsq * beta * gmx::pmeForceCorrection(brsq);
99     *pot                = beta * gmx::pmePotentialCorrection(brsq);
100 }
101
102 template<class RealType, class BoolType>
103 static inline void pmeLJCorrectionVF(const RealType rInv,
104                                      const RealType rSq,
105                                      const real     ewaldLJCoeffSq,
106                                      const real     ewaldLJCoeffSixDivSix,
107                                      RealType*      pot,
108                                      RealType*      force,
109                                      const BoolType mask,
110                                      const BoolType bIiEqJnr)
111 {
112     // We mask rInv to get zero force and potential for masked out pair interactions
113     const RealType rInvSq  = gmx::selectByMask(rInv * rInv, mask);
114     const RealType rInvSix = rInvSq * rInvSq * rInvSq;
115     // Mask rSq to avoid underflow in exp()
116     const RealType coeffSqRSq       = ewaldLJCoeffSq * gmx::selectByMask(rSq, mask);
117     const RealType expNegCoeffSqRSq = gmx::exp(-coeffSqRSq);
118     const RealType poly             = 1.0_real + coeffSqRSq + 0.5_real * coeffSqRSq * coeffSqRSq;
119     *force = rInvSix - expNegCoeffSqRSq * (rInvSix * poly + ewaldLJCoeffSixDivSix);
120     *force = *force * rInvSq;
121     // The self interaction is the limit for r -> 0 which we need to compute separately
122     *pot = gmx::blend(
123             rInvSix * (1.0_real - expNegCoeffSqRSq * poly), 0.5_real * ewaldLJCoeffSixDivSix, bIiEqJnr);
124 }
125
126 //! Computes r^(1/p) and 1/r^(1/p) for the standard p=6
127 template<class RealType, class BoolType>
128 static inline void pthRoot(const RealType r, RealType* pthRoot, RealType* invPthRoot, const BoolType mask)
129 {
130     RealType cbrtRes = gmx::cbrt(r);
131     *invPthRoot      = gmx::maskzInvsqrt(cbrtRes, mask);
132     *pthRoot         = gmx::maskzInv(*invPthRoot, mask);
133 }
134
135 template<class RealType>
136 static inline RealType calculateRinv6(const RealType rInvV)
137 {
138     RealType rInv6 = rInvV * rInvV;
139     return (rInv6 * rInv6 * rInv6);
140 }
141
142 template<class RealType>
143 static inline RealType calculateVdw6(const RealType c6, const RealType rInv6)
144 {
145     return (c6 * rInv6);
146 }
147
148 template<class RealType>
149 static inline RealType calculateVdw12(const RealType c12, const RealType rInv6)
150 {
151     return (c12 * rInv6 * rInv6);
152 }
153
154 /* reaction-field electrostatics */
155 template<class RealType>
156 static inline RealType reactionFieldScalarForce(const RealType qq,
157                                                 const RealType rInv,
158                                                 const RealType r,
159                                                 const real     krf,
160                                                 const real     two)
161 {
162     return (qq * (rInv - two * krf * r * r));
163 }
164 template<class RealType>
165 static inline RealType reactionFieldPotential(const RealType qq,
166                                               const RealType rInv,
167                                               const RealType r,
168                                               const real     krf,
169                                               const real     potentialShift)
170 {
171     return (qq * (rInv + krf * r * r - potentialShift));
172 }
173
174 /* Ewald electrostatics */
175 template<class RealType>
176 static inline RealType ewaldScalarForce(const RealType coulomb, const RealType rInv)
177 {
178     return (coulomb * rInv);
179 }
180 template<class RealType>
181 static inline RealType ewaldPotential(const RealType coulomb, const RealType rInv, const real potentialShift)
182 {
183     return (coulomb * (rInv - potentialShift));
184 }
185
186 /* cutoff LJ */
187 template<class RealType>
188 static inline RealType lennardJonesScalarForce(const RealType v6, const RealType v12)
189 {
190     return (v12 - v6);
191 }
192 template<class RealType>
193 static inline RealType lennardJonesPotential(const RealType v6,
194                                              const RealType v12,
195                                              const RealType c6,
196                                              const RealType c12,
197                                              const real     repulsionShift,
198                                              const real     dispersionShift,
199                                              const real     oneSixth,
200                                              const real     oneTwelfth)
201 {
202     return ((v12 + c12 * repulsionShift) * oneTwelfth - (v6 + c6 * dispersionShift) * oneSixth);
203 }
204
205 /* Ewald LJ */
206 template<class RealType>
207 static inline RealType ewaldLennardJonesGridSubtract(const RealType c6grid,
208                                                      const real     potentialShift,
209                                                      const real     oneSixth)
210 {
211     return (c6grid * potentialShift * oneSixth);
212 }
213
214 /* LJ Potential switch */
215 template<class RealType, class BoolType>
216 static inline RealType potSwitchScalarForceMod(const RealType fScalarInp,
217                                                const RealType potential,
218                                                const RealType sw,
219                                                const RealType r,
220                                                const RealType dsw,
221                                                const BoolType mask)
222 {
223     /* The mask should select on rV < rVdw */
224     return (gmx::selectByMask(fScalarInp * sw - r * potential * dsw, mask));
225 }
226 template<class RealType, class BoolType>
227 static inline RealType potSwitchPotentialMod(const RealType potentialInp, const RealType sw, const BoolType mask)
228 {
229     /* The mask should select on rV < rVdw */
230     return (gmx::selectByMask(potentialInp * sw, mask));
231 }
232
233
234 //! Templated free-energy non-bonded kernel
235 template<typename DataTypes, bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
236 static void nb_free_energy_kernel(const t_nblist&                           nlist,
237                                   const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
238                                   const int                                 ntype,
239                                   const real                                rlist,
240                                   const interaction_const_t&                ic,
241                                   gmx::ArrayRef<const gmx::RVec>            shiftvec,
242                                   gmx::ArrayRef<const real>                 nbfp,
243                                   gmx::ArrayRef<const real> gmx_unused      nbfp_grid,
244                                   gmx::ArrayRef<const real>                 chargeA,
245                                   gmx::ArrayRef<const real>                 chargeB,
246                                   gmx::ArrayRef<const int>                  typeA,
247                                   gmx::ArrayRef<const int>                  typeB,
248                                   int                                       flags,
249                                   gmx::ArrayRef<const real>                 lambda,
250                                   t_nrnb* gmx_restrict                      nrnb,
251                                   gmx::RVec*                                threadForceBuffer,
252                                   rvec*                                     threadForceShiftBuffer,
253                                   gmx::ArrayRef<real>                       threadVc,
254                                   gmx::ArrayRef<real>                       threadVv,
255                                   gmx::ArrayRef<real>                       threadDvdl)
256 {
257 #define STATE_A 0
258 #define STATE_B 1
259 #define NSTATES 2
260
261     using RealType = typename DataTypes::RealType;
262     using IntType  = typename DataTypes::IntType;
263     using BoolType = typename DataTypes::BoolType;
264
265     constexpr real oneTwelfth = 1.0_real / 12.0_real;
266     constexpr real oneSixth   = 1.0_real / 6.0_real;
267     constexpr real zero       = 0.0_real;
268     constexpr real half       = 0.5_real;
269     constexpr real one        = 1.0_real;
270     constexpr real two        = 2.0_real;
271     constexpr real six        = 6.0_real;
272
273     // Extract pair list data
274     const int                nri    = nlist.nri;
275     gmx::ArrayRef<const int> iinr   = nlist.iinr;
276     gmx::ArrayRef<const int> jindex = nlist.jindex;
277     gmx::ArrayRef<const int> jjnr   = nlist.jjnr;
278     gmx::ArrayRef<const int> shift  = nlist.shift;
279     gmx::ArrayRef<const int> gid    = nlist.gid;
280
281     const real  lambda_coul = lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)];
282     const real  lambda_vdw  = lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)];
283     const auto& scParams    = *ic.softCoreParameters;
284     const real gmx_unused alpha_coul    = scParams.alphaCoulomb;
285     const real gmx_unused alpha_vdw     = scParams.alphaVdw;
286     const real            lam_power     = scParams.lambdaPower;
287     const real gmx_unused sigma6_def    = scParams.sigma6WithInvalidSigma;
288     const real gmx_unused sigma6_min    = scParams.sigma6Minimum;
289     const bool            doForces      = ((flags & GMX_NONBONDED_DO_FORCE) != 0);
290     const bool            doShiftForces = ((flags & GMX_NONBONDED_DO_SHIFTFORCE) != 0);
291     const bool            doPotential   = ((flags & GMX_NONBONDED_DO_POTENTIAL) != 0);
292
293     // Extract data from interaction_const_t
294     const real            facel           = ic.epsfac;
295     const real            rCoulomb        = ic.rcoulomb;
296     const real            krf             = ic.reactionFieldCoefficient;
297     const real            crf             = ic.reactionFieldShift;
298     const real gmx_unused shLjEwald       = ic.sh_lj_ewald;
299     const real            rVdw            = ic.rvdw;
300     const real            dispersionShift = ic.dispersion_shift.cpot;
301     const real            repulsionShift  = ic.repulsion_shift.cpot;
302     const real            ewaldBeta       = ic.ewaldcoeff_q;
303     real gmx_unused       ewaldLJCoeffSq;
304     real gmx_unused       ewaldLJCoeffSixDivSix;
305     if constexpr (vdwInteractionTypeIsEwald)
306     {
307         ewaldLJCoeffSq        = ic.ewaldcoeff_lj * ic.ewaldcoeff_lj;
308         ewaldLJCoeffSixDivSix = ewaldLJCoeffSq * ewaldLJCoeffSq * ewaldLJCoeffSq / six;
309     }
310
311     // Note that the nbnxm kernels do not support Coulomb potential switching at all
312     GMX_ASSERT(ic.coulomb_modifier != InteractionModifiers::PotSwitch,
313                "Potential switching is not supported for Coulomb with FEP");
314
315     const real      rVdwSwitch = ic.rvdw_switch;
316     real gmx_unused vdw_swV3, vdw_swV4, vdw_swV5, vdw_swF2, vdw_swF3, vdw_swF4;
317     if constexpr (vdwModifierIsPotSwitch)
318     {
319         const real d = rVdw - rVdwSwitch;
320         vdw_swV3     = -10.0_real / (d * d * d);
321         vdw_swV4     = 15.0_real / (d * d * d * d);
322         vdw_swV5     = -6.0_real / (d * d * d * d * d);
323         vdw_swF2     = -30.0_real / (d * d * d);
324         vdw_swF3     = 60.0_real / (d * d * d * d);
325         vdw_swF4     = -30.0_real / (d * d * d * d * d);
326     }
327     else
328     {
329         /* Avoid warnings from stupid compilers (looking at you, Clang!) */
330         vdw_swV3 = vdw_swV4 = vdw_swV5 = vdw_swF2 = vdw_swF3 = vdw_swF4 = zero;
331     }
332
333     NbkernelElecType icoul;
334     if (ic.eeltype == CoulombInteractionType::Cut || EEL_RF(ic.eeltype))
335     {
336         icoul = NbkernelElecType::ReactionField;
337     }
338     else
339     {
340         icoul = NbkernelElecType::None;
341     }
342
343     real rcutoff_max2 = std::max(ic.rcoulomb, ic.rvdw);
344     rcutoff_max2      = rcutoff_max2 * rcutoff_max2;
345
346     real gmx_unused sh_ewald = 0;
347     if constexpr (elecInteractionTypeIsEwald || vdwInteractionTypeIsEwald)
348     {
349         sh_ewald = ic.sh_ewald;
350     }
351
352     /* For Ewald/PME interactions we cannot easily apply the soft-core component to
353      * reciprocal space. When we use non-switched Ewald interactions, we
354      * assume the soft-coring does not significantly affect the grid contribution
355      * and apply the soft-core only to the full 1/r (- shift) pair contribution.
356      *
357      * However, we cannot use this approach for switch-modified since we would then
358      * effectively end up evaluating a significantly different interaction here compared to the
359      * normal (non-free-energy) kernels, either by applying a cutoff at a different
360      * position than what the user requested, or by switching different
361      * things (1/r rather than short-range Ewald). For these settings, we just
362      * use the traditional short-range Ewald interaction in that case.
363      */
364     GMX_RELEASE_ASSERT(!(vdwInteractionTypeIsEwald && vdwModifierIsPotSwitch),
365                        "Can not apply soft-core to switched Ewald potentials");
366
367     RealType dvdlCoul(zero);
368     RealType dvdlVdw(zero);
369
370     /* Lambda factor for state A, 1-lambda*/
371     real LFC[NSTATES], LFV[NSTATES];
372     LFC[STATE_A] = one - lambda_coul;
373     LFV[STATE_A] = one - lambda_vdw;
374
375     /* Lambda factor for state B, lambda*/
376     LFC[STATE_B] = lambda_coul;
377     LFV[STATE_B] = lambda_vdw;
378
379     /*derivative of the lambda factor for state A and B */
380     real DLF[NSTATES];
381     DLF[STATE_A] = -one;
382     DLF[STATE_B] = one;
383
384     real gmx_unused lFacCoul[NSTATES], dlFacCoul[NSTATES], lFacVdw[NSTATES], dlFacVdw[NSTATES];
385     constexpr real  sc_r_power = six;
386     for (int i = 0; i < NSTATES; i++)
387     {
388         lFacCoul[i]  = (lam_power == 2 ? (1 - LFC[i]) * (1 - LFC[i]) : (1 - LFC[i]));
389         dlFacCoul[i] = DLF[i] * lam_power / sc_r_power * (lam_power == 2 ? (1 - LFC[i]) : 1);
390         lFacVdw[i]   = (lam_power == 2 ? (1 - LFV[i]) * (1 - LFV[i]) : (1 - LFV[i]));
391         dlFacVdw[i]  = DLF[i] * lam_power / sc_r_power * (lam_power == 2 ? (1 - LFV[i]) : 1);
392     }
393
394     // TODO: We should get rid of using pointers to real
395     const real* gmx_restrict x = coords.paddedConstArrayRef().data()[0];
396
397     const real rlistSquared = gmx::square(rlist);
398
399     bool haveExcludedPairsBeyondRlist = false;
400
401     for (int n = 0; n < nri; n++)
402     {
403         bool havePairsWithinCutoff = false;
404
405         const int  is   = shift[n];
406         const real shX  = shiftvec[is][XX];
407         const real shY  = shiftvec[is][YY];
408         const real shZ  = shiftvec[is][ZZ];
409         const int  nj0  = jindex[n];
410         const int  nj1  = jindex[n + 1];
411         const int  ii   = iinr[n];
412         const int  ii3  = 3 * ii;
413         const real ix   = shX + x[ii3 + 0];
414         const real iy   = shY + x[ii3 + 1];
415         const real iz   = shZ + x[ii3 + 2];
416         const real iqA  = facel * chargeA[ii];
417         const real iqB  = facel * chargeB[ii];
418         const int  ntiA = ntype * typeA[ii];
419         const int  ntiB = ntype * typeB[ii];
420         RealType   vCTot(0);
421         RealType   vVTot(0);
422         RealType   fIX(0);
423         RealType   fIY(0);
424         RealType   fIZ(0);
425
426 #if GMX_SIMD_HAVE_REAL
427         alignas(GMX_SIMD_ALIGNMENT) int preloadIi[DataTypes::simdRealWidth];
428         alignas(GMX_SIMD_ALIGNMENT) int preloadIs[DataTypes::simdRealWidth];
429 #else
430         int preloadIi[DataTypes::simdRealWidth];
431         int preloadIs[DataTypes::simdRealWidth];
432 #endif
433         for (int s = 0; s < DataTypes::simdRealWidth; s++)
434         {
435             preloadIi[s] = ii;
436             preloadIs[s] = shift[n];
437         }
438         IntType ii_s = gmx::load<IntType>(preloadIi);
439
440         for (int k = nj0; k < nj1; k += DataTypes::simdRealWidth)
441         {
442             RealType r, rInv;
443
444 #if GMX_SIMD_HAVE_REAL
445             alignas(GMX_SIMD_ALIGNMENT) real    preloadPairIsValid[DataTypes::simdRealWidth];
446             alignas(GMX_SIMD_ALIGNMENT) real    preloadPairIncluded[DataTypes::simdRealWidth];
447             alignas(GMX_SIMD_ALIGNMENT) int32_t preloadJnr[DataTypes::simdRealWidth];
448             alignas(GMX_SIMD_ALIGNMENT) int32_t typeIndices[NSTATES][DataTypes::simdRealWidth];
449             alignas(GMX_SIMD_ALIGNMENT) real    preloadQq[NSTATES][DataTypes::simdRealWidth];
450             alignas(GMX_SIMD_ALIGNMENT) real gmx_unused preloadSigma6[NSTATES][DataTypes::simdRealWidth];
451             alignas(GMX_SIMD_ALIGNMENT) real gmx_unused preloadAlphaVdwEff[DataTypes::simdRealWidth];
452             alignas(GMX_SIMD_ALIGNMENT) real gmx_unused preloadAlphaCoulEff[DataTypes::simdRealWidth];
453             alignas(GMX_SIMD_ALIGNMENT) real preloadLjPmeC6Grid[NSTATES][DataTypes::simdRealWidth];
454 #else
455             real            preloadPairIsValid[DataTypes::simdRealWidth];
456             real            preloadPairIncluded[DataTypes::simdRealWidth];
457             int             preloadJnr[DataTypes::simdRealWidth];
458             int             typeIndices[NSTATES][DataTypes::simdRealWidth];
459             real            preloadQq[NSTATES][DataTypes::simdRealWidth];
460             real gmx_unused preloadSigma6[NSTATES][DataTypes::simdRealWidth];
461             real gmx_unused preloadAlphaVdwEff[DataTypes::simdRealWidth];
462             real gmx_unused preloadAlphaCoulEff[DataTypes::simdRealWidth];
463             real            preloadLjPmeC6Grid[NSTATES][DataTypes::simdRealWidth];
464 #endif
465             for (int s = 0; s < DataTypes::simdRealWidth; s++)
466             {
467                 if (k + s < nj1)
468                 {
469                     preloadPairIsValid[s] = true;
470                     /* Check if this pair on the exclusions list.*/
471                     preloadPairIncluded[s]  = (nlist.excl_fep.empty() || nlist.excl_fep[k + s]);
472                     const int jnr           = jjnr[k + s];
473                     preloadJnr[s]           = jnr;
474                     typeIndices[STATE_A][s] = ntiA + typeA[jnr];
475                     typeIndices[STATE_B][s] = ntiB + typeB[jnr];
476                     preloadQq[STATE_A][s]   = iqA * chargeA[jnr];
477                     preloadQq[STATE_B][s]   = iqB * chargeB[jnr];
478
479                     for (int i = 0; i < NSTATES; i++)
480                     {
481                         if constexpr (vdwInteractionTypeIsEwald)
482                         {
483                             preloadLjPmeC6Grid[i][s] = nbfp_grid[2 * typeIndices[i][s]];
484                         }
485                         else
486                         {
487                             preloadLjPmeC6Grid[i][s] = 0;
488                         }
489                         if constexpr (useSoftCore)
490                         {
491                             const real c6  = nbfp[2 * typeIndices[i][s]];
492                             const real c12 = nbfp[2 * typeIndices[i][s] + 1];
493                             if (c6 > 0 && c12 > 0)
494                             {
495                                 /* c12 is stored scaled with 12.0 and c6 is scaled with 6.0 - correct for this */
496                                 preloadSigma6[i][s] = 0.5_real * c12 / c6;
497                                 if (preloadSigma6[i][s]
498                                     < sigma6_min) /* for disappearing coul and vdw with soft core at the same time */
499                                 {
500                                     preloadSigma6[i][s] = sigma6_min;
501                                 }
502                             }
503                             else
504                             {
505                                 preloadSigma6[i][s] = sigma6_def;
506                             }
507                         }
508                     }
509                     if constexpr (useSoftCore)
510                     {
511                         /* only use softcore if one of the states has a zero endstate - softcore is for avoiding infinities!*/
512                         const real c12A = nbfp[2 * typeIndices[STATE_A][s] + 1];
513                         const real c12B = nbfp[2 * typeIndices[STATE_B][s] + 1];
514                         if (c12A > 0 && c12B > 0)
515                         {
516                             preloadAlphaVdwEff[s]  = 0;
517                             preloadAlphaCoulEff[s] = 0;
518                         }
519                         else
520                         {
521                             preloadAlphaVdwEff[s]  = alpha_vdw;
522                             preloadAlphaCoulEff[s] = alpha_coul;
523                         }
524                     }
525                 }
526                 else
527                 {
528                     preloadJnr[s]          = jjnr[k];
529                     preloadPairIsValid[s]  = false;
530                     preloadPairIncluded[s] = false;
531                     preloadAlphaVdwEff[s]  = 0;
532                     preloadAlphaCoulEff[s] = 0;
533
534                     for (int i = 0; i < NSTATES; i++)
535                     {
536                         typeIndices[STATE_A][s]  = ntiA + typeA[jjnr[k]];
537                         typeIndices[STATE_B][s]  = ntiB + typeB[jjnr[k]];
538                         preloadLjPmeC6Grid[i][s] = 0;
539                         preloadQq[i][s]          = 0;
540                         preloadSigma6[i][s]      = 0;
541                     }
542                 }
543             }
544
545             RealType jx, jy, jz;
546             gmx::gatherLoadUTranspose<3>(reinterpret_cast<const real*>(x), preloadJnr, &jx, &jy, &jz);
547
548             const RealType pairIsValid   = gmx::load<RealType>(preloadPairIsValid);
549             const RealType pairIncluded  = gmx::load<RealType>(preloadPairIncluded);
550             const BoolType bPairIncluded = (pairIncluded != zero);
551             const BoolType bPairExcluded = (pairIncluded == zero && pairIsValid != zero);
552
553             const RealType dX  = ix - jx;
554             const RealType dY  = iy - jy;
555             const RealType dZ  = iz - jz;
556             const RealType rSq = dX * dX + dY * dY + dZ * dZ;
557
558             BoolType withinCutoffMask = (rSq < rcutoff_max2);
559
560             if (!gmx::anyTrue(withinCutoffMask || bPairExcluded))
561             {
562                 /* We save significant time by skipping all code below.
563                  * Note that with soft-core interactions, the actual cut-off
564                  * check might be different. But since the soft-core distance
565                  * is always larger than r, checking on r here is safe.
566                  * Exclusions outside the cutoff can not be skipped as
567                  * when using Ewald: the reciprocal-space
568                  * Ewald component still needs to be subtracted.
569                  */
570                 continue;
571             }
572             else
573             {
574                 havePairsWithinCutoff = true;
575             }
576
577             if (gmx::anyTrue(rlistSquared < rSq && bPairExcluded))
578             {
579                 haveExcludedPairsBeyondRlist = true;
580             }
581
582             const IntType  jnr_s    = gmx::load<IntType>(preloadJnr);
583             const BoolType bIiEqJnr = gmx::cvtIB2B(ii_s == jnr_s);
584
585             RealType            c6[NSTATES];
586             RealType            c12[NSTATES];
587             RealType gmx_unused sigma6[NSTATES];
588             RealType            qq[NSTATES];
589             RealType gmx_unused ljPmeC6Grid[NSTATES];
590             RealType gmx_unused alphaVdwEff;
591             RealType gmx_unused alphaCoulEff;
592             for (int i = 0; i < NSTATES; i++)
593             {
594                 gmx::gatherLoadTranspose<2>(nbfp.data(), typeIndices[i], &c6[i], &c12[i]);
595                 qq[i]          = gmx::load<RealType>(preloadQq[i]);
596                 ljPmeC6Grid[i] = gmx::load<RealType>(preloadLjPmeC6Grid[i]);
597                 if constexpr (useSoftCore)
598                 {
599                     sigma6[i] = gmx::load<RealType>(preloadSigma6[i]);
600                 }
601             }
602             if constexpr (useSoftCore)
603             {
604                 alphaVdwEff  = gmx::load<RealType>(preloadAlphaVdwEff);
605                 alphaCoulEff = gmx::load<RealType>(preloadAlphaCoulEff);
606             }
607
608             BoolType rSqValid = (zero < rSq);
609
610             /* The force at r=0 is zero, because of symmetry.
611              * But note that the potential is in general non-zero,
612              * since the soft-cored r will be non-zero.
613              */
614             rInv = gmx::maskzInvsqrt(rSq, rSqValid);
615             r    = rSq * rInv;
616
617             RealType gmx_unused rp, rpm2;
618             if constexpr (useSoftCore)
619             {
620                 rpm2 = rSq * rSq;  /* r4 */
621                 rp   = rpm2 * rSq; /* r6 */
622             }
623             else
624             {
625                 /* The soft-core power p will not affect the results
626                  * with not using soft-core, so we use power of 0 which gives
627                  * the simplest math and cheapest code.
628                  */
629                 rpm2 = rInv * rInv;
630                 rp   = one;
631             }
632
633             RealType fScal(0);
634
635             /* The following block is masked to only calculate values having bPairIncluded. If
636              * bPairIncluded is true then withinCutoffMask must also be true. */
637             if (gmx::anyTrue(withinCutoffMask && bPairIncluded))
638             {
639                 RealType fScalC[NSTATES], fScalV[NSTATES];
640                 RealType vCoul[NSTATES], vVdw[NSTATES];
641                 for (int i = 0; i < NSTATES; i++)
642                 {
643                     fScalC[i] = zero;
644                     fScalV[i] = zero;
645                     vCoul[i]  = zero;
646                     vVdw[i]   = zero;
647
648                     RealType gmx_unused rInvC, rInvV, rC, rV, rPInvC, rPInvV;
649
650                     /* The following block is masked to require (qq[i] != 0 || c6[i] != 0 || c12[i]
651                      * != 0) in addition to bPairIncluded, which in turn requires withinCutoffMask. */
652                     BoolType nonZeroState = ((qq[i] != zero || c6[i] != zero || c12[i] != zero)
653                                              && bPairIncluded && withinCutoffMask);
654                     if (gmx::anyTrue(nonZeroState))
655                     {
656                         if constexpr (useSoftCore)
657                         {
658                             RealType divisor      = (alphaCoulEff * lFacCoul[i] * sigma6[i] + rp);
659                             BoolType validDivisor = (zero < divisor);
660                             rPInvC                = gmx::maskzInv(divisor, validDivisor);
661                             pthRoot(rPInvC, &rInvC, &rC, validDivisor);
662
663                             if constexpr (scLambdasOrAlphasDiffer)
664                             {
665                                 RealType divisor      = (alphaVdwEff * lFacVdw[i] * sigma6[i] + rp);
666                                 BoolType validDivisor = (zero < divisor);
667                                 rPInvV                = gmx::maskzInv(divisor, validDivisor);
668                                 pthRoot(rPInvV, &rInvV, &rV, validDivisor);
669                             }
670                             else
671                             {
672                                 /* We can avoid one expensive pow and one / operation */
673                                 rPInvV = rPInvC;
674                                 rInvV  = rInvC;
675                                 rV     = rC;
676                             }
677                         }
678                         else
679                         {
680                             rPInvC = one;
681                             rInvC  = rInv;
682                             rC     = r;
683
684                             rPInvV = one;
685                             rInvV  = rInv;
686                             rV     = r;
687                         }
688
689                         /* Only process the coulomb interactions if we either
690                          * include all entries in the list (no cutoff
691                          * used in the kernel), or if we are within the cutoff.
692                          */
693                         BoolType computeElecInteraction;
694                         if constexpr (elecInteractionTypeIsEwald)
695                         {
696                             computeElecInteraction = (r < rCoulomb && qq[i] != zero && bPairIncluded);
697                         }
698                         else
699                         {
700                             computeElecInteraction = (rC < rCoulomb && qq[i] != zero && bPairIncluded);
701                         }
702                         if (gmx::anyTrue(computeElecInteraction))
703                         {
704                             if constexpr (elecInteractionTypeIsEwald)
705                             {
706                                 vCoul[i]  = ewaldPotential(qq[i], rInvC, sh_ewald);
707                                 fScalC[i] = ewaldScalarForce(qq[i], rInvC);
708                             }
709                             else
710                             {
711                                 vCoul[i]  = reactionFieldPotential(qq[i], rInvC, rC, krf, crf);
712                                 fScalC[i] = reactionFieldScalarForce(qq[i], rInvC, rC, krf, two);
713                             }
714
715                             vCoul[i]  = gmx::selectByMask(vCoul[i], computeElecInteraction);
716                             fScalC[i] = gmx::selectByMask(fScalC[i], computeElecInteraction);
717                         }
718
719                         /* Only process the VDW interactions if we either
720                          * include all entries in the list (no cutoff used
721                          * in the kernel), or if we are within the cutoff.
722                          */
723                         BoolType computeVdwInteraction;
724                         if constexpr (vdwInteractionTypeIsEwald)
725                         {
726                             computeVdwInteraction =
727                                     (r < rVdw && (c6[i] != 0 || c12[i] != 0) && bPairIncluded);
728                         }
729                         else
730                         {
731                             computeVdwInteraction =
732                                     (rV < rVdw && (c6[i] != 0 || c12[i] != 0) && bPairIncluded);
733                         }
734                         if (gmx::anyTrue(computeVdwInteraction))
735                         {
736                             RealType rInv6;
737                             if constexpr (useSoftCore)
738                             {
739                                 rInv6 = rPInvV;
740                             }
741                             else
742                             {
743                                 rInv6 = calculateRinv6(rInvV);
744                             }
745                             RealType vVdw6  = calculateVdw6(c6[i], rInv6);
746                             RealType vVdw12 = calculateVdw12(c12[i], rInv6);
747
748                             vVdw[i] = lennardJonesPotential(
749                                     vVdw6, vVdw12, c6[i], c12[i], repulsionShift, dispersionShift, oneSixth, oneTwelfth);
750                             fScalV[i] = lennardJonesScalarForce(vVdw6, vVdw12);
751
752                             if constexpr (vdwInteractionTypeIsEwald)
753                             {
754                                 /* Subtract the grid potential at the cut-off */
755                                 vVdw[i] = vVdw[i]
756                                           + gmx::selectByMask(ewaldLennardJonesGridSubtract(
757                                                                       ljPmeC6Grid[i], shLjEwald, oneSixth),
758                                                               computeVdwInteraction);
759                             }
760
761                             if constexpr (vdwModifierIsPotSwitch)
762                             {
763                                 RealType d             = rV - rVdwSwitch;
764                                 BoolType zeroMask      = zero < d;
765                                 BoolType potSwitchMask = rV < rVdw;
766                                 d                      = gmx::selectByMask(d, zeroMask);
767                                 const RealType d2      = d * d;
768                                 const RealType sw =
769                                         one + d2 * d * (vdw_swV3 + d * (vdw_swV4 + d * vdw_swV5));
770                                 const RealType dsw = d2 * (vdw_swF2 + d * (vdw_swF3 + d * vdw_swF4));
771
772                                 fScalV[i] = potSwitchScalarForceMod(
773                                         fScalV[i], vVdw[i], sw, rV, dsw, potSwitchMask);
774                                 vVdw[i] = potSwitchPotentialMod(vVdw[i], sw, potSwitchMask);
775                             }
776
777                             vVdw[i]   = gmx::selectByMask(vVdw[i], computeVdwInteraction);
778                             fScalV[i] = gmx::selectByMask(fScalV[i], computeVdwInteraction);
779                         }
780
781                         /* fScalC (and fScalV) now contain: dV/drC * rC
782                          * Now we multiply by rC^-p, so it will be: dV/drC * rC^1-p
783                          * Further down we first multiply by r^p-2 and then by
784                          * the vector r, which in total gives: dV/drC * (r/rC)^1-p
785                          */
786                         fScalC[i] = fScalC[i] * rPInvC;
787                         fScalV[i] = fScalV[i] * rPInvV;
788                     } // end of block requiring nonZeroState
789                 }     // end for (int i = 0; i < NSTATES; i++)
790
791                 /* Assemble A and B states. */
792                 BoolType assembleStates = (bPairIncluded && withinCutoffMask);
793                 if (gmx::anyTrue(assembleStates))
794                 {
795                     for (int i = 0; i < NSTATES; i++)
796                     {
797                         vCTot = vCTot + LFC[i] * vCoul[i];
798                         vVTot = vVTot + LFV[i] * vVdw[i];
799
800                         fScal = fScal + LFC[i] * fScalC[i] * rpm2;
801                         fScal = fScal + LFV[i] * fScalV[i] * rpm2;
802
803                         if constexpr (useSoftCore)
804                         {
805                             dvdlCoul = dvdlCoul + vCoul[i] * DLF[i]
806                                        + LFC[i] * alphaCoulEff * dlFacCoul[i] * fScalC[i] * sigma6[i];
807                             dvdlVdw = dvdlVdw + vVdw[i] * DLF[i]
808                                       + LFV[i] * alphaVdwEff * dlFacVdw[i] * fScalV[i] * sigma6[i];
809                         }
810                         else
811                         {
812                             dvdlCoul = dvdlCoul + vCoul[i] * DLF[i];
813                             dvdlVdw  = dvdlVdw + vVdw[i] * DLF[i];
814                         }
815                     }
816                 }
817             } // end of block requiring bPairIncluded && withinCutoffMask
818             /* In the following block bPairIncluded should be false in the masks. */
819             if (icoul == NbkernelElecType::ReactionField)
820             {
821                 const BoolType computeReactionField = bPairExcluded;
822
823                 if (gmx::anyTrue(computeReactionField))
824                 {
825                     /* For excluded pairs we don't use soft-core.
826                      * As there is no singularity, there is no need for soft-core.
827                      */
828                     const RealType FF = -two * krf;
829                     RealType       VV = krf * rSq - crf;
830
831                     /* If ii == jnr the i particle (ii) has itself (jnr)
832                      * in its neighborlist. This corresponds to a self-interaction
833                      * that will occur twice. Scale it down by 50% to only include
834                      * it once.
835                      */
836                     VV = VV * gmx::blend(one, half, bIiEqJnr);
837
838                     for (int i = 0; i < NSTATES; i++)
839                     {
840                         vCTot = vCTot + gmx::selectByMask(LFC[i] * qq[i] * VV, computeReactionField);
841                         fScal = fScal + gmx::selectByMask(LFC[i] * qq[i] * FF, computeReactionField);
842                         dvdlCoul = dvdlCoul + gmx::selectByMask(DLF[i] * qq[i] * VV, computeReactionField);
843                     }
844                 }
845             }
846
847             const BoolType computeElecEwaldInteraction = (bPairExcluded || r < rCoulomb);
848             if (elecInteractionTypeIsEwald && gmx::anyTrue(computeElecEwaldInteraction))
849             {
850                 /* See comment in the preamble. When using Ewald interactions
851                  * (unless we use a switch modifier) we subtract the reciprocal-space
852                  * Ewald component here which made it possible to apply the free
853                  * energy interaction to 1/r (vanilla coulomb short-range part)
854                  * above. This gets us closer to the ideal case of applying
855                  * the softcore to the entire electrostatic interaction,
856                  * including the reciprocal-space component.
857                  */
858                 RealType v_lr, f_lr;
859
860                 pmeCoulombCorrectionVF(rSq, ewaldBeta, &v_lr, &f_lr, rSqValid);
861                 f_lr = f_lr * rInv * rInv;
862
863                 /* Note that any possible Ewald shift has already been applied in
864                  * the normal interaction part above.
865                  */
866
867                 /* If ii == jnr the i particle (ii) has itself (jnr)
868                  * in its neighborlist. This corresponds to a self-interaction
869                  * that will occur twice. Scale it down by 50% to only include
870                  * it once.
871                  */
872                 v_lr = v_lr * gmx::blend(one, half, bIiEqJnr);
873
874                 for (int i = 0; i < NSTATES; i++)
875                 {
876                     vCTot = vCTot - gmx::selectByMask(LFC[i] * qq[i] * v_lr, computeElecEwaldInteraction);
877                     fScal = fScal - gmx::selectByMask(LFC[i] * qq[i] * f_lr, computeElecEwaldInteraction);
878                     dvdlCoul = dvdlCoul
879                                - gmx::selectByMask(DLF[i] * qq[i] * v_lr, computeElecEwaldInteraction);
880                 }
881             }
882
883             const BoolType computeVdwEwaldInteraction = (bPairExcluded || r < rVdw);
884             if (vdwInteractionTypeIsEwald && gmx::anyTrue(computeVdwEwaldInteraction))
885             {
886                 /* See comment in the preamble. When using LJ-Ewald interactions
887                  * (unless we use a switch modifier) we subtract the reciprocal-space
888                  * Ewald component here which made it possible to apply the free
889                  * energy interaction to r^-6 (vanilla LJ6 short-range part)
890                  * above. This gets us closer to the ideal case of applying
891                  * the softcore to the entire VdW interaction,
892                  * including the reciprocal-space component.
893                  */
894
895                 RealType v_lr, f_lr;
896                 pmeLJCorrectionVF(
897                         rInv, rSq, ewaldLJCoeffSq, ewaldLJCoeffSixDivSix, &v_lr, &f_lr, computeVdwEwaldInteraction, bIiEqJnr);
898                 v_lr = v_lr * oneSixth;
899
900                 for (int i = 0; i < NSTATES; i++)
901                 {
902                     vVTot = vVTot + gmx::selectByMask(LFV[i] * ljPmeC6Grid[i] * v_lr, computeVdwEwaldInteraction);
903                     fScal = fScal + gmx::selectByMask(LFV[i] * ljPmeC6Grid[i] * f_lr, computeVdwEwaldInteraction);
904                     dvdlVdw = dvdlVdw + gmx::selectByMask(DLF[i] * ljPmeC6Grid[i] * v_lr, computeVdwEwaldInteraction);
905                 }
906             }
907
908             if (doForces && gmx::anyTrue(fScal != zero))
909             {
910                 const RealType tX = fScal * dX;
911                 const RealType tY = fScal * dY;
912                 const RealType tZ = fScal * dZ;
913                 fIX               = fIX + tX;
914                 fIY               = fIY + tY;
915                 fIZ               = fIZ + tZ;
916
917                 gmx::transposeScatterDecrU<3>(
918                         reinterpret_cast<real*>(threadForceBuffer), preloadJnr, tX, tY, tZ);
919             }
920         } // end for (int k = nj0; k < nj1; k += DataTypes::simdRealWidth)
921
922         if (havePairsWithinCutoff)
923         {
924             if (doForces)
925             {
926                 gmx::transposeScatterIncrU<3>(
927                         reinterpret_cast<real*>(threadForceBuffer), preloadIi, fIX, fIY, fIZ);
928             }
929             if (doShiftForces)
930             {
931                 gmx::transposeScatterIncrU<3>(
932                         reinterpret_cast<real*>(threadForceShiftBuffer), preloadIs, fIX, fIY, fIZ);
933             }
934             if (doPotential)
935             {
936                 int ggid = gid[n];
937                 threadVc[ggid] += gmx::reduce(vCTot);
938                 threadVv[ggid] += gmx::reduce(vVTot);
939             }
940         }
941     } // end for (int n = 0; n < nri; n++)
942
943     if (gmx::anyTrue(dvdlCoul != zero))
944     {
945         threadDvdl[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)] += gmx::reduce(dvdlCoul);
946     }
947     if (gmx::anyTrue(dvdlVdw != zero))
948     {
949         threadDvdl[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)] += gmx::reduce(dvdlVdw);
950     }
951
952     /* Estimate flops, average for free energy stuff:
953      * 12  flops per outer iteration
954      * 150 flops per inner iteration
955      * TODO: Update the number of flops and/or use different counts for different code paths.
956      */
957     atomicNrnbIncrement(nrnb, eNR_NBKERNEL_FREE_ENERGY, nlist.nri * 12 + nlist.jindex[nri] * 150);
958
959     if (haveExcludedPairsBeyondRlist > 0)
960     {
961         gmx_fatal(FARGS,
962                   "There are perturbed non-bonded pair interactions beyond the pair-list cutoff "
963                   "of %g nm, which is not supported. This can happen because the system is "
964                   "unstable or because intra-molecular interactions at long distances are "
965                   "excluded. If the "
966                   "latter is the case, you can try to increase nstlist or rlist to avoid this."
967                   "The error is likely triggered by the use of couple-intramol=no "
968                   "and the maximal distance in the decoupled molecule exceeding rlist.",
969                   rlist);
970     }
971 }
972
973 typedef void (*KernelFunction)(const t_nblist&                           nlist,
974                                const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
975                                const int                                 ntype,
976                                const real                                rlist,
977                                const interaction_const_t&                ic,
978                                gmx::ArrayRef<const gmx::RVec>            shiftvec,
979                                gmx::ArrayRef<const real>                 nbfp,
980                                gmx::ArrayRef<const real>                 nbfp_grid,
981                                gmx::ArrayRef<const real>                 chargeA,
982                                gmx::ArrayRef<const real>                 chargeB,
983                                gmx::ArrayRef<const int>                  typeA,
984                                gmx::ArrayRef<const int>                  typeB,
985                                int                                       flags,
986                                gmx::ArrayRef<const real>                 lambda,
987                                t_nrnb* gmx_restrict                      nrnb,
988                                gmx::RVec*                                threadForceBuffer,
989                                rvec*                                     threadForceShiftBuffer,
990                                gmx::ArrayRef<real>                       threadVc,
991                                gmx::ArrayRef<real>                       threadVv,
992                                gmx::ArrayRef<real>                       threadDvdl);
993
994 template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
995 static KernelFunction dispatchKernelOnUseSimd(const bool useSimd)
996 {
997     if (useSimd)
998     {
999 #if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS && GMX_USE_SIMD_KERNELS
1000         return (nb_free_energy_kernel<SimdDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
1001 #else
1002         return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
1003 #endif
1004     }
1005     else
1006     {
1007         return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch>);
1008     }
1009 }
1010
1011 template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald>
1012 static KernelFunction dispatchKernelOnVdwModifier(const bool vdwModifierIsPotSwitch, const bool useSimd)
1013 {
1014     if (vdwModifierIsPotSwitch)
1015     {
1016         return (dispatchKernelOnUseSimd<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, true>(
1017                 useSimd));
1018     }
1019     else
1020     {
1021         return (dispatchKernelOnUseSimd<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, false>(
1022                 useSimd));
1023     }
1024 }
1025
1026 template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald>
1027 static KernelFunction dispatchKernelOnElecInteractionType(const bool elecInteractionTypeIsEwald,
1028                                                           const bool vdwModifierIsPotSwitch,
1029                                                           const bool useSimd)
1030 {
1031     if (elecInteractionTypeIsEwald)
1032     {
1033         return (dispatchKernelOnVdwModifier<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, true>(
1034                 vdwModifierIsPotSwitch, useSimd));
1035     }
1036     else
1037     {
1038         return (dispatchKernelOnVdwModifier<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, false>(
1039                 vdwModifierIsPotSwitch, useSimd));
1040     }
1041 }
1042
1043 template<bool useSoftCore, bool scLambdasOrAlphasDiffer>
1044 static KernelFunction dispatchKernelOnVdwInteractionType(const bool vdwInteractionTypeIsEwald,
1045                                                          const bool elecInteractionTypeIsEwald,
1046                                                          const bool vdwModifierIsPotSwitch,
1047                                                          const bool useSimd)
1048 {
1049     if (vdwInteractionTypeIsEwald)
1050     {
1051         return (dispatchKernelOnElecInteractionType<useSoftCore, scLambdasOrAlphasDiffer, true>(
1052                 elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
1053     }
1054     else
1055     {
1056         return (dispatchKernelOnElecInteractionType<useSoftCore, scLambdasOrAlphasDiffer, false>(
1057                 elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
1058     }
1059 }
1060
1061 template<bool useSoftCore>
1062 static KernelFunction dispatchKernelOnScLambdasOrAlphasDifference(const bool scLambdasOrAlphasDiffer,
1063                                                                   const bool vdwInteractionTypeIsEwald,
1064                                                                   const bool elecInteractionTypeIsEwald,
1065                                                                   const bool vdwModifierIsPotSwitch,
1066                                                                   const bool useSimd)
1067 {
1068     if (scLambdasOrAlphasDiffer)
1069     {
1070         return (dispatchKernelOnVdwInteractionType<useSoftCore, true>(
1071                 vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
1072     }
1073     else
1074     {
1075         return (dispatchKernelOnVdwInteractionType<useSoftCore, false>(
1076                 vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, useSimd));
1077     }
1078 }
1079
1080 static KernelFunction dispatchKernel(const bool                 scLambdasOrAlphasDiffer,
1081                                      const bool                 vdwInteractionTypeIsEwald,
1082                                      const bool                 elecInteractionTypeIsEwald,
1083                                      const bool                 vdwModifierIsPotSwitch,
1084                                      const bool                 useSimd,
1085                                      const interaction_const_t& ic)
1086 {
1087     if (ic.softCoreParameters->alphaCoulomb == 0 && ic.softCoreParameters->alphaVdw == 0)
1088     {
1089         return (dispatchKernelOnScLambdasOrAlphasDifference<false>(scLambdasOrAlphasDiffer,
1090                                                                    vdwInteractionTypeIsEwald,
1091                                                                    elecInteractionTypeIsEwald,
1092                                                                    vdwModifierIsPotSwitch,
1093                                                                    useSimd));
1094     }
1095     else
1096     {
1097         return (dispatchKernelOnScLambdasOrAlphasDifference<true>(scLambdasOrAlphasDiffer,
1098                                                                   vdwInteractionTypeIsEwald,
1099                                                                   elecInteractionTypeIsEwald,
1100                                                                   vdwModifierIsPotSwitch,
1101                                                                   useSimd));
1102     }
1103 }
1104
1105
1106 void gmx_nb_free_energy_kernel(const t_nblist&                           nlist,
1107                                const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
1108                                const bool                                useSimd,
1109                                const int                                 ntype,
1110                                const real                                rlist,
1111                                const interaction_const_t&                ic,
1112                                gmx::ArrayRef<const gmx::RVec>            shiftvec,
1113                                gmx::ArrayRef<const real>                 nbfp,
1114                                gmx::ArrayRef<const real>                 nbfp_grid,
1115                                gmx::ArrayRef<const real>                 chargeA,
1116                                gmx::ArrayRef<const real>                 chargeB,
1117                                gmx::ArrayRef<const int>                  typeA,
1118                                gmx::ArrayRef<const int>                  typeB,
1119                                int                                       flags,
1120                                gmx::ArrayRef<const real>                 lambda,
1121                                t_nrnb*                                   nrnb,
1122                                gmx::RVec*                                threadForceBuffer,
1123                                rvec*                                     threadForceShiftBuffer,
1124                                gmx::ArrayRef<real>                       threadVc,
1125                                gmx::ArrayRef<real>                       threadVv,
1126                                gmx::ArrayRef<real>                       threadDvdl)
1127 {
1128     GMX_ASSERT(EEL_PME_EWALD(ic.eeltype) || ic.eeltype == CoulombInteractionType::Cut || EEL_RF(ic.eeltype),
1129                "Unsupported eeltype with free energy");
1130     GMX_ASSERT(ic.softCoreParameters, "We need soft-core parameters");
1131
1132     const auto& scParams                   = *ic.softCoreParameters;
1133     const bool  vdwInteractionTypeIsEwald  = (EVDW_PME(ic.vdwtype));
1134     const bool  elecInteractionTypeIsEwald = (EEL_PME_EWALD(ic.eeltype));
1135     const bool  vdwModifierIsPotSwitch     = (ic.vdw_modifier == InteractionModifiers::PotSwitch);
1136     bool        scLambdasOrAlphasDiffer    = true;
1137
1138     if (scParams.alphaCoulomb == 0 && scParams.alphaVdw == 0)
1139     {
1140         scLambdasOrAlphasDiffer = false;
1141     }
1142     else
1143     {
1144         if (lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)]
1145                     == lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)]
1146             && scParams.alphaCoulomb == scParams.alphaVdw)
1147         {
1148             scLambdasOrAlphasDiffer = false;
1149         }
1150     }
1151
1152     KernelFunction kernelFunc;
1153     kernelFunc = dispatchKernel(scLambdasOrAlphasDiffer,
1154                                 vdwInteractionTypeIsEwald,
1155                                 elecInteractionTypeIsEwald,
1156                                 vdwModifierIsPotSwitch,
1157                                 useSimd,
1158                                 ic);
1159     kernelFunc(nlist,
1160                coords,
1161                ntype,
1162                rlist,
1163                ic,
1164                shiftvec,
1165                nbfp,
1166                nbfp_grid,
1167                chargeA,
1168                chargeB,
1169                typeA,
1170                typeB,
1171                flags,
1172                lambda,
1173                nrnb,
1174                threadForceBuffer,
1175                threadForceShiftBuffer,
1176                threadVc,
1177                threadVv,
1178                threadDvdl);
1179 }