Refactor wall potential calculation
authorBerk Hess <hess@kth.se>
Mon, 2 Oct 2017 19:41:09 +0000 (21:41 +0200)
committerMark Abraham <mark.j.abraham@gmail.com>
Wed, 29 Aug 2018 07:32:49 +0000 (09:32 +0200)
Changed the wall potential calculation to use ForceWithVirial instead
of the coord x forces sum plus a virial correction. This change removes
the last virial buffer from t_forcerec.
Also refactored do_walls().

Change-Id: I23e56b6e08c57bd03363646f1f968bf0e251faa2

src/gromacs/mdlib/force.cpp
src/gromacs/mdlib/forcerec.cpp
src/gromacs/mdlib/sim_util.cpp
src/gromacs/mdlib/wall.cpp
src/gromacs/mdlib/wall.h
src/gromacs/mdtypes/forcerec.h

index f404d1b9701b8934c11646c2ad48545cca8b1a52..93e623f128bb288eb15ed3188c998fb9bc2e591f 100644 (file)
@@ -201,7 +201,8 @@ void do_force_lowlevel(t_forcerec           *fr,
     if (ir->nwall)
     {
         /* foreign lambda component for walls */
-        real dvdl_walls = do_walls(ir, fr, box, md, x, forceForUseWithShiftForces, lambda[efptVDW],
+        real dvdl_walls = do_walls(*ir, *fr, box, *md, x,
+                                   forceWithVirial, lambda[efptVDW],
                                    enerd->grpp.ener[egLJSR], nrnb);
         enerd->dvdl_lin[efptVDW] += dvdl_walls;
     }
index c60efe0d21fb90160f7579d274338f970ec4f431..c23876e2843be5507ef91fbff1e7b1ad29905455 100644 (file)
@@ -2761,6 +2761,7 @@ void init_forcerec(FILE                             *fp,
          fr->forceProviders->hasForceProvider() ||
          gmx_mtop_ftype_count(mtop, F_POSRES) > 0 ||
          gmx_mtop_ftype_count(mtop, F_FBPOSRES) > 0 ||
+         ir->nwall > 0 ||
          ir->bPull ||
          ir->bRot ||
          ir->bIMD);
index 82b986ecf45e7b6bf6c0cba6c3ead412acd15c58..2e60179c163afee35dacc4689142256e8a7c44bd 100644 (file)
@@ -251,8 +251,6 @@ static void calc_virial(int start, int homenr, rvec x[], rvec f[],
                         tensor vir_part, t_graph *graph, matrix box,
                         t_nrnb *nrnb, const t_forcerec *fr, int ePBC)
 {
-    int    i;
-
     /* The short-range virial from surrounding boxes */
     calc_vir(SHIFTS, fr->shift_vec, fr->fshift, vir_part, ePBC == epbcSCREW, box);
     inc_nrnb(nrnb, eNR_VIRIAL, SHIFTS);
@@ -263,12 +261,6 @@ static void calc_virial(int start, int homenr, rvec x[], rvec f[],
     f_calc_vir(start, start+homenr, x, f, vir_part, graph, box);
     inc_nrnb(nrnb, eNR_VIRIAL, homenr);
 
-    /* Add wall contribution */
-    for (i = 0; i < DIM; i++)
-    {
-        vir_part[i][ZZ] += fr->vir_wall_z[i];
-    }
-
     if (debug)
     {
         pr_rvecs(debug, 0, "vir_part", vir_part, DIM);
index 7933d2670d7bd60d9ded941508f0c57d88eef8b4..04bbd147113662aae0216221a086c567ca19c9af 100644 (file)
@@ -47,6 +47,7 @@
 #include "gromacs/gmxlib/nrnb.h"
 #include "gromacs/math/utilities.h"
 #include "gromacs/math/vec.h"
+#include "gromacs/mdtypes/forceoutput.h"
 #include "gromacs/mdtypes/inputrec.h"
 #include "gromacs/mdtypes/md_enums.h"
 #include "gromacs/mdtypes/mdatom.h"
@@ -114,81 +115,133 @@ void make_wall_tables(FILE *fplog,
               x[a][XX], x[a][YY], x[a][ZZ], r);
 }
 
-real do_walls(const t_inputrec *ir, t_forcerec *fr, matrix box, const t_mdatoms *md,
-              const rvec x[], rvec f[], real lambda, real Vlj[], t_nrnb *nrnb)
+static void tableForce(real                r,
+                       const t_forcetable &tab,
+                       real                Cd,
+                       real                Cr,
+                       real               *V,
+                       real               *F)
 {
-    int             nwall;
-    int             ntw[2], at, ntype, ngid, ggid, *egp_flags, *type;
-    real           *nbfp, lamfac, fac_d[2], fac_r[2], Cd, Cr;
-    real            wall_z[2], r, mr, r1, r2, r4, Vd, Vr, V = 0, Fd, Fr, F = 0, dvdlambda;
-    dvec            xf_z;
-    int             n0, nnn;
-    real            tabscale, *VFtab, rt, eps, eps2, Yt, Ft, Geps, Heps2, Fp, VV, FF;
-    unsigned short *gid = md->cENER;
-    t_forcetable   *tab;
+    const real  tabscale = tab.scale;
+    const real *VFtab    = tab.data;
 
-    nwall     = ir->nwall;
-    ngid      = ir->opts.ngener;
-    ntype     = fr->ntype;
-    nbfp      = fr->nbfp;
-    egp_flags = fr->egp_flags;
+    real        rt = r*tabscale;
+    int         n0 = static_cast<int>(rt);
+    if (n0 >= tab.n)
+    {
+        /* Beyond the table range, set V and F to zero */
+        *V         = 0;
+        *F         = 0;
+    }
+    else
+    {
+        real eps   = rt - n0;
+        real eps2  = eps*eps;
+        /* Dispersion */
+        int  nnn   = 8*n0;
+        real Yt    = VFtab[nnn];
+        real Ft    = VFtab[nnn + 1];
+        real Geps  = VFtab[nnn + 2]*eps;
+        real Heps2 = VFtab[nnn + 3]*eps2;
+        real Fp    = Ft + Geps + Heps2;
+        real VV    = Yt + Fp*eps;
+        real FF    = Fp + Geps + 2.0*Heps2;
+        real Vd    = 6*Cd*VV;
+        real Fd    = 6*Cd*FF;
+        /* Repulsion */
+        nnn        = nnn + 4;
+        Yt         = VFtab[nnn];
+        Ft         = VFtab[nnn+1];
+        Geps       = VFtab[nnn+2]*eps;
+        Heps2      = VFtab[nnn+3]*eps2;
+        Fp         = Ft + Geps + Heps2;
+        VV         = Yt + Fp*eps;
+        FF         = Fp + Geps + 2.0*Heps2;
+        real Vr    = 12*Cr*VV;
+        real Fr    = 12*Cr*FF;
+        *V         = Vd + Vr;
+        *F         = -(Fd + Fr)*tabscale;
+    }
+}
+
+real do_walls(const t_inputrec &ir, const t_forcerec &fr,
+              const matrix box, const t_mdatoms &md,
+              const rvec *x, gmx::ForceWithVirial *forceWithVirial,
+              real lambda, real Vlj[], t_nrnb *nrnb)
+{
+    constexpr real        sixth   = 1.0/6.0;
+    constexpr real        twelfth = 1.0/12.0;
+
+    int                   ntw[2];
+    real                  fac_d[2], fac_r[2];
+    const unsigned short *gid = md.cENER;
+
+    const int             nwall     = ir.nwall;
+    const int             ngid      = ir.opts.ngener;
+    const int             ntype     = fr.ntype;
+    const real           *nbfp      = fr.nbfp;
+    const int            *egp_flags = fr.egp_flags;
 
     for (int w = 0; w < nwall; w++)
     {
-        ntw[w] = 2*ntype*ir->wall_atomtype[w];
-        switch (ir->wall_type)
+        ntw[w] = 2*ntype*ir.wall_atomtype[w];
+        switch (ir.wall_type)
         {
             case ewt93:
-                fac_d[w] = ir->wall_density[w]*M_PI/6;
-                fac_r[w] = ir->wall_density[w]*M_PI/45;
+                fac_d[w] = ir.wall_density[w]*M_PI/6;
+                fac_r[w] = ir.wall_density[w]*M_PI/45;
                 break;
             case ewt104:
-                fac_d[w] = ir->wall_density[w]*M_PI/2;
-                fac_r[w] = ir->wall_density[w]*M_PI/5;
+                fac_d[w] = ir.wall_density[w]*M_PI/2;
+                fac_r[w] = ir.wall_density[w]*M_PI/5;
                 break;
             default:
                 break;
         }
     }
-    wall_z[0] = 0;
-    wall_z[1] = box[ZZ][ZZ];
+    const real          wall_z[2] = { 0, box[ZZ][ZZ] };
 
-    dvdlambda = 0;
-    clear_dvec(xf_z);
-    for (int lam = 0; lam < (md->nPerturbed ? 2 : 1); lam++)
+    rvec * gmx_restrict f = as_rvec_array(forceWithVirial->force_.data());
+
+    real                dvdlambda = 0;
+    double              sumRF     = 0;
+    for (int lam = 0; lam < (md.nPerturbed ? 2 : 1); lam++)
     {
-        if (md->nPerturbed)
+        real       lamfac;
+        const int *type;
+        if (md.nPerturbed)
         {
             if (lam == 0)
             {
                 lamfac = 1 - lambda;
-                type   = md->typeA;
+                type   = md.typeA;
             }
             else
             {
                 lamfac = lambda;
-                type   = md->typeB;
+                type   = md.typeB;
             }
         }
         else
         {
             lamfac = 1;
-            type   = md->typeA;
+            type   = md.typeA;
         }
 
         real Vlambda = 0;
-        for (int i = 0; i < md->homenr; i++)
+        for (int i = 0; i < md.homenr; i++)
         {
             for (int w = 0; w < std::min(nwall, 2); w++)
             {
                 /* The wall energy groups are always at the end of the list */
-                ggid = gid[i]*ngid + ngid - nwall + w;
-                at   = type[i];
+                const int  ggid = gid[i]*ngid + ngid - nwall + w;
+                const int  at   = type[i];
                 /* nbfp now includes the 6/12 derivative prefactors */
-                Cd = nbfp[ntw[w]+2*at]/6;
-                Cr = nbfp[ntw[w]+2*at+1]/12;
+                const real Cd = nbfp[ntw[w] + 2*at]*sixth;
+                const real Cr = nbfp[ntw[w] + 2*at + 1]*twelfth;
                 if (!((Cd == 0 && Cr == 0) || (egp_flags[ggid] & EGP_EXCL)))
                 {
+                    real r, mr;
                     if (w == 0)
                     {
                         r = x[i][ZZ];
@@ -197,69 +250,29 @@ real do_walls(const t_inputrec *ir, t_forcerec *fr, matrix box, const t_mdatoms
                     {
                         r = wall_z[1] - x[i][ZZ];
                     }
-                    if (r < ir->wall_r_linpot)
+                    if (r < ir.wall_r_linpot)
                     {
-                        mr = ir->wall_r_linpot - r;
-                        r  = ir->wall_r_linpot;
+                        mr = ir.wall_r_linpot - r;
+                        r  = ir.wall_r_linpot;
                     }
                     else
                     {
                         mr = 0;
                     }
-                    switch (ir->wall_type)
+                    if (r <= 0)
                     {
-                        case ewtTABLE:
-                            if (r < 0)
-                            {
-                                wall_error(i, x, r);
-                            }
-                            tab      = fr->wall_tab[w][gid[i]];
-                            tabscale = tab->scale;
-                            VFtab    = tab->data;
+                        wall_error(i, x, r);
+                    }
 
-                            rt    = r*tabscale;
-                            n0    = static_cast<int>(rt);
-                            if (n0 >= tab->n)
-                            {
-                                /* Beyond the table range, set V and F to zero */
-                                V     = 0;
-                                F     = 0;
-                            }
-                            else
-                            {
-                                eps   = rt - n0;
-                                eps2  = eps*eps;
-                                /* Dispersion */
-                                nnn   = 8*n0;
-                                Yt    = VFtab[nnn];
-                                Ft    = VFtab[nnn+1];
-                                Geps  = VFtab[nnn+2]*eps;
-                                Heps2 = VFtab[nnn+3]*eps2;
-                                Fp    = Ft + Geps + Heps2;
-                                VV    = Yt + Fp*eps;
-                                FF    = Fp + Geps + 2.0*Heps2;
-                                Vd    = 6*Cd*VV;
-                                Fd    = 6*Cd*FF;
-                                /* Repulsion */
-                                nnn   = nnn + 4;
-                                Yt    = VFtab[nnn];
-                                Ft    = VFtab[nnn+1];
-                                Geps  = VFtab[nnn+2]*eps;
-                                Heps2 = VFtab[nnn+3]*eps2;
-                                Fp    = Ft + Geps + Heps2;
-                                VV    = Yt + Fp*eps;
-                                FF    = Fp + Geps + 2.0*Heps2;
-                                Vr    = 12*Cr*VV;
-                                Fr    = 12*Cr*FF;
-                                V     = Vd + Vr;
-                                F     = -lamfac*(Fd + Fr)*tabscale;
-                            }
+                    real V, F;
+                    real r1, r2, r4, Vd, Vr;
+                    switch (ir.wall_type)
+                    {
+                        case ewtTABLE:
+                            tableForce(r, *fr.wall_tab[w][gid[i]], Cd, Cr, &V, &F);
+                            F *= lamfac;
                             break;
                         case ewt93:
-                            if (r <= 0)
-                            {
-                                wall_error(i, x, r);
-                            }
                             r1 = 1/r;
                             r2 = r1*r1;
                             r4 = r2*r2;
@@ -269,10 +282,6 @@ real do_walls(const t_inputrec *ir, t_forcerec *fr, matrix box, const t_mdatoms
                             F  = lamfac*(9*Vr - 3*Vd)*r1;
                             break;
                         case ewt104:
-                            if (r <= 0)
-                            {
-                                wall_error(i, x, r);
-                            }
                             r1 = 1/r;
                             r2 = r1*r1;
                             r4 = r2*r2;
@@ -282,10 +291,6 @@ real do_walls(const t_inputrec *ir, t_forcerec *fr, matrix box, const t_mdatoms
                             F  = lamfac*(10*Vr - 4*Vd)*r1;
                             break;
                         case ewt126:
-                            if (r <= 0)
-                            {
-                                wall_error(i, x, r);
-                            }
                             r1 = 1/r;
                             r2 = r1*r1;
                             r4 = r2*r2;
@@ -295,45 +300,37 @@ real do_walls(const t_inputrec *ir, t_forcerec *fr, matrix box, const t_mdatoms
                             F  = lamfac*(12*Vr - 6*Vd)*r1;
                             break;
                         default:
+                            V  = 0;
+                            F  = 0;
                             break;
                     }
                     if (mr > 0)
                     {
-                        V += mr*F;
+                        V     += mr*F;
                     }
+                    sumRF     += r*F;
                     if (w == 1)
                     {
-                        F = -F;
+                        F      = -F;
                     }
                     Vlj[ggid] += lamfac*V;
                     Vlambda   += V;
                     f[i][ZZ]  += F;
-                    /* Because of the single sum virial calculation we need
-                     * to add  the full virial contribution of the walls.
-                     * Since the force only has a z-component, there is only
-                     * a contribution to the z component of the virial tensor.
-                     * We could also determine the virial contribution directly,
-                     * which would be cheaper here, but that would require extra
-                     * communication for f_novirsum for with virtual sites
-                     * in parallel.
-                     */
-                    xf_z[XX]  -= x[i][XX]*F;
-                    xf_z[YY]  -= x[i][YY]*F;
-                    xf_z[ZZ]  -= wall_z[w]*F;
                 }
             }
         }
-        if (md->nPerturbed)
+        if (md.nPerturbed)
         {
             dvdlambda += (lam == 0 ? -1 : 1)*Vlambda;
         }
 
-        inc_nrnb(nrnb, eNR_WALLS, md->homenr);
+        inc_nrnb(nrnb, eNR_WALLS, md.homenr);
     }
 
-    for (int i = 0; i < DIM; i++)
+    if (forceWithVirial->computeVirial_)
     {
-        fr->vir_wall_z[i] = -0.5*xf_z[i];
+        rvec virial = { 0, 0, static_cast<real>(-0.5*sumRF) };
+        forceWithVirial->addVirialContribution(virial);
     }
 
     return dvdlambda;
index 2120d64921b96427e75a7805d197a86d860e90b0..a706845e941fb34453f8c69222b6f71087d9c30a 100644 (file)
@@ -46,19 +46,24 @@ struct t_inputrec;
 struct t_mdatoms;
 struct t_nrnb;
 
+namespace gmx
+{
+class ForceWithVirial;
+}
+
 void make_wall_tables(FILE *fplog,
                       const t_inputrec *ir, const char *tabfn,
                       const gmx_groups_t *groups,
                       t_forcerec *fr);
 
-real do_walls(const t_inputrec *ir,
-              t_forcerec       *fr,
-              matrix            box,
-              const t_mdatoms  *md,
-              const rvec        x[],
-              rvec              f[],
-              real              lambda,
-              real              Vlj[],
-              t_nrnb           *nrnb);
+real do_walls(const t_inputrec      &ir,
+              const t_forcerec      &fr,
+              const matrix           box,
+              const t_mdatoms       &md,
+              const rvec             x[],
+              gmx::ForceWithVirial  *forceWithVirial,
+              real                   lambda,
+              real                   Vlj[],
+              t_nrnb                *nrnb);
 
 #endif
index ea64b519ca181c50815c61e674cd951e0a59fe66..7a04515cdd9e10e340316efe2eef1a3d2a34e6e1 100644 (file)
@@ -265,9 +265,8 @@ struct t_forcerec { // NOLINT (clang-analyzer-optin.performance.Padding)
     /* PME/Ewald stuff */
     struct gmx_ewald_tab_t *ewald_table;
 
-    /* Virial Stuff */
+    /* Shift force array for computing the virial */
     rvec *fshift;
-    dvec  vir_wall_z;
 
     /* Non bonded Parameter lists */
     int      ntype; /* Number of atom types */