Make use of the DeviceStreamManager
[alexxy/gromacs.git] / src / gromacs / ewald / pme.cpp
index 8d8bb673c22500f46c72600ed2ecfaa7ac5f887e..120887e8bbea3567ee059aea4ccb66fef4e966a2 100644 (file)
@@ -560,19 +560,20 @@ static int div_round_up(int enumerator, int denominator)
     return (enumerator + denominator - 1) / denominator;
 }
 
-gmx_pme_t* gmx_pme_init(const t_commrec*         cr,
-                        const NumPmeDomains&     numPmeDomains,
-                        const t_inputrec*        ir,
-                        gmx_bool                 bFreeEnergy_q,
-                        gmx_bool                 bFreeEnergy_lj,
-                        gmx_bool                 bReproducible,
-                        real                     ewaldcoeff_q,
-                        real                     ewaldcoeff_lj,
-                        int                      nthread,
-                        PmeRunMode               runMode,
-                        PmeGpu*                  pmeGpu,
-                        const DeviceInformation* deviceInfo,
-                        const PmeGpuProgram*     pmeGpuProgram,
+gmx_pme_t* gmx_pme_init(const t_commrec*     cr,
+                        const NumPmeDomains& numPmeDomains,
+                        const t_inputrec*    ir,
+                        gmx_bool             bFreeEnergy_q,
+                        gmx_bool             bFreeEnergy_lj,
+                        gmx_bool             bReproducible,
+                        real                 ewaldcoeff_q,
+                        real                 ewaldcoeff_lj,
+                        int                  nthread,
+                        PmeRunMode           runMode,
+                        PmeGpu*              pmeGpu,
+                        const DeviceContext* deviceContext,
+                        const DeviceStream*  deviceStream,
+                        const PmeGpuProgram* pmeGpuProgram,
                         const gmx::MDLogger& /*mdlog*/)
 {
     int  use_threads, sum_use_threads, i;
@@ -883,8 +884,13 @@ gmx_pme_t* gmx_pme_init(const t_commrec*         cr,
         {
             GMX_THROW(gmx::NotImplementedError(errorString));
         }
+        pme_gpu_reinit(pme.get(), deviceContext, deviceStream, pmeGpuProgram);
     }
-    pme_gpu_reinit(pme.get(), deviceInfo, pmeGpuProgram);
+    else
+    {
+        GMX_ASSERT(pme->gpu == nullptr, "Should not have PME GPU object when PME is on a CPU.");
+    }
+
 
     pme_init_all_work(&pme->solve_work, pme->nthread, pme->nkx);
 
@@ -925,7 +931,7 @@ void gmx_pme_reinit(struct gmx_pme_t** pmedata,
         NumPmeDomains numPmeDomains = { pme_src->nnodes_major, pme_src->nnodes_minor };
         *pmedata = gmx_pme_init(cr, numPmeDomains, &irc, pme_src->bFEP_q, pme_src->bFEP_lj, FALSE,
                                 ewaldcoeff_q, ewaldcoeff_lj, pme_src->nthread, pme_src->runMode,
-                                pme_src->gpu, nullptr, nullptr, dummyLogger);
+                                pme_src->gpu, nullptr, nullptr, nullptr, dummyLogger);
         /* When running PME on the CPU not using domain decomposition,
          * the atom data is allocated once only in gmx_pme_(re)init().
          */