SYCL: Avoid using no_init read accessor in rocFFT
[alexxy/gromacs.git] / src / gromacs / nbnxm / freeenergydispatch.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 #include "gmxpre.h"
37
38 #include "freeenergydispatch.h"
39
40 #include "gromacs/gmxlib/nrnb.h"
41 #include "gromacs/gmxlib/nonbonded/nb_free_energy.h"
42 #include "gromacs/gmxlib/nonbonded/nonbonded.h"
43 #include "gromacs/math/vectypes.h"
44 #include "gromacs/mdlib/enerdata_utils.h"
45 #include "gromacs/mdlib/force.h"
46 #include "gromacs/mdlib/gmx_omp_nthreads.h"
47 #include "gromacs/mdtypes/enerdata.h"
48 #include "gromacs/mdtypes/forceoutput.h"
49 #include "gromacs/mdtypes/inputrec.h"
50 #include "gromacs/mdtypes/interaction_const.h"
51 #include "gromacs/mdtypes/md_enums.h"
52 #include "gromacs/mdtypes/nblist.h"
53 #include "gromacs/mdtypes/simulation_workload.h"
54 #include "gromacs/mdtypes/threaded_force_buffer.h"
55 #include "gromacs/nbnxm/nbnxm.h"
56 #include "gromacs/timing/wallcycle.h"
57 #include "gromacs/utility/enumerationhelpers.h"
58 #include "gromacs/utility/gmxassert.h"
59 #include "gromacs/utility/real.h"
60
61 #include "pairlistset.h"
62 #include "pairlistsets.h"
63
64 FreeEnergyDispatch::FreeEnergyDispatch(const int numEnergyGroups) :
65     foreignGroupPairEnergies_(numEnergyGroups),
66     threadedForceBuffer_(gmx_omp_nthreads_get(ModuleMultiThread::Nonbonded), false, numEnergyGroups),
67     threadedForeignEnergyBuffer_(gmx_omp_nthreads_get(ModuleMultiThread::Nonbonded), false, numEnergyGroups)
68 {
69 }
70
71 namespace
72 {
73
74 //! Flags all atoms present in pairlist \p nlist in the mask in \p threadForceBuffer
75 void setReductionMaskFromFepPairlist(const t_nblist& gmx_restrict       nlist,
76                                      gmx::ThreadForceBuffer<gmx::RVec>* threadForceBuffer)
77 {
78     // Extract pair list data
79     gmx::ArrayRef<const int> iinr = nlist.iinr;
80     gmx::ArrayRef<const int> jjnr = nlist.jjnr;
81
82     for (int i : iinr)
83     {
84         threadForceBuffer->addAtomToMask(i);
85     }
86     for (int j : jjnr)
87     {
88         threadForceBuffer->addAtomToMask(j);
89     }
90 }
91
92 } // namespace
93
94 void FreeEnergyDispatch::setupFepThreadedForceBuffer(const int numAtomsForce, const PairlistSets& pairlistSets)
95 {
96     const int numThreads = threadedForceBuffer_.numThreadBuffers();
97
98     GMX_ASSERT(gmx_omp_nthreads_get(ModuleMultiThread::Nonbonded) == numThreads,
99                "The number of buffers should be same as number of NB threads");
100
101 #pragma omp parallel for num_threads(numThreads) schedule(static)
102     for (int th = 0; th < numThreads; th++)
103     {
104         auto& threadForceBuffer = threadedForceBuffer_.threadForceBuffer(th);
105
106         threadForceBuffer.resizeBufferAndClearMask(numAtomsForce);
107
108         setReductionMaskFromFepPairlist(
109                 *pairlistSets.pairlistSet(gmx::InteractionLocality::Local).fepLists()[th],
110                 &threadForceBuffer);
111         if (pairlistSets.params().haveMultipleDomains)
112         {
113             setReductionMaskFromFepPairlist(
114                     *pairlistSets.pairlistSet(gmx::InteractionLocality::NonLocal).fepLists()[th],
115                     &threadForceBuffer);
116         }
117
118         threadForceBuffer.processMask();
119     }
120
121     threadedForceBuffer_.setupReduction();
122 }
123
124 void nonbonded_verlet_t::setupFepThreadedForceBuffer(const int numAtomsForce)
125 {
126     if (!pairlistSets_->params().haveFep)
127     {
128         return;
129     }
130
131     GMX_RELEASE_ASSERT(freeEnergyDispatch_, "Need a valid dispatch object");
132
133     freeEnergyDispatch_->setupFepThreadedForceBuffer(numAtomsForce, *pairlistSets_);
134 }
135
136 namespace
137 {
138
139 void dispatchFreeEnergyKernel(gmx::ArrayRef<const std::unique_ptr<t_nblist>>   nbl_fep,
140                               const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
141                               bool                                             useSimd,
142                               int                                              ntype,
143                               real                                             rlist,
144                               const interaction_const_t&                       ic,
145                               gmx::ArrayRef<const gmx::RVec>                   shiftvec,
146                               gmx::ArrayRef<const real>                        nbfp,
147                               gmx::ArrayRef<const real>                        nbfp_grid,
148                               gmx::ArrayRef<const real>                        chargeA,
149                               gmx::ArrayRef<const real>                        chargeB,
150                               gmx::ArrayRef<const int>                         typeA,
151                               gmx::ArrayRef<const int>                         typeB,
152                               t_lambda*                                        fepvals,
153                               gmx::ArrayRef<const real>                        lambda,
154                               const bool                           clearForcesAndEnergies,
155                               gmx::ThreadedForceBuffer<gmx::RVec>* threadedForceBuffer,
156                               gmx::ThreadedForceBuffer<gmx::RVec>* threadedForeignEnergyBuffer,
157                               gmx_grppairener_t*                   foreignGroupPairEnergies,
158                               gmx_enerdata_t*                      enerd,
159                               const gmx::StepWorkload&             stepWork,
160                               t_nrnb*                              nrnb)
161 {
162     int donb_flags = 0;
163     /* Add short-range interactions */
164     donb_flags |= GMX_NONBONDED_DO_SR;
165
166     if (stepWork.computeForces)
167     {
168         donb_flags |= GMX_NONBONDED_DO_FORCE;
169     }
170     if (stepWork.computeVirial)
171     {
172         donb_flags |= GMX_NONBONDED_DO_SHIFTFORCE;
173     }
174     if (stepWork.computeEnergy)
175     {
176         donb_flags |= GMX_NONBONDED_DO_POTENTIAL;
177     }
178
179     GMX_ASSERT(gmx_omp_nthreads_get(ModuleMultiThread::Nonbonded) == nbl_fep.ssize(),
180                "Number of lists should be same as number of NB threads");
181
182 #pragma omp parallel for schedule(static) num_threads(nbl_fep.ssize())
183     for (gmx::index th = 0; th < nbl_fep.ssize(); th++)
184     {
185         try
186         {
187             auto& threadForceBuffer = threadedForceBuffer->threadForceBuffer(th);
188
189             if (clearForcesAndEnergies)
190             {
191                 threadForceBuffer.clearForcesAndEnergies();
192             }
193
194             auto  threadForces           = threadForceBuffer.forceBufferWithPadding();
195             rvec* threadForceShiftBuffer = as_rvec_array(threadForceBuffer.shiftForces().data());
196             gmx::ArrayRef<real> threadVc =
197                     threadForceBuffer.groupPairEnergies().energyGroupPairTerms[NonBondedEnergyTerms::CoulombSR];
198             gmx::ArrayRef<real> threadVv =
199                     threadForceBuffer.groupPairEnergies().energyGroupPairTerms[NonBondedEnergyTerms::LJSR];
200             gmx::ArrayRef<real> threadDvdl = threadForceBuffer.dvdl();
201
202             gmx_nb_free_energy_kernel(*nbl_fep[th],
203                                       coords,
204                                       useSimd,
205                                       ntype,
206                                       rlist,
207                                       ic,
208                                       shiftvec,
209                                       nbfp,
210                                       nbfp_grid,
211                                       chargeA,
212                                       chargeB,
213                                       typeA,
214                                       typeB,
215                                       donb_flags,
216                                       lambda,
217                                       nrnb,
218                                       threadForces,
219                                       threadForceShiftBuffer,
220                                       threadVc,
221                                       threadVv,
222                                       threadDvdl);
223         }
224         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
225     }
226
227     /* If we do foreign lambda and we have soft-core interactions
228      * we have to recalculate the (non-linear) energies contributions.
229      */
230     if (fepvals->n_lambda > 0 && stepWork.computeDhdl && fepvals->sc_alpha != 0)
231     {
232         gmx::StepWorkload stepWorkForeignEnergies = stepWork;
233         stepWorkForeignEnergies.computeForces     = false;
234         stepWorkForeignEnergies.computeVirial     = false;
235
236         gmx::EnumerationArray<FreeEnergyPerturbationCouplingType, real> lam_i;
237         gmx::EnumerationArray<FreeEnergyPerturbationCouplingType, real> dvdl_nb = { 0 };
238         const int kernelFlags = (donb_flags & ~(GMX_NONBONDED_DO_FORCE | GMX_NONBONDED_DO_SHIFTFORCE))
239                                 | GMX_NONBONDED_DO_FOREIGNLAMBDA;
240
241         for (gmx::index i = 0; i < 1 + enerd->foreignLambdaTerms.numLambdas(); i++)
242         {
243             std::fill(std::begin(dvdl_nb), std::end(dvdl_nb), 0);
244             for (int j = 0; j < static_cast<int>(FreeEnergyPerturbationCouplingType::Count); j++)
245             {
246                 lam_i[j] = (i == 0 ? lambda[j] : fepvals->all_lambda[j][i - 1]);
247             }
248
249 #pragma omp parallel for schedule(static) num_threads(nbl_fep.ssize())
250             for (gmx::index th = 0; th < nbl_fep.ssize(); th++)
251             {
252                 try
253                 {
254                     // Note that here we only compute energies and dV/dlambda, but we need
255                     // to pass a force buffer. No forces are compute and stored.
256                     auto& threadForeignEnergyBuffer = threadedForeignEnergyBuffer->threadForceBuffer(th);
257
258                     threadForeignEnergyBuffer.clearForcesAndEnergies();
259
260                     gmx::ArrayRef<real> threadVc =
261                             threadForeignEnergyBuffer.groupPairEnergies()
262                                     .energyGroupPairTerms[NonBondedEnergyTerms::CoulombSR];
263                     gmx::ArrayRef<real> threadVv =
264                             threadForeignEnergyBuffer.groupPairEnergies()
265                                     .energyGroupPairTerms[NonBondedEnergyTerms::LJSR];
266                     gmx::ArrayRef<real> threadDvdl = threadForeignEnergyBuffer.dvdl();
267
268                     gmx_nb_free_energy_kernel(*nbl_fep[th],
269                                               coords,
270                                               useSimd,
271                                               ntype,
272                                               rlist,
273                                               ic,
274                                               shiftvec,
275                                               nbfp,
276                                               nbfp_grid,
277                                               chargeA,
278                                               chargeB,
279                                               typeA,
280                                               typeB,
281                                               kernelFlags,
282                                               lam_i,
283                                               nrnb,
284                                               gmx::ArrayRefWithPadding<gmx::RVec>(),
285                                               nullptr,
286                                               threadVc,
287                                               threadVv,
288                                               threadDvdl);
289                 }
290                 GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
291             }
292
293             foreignGroupPairEnergies->clear();
294             threadedForeignEnergyBuffer->reduce(
295                     nullptr, nullptr, foreignGroupPairEnergies, dvdl_nb, stepWorkForeignEnergies, 0);
296
297             std::array<real, F_NRE> foreign_term = { 0 };
298             sum_epot(*foreignGroupPairEnergies, foreign_term.data());
299             // Accumulate the foreign energy difference and dV/dlambda into the passed enerd
300             enerd->foreignLambdaTerms.accumulate(
301                     i,
302                     foreign_term[F_EPOT],
303                     dvdl_nb[FreeEnergyPerturbationCouplingType::Vdw]
304                             + dvdl_nb[FreeEnergyPerturbationCouplingType::Coul]);
305         }
306     }
307 }
308
309 } // namespace
310
311 void FreeEnergyDispatch::dispatchFreeEnergyKernels(const PairlistSets& pairlistSets,
312                                                    const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
313                                                    gmx::ForceWithShiftForces* forceWithShiftForces,
314                                                    const bool                 useSimd,
315                                                    const int                  ntype,
316                                                    const real                 rlist,
317                                                    const interaction_const_t& ic,
318                                                    gmx::ArrayRef<const gmx::RVec> shiftvec,
319                                                    gmx::ArrayRef<const real>      nbfp,
320                                                    gmx::ArrayRef<const real>      nbfp_grid,
321                                                    gmx::ArrayRef<const real>      chargeA,
322                                                    gmx::ArrayRef<const real>      chargeB,
323                                                    gmx::ArrayRef<const int>       typeA,
324                                                    gmx::ArrayRef<const int>       typeB,
325                                                    t_lambda*                      fepvals,
326                                                    gmx::ArrayRef<const real>      lambda,
327                                                    gmx_enerdata_t*                enerd,
328                                                    const gmx::StepWorkload&       stepWork,
329                                                    t_nrnb*                        nrnb,
330                                                    gmx_wallcycle*                 wcycle)
331 {
332     GMX_ASSERT(pairlistSets.params().haveFep, "We should have a free-energy pairlist");
333
334     wallcycle_sub_start(wcycle, WallCycleSubCounter::NonbondedFep);
335
336     const int numLocalities = (pairlistSets.params().haveMultipleDomains ? 2 : 1);
337     // The first call to dispatchFreeEnergyKernel() should clear the buffers. Clearing happens
338     // inside that function to avoid an extra OpenMP parallel region here. We need a boolean
339     // to track the need for clearing.
340     // A better solution would be to move the OpenMP parallel region here, but that first
341     // requires modifying ThreadedForceBuffer.reduce() to be called thread parallel.
342     bool clearForcesAndEnergies = true;
343     for (int i = 0; i < numLocalities; i++)
344     {
345         const gmx::InteractionLocality iLocality = static_cast<gmx::InteractionLocality>(i);
346         const auto fepPairlists                  = pairlistSets.pairlistSet(iLocality).fepLists();
347         /* When the first list is empty, all are empty and there is nothing to do */
348         if (fepPairlists[0]->nrj > 0)
349         {
350             dispatchFreeEnergyKernel(fepPairlists,
351                                      coords,
352                                      useSimd,
353                                      ntype,
354                                      rlist,
355                                      ic,
356                                      shiftvec,
357                                      nbfp,
358                                      nbfp_grid,
359                                      chargeA,
360                                      chargeB,
361                                      typeA,
362                                      typeB,
363                                      fepvals,
364                                      lambda,
365                                      clearForcesAndEnergies,
366                                      &threadedForceBuffer_,
367                                      &threadedForeignEnergyBuffer_,
368                                      &foreignGroupPairEnergies_,
369                                      enerd,
370                                      stepWork,
371                                      nrnb);
372         }
373         else if (clearForcesAndEnergies)
374         {
375             // We need to clear the thread force buffer.
376             // With a non-empty pairlist we do this in dispatchFreeEnergyKernel()
377             // to avoid the overhead of an extra openMP parallel loop
378 #pragma omp parallel for schedule(static) num_threads(fepPairlists.ssize())
379             for (gmx::index th = 0; th < fepPairlists.ssize(); th++)
380             {
381                 try
382                 {
383                     threadedForceBuffer_.threadForceBuffer(th).clearForcesAndEnergies();
384                 }
385                 GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
386             }
387         }
388         clearForcesAndEnergies = false;
389     }
390     wallcycle_sub_stop(wcycle, WallCycleSubCounter::NonbondedFep);
391
392     wallcycle_sub_start(wcycle, WallCycleSubCounter::NonbondedFepReduction);
393
394     gmx::EnumerationArray<FreeEnergyPerturbationCouplingType, real> dvdl_nb = { 0 };
395
396     threadedForceBuffer_.reduce(forceWithShiftForces, nullptr, &enerd->grpp, dvdl_nb, stepWork, 0);
397
398     if (fepvals->sc_alpha != 0)
399     {
400         enerd->dvdl_nonlin[FreeEnergyPerturbationCouplingType::Vdw] +=
401                 dvdl_nb[FreeEnergyPerturbationCouplingType::Vdw];
402         enerd->dvdl_nonlin[FreeEnergyPerturbationCouplingType::Coul] +=
403                 dvdl_nb[FreeEnergyPerturbationCouplingType::Coul];
404     }
405     else
406     {
407         enerd->dvdl_lin[FreeEnergyPerturbationCouplingType::Vdw] +=
408                 dvdl_nb[FreeEnergyPerturbationCouplingType::Vdw];
409         enerd->dvdl_lin[FreeEnergyPerturbationCouplingType::Coul] +=
410                 dvdl_nb[FreeEnergyPerturbationCouplingType::Coul];
411     }
412
413     wallcycle_sub_stop(wcycle, WallCycleSubCounter::NonbondedFepReduction);
414 }
415
416 void nonbonded_verlet_t::dispatchFreeEnergyKernels(const gmx::ArrayRefWithPadding<const gmx::RVec>& coords,
417                                                    gmx::ForceWithShiftForces* forceWithShiftForces,
418                                                    const bool                 useSimd,
419                                                    const int                  ntype,
420                                                    const real                 rlist,
421                                                    const interaction_const_t& ic,
422                                                    gmx::ArrayRef<const gmx::RVec> shiftvec,
423                                                    gmx::ArrayRef<const real>      nbfp,
424                                                    gmx::ArrayRef<const real>      nbfp_grid,
425                                                    gmx::ArrayRef<const real>      chargeA,
426                                                    gmx::ArrayRef<const real>      chargeB,
427                                                    gmx::ArrayRef<const int>       typeA,
428                                                    gmx::ArrayRef<const int>       typeB,
429                                                    t_lambda*                      fepvals,
430                                                    gmx::ArrayRef<const real>      lambda,
431                                                    gmx_enerdata_t*                enerd,
432                                                    const gmx::StepWorkload&       stepWork,
433                                                    t_nrnb*                        nrnb)
434 {
435     if (!pairlistSets_->params().haveFep)
436     {
437         return;
438     }
439
440     GMX_RELEASE_ASSERT(freeEnergyDispatch_, "Need a valid dispatch object");
441
442     freeEnergyDispatch_->dispatchFreeEnergyKernels(*pairlistSets_,
443                                                    coords,
444                                                    forceWithShiftForces,
445                                                    useSimd,
446                                                    ntype,
447                                                    rlist,
448                                                    ic,
449                                                    shiftvec,
450                                                    nbfp,
451                                                    nbfp_grid,
452                                                    chargeA,
453                                                    chargeB,
454                                                    typeA,
455                                                    typeB,
456                                                    fepvals,
457                                                    lambda,
458                                                    enerd,
459                                                    stepWork,
460                                                    nrnb,
461                                                    wcycle_);
462 }