Make pull with COM from previous step work with MPI
[alexxy/gromacs.git] / src / gromacs / pulling / pullutil.cpp
index b9732e40eedbc6f364e8d00c37f6cdfcdea6c1f3..6d686f4689a08906a7d579e634f1c683291ace14 100644 (file)
@@ -990,29 +990,18 @@ bool pullCheckPbcWithinGroup(const pull_t                  &pull,
     return (pullGroupObeysPbcRestrictions(group, dimUsed, as_rvec_array(x.data()), pbc, pull.comm.pbcAtomBuffer[groupNr], pbcMargin));
 }
 
-void setStatePrevStepPullCom(const struct pull_t *pull, t_state *state)
-{
-    for (size_t i = 0; i < state->com_prev_step.size()/DIM; i++)
-    {
-        for (int j = 0; j < DIM; j++)
-        {
-            state->com_prev_step[i*DIM+j] = pull->group[i].x_prev_step[j];
-        }
-    }
-}
-
 void setPrevStepPullComFromState(struct pull_t *pull, const t_state *state)
 {
-    for (size_t i = 0; i < state->com_prev_step.size()/DIM; i++)
+    for (size_t g = 0; g < pull->group.size(); g++)
     {
         for (int j = 0; j < DIM; j++)
         {
-            pull->group[i].x_prev_step[j] = state->com_prev_step[i*DIM+j];
+            pull->group[g].x_prev_step[j] = state->pull_com_prev_step[g*DIM+j];
         }
     }
 }
 
-void updatePrevStepCom(struct pull_t *pull)
+void updatePrevStepPullCom(struct pull_t *pull, t_state *state)
 {
     for (size_t g = 0; g < pull->group.size(); g++)
     {
@@ -1020,7 +1009,8 @@ void updatePrevStepCom(struct pull_t *pull)
         {
             for (int j = 0; j < DIM; j++)
             {
-                pull->group[g].x_prev_step[j] = pull->group[g].x[j];
+                pull->group[g].x_prev_step[j]      = pull->group[g].x[j];
+                state->pull_com_prev_step[g*DIM+j] = pull->group[g].x[j];
             }
         }
     }
@@ -1030,13 +1020,13 @@ void allocStatePrevStepPullCom(t_state *state, pull_t *pull)
 {
     if (!pull)
     {
-        state->com_prev_step.clear();
+        state->pull_com_prev_step.clear();
         return;
     }
     size_t ngroup = pull->group.size();
-    if (state->com_prev_step.size()/DIM != ngroup)
+    if (state->pull_com_prev_step.size()/DIM != ngroup)
     {
-        state->com_prev_step.resize(ngroup * DIM, NAN);
+        state->pull_com_prev_step.resize(ngroup * DIM, NAN);
     }
 }
 
@@ -1049,8 +1039,16 @@ void initPullComFromPrevStep(const t_commrec *cr,
     pull_comm_t *comm   = &pull->comm;
     size_t       ngroup = pull->group.size();
 
-    comm->pbcAtomBuffer.resize(ngroup);
-    comm->comBuffer.resize(ngroup*DIM);
+    if (!comm->bParticipate)
+    {
+        return;
+    }
+
+    GMX_ASSERT(comm->pbcAtomBuffer.size() == pull->group.size(), "pbcAtomBuffer should have size number of groups");
+    GMX_ASSERT(comm->comBuffer.size() == pull->group.size()*c_comBufferStride,
+               "comBuffer should have size #group*c_comBufferStride");
+
+    pull_set_pbcatoms(cr, pull, x, comm->pbcAtomBuffer);
 
     for (size_t g = 0; g < ngroup; g++)
     {
@@ -1064,7 +1062,6 @@ void initPullComFromPrevStep(const t_commrec *cr,
                        "use the COM from the previous step as reference.");
 
             rvec x_pbc = { 0, 0, 0 };
-            pull_set_pbcatoms(cr, pull, x, comm->pbcAtomBuffer);
             copy_rvec(comm->pbcAtomBuffer[g], x_pbc);
 
             if (debug)
@@ -1118,11 +1115,14 @@ void initPullComFromPrevStep(const t_commrec *cr,
             }
 
             /* Copy local sums to a buffer for global summing */
-            copy_dvec(comSumsTotal.sum_wmx,  comm->comBuffer[g*3]);
-            copy_dvec(comSumsTotal.sum_wmxp, comm->comBuffer[g*3 + 1]);
-            comm->comBuffer[g*3 + 2][0] = comSumsTotal.sum_wm;
-            comm->comBuffer[g*3 + 2][1] = comSumsTotal.sum_wwm;
-            comm->comBuffer[g*3 + 2][2] = 0;
+            auto localSums =
+                gmx::arrayRefFromArray(comm->comBuffer.data() + g*c_comBufferStride, c_comBufferStride);
+
+            localSums[0]    = comSumsTotal.sum_wmx;
+            localSums[1]    = comSumsTotal.sum_wmxp;
+            localSums[2][0] = comSumsTotal.sum_wm;
+            localSums[2][1] = comSumsTotal.sum_wwm;
+            localSums[2][2] = 0;
         }
     }
 
@@ -1137,11 +1137,13 @@ void initPullComFromPrevStep(const t_commrec *cr,
         {
             if (pgrp->epgrppbc == epgrppbcPREVSTEPCOM)
             {
+                auto   localSums =
+                    gmx::arrayRefFromArray(comm->comBuffer.data() + g*c_comBufferStride, c_comBufferStride);
                 double wmass, wwmass;
 
                 /* Determine the inverse mass */
-                wmass             = comm->comBuffer[g*3+2][0];
-                wwmass            = comm->comBuffer[g*3+2][1];
+                wmass             = localSums[2][0];
+                wwmass            = localSums[2][1];
                 pgrp->mwscale     = 1.0/wmass;
                 /* invtm==0 signals a frozen group, so then we should keep it zero */
                 if (pgrp->invtm != 0)
@@ -1152,11 +1154,8 @@ void initPullComFromPrevStep(const t_commrec *cr,
                 /* Divide by the total mass */
                 for (int m = 0; m < DIM; m++)
                 {
-                    pgrp->x[m]    = comm->comBuffer[g*3  ][m]*pgrp->mwscale;
-                    if (pgrp->epgrppbc == epgrppbcREFAT || pgrp->epgrppbc == epgrppbcPREVSTEPCOM)
-                    {
-                        pgrp->x[m]     += comm->pbcAtomBuffer[g][m];
-                    }
+                    pgrp->x[m]  = localSums[0][m]*pgrp->mwscale;
+                    pgrp->x[m] += comm->pbcAtomBuffer[g][m];
                 }
                 if (debug)
                 {