Replace gmx::Mutex with std::mutex
[alexxy/gromacs.git] / src / gromacs / fft / fft5d.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2009-2018, The GROMACS development team.
5  * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36 #include "gmxpre.h"
37
38 #include "fft5d.h"
39
40 #include "config.h"
41
42 #include <cassert>
43 #include <cfloat>
44 #include <cmath>
45 #include <cstdio>
46 #include <cstdlib>
47 #include <cstring>
48
49 #include <algorithm>
50
51 #include "gromacs/gpu_utils/gpu_utils.h"
52 #include "gromacs/gpu_utils/hostallocator.h"
53 #include "gromacs/gpu_utils/pinning.h"
54 #include "gromacs/utility/alignedallocator.h"
55 #include "gromacs/utility/exceptions.h"
56 #include "gromacs/utility/fatalerror.h"
57 #include "gromacs/utility/gmxassert.h"
58 #include "gromacs/utility/gmxmpi.h"
59 #include "gromacs/utility/smalloc.h"
60
61 #ifdef NOGMX
62 #    define GMX_PARALLEL_ENV_INITIALIZED 1
63 #else
64 #    if GMX_MPI
65 #        define GMX_PARALLEL_ENV_INITIALIZED 1
66 #    else
67 #        define GMX_PARALLEL_ENV_INITIALIZED 0
68 #    endif
69 #endif
70
71 #if GMX_OPENMP
72 /* TODO: Do we still need this? Are we still planning ot use fftw + OpenMP? */
73 #    define FFT5D_THREADS
74 /* requires fftw compiled with openmp */
75 /* #define FFT5D_FFTW_THREADS (now set by cmake) */
76 #endif
77
78 #ifndef __FLT_EPSILON__
79 #    define __FLT_EPSILON__ FLT_EPSILON
80 #    define __DBL_EPSILON__ DBL_EPSILON
81 #endif
82
83 #ifdef NOGMX
84 FILE* debug = 0;
85 #endif
86
87 #if GMX_FFT_FFTW3
88
89 #    include <mutex>
90
91 #    include "gromacs/utility/exceptions.h"
92
93 /* none of the fftw3 calls, except execute(), are thread-safe, so
94    we need to serialize them with this mutex. */
95 static std::mutex big_fftw_mutex;
96 #    define FFTW_LOCK              \
97         try                        \
98         {                          \
99             big_fftw_mutex.lock(); \
100         }                          \
101         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
102 #    define FFTW_UNLOCK              \
103         try                          \
104         {                            \
105             big_fftw_mutex.unlock(); \
106         }                            \
107         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
108 #endif /* GMX_FFT_FFTW3 */
109
110 #if GMX_MPI
111 /* largest factor smaller than sqrt */
112 static int lfactor(int z)
113 {
114     int i = static_cast<int>(sqrt(static_cast<double>(z)));
115     while (z % i != 0)
116     {
117         i--;
118     }
119     return i;
120 }
121 #endif
122
123 #if !GMX_MPI
124 #    if HAVE_GETTIMEOFDAY
125 #        include <sys/time.h>
126 double MPI_Wtime()
127 {
128     struct timeval tv;
129     gettimeofday(&tv, nullptr);
130     return tv.tv_sec + tv.tv_usec * 1e-6;
131 }
132 #    else
133 double MPI_Wtime()
134 {
135     return 0.0;
136 }
137 #    endif
138 #endif
139
140 static int vmax(const int* a, int s)
141 {
142     int i, max = 0;
143     for (i = 0; i < s; i++)
144     {
145         if (a[i] > max)
146         {
147             max = a[i];
148         }
149     }
150     return max;
151 }
152
153
154 /* NxMxK the size of the data
155  * comm communicator to use for fft5d
156  * P0 number of processor in 1st axes (can be null for automatic)
157  * lin is allocated by fft5d because size of array is only known after planning phase
158  * rlout2 is only used as intermediate buffer - only returned after allocation to reuse for back transform - should not be used by caller
159  */
160 fft5d_plan fft5d_plan_3d(int                NG,
161                          int                MG,
162                          int                KG,
163                          MPI_Comm           comm[2],
164                          int                flags,
165                          t_complex**        rlin,
166                          t_complex**        rlout,
167                          t_complex**        rlout2,
168                          t_complex**        rlout3,
169                          int                nthreads,
170                          gmx::PinningPolicy realGridAllocationPinningPolicy)
171 {
172
173     int  P[2], prank[2], i;
174     bool bMaster;
175     int  rNG, rMG, rKG;
176     int *N0 = nullptr, *N1 = nullptr, *M0 = nullptr, *M1 = nullptr, *K0 = nullptr, *K1 = nullptr,
177         *oN0 = nullptr, *oN1 = nullptr, *oM0 = nullptr, *oM1 = nullptr, *oK0 = nullptr, *oK1 = nullptr;
178     int N[3], M[3], K[3], pN[3], pM[3], pK[3], oM[3], oK[3],
179             *iNin[3] = { nullptr }, *oNin[3] = { nullptr }, *iNout[3] = { nullptr },
180             *oNout[3] = { nullptr };
181     int        C[3], rC[3], nP[2];
182     int        lsize;
183     t_complex *lin = nullptr, *lout = nullptr, *lout2 = nullptr, *lout3 = nullptr;
184     fft5d_plan plan;
185     int        s;
186
187     /* comm, prank and P are in the order of the decomposition (plan->cart is in the order of transposes) */
188 #if GMX_MPI
189     if (GMX_PARALLEL_ENV_INITIALIZED && comm[0] != MPI_COMM_NULL)
190     {
191         MPI_Comm_size(comm[0], &P[0]);
192         MPI_Comm_rank(comm[0], &prank[0]);
193     }
194     else
195 #endif
196     {
197         P[0]     = 1;
198         prank[0] = 0;
199     }
200 #if GMX_MPI
201     if (GMX_PARALLEL_ENV_INITIALIZED && comm[1] != MPI_COMM_NULL)
202     {
203         MPI_Comm_size(comm[1], &P[1]);
204         MPI_Comm_rank(comm[1], &prank[1]);
205     }
206     else
207 #endif
208     {
209         P[1]     = 1;
210         prank[1] = 0;
211     }
212
213     bMaster = prank[0] == 0 && prank[1] == 0;
214
215
216     if (debug)
217     {
218         fprintf(debug, "FFT5D: Using %dx%d rank grid, rank %d,%d\n", P[0], P[1], prank[0], prank[1]);
219     }
220
221     if (bMaster)
222     {
223         if (debug)
224         {
225             fprintf(debug,
226                     "FFT5D: N: %d, M: %d, K: %d, P: %dx%d, real2complex: %d, backward: %d, order "
227                     "yz: %d, debug %d\n",
228                     NG,
229                     MG,
230                     KG,
231                     P[0],
232                     P[1],
233                     int((flags & FFT5D_REALCOMPLEX) > 0),
234                     int((flags & FFT5D_BACKWARD) > 0),
235                     int((flags & FFT5D_ORDER_YZ) > 0),
236                     int((flags & FFT5D_DEBUG) > 0));
237         }
238         /* The check below is not correct, one prime factor 11 or 13 is ok.
239            if (fft5d_fmax(fft5d_fmax(lpfactor(NG),lpfactor(MG)),lpfactor(KG))>7) {
240             printf("WARNING: FFT very slow with prime factors larger 7\n");
241             printf("Change FFT size or in case you cannot change it look at\n");
242             printf("http://www.fftw.org/fftw3_doc/Generating-your-own-code.html\n");
243            }
244          */
245     }
246
247     if (NG == 0 || MG == 0 || KG == 0)
248     {
249         if (bMaster)
250         {
251             printf("FFT5D: FATAL: Datasize cannot be zero in any dimension\n");
252         }
253         return nullptr;
254     }
255
256     rNG = NG;
257     rMG = MG;
258     rKG = KG;
259
260     if (flags & FFT5D_REALCOMPLEX)
261     {
262         if (!(flags & FFT5D_BACKWARD))
263         {
264             NG = NG / 2 + 1;
265         }
266         else
267         {
268             if (!(flags & FFT5D_ORDER_YZ))
269             {
270                 MG = MG / 2 + 1;
271             }
272             else
273             {
274                 KG = KG / 2 + 1;
275             }
276         }
277     }
278
279
280     /*for transpose we need to know the size for each processor not only our own size*/
281
282     N0  = static_cast<int*>(malloc(P[0] * sizeof(int)));
283     N1  = static_cast<int*>(malloc(P[1] * sizeof(int)));
284     M0  = static_cast<int*>(malloc(P[0] * sizeof(int)));
285     M1  = static_cast<int*>(malloc(P[1] * sizeof(int)));
286     K0  = static_cast<int*>(malloc(P[0] * sizeof(int)));
287     K1  = static_cast<int*>(malloc(P[1] * sizeof(int)));
288     oN0 = static_cast<int*>(malloc(P[0] * sizeof(int)));
289     oN1 = static_cast<int*>(malloc(P[1] * sizeof(int)));
290     oM0 = static_cast<int*>(malloc(P[0] * sizeof(int)));
291     oM1 = static_cast<int*>(malloc(P[1] * sizeof(int)));
292     oK0 = static_cast<int*>(malloc(P[0] * sizeof(int)));
293     oK1 = static_cast<int*>(malloc(P[1] * sizeof(int)));
294
295     for (i = 0; i < P[0]; i++)
296     {
297 #define EVENDIST
298 #ifndef EVENDIST
299         oN0[i] = i * ceil((double)NG / P[0]);
300         oM0[i] = i * ceil((double)MG / P[0]);
301         oK0[i] = i * ceil((double)KG / P[0]);
302 #else
303         oN0[i] = (NG * i) / P[0];
304         oM0[i] = (MG * i) / P[0];
305         oK0[i] = (KG * i) / P[0];
306 #endif
307     }
308     for (i = 0; i < P[1]; i++)
309     {
310 #ifndef EVENDIST
311         oN1[i] = i * ceil((double)NG / P[1]);
312         oM1[i] = i * ceil((double)MG / P[1]);
313         oK1[i] = i * ceil((double)KG / P[1]);
314 #else
315         oN1[i] = (NG * i) / P[1];
316         oM1[i] = (MG * i) / P[1];
317         oK1[i] = (KG * i) / P[1];
318 #endif
319     }
320     for (i = 0; P[0] > 0 && i < P[0] - 1; i++)
321     {
322         N0[i] = oN0[i + 1] - oN0[i];
323         M0[i] = oM0[i + 1] - oM0[i];
324         K0[i] = oK0[i + 1] - oK0[i];
325     }
326     N0[P[0] - 1] = NG - oN0[P[0] - 1];
327     M0[P[0] - 1] = MG - oM0[P[0] - 1];
328     K0[P[0] - 1] = KG - oK0[P[0] - 1];
329     for (i = 0; P[1] > 0 && i < P[1] - 1; i++)
330     {
331         N1[i] = oN1[i + 1] - oN1[i];
332         M1[i] = oM1[i + 1] - oM1[i];
333         K1[i] = oK1[i + 1] - oK1[i];
334     }
335     N1[P[1] - 1] = NG - oN1[P[1] - 1];
336     M1[P[1] - 1] = MG - oM1[P[1] - 1];
337     K1[P[1] - 1] = KG - oK1[P[1] - 1];
338
339     /* for step 1-3 the local N,M,K sizes of the transposed system
340        C: contiguous dimension, and nP: number of processor in subcommunicator
341        for that step */
342
343     GMX_ASSERT(prank[0] < P[0], "Must have valid rank within communicator size");
344     GMX_ASSERT(prank[1] < P[1], "Must have valid rank within communicator size");
345     pM[0] = M0[prank[0]];
346     oM[0] = oM0[prank[0]];
347     pK[0] = K1[prank[1]];
348     oK[0] = oK1[prank[1]];
349     C[0]  = NG;
350     rC[0] = rNG;
351     if (!(flags & FFT5D_ORDER_YZ))
352     {
353         N[0]     = vmax(N1, P[1]);
354         M[0]     = M0[prank[0]];
355         K[0]     = vmax(K1, P[1]);
356         pN[0]    = N1[prank[1]];
357         iNout[0] = N1;
358         oNout[0] = oN1;
359         nP[0]    = P[1];
360         C[1]     = KG;
361         rC[1]    = rKG;
362         N[1]     = vmax(K0, P[0]);
363         pN[1]    = K0[prank[0]];
364         iNin[1]  = K1;
365         oNin[1]  = oK1;
366         iNout[1] = K0;
367         oNout[1] = oK0;
368         M[1]     = vmax(M0, P[0]);
369         pM[1]    = M0[prank[0]];
370         oM[1]    = oM0[prank[0]];
371         K[1]     = N1[prank[1]];
372         pK[1]    = N1[prank[1]];
373         oK[1]    = oN1[prank[1]];
374         nP[1]    = P[0];
375         C[2]     = MG;
376         rC[2]    = rMG;
377         iNin[2]  = M0;
378         oNin[2]  = oM0;
379         M[2]     = vmax(K0, P[0]);
380         pM[2]    = K0[prank[0]];
381         oM[2]    = oK0[prank[0]];
382         K[2]     = vmax(N1, P[1]);
383         pK[2]    = N1[prank[1]];
384         oK[2]    = oN1[prank[1]];
385         free(N0);
386         free(oN0); /*these are not used for this order*/
387         free(M1);
388         free(oM1); /*the rest is freed in destroy*/
389     }
390     else
391     {
392         N[0]     = vmax(N0, P[0]);
393         M[0]     = vmax(M0, P[0]);
394         K[0]     = K1[prank[1]];
395         pN[0]    = N0[prank[0]];
396         iNout[0] = N0;
397         oNout[0] = oN0;
398         nP[0]    = P[0];
399         C[1]     = MG;
400         rC[1]    = rMG;
401         N[1]     = vmax(M1, P[1]);
402         pN[1]    = M1[prank[1]];
403         iNin[1]  = M0;
404         oNin[1]  = oM0;
405         iNout[1] = M1;
406         oNout[1] = oM1;
407         M[1]     = N0[prank[0]];
408         pM[1]    = N0[prank[0]];
409         oM[1]    = oN0[prank[0]];
410         K[1]     = vmax(K1, P[1]);
411         pK[1]    = K1[prank[1]];
412         oK[1]    = oK1[prank[1]];
413         nP[1]    = P[1];
414         C[2]     = KG;
415         rC[2]    = rKG;
416         iNin[2]  = K1;
417         oNin[2]  = oK1;
418         M[2]     = vmax(N0, P[0]);
419         pM[2]    = N0[prank[0]];
420         oM[2]    = oN0[prank[0]];
421         K[2]     = vmax(M1, P[1]);
422         pK[2]    = M1[prank[1]];
423         oK[2]    = oM1[prank[1]];
424         free(N1);
425         free(oN1); /*these are not used for this order*/
426         free(K0);
427         free(oK0); /*the rest is freed in destroy*/
428     }
429     N[2] = pN[2] = -1; /*not used*/
430
431     /*
432        Difference between x-y-z regarding 2d decomposition is whether they are
433        distributed along axis 1, 2 or both
434      */
435
436     /* int lsize = fmax(N[0]*M[0]*K[0]*nP[0],N[1]*M[1]*K[1]*nP[1]); */
437     lsize = std::max(N[0] * M[0] * K[0] * nP[0], std::max(N[1] * M[1] * K[1] * nP[1], C[2] * M[2] * K[2]));
438     /* int lsize = fmax(C[0]*M[0]*K[0],fmax(C[1]*M[1]*K[1],C[2]*M[2]*K[2])); */
439     if (!(flags & FFT5D_NOMALLOC))
440     {
441         // only needed for PME GPU mixed mode
442         if (GMX_GPU_CUDA && realGridAllocationPinningPolicy == gmx::PinningPolicy::PinnedIfSupported)
443         {
444             const std::size_t numBytes = lsize * sizeof(t_complex);
445             lin = static_cast<t_complex*>(gmx::PageAlignedAllocationPolicy::malloc(numBytes));
446             gmx::pinBuffer(lin, numBytes);
447         }
448         else
449         {
450             snew_aligned(lin, lsize, 32);
451         }
452         snew_aligned(lout, lsize, 32);
453         if (nthreads > 1)
454         {
455             /* We need extra transpose buffers to avoid OpenMP barriers */
456             snew_aligned(lout2, lsize, 32);
457             snew_aligned(lout3, lsize, 32);
458         }
459         else
460         {
461             /* We can reuse the buffers to avoid cache misses */
462             lout2 = lin;
463             lout3 = lout;
464         }
465     }
466     else
467     {
468         lin  = *rlin;
469         lout = *rlout;
470         if (nthreads > 1)
471         {
472             lout2 = *rlout2;
473             lout3 = *rlout3;
474         }
475         else
476         {
477             lout2 = lin;
478             lout3 = lout;
479         }
480     }
481
482     plan = static_cast<fft5d_plan>(calloc(1, sizeof(struct fft5d_plan_t)));
483
484
485     if (debug)
486     {
487         fprintf(debug, "Running on %d threads\n", nthreads);
488     }
489
490 #if GMX_FFT_FFTW3
491     /* Don't add more stuff here! We have already had at least one bug because we are reimplementing
492      * the low-level FFT interface instead of using the Gromacs FFT module. If we need more
493      * generic functionality it is far better to extend the interface so we can use it for
494      * all FFT libraries instead of writing FFTW-specific code here.
495      */
496
497     /*if not FFTW - then we don't do a 3d plan but instead use only 1D plans */
498     /* It is possible to use the 3d plan with OMP threads - but in that case it is not allowed to be
499      * called from within a parallel region. For now deactivated. If it should be supported it has
500      * to made sure that that the execute of the 3d plan is in a master/serial block (since it
501      * contains it own parallel region) and that the 3d plan is faster than the 1d plan.
502      */
503     if ((!(flags & FFT5D_INPLACE)) && (!(P[0] > 1 || P[1] > 1))
504         && nthreads == 1) /*don't do 3d plan in parallel or if in_place requested */
505     {
506         int fftwflags = FFTW_DESTROY_INPUT;
507         FFTW(iodim) dims[3];
508         int inNG = NG, outMG = MG, outKG = KG;
509
510         FFTW_LOCK
511
512         fftwflags |= (flags & FFT5D_NOMEASURE) ? FFTW_ESTIMATE : FFTW_MEASURE;
513
514         if (flags & FFT5D_REALCOMPLEX)
515         {
516             if (!(flags & FFT5D_BACKWARD)) /*input pointer is not complex*/
517             {
518                 inNG *= 2;
519             }
520             else /*output pointer is not complex*/
521             {
522                 if (!(flags & FFT5D_ORDER_YZ))
523                 {
524                     outMG *= 2;
525                 }
526                 else
527                 {
528                     outKG *= 2;
529                 }
530             }
531         }
532
533         if (!(flags & FFT5D_BACKWARD))
534         {
535             dims[0].n = KG;
536             dims[1].n = MG;
537             dims[2].n = rNG;
538
539             dims[0].is = inNG * MG; /*N M K*/
540             dims[1].is = inNG;
541             dims[2].is = 1;
542             if (!(flags & FFT5D_ORDER_YZ))
543             {
544                 dims[0].os = MG; /*M K N*/
545                 dims[1].os = 1;
546                 dims[2].os = MG * KG;
547             }
548             else
549             {
550                 dims[0].os = 1; /*K N M*/
551                 dims[1].os = KG * NG;
552                 dims[2].os = KG;
553             }
554         }
555         else
556         {
557             if (!(flags & FFT5D_ORDER_YZ))
558             {
559                 dims[0].n = NG;
560                 dims[1].n = KG;
561                 dims[2].n = rMG;
562
563                 dims[0].is = 1;
564                 dims[1].is = NG * MG;
565                 dims[2].is = NG;
566
567                 dims[0].os = outMG * KG;
568                 dims[1].os = outMG;
569                 dims[2].os = 1;
570             }
571             else
572             {
573                 dims[0].n = MG;
574                 dims[1].n = NG;
575                 dims[2].n = rKG;
576
577                 dims[0].is = NG;
578                 dims[1].is = 1;
579                 dims[2].is = NG * MG;
580
581                 dims[0].os = outKG * NG;
582                 dims[1].os = outKG;
583                 dims[2].os = 1;
584             }
585         }
586 #    ifdef FFT5D_THREADS
587 #        ifdef FFT5D_FFTW_THREADS
588         FFTW(plan_with_nthreads)(nthreads);
589 #        endif
590 #    endif
591         if ((flags & FFT5D_REALCOMPLEX) && !(flags & FFT5D_BACKWARD))
592         {
593             plan->p3d = FFTW(plan_guru_dft_r2c)(/*rank*/ 3,
594                                                 dims,
595                                                 /*howmany*/ 0,
596                                                 /*howmany_dims*/ nullptr,
597                                                 reinterpret_cast<real*>(lin),
598                                                 reinterpret_cast<FFTW(complex)*>(lout),
599                                                 /*flags*/ fftwflags);
600         }
601         else if ((flags & FFT5D_REALCOMPLEX) && (flags & FFT5D_BACKWARD))
602         {
603             plan->p3d = FFTW(plan_guru_dft_c2r)(/*rank*/ 3,
604                                                 dims,
605                                                 /*howmany*/ 0,
606                                                 /*howmany_dims*/ nullptr,
607                                                 reinterpret_cast<FFTW(complex)*>(lin),
608                                                 reinterpret_cast<real*>(lout),
609                                                 /*flags*/ fftwflags);
610         }
611         else
612         {
613             plan->p3d = FFTW(plan_guru_dft)(
614                     /*rank*/ 3,
615                     dims,
616                     /*howmany*/ 0,
617                     /*howmany_dims*/ nullptr,
618                     reinterpret_cast<FFTW(complex)*>(lin),
619                     reinterpret_cast<FFTW(complex)*>(lout),
620                     /*sign*/ (flags & FFT5D_BACKWARD) ? 1 : -1,
621                     /*flags*/ fftwflags);
622         }
623 #    ifdef FFT5D_THREADS
624 #        ifdef FFT5D_FFTW_THREADS
625         FFTW(plan_with_nthreads)(1);
626 #        endif
627 #    endif
628         FFTW_UNLOCK
629     }
630     if (!plan->p3d) /* for decomposition and if 3d plan did not work */
631     {
632 #endif /* GMX_FFT_FFTW3 */
633         for (s = 0; s < 3; s++)
634         {
635             if (debug)
636             {
637                 fprintf(debug, "FFT5D: Plan s %d rC %d M %d pK %d C %d lsize %d\n", s, rC[s], M[s], pK[s], C[s], lsize);
638             }
639             plan->p1d[s] = static_cast<gmx_fft_t*>(malloc(sizeof(gmx_fft_t) * nthreads));
640
641             /* Make sure that the init routines are only called by one thread at a time and in order
642                (later is only important to not confuse valgrind)
643              */
644 #pragma omp parallel for num_threads(nthreads) schedule(static) ordered
645             for (int t = 0; t < nthreads; t++)
646             {
647 #pragma omp ordered
648                 {
649                     try
650                     {
651                         int tsize = ((t + 1) * pM[s] * pK[s] / nthreads) - (t * pM[s] * pK[s] / nthreads);
652
653                         if ((flags & FFT5D_REALCOMPLEX)
654                             && ((!(flags & FFT5D_BACKWARD) && s == 0)
655                                 || ((flags & FFT5D_BACKWARD) && s == 2)))
656                         {
657                             gmx_fft_init_many_1d_real(
658                                     &plan->p1d[s][t],
659                                     rC[s],
660                                     tsize,
661                                     (flags & FFT5D_NOMEASURE) ? GMX_FFT_FLAG_CONSERVATIVE : 0);
662                         }
663                         else
664                         {
665                             gmx_fft_init_many_1d(&plan->p1d[s][t],
666                                                  C[s],
667                                                  tsize,
668                                                  (flags & FFT5D_NOMEASURE) ? GMX_FFT_FLAG_CONSERVATIVE : 0);
669                         }
670                     }
671                     GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
672                 }
673             }
674         }
675
676 #if GMX_FFT_FFTW3
677     }
678 #endif
679     if ((flags & FFT5D_ORDER_YZ)) /*plan->cart is in the order of transposes */
680     {
681         plan->cart[0] = comm[0];
682         plan->cart[1] = comm[1];
683     }
684     else
685     {
686         plan->cart[1] = comm[0];
687         plan->cart[0] = comm[1];
688     }
689 #ifdef FFT5D_MPI_TRANSPOSE
690     FFTW_LOCK;
691     for (s = 0; s < 2; s++)
692     {
693         if ((s == 0 && !(flags & FFT5D_ORDER_YZ)) || (s == 1 && (flags & FFT5D_ORDER_YZ)))
694         {
695             plan->mpip[s] = FFTW(mpi_plan_many_transpose)(
696                     nP[s], nP[s], N[s] * K[s] * pM[s] * 2, 1, 1, (real*)lout2, (real*)lout3, plan->cart[s], FFTW_PATIENT);
697         }
698         else
699         {
700             plan->mpip[s] = FFTW(mpi_plan_many_transpose)(
701                     nP[s], nP[s], N[s] * pK[s] * M[s] * 2, 1, 1, (real*)lout2, (real*)lout3, plan->cart[s], FFTW_PATIENT);
702         }
703     }
704     FFTW_UNLOCK;
705 #endif
706
707
708     plan->lin   = lin;
709     plan->lout  = lout;
710     plan->lout2 = lout2;
711     plan->lout3 = lout3;
712
713     plan->NG = NG;
714     plan->MG = MG;
715     plan->KG = KG;
716
717     for (s = 0; s < 3; s++)
718     {
719         plan->N[s]     = N[s];
720         plan->M[s]     = M[s];
721         plan->K[s]     = K[s];
722         plan->pN[s]    = pN[s];
723         plan->pM[s]    = pM[s];
724         plan->pK[s]    = pK[s];
725         plan->oM[s]    = oM[s];
726         plan->oK[s]    = oK[s];
727         plan->C[s]     = C[s];
728         plan->rC[s]    = rC[s];
729         plan->iNin[s]  = iNin[s];
730         plan->oNin[s]  = oNin[s];
731         plan->iNout[s] = iNout[s];
732         plan->oNout[s] = oNout[s];
733     }
734     for (s = 0; s < 2; s++)
735     {
736         plan->P[s]    = nP[s];
737         plan->coor[s] = prank[s];
738     }
739
740     /*    plan->fftorder=fftorder;
741         plan->direction=direction;
742         plan->realcomplex=realcomplex;
743      */
744     plan->flags         = flags;
745     plan->nthreads      = nthreads;
746     plan->pinningPolicy = realGridAllocationPinningPolicy;
747     *rlin               = lin;
748     *rlout              = lout;
749     *rlout2             = lout2;
750     *rlout3             = lout3;
751     return plan;
752 }
753
754
755 enum order
756 {
757     XYZ,
758     XZY,
759     YXZ,
760     YZX,
761     ZXY,
762     ZYX
763 };
764
765
766 /*here x,y,z and N,M,K is in rotated coordinate system!!
767    x (and N) is mayor (consecutive) dimension, y (M) middle and z (K) major
768    maxN,maxM,maxK is max size of local data
769    pN, pM, pK is local size specific to current processor (only different to max if not divisible)
770    NG, MG, KG is size of global data*/
771 static void splitaxes(t_complex*       lout,
772                       const t_complex* lin,
773                       int              maxN,
774                       int              maxM,
775                       int              maxK,
776                       int              pM,
777                       int              P,
778                       int              NG,
779                       const int*       N,
780                       const int*       oN,
781                       int              starty,
782                       int              startz,
783                       int              endy,
784                       int              endz)
785 {
786     int x, y, z, i;
787     int in_i, out_i, in_z, out_z, in_y, out_y;
788     int s_y, e_y;
789
790     for (z = startz; z < endz + 1; z++) /*3. z l*/
791     {
792         if (z == startz)
793         {
794             s_y = starty;
795         }
796         else
797         {
798             s_y = 0;
799         }
800         if (z == endz)
801         {
802             e_y = endy;
803         }
804         else
805         {
806             e_y = pM;
807         }
808         out_z = z * maxN * maxM;
809         in_z  = z * NG * pM;
810
811         for (i = 0; i < P; i++) /*index cube along long axis*/
812         {
813             out_i = out_z + i * maxN * maxM * maxK;
814             in_i  = in_z + oN[i];
815             for (y = s_y; y < e_y; y++) /*2. y k*/
816             {
817                 out_y = out_i + y * maxN;
818                 in_y  = in_i + y * NG;
819                 for (x = 0; x < N[i]; x++) /*1. x j*/
820                 {
821                     lout[out_y + x] = lin[in_y + x]; /*in=z*NG*pM+oN[i]+y*NG+x*/
822                     /*after split important that each processor chunk i has size maxN*maxM*maxK and thus being the same size*/
823                     /*before split data contiguos - thus if different processor get different amount oN is different*/
824                 }
825             }
826         }
827     }
828 }
829
830 /*make axis contiguous again (after AllToAll) and also do local transpose*/
831 /*transpose mayor and major dimension
832    variables see above
833    the major, middle, minor order is only correct for x,y,z (N,M,K) for the input
834    N,M,K local dimensions
835    KG global size*/
836 static void joinAxesTrans13(t_complex*       lout,
837                             const t_complex* lin,
838                             int              maxN,
839                             int              maxM,
840                             int              maxK,
841                             int              pM,
842                             int              P,
843                             int              KG,
844                             const int*       K,
845                             const int*       oK,
846                             int              starty,
847                             int              startx,
848                             int              endy,
849                             int              endx)
850 {
851     int i, x, y, z;
852     int out_i, in_i, out_x, in_x, out_z, in_z;
853     int s_y, e_y;
854
855     for (x = startx; x < endx + 1; x++) /*1.j*/
856     {
857         if (x == startx)
858         {
859             s_y = starty;
860         }
861         else
862         {
863             s_y = 0;
864         }
865         if (x == endx)
866         {
867             e_y = endy;
868         }
869         else
870         {
871             e_y = pM;
872         }
873
874         out_x = x * KG * pM;
875         in_x  = x;
876
877         for (i = 0; i < P; i++) /*index cube along long axis*/
878         {
879             out_i = out_x + oK[i];
880             in_i  = in_x + i * maxM * maxN * maxK;
881             for (z = 0; z < K[i]; z++) /*3.l*/
882             {
883                 out_z = out_i + z;
884                 in_z  = in_i + z * maxM * maxN;
885                 for (y = s_y; y < e_y; y++) /*2.k*/
886                 {
887                     lout[out_z + y * KG] = lin[in_z + y * maxN]; /*out=x*KG*pM+oK[i]+z+y*KG*/
888                 }
889             }
890         }
891     }
892 }
893
894 /*make axis contiguous again (after AllToAll) and also do local transpose
895    tranpose mayor and middle dimension
896    variables see above
897    the minor, middle, major order is only correct for x,y,z (N,M,K) for the input
898    N,M,K local size
899    MG, global size*/
900 static void joinAxesTrans12(t_complex*       lout,
901                             const t_complex* lin,
902                             int              maxN,
903                             int              maxM,
904                             int              maxK,
905                             int              pN,
906                             int              P,
907                             int              MG,
908                             const int*       M,
909                             const int*       oM,
910                             int              startx,
911                             int              startz,
912                             int              endx,
913                             int              endz)
914 {
915     int i, z, y, x;
916     int out_i, in_i, out_z, in_z, out_x, in_x;
917     int s_x, e_x;
918
919     for (z = startz; z < endz + 1; z++)
920     {
921         if (z == startz)
922         {
923             s_x = startx;
924         }
925         else
926         {
927             s_x = 0;
928         }
929         if (z == endz)
930         {
931             e_x = endx;
932         }
933         else
934         {
935             e_x = pN;
936         }
937         out_z = z * MG * pN;
938         in_z  = z * maxM * maxN;
939
940         for (i = 0; i < P; i++) /*index cube along long axis*/
941         {
942             out_i = out_z + oM[i];
943             in_i  = in_z + i * maxM * maxN * maxK;
944             for (x = s_x; x < e_x; x++)
945             {
946                 out_x = out_i + x * MG;
947                 in_x  = in_i + x;
948                 for (y = 0; y < M[i]; y++)
949                 {
950                     lout[out_x + y] = lin[in_x + y * maxN]; /*out=z*MG*pN+oM[i]+x*MG+y*/
951                 }
952             }
953         }
954     }
955 }
956
957
958 static void rotate_offsets(int x[])
959 {
960     int t = x[0];
961     /*    x[0]=x[2];
962         x[2]=x[1];
963         x[1]=t;*/
964     x[0] = x[1];
965     x[1] = x[2];
966     x[2] = t;
967 }
968
969 /*compute the offset to compare or print transposed local data in original input coordinates
970    xs matrix dimension size, xl dimension length, xc decomposition offset
971    s: step in computation = number of transposes*/
972 static void compute_offsets(fft5d_plan plan, int xs[], int xl[], int xc[], int NG[], int s)
973 {
974     /*    int direction = plan->direction;
975         int fftorder = plan->fftorder;*/
976
977     int  o = 0;
978     int  pos[3], i;
979     int *pM = plan->pM, *pK = plan->pK, *oM = plan->oM, *oK = plan->oK, *C = plan->C, *rC = plan->rC;
980
981     NG[0] = plan->NG;
982     NG[1] = plan->MG;
983     NG[2] = plan->KG;
984
985     if (!(plan->flags & FFT5D_ORDER_YZ))
986     {
987         switch (s)
988         {
989             case 0: o = XYZ; break;
990             case 1: o = ZYX; break;
991             case 2: o = YZX; break;
992             default: assert(0);
993         }
994     }
995     else
996     {
997         switch (s)
998         {
999             case 0: o = XYZ; break;
1000             case 1: o = YXZ; break;
1001             case 2: o = ZXY; break;
1002             default: assert(0);
1003         }
1004     }
1005
1006     switch (o)
1007     {
1008         case XYZ:
1009             pos[0] = 1;
1010             pos[1] = 2;
1011             pos[2] = 3;
1012             break;
1013         case XZY:
1014             pos[0] = 1;
1015             pos[1] = 3;
1016             pos[2] = 2;
1017             break;
1018         case YXZ:
1019             pos[0] = 2;
1020             pos[1] = 1;
1021             pos[2] = 3;
1022             break;
1023         case YZX:
1024             pos[0] = 3;
1025             pos[1] = 1;
1026             pos[2] = 2;
1027             break;
1028         case ZXY:
1029             pos[0] = 2;
1030             pos[1] = 3;
1031             pos[2] = 1;
1032             break;
1033         case ZYX:
1034             pos[0] = 3;
1035             pos[1] = 2;
1036             pos[2] = 1;
1037             break;
1038     }
1039     /*if (debug) printf("pos: %d %d %d\n",pos[0],pos[1],pos[2]);*/
1040
1041     /*xs, xl give dimension size and data length in local transposed coordinate system
1042        for 0(/1/2): x(/y/z) in original coordinate system*/
1043     for (i = 0; i < 3; i++)
1044     {
1045         switch (pos[i])
1046         {
1047             case 1:
1048                 xs[i] = 1;
1049                 xc[i] = 0;
1050                 xl[i] = C[s];
1051                 break;
1052             case 2:
1053                 xs[i] = C[s];
1054                 xc[i] = oM[s];
1055                 xl[i] = pM[s];
1056                 break;
1057             case 3:
1058                 xs[i] = C[s] * pM[s];
1059                 xc[i] = oK[s];
1060                 xl[i] = pK[s];
1061                 break;
1062         }
1063     }
1064     /*input order is different for test program to match FFTW order
1065        (important for complex to real)*/
1066     if (plan->flags & FFT5D_BACKWARD)
1067     {
1068         rotate_offsets(xs);
1069         rotate_offsets(xl);
1070         rotate_offsets(xc);
1071         rotate_offsets(NG);
1072         if (plan->flags & FFT5D_ORDER_YZ)
1073         {
1074             rotate_offsets(xs);
1075             rotate_offsets(xl);
1076             rotate_offsets(xc);
1077             rotate_offsets(NG);
1078         }
1079     }
1080     if ((plan->flags & FFT5D_REALCOMPLEX)
1081         && ((!(plan->flags & FFT5D_BACKWARD) && s == 0) || ((plan->flags & FFT5D_BACKWARD) && s == 2)))
1082     {
1083         xl[0] = rC[s];
1084     }
1085 }
1086
1087 static void print_localdata(const t_complex* lin, const char* txt, int s, fft5d_plan plan)
1088 {
1089     int  x, y, z, l;
1090     int* coor = plan->coor;
1091     int  xs[3], xl[3], xc[3], NG[3];
1092     int  ll = (plan->flags & FFT5D_REALCOMPLEX) ? 1 : 2;
1093     compute_offsets(plan, xs, xl, xc, NG, s);
1094     fprintf(debug, txt, coor[0], coor[1]);
1095     /*printf("xs: %d %d %d, xl: %d %d %d\n",xs[0],xs[1],xs[2],xl[0],xl[1],xl[2]);*/
1096     for (z = 0; z < xl[2]; z++)
1097     {
1098         for (y = 0; y < xl[1]; y++)
1099         {
1100             fprintf(debug, "%d %d: ", coor[0], coor[1]);
1101             for (x = 0; x < xl[0]; x++)
1102             {
1103                 for (l = 0; l < ll; l++)
1104                 {
1105                     fprintf(debug,
1106                             "%f ",
1107                             reinterpret_cast<const real*>(
1108                                     lin)[(z * xs[2] + y * xs[1]) * 2 + (x * xs[0]) * ll + l]);
1109                 }
1110                 fprintf(debug, ",");
1111             }
1112             fprintf(debug, "\n");
1113         }
1114     }
1115 }
1116
1117 void fft5d_execute(fft5d_plan plan, int thread, fft5d_time times)
1118 {
1119     t_complex* lin   = plan->lin;
1120     t_complex* lout  = plan->lout;
1121     t_complex* lout2 = plan->lout2;
1122     t_complex* lout3 = plan->lout3;
1123     t_complex *fftout, *joinin;
1124
1125     gmx_fft_t** p1d = plan->p1d;
1126 #ifdef FFT5D_MPI_TRANSPOSE
1127     FFTW(plan)* mpip = plan->mpip;
1128 #endif
1129 #if GMX_MPI
1130     MPI_Comm* cart = plan->cart;
1131 #endif
1132 #ifdef NOGMX
1133     double time_fft = 0, time_local = 0, time_mpi[2] = { 0 }, time = 0;
1134 #endif
1135     int *N = plan->N, *M = plan->M, *K = plan->K, *pN = plan->pN, *pM = plan->pM, *pK = plan->pK,
1136         *C = plan->C, *P = plan->P, **iNin = plan->iNin, **oNin = plan->oNin, **iNout = plan->iNout,
1137         **oNout = plan->oNout;
1138     int s       = 0, tstart, tend, bParallelDim;
1139
1140
1141 #if GMX_FFT_FFTW3
1142     if (plan->p3d)
1143     {
1144         if (thread == 0)
1145         {
1146 #    ifdef NOGMX
1147             if (times != 0)
1148             {
1149                 time = MPI_Wtime();
1150             }
1151 #    endif
1152             FFTW(execute)(plan->p3d);
1153 #    ifdef NOGMX
1154             if (times != 0)
1155             {
1156                 times->fft += MPI_Wtime() - time;
1157             }
1158 #    endif
1159         }
1160         return;
1161     }
1162 #endif
1163
1164     s = 0;
1165
1166     /*lin: x,y,z*/
1167     if ((plan->flags & FFT5D_DEBUG) && thread == 0)
1168     {
1169         print_localdata(lin, "%d %d: copy in lin\n", s, plan);
1170     }
1171
1172     for (s = 0; s < 2; s++) /*loop over first two FFT steps (corner rotations)*/
1173
1174     {
1175 #if GMX_MPI
1176         if (GMX_PARALLEL_ENV_INITIALIZED && cart[s] != MPI_COMM_NULL && P[s] > 1)
1177         {
1178             bParallelDim = 1;
1179         }
1180         else
1181 #endif
1182         {
1183             bParallelDim = 0;
1184         }
1185
1186         /* ---------- START FFT ------------ */
1187 #ifdef NOGMX
1188         if (times != 0 && thread == 0)
1189         {
1190             time = MPI_Wtime();
1191         }
1192 #endif
1193
1194         if (bParallelDim || plan->nthreads == 1)
1195         {
1196             fftout = lout;
1197         }
1198         else
1199         {
1200             if (s == 0)
1201             {
1202                 fftout = lout3;
1203             }
1204             else
1205             {
1206                 fftout = lout2;
1207             }
1208         }
1209
1210         tstart = (thread * pM[s] * pK[s] / plan->nthreads) * C[s];
1211         if ((plan->flags & FFT5D_REALCOMPLEX) && !(plan->flags & FFT5D_BACKWARD) && s == 0)
1212         {
1213             gmx_fft_many_1d_real(p1d[s][thread],
1214                                  (plan->flags & FFT5D_BACKWARD) ? GMX_FFT_COMPLEX_TO_REAL
1215                                                                 : GMX_FFT_REAL_TO_COMPLEX,
1216                                  lin + tstart,
1217                                  fftout + tstart);
1218         }
1219         else
1220         {
1221             gmx_fft_many_1d(p1d[s][thread],
1222                             (plan->flags & FFT5D_BACKWARD) ? GMX_FFT_BACKWARD : GMX_FFT_FORWARD,
1223                             lin + tstart,
1224                             fftout + tstart);
1225         }
1226
1227 #ifdef NOGMX
1228         if (times != NULL && thread == 0)
1229         {
1230             time_fft += MPI_Wtime() - time;
1231         }
1232 #endif
1233         if ((plan->flags & FFT5D_DEBUG) && thread == 0)
1234         {
1235             print_localdata(lout, "%d %d: FFT %d\n", s, plan);
1236         }
1237         /* ---------- END FFT ------------ */
1238
1239         /* ---------- START SPLIT + TRANSPOSE------------ (if parallel in in this dimension)*/
1240         if (bParallelDim)
1241         {
1242 #ifdef NOGMX
1243             if (times != NULL && thread == 0)
1244             {
1245                 time = MPI_Wtime();
1246             }
1247 #endif
1248             /*prepare for A
1249                llToAll
1250                1. (most outer) axes (x) is split into P[s] parts of size N[s]
1251                for sending*/
1252             if (pM[s] > 0)
1253             {
1254                 tend = ((thread + 1) * pM[s] * pK[s] / plan->nthreads);
1255                 tstart /= C[s];
1256                 splitaxes(lout2,
1257                           lout,
1258                           N[s],
1259                           M[s],
1260                           K[s],
1261                           pM[s],
1262                           P[s],
1263                           C[s],
1264                           iNout[s],
1265                           oNout[s],
1266                           tstart % pM[s],
1267                           tstart / pM[s],
1268                           tend % pM[s],
1269                           tend / pM[s]);
1270             }
1271 #pragma omp barrier /*barrier required before AllToAll (all input has to be their) - before timing to make timing more acurate*/
1272 #ifdef NOGMX
1273             if (times != NULL && thread == 0)
1274             {
1275                 time_local += MPI_Wtime() - time;
1276             }
1277 #endif
1278
1279             /* ---------- END SPLIT , START TRANSPOSE------------ */
1280
1281             if (thread == 0)
1282             {
1283 #ifdef NOGMX
1284                 if (times != 0)
1285                 {
1286                     time = MPI_Wtime();
1287                 }
1288 #else
1289                 wallcycle_start(times, ewcPME_FFTCOMM);
1290 #endif
1291 #ifdef FFT5D_MPI_TRANSPOSE
1292                 FFTW(execute)(mpip[s]);
1293 #else
1294 #    if GMX_MPI
1295                 if ((s == 0 && !(plan->flags & FFT5D_ORDER_YZ))
1296                     || (s == 1 && (plan->flags & FFT5D_ORDER_YZ)))
1297                 {
1298                     MPI_Alltoall(reinterpret_cast<real*>(lout2),
1299                                  N[s] * pM[s] * K[s] * sizeof(t_complex) / sizeof(real),
1300                                  GMX_MPI_REAL,
1301                                  reinterpret_cast<real*>(lout3),
1302                                  N[s] * pM[s] * K[s] * sizeof(t_complex) / sizeof(real),
1303                                  GMX_MPI_REAL,
1304                                  cart[s]);
1305                 }
1306                 else
1307                 {
1308                     MPI_Alltoall(reinterpret_cast<real*>(lout2),
1309                                  N[s] * M[s] * pK[s] * sizeof(t_complex) / sizeof(real),
1310                                  GMX_MPI_REAL,
1311                                  reinterpret_cast<real*>(lout3),
1312                                  N[s] * M[s] * pK[s] * sizeof(t_complex) / sizeof(real),
1313                                  GMX_MPI_REAL,
1314                                  cart[s]);
1315                 }
1316 #    else
1317                 GMX_RELEASE_ASSERT(false, "Invalid call to fft5d_execute");
1318 #    endif /*GMX_MPI*/
1319 #endif     /*FFT5D_MPI_TRANSPOSE*/
1320 #ifdef NOGMX
1321                 if (times != 0)
1322                 {
1323                     time_mpi[s] = MPI_Wtime() - time;
1324                 }
1325 #else
1326                 wallcycle_stop(times, ewcPME_FFTCOMM);
1327 #endif
1328             }       /*master*/
1329         }           /* bPrallelDim */
1330 #pragma omp barrier /*both needed for parallel and non-parallel dimension (either have to wait on data from AlltoAll or from last FFT*/
1331
1332         /* ---------- END SPLIT + TRANSPOSE------------ */
1333
1334         /* ---------- START JOIN ------------ */
1335 #ifdef NOGMX
1336         if (times != NULL && thread == 0)
1337         {
1338             time = MPI_Wtime();
1339         }
1340 #endif
1341
1342         if (bParallelDim)
1343         {
1344             joinin = lout3;
1345         }
1346         else
1347         {
1348             joinin = fftout;
1349         }
1350         /*bring back in matrix form
1351            thus make  new 1. axes contiguos
1352            also local transpose 1 and 2/3
1353            runs on thread used for following FFT (thus needing a barrier before but not afterwards)
1354          */
1355         if ((s == 0 && !(plan->flags & FFT5D_ORDER_YZ)) || (s == 1 && (plan->flags & FFT5D_ORDER_YZ)))
1356         {
1357             if (pM[s] > 0)
1358             {
1359                 tstart = (thread * pM[s] * pN[s] / plan->nthreads);
1360                 tend   = ((thread + 1) * pM[s] * pN[s] / plan->nthreads);
1361                 joinAxesTrans13(lin,
1362                                 joinin,
1363                                 N[s],
1364                                 pM[s],
1365                                 K[s],
1366                                 pM[s],
1367                                 P[s],
1368                                 C[s + 1],
1369                                 iNin[s + 1],
1370                                 oNin[s + 1],
1371                                 tstart % pM[s],
1372                                 tstart / pM[s],
1373                                 tend % pM[s],
1374                                 tend / pM[s]);
1375             }
1376         }
1377         else
1378         {
1379             if (pN[s] > 0)
1380             {
1381                 tstart = (thread * pK[s] * pN[s] / plan->nthreads);
1382                 tend   = ((thread + 1) * pK[s] * pN[s] / plan->nthreads);
1383                 joinAxesTrans12(lin,
1384                                 joinin,
1385                                 N[s],
1386                                 M[s],
1387                                 pK[s],
1388                                 pN[s],
1389                                 P[s],
1390                                 C[s + 1],
1391                                 iNin[s + 1],
1392                                 oNin[s + 1],
1393                                 tstart % pN[s],
1394                                 tstart / pN[s],
1395                                 tend % pN[s],
1396                                 tend / pN[s]);
1397             }
1398         }
1399
1400 #ifdef NOGMX
1401         if (times != NULL && thread == 0)
1402         {
1403             time_local += MPI_Wtime() - time;
1404         }
1405 #endif
1406         if ((plan->flags & FFT5D_DEBUG) && thread == 0)
1407         {
1408             print_localdata(lin, "%d %d: tranposed %d\n", s + 1, plan);
1409         }
1410         /* ---------- END JOIN ------------ */
1411
1412         /*if (debug) print_localdata(lin, "%d %d: transposed x-z\n", N1, M0, K, ZYX, coor);*/
1413     } /* for(s=0;s<2;s++) */
1414 #ifdef NOGMX
1415     if (times != NULL && thread == 0)
1416     {
1417         time = MPI_Wtime();
1418     }
1419 #endif
1420
1421     if (plan->flags & FFT5D_INPLACE)
1422     {
1423         lout = lin; /*in place currently not supported*/
1424     }
1425     /*  ----------- FFT ----------- */
1426     tstart = (thread * pM[s] * pK[s] / plan->nthreads) * C[s];
1427     if ((plan->flags & FFT5D_REALCOMPLEX) && (plan->flags & FFT5D_BACKWARD))
1428     {
1429         gmx_fft_many_1d_real(p1d[s][thread],
1430                              (plan->flags & FFT5D_BACKWARD) ? GMX_FFT_COMPLEX_TO_REAL : GMX_FFT_REAL_TO_COMPLEX,
1431                              lin + tstart,
1432                              lout + tstart);
1433     }
1434     else
1435     {
1436         gmx_fft_many_1d(p1d[s][thread],
1437                         (plan->flags & FFT5D_BACKWARD) ? GMX_FFT_BACKWARD : GMX_FFT_FORWARD,
1438                         lin + tstart,
1439                         lout + tstart);
1440     }
1441     /* ------------ END FFT ---------*/
1442
1443 #ifdef NOGMX
1444     if (times != NULL && thread == 0)
1445     {
1446         time_fft += MPI_Wtime() - time;
1447
1448         times->fft += time_fft;
1449         times->local += time_local;
1450         times->mpi2 += time_mpi[1];
1451         times->mpi1 += time_mpi[0];
1452     }
1453 #endif
1454
1455     if ((plan->flags & FFT5D_DEBUG) && thread == 0)
1456     {
1457         print_localdata(lout, "%d %d: FFT %d\n", s, plan);
1458     }
1459 }
1460
1461 void fft5d_destroy(fft5d_plan plan)
1462 {
1463     int s, t;
1464
1465     for (s = 0; s < 3; s++)
1466     {
1467         if (plan->p1d[s])
1468         {
1469             for (t = 0; t < plan->nthreads; t++)
1470             {
1471                 gmx_many_fft_destroy(plan->p1d[s][t]);
1472             }
1473             free(plan->p1d[s]);
1474         }
1475         if (plan->iNin[s])
1476         {
1477             free(plan->iNin[s]);
1478             plan->iNin[s] = nullptr;
1479         }
1480         if (plan->oNin[s])
1481         {
1482             free(plan->oNin[s]);
1483             plan->oNin[s] = nullptr;
1484         }
1485         if (plan->iNout[s])
1486         {
1487             free(plan->iNout[s]);
1488             plan->iNout[s] = nullptr;
1489         }
1490         if (plan->oNout[s])
1491         {
1492             free(plan->oNout[s]);
1493             plan->oNout[s] = nullptr;
1494         }
1495     }
1496 #if GMX_FFT_FFTW3
1497     FFTW_LOCK
1498 #    ifdef FFT5D_MPI_TRANSPOS
1499     for (s = 0; s < 2; s++)
1500     {
1501         FFTW(destroy_plan)(plan->mpip[s]);
1502     }
1503 #    endif /* FFT5D_MPI_TRANSPOS */
1504     if (plan->p3d)
1505     {
1506         FFTW(destroy_plan)(plan->p3d);
1507     }
1508     FFTW_UNLOCK
1509 #endif /* GMX_FFT_FFTW3 */
1510
1511     if (!(plan->flags & FFT5D_NOMALLOC))
1512     {
1513         // only needed for PME GPU mixed mode
1514         if (plan->pinningPolicy == gmx::PinningPolicy::PinnedIfSupported && isHostMemoryPinned(plan->lin))
1515         {
1516             gmx::unpinBuffer(plan->lin);
1517         }
1518         sfree_aligned(plan->lin);
1519         sfree_aligned(plan->lout);
1520         if (plan->nthreads > 1)
1521         {
1522             sfree_aligned(plan->lout2);
1523             sfree_aligned(plan->lout3);
1524         }
1525     }
1526
1527 #ifdef FFT5D_THREADS
1528 #    ifdef FFT5D_FFTW_THREADS
1529     /*FFTW(cleanup_threads)();*/
1530 #    endif
1531 #endif
1532
1533     free(plan);
1534 }
1535
1536 /*Is this better than direct access of plan? enough data?
1537    here 0,1 reference divided by which processor grid dimension (not FFT step!)*/
1538 void fft5d_local_size(fft5d_plan plan, int* N1, int* M0, int* K0, int* K1, int** coor)
1539 {
1540     *N1 = plan->N[0];
1541     *M0 = plan->M[0];
1542     *K1 = plan->K[0];
1543     *K0 = plan->N[1];
1544
1545     *coor = plan->coor;
1546 }
1547
1548
1549 /*same as fft5d_plan_3d but with cartesian coordinator and automatic splitting
1550    of processor dimensions*/
1551 fft5d_plan fft5d_plan_3d_cart(int         NG,
1552                               int         MG,
1553                               int         KG,
1554                               MPI_Comm    comm,
1555                               int         P0,
1556                               int         flags,
1557                               t_complex** rlin,
1558                               t_complex** rlout,
1559                               t_complex** rlout2,
1560                               t_complex** rlout3,
1561                               int         nthreads)
1562 {
1563     MPI_Comm cart[2] = { MPI_COMM_NULL, MPI_COMM_NULL };
1564 #if GMX_MPI
1565     int      size = 1, prank = 0;
1566     int      P[2];
1567     int      coor[2];
1568     int      wrap[] = { 0, 0 };
1569     MPI_Comm gcart;
1570     int      rdim1[] = { 0, 1 }, rdim2[] = { 1, 0 };
1571
1572     MPI_Comm_size(comm, &size);
1573     MPI_Comm_rank(comm, &prank);
1574
1575     if (P0 == 0)
1576     {
1577         P0 = lfactor(size);
1578     }
1579     if (size % P0 != 0)
1580     {
1581         if (prank == 0)
1582         {
1583             printf("FFT5D: WARNING: Number of ranks %d not evenly divisible by %d\n", size, P0);
1584         }
1585         P0 = lfactor(size);
1586     }
1587
1588     P[0] = P0;
1589     P[1] = size / P0; /*number of processors in the two dimensions*/
1590
1591     /*Difference between x-y-z regarding 2d decomposition is whether they are
1592        distributed along axis 1, 2 or both*/
1593
1594     MPI_Cart_create(comm, 2, P, wrap, 1, &gcart); /*parameter 4: value 1: reorder*/
1595     MPI_Cart_get(gcart, 2, P, wrap, coor);
1596     MPI_Cart_sub(gcart, rdim1, &cart[0]);
1597     MPI_Cart_sub(gcart, rdim2, &cart[1]);
1598 #else
1599     (void)P0;
1600     (void)comm;
1601 #endif
1602     return fft5d_plan_3d(NG, MG, KG, cart, flags, rlin, rlout, rlout2, rlout3, nthreads);
1603 }
1604
1605
1606 /*prints in original coordinate system of data (as the input to FFT)*/
1607 void fft5d_compare_data(const t_complex* lin, const t_complex* in, fft5d_plan plan, int bothLocal, int normalize)
1608 {
1609     int  xs[3], xl[3], xc[3], NG[3];
1610     int  x, y, z, l;
1611     int* coor = plan->coor;
1612     int  ll   = 2; /*compare ll values per element (has to be 2 for complex)*/
1613     if ((plan->flags & FFT5D_REALCOMPLEX) && (plan->flags & FFT5D_BACKWARD))
1614     {
1615         ll = 1;
1616     }
1617
1618     compute_offsets(plan, xs, xl, xc, NG, 2);
1619     if (plan->flags & FFT5D_DEBUG)
1620     {
1621         printf("Compare2\n");
1622     }
1623     for (z = 0; z < xl[2]; z++)
1624     {
1625         for (y = 0; y < xl[1]; y++)
1626         {
1627             if (plan->flags & FFT5D_DEBUG)
1628             {
1629                 printf("%d %d: ", coor[0], coor[1]);
1630             }
1631             for (x = 0; x < xl[0]; x++)
1632             {
1633                 for (l = 0; l < ll; l++) /*loop over real/complex parts*/
1634                 {
1635                     real a, b;
1636                     a = reinterpret_cast<const real*>(lin)[(z * xs[2] + y * xs[1]) * 2 + x * xs[0] * ll + l];
1637                     if (normalize)
1638                     {
1639                         a /= plan->rC[0] * plan->rC[1] * plan->rC[2];
1640                     }
1641                     if (!bothLocal)
1642                     {
1643                         b = reinterpret_cast<const real*>(
1644                                 in)[((z + xc[2]) * NG[0] * NG[1] + (y + xc[1]) * NG[0]) * 2 + (x + xc[0]) * ll + l];
1645                     }
1646                     else
1647                     {
1648                         b = reinterpret_cast<const real*>(
1649                                 in)[(z * xs[2] + y * xs[1]) * 2 + x * xs[0] * ll + l];
1650                     }
1651                     if (plan->flags & FFT5D_DEBUG)
1652                     {
1653                         printf("%f %f, ", a, b);
1654                     }
1655                     else
1656                     {
1657                         if (std::fabs(a - b) > 2 * NG[0] * NG[1] * NG[2] * GMX_REAL_EPS)
1658                         {
1659                             printf("result incorrect on %d,%d at %d,%d,%d: FFT5D:%f reference:%f\n",
1660                                    coor[0],
1661                                    coor[1],
1662                                    x,
1663                                    y,
1664                                    z,
1665                                    a,
1666                                    b);
1667                         }
1668                         /*                        assert(fabs(a-b)<2*NG[0]*NG[1]*NG[2]*GMX_REAL_EPS);*/
1669                     }
1670                 }
1671                 if (plan->flags & FFT5D_DEBUG)
1672                 {
1673                     printf(",");
1674                 }
1675             }
1676             if (plan->flags & FFT5D_DEBUG)
1677             {
1678                 printf("\n");
1679             }
1680         }
1681     }
1682 }