SYCL: Avoid using no_init read accessor in rocFFT
[alexxy/gromacs.git] / src / gromacs / fft / fft5d.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2009-2018, The GROMACS development team.
5  * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36 #include "gmxpre.h"
37
38 #include "fft5d.h"
39
40 #include "config.h"
41
42 #include <cassert>
43 #include <cfloat>
44 #include <cmath>
45 #include <cstdio>
46 #include <cstdlib>
47 #include <cstring>
48
49 #include <algorithm>
50
51 #include "gromacs/gpu_utils/gpu_utils.h"
52 #include "gromacs/gpu_utils/hostallocator.h"
53 #include "gromacs/gpu_utils/pinning.h"
54 #include "gromacs/utility/alignedallocator.h"
55 #include "gromacs/utility/exceptions.h"
56 #include "gromacs/utility/fatalerror.h"
57 #include "gromacs/utility/gmxassert.h"
58 #include "gromacs/utility/gmxmpi.h"
59 #include "gromacs/utility/smalloc.h"
60
61 #ifdef NOGMX
62 #    define GMX_PARALLEL_ENV_INITIALIZED 1
63 #else
64 #    if GMX_MPI
65 #        define GMX_PARALLEL_ENV_INITIALIZED 1
66 #    else
67 #        define GMX_PARALLEL_ENV_INITIALIZED 0
68 #    endif
69 #endif
70
71 #if GMX_OPENMP
72 /* TODO: Do we still need this? Are we still planning ot use fftw + OpenMP? */
73 #    define FFT5D_THREADS
74 /* requires fftw compiled with openmp */
75 /* #define FFT5D_FFTW_THREADS (now set by cmake) */
76 #endif
77
78 #ifndef __FLT_EPSILON__
79 #    define __FLT_EPSILON__ FLT_EPSILON
80 #    define __DBL_EPSILON__ DBL_EPSILON
81 #endif
82
83 #ifdef NOGMX
84 FILE* debug = 0;
85 #endif
86
87 #if GMX_FFT_FFTW3
88
89 #    include <mutex>
90
91 #    include "gromacs/utility/exceptions.h"
92
93 /* none of the fftw3 calls, except execute(), are thread-safe, so
94    we need to serialize them with this mutex. */
95 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
96 static std::mutex big_fftw_mutex;
97 #    define FFTW_LOCK              \
98         try                        \
99         {                          \
100             big_fftw_mutex.lock(); \
101         }                          \
102         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
103 #    define FFTW_UNLOCK              \
104         try                          \
105         {                            \
106             big_fftw_mutex.unlock(); \
107         }                            \
108         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
109 #endif /* GMX_FFT_FFTW3 */
110
111 #if GMX_MPI
112 /* largest factor smaller than sqrt */
113 static int lfactor(int z)
114 {
115     int i = static_cast<int>(sqrt(static_cast<double>(z)));
116     while (z % i != 0)
117     {
118         i--;
119     }
120     return i;
121 }
122 #endif
123
124 #if !GMX_MPI
125 #    if HAVE_GETTIMEOFDAY
126 #        include <sys/time.h>
127 double MPI_Wtime()
128 {
129     struct timeval tv;
130     gettimeofday(&tv, nullptr);
131     return tv.tv_sec + tv.tv_usec * 1e-6;
132 }
133 #    else
134 double MPI_Wtime()
135 {
136     return 0.0;
137 }
138 #    endif
139 #endif
140
141 static int vmax(const int* a, int s)
142 {
143     int i, max = 0;
144     for (i = 0; i < s; i++)
145     {
146         if (a[i] > max)
147         {
148             max = a[i];
149         }
150     }
151     return max;
152 }
153
154
155 /* NxMxK the size of the data
156  * comm communicator to use for fft5d
157  * P0 number of processor in 1st axes (can be null for automatic)
158  * lin is allocated by fft5d because size of array is only known after planning phase
159  * rlout2 is only used as intermediate buffer - only returned after allocation to reuse for back transform - should not be used by caller
160  */
161 fft5d_plan fft5d_plan_3d(int                NG,
162                          int                MG,
163                          int                KG,
164                          MPI_Comm           comm[2],
165                          int                flags,
166                          t_complex**        rlin,
167                          t_complex**        rlout,
168                          t_complex**        rlout2,
169                          t_complex**        rlout3,
170                          int                nthreads,
171                          gmx::PinningPolicy realGridAllocationPinningPolicy)
172 {
173
174     int  P[2], prank[2], i;
175     bool bMaster;
176     int  rNG, rMG, rKG;
177     int *N0 = nullptr, *N1 = nullptr, *M0 = nullptr, *M1 = nullptr, *K0 = nullptr, *K1 = nullptr,
178         *oN0 = nullptr, *oN1 = nullptr, *oM0 = nullptr, *oM1 = nullptr, *oK0 = nullptr, *oK1 = nullptr;
179     int N[3], M[3], K[3], pN[3], pM[3], pK[3], oM[3], oK[3],
180             *iNin[3] = { nullptr }, *oNin[3] = { nullptr }, *iNout[3] = { nullptr },
181             *oNout[3] = { nullptr };
182     int        C[3], rC[3], nP[2];
183     int        lsize;
184     t_complex *lin = nullptr, *lout = nullptr, *lout2 = nullptr, *lout3 = nullptr;
185     fft5d_plan plan;
186     int        s;
187
188     /* comm, prank and P are in the order of the decomposition (plan->cart is in the order of transposes) */
189 #if GMX_MPI
190     if (GMX_PARALLEL_ENV_INITIALIZED && comm[0] != MPI_COMM_NULL)
191     {
192         MPI_Comm_size(comm[0], &P[0]);
193         MPI_Comm_rank(comm[0], &prank[0]);
194     }
195     else
196 #endif
197     {
198         P[0]     = 1;
199         prank[0] = 0;
200     }
201 #if GMX_MPI
202     if (GMX_PARALLEL_ENV_INITIALIZED && comm[1] != MPI_COMM_NULL)
203     {
204         MPI_Comm_size(comm[1], &P[1]);
205         MPI_Comm_rank(comm[1], &prank[1]);
206     }
207     else
208 #endif
209     {
210         P[1]     = 1;
211         prank[1] = 0;
212     }
213
214     bMaster = prank[0] == 0 && prank[1] == 0;
215
216
217     if (debug)
218     {
219         fprintf(debug, "FFT5D: Using %dx%d rank grid, rank %d,%d\n", P[0], P[1], prank[0], prank[1]);
220     }
221
222     if (bMaster)
223     {
224         if (debug)
225         {
226             fprintf(debug,
227                     "FFT5D: N: %d, M: %d, K: %d, P: %dx%d, real2complex: %d, backward: %d, order "
228                     "yz: %d, debug %d\n",
229                     NG,
230                     MG,
231                     KG,
232                     P[0],
233                     P[1],
234                     int((flags & FFT5D_REALCOMPLEX) > 0),
235                     int((flags & FFT5D_BACKWARD) > 0),
236                     int((flags & FFT5D_ORDER_YZ) > 0),
237                     int((flags & FFT5D_DEBUG) > 0));
238         }
239         /* The check below is not correct, one prime factor 11 or 13 is ok.
240            if (fft5d_fmax(fft5d_fmax(lpfactor(NG),lpfactor(MG)),lpfactor(KG))>7) {
241             printf("WARNING: FFT very slow with prime factors larger 7\n");
242             printf("Change FFT size or in case you cannot change it look at\n");
243             printf("http://www.fftw.org/fftw3_doc/Generating-your-own-code.html\n");
244            }
245          */
246     }
247
248     if (NG == 0 || MG == 0 || KG == 0)
249     {
250         if (bMaster)
251         {
252             printf("FFT5D: FATAL: Datasize cannot be zero in any dimension\n");
253         }
254         return nullptr;
255     }
256
257     rNG = NG;
258     rMG = MG;
259     rKG = KG;
260
261     if (flags & FFT5D_REALCOMPLEX)
262     {
263         if (!(flags & FFT5D_BACKWARD))
264         {
265             NG = NG / 2 + 1;
266         }
267         else
268         {
269             if (!(flags & FFT5D_ORDER_YZ))
270             {
271                 MG = MG / 2 + 1;
272             }
273             else
274             {
275                 KG = KG / 2 + 1;
276             }
277         }
278     }
279
280
281     /*for transpose we need to know the size for each processor not only our own size*/
282
283     N0  = static_cast<int*>(malloc(P[0] * sizeof(int)));
284     N1  = static_cast<int*>(malloc(P[1] * sizeof(int)));
285     M0  = static_cast<int*>(malloc(P[0] * sizeof(int)));
286     M1  = static_cast<int*>(malloc(P[1] * sizeof(int)));
287     K0  = static_cast<int*>(malloc(P[0] * sizeof(int)));
288     K1  = static_cast<int*>(malloc(P[1] * sizeof(int)));
289     oN0 = static_cast<int*>(malloc(P[0] * sizeof(int)));
290     oN1 = static_cast<int*>(malloc(P[1] * sizeof(int)));
291     oM0 = static_cast<int*>(malloc(P[0] * sizeof(int)));
292     oM1 = static_cast<int*>(malloc(P[1] * sizeof(int)));
293     oK0 = static_cast<int*>(malloc(P[0] * sizeof(int)));
294     oK1 = static_cast<int*>(malloc(P[1] * sizeof(int)));
295
296     for (i = 0; i < P[0]; i++)
297     {
298 #define EVENDIST
299 #ifndef EVENDIST
300         oN0[i] = i * ceil((double)NG / P[0]);
301         oM0[i] = i * ceil((double)MG / P[0]);
302         oK0[i] = i * ceil((double)KG / P[0]);
303 #else
304         oN0[i] = (NG * i) / P[0];
305         oM0[i] = (MG * i) / P[0];
306         oK0[i] = (KG * i) / P[0];
307 #endif
308     }
309     for (i = 0; i < P[1]; i++)
310     {
311 #ifndef EVENDIST
312         oN1[i] = i * ceil((double)NG / P[1]);
313         oM1[i] = i * ceil((double)MG / P[1]);
314         oK1[i] = i * ceil((double)KG / P[1]);
315 #else
316         oN1[i] = (NG * i) / P[1];
317         oM1[i] = (MG * i) / P[1];
318         oK1[i] = (KG * i) / P[1];
319 #endif
320     }
321     for (i = 0; P[0] > 0 && i < P[0] - 1; i++)
322     {
323         N0[i] = oN0[i + 1] - oN0[i];
324         M0[i] = oM0[i + 1] - oM0[i];
325         K0[i] = oK0[i + 1] - oK0[i];
326     }
327     N0[P[0] - 1] = NG - oN0[P[0] - 1];
328     M0[P[0] - 1] = MG - oM0[P[0] - 1];
329     K0[P[0] - 1] = KG - oK0[P[0] - 1];
330     for (i = 0; P[1] > 0 && i < P[1] - 1; i++)
331     {
332         N1[i] = oN1[i + 1] - oN1[i];
333         M1[i] = oM1[i + 1] - oM1[i];
334         K1[i] = oK1[i + 1] - oK1[i];
335     }
336     N1[P[1] - 1] = NG - oN1[P[1] - 1];
337     M1[P[1] - 1] = MG - oM1[P[1] - 1];
338     K1[P[1] - 1] = KG - oK1[P[1] - 1];
339
340     /* for step 1-3 the local N,M,K sizes of the transposed system
341        C: contiguous dimension, and nP: number of processor in subcommunicator
342        for that step */
343
344     GMX_ASSERT(prank[0] < P[0], "Must have valid rank within communicator size");
345     GMX_ASSERT(prank[1] < P[1], "Must have valid rank within communicator size");
346     pM[0] = M0[prank[0]];
347     oM[0] = oM0[prank[0]];
348     pK[0] = K1[prank[1]];
349     oK[0] = oK1[prank[1]];
350     C[0]  = NG;
351     rC[0] = rNG;
352     if (!(flags & FFT5D_ORDER_YZ))
353     {
354         N[0]     = vmax(N1, P[1]);
355         M[0]     = M0[prank[0]];
356         K[0]     = vmax(K1, P[1]);
357         pN[0]    = N1[prank[1]];
358         iNout[0] = N1;
359         oNout[0] = oN1;
360         nP[0]    = P[1];
361         C[1]     = KG;
362         rC[1]    = rKG;
363         N[1]     = vmax(K0, P[0]);
364         pN[1]    = K0[prank[0]];
365         iNin[1]  = K1;
366         oNin[1]  = oK1;
367         iNout[1] = K0;
368         oNout[1] = oK0;
369         M[1]     = vmax(M0, P[0]);
370         pM[1]    = M0[prank[0]];
371         oM[1]    = oM0[prank[0]];
372         K[1]     = N1[prank[1]];
373         pK[1]    = N1[prank[1]];
374         oK[1]    = oN1[prank[1]];
375         nP[1]    = P[0];
376         C[2]     = MG;
377         rC[2]    = rMG;
378         iNin[2]  = M0;
379         oNin[2]  = oM0;
380         M[2]     = vmax(K0, P[0]);
381         pM[2]    = K0[prank[0]];
382         oM[2]    = oK0[prank[0]];
383         K[2]     = vmax(N1, P[1]);
384         pK[2]    = N1[prank[1]];
385         oK[2]    = oN1[prank[1]];
386         free(N0);
387         free(oN0); /*these are not used for this order*/
388         free(M1);
389         free(oM1); /*the rest is freed in destroy*/
390     }
391     else
392     {
393         N[0]     = vmax(N0, P[0]);
394         M[0]     = vmax(M0, P[0]);
395         K[0]     = K1[prank[1]];
396         pN[0]    = N0[prank[0]];
397         iNout[0] = N0;
398         oNout[0] = oN0;
399         nP[0]    = P[0];
400         C[1]     = MG;
401         rC[1]    = rMG;
402         N[1]     = vmax(M1, P[1]);
403         pN[1]    = M1[prank[1]];
404         iNin[1]  = M0;
405         oNin[1]  = oM0;
406         iNout[1] = M1;
407         oNout[1] = oM1;
408         M[1]     = N0[prank[0]];
409         pM[1]    = N0[prank[0]];
410         oM[1]    = oN0[prank[0]];
411         K[1]     = vmax(K1, P[1]);
412         pK[1]    = K1[prank[1]];
413         oK[1]    = oK1[prank[1]];
414         nP[1]    = P[1];
415         C[2]     = KG;
416         rC[2]    = rKG;
417         iNin[2]  = K1;
418         oNin[2]  = oK1;
419         M[2]     = vmax(N0, P[0]);
420         pM[2]    = N0[prank[0]];
421         oM[2]    = oN0[prank[0]];
422         K[2]     = vmax(M1, P[1]);
423         pK[2]    = M1[prank[1]];
424         oK[2]    = oM1[prank[1]];
425         free(N1);
426         free(oN1); /*these are not used for this order*/
427         free(K0);
428         free(oK0); /*the rest is freed in destroy*/
429     }
430     N[2] = pN[2] = -1; /*not used*/
431
432     /*
433        Difference between x-y-z regarding 2d decomposition is whether they are
434        distributed along axis 1, 2 or both
435      */
436
437     /* int lsize = fmax(N[0]*M[0]*K[0]*nP[0],N[1]*M[1]*K[1]*nP[1]); */
438     lsize = std::max(N[0] * M[0] * K[0] * nP[0], std::max(N[1] * M[1] * K[1] * nP[1], C[2] * M[2] * K[2]));
439     /* int lsize = fmax(C[0]*M[0]*K[0],fmax(C[1]*M[1]*K[1],C[2]*M[2]*K[2])); */
440     if (!(flags & FFT5D_NOMALLOC))
441     {
442         // only needed for PME GPU mixed mode
443         if (GMX_GPU_CUDA && realGridAllocationPinningPolicy == gmx::PinningPolicy::PinnedIfSupported)
444         {
445             const std::size_t numBytes = lsize * sizeof(t_complex);
446             lin = static_cast<t_complex*>(gmx::PageAlignedAllocationPolicy::malloc(numBytes));
447             gmx::pinBuffer(lin, numBytes);
448         }
449         else
450         {
451             snew_aligned(lin, lsize, 32);
452         }
453         snew_aligned(lout, lsize, 32);
454         if (nthreads > 1)
455         {
456             /* We need extra transpose buffers to avoid OpenMP barriers */
457             snew_aligned(lout2, lsize, 32);
458             snew_aligned(lout3, lsize, 32);
459         }
460         else
461         {
462             /* We can reuse the buffers to avoid cache misses */
463             lout2 = lin;
464             lout3 = lout;
465         }
466     }
467     else
468     {
469         lin  = *rlin;
470         lout = *rlout;
471         if (nthreads > 1)
472         {
473             lout2 = *rlout2;
474             lout3 = *rlout3;
475         }
476         else
477         {
478             lout2 = lin;
479             lout3 = lout;
480         }
481     }
482
483     plan = static_cast<fft5d_plan>(calloc(1, sizeof(struct fft5d_plan_t)));
484
485
486     if (debug)
487     {
488         fprintf(debug, "Running on %d threads\n", nthreads);
489     }
490
491 #if GMX_FFT_FFTW3
492     /* Don't add more stuff here! We have already had at least one bug because we are reimplementing
493      * the low-level FFT interface instead of using the Gromacs FFT module. If we need more
494      * generic functionality it is far better to extend the interface so we can use it for
495      * all FFT libraries instead of writing FFTW-specific code here.
496      */
497
498     /*if not FFTW - then we don't do a 3d plan but instead use only 1D plans */
499     /* It is possible to use the 3d plan with OMP threads - but in that case it is not allowed to be
500      * called from within a parallel region. For now deactivated. If it should be supported it has
501      * to made sure that that the execute of the 3d plan is in a master/serial block (since it
502      * contains it own parallel region) and that the 3d plan is faster than the 1d plan.
503      */
504     if ((!(flags & FFT5D_INPLACE)) && (!(P[0] > 1 || P[1] > 1))
505         && nthreads == 1) /*don't do 3d plan in parallel or if in_place requested */
506     {
507         int fftwflags = FFTW_DESTROY_INPUT;
508         FFTW(iodim) dims[3];
509         int inNG = NG, outMG = MG, outKG = KG;
510
511         FFTW_LOCK
512
513         fftwflags |= (flags & FFT5D_NOMEASURE) ? FFTW_ESTIMATE : FFTW_MEASURE;
514
515         if (flags & FFT5D_REALCOMPLEX)
516         {
517             if (!(flags & FFT5D_BACKWARD)) /*input pointer is not complex*/
518             {
519                 inNG *= 2;
520             }
521             else /*output pointer is not complex*/
522             {
523                 if (!(flags & FFT5D_ORDER_YZ))
524                 {
525                     outMG *= 2;
526                 }
527                 else
528                 {
529                     outKG *= 2;
530                 }
531             }
532         }
533
534         if (!(flags & FFT5D_BACKWARD))
535         {
536             dims[0].n = KG;
537             dims[1].n = MG;
538             dims[2].n = rNG;
539
540             dims[0].is = inNG * MG; /*N M K*/
541             dims[1].is = inNG;
542             dims[2].is = 1;
543             if (!(flags & FFT5D_ORDER_YZ))
544             {
545                 dims[0].os = MG; /*M K N*/
546                 dims[1].os = 1;
547                 dims[2].os = MG * KG;
548             }
549             else
550             {
551                 dims[0].os = 1; /*K N M*/
552                 dims[1].os = KG * NG;
553                 dims[2].os = KG;
554             }
555         }
556         else
557         {
558             if (!(flags & FFT5D_ORDER_YZ))
559             {
560                 dims[0].n = NG;
561                 dims[1].n = KG;
562                 dims[2].n = rMG;
563
564                 dims[0].is = 1;
565                 dims[1].is = NG * MG;
566                 dims[2].is = NG;
567
568                 dims[0].os = outMG * KG;
569                 dims[1].os = outMG;
570                 dims[2].os = 1;
571             }
572             else
573             {
574                 dims[0].n = MG;
575                 dims[1].n = NG;
576                 dims[2].n = rKG;
577
578                 dims[0].is = NG;
579                 dims[1].is = 1;
580                 dims[2].is = NG * MG;
581
582                 dims[0].os = outKG * NG;
583                 dims[1].os = outKG;
584                 dims[2].os = 1;
585             }
586         }
587 #    ifdef FFT5D_THREADS
588 #        ifdef FFT5D_FFTW_THREADS
589         FFTW(plan_with_nthreads)(nthreads);
590 #        endif
591 #    endif
592         if ((flags & FFT5D_REALCOMPLEX) && !(flags & FFT5D_BACKWARD))
593         {
594             plan->p3d = FFTW(plan_guru_dft_r2c)(/*rank*/ 3,
595                                                 dims,
596                                                 /*howmany*/ 0,
597                                                 /*howmany_dims*/ nullptr,
598                                                 reinterpret_cast<real*>(lin),
599                                                 reinterpret_cast<FFTW(complex)*>(lout),
600                                                 /*flags*/ fftwflags);
601         }
602         else if ((flags & FFT5D_REALCOMPLEX) && (flags & FFT5D_BACKWARD))
603         {
604             plan->p3d = FFTW(plan_guru_dft_c2r)(/*rank*/ 3,
605                                                 dims,
606                                                 /*howmany*/ 0,
607                                                 /*howmany_dims*/ nullptr,
608                                                 reinterpret_cast<FFTW(complex)*>(lin),
609                                                 reinterpret_cast<real*>(lout),
610                                                 /*flags*/ fftwflags);
611         }
612         else
613         {
614             plan->p3d = FFTW(plan_guru_dft)(
615                     /*rank*/ 3,
616                     dims,
617                     /*howmany*/ 0,
618                     /*howmany_dims*/ nullptr,
619                     reinterpret_cast<FFTW(complex)*>(lin),
620                     reinterpret_cast<FFTW(complex)*>(lout),
621                     /*sign*/ (flags & FFT5D_BACKWARD) ? 1 : -1,
622                     /*flags*/ fftwflags);
623         }
624 #    ifdef FFT5D_THREADS
625 #        ifdef FFT5D_FFTW_THREADS
626         FFTW(plan_with_nthreads)(1);
627 #        endif
628 #    endif
629         FFTW_UNLOCK
630     }
631     if (!plan->p3d) /* for decomposition and if 3d plan did not work */
632     {
633 #endif /* GMX_FFT_FFTW3 */
634         for (s = 0; s < 3; s++)
635         {
636             if (debug)
637             {
638                 fprintf(debug, "FFT5D: Plan s %d rC %d M %d pK %d C %d lsize %d\n", s, rC[s], M[s], pK[s], C[s], lsize);
639             }
640             plan->p1d[s] = static_cast<gmx_fft_t*>(malloc(sizeof(gmx_fft_t) * nthreads));
641
642             /* Make sure that the init routines are only called by one thread at a time and in order
643                (later is only important to not confuse valgrind)
644              */
645 #pragma omp parallel for num_threads(nthreads) schedule(static) ordered
646             for (int t = 0; t < nthreads; t++)
647             {
648 #pragma omp ordered
649                 {
650                     try
651                     {
652                         int tsize = ((t + 1) * pM[s] * pK[s] / nthreads) - (t * pM[s] * pK[s] / nthreads);
653
654                         if ((flags & FFT5D_REALCOMPLEX)
655                             && ((!(flags & FFT5D_BACKWARD) && s == 0)
656                                 || ((flags & FFT5D_BACKWARD) && s == 2)))
657                         {
658                             gmx_fft_init_many_1d_real(
659                                     &plan->p1d[s][t],
660                                     rC[s],
661                                     tsize,
662                                     (flags & FFT5D_NOMEASURE) ? GMX_FFT_FLAG_CONSERVATIVE : 0);
663                         }
664                         else
665                         {
666                             gmx_fft_init_many_1d(&plan->p1d[s][t],
667                                                  C[s],
668                                                  tsize,
669                                                  (flags & FFT5D_NOMEASURE) ? GMX_FFT_FLAG_CONSERVATIVE : 0);
670                         }
671                     }
672                     GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
673                 }
674             }
675         }
676
677 #if GMX_FFT_FFTW3
678     }
679 #endif
680     if ((flags & FFT5D_ORDER_YZ)) /*plan->cart is in the order of transposes */
681     {
682         plan->cart[0] = comm[0];
683         plan->cart[1] = comm[1];
684     }
685     else
686     {
687         plan->cart[1] = comm[0];
688         plan->cart[0] = comm[1];
689     }
690 #ifdef FFT5D_MPI_TRANSPOSE
691     FFTW_LOCK;
692     for (s = 0; s < 2; s++)
693     {
694         if ((s == 0 && !(flags & FFT5D_ORDER_YZ)) || (s == 1 && (flags & FFT5D_ORDER_YZ)))
695         {
696             plan->mpip[s] = FFTW(mpi_plan_many_transpose)(
697                     nP[s], nP[s], N[s] * K[s] * pM[s] * 2, 1, 1, (real*)lout2, (real*)lout3, plan->cart[s], FFTW_PATIENT);
698         }
699         else
700         {
701             plan->mpip[s] = FFTW(mpi_plan_many_transpose)(
702                     nP[s], nP[s], N[s] * pK[s] * M[s] * 2, 1, 1, (real*)lout2, (real*)lout3, plan->cart[s], FFTW_PATIENT);
703         }
704     }
705     FFTW_UNLOCK;
706 #endif
707
708
709     plan->lin   = lin;
710     plan->lout  = lout;
711     plan->lout2 = lout2;
712     plan->lout3 = lout3;
713
714     plan->NG = NG;
715     plan->MG = MG;
716     plan->KG = KG;
717
718     for (s = 0; s < 3; s++)
719     {
720         plan->N[s]     = N[s];
721         plan->M[s]     = M[s];
722         plan->K[s]     = K[s];
723         plan->pN[s]    = pN[s];
724         plan->pM[s]    = pM[s];
725         plan->pK[s]    = pK[s];
726         plan->oM[s]    = oM[s];
727         plan->oK[s]    = oK[s];
728         plan->C[s]     = C[s];
729         plan->rC[s]    = rC[s];
730         plan->iNin[s]  = iNin[s];
731         plan->oNin[s]  = oNin[s];
732         plan->iNout[s] = iNout[s];
733         plan->oNout[s] = oNout[s];
734     }
735     for (s = 0; s < 2; s++)
736     {
737         plan->P[s]    = nP[s];
738         plan->coor[s] = prank[s];
739     }
740
741     /*    plan->fftorder=fftorder;
742         plan->direction=direction;
743         plan->realcomplex=realcomplex;
744      */
745     plan->flags         = flags;
746     plan->nthreads      = nthreads;
747     plan->pinningPolicy = realGridAllocationPinningPolicy;
748     *rlin               = lin;
749     *rlout              = lout;
750     *rlout2             = lout2;
751     *rlout3             = lout3;
752     return plan;
753 }
754
755
756 enum order
757 {
758     XYZ,
759     XZY,
760     YXZ,
761     YZX,
762     ZXY,
763     ZYX
764 };
765
766
767 /*here x,y,z and N,M,K is in rotated coordinate system!!
768    x (and N) is mayor (consecutive) dimension, y (M) middle and z (K) major
769    maxN,maxM,maxK is max size of local data
770    pN, pM, pK is local size specific to current processor (only different to max if not divisible)
771    NG, MG, KG is size of global data*/
772 static void splitaxes(t_complex*       lout,
773                       const t_complex* lin,
774                       int              maxN,
775                       int              maxM,
776                       int              maxK,
777                       int              pM,
778                       int              P,
779                       int              NG,
780                       const int*       N,
781                       const int*       oN,
782                       int              starty,
783                       int              startz,
784                       int              endy,
785                       int              endz)
786 {
787     int x, y, z, i;
788     int in_i, out_i, in_z, out_z, in_y, out_y;
789     int s_y, e_y;
790
791     for (z = startz; z < endz + 1; z++) /*3. z l*/
792     {
793         if (z == startz)
794         {
795             s_y = starty;
796         }
797         else
798         {
799             s_y = 0;
800         }
801         if (z == endz)
802         {
803             e_y = endy;
804         }
805         else
806         {
807             e_y = pM;
808         }
809         out_z = z * maxN * maxM;
810         in_z  = z * NG * pM;
811
812         for (i = 0; i < P; i++) /*index cube along long axis*/
813         {
814             out_i = out_z + i * maxN * maxM * maxK;
815             in_i  = in_z + oN[i];
816             for (y = s_y; y < e_y; y++) /*2. y k*/
817             {
818                 out_y = out_i + y * maxN;
819                 in_y  = in_i + y * NG;
820                 for (x = 0; x < N[i]; x++) /*1. x j*/
821                 {
822                     lout[out_y + x] = lin[in_y + x]; /*in=z*NG*pM+oN[i]+y*NG+x*/
823                     /*after split important that each processor chunk i has size maxN*maxM*maxK and thus being the same size*/
824                     /*before split data contiguos - thus if different processor get different amount oN is different*/
825                 }
826             }
827         }
828     }
829 }
830
831 /*make axis contiguous again (after AllToAll) and also do local transpose*/
832 /*transpose mayor and major dimension
833    variables see above
834    the major, middle, minor order is only correct for x,y,z (N,M,K) for the input
835    N,M,K local dimensions
836    KG global size*/
837 static void joinAxesTrans13(t_complex*       lout,
838                             const t_complex* lin,
839                             int              maxN,
840                             int              maxM,
841                             int              maxK,
842                             int              pM,
843                             int              P,
844                             int              KG,
845                             const int*       K,
846                             const int*       oK,
847                             int              starty,
848                             int              startx,
849                             int              endy,
850                             int              endx)
851 {
852     int i, x, y, z;
853     int out_i, in_i, out_x, in_x, out_z, in_z;
854     int s_y, e_y;
855
856     for (x = startx; x < endx + 1; x++) /*1.j*/
857     {
858         if (x == startx)
859         {
860             s_y = starty;
861         }
862         else
863         {
864             s_y = 0;
865         }
866         if (x == endx)
867         {
868             e_y = endy;
869         }
870         else
871         {
872             e_y = pM;
873         }
874
875         out_x = x * KG * pM;
876         in_x  = x;
877
878         for (i = 0; i < P; i++) /*index cube along long axis*/
879         {
880             out_i = out_x + oK[i];
881             in_i  = in_x + i * maxM * maxN * maxK;
882             for (z = 0; z < K[i]; z++) /*3.l*/
883             {
884                 out_z = out_i + z;
885                 in_z  = in_i + z * maxM * maxN;
886                 for (y = s_y; y < e_y; y++) /*2.k*/
887                 {
888                     lout[out_z + y * KG] = lin[in_z + y * maxN]; /*out=x*KG*pM+oK[i]+z+y*KG*/
889                 }
890             }
891         }
892     }
893 }
894
895 /*make axis contiguous again (after AllToAll) and also do local transpose
896    tranpose mayor and middle dimension
897    variables see above
898    the minor, middle, major order is only correct for x,y,z (N,M,K) for the input
899    N,M,K local size
900    MG, global size*/
901 static void joinAxesTrans12(t_complex*       lout,
902                             const t_complex* lin,
903                             int              maxN,
904                             int              maxM,
905                             int              maxK,
906                             int              pN,
907                             int              P,
908                             int              MG,
909                             const int*       M,
910                             const int*       oM,
911                             int              startx,
912                             int              startz,
913                             int              endx,
914                             int              endz)
915 {
916     int i, z, y, x;
917     int out_i, in_i, out_z, in_z, out_x, in_x;
918     int s_x, e_x;
919
920     for (z = startz; z < endz + 1; z++)
921     {
922         if (z == startz)
923         {
924             s_x = startx;
925         }
926         else
927         {
928             s_x = 0;
929         }
930         if (z == endz)
931         {
932             e_x = endx;
933         }
934         else
935         {
936             e_x = pN;
937         }
938         out_z = z * MG * pN;
939         in_z  = z * maxM * maxN;
940
941         for (i = 0; i < P; i++) /*index cube along long axis*/
942         {
943             out_i = out_z + oM[i];
944             in_i  = in_z + i * maxM * maxN * maxK;
945             for (x = s_x; x < e_x; x++)
946             {
947                 out_x = out_i + x * MG;
948                 in_x  = in_i + x;
949                 for (y = 0; y < M[i]; y++)
950                 {
951                     lout[out_x + y] = lin[in_x + y * maxN]; /*out=z*MG*pN+oM[i]+x*MG+y*/
952                 }
953             }
954         }
955     }
956 }
957
958
959 static void rotate_offsets(int x[])
960 {
961     int t = x[0];
962     /*    x[0]=x[2];
963         x[2]=x[1];
964         x[1]=t;*/
965     x[0] = x[1];
966     x[1] = x[2];
967     x[2] = t;
968 }
969
970 /*compute the offset to compare or print transposed local data in original input coordinates
971    xs matrix dimension size, xl dimension length, xc decomposition offset
972    s: step in computation = number of transposes*/
973 static void compute_offsets(fft5d_plan plan, int xs[], int xl[], int xc[], int NG[], int s)
974 {
975     /*    int direction = plan->direction;
976         int fftorder = plan->fftorder;*/
977
978     int  o = 0;
979     int  pos[3], i;
980     int *pM = plan->pM, *pK = plan->pK, *oM = plan->oM, *oK = plan->oK, *C = plan->C, *rC = plan->rC;
981
982     NG[0] = plan->NG;
983     NG[1] = plan->MG;
984     NG[2] = plan->KG;
985
986     if (!(plan->flags & FFT5D_ORDER_YZ))
987     {
988         switch (s)
989         {
990             case 0: o = XYZ; break;
991             case 1: o = ZYX; break;
992             case 2: o = YZX; break;
993             default: assert(0);
994         }
995     }
996     else
997     {
998         switch (s)
999         {
1000             case 0: o = XYZ; break;
1001             case 1: o = YXZ; break;
1002             case 2: o = ZXY; break;
1003             default: assert(0);
1004         }
1005     }
1006
1007     switch (o)
1008     {
1009         case XYZ:
1010             pos[0] = 1;
1011             pos[1] = 2;
1012             pos[2] = 3;
1013             break;
1014         case XZY:
1015             pos[0] = 1;
1016             pos[1] = 3;
1017             pos[2] = 2;
1018             break;
1019         case YXZ:
1020             pos[0] = 2;
1021             pos[1] = 1;
1022             pos[2] = 3;
1023             break;
1024         case YZX:
1025             pos[0] = 3;
1026             pos[1] = 1;
1027             pos[2] = 2;
1028             break;
1029         case ZXY:
1030             pos[0] = 2;
1031             pos[1] = 3;
1032             pos[2] = 1;
1033             break;
1034         case ZYX:
1035             pos[0] = 3;
1036             pos[1] = 2;
1037             pos[2] = 1;
1038             break;
1039     }
1040     /*if (debug) printf("pos: %d %d %d\n",pos[0],pos[1],pos[2]);*/
1041
1042     /*xs, xl give dimension size and data length in local transposed coordinate system
1043        for 0(/1/2): x(/y/z) in original coordinate system*/
1044     for (i = 0; i < 3; i++)
1045     {
1046         switch (pos[i])
1047         {
1048             case 1:
1049                 xs[i] = 1;
1050                 xc[i] = 0;
1051                 xl[i] = C[s];
1052                 break;
1053             case 2:
1054                 xs[i] = C[s];
1055                 xc[i] = oM[s];
1056                 xl[i] = pM[s];
1057                 break;
1058             case 3:
1059                 xs[i] = C[s] * pM[s];
1060                 xc[i] = oK[s];
1061                 xl[i] = pK[s];
1062                 break;
1063         }
1064     }
1065     /*input order is different for test program to match FFTW order
1066        (important for complex to real)*/
1067     if (plan->flags & FFT5D_BACKWARD)
1068     {
1069         rotate_offsets(xs);
1070         rotate_offsets(xl);
1071         rotate_offsets(xc);
1072         rotate_offsets(NG);
1073         if (plan->flags & FFT5D_ORDER_YZ)
1074         {
1075             rotate_offsets(xs);
1076             rotate_offsets(xl);
1077             rotate_offsets(xc);
1078             rotate_offsets(NG);
1079         }
1080     }
1081     if ((plan->flags & FFT5D_REALCOMPLEX)
1082         && ((!(plan->flags & FFT5D_BACKWARD) && s == 0) || ((plan->flags & FFT5D_BACKWARD) && s == 2)))
1083     {
1084         xl[0] = rC[s];
1085     }
1086 }
1087
1088 static void print_localdata(const t_complex* lin, const char* txt, int s, fft5d_plan plan)
1089 {
1090     int  x, y, z, l;
1091     int* coor = plan->coor;
1092     int  xs[3], xl[3], xc[3], NG[3];
1093     int  ll = (plan->flags & FFT5D_REALCOMPLEX) ? 1 : 2;
1094     compute_offsets(plan, xs, xl, xc, NG, s);
1095     fprintf(debug, txt, coor[0], coor[1]);
1096     /*printf("xs: %d %d %d, xl: %d %d %d\n",xs[0],xs[1],xs[2],xl[0],xl[1],xl[2]);*/
1097     for (z = 0; z < xl[2]; z++)
1098     {
1099         for (y = 0; y < xl[1]; y++)
1100         {
1101             fprintf(debug, "%d %d: ", coor[0], coor[1]);
1102             for (x = 0; x < xl[0]; x++)
1103             {
1104                 for (l = 0; l < ll; l++)
1105                 {
1106                     fprintf(debug,
1107                             "%f ",
1108                             reinterpret_cast<const real*>(
1109                                     lin)[(z * xs[2] + y * xs[1]) * 2 + (x * xs[0]) * ll + l]);
1110                 }
1111                 fprintf(debug, ",");
1112             }
1113             fprintf(debug, "\n");
1114         }
1115     }
1116 }
1117
1118 void fft5d_execute(fft5d_plan plan, int thread, fft5d_time times)
1119 {
1120     t_complex* lin   = plan->lin;
1121     t_complex* lout  = plan->lout;
1122     t_complex* lout2 = plan->lout2;
1123     t_complex* lout3 = plan->lout3;
1124     t_complex *fftout, *joinin;
1125
1126     gmx_fft_t** p1d = plan->p1d;
1127 #ifdef FFT5D_MPI_TRANSPOSE
1128     FFTW(plan)* mpip = plan->mpip;
1129 #endif
1130 #if GMX_MPI
1131     MPI_Comm* cart = plan->cart;
1132 #endif
1133 #ifdef NOGMX
1134     double time_fft = 0, time_local = 0, time_mpi[2] = { 0 }, time = 0;
1135 #endif
1136     int *N = plan->N, *M = plan->M, *K = plan->K, *pN = plan->pN, *pM = plan->pM, *pK = plan->pK,
1137         *C = plan->C, *P = plan->P, **iNin = plan->iNin, **oNin = plan->oNin, **iNout = plan->iNout,
1138         **oNout = plan->oNout;
1139     int s       = 0, tstart, tend, bParallelDim;
1140
1141
1142 #if GMX_FFT_FFTW3
1143     if (plan->p3d)
1144     {
1145         if (thread == 0)
1146         {
1147 #    ifdef NOGMX
1148             if (times != 0)
1149             {
1150                 time = MPI_Wtime();
1151             }
1152 #    endif
1153             FFTW(execute)(plan->p3d);
1154 #    ifdef NOGMX
1155             if (times != 0)
1156             {
1157                 times->fft += MPI_Wtime() - time;
1158             }
1159 #    endif
1160         }
1161         return;
1162     }
1163 #endif
1164
1165     s = 0;
1166
1167     /*lin: x,y,z*/
1168     if ((plan->flags & FFT5D_DEBUG) && thread == 0)
1169     {
1170         print_localdata(lin, "%d %d: copy in lin\n", s, plan);
1171     }
1172
1173     for (s = 0; s < 2; s++) /*loop over first two FFT steps (corner rotations)*/
1174
1175     {
1176 #if GMX_MPI
1177         if (GMX_PARALLEL_ENV_INITIALIZED && cart[s] != MPI_COMM_NULL && P[s] > 1)
1178         {
1179             bParallelDim = 1;
1180         }
1181         else
1182 #endif
1183         {
1184             bParallelDim = 0;
1185         }
1186
1187         /* ---------- START FFT ------------ */
1188 #ifdef NOGMX
1189         if (times != 0 && thread == 0)
1190         {
1191             time = MPI_Wtime();
1192         }
1193 #endif
1194
1195         if (bParallelDim || plan->nthreads == 1)
1196         {
1197             fftout = lout;
1198         }
1199         else
1200         {
1201             if (s == 0)
1202             {
1203                 fftout = lout3;
1204             }
1205             else
1206             {
1207                 fftout = lout2;
1208             }
1209         }
1210
1211         tstart = (thread * pM[s] * pK[s] / plan->nthreads) * C[s];
1212         if ((plan->flags & FFT5D_REALCOMPLEX) && !(plan->flags & FFT5D_BACKWARD) && s == 0)
1213         {
1214             gmx_fft_many_1d_real(p1d[s][thread],
1215                                  (plan->flags & FFT5D_BACKWARD) ? GMX_FFT_COMPLEX_TO_REAL
1216                                                                 : GMX_FFT_REAL_TO_COMPLEX,
1217                                  lin + tstart,
1218                                  fftout + tstart);
1219         }
1220         else
1221         {
1222             gmx_fft_many_1d(p1d[s][thread],
1223                             (plan->flags & FFT5D_BACKWARD) ? GMX_FFT_BACKWARD : GMX_FFT_FORWARD,
1224                             lin + tstart,
1225                             fftout + tstart);
1226         }
1227
1228 #ifdef NOGMX
1229         if (times != NULL && thread == 0)
1230         {
1231             time_fft += MPI_Wtime() - time;
1232         }
1233 #endif
1234         if ((plan->flags & FFT5D_DEBUG) && thread == 0)
1235         {
1236             print_localdata(lout, "%d %d: FFT %d\n", s, plan);
1237         }
1238         /* ---------- END FFT ------------ */
1239
1240         /* ---------- START SPLIT + TRANSPOSE------------ (if parallel in in this dimension)*/
1241         if (bParallelDim)
1242         {
1243 #ifdef NOGMX
1244             if (times != NULL && thread == 0)
1245             {
1246                 time = MPI_Wtime();
1247             }
1248 #endif
1249             /*prepare for A
1250                llToAll
1251                1. (most outer) axes (x) is split into P[s] parts of size N[s]
1252                for sending*/
1253             if (pM[s] > 0)
1254             {
1255                 tend = ((thread + 1) * pM[s] * pK[s] / plan->nthreads);
1256                 tstart /= C[s];
1257                 splitaxes(lout2,
1258                           lout,
1259                           N[s],
1260                           M[s],
1261                           K[s],
1262                           pM[s],
1263                           P[s],
1264                           C[s],
1265                           iNout[s],
1266                           oNout[s],
1267                           tstart % pM[s],
1268                           tstart / pM[s],
1269                           tend % pM[s],
1270                           tend / pM[s]);
1271             }
1272 #pragma omp barrier /*barrier required before AllToAll (all input has to be their) - before timing to make timing more acurate*/
1273 #ifdef NOGMX
1274             if (times != NULL && thread == 0)
1275             {
1276                 time_local += MPI_Wtime() - time;
1277             }
1278 #endif
1279
1280             /* ---------- END SPLIT , START TRANSPOSE------------ */
1281
1282             if (thread == 0)
1283             {
1284 #ifdef NOGMX
1285                 if (times != 0)
1286                 {
1287                     time = MPI_Wtime();
1288                 }
1289 #else
1290                 wallcycle_start(times, WallCycleCounter::PmeFftComm);
1291 #endif
1292 #ifdef FFT5D_MPI_TRANSPOSE
1293                 FFTW(execute)(mpip[s]);
1294 #else
1295 #    if GMX_MPI
1296                 if ((s == 0 && !(plan->flags & FFT5D_ORDER_YZ))
1297                     || (s == 1 && (plan->flags & FFT5D_ORDER_YZ)))
1298                 {
1299                     MPI_Alltoall(reinterpret_cast<real*>(lout2),
1300                                  N[s] * pM[s] * K[s] * sizeof(t_complex) / sizeof(real),
1301                                  GMX_MPI_REAL,
1302                                  reinterpret_cast<real*>(lout3),
1303                                  N[s] * pM[s] * K[s] * sizeof(t_complex) / sizeof(real),
1304                                  GMX_MPI_REAL,
1305                                  cart[s]);
1306                 }
1307                 else
1308                 {
1309                     MPI_Alltoall(reinterpret_cast<real*>(lout2),
1310                                  N[s] * M[s] * pK[s] * sizeof(t_complex) / sizeof(real),
1311                                  GMX_MPI_REAL,
1312                                  reinterpret_cast<real*>(lout3),
1313                                  N[s] * M[s] * pK[s] * sizeof(t_complex) / sizeof(real),
1314                                  GMX_MPI_REAL,
1315                                  cart[s]);
1316                 }
1317 #    else
1318                 GMX_RELEASE_ASSERT(false, "Invalid call to fft5d_execute");
1319 #    endif /*GMX_MPI*/
1320 #endif     /*FFT5D_MPI_TRANSPOSE*/
1321 #ifdef NOGMX
1322                 if (times != 0)
1323                 {
1324                     time_mpi[s] = MPI_Wtime() - time;
1325                 }
1326 #else
1327                 wallcycle_stop(times, WallCycleCounter::PmeFftComm);
1328 #endif
1329             }       /*master*/
1330         }           /* bPrallelDim */
1331 #pragma omp barrier /*both needed for parallel and non-parallel dimension (either have to wait on data from AlltoAll or from last FFT*/
1332
1333         /* ---------- END SPLIT + TRANSPOSE------------ */
1334
1335         /* ---------- START JOIN ------------ */
1336 #ifdef NOGMX
1337         if (times != NULL && thread == 0)
1338         {
1339             time = MPI_Wtime();
1340         }
1341 #endif
1342
1343         if (bParallelDim)
1344         {
1345             joinin = lout3;
1346         }
1347         else
1348         {
1349             joinin = fftout;
1350         }
1351         /*bring back in matrix form
1352            thus make  new 1. axes contiguos
1353            also local transpose 1 and 2/3
1354            runs on thread used for following FFT (thus needing a barrier before but not afterwards)
1355          */
1356         if ((s == 0 && !(plan->flags & FFT5D_ORDER_YZ)) || (s == 1 && (plan->flags & FFT5D_ORDER_YZ)))
1357         {
1358             if (pM[s] > 0)
1359             {
1360                 tstart = (thread * pM[s] * pN[s] / plan->nthreads);
1361                 tend   = ((thread + 1) * pM[s] * pN[s] / plan->nthreads);
1362                 joinAxesTrans13(lin,
1363                                 joinin,
1364                                 N[s],
1365                                 pM[s],
1366                                 K[s],
1367                                 pM[s],
1368                                 P[s],
1369                                 C[s + 1],
1370                                 iNin[s + 1],
1371                                 oNin[s + 1],
1372                                 tstart % pM[s],
1373                                 tstart / pM[s],
1374                                 tend % pM[s],
1375                                 tend / pM[s]);
1376             }
1377         }
1378         else
1379         {
1380             if (pN[s] > 0)
1381             {
1382                 tstart = (thread * pK[s] * pN[s] / plan->nthreads);
1383                 tend   = ((thread + 1) * pK[s] * pN[s] / plan->nthreads);
1384                 joinAxesTrans12(lin,
1385                                 joinin,
1386                                 N[s],
1387                                 M[s],
1388                                 pK[s],
1389                                 pN[s],
1390                                 P[s],
1391                                 C[s + 1],
1392                                 iNin[s + 1],
1393                                 oNin[s + 1],
1394                                 tstart % pN[s],
1395                                 tstart / pN[s],
1396                                 tend % pN[s],
1397                                 tend / pN[s]);
1398             }
1399         }
1400
1401 #ifdef NOGMX
1402         if (times != NULL && thread == 0)
1403         {
1404             time_local += MPI_Wtime() - time;
1405         }
1406 #endif
1407         if ((plan->flags & FFT5D_DEBUG) && thread == 0)
1408         {
1409             print_localdata(lin, "%d %d: tranposed %d\n", s + 1, plan);
1410         }
1411         /* ---------- END JOIN ------------ */
1412
1413         /*if (debug) print_localdata(lin, "%d %d: transposed x-z\n", N1, M0, K, ZYX, coor);*/
1414     } /* for(s=0;s<2;s++) */
1415 #ifdef NOGMX
1416     if (times != NULL && thread == 0)
1417     {
1418         time = MPI_Wtime();
1419     }
1420 #endif
1421
1422     if (plan->flags & FFT5D_INPLACE)
1423     {
1424         lout = lin; /*in place currently not supported*/
1425     }
1426     /*  ----------- FFT ----------- */
1427     tstart = (thread * pM[s] * pK[s] / plan->nthreads) * C[s];
1428     if ((plan->flags & FFT5D_REALCOMPLEX) && (plan->flags & FFT5D_BACKWARD))
1429     {
1430         gmx_fft_many_1d_real(p1d[s][thread],
1431                              (plan->flags & FFT5D_BACKWARD) ? GMX_FFT_COMPLEX_TO_REAL : GMX_FFT_REAL_TO_COMPLEX,
1432                              lin + tstart,
1433                              lout + tstart);
1434     }
1435     else
1436     {
1437         gmx_fft_many_1d(p1d[s][thread],
1438                         (plan->flags & FFT5D_BACKWARD) ? GMX_FFT_BACKWARD : GMX_FFT_FORWARD,
1439                         lin + tstart,
1440                         lout + tstart);
1441     }
1442     /* ------------ END FFT ---------*/
1443
1444 #ifdef NOGMX
1445     if (times != NULL && thread == 0)
1446     {
1447         time_fft += MPI_Wtime() - time;
1448
1449         times->fft += time_fft;
1450         times->local += time_local;
1451         times->mpi2 += time_mpi[1];
1452         times->mpi1 += time_mpi[0];
1453     }
1454 #endif
1455
1456     if ((plan->flags & FFT5D_DEBUG) && thread == 0)
1457     {
1458         print_localdata(lout, "%d %d: FFT %d\n", s, plan);
1459     }
1460 }
1461
1462 void fft5d_destroy(fft5d_plan plan)
1463 {
1464     int s, t;
1465
1466     for (s = 0; s < 3; s++)
1467     {
1468         if (plan->p1d[s])
1469         {
1470             for (t = 0; t < plan->nthreads; t++)
1471             {
1472                 gmx_many_fft_destroy(plan->p1d[s][t]);
1473             }
1474             free(plan->p1d[s]);
1475         }
1476         if (plan->iNin[s])
1477         {
1478             free(plan->iNin[s]);
1479             plan->iNin[s] = nullptr;
1480         }
1481         if (plan->oNin[s])
1482         {
1483             free(plan->oNin[s]);
1484             plan->oNin[s] = nullptr;
1485         }
1486         if (plan->iNout[s])
1487         {
1488             free(plan->iNout[s]);
1489             plan->iNout[s] = nullptr;
1490         }
1491         if (plan->oNout[s])
1492         {
1493             free(plan->oNout[s]);
1494             plan->oNout[s] = nullptr;
1495         }
1496     }
1497 #if GMX_FFT_FFTW3
1498     FFTW_LOCK
1499 #    ifdef FFT5D_MPI_TRANSPOS
1500     for (s = 0; s < 2; s++)
1501     {
1502         FFTW(destroy_plan)(plan->mpip[s]);
1503     }
1504 #    endif /* FFT5D_MPI_TRANSPOS */
1505     if (plan->p3d)
1506     {
1507         FFTW(destroy_plan)(plan->p3d);
1508     }
1509     FFTW_UNLOCK
1510 #endif /* GMX_FFT_FFTW3 */
1511
1512     if (!(plan->flags & FFT5D_NOMALLOC))
1513     {
1514         // only needed for PME GPU mixed mode
1515         if (plan->pinningPolicy == gmx::PinningPolicy::PinnedIfSupported && isHostMemoryPinned(plan->lin))
1516         {
1517             gmx::unpinBuffer(plan->lin);
1518         }
1519         sfree_aligned(plan->lin);
1520         sfree_aligned(plan->lout);
1521         if (plan->nthreads > 1)
1522         {
1523             sfree_aligned(plan->lout2);
1524             sfree_aligned(plan->lout3);
1525         }
1526     }
1527
1528 #ifdef FFT5D_THREADS
1529 #    ifdef FFT5D_FFTW_THREADS
1530     /*FFTW(cleanup_threads)();*/
1531 #    endif
1532 #endif
1533
1534     free(plan);
1535 }
1536
1537 /*Is this better than direct access of plan? enough data?
1538    here 0,1 reference divided by which processor grid dimension (not FFT step!)*/
1539 void fft5d_local_size(fft5d_plan plan, int* N1, int* M0, int* K0, int* K1, int** coor)
1540 {
1541     *N1 = plan->N[0];
1542     *M0 = plan->M[0];
1543     *K1 = plan->K[0];
1544     *K0 = plan->N[1];
1545
1546     *coor = plan->coor;
1547 }
1548
1549
1550 /*same as fft5d_plan_3d but with cartesian coordinator and automatic splitting
1551    of processor dimensions*/
1552 fft5d_plan fft5d_plan_3d_cart(int         NG,
1553                               int         MG,
1554                               int         KG,
1555                               MPI_Comm    comm,
1556                               int         P0,
1557                               int         flags,
1558                               t_complex** rlin,
1559                               t_complex** rlout,
1560                               t_complex** rlout2,
1561                               t_complex** rlout3,
1562                               int         nthreads)
1563 {
1564     MPI_Comm cart[2] = { MPI_COMM_NULL, MPI_COMM_NULL };
1565 #if GMX_MPI
1566     int      size = 1, prank = 0;
1567     int      P[2];
1568     int      coor[2];
1569     int      wrap[] = { 0, 0 };
1570     MPI_Comm gcart;
1571     int      rdim1[] = { 0, 1 }, rdim2[] = { 1, 0 };
1572
1573     MPI_Comm_size(comm, &size);
1574     MPI_Comm_rank(comm, &prank);
1575
1576     if (P0 == 0)
1577     {
1578         P0 = lfactor(size);
1579     }
1580     if (size % P0 != 0)
1581     {
1582         if (prank == 0)
1583         {
1584             printf("FFT5D: WARNING: Number of ranks %d not evenly divisible by %d\n", size, P0);
1585         }
1586         P0 = lfactor(size);
1587     }
1588
1589     P[0] = P0;
1590     P[1] = size / P0; /*number of processors in the two dimensions*/
1591
1592     /*Difference between x-y-z regarding 2d decomposition is whether they are
1593        distributed along axis 1, 2 or both*/
1594
1595     MPI_Cart_create(comm, 2, P, wrap, 1, &gcart); /*parameter 4: value 1: reorder*/
1596     MPI_Cart_get(gcart, 2, P, wrap, coor);
1597     MPI_Cart_sub(gcart, rdim1, &cart[0]);
1598     MPI_Cart_sub(gcart, rdim2, &cart[1]);
1599 #else
1600     (void)P0;
1601     (void)comm;
1602 #endif
1603     return fft5d_plan_3d(NG, MG, KG, cart, flags, rlin, rlout, rlout2, rlout3, nthreads);
1604 }
1605
1606
1607 /*prints in original coordinate system of data (as the input to FFT)*/
1608 void fft5d_compare_data(const t_complex* lin, const t_complex* in, fft5d_plan plan, int bothLocal, int normalize)
1609 {
1610     int  xs[3], xl[3], xc[3], NG[3];
1611     int  x, y, z, l;
1612     int* coor = plan->coor;
1613     int  ll   = 2; /*compare ll values per element (has to be 2 for complex)*/
1614     if ((plan->flags & FFT5D_REALCOMPLEX) && (plan->flags & FFT5D_BACKWARD))
1615     {
1616         ll = 1;
1617     }
1618
1619     compute_offsets(plan, xs, xl, xc, NG, 2);
1620     if (plan->flags & FFT5D_DEBUG)
1621     {
1622         printf("Compare2\n");
1623     }
1624     for (z = 0; z < xl[2]; z++)
1625     {
1626         for (y = 0; y < xl[1]; y++)
1627         {
1628             if (plan->flags & FFT5D_DEBUG)
1629             {
1630                 printf("%d %d: ", coor[0], coor[1]);
1631             }
1632             for (x = 0; x < xl[0]; x++)
1633             {
1634                 for (l = 0; l < ll; l++) /*loop over real/complex parts*/
1635                 {
1636                     real a, b;
1637                     a = reinterpret_cast<const real*>(lin)[(z * xs[2] + y * xs[1]) * 2 + x * xs[0] * ll + l];
1638                     if (normalize)
1639                     {
1640                         a /= plan->rC[0] * plan->rC[1] * plan->rC[2];
1641                     }
1642                     if (!bothLocal)
1643                     {
1644                         b = reinterpret_cast<const real*>(
1645                                 in)[((z + xc[2]) * NG[0] * NG[1] + (y + xc[1]) * NG[0]) * 2 + (x + xc[0]) * ll + l];
1646                     }
1647                     else
1648                     {
1649                         b = reinterpret_cast<const real*>(
1650                                 in)[(z * xs[2] + y * xs[1]) * 2 + x * xs[0] * ll + l];
1651                     }
1652                     if (plan->flags & FFT5D_DEBUG)
1653                     {
1654                         printf("%f %f, ", a, b);
1655                     }
1656                     else
1657                     {
1658                         if (std::fabs(a - b) > 2 * NG[0] * NG[1] * NG[2] * GMX_REAL_EPS)
1659                         {
1660                             printf("result incorrect on %d,%d at %d,%d,%d: FFT5D:%f reference:%f\n",
1661                                    coor[0],
1662                                    coor[1],
1663                                    x,
1664                                    y,
1665                                    z,
1666                                    a,
1667                                    b);
1668                         }
1669                         /*                        assert(fabs(a-b)<2*NG[0]*NG[1]*NG[2]*GMX_REAL_EPS);*/
1670                     }
1671                 }
1672                 if (plan->flags & FFT5D_DEBUG)
1673                 {
1674                     printf(",");
1675                 }
1676             }
1677             if (plan->flags & FFT5D_DEBUG)
1678             {
1679                 printf("\n");
1680             }
1681         }
1682     }
1683 }