Add second LINCS atom update task
[alexxy/gromacs.git] / src / gromacs / mdlib / lincs.cpp
index 3c2f2d2547a9a859a0866596dafbe4dc6ce8cc86..d7440620b142188fde444e35d7625d9fe0e4240c 100644 (file)
@@ -115,10 +115,12 @@ struct Task
     std::vector<int> triangle;
     //! The bits tell if the matrix element should be used.
     std::vector<int> tri_bits;
-    //! Constraint index for updating atom data.
-    std::vector<int> ind;
-    //! Constraint index for updating atom data.
-    std::vector<int> ind_r;
+    //! Constraint indices for updating atom data.
+    std::vector<int> updateConstraintIndices1;
+    //! Constraint indices for updating atom data, second group.
+    std::vector<int> updateConstraintIndices2;
+    //! Temporay constraint indices for setting up updating of atom data.
+    std::vector<int> updateConstraintIndicesRest;
     //! Temporary variable for virial calculation.
     tensor vir_r_m_dr = { { 0 } };
     //! Temporary variable for lambda derivative.
@@ -198,6 +200,8 @@ public:
     bool bTaskDep = false;
     //! Are there triangle constraints that cross task borders?
     bool bTaskDepTri = false;
+    //! Whether any task has constraints in the second update list.
+    bool haveSecondUpdateTask = false;
     //! Arrays for temporary storage in the LINCS algorithm.
     /*! @{ */
     PaddedVector<gmx::RVec>                   tmpv;
@@ -472,9 +476,18 @@ static void lincs_update_atoms(Lincs*                         li,
          * constraints that only access our local atom range.
          * This can be done without a barrier.
          */
-        lincs_update_atoms_ind(li->task[th].ind, li->atoms, preFactor, fac, r, invmass, x);
+        lincs_update_atoms_ind(
+                li->task[th].updateConstraintIndices1, li->atoms, preFactor, fac, r, invmass, x);
 
-        if (!li->task[li->ntask].ind.empty())
+        if (li->haveSecondUpdateTask)
+        {
+            /* Second round of update, we need a barrier for cross-task access of x */
+#pragma omp barrier
+            lincs_update_atoms_ind(
+                    li->task[th].updateConstraintIndices2, li->atoms, preFactor, fac, r, invmass, x);
+        }
+
+        if (!li->task[li->ntask].updateConstraintIndices1.empty())
         {
             /* Update the constraints that operate on atoms
              * in multiple thread atom blocks on the master thread.
@@ -482,7 +495,8 @@ static void lincs_update_atoms(Lincs*                         li,
 #pragma omp barrier
 #pragma omp master
             {
-                lincs_update_atoms_ind(li->task[li->ntask].ind, li->atoms, preFactor, fac, r, invmass, x);
+                lincs_update_atoms_ind(
+                        li->task[li->ntask].updateConstraintIndices1, li->atoms, preFactor, fac, r, invmass, x);
             }
         }
     }
@@ -1636,8 +1650,9 @@ static void lincs_thread_setup(Lincs* li, int natoms)
 
             bitmask_init_low_bits(&mask, th);
 
-            li_task->ind.clear();
-            li_task->ind_r.clear();
+            li_task->updateConstraintIndices1.clear();
+            li_task->updateConstraintIndices2.clear();
+            li_task->updateConstraintIndicesRest.clear();
             for (b = li_task->b0; b < li_task->b1; b++)
             {
                 /* We let the constraint with the lowest thread index
@@ -1647,42 +1662,104 @@ static void lincs_thread_setup(Lincs* li, int natoms)
                     && bitmask_is_disjoint(atf[li->atoms[b].index2], mask))
                 {
                     /* Add the constraint to the local atom update index */
-                    li_task->ind.push_back(b);
+                    li_task->updateConstraintIndices1.push_back(b);
                 }
                 else
                 {
-                    /* Add the constraint to the rest block */
-                    li_task->ind_r.push_back(b);
+                    /* Store the constraint to the rest block */
+                    li_task->updateConstraintIndicesRest.push_back(b);
                 }
             }
         }
         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
     }
 
-    /* We need to copy all constraints which have not be assigned
-     * to a thread to a separate list which will be handled by one thread.
-     */
-    Task* li_m = &li->task[li->ntask];
-
-    li_m->ind.clear();
-    for (int th = 0; th < li->ntask; th++)
+    if (li->bTaskDep)
     {
-        const Task& li_task = li->task[th];
+        /* Assign the rest constraint to a second thread task or a master test task */
 
-        for (int ind_r : li_task.ind_r)
+        /* Clear the atom flags */
+        for (gmx_bitmask_t& mask : atf)
         {
-            li_m->ind.push_back(ind_r);
+            bitmask_clear(&mask);
         }
 
-        if (debug)
+        for (int th = 0; th < li->ntask; th++)
         {
-            fprintf(debug, "LINCS thread %d: %zu constraints\n", th, li_task.ind.size());
+            const Task* li_task = &li->task[th];
+
+            /* For each atom set a flag for constraints from each */
+            for (int b : li_task->updateConstraintIndicesRest)
+            {
+                bitmask_set_bit(&atf[li->atoms[b].index1], th);
+                bitmask_set_bit(&atf[li->atoms[b].index2], th);
+            }
         }
-    }
 
-    if (debug)
-    {
-        fprintf(debug, "LINCS thread r: %zu constraints\n", li_m->ind.size());
+#pragma omp parallel for num_threads(li->ntask) schedule(static)
+        for (int th = 0; th < li->ntask; th++)
+        {
+            try
+            {
+                Task& li_task = li->task[th];
+
+                gmx_bitmask_t mask;
+                bitmask_init_low_bits(&mask, th);
+
+                for (int& b : li_task.updateConstraintIndicesRest)
+                {
+                    /* We let the constraint with the highest thread index
+                     * operate on atoms with constraints from multiple threads.
+                     */
+                    if (bitmask_is_disjoint(atf[li->atoms[b].index1], mask)
+                        && bitmask_is_disjoint(atf[li->atoms[b].index2], mask))
+                    {
+                        li_task.updateConstraintIndices2.push_back(b);
+                        // mark the entry in updateConstraintIndicesRest as invalid, so we do not assign it again
+                        b = -1;
+                    }
+                }
+            }
+            GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
+        }
+
+        /* We need to copy all constraints which have not been assigned
+         * to a thread to a separate list which will be handled by one thread.
+         */
+        Task* li_m = &li->task[li->ntask];
+
+        li->haveSecondUpdateTask = false;
+        li_m->updateConstraintIndices1.clear();
+        for (int th = 0; th < li->ntask; th++)
+        {
+            const Task& li_task = li->task[th];
+
+            for (int constraint : li_task.updateConstraintIndicesRest)
+            {
+                if (constraint >= 0)
+                {
+                    li_m->updateConstraintIndices1.push_back(constraint);
+                }
+                else
+                {
+                    li->haveSecondUpdateTask = true;
+                }
+            }
+
+            if (debug)
+            {
+                fprintf(debug,
+                        "LINCS thread %d: %zu constraints, %zu constraints\n",
+                        th,
+                        li_task.updateConstraintIndices1.size(),
+                        li_task.updateConstraintIndices2.size());
+            }
+        }
+
+        if (debug)
+        {
+            fprintf(debug, "LINCS thread r: %zu constraints\n", li_m->updateConstraintIndices1.size());
+        }
     }
 }
 
@@ -1909,11 +1986,11 @@ void set_lincs(const InteractionDefinitions& idef,
     {
         li->task[i].b0 = 0;
         li->task[i].b1 = 0;
-        li->task[i].ind.clear();
+        li->task[i].updateConstraintIndices1.clear();
     }
     if (li->ntask > 1)
     {
-        li->task[li->ntask].ind.clear();
+        li->task[li->ntask].updateConstraintIndices1.clear();
     }
 
     /* This is the local topology, so there are only F_CONSTR constraints */