c8e46d15d3a5cb381c876ddc79941fc6fabf9a39
[alexxy/gromacs.git] / src / gromacs / gmxlib / nonbonded / nb_free_energy.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017 by the GROMACS development team.
7  * Copyright (c) 2018,2019,2020,2021, by the GROMACS development team, led by
8  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
9  * and including many others, as listed in the AUTHORS file in the
10  * top-level source directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 #include "gmxpre.h"
39
40 #include "nb_free_energy.h"
41
42 #include "config.h"
43
44 #include <cmath>
45 #include <set>
46
47 #include <algorithm>
48
49 #include "gromacs/gmxlib/nrnb.h"
50 #include "gromacs/gmxlib/nonbonded/nonbonded.h"
51 #include "gromacs/math/arrayrefwithpadding.h"
52 #include "gromacs/math/functions.h"
53 #include "gromacs/math/vec.h"
54 #include "gromacs/mdtypes/forceoutput.h"
55 #include "gromacs/mdtypes/forcerec.h"
56 #include "gromacs/mdtypes/interaction_const.h"
57 #include "gromacs/mdtypes/md_enums.h"
58 #include "gromacs/mdtypes/mdatom.h"
59 #include "gromacs/mdtypes/nblist.h"
60 #include "gromacs/pbcutil/ishift.h"
61 #include "gromacs/simd/simd.h"
62 #include "gromacs/simd/simd_math.h"
63 #include "gromacs/utility/fatalerror.h"
64 #include "gromacs/utility/arrayref.h"
65
66
67 //! Scalar (non-SIMD) data types.
68 struct ScalarDataTypes
69 {
70     using RealType = real; //!< The data type to use as real.
71     using IntType  = int;  //!< The data type to use as int.
72     using BoolType = bool; //!< The data type to use as bool for real value comparison.
73     static constexpr int simdRealWidth = 1; //!< The width of the RealType.
74     static constexpr int simdIntWidth  = 1; //!< The width of the IntType.
75 };
76
77 #if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS
78 //! SIMD data types.
79 struct SimdDataTypes
80 {
81     using RealType = gmx::SimdReal;  //!< The data type to use as real.
82     using IntType  = gmx::SimdInt32; //!< The data type to use as int.
83     using BoolType = gmx::SimdBool;  //!< The data type to use as bool for real value comparison.
84     static constexpr int simdRealWidth = GMX_SIMD_REAL_WIDTH; //!< The width of the RealType.
85 #    if GMX_SIMD_HAVE_DOUBLE && GMX_DOUBLE
86     static constexpr int simdIntWidth = GMX_SIMD_DINT32_WIDTH; //!< The width of the IntType.
87 #    else
88     static constexpr int simdIntWidth = GMX_SIMD_FINT32_WIDTH; //!< The width of the IntType.
89 #    endif
90 };
91 #endif
92
93 /*! \brief Lower limit for square interaction distances in nonbonded kernels.
94  *
95  * This is a mimimum on r^2 to avoid overflows when computing r^6.
96  * This will only affect results for soft-cored interaction at distances smaller
97  * than 1e-6 and will limit extremely high foreign energies for overlapping atoms.
98  * Note that we could use a somewhat smaller minimum in double precision.
99  * But because invsqrt in double precision can use single precision, this number
100  * can not be much smaller, we use the same number for simplicity.
101  */
102 constexpr real c_minDistanceSquared = 1.0e-12_real;
103
104 /*! \brief Higher limit for r^-6 used for Lennard-Jones interactions
105  *
106  * This is needed to avoid overflow of LJ energy and force terms for excluded
107  * atoms and foreign energies of hard-core states of overlapping atoms.
108  * Note that in single precision this value leaves room for C12 coefficients up to 3.4e8.
109  */
110 constexpr real c_maxRInvSix = 1.0e15_real;
111
112 template<bool computeForces, class RealType>
113 static inline void
114 pmeCoulombCorrectionVF(const RealType rSq, const real beta, RealType* pot, RealType gmx_unused* force)
115 {
116     const RealType brsq = rSq * beta * beta;
117     if constexpr (computeForces)
118     {
119         *force = -brsq * beta * gmx::pmeForceCorrection(brsq);
120     }
121     *pot = beta * gmx::pmePotentialCorrection(brsq);
122 }
123
124 template<bool computeForces, class RealType, class BoolType>
125 static inline void pmeLJCorrectionVF(const RealType rInv,
126                                      const RealType rSq,
127                                      const real     ewaldLJCoeffSq,
128                                      const real     ewaldLJCoeffSixDivSix,
129                                      RealType*      pot,
130                                      RealType gmx_unused* force,
131                                      const BoolType       mask,
132                                      const BoolType       bIiEqJnr)
133 {
134     // We mask rInv to get zero force and potential for masked out pair interactions
135     const RealType rInvSq  = rInv * rInv;
136     const RealType rInvSix = rInvSq * rInvSq * rInvSq;
137     // Mask rSq to avoid underflow in exp()
138     const RealType coeffSqRSq       = ewaldLJCoeffSq * gmx::selectByMask(rSq, mask);
139     const RealType expNegCoeffSqRSq = gmx::exp(-coeffSqRSq);
140     const RealType poly             = 1.0_real + coeffSqRSq + 0.5_real * coeffSqRSq * coeffSqRSq;
141     if constexpr (computeForces)
142     {
143         *force = rInvSix - expNegCoeffSqRSq * (rInvSix * poly + ewaldLJCoeffSixDivSix);
144         *force = *force * rInvSq;
145     }
146     // The self interaction is the limit for r -> 0 which we need to compute separately
147     *pot = gmx::blend(
148             rInvSix * (1.0_real - expNegCoeffSqRSq * poly), 0.5_real * ewaldLJCoeffSixDivSix, bIiEqJnr);
149 }
150
151 //! Computes r^(1/6) and 1/r^(1/6)
152 template<class RealType>
153 static inline void sixthRoot(const RealType r, RealType* sixthRoot, RealType* invSixthRoot)
154 {
155     RealType cbrtRes = gmx::cbrt(r);
156     *invSixthRoot    = gmx::invsqrt(cbrtRes);
157     *sixthRoot       = gmx::inv(*invSixthRoot);
158 }
159
160 template<class RealType>
161 static inline RealType calculateRinv6(const RealType rInvV)
162 {
163     RealType rInv6 = rInvV * rInvV;
164     return (rInv6 * rInv6 * rInv6);
165 }
166
167 template<class RealType>
168 static inline RealType calculateVdw6(const RealType c6, const RealType rInv6)
169 {
170     return (c6 * rInv6);
171 }
172
173 template<class RealType>
174 static inline RealType calculateVdw12(const RealType c12, const RealType rInv6)
175 {
176     return (c12 * rInv6 * rInv6);
177 }
178
179 /* reaction-field electrostatics */
180 template<class RealType>
181 static inline RealType reactionFieldScalarForce(const RealType qq,
182                                                 const RealType rInv,
183                                                 const RealType r,
184                                                 const real     krf,
185                                                 const real     two)
186 {
187     return (qq * (rInv - two * krf * r * r));
188 }
189 template<class RealType>
190 static inline RealType reactionFieldPotential(const RealType qq,
191                                               const RealType rInv,
192                                               const RealType r,
193                                               const real     krf,
194                                               const real     potentialShift)
195 {
196     return (qq * (rInv + krf * r * r - potentialShift));
197 }
198
199 /* Ewald electrostatics */
200 template<class RealType>
201 static inline RealType ewaldScalarForce(const RealType coulomb, const RealType rInv)
202 {
203     return (coulomb * rInv);
204 }
205 template<class RealType>
206 static inline RealType ewaldPotential(const RealType coulomb, const RealType rInv, const real potentialShift)
207 {
208     return (coulomb * (rInv - potentialShift));
209 }
210
211 /* cutoff LJ */
212 template<class RealType>
213 static inline RealType lennardJonesScalarForce(const RealType v6, const RealType v12)
214 {
215     return (v12 - v6);
216 }
217 template<class RealType>
218 static inline RealType lennardJonesPotential(const RealType v6,
219                                              const RealType v12,
220                                              const RealType c6,
221                                              const RealType c12,
222                                              const real     repulsionShift,
223                                              const real     dispersionShift,
224                                              const real     oneSixth,
225                                              const real     oneTwelfth)
226 {
227     return ((v12 + c12 * repulsionShift) * oneTwelfth - (v6 + c6 * dispersionShift) * oneSixth);
228 }
229
230 /* Ewald LJ */
231 template<class RealType>
232 static inline RealType ewaldLennardJonesGridSubtract(const RealType c6grid,
233                                                      const real     potentialShift,
234                                                      const real     oneSixth)
235 {
236     return (c6grid * potentialShift * oneSixth);
237 }
238
239 /* LJ Potential switch */
240 template<class RealType, class BoolType>
241 static inline RealType potSwitchScalarForceMod(const RealType fScalarInp,
242                                                const RealType potential,
243                                                const RealType sw,
244                                                const RealType r,
245                                                const RealType dsw,
246                                                const BoolType mask)
247 {
248     /* The mask should select on rV < rVdw */
249     return (gmx::selectByMask(fScalarInp * sw - r * potential * dsw, mask));
250 }
251 template<class RealType, class BoolType>
252 static inline RealType potSwitchPotentialMod(const RealType potentialInp, const RealType sw, const BoolType mask)
253 {
254     /* The mask should select on rV < rVdw */
255     return (gmx::selectByMask(potentialInp * sw, mask));
256 }
257
258
259 //! Templated free-energy non-bonded kernel
260 template<typename DataTypes, bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch, bool computeForces>
261 static void nb_free_energy_kernel(const t_nblist&                                  nlist,
262                                   const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
263                                   const int                                        ntype,
264                                   const real                                       rlist,
265                                   const interaction_const_t&                       ic,
266                                   gmx::ArrayRef<const gmx::RVec>                   shiftvec,
267                                   gmx::ArrayRef<const real>                        nbfp,
268                                   gmx::ArrayRef<const real> gmx_unused             nbfp_grid,
269                                   gmx::ArrayRef<const real>                        chargeA,
270                                   gmx::ArrayRef<const real>                        chargeB,
271                                   gmx::ArrayRef<const int>                         typeA,
272                                   gmx::ArrayRef<const int>                         typeB,
273                                   int                                              flags,
274                                   gmx::ArrayRef<const real>                        lambda,
275                                   t_nrnb* gmx_restrict                             nrnb,
276                                   gmx::ArrayRefWithPadding<gmx::RVec> threadForceBuffer,
277                                   rvec gmx_unused*    threadForceShiftBuffer,
278                                   gmx::ArrayRef<real> threadVc,
279                                   gmx::ArrayRef<real> threadVv,
280                                   gmx::ArrayRef<real> threadDvdl)
281 {
282 #define STATE_A 0
283 #define STATE_B 1
284 #define NSTATES 2
285
286     using RealType = typename DataTypes::RealType;
287     using IntType  = typename DataTypes::IntType;
288     using BoolType = typename DataTypes::BoolType;
289
290     constexpr real oneTwelfth = 1.0_real / 12.0_real;
291     constexpr real oneSixth   = 1.0_real / 6.0_real;
292     constexpr real zero       = 0.0_real;
293     constexpr real half       = 0.5_real;
294     constexpr real one        = 1.0_real;
295     constexpr real two        = 2.0_real;
296     constexpr real six        = 6.0_real;
297
298     // Extract pair list data
299     const int                nri    = nlist.nri;
300     gmx::ArrayRef<const int> iinr   = nlist.iinr;
301     gmx::ArrayRef<const int> jindex = nlist.jindex;
302     gmx::ArrayRef<const int> jjnr   = nlist.jjnr;
303     gmx::ArrayRef<const int> shift  = nlist.shift;
304     gmx::ArrayRef<const int> gid    = nlist.gid;
305
306     const real  lambda_coul = lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)];
307     const real  lambda_vdw  = lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)];
308     const auto& scParams    = *ic.softCoreParameters;
309     const real gmx_unused alpha_coul    = scParams.alphaCoulomb;
310     const real gmx_unused alpha_vdw     = scParams.alphaVdw;
311     const real            lam_power     = scParams.lambdaPower;
312     const real gmx_unused sigma6_def    = scParams.sigma6WithInvalidSigma;
313     const real gmx_unused sigma6_min    = scParams.sigma6Minimum;
314     const bool gmx_unused doShiftForces = ((flags & GMX_NONBONDED_DO_SHIFTFORCE) != 0);
315     const bool            doPotential   = ((flags & GMX_NONBONDED_DO_POTENTIAL) != 0);
316
317     // Extract data from interaction_const_t
318     const real            facel           = ic.epsfac;
319     const real            rCoulomb        = ic.rcoulomb;
320     const real            krf             = ic.reactionFieldCoefficient;
321     const real            crf             = ic.reactionFieldShift;
322     const real gmx_unused shLjEwald       = ic.sh_lj_ewald;
323     const real            rVdw            = ic.rvdw;
324     const real            dispersionShift = ic.dispersion_shift.cpot;
325     const real            repulsionShift  = ic.repulsion_shift.cpot;
326     const real            ewaldBeta       = ic.ewaldcoeff_q;
327     real gmx_unused       ewaldLJCoeffSq;
328     real gmx_unused       ewaldLJCoeffSixDivSix;
329     if constexpr (vdwInteractionTypeIsEwald)
330     {
331         ewaldLJCoeffSq        = ic.ewaldcoeff_lj * ic.ewaldcoeff_lj;
332         ewaldLJCoeffSixDivSix = ewaldLJCoeffSq * ewaldLJCoeffSq * ewaldLJCoeffSq / six;
333     }
334
335     // Note that the nbnxm kernels do not support Coulomb potential switching at all
336     GMX_ASSERT(ic.coulomb_modifier != InteractionModifiers::PotSwitch,
337                "Potential switching is not supported for Coulomb with FEP");
338
339     const real      rVdwSwitch = ic.rvdw_switch;
340     real gmx_unused vdw_swV3, vdw_swV4, vdw_swV5, vdw_swF2, vdw_swF3, vdw_swF4;
341     if constexpr (vdwModifierIsPotSwitch)
342     {
343         const real d = rVdw - rVdwSwitch;
344         vdw_swV3     = -10.0_real / (d * d * d);
345         vdw_swV4     = 15.0_real / (d * d * d * d);
346         vdw_swV5     = -6.0_real / (d * d * d * d * d);
347         vdw_swF2     = -30.0_real / (d * d * d);
348         vdw_swF3     = 60.0_real / (d * d * d * d);
349         vdw_swF4     = -30.0_real / (d * d * d * d * d);
350     }
351     else
352     {
353         /* Avoid warnings from stupid compilers (looking at you, Clang!) */
354         vdw_swV3 = vdw_swV4 = vdw_swV5 = vdw_swF2 = vdw_swF3 = vdw_swF4 = zero;
355     }
356
357     NbkernelElecType icoul;
358     if (ic.eeltype == CoulombInteractionType::Cut || EEL_RF(ic.eeltype))
359     {
360         icoul = NbkernelElecType::ReactionField;
361     }
362     else
363     {
364         icoul = NbkernelElecType::None;
365     }
366
367     real rcutoff_max2 = std::max(ic.rcoulomb, ic.rvdw);
368     rcutoff_max2      = rcutoff_max2 * rcutoff_max2;
369
370     real gmx_unused sh_ewald = 0;
371     if constexpr (elecInteractionTypeIsEwald || vdwInteractionTypeIsEwald)
372     {
373         sh_ewald = ic.sh_ewald;
374     }
375
376     /* For Ewald/PME interactions we cannot easily apply the soft-core component to
377      * reciprocal space. When we use non-switched Ewald interactions, we
378      * assume the soft-coring does not significantly affect the grid contribution
379      * and apply the soft-core only to the full 1/r (- shift) pair contribution.
380      *
381      * However, we cannot use this approach for switch-modified since we would then
382      * effectively end up evaluating a significantly different interaction here compared to the
383      * normal (non-free-energy) kernels, either by applying a cutoff at a different
384      * position than what the user requested, or by switching different
385      * things (1/r rather than short-range Ewald). For these settings, we just
386      * use the traditional short-range Ewald interaction in that case.
387      */
388     GMX_RELEASE_ASSERT(!(vdwInteractionTypeIsEwald && vdwModifierIsPotSwitch),
389                        "Can not apply soft-core to switched Ewald potentials");
390
391     const RealType            minDistanceSquared(c_minDistanceSquared);
392     const RealType            maxRInvSix(c_maxRInvSix);
393     const RealType gmx_unused floatMin(GMX_FLOAT_MIN);
394
395     RealType dvdlCoul(zero);
396     RealType dvdlVdw(zero);
397
398     /* Lambda factor for state A, 1-lambda*/
399     real LFC[NSTATES], LFV[NSTATES];
400     LFC[STATE_A] = one - lambda_coul;
401     LFV[STATE_A] = one - lambda_vdw;
402
403     /* Lambda factor for state B, lambda*/
404     LFC[STATE_B] = lambda_coul;
405     LFV[STATE_B] = lambda_vdw;
406
407     /*derivative of the lambda factor for state A and B */
408     real DLF[NSTATES];
409     DLF[STATE_A] = -one;
410     DLF[STATE_B] = one;
411
412     real gmx_unused lFacCoul[NSTATES], dlFacCoul[NSTATES], lFacVdw[NSTATES], dlFacVdw[NSTATES];
413     constexpr real  sc_r_power = six;
414     for (int i = 0; i < NSTATES; i++)
415     {
416         lFacCoul[i]  = (lam_power == 2 ? (1 - LFC[i]) * (1 - LFC[i]) : (1 - LFC[i]));
417         dlFacCoul[i] = DLF[i] * lam_power / sc_r_power * (lam_power == 2 ? (1 - LFC[i]) : 1);
418         lFacVdw[i]   = (lam_power == 2 ? (1 - LFV[i]) * (1 - LFV[i]) : (1 - LFV[i]));
419         dlFacVdw[i]  = DLF[i] * lam_power / sc_r_power * (lam_power == 2 ? (1 - LFV[i]) : 1);
420     }
421
422     // We need pointers to real for SIMD access
423     const real* gmx_restrict x = coords.paddedConstArrayRef().data()[0];
424     real* gmx_restrict       forceRealPtr;
425     if constexpr (computeForces)
426     {
427         GMX_ASSERT(nri == 0 || !threadForceBuffer.empty(), "need a valid threadForceBuffer");
428
429         forceRealPtr = threadForceBuffer.paddedArrayRef().data()[0];
430
431         if (doShiftForces)
432         {
433             GMX_ASSERT(threadForceShiftBuffer != nullptr, "need a valid threadForceShiftBuffer");
434         }
435     }
436
437     const real rlistSquared = gmx::square(rlist);
438
439     bool haveExcludedPairsBeyondRlist = false;
440
441     for (int n = 0; n < nri; n++)
442     {
443         bool havePairsWithinCutoff = false;
444
445         const int  is   = shift[n];
446         const real shX  = shiftvec[is][XX];
447         const real shY  = shiftvec[is][YY];
448         const real shZ  = shiftvec[is][ZZ];
449         const int  nj0  = jindex[n];
450         const int  nj1  = jindex[n + 1];
451         const int  ii   = iinr[n];
452         const int  ii3  = 3 * ii;
453         const real ix   = shX + x[ii3 + 0];
454         const real iy   = shY + x[ii3 + 1];
455         const real iz   = shZ + x[ii3 + 2];
456         const real iqA  = facel * chargeA[ii];
457         const real iqB  = facel * chargeB[ii];
458         const int  ntiA = ntype * typeA[ii];
459         const int  ntiB = ntype * typeB[ii];
460         RealType   vCTot(0);
461         RealType   vVTot(0);
462         RealType   fIX(0);
463         RealType   fIY(0);
464         RealType   fIZ(0);
465
466 #if GMX_SIMD_HAVE_REAL
467         alignas(GMX_SIMD_ALIGNMENT) int            preloadIi[DataTypes::simdRealWidth];
468         alignas(GMX_SIMD_ALIGNMENT) int gmx_unused preloadIs[DataTypes::simdRealWidth];
469 #else
470         int            preloadIi[DataTypes::simdRealWidth];
471         int gmx_unused preloadIs[DataTypes::simdRealWidth];
472 #endif
473         for (int s = 0; s < DataTypes::simdRealWidth; s++)
474         {
475             preloadIi[s] = ii;
476             preloadIs[s] = shift[n];
477         }
478         IntType ii_s = gmx::load<IntType>(preloadIi);
479
480         for (int k = nj0; k < nj1; k += DataTypes::simdRealWidth)
481         {
482             RealType r, rInv;
483
484 #if GMX_SIMD_HAVE_REAL
485             alignas(GMX_SIMD_ALIGNMENT) real    preloadPairIsValid[DataTypes::simdRealWidth];
486             alignas(GMX_SIMD_ALIGNMENT) real    preloadPairIncluded[DataTypes::simdRealWidth];
487             alignas(GMX_SIMD_ALIGNMENT) int32_t preloadJnr[DataTypes::simdRealWidth];
488             alignas(GMX_SIMD_ALIGNMENT) int32_t typeIndices[NSTATES][DataTypes::simdRealWidth];
489             alignas(GMX_SIMD_ALIGNMENT) real    preloadQq[NSTATES][DataTypes::simdRealWidth];
490             alignas(GMX_SIMD_ALIGNMENT) real gmx_unused preloadSigma6[NSTATES][DataTypes::simdRealWidth];
491             alignas(GMX_SIMD_ALIGNMENT) real gmx_unused preloadAlphaVdwEff[DataTypes::simdRealWidth];
492             alignas(GMX_SIMD_ALIGNMENT) real gmx_unused preloadAlphaCoulEff[DataTypes::simdRealWidth];
493             alignas(GMX_SIMD_ALIGNMENT) real preloadLjPmeC6Grid[NSTATES][DataTypes::simdRealWidth];
494 #else
495             real            preloadPairIsValid[DataTypes::simdRealWidth];
496             real            preloadPairIncluded[DataTypes::simdRealWidth];
497             int             preloadJnr[DataTypes::simdRealWidth];
498             int             typeIndices[NSTATES][DataTypes::simdRealWidth];
499             real            preloadQq[NSTATES][DataTypes::simdRealWidth];
500             real gmx_unused preloadSigma6[NSTATES][DataTypes::simdRealWidth];
501             real gmx_unused preloadAlphaVdwEff[DataTypes::simdRealWidth];
502             real gmx_unused preloadAlphaCoulEff[DataTypes::simdRealWidth];
503             real            preloadLjPmeC6Grid[NSTATES][DataTypes::simdRealWidth];
504 #endif
505             for (int s = 0; s < DataTypes::simdRealWidth; s++)
506             {
507                 if (k + s < nj1)
508                 {
509                     preloadPairIsValid[s] = true;
510                     /* Check if this pair on the exclusions list.*/
511                     preloadPairIncluded[s]  = (nlist.excl_fep.empty() || nlist.excl_fep[k + s]);
512                     const int jnr           = jjnr[k + s];
513                     preloadJnr[s]           = jnr;
514                     typeIndices[STATE_A][s] = ntiA + typeA[jnr];
515                     typeIndices[STATE_B][s] = ntiB + typeB[jnr];
516                     preloadQq[STATE_A][s]   = iqA * chargeA[jnr];
517                     preloadQq[STATE_B][s]   = iqB * chargeB[jnr];
518
519                     for (int i = 0; i < NSTATES; i++)
520                     {
521                         if constexpr (vdwInteractionTypeIsEwald)
522                         {
523                             preloadLjPmeC6Grid[i][s] = nbfp_grid[2 * typeIndices[i][s]];
524                         }
525                         else
526                         {
527                             preloadLjPmeC6Grid[i][s] = 0;
528                         }
529                         if constexpr (useSoftCore)
530                         {
531                             const real c6  = nbfp[2 * typeIndices[i][s]];
532                             const real c12 = nbfp[2 * typeIndices[i][s] + 1];
533                             if (c6 > 0 && c12 > 0)
534                             {
535                                 /* c12 is stored scaled with 12.0 and c6 is scaled with 6.0 - correct for this */
536                                 preloadSigma6[i][s] = 0.5_real * c12 / c6;
537                                 if (preloadSigma6[i][s]
538                                     < sigma6_min) /* for disappearing coul and vdw with soft core at the same time */
539                                 {
540                                     preloadSigma6[i][s] = sigma6_min;
541                                 }
542                             }
543                             else
544                             {
545                                 preloadSigma6[i][s] = sigma6_def;
546                             }
547                         }
548                     }
549                     if constexpr (useSoftCore)
550                     {
551                         /* only use softcore if one of the states has a zero endstate - softcore is for avoiding infinities!*/
552                         const real c12A = nbfp[2 * typeIndices[STATE_A][s] + 1];
553                         const real c12B = nbfp[2 * typeIndices[STATE_B][s] + 1];
554                         if (c12A > 0 && c12B > 0)
555                         {
556                             preloadAlphaVdwEff[s]  = 0;
557                             preloadAlphaCoulEff[s] = 0;
558                         }
559                         else
560                         {
561                             preloadAlphaVdwEff[s]  = alpha_vdw;
562                             preloadAlphaCoulEff[s] = alpha_coul;
563                         }
564                     }
565                 }
566                 else
567                 {
568                     preloadJnr[s]          = jjnr[k];
569                     preloadPairIsValid[s]  = false;
570                     preloadPairIncluded[s] = false;
571                     preloadAlphaVdwEff[s]  = 0;
572                     preloadAlphaCoulEff[s] = 0;
573
574                     for (int i = 0; i < NSTATES; i++)
575                     {
576                         typeIndices[STATE_A][s]  = ntiA + typeA[jjnr[k]];
577                         typeIndices[STATE_B][s]  = ntiB + typeB[jjnr[k]];
578                         preloadLjPmeC6Grid[i][s] = 0;
579                         preloadQq[i][s]          = 0;
580                         preloadSigma6[i][s]      = 0;
581                     }
582                 }
583             }
584
585             RealType jx, jy, jz;
586             gmx::gatherLoadUTranspose<3>(reinterpret_cast<const real*>(x), preloadJnr, &jx, &jy, &jz);
587
588             const RealType pairIsValid   = gmx::load<RealType>(preloadPairIsValid);
589             const RealType pairIncluded  = gmx::load<RealType>(preloadPairIncluded);
590             const BoolType bPairIncluded = (pairIncluded != zero);
591             const BoolType bPairExcluded = (pairIncluded == zero && pairIsValid != zero);
592
593             const RealType dX  = ix - jx;
594             const RealType dY  = iy - jy;
595             const RealType dZ  = iz - jz;
596             RealType       rSq = dX * dX + dY * dY + dZ * dZ;
597
598             BoolType withinCutoffMask = (rSq < rcutoff_max2);
599
600             if (!gmx::anyTrue(withinCutoffMask || bPairExcluded))
601             {
602                 /* We save significant time by skipping all code below.
603                  * Note that with soft-core interactions, the actual cut-off
604                  * check might be different. But since the soft-core distance
605                  * is always larger than r, checking on r here is safe.
606                  * Exclusions outside the cutoff can not be skipped as
607                  * when using Ewald: the reciprocal-space
608                  * Ewald component still needs to be subtracted.
609                  */
610                 continue;
611             }
612             else
613             {
614                 havePairsWithinCutoff = true;
615             }
616
617             if (gmx::anyTrue(rlistSquared < rSq && bPairExcluded))
618             {
619                 haveExcludedPairsBeyondRlist = true;
620             }
621
622             const IntType  jnr_s    = gmx::load<IntType>(preloadJnr);
623             const BoolType bIiEqJnr = gmx::cvtIB2B(ii_s == jnr_s);
624
625             RealType            c6[NSTATES];
626             RealType            c12[NSTATES];
627             RealType gmx_unused sigma6[NSTATES];
628             RealType            qq[NSTATES];
629             RealType gmx_unused ljPmeC6Grid[NSTATES];
630             RealType gmx_unused alphaVdwEff;
631             RealType gmx_unused alphaCoulEff;
632             for (int i = 0; i < NSTATES; i++)
633             {
634                 gmx::gatherLoadTranspose<2>(nbfp.data(), typeIndices[i], &c6[i], &c12[i]);
635                 qq[i]          = gmx::load<RealType>(preloadQq[i]);
636                 ljPmeC6Grid[i] = gmx::load<RealType>(preloadLjPmeC6Grid[i]);
637                 if constexpr (useSoftCore)
638                 {
639                     sigma6[i] = gmx::load<RealType>(preloadSigma6[i]);
640                 }
641             }
642             if constexpr (useSoftCore)
643             {
644                 alphaVdwEff  = gmx::load<RealType>(preloadAlphaVdwEff);
645                 alphaCoulEff = gmx::load<RealType>(preloadAlphaCoulEff);
646             }
647
648             // Avoid overflow of r^-12 at distances near zero
649             rSq  = gmx::max(rSq, minDistanceSquared);
650             rInv = gmx::invsqrt(rSq);
651             r    = rSq * rInv;
652
653             RealType gmx_unused rp, rpm2;
654             if constexpr (useSoftCore)
655             {
656                 rpm2 = rSq * rSq;  /* r4 */
657                 rp   = rpm2 * rSq; /* r6 */
658             }
659             else
660             {
661                 /* The soft-core power p will not affect the results
662                  * with not using soft-core, so we use power of 0 which gives
663                  * the simplest math and cheapest code.
664                  */
665                 rpm2 = rInv * rInv;
666                 rp   = one;
667             }
668
669             RealType fScal(0);
670
671             /* The following block is masked to only calculate values having bPairIncluded. If
672              * bPairIncluded is true then withinCutoffMask must also be true. */
673             if (gmx::anyTrue(withinCutoffMask && bPairIncluded))
674             {
675                 RealType gmx_unused fScalC[NSTATES], fScalV[NSTATES];
676                 RealType            vCoul[NSTATES], vVdw[NSTATES];
677                 for (int i = 0; i < NSTATES; i++)
678                 {
679                     fScalC[i] = zero;
680                     fScalV[i] = zero;
681                     vCoul[i]  = zero;
682                     vVdw[i]   = zero;
683
684                     RealType gmx_unused rInvC, rInvV, rC, rV, rPInvC, rPInvV;
685
686                     /* The following block is masked to require (qq[i] != 0 || c6[i] != 0 || c12[i]
687                      * != 0) in addition to bPairIncluded, which in turn requires withinCutoffMask. */
688                     BoolType nonZeroState = ((qq[i] != zero || c6[i] != zero || c12[i] != zero)
689                                              && bPairIncluded && withinCutoffMask);
690                     if (gmx::anyTrue(nonZeroState))
691                     {
692                         if constexpr (useSoftCore)
693                         {
694                             RealType divisor = (alphaCoulEff * lFacCoul[i] * sigma6[i] + rp);
695                             rPInvC           = gmx::inv(divisor);
696                             sixthRoot(rPInvC, &rInvC, &rC);
697
698                             if constexpr (scLambdasOrAlphasDiffer)
699                             {
700                                 RealType divisor = (alphaVdwEff * lFacVdw[i] * sigma6[i] + rp);
701                                 rPInvV           = gmx::inv(divisor);
702                                 sixthRoot(rPInvV, &rInvV, &rV);
703                             }
704                             else
705                             {
706                                 /* We can avoid one expensive pow and one / operation */
707                                 rPInvV = rPInvC;
708                                 rInvV  = rInvC;
709                                 rV     = rC;
710                             }
711                         }
712                         else
713                         {
714                             rPInvC = one;
715                             rInvC  = rInv;
716                             rC     = r;
717
718                             rPInvV = one;
719                             rInvV  = rInv;
720                             rV     = r;
721                         }
722
723                         /* Only process the coulomb interactions if we either
724                          * include all entries in the list (no cutoff
725                          * used in the kernel), or if we are within the cutoff.
726                          */
727                         BoolType computeElecInteraction;
728                         if constexpr (elecInteractionTypeIsEwald)
729                         {
730                             computeElecInteraction = (r < rCoulomb && qq[i] != zero && bPairIncluded);
731                         }
732                         else
733                         {
734                             computeElecInteraction = (rC < rCoulomb && qq[i] != zero && bPairIncluded);
735                         }
736                         if (gmx::anyTrue(computeElecInteraction))
737                         {
738                             if constexpr (elecInteractionTypeIsEwald)
739                             {
740                                 vCoul[i] = ewaldPotential(qq[i], rInvC, sh_ewald);
741                                 if constexpr (computeForces)
742                                 {
743                                     fScalC[i] = ewaldScalarForce(qq[i], rInvC);
744                                 }
745                             }
746                             else
747                             {
748                                 vCoul[i] = reactionFieldPotential(qq[i], rInvC, rC, krf, crf);
749                                 if constexpr (computeForces)
750                                 {
751                                     fScalC[i] = reactionFieldScalarForce(qq[i], rInvC, rC, krf, two);
752                                 }
753                             }
754
755                             vCoul[i] = gmx::selectByMask(vCoul[i], computeElecInteraction);
756                             if constexpr (computeForces)
757                             {
758                                 fScalC[i] = gmx::selectByMask(fScalC[i], computeElecInteraction);
759                             }
760                         }
761
762                         /* Only process the VDW interactions if we either
763                          * include all entries in the list (no cutoff used
764                          * in the kernel), or if we are within the cutoff.
765                          */
766                         BoolType computeVdwInteraction;
767                         if constexpr (vdwInteractionTypeIsEwald)
768                         {
769                             computeVdwInteraction =
770                                     (r < rVdw && (c6[i] != 0 || c12[i] != 0) && bPairIncluded);
771                         }
772                         else
773                         {
774                             computeVdwInteraction =
775                                     (rV < rVdw && (c6[i] != 0 || c12[i] != 0) && bPairIncluded);
776                         }
777                         if (gmx::anyTrue(computeVdwInteraction))
778                         {
779                             RealType rInv6;
780                             if constexpr (useSoftCore)
781                             {
782                                 rInv6 = rPInvV;
783                             }
784                             else
785                             {
786                                 rInv6 = calculateRinv6(rInvV);
787                             }
788                             // Avoid overflow at short distance for masked exclusions and
789                             // for foreign energy calculations at a hard core end state.
790                             // Note that we should limit r^-6, and thus also r^-12, and
791                             // not only r^-12, as that could lead to erroneously low instead
792                             // of very high foreign energies.
793                             rInv6           = gmx::min(rInv6, maxRInvSix);
794                             RealType vVdw6  = calculateVdw6(c6[i], rInv6);
795                             RealType vVdw12 = calculateVdw12(c12[i], rInv6);
796
797                             vVdw[i] = lennardJonesPotential(
798                                     vVdw6, vVdw12, c6[i], c12[i], repulsionShift, dispersionShift, oneSixth, oneTwelfth);
799                             if constexpr (computeForces)
800                             {
801                                 fScalV[i] = lennardJonesScalarForce(vVdw6, vVdw12);
802                             }
803
804                             if constexpr (vdwInteractionTypeIsEwald)
805                             {
806                                 /* Subtract the grid potential at the cut-off */
807                                 vVdw[i] = vVdw[i]
808                                           + gmx::selectByMask(ewaldLennardJonesGridSubtract(
809                                                                       ljPmeC6Grid[i], shLjEwald, oneSixth),
810                                                               computeVdwInteraction);
811                             }
812
813                             if constexpr (vdwModifierIsPotSwitch)
814                             {
815                                 RealType d             = rV - rVdwSwitch;
816                                 BoolType zeroMask      = zero < d;
817                                 BoolType potSwitchMask = rV < rVdw;
818                                 d                      = gmx::selectByMask(d, zeroMask);
819                                 const RealType d2      = d * d;
820                                 const RealType sw =
821                                         one + d2 * d * (vdw_swV3 + d * (vdw_swV4 + d * vdw_swV5));
822
823                                 if constexpr (computeForces)
824                                 {
825                                     const RealType dsw = d2 * (vdw_swF2 + d * (vdw_swF3 + d * vdw_swF4));
826                                     fScalV[i]          = potSwitchScalarForceMod(
827                                             fScalV[i], vVdw[i], sw, rV, dsw, potSwitchMask);
828                                 }
829                                 vVdw[i] = potSwitchPotentialMod(vVdw[i], sw, potSwitchMask);
830                             }
831
832                             vVdw[i] = gmx::selectByMask(vVdw[i], computeVdwInteraction);
833                             if constexpr (computeForces)
834                             {
835                                 fScalV[i] = gmx::selectByMask(fScalV[i], computeVdwInteraction);
836                             }
837                         }
838
839                         if constexpr (computeForces)
840                         {
841                             /* fScalC (and fScalV) now contain: dV/drC * rC
842                              * Now we multiply by rC^-6, so it will be: dV/drC * rC^-5
843                              * Further down we first multiply by r^4 and then by
844                              * the vector r, which in total gives: dV/drC * (r/rC)^-5
845                              */
846                             fScalC[i] = fScalC[i] * rPInvC;
847                             fScalV[i] = fScalV[i] * rPInvV;
848                         }
849                     } // end of block requiring nonZeroState
850                 }     // end for (int i = 0; i < NSTATES; i++)
851
852                 /* Assemble A and B states. */
853                 BoolType assembleStates = (bPairIncluded && withinCutoffMask);
854                 if (gmx::anyTrue(assembleStates))
855                 {
856                     for (int i = 0; i < NSTATES; i++)
857                     {
858                         vCTot = vCTot + LFC[i] * vCoul[i];
859                         vVTot = vVTot + LFV[i] * vVdw[i];
860
861                         if constexpr (computeForces)
862                         {
863                             fScal = fScal + LFC[i] * fScalC[i] * rpm2;
864                             fScal = fScal + LFV[i] * fScalV[i] * rpm2;
865                         }
866
867                         if constexpr (useSoftCore)
868                         {
869                             dvdlCoul = dvdlCoul + vCoul[i] * DLF[i]
870                                        + LFC[i] * alphaCoulEff * dlFacCoul[i] * fScalC[i] * sigma6[i];
871                             dvdlVdw = dvdlVdw + vVdw[i] * DLF[i]
872                                       + LFV[i] * alphaVdwEff * dlFacVdw[i] * fScalV[i] * sigma6[i];
873                         }
874                         else
875                         {
876                             dvdlCoul = dvdlCoul + vCoul[i] * DLF[i];
877                             dvdlVdw  = dvdlVdw + vVdw[i] * DLF[i];
878                         }
879                     }
880                 }
881             } // end of block requiring bPairIncluded && withinCutoffMask
882             /* In the following block bPairIncluded should be false in the masks. */
883             if (icoul == NbkernelElecType::ReactionField)
884             {
885                 const BoolType computeReactionField = bPairExcluded;
886
887                 if (gmx::anyTrue(computeReactionField))
888                 {
889                     /* For excluded pairs we don't use soft-core.
890                      * As there is no singularity, there is no need for soft-core.
891                      */
892                     const RealType FF = -two * krf;
893                     RealType       VV = krf * rSq - crf;
894
895                     /* If ii == jnr the i particle (ii) has itself (jnr)
896                      * in its neighborlist. This corresponds to a self-interaction
897                      * that will occur twice. Scale it down by 50% to only include
898                      * it once.
899                      */
900                     VV = VV * gmx::blend(one, half, bIiEqJnr);
901
902                     for (int i = 0; i < NSTATES; i++)
903                     {
904                         vCTot = vCTot + gmx::selectByMask(LFC[i] * qq[i] * VV, computeReactionField);
905                         fScal = fScal + gmx::selectByMask(LFC[i] * qq[i] * FF, computeReactionField);
906                         dvdlCoul = dvdlCoul + gmx::selectByMask(DLF[i] * qq[i] * VV, computeReactionField);
907                     }
908                 }
909             }
910
911             const BoolType computeElecEwaldInteraction = (bPairExcluded || r < rCoulomb);
912             if (elecInteractionTypeIsEwald && gmx::anyTrue(computeElecEwaldInteraction))
913             {
914                 /* See comment in the preamble. When using Ewald interactions
915                  * (unless we use a switch modifier) we subtract the reciprocal-space
916                  * Ewald component here which made it possible to apply the free
917                  * energy interaction to 1/r (vanilla coulomb short-range part)
918                  * above. This gets us closer to the ideal case of applying
919                  * the softcore to the entire electrostatic interaction,
920                  * including the reciprocal-space component.
921                  */
922                 RealType v_lr, f_lr;
923
924                 pmeCoulombCorrectionVF<computeForces>(rSq, ewaldBeta, &v_lr, &f_lr);
925                 if constexpr (computeForces)
926                 {
927                     f_lr = f_lr * rInv * rInv;
928                 }
929
930                 /* Note that any possible Ewald shift has already been applied in
931                  * the normal interaction part above.
932                  */
933
934                 /* If ii == jnr the i particle (ii) has itself (jnr)
935                  * in its neighborlist. This corresponds to a self-interaction
936                  * that will occur twice. Scale it down by 50% to only include
937                  * it once.
938                  */
939                 v_lr = v_lr * gmx::blend(one, half, bIiEqJnr);
940
941                 for (int i = 0; i < NSTATES; i++)
942                 {
943                     vCTot = vCTot - gmx::selectByMask(LFC[i] * qq[i] * v_lr, computeElecEwaldInteraction);
944                     if constexpr (computeForces)
945                     {
946                         fScal = fScal - gmx::selectByMask(LFC[i] * qq[i] * f_lr, computeElecEwaldInteraction);
947                     }
948                     dvdlCoul = dvdlCoul
949                                - gmx::selectByMask(DLF[i] * qq[i] * v_lr, computeElecEwaldInteraction);
950                 }
951             }
952
953             const BoolType computeVdwEwaldInteraction = (bPairExcluded || r < rVdw);
954             if (vdwInteractionTypeIsEwald && gmx::anyTrue(computeVdwEwaldInteraction))
955             {
956                 /* See comment in the preamble. When using LJ-Ewald interactions
957                  * (unless we use a switch modifier) we subtract the reciprocal-space
958                  * Ewald component here which made it possible to apply the free
959                  * energy interaction to r^-6 (vanilla LJ6 short-range part)
960                  * above. This gets us closer to the ideal case of applying
961                  * the softcore to the entire VdW interaction,
962                  * including the reciprocal-space component.
963                  */
964
965                 RealType v_lr, f_lr;
966                 pmeLJCorrectionVF<computeForces>(
967                         rInv, rSq, ewaldLJCoeffSq, ewaldLJCoeffSixDivSix, &v_lr, &f_lr, computeVdwEwaldInteraction, bIiEqJnr);
968                 v_lr = v_lr * oneSixth;
969
970                 for (int i = 0; i < NSTATES; i++)
971                 {
972                     vVTot = vVTot + gmx::selectByMask(LFV[i] * ljPmeC6Grid[i] * v_lr, computeVdwEwaldInteraction);
973                     if constexpr (computeForces)
974                     {
975                         fScal = fScal + gmx::selectByMask(LFV[i] * ljPmeC6Grid[i] * f_lr, computeVdwEwaldInteraction);
976                     }
977                     dvdlVdw = dvdlVdw + gmx::selectByMask(DLF[i] * ljPmeC6Grid[i] * v_lr, computeVdwEwaldInteraction);
978                 }
979             }
980
981             if (computeForces && gmx::anyTrue(fScal != zero))
982             {
983                 const RealType tX = fScal * dX;
984                 const RealType tY = fScal * dY;
985                 const RealType tZ = fScal * dZ;
986                 fIX               = fIX + tX;
987                 fIY               = fIY + tY;
988                 fIZ               = fIZ + tZ;
989
990                 gmx::transposeScatterDecrU<3>(forceRealPtr, preloadJnr, tX, tY, tZ);
991             }
992         } // end for (int k = nj0; k < nj1; k += DataTypes::simdRealWidth)
993
994         if (havePairsWithinCutoff)
995         {
996             if constexpr (computeForces)
997             {
998                 gmx::transposeScatterIncrU<3>(forceRealPtr, preloadIi, fIX, fIY, fIZ);
999
1000                 if (doShiftForces)
1001                 {
1002                     gmx::transposeScatterIncrU<3>(
1003                             reinterpret_cast<real*>(threadForceShiftBuffer), preloadIs, fIX, fIY, fIZ);
1004                 }
1005             }
1006             if (doPotential)
1007             {
1008                 int ggid = gid[n];
1009                 threadVc[ggid] += gmx::reduce(vCTot);
1010                 threadVv[ggid] += gmx::reduce(vVTot);
1011             }
1012         }
1013     } // end for (int n = 0; n < nri; n++)
1014
1015     if (gmx::anyTrue(dvdlCoul != zero))
1016     {
1017         threadDvdl[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)] += gmx::reduce(dvdlCoul);
1018     }
1019     if (gmx::anyTrue(dvdlVdw != zero))
1020     {
1021         threadDvdl[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)] += gmx::reduce(dvdlVdw);
1022     }
1023
1024     /* Estimate flops, average for free energy stuff:
1025      * 12  flops per outer iteration
1026      * 150 flops per inner iteration
1027      * TODO: Update the number of flops and/or use different counts for different code paths.
1028      */
1029     atomicNrnbIncrement(nrnb, eNR_NBKERNEL_FREE_ENERGY, nlist.nri * 12 + nlist.jindex[nri] * 150);
1030
1031     if (haveExcludedPairsBeyondRlist > 0)
1032     {
1033         gmx_fatal(FARGS,
1034                   "There are perturbed non-bonded pair interactions beyond the pair-list cutoff "
1035                   "of %g nm, which is not supported. This can happen because the system is "
1036                   "unstable or because intra-molecular interactions at long distances are "
1037                   "excluded. If the "
1038                   "latter is the case, you can try to increase nstlist or rlist to avoid this."
1039                   "The error is likely triggered by the use of couple-intramol=no "
1040                   "and the maximal distance in the decoupled molecule exceeding rlist.",
1041                   rlist);
1042     }
1043 }
1044
1045 typedef void (*KernelFunction)(const t_nblist&                                  nlist,
1046                                const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
1047                                const int                                        ntype,
1048                                const real                                       rlist,
1049                                const interaction_const_t&                       ic,
1050                                gmx::ArrayRef<const gmx::RVec>                   shiftvec,
1051                                gmx::ArrayRef<const real>                        nbfp,
1052                                gmx::ArrayRef<const real>                        nbfp_grid,
1053                                gmx::ArrayRef<const real>                        chargeA,
1054                                gmx::ArrayRef<const real>                        chargeB,
1055                                gmx::ArrayRef<const int>                         typeA,
1056                                gmx::ArrayRef<const int>                         typeB,
1057                                int                                              flags,
1058                                gmx::ArrayRef<const real>                        lambda,
1059                                t_nrnb* gmx_restrict                             nrnb,
1060                                gmx::ArrayRefWithPadding<gmx::RVec>              threadForceBuffer,
1061                                rvec*               threadForceShiftBuffer,
1062                                gmx::ArrayRef<real> threadVc,
1063                                gmx::ArrayRef<real> threadVv,
1064                                gmx::ArrayRef<real> threadDvdl);
1065
1066 template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch, bool computeForces>
1067 static KernelFunction dispatchKernelOnUseSimd(const bool useSimd)
1068 {
1069     if (useSimd)
1070     {
1071 #if GMX_SIMD_HAVE_REAL && GMX_SIMD_HAVE_INT32_ARITHMETICS && GMX_USE_SIMD_KERNELS
1072         return (nb_free_energy_kernel<SimdDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, computeForces>);
1073 #else
1074         return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, computeForces>);
1075 #endif
1076     }
1077     else
1078     {
1079         return (nb_free_energy_kernel<ScalarDataTypes, useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, computeForces>);
1080     }
1081 }
1082
1083 template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald, bool vdwModifierIsPotSwitch>
1084 static KernelFunction dispatchKernelOnComputeForces(const bool computeForces, const bool useSimd)
1085 {
1086     if (computeForces)
1087     {
1088         return (dispatchKernelOnUseSimd<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, true>(
1089                 useSimd));
1090     }
1091     else
1092     {
1093         return (dispatchKernelOnUseSimd<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, false>(
1094                 useSimd));
1095     }
1096 }
1097
1098 template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald, bool elecInteractionTypeIsEwald>
1099 static KernelFunction dispatchKernelOnVdwModifier(const bool vdwModifierIsPotSwitch,
1100                                                   const bool computeForces,
1101                                                   const bool useSimd)
1102 {
1103     if (vdwModifierIsPotSwitch)
1104     {
1105         return (dispatchKernelOnComputeForces<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, true>(
1106                 computeForces, useSimd));
1107     }
1108     else
1109     {
1110         return (dispatchKernelOnComputeForces<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, false>(
1111                 computeForces, useSimd));
1112     }
1113 }
1114
1115 template<bool useSoftCore, bool scLambdasOrAlphasDiffer, bool vdwInteractionTypeIsEwald>
1116 static KernelFunction dispatchKernelOnElecInteractionType(const bool elecInteractionTypeIsEwald,
1117                                                           const bool vdwModifierIsPotSwitch,
1118                                                           const bool computeForces,
1119                                                           const bool useSimd)
1120 {
1121     if (elecInteractionTypeIsEwald)
1122     {
1123         return (dispatchKernelOnVdwModifier<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, true>(
1124                 vdwModifierIsPotSwitch, computeForces, useSimd));
1125     }
1126     else
1127     {
1128         return (dispatchKernelOnVdwModifier<useSoftCore, scLambdasOrAlphasDiffer, vdwInteractionTypeIsEwald, false>(
1129                 vdwModifierIsPotSwitch, computeForces, useSimd));
1130     }
1131 }
1132
1133 template<bool useSoftCore, bool scLambdasOrAlphasDiffer>
1134 static KernelFunction dispatchKernelOnVdwInteractionType(const bool vdwInteractionTypeIsEwald,
1135                                                          const bool elecInteractionTypeIsEwald,
1136                                                          const bool vdwModifierIsPotSwitch,
1137                                                          const bool computeForces,
1138                                                          const bool useSimd)
1139 {
1140     if (vdwInteractionTypeIsEwald)
1141     {
1142         return (dispatchKernelOnElecInteractionType<useSoftCore, scLambdasOrAlphasDiffer, true>(
1143                 elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, computeForces, useSimd));
1144     }
1145     else
1146     {
1147         return (dispatchKernelOnElecInteractionType<useSoftCore, scLambdasOrAlphasDiffer, false>(
1148                 elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, computeForces, useSimd));
1149     }
1150 }
1151
1152 template<bool useSoftCore>
1153 static KernelFunction dispatchKernelOnScLambdasOrAlphasDifference(const bool scLambdasOrAlphasDiffer,
1154                                                                   const bool vdwInteractionTypeIsEwald,
1155                                                                   const bool elecInteractionTypeIsEwald,
1156                                                                   const bool vdwModifierIsPotSwitch,
1157                                                                   const bool computeForces,
1158                                                                   const bool useSimd)
1159 {
1160     if (scLambdasOrAlphasDiffer)
1161     {
1162         return (dispatchKernelOnVdwInteractionType<useSoftCore, true>(
1163                 vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, computeForces, useSimd));
1164     }
1165     else
1166     {
1167         return (dispatchKernelOnVdwInteractionType<useSoftCore, false>(
1168                 vdwInteractionTypeIsEwald, elecInteractionTypeIsEwald, vdwModifierIsPotSwitch, computeForces, useSimd));
1169     }
1170 }
1171
1172 static KernelFunction dispatchKernel(const bool                 scLambdasOrAlphasDiffer,
1173                                      const bool                 vdwInteractionTypeIsEwald,
1174                                      const bool                 elecInteractionTypeIsEwald,
1175                                      const bool                 vdwModifierIsPotSwitch,
1176                                      const bool                 computeForces,
1177                                      const bool                 useSimd,
1178                                      const interaction_const_t& ic)
1179 {
1180     if (ic.softCoreParameters->alphaCoulomb == 0 && ic.softCoreParameters->alphaVdw == 0)
1181     {
1182         return (dispatchKernelOnScLambdasOrAlphasDifference<false>(scLambdasOrAlphasDiffer,
1183                                                                    vdwInteractionTypeIsEwald,
1184                                                                    elecInteractionTypeIsEwald,
1185                                                                    vdwModifierIsPotSwitch,
1186                                                                    computeForces,
1187                                                                    useSimd));
1188     }
1189     else
1190     {
1191         return (dispatchKernelOnScLambdasOrAlphasDifference<true>(scLambdasOrAlphasDiffer,
1192                                                                   vdwInteractionTypeIsEwald,
1193                                                                   elecInteractionTypeIsEwald,
1194                                                                   vdwModifierIsPotSwitch,
1195                                                                   computeForces,
1196                                                                   useSimd));
1197     }
1198 }
1199
1200
1201 void gmx_nb_free_energy_kernel(const t_nblist&                                  nlist,
1202                                const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
1203                                const bool                                       useSimd,
1204                                const int                                        ntype,
1205                                const real                                       rlist,
1206                                const interaction_const_t&                       ic,
1207                                gmx::ArrayRef<const gmx::RVec>                   shiftvec,
1208                                gmx::ArrayRef<const real>                        nbfp,
1209                                gmx::ArrayRef<const real>                        nbfp_grid,
1210                                gmx::ArrayRef<const real>                        chargeA,
1211                                gmx::ArrayRef<const real>                        chargeB,
1212                                gmx::ArrayRef<const int>                         typeA,
1213                                gmx::ArrayRef<const int>                         typeB,
1214                                int                                              flags,
1215                                gmx::ArrayRef<const real>                        lambda,
1216                                t_nrnb*                                          nrnb,
1217                                gmx::ArrayRefWithPadding<gmx::RVec>              threadForceBuffer,
1218                                rvec*               threadForceShiftBuffer,
1219                                gmx::ArrayRef<real> threadVc,
1220                                gmx::ArrayRef<real> threadVv,
1221                                gmx::ArrayRef<real> threadDvdl)
1222 {
1223     GMX_ASSERT(EEL_PME_EWALD(ic.eeltype) || ic.eeltype == CoulombInteractionType::Cut || EEL_RF(ic.eeltype),
1224                "Unsupported eeltype with free energy");
1225     GMX_ASSERT(ic.softCoreParameters, "We need soft-core parameters");
1226
1227     // Not all SIMD implementations need padding, but we provide padding anyhow so we can assert
1228     GMX_ASSERT(!GMX_SIMD_HAVE_REAL || threadForceBuffer.empty()
1229                        || threadForceBuffer.size() > threadForceBuffer.unpaddedArrayRef().ssize(),
1230                "We need actual padding with at least one element for SIMD scatter operations");
1231
1232     const auto& scParams                   = *ic.softCoreParameters;
1233     const bool  vdwInteractionTypeIsEwald  = (EVDW_PME(ic.vdwtype));
1234     const bool  elecInteractionTypeIsEwald = (EEL_PME_EWALD(ic.eeltype));
1235     const bool  vdwModifierIsPotSwitch     = (ic.vdw_modifier == InteractionModifiers::PotSwitch);
1236     const bool  computeForces              = ((flags & GMX_NONBONDED_DO_FORCE) != 0);
1237     bool        scLambdasOrAlphasDiffer    = true;
1238
1239     if (scParams.alphaCoulomb == 0 && scParams.alphaVdw == 0)
1240     {
1241         scLambdasOrAlphasDiffer = false;
1242     }
1243     else
1244     {
1245         if (lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Coul)]
1246                     == lambda[static_cast<int>(FreeEnergyPerturbationCouplingType::Vdw)]
1247             && scParams.alphaCoulomb == scParams.alphaVdw)
1248         {
1249             scLambdasOrAlphasDiffer = false;
1250         }
1251     }
1252
1253     KernelFunction kernelFunc;
1254     kernelFunc = dispatchKernel(scLambdasOrAlphasDiffer,
1255                                 vdwInteractionTypeIsEwald,
1256                                 elecInteractionTypeIsEwald,
1257                                 vdwModifierIsPotSwitch,
1258                                 computeForces,
1259                                 useSimd,
1260                                 ic);
1261     kernelFunc(nlist,
1262                coords,
1263                ntype,
1264                rlist,
1265                ic,
1266                shiftvec,
1267                nbfp,
1268                nbfp_grid,
1269                chargeA,
1270                chargeB,
1271                typeA,
1272                typeB,
1273                flags,
1274                lambda,
1275                nrnb,
1276                threadForceBuffer,
1277                threadForceShiftBuffer,
1278                threadVc,
1279                threadVv,
1280                threadDvdl);
1281 }