29618ee086aadf4bd269dec351bdc0542fa9f892
[alexxy/gromacs.git] / src / gromacs / fft / fft5d.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2009-2018, The GROMACS development team.
5  * Copyright (c) 2019,2020, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36 #include "gmxpre.h"
37
38 #include "fft5d.h"
39
40 #include "config.h"
41
42 #include <cassert>
43 #include <cfloat>
44 #include <cmath>
45 #include <cstdio>
46 #include <cstdlib>
47 #include <cstring>
48
49 #include <algorithm>
50
51 #include "gromacs/gpu_utils/gpu_utils.h"
52 #include "gromacs/gpu_utils/hostallocator.h"
53 #include "gromacs/gpu_utils/pinning.h"
54 #include "gromacs/utility/alignedallocator.h"
55 #include "gromacs/utility/exceptions.h"
56 #include "gromacs/utility/fatalerror.h"
57 #include "gromacs/utility/gmxassert.h"
58 #include "gromacs/utility/gmxmpi.h"
59 #include "gromacs/utility/smalloc.h"
60
61 #ifdef NOGMX
62 #    define GMX_PARALLEL_ENV_INITIALIZED 1
63 #else
64 #    if GMX_MPI
65 #        define GMX_PARALLEL_ENV_INITIALIZED 1
66 #    else
67 #        define GMX_PARALLEL_ENV_INITIALIZED 0
68 #    endif
69 #endif
70
71 #if GMX_OPENMP
72 /* TODO: Do we still need this? Are we still planning ot use fftw + OpenMP? */
73 #    define FFT5D_THREADS
74 /* requires fftw compiled with openmp */
75 /* #define FFT5D_FFTW_THREADS (now set by cmake) */
76 #endif
77
78 #ifndef __FLT_EPSILON__
79 #    define __FLT_EPSILON__ FLT_EPSILON
80 #    define __DBL_EPSILON__ DBL_EPSILON
81 #endif
82
83 #ifdef NOGMX
84 FILE* debug = 0;
85 #endif
86
87 #if GMX_FFT_FFTW3
88
89 #    include "gromacs/utility/exceptions.h"
90 #    include "gromacs/utility/mutex.h"
91 /* none of the fftw3 calls, except execute(), are thread-safe, so
92    we need to serialize them with this mutex. */
93 static gmx::Mutex big_fftw_mutex;
94 #    define FFTW_LOCK              \
95         try                        \
96         {                          \
97             big_fftw_mutex.lock(); \
98         }                          \
99         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
100 #    define FFTW_UNLOCK              \
101         try                          \
102         {                            \
103             big_fftw_mutex.unlock(); \
104         }                            \
105         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
106 #endif /* GMX_FFT_FFTW3 */
107
108 #if GMX_MPI
109 /* largest factor smaller than sqrt */
110 static int lfactor(int z)
111 {
112     int i = static_cast<int>(sqrt(static_cast<double>(z)));
113     while (z % i != 0)
114     {
115         i--;
116     }
117     return i;
118 }
119 #endif
120
121 #if !GMX_MPI
122 #    if HAVE_GETTIMEOFDAY
123 #        include <sys/time.h>
124 double MPI_Wtime()
125 {
126     struct timeval tv;
127     gettimeofday(&tv, nullptr);
128     return tv.tv_sec + tv.tv_usec * 1e-6;
129 }
130 #    else
131 double MPI_Wtime()
132 {
133     return 0.0;
134 }
135 #    endif
136 #endif
137
138 static int vmax(const int* a, int s)
139 {
140     int i, max = 0;
141     for (i = 0; i < s; i++)
142     {
143         if (a[i] > max)
144         {
145             max = a[i];
146         }
147     }
148     return max;
149 }
150
151
152 /* NxMxK the size of the data
153  * comm communicator to use for fft5d
154  * P0 number of processor in 1st axes (can be null for automatic)
155  * lin is allocated by fft5d because size of array is only known after planning phase
156  * rlout2 is only used as intermediate buffer - only returned after allocation to reuse for back transform - should not be used by caller
157  */
158 fft5d_plan fft5d_plan_3d(int                NG,
159                          int                MG,
160                          int                KG,
161                          MPI_Comm           comm[2],
162                          int                flags,
163                          t_complex**        rlin,
164                          t_complex**        rlout,
165                          t_complex**        rlout2,
166                          t_complex**        rlout3,
167                          int                nthreads,
168                          gmx::PinningPolicy realGridAllocationPinningPolicy)
169 {
170
171     int  P[2], prank[2], i;
172     bool bMaster;
173     int  rNG, rMG, rKG;
174     int *N0 = nullptr, *N1 = nullptr, *M0 = nullptr, *M1 = nullptr, *K0 = nullptr, *K1 = nullptr,
175         *oN0 = nullptr, *oN1 = nullptr, *oM0 = nullptr, *oM1 = nullptr, *oK0 = nullptr, *oK1 = nullptr;
176     int N[3], M[3], K[3], pN[3], pM[3], pK[3], oM[3], oK[3],
177             *iNin[3] = { nullptr }, *oNin[3] = { nullptr }, *iNout[3] = { nullptr },
178             *oNout[3] = { nullptr };
179     int        C[3], rC[3], nP[2];
180     int        lsize;
181     t_complex *lin = nullptr, *lout = nullptr, *lout2 = nullptr, *lout3 = nullptr;
182     fft5d_plan plan;
183     int        s;
184
185     /* comm, prank and P are in the order of the decomposition (plan->cart is in the order of transposes) */
186 #if GMX_MPI
187     if (GMX_PARALLEL_ENV_INITIALIZED && comm[0] != MPI_COMM_NULL)
188     {
189         MPI_Comm_size(comm[0], &P[0]);
190         MPI_Comm_rank(comm[0], &prank[0]);
191     }
192     else
193 #endif
194     {
195         P[0]     = 1;
196         prank[0] = 0;
197     }
198 #if GMX_MPI
199     if (GMX_PARALLEL_ENV_INITIALIZED && comm[1] != MPI_COMM_NULL)
200     {
201         MPI_Comm_size(comm[1], &P[1]);
202         MPI_Comm_rank(comm[1], &prank[1]);
203     }
204     else
205 #endif
206     {
207         P[1]     = 1;
208         prank[1] = 0;
209     }
210
211     bMaster = prank[0] == 0 && prank[1] == 0;
212
213
214     if (debug)
215     {
216         fprintf(debug, "FFT5D: Using %dx%d rank grid, rank %d,%d\n", P[0], P[1], prank[0], prank[1]);
217     }
218
219     if (bMaster)
220     {
221         if (debug)
222         {
223             fprintf(debug,
224                     "FFT5D: N: %d, M: %d, K: %d, P: %dx%d, real2complex: %d, backward: %d, order "
225                     "yz: %d, debug %d\n",
226                     NG, MG, KG, P[0], P[1], int((flags & FFT5D_REALCOMPLEX) > 0),
227                     int((flags & FFT5D_BACKWARD) > 0), int((flags & FFT5D_ORDER_YZ) > 0),
228                     int((flags & FFT5D_DEBUG) > 0));
229         }
230         /* The check below is not correct, one prime factor 11 or 13 is ok.
231            if (fft5d_fmax(fft5d_fmax(lpfactor(NG),lpfactor(MG)),lpfactor(KG))>7) {
232             printf("WARNING: FFT very slow with prime factors larger 7\n");
233             printf("Change FFT size or in case you cannot change it look at\n");
234             printf("http://www.fftw.org/fftw3_doc/Generating-your-own-code.html\n");
235            }
236          */
237     }
238
239     if (NG == 0 || MG == 0 || KG == 0)
240     {
241         if (bMaster)
242         {
243             printf("FFT5D: FATAL: Datasize cannot be zero in any dimension\n");
244         }
245         return nullptr;
246     }
247
248     rNG = NG;
249     rMG = MG;
250     rKG = KG;
251
252     if (flags & FFT5D_REALCOMPLEX)
253     {
254         if (!(flags & FFT5D_BACKWARD))
255         {
256             NG = NG / 2 + 1;
257         }
258         else
259         {
260             if (!(flags & FFT5D_ORDER_YZ))
261             {
262                 MG = MG / 2 + 1;
263             }
264             else
265             {
266                 KG = KG / 2 + 1;
267             }
268         }
269     }
270
271
272     /*for transpose we need to know the size for each processor not only our own size*/
273
274     N0  = static_cast<int*>(malloc(P[0] * sizeof(int)));
275     N1  = static_cast<int*>(malloc(P[1] * sizeof(int)));
276     M0  = static_cast<int*>(malloc(P[0] * sizeof(int)));
277     M1  = static_cast<int*>(malloc(P[1] * sizeof(int)));
278     K0  = static_cast<int*>(malloc(P[0] * sizeof(int)));
279     K1  = static_cast<int*>(malloc(P[1] * sizeof(int)));
280     oN0 = static_cast<int*>(malloc(P[0] * sizeof(int)));
281     oN1 = static_cast<int*>(malloc(P[1] * sizeof(int)));
282     oM0 = static_cast<int*>(malloc(P[0] * sizeof(int)));
283     oM1 = static_cast<int*>(malloc(P[1] * sizeof(int)));
284     oK0 = static_cast<int*>(malloc(P[0] * sizeof(int)));
285     oK1 = static_cast<int*>(malloc(P[1] * sizeof(int)));
286
287     for (i = 0; i < P[0]; i++)
288     {
289 #define EVENDIST
290 #ifndef EVENDIST
291         oN0[i] = i * ceil((double)NG / P[0]);
292         oM0[i] = i * ceil((double)MG / P[0]);
293         oK0[i] = i * ceil((double)KG / P[0]);
294 #else
295         oN0[i] = (NG * i) / P[0];
296         oM0[i] = (MG * i) / P[0];
297         oK0[i] = (KG * i) / P[0];
298 #endif
299     }
300     for (i = 0; i < P[1]; i++)
301     {
302 #ifndef EVENDIST
303         oN1[i] = i * ceil((double)NG / P[1]);
304         oM1[i] = i * ceil((double)MG / P[1]);
305         oK1[i] = i * ceil((double)KG / P[1]);
306 #else
307         oN1[i] = (NG * i) / P[1];
308         oM1[i] = (MG * i) / P[1];
309         oK1[i] = (KG * i) / P[1];
310 #endif
311     }
312     for (i = 0; P[0] > 0 && i < P[0] - 1; i++)
313     {
314         N0[i] = oN0[i + 1] - oN0[i];
315         M0[i] = oM0[i + 1] - oM0[i];
316         K0[i] = oK0[i + 1] - oK0[i];
317     }
318     N0[P[0] - 1] = NG - oN0[P[0] - 1];
319     M0[P[0] - 1] = MG - oM0[P[0] - 1];
320     K0[P[0] - 1] = KG - oK0[P[0] - 1];
321     for (i = 0; P[1] > 0 && i < P[1] - 1; i++)
322     {
323         N1[i] = oN1[i + 1] - oN1[i];
324         M1[i] = oM1[i + 1] - oM1[i];
325         K1[i] = oK1[i + 1] - oK1[i];
326     }
327     N1[P[1] - 1] = NG - oN1[P[1] - 1];
328     M1[P[1] - 1] = MG - oM1[P[1] - 1];
329     K1[P[1] - 1] = KG - oK1[P[1] - 1];
330
331     /* for step 1-3 the local N,M,K sizes of the transposed system
332        C: contiguous dimension, and nP: number of processor in subcommunicator
333        for that step */
334
335     GMX_ASSERT(prank[0] < P[0], "Must have valid rank within communicator size");
336     GMX_ASSERT(prank[1] < P[1], "Must have valid rank within communicator size");
337     pM[0] = M0[prank[0]];
338     oM[0] = oM0[prank[0]];
339     pK[0] = K1[prank[1]];
340     oK[0] = oK1[prank[1]];
341     C[0]  = NG;
342     rC[0] = rNG;
343     if (!(flags & FFT5D_ORDER_YZ))
344     {
345         N[0]     = vmax(N1, P[1]);
346         M[0]     = M0[prank[0]];
347         K[0]     = vmax(K1, P[1]);
348         pN[0]    = N1[prank[1]];
349         iNout[0] = N1;
350         oNout[0] = oN1;
351         nP[0]    = P[1];
352         C[1]     = KG;
353         rC[1]    = rKG;
354         N[1]     = vmax(K0, P[0]);
355         pN[1]    = K0[prank[0]];
356         iNin[1]  = K1;
357         oNin[1]  = oK1;
358         iNout[1] = K0;
359         oNout[1] = oK0;
360         M[1]     = vmax(M0, P[0]);
361         pM[1]    = M0[prank[0]];
362         oM[1]    = oM0[prank[0]];
363         K[1]     = N1[prank[1]];
364         pK[1]    = N1[prank[1]];
365         oK[1]    = oN1[prank[1]];
366         nP[1]    = P[0];
367         C[2]     = MG;
368         rC[2]    = rMG;
369         iNin[2]  = M0;
370         oNin[2]  = oM0;
371         M[2]     = vmax(K0, P[0]);
372         pM[2]    = K0[prank[0]];
373         oM[2]    = oK0[prank[0]];
374         K[2]     = vmax(N1, P[1]);
375         pK[2]    = N1[prank[1]];
376         oK[2]    = oN1[prank[1]];
377         free(N0);
378         free(oN0); /*these are not used for this order*/
379         free(M1);
380         free(oM1); /*the rest is freed in destroy*/
381     }
382     else
383     {
384         N[0]     = vmax(N0, P[0]);
385         M[0]     = vmax(M0, P[0]);
386         K[0]     = K1[prank[1]];
387         pN[0]    = N0[prank[0]];
388         iNout[0] = N0;
389         oNout[0] = oN0;
390         nP[0]    = P[0];
391         C[1]     = MG;
392         rC[1]    = rMG;
393         N[1]     = vmax(M1, P[1]);
394         pN[1]    = M1[prank[1]];
395         iNin[1]  = M0;
396         oNin[1]  = oM0;
397         iNout[1] = M1;
398         oNout[1] = oM1;
399         M[1]     = N0[prank[0]];
400         pM[1]    = N0[prank[0]];
401         oM[1]    = oN0[prank[0]];
402         K[1]     = vmax(K1, P[1]);
403         pK[1]    = K1[prank[1]];
404         oK[1]    = oK1[prank[1]];
405         nP[1]    = P[1];
406         C[2]     = KG;
407         rC[2]    = rKG;
408         iNin[2]  = K1;
409         oNin[2]  = oK1;
410         M[2]     = vmax(N0, P[0]);
411         pM[2]    = N0[prank[0]];
412         oM[2]    = oN0[prank[0]];
413         K[2]     = vmax(M1, P[1]);
414         pK[2]    = M1[prank[1]];
415         oK[2]    = oM1[prank[1]];
416         free(N1);
417         free(oN1); /*these are not used for this order*/
418         free(K0);
419         free(oK0); /*the rest is freed in destroy*/
420     }
421     N[2] = pN[2] = -1; /*not used*/
422
423     /*
424        Difference between x-y-z regarding 2d decomposition is whether they are
425        distributed along axis 1, 2 or both
426      */
427
428     /* int lsize = fmax(N[0]*M[0]*K[0]*nP[0],N[1]*M[1]*K[1]*nP[1]); */
429     lsize = std::max(N[0] * M[0] * K[0] * nP[0], std::max(N[1] * M[1] * K[1] * nP[1], C[2] * M[2] * K[2]));
430     /* int lsize = fmax(C[0]*M[0]*K[0],fmax(C[1]*M[1]*K[1],C[2]*M[2]*K[2])); */
431     if (!(flags & FFT5D_NOMALLOC))
432     {
433         // only needed for PME GPU mixed mode
434         if (realGridAllocationPinningPolicy == gmx::PinningPolicy::PinnedIfSupported && GMX_GPU == GMX_GPU_CUDA)
435         {
436             const std::size_t numBytes = lsize * sizeof(t_complex);
437             lin = static_cast<t_complex*>(gmx::PageAlignedAllocationPolicy::malloc(numBytes));
438             gmx::pinBuffer(lin, numBytes);
439         }
440         else
441         {
442             snew_aligned(lin, lsize, 32);
443         }
444         snew_aligned(lout, lsize, 32);
445         if (nthreads > 1)
446         {
447             /* We need extra transpose buffers to avoid OpenMP barriers */
448             snew_aligned(lout2, lsize, 32);
449             snew_aligned(lout3, lsize, 32);
450         }
451         else
452         {
453             /* We can reuse the buffers to avoid cache misses */
454             lout2 = lin;
455             lout3 = lout;
456         }
457     }
458     else
459     {
460         lin  = *rlin;
461         lout = *rlout;
462         if (nthreads > 1)
463         {
464             lout2 = *rlout2;
465             lout3 = *rlout3;
466         }
467         else
468         {
469             lout2 = lin;
470             lout3 = lout;
471         }
472     }
473
474     plan = static_cast<fft5d_plan>(calloc(1, sizeof(struct fft5d_plan_t)));
475
476
477     if (debug)
478     {
479         fprintf(debug, "Running on %d threads\n", nthreads);
480     }
481
482 #if GMX_FFT_FFTW3
483     /* Don't add more stuff here! We have already had at least one bug because we are reimplementing
484      * the low-level FFT interface instead of using the Gromacs FFT module. If we need more
485      * generic functionality it is far better to extend the interface so we can use it for
486      * all FFT libraries instead of writing FFTW-specific code here.
487      */
488
489     /*if not FFTW - then we don't do a 3d plan but instead use only 1D plans */
490     /* It is possible to use the 3d plan with OMP threads - but in that case it is not allowed to be
491      * called from within a parallel region. For now deactivated. If it should be supported it has
492      * to made sure that that the execute of the 3d plan is in a master/serial block (since it
493      * contains it own parallel region) and that the 3d plan is faster than the 1d plan.
494      */
495     if ((!(flags & FFT5D_INPLACE)) && (!(P[0] > 1 || P[1] > 1))
496         && nthreads == 1) /*don't do 3d plan in parallel or if in_place requested */
497     {
498         int fftwflags = FFTW_DESTROY_INPUT;
499         FFTW(iodim) dims[3];
500         int inNG = NG, outMG = MG, outKG = KG;
501
502         FFTW_LOCK
503
504         fftwflags |= (flags & FFT5D_NOMEASURE) ? FFTW_ESTIMATE : FFTW_MEASURE;
505
506         if (flags & FFT5D_REALCOMPLEX)
507         {
508             if (!(flags & FFT5D_BACKWARD)) /*input pointer is not complex*/
509             {
510                 inNG *= 2;
511             }
512             else /*output pointer is not complex*/
513             {
514                 if (!(flags & FFT5D_ORDER_YZ))
515                 {
516                     outMG *= 2;
517                 }
518                 else
519                 {
520                     outKG *= 2;
521                 }
522             }
523         }
524
525         if (!(flags & FFT5D_BACKWARD))
526         {
527             dims[0].n = KG;
528             dims[1].n = MG;
529             dims[2].n = rNG;
530
531             dims[0].is = inNG * MG; /*N M K*/
532             dims[1].is = inNG;
533             dims[2].is = 1;
534             if (!(flags & FFT5D_ORDER_YZ))
535             {
536                 dims[0].os = MG; /*M K N*/
537                 dims[1].os = 1;
538                 dims[2].os = MG * KG;
539             }
540             else
541             {
542                 dims[0].os = 1; /*K N M*/
543                 dims[1].os = KG * NG;
544                 dims[2].os = KG;
545             }
546         }
547         else
548         {
549             if (!(flags & FFT5D_ORDER_YZ))
550             {
551                 dims[0].n = NG;
552                 dims[1].n = KG;
553                 dims[2].n = rMG;
554
555                 dims[0].is = 1;
556                 dims[1].is = NG * MG;
557                 dims[2].is = NG;
558
559                 dims[0].os = outMG * KG;
560                 dims[1].os = outMG;
561                 dims[2].os = 1;
562             }
563             else
564             {
565                 dims[0].n = MG;
566                 dims[1].n = NG;
567                 dims[2].n = rKG;
568
569                 dims[0].is = NG;
570                 dims[1].is = 1;
571                 dims[2].is = NG * MG;
572
573                 dims[0].os = outKG * NG;
574                 dims[1].os = outKG;
575                 dims[2].os = 1;
576             }
577         }
578 #    ifdef FFT5D_THREADS
579 #        ifdef FFT5D_FFTW_THREADS
580         FFTW(plan_with_nthreads)(nthreads);
581 #        endif
582 #    endif
583         if ((flags & FFT5D_REALCOMPLEX) && !(flags & FFT5D_BACKWARD))
584         {
585             plan->p3d = FFTW(plan_guru_dft_r2c)(/*rank*/ 3, dims,
586                                                 /*howmany*/ 0, /*howmany_dims*/ nullptr,
587                                                 reinterpret_cast<real*>(lin),
588                                                 reinterpret_cast<FFTW(complex)*>(lout),
589                                                 /*flags*/ fftwflags);
590         }
591         else if ((flags & FFT5D_REALCOMPLEX) && (flags & FFT5D_BACKWARD))
592         {
593             plan->p3d = FFTW(plan_guru_dft_c2r)(/*rank*/ 3, dims,
594                                                 /*howmany*/ 0, /*howmany_dims*/ nullptr,
595                                                 reinterpret_cast<FFTW(complex)*>(lin),
596                                                 reinterpret_cast<real*>(lout),
597                                                 /*flags*/ fftwflags);
598         }
599         else
600         {
601             plan->p3d = FFTW(plan_guru_dft)(
602                     /*rank*/ 3, dims,
603                     /*howmany*/ 0, /*howmany_dims*/ nullptr, reinterpret_cast<FFTW(complex)*>(lin),
604                     reinterpret_cast<FFTW(complex)*>(lout),
605                     /*sign*/ (flags & FFT5D_BACKWARD) ? 1 : -1, /*flags*/ fftwflags);
606         }
607 #    ifdef FFT5D_THREADS
608 #        ifdef FFT5D_FFTW_THREADS
609         FFTW(plan_with_nthreads)(1);
610 #        endif
611 #    endif
612         FFTW_UNLOCK
613     }
614     if (!plan->p3d) /* for decomposition and if 3d plan did not work */
615     {
616 #endif /* GMX_FFT_FFTW3 */
617         for (s = 0; s < 3; s++)
618         {
619             if (debug)
620             {
621                 fprintf(debug, "FFT5D: Plan s %d rC %d M %d pK %d C %d lsize %d\n", s, rC[s], M[s],
622                         pK[s], C[s], lsize);
623             }
624             plan->p1d[s] = static_cast<gmx_fft_t*>(malloc(sizeof(gmx_fft_t) * nthreads));
625
626             /* Make sure that the init routines are only called by one thread at a time and in order
627                (later is only important to not confuse valgrind)
628              */
629 #pragma omp parallel for num_threads(nthreads) schedule(static) ordered
630             for (int t = 0; t < nthreads; t++)
631             {
632 #pragma omp ordered
633                 {
634                     try
635                     {
636                         int tsize = ((t + 1) * pM[s] * pK[s] / nthreads) - (t * pM[s] * pK[s] / nthreads);
637
638                         if ((flags & FFT5D_REALCOMPLEX)
639                             && ((!(flags & FFT5D_BACKWARD) && s == 0)
640                                 || ((flags & FFT5D_BACKWARD) && s == 2)))
641                         {
642                             gmx_fft_init_many_1d_real(
643                                     &plan->p1d[s][t], rC[s], tsize,
644                                     (flags & FFT5D_NOMEASURE) ? GMX_FFT_FLAG_CONSERVATIVE : 0);
645                         }
646                         else
647                         {
648                             gmx_fft_init_many_1d(&plan->p1d[s][t], C[s], tsize,
649                                                  (flags & FFT5D_NOMEASURE) ? GMX_FFT_FLAG_CONSERVATIVE : 0);
650                         }
651                     }
652                     GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
653                 }
654             }
655         }
656
657 #if GMX_FFT_FFTW3
658     }
659 #endif
660     if ((flags & FFT5D_ORDER_YZ)) /*plan->cart is in the order of transposes */
661     {
662         plan->cart[0] = comm[0];
663         plan->cart[1] = comm[1];
664     }
665     else
666     {
667         plan->cart[1] = comm[0];
668         plan->cart[0] = comm[1];
669     }
670 #ifdef FFT5D_MPI_TRANSPOSE
671     FFTW_LOCK;
672     for (s = 0; s < 2; s++)
673     {
674         if ((s == 0 && !(flags & FFT5D_ORDER_YZ)) || (s == 1 && (flags & FFT5D_ORDER_YZ)))
675         {
676             plan->mpip[s] = FFTW(mpi_plan_many_transpose)(nP[s], nP[s], N[s] * K[s] * pM[s] * 2, 1,
677                                                           1, (real*)lout2, (real*)lout3,
678                                                           plan->cart[s], FFTW_PATIENT);
679         }
680         else
681         {
682             plan->mpip[s] = FFTW(mpi_plan_many_transpose)(nP[s], nP[s], N[s] * pK[s] * M[s] * 2, 1,
683                                                           1, (real*)lout2, (real*)lout3,
684                                                           plan->cart[s], FFTW_PATIENT);
685         }
686     }
687     FFTW_UNLOCK;
688 #endif
689
690
691     plan->lin   = lin;
692     plan->lout  = lout;
693     plan->lout2 = lout2;
694     plan->lout3 = lout3;
695
696     plan->NG = NG;
697     plan->MG = MG;
698     plan->KG = KG;
699
700     for (s = 0; s < 3; s++)
701     {
702         plan->N[s]     = N[s];
703         plan->M[s]     = M[s];
704         plan->K[s]     = K[s];
705         plan->pN[s]    = pN[s];
706         plan->pM[s]    = pM[s];
707         plan->pK[s]    = pK[s];
708         plan->oM[s]    = oM[s];
709         plan->oK[s]    = oK[s];
710         plan->C[s]     = C[s];
711         plan->rC[s]    = rC[s];
712         plan->iNin[s]  = iNin[s];
713         plan->oNin[s]  = oNin[s];
714         plan->iNout[s] = iNout[s];
715         plan->oNout[s] = oNout[s];
716     }
717     for (s = 0; s < 2; s++)
718     {
719         plan->P[s]    = nP[s];
720         plan->coor[s] = prank[s];
721     }
722
723     /*    plan->fftorder=fftorder;
724         plan->direction=direction;
725         plan->realcomplex=realcomplex;
726      */
727     plan->flags         = flags;
728     plan->nthreads      = nthreads;
729     plan->pinningPolicy = realGridAllocationPinningPolicy;
730     *rlin               = lin;
731     *rlout              = lout;
732     *rlout2             = lout2;
733     *rlout3             = lout3;
734     return plan;
735 }
736
737
738 enum order
739 {
740     XYZ,
741     XZY,
742     YXZ,
743     YZX,
744     ZXY,
745     ZYX
746 };
747
748
749 /*here x,y,z and N,M,K is in rotated coordinate system!!
750    x (and N) is mayor (consecutive) dimension, y (M) middle and z (K) major
751    maxN,maxM,maxK is max size of local data
752    pN, pM, pK is local size specific to current processor (only different to max if not divisible)
753    NG, MG, KG is size of global data*/
754 static void splitaxes(t_complex*       lout,
755                       const t_complex* lin,
756                       int              maxN,
757                       int              maxM,
758                       int              maxK,
759                       int              pM,
760                       int              P,
761                       int              NG,
762                       const int*       N,
763                       const int*       oN,
764                       int              starty,
765                       int              startz,
766                       int              endy,
767                       int              endz)
768 {
769     int x, y, z, i;
770     int in_i, out_i, in_z, out_z, in_y, out_y;
771     int s_y, e_y;
772
773     for (z = startz; z < endz + 1; z++) /*3. z l*/
774     {
775         if (z == startz)
776         {
777             s_y = starty;
778         }
779         else
780         {
781             s_y = 0;
782         }
783         if (z == endz)
784         {
785             e_y = endy;
786         }
787         else
788         {
789             e_y = pM;
790         }
791         out_z = z * maxN * maxM;
792         in_z  = z * NG * pM;
793
794         for (i = 0; i < P; i++) /*index cube along long axis*/
795         {
796             out_i = out_z + i * maxN * maxM * maxK;
797             in_i  = in_z + oN[i];
798             for (y = s_y; y < e_y; y++) /*2. y k*/
799             {
800                 out_y = out_i + y * maxN;
801                 in_y  = in_i + y * NG;
802                 for (x = 0; x < N[i]; x++) /*1. x j*/
803                 {
804                     lout[out_y + x] = lin[in_y + x]; /*in=z*NG*pM+oN[i]+y*NG+x*/
805                     /*after split important that each processor chunk i has size maxN*maxM*maxK and thus being the same size*/
806                     /*before split data contiguos - thus if different processor get different amount oN is different*/
807                 }
808             }
809         }
810     }
811 }
812
813 /*make axis contiguous again (after AllToAll) and also do local transpose*/
814 /*transpose mayor and major dimension
815    variables see above
816    the major, middle, minor order is only correct for x,y,z (N,M,K) for the input
817    N,M,K local dimensions
818    KG global size*/
819 static void joinAxesTrans13(t_complex*       lout,
820                             const t_complex* lin,
821                             int              maxN,
822                             int              maxM,
823                             int              maxK,
824                             int              pM,
825                             int              P,
826                             int              KG,
827                             const int*       K,
828                             const int*       oK,
829                             int              starty,
830                             int              startx,
831                             int              endy,
832                             int              endx)
833 {
834     int i, x, y, z;
835     int out_i, in_i, out_x, in_x, out_z, in_z;
836     int s_y, e_y;
837
838     for (x = startx; x < endx + 1; x++) /*1.j*/
839     {
840         if (x == startx)
841         {
842             s_y = starty;
843         }
844         else
845         {
846             s_y = 0;
847         }
848         if (x == endx)
849         {
850             e_y = endy;
851         }
852         else
853         {
854             e_y = pM;
855         }
856
857         out_x = x * KG * pM;
858         in_x  = x;
859
860         for (i = 0; i < P; i++) /*index cube along long axis*/
861         {
862             out_i = out_x + oK[i];
863             in_i  = in_x + i * maxM * maxN * maxK;
864             for (z = 0; z < K[i]; z++) /*3.l*/
865             {
866                 out_z = out_i + z;
867                 in_z  = in_i + z * maxM * maxN;
868                 for (y = s_y; y < e_y; y++) /*2.k*/
869                 {
870                     lout[out_z + y * KG] = lin[in_z + y * maxN]; /*out=x*KG*pM+oK[i]+z+y*KG*/
871                 }
872             }
873         }
874     }
875 }
876
877 /*make axis contiguous again (after AllToAll) and also do local transpose
878    tranpose mayor and middle dimension
879    variables see above
880    the minor, middle, major order is only correct for x,y,z (N,M,K) for the input
881    N,M,K local size
882    MG, global size*/
883 static void joinAxesTrans12(t_complex*       lout,
884                             const t_complex* lin,
885                             int              maxN,
886                             int              maxM,
887                             int              maxK,
888                             int              pN,
889                             int              P,
890                             int              MG,
891                             const int*       M,
892                             const int*       oM,
893                             int              startx,
894                             int              startz,
895                             int              endx,
896                             int              endz)
897 {
898     int i, z, y, x;
899     int out_i, in_i, out_z, in_z, out_x, in_x;
900     int s_x, e_x;
901
902     for (z = startz; z < endz + 1; z++)
903     {
904         if (z == startz)
905         {
906             s_x = startx;
907         }
908         else
909         {
910             s_x = 0;
911         }
912         if (z == endz)
913         {
914             e_x = endx;
915         }
916         else
917         {
918             e_x = pN;
919         }
920         out_z = z * MG * pN;
921         in_z  = z * maxM * maxN;
922
923         for (i = 0; i < P; i++) /*index cube along long axis*/
924         {
925             out_i = out_z + oM[i];
926             in_i  = in_z + i * maxM * maxN * maxK;
927             for (x = s_x; x < e_x; x++)
928             {
929                 out_x = out_i + x * MG;
930                 in_x  = in_i + x;
931                 for (y = 0; y < M[i]; y++)
932                 {
933                     lout[out_x + y] = lin[in_x + y * maxN]; /*out=z*MG*pN+oM[i]+x*MG+y*/
934                 }
935             }
936         }
937     }
938 }
939
940
941 static void rotate_offsets(int x[])
942 {
943     int t = x[0];
944     /*    x[0]=x[2];
945         x[2]=x[1];
946         x[1]=t;*/
947     x[0] = x[1];
948     x[1] = x[2];
949     x[2] = t;
950 }
951
952 /*compute the offset to compare or print transposed local data in original input coordinates
953    xs matrix dimension size, xl dimension length, xc decomposition offset
954    s: step in computation = number of transposes*/
955 static void compute_offsets(fft5d_plan plan, int xs[], int xl[], int xc[], int NG[], int s)
956 {
957     /*    int direction = plan->direction;
958         int fftorder = plan->fftorder;*/
959
960     int  o = 0;
961     int  pos[3], i;
962     int *pM = plan->pM, *pK = plan->pK, *oM = plan->oM, *oK = plan->oK, *C = plan->C, *rC = plan->rC;
963
964     NG[0] = plan->NG;
965     NG[1] = plan->MG;
966     NG[2] = plan->KG;
967
968     if (!(plan->flags & FFT5D_ORDER_YZ))
969     {
970         switch (s)
971         {
972             case 0: o = XYZ; break;
973             case 1: o = ZYX; break;
974             case 2: o = YZX; break;
975             default: assert(0);
976         }
977     }
978     else
979     {
980         switch (s)
981         {
982             case 0: o = XYZ; break;
983             case 1: o = YXZ; break;
984             case 2: o = ZXY; break;
985             default: assert(0);
986         }
987     }
988
989     switch (o)
990     {
991         case XYZ:
992             pos[0] = 1;
993             pos[1] = 2;
994             pos[2] = 3;
995             break;
996         case XZY:
997             pos[0] = 1;
998             pos[1] = 3;
999             pos[2] = 2;
1000             break;
1001         case YXZ:
1002             pos[0] = 2;
1003             pos[1] = 1;
1004             pos[2] = 3;
1005             break;
1006         case YZX:
1007             pos[0] = 3;
1008             pos[1] = 1;
1009             pos[2] = 2;
1010             break;
1011         case ZXY:
1012             pos[0] = 2;
1013             pos[1] = 3;
1014             pos[2] = 1;
1015             break;
1016         case ZYX:
1017             pos[0] = 3;
1018             pos[1] = 2;
1019             pos[2] = 1;
1020             break;
1021     }
1022     /*if (debug) printf("pos: %d %d %d\n",pos[0],pos[1],pos[2]);*/
1023
1024     /*xs, xl give dimension size and data length in local transposed coordinate system
1025        for 0(/1/2): x(/y/z) in original coordinate system*/
1026     for (i = 0; i < 3; i++)
1027     {
1028         switch (pos[i])
1029         {
1030             case 1:
1031                 xs[i] = 1;
1032                 xc[i] = 0;
1033                 xl[i] = C[s];
1034                 break;
1035             case 2:
1036                 xs[i] = C[s];
1037                 xc[i] = oM[s];
1038                 xl[i] = pM[s];
1039                 break;
1040             case 3:
1041                 xs[i] = C[s] * pM[s];
1042                 xc[i] = oK[s];
1043                 xl[i] = pK[s];
1044                 break;
1045         }
1046     }
1047     /*input order is different for test program to match FFTW order
1048        (important for complex to real)*/
1049     if (plan->flags & FFT5D_BACKWARD)
1050     {
1051         rotate_offsets(xs);
1052         rotate_offsets(xl);
1053         rotate_offsets(xc);
1054         rotate_offsets(NG);
1055         if (plan->flags & FFT5D_ORDER_YZ)
1056         {
1057             rotate_offsets(xs);
1058             rotate_offsets(xl);
1059             rotate_offsets(xc);
1060             rotate_offsets(NG);
1061         }
1062     }
1063     if ((plan->flags & FFT5D_REALCOMPLEX)
1064         && ((!(plan->flags & FFT5D_BACKWARD) && s == 0) || ((plan->flags & FFT5D_BACKWARD) && s == 2)))
1065     {
1066         xl[0] = rC[s];
1067     }
1068 }
1069
1070 static void print_localdata(const t_complex* lin, const char* txt, int s, fft5d_plan plan)
1071 {
1072     int  x, y, z, l;
1073     int* coor = plan->coor;
1074     int  xs[3], xl[3], xc[3], NG[3];
1075     int  ll = (plan->flags & FFT5D_REALCOMPLEX) ? 1 : 2;
1076     compute_offsets(plan, xs, xl, xc, NG, s);
1077     fprintf(debug, txt, coor[0], coor[1]);
1078     /*printf("xs: %d %d %d, xl: %d %d %d\n",xs[0],xs[1],xs[2],xl[0],xl[1],xl[2]);*/
1079     for (z = 0; z < xl[2]; z++)
1080     {
1081         for (y = 0; y < xl[1]; y++)
1082         {
1083             fprintf(debug, "%d %d: ", coor[0], coor[1]);
1084             for (x = 0; x < xl[0]; x++)
1085             {
1086                 for (l = 0; l < ll; l++)
1087                 {
1088                     fprintf(debug, "%f ",
1089                             reinterpret_cast<const real*>(
1090                                     lin)[(z * xs[2] + y * xs[1]) * 2 + (x * xs[0]) * ll + l]);
1091                 }
1092                 fprintf(debug, ",");
1093             }
1094             fprintf(debug, "\n");
1095         }
1096     }
1097 }
1098
1099 void fft5d_execute(fft5d_plan plan, int thread, fft5d_time times)
1100 {
1101     t_complex* lin   = plan->lin;
1102     t_complex* lout  = plan->lout;
1103     t_complex* lout2 = plan->lout2;
1104     t_complex* lout3 = plan->lout3;
1105     t_complex *fftout, *joinin;
1106
1107     gmx_fft_t** p1d = plan->p1d;
1108 #ifdef FFT5D_MPI_TRANSPOSE
1109     FFTW(plan)* mpip = plan->mpip;
1110 #endif
1111 #if GMX_MPI
1112     MPI_Comm* cart = plan->cart;
1113 #endif
1114 #ifdef NOGMX
1115     double time_fft = 0, time_local = 0, time_mpi[2] = { 0 }, time = 0;
1116 #endif
1117     int *N = plan->N, *M = plan->M, *K = plan->K, *pN = plan->pN, *pM = plan->pM, *pK = plan->pK,
1118         *C = plan->C, *P = plan->P, **iNin = plan->iNin, **oNin = plan->oNin, **iNout = plan->iNout,
1119         **oNout = plan->oNout;
1120     int s       = 0, tstart, tend, bParallelDim;
1121
1122
1123 #if GMX_FFT_FFTW3
1124     if (plan->p3d)
1125     {
1126         if (thread == 0)
1127         {
1128 #    ifdef NOGMX
1129             if (times != 0)
1130             {
1131                 time = MPI_Wtime();
1132             }
1133 #    endif
1134             FFTW(execute)(plan->p3d);
1135 #    ifdef NOGMX
1136             if (times != 0)
1137             {
1138                 times->fft += MPI_Wtime() - time;
1139             }
1140 #    endif
1141         }
1142         return;
1143     }
1144 #endif
1145
1146     s = 0;
1147
1148     /*lin: x,y,z*/
1149     if ((plan->flags & FFT5D_DEBUG) && thread == 0)
1150     {
1151         print_localdata(lin, "%d %d: copy in lin\n", s, plan);
1152     }
1153
1154     for (s = 0; s < 2; s++) /*loop over first two FFT steps (corner rotations)*/
1155
1156     {
1157 #if GMX_MPI
1158         if (GMX_PARALLEL_ENV_INITIALIZED && cart[s] != MPI_COMM_NULL && P[s] > 1)
1159         {
1160             bParallelDim = 1;
1161         }
1162         else
1163 #endif
1164         {
1165             bParallelDim = 0;
1166         }
1167
1168         /* ---------- START FFT ------------ */
1169 #ifdef NOGMX
1170         if (times != 0 && thread == 0)
1171         {
1172             time = MPI_Wtime();
1173         }
1174 #endif
1175
1176         if (bParallelDim || plan->nthreads == 1)
1177         {
1178             fftout = lout;
1179         }
1180         else
1181         {
1182             if (s == 0)
1183             {
1184                 fftout = lout3;
1185             }
1186             else
1187             {
1188                 fftout = lout2;
1189             }
1190         }
1191
1192         tstart = (thread * pM[s] * pK[s] / plan->nthreads) * C[s];
1193         if ((plan->flags & FFT5D_REALCOMPLEX) && !(plan->flags & FFT5D_BACKWARD) && s == 0)
1194         {
1195             gmx_fft_many_1d_real(p1d[s][thread],
1196                                  (plan->flags & FFT5D_BACKWARD) ? GMX_FFT_COMPLEX_TO_REAL
1197                                                                 : GMX_FFT_REAL_TO_COMPLEX,
1198                                  lin + tstart, fftout + tstart);
1199         }
1200         else
1201         {
1202             gmx_fft_many_1d(p1d[s][thread],
1203                             (plan->flags & FFT5D_BACKWARD) ? GMX_FFT_BACKWARD : GMX_FFT_FORWARD,
1204                             lin + tstart, fftout + tstart);
1205         }
1206
1207 #ifdef NOGMX
1208         if (times != NULL && thread == 0)
1209         {
1210             time_fft += MPI_Wtime() - time;
1211         }
1212 #endif
1213         if ((plan->flags & FFT5D_DEBUG) && thread == 0)
1214         {
1215             print_localdata(lout, "%d %d: FFT %d\n", s, plan);
1216         }
1217         /* ---------- END FFT ------------ */
1218
1219         /* ---------- START SPLIT + TRANSPOSE------------ (if parallel in in this dimension)*/
1220         if (bParallelDim)
1221         {
1222 #ifdef NOGMX
1223             if (times != NULL && thread == 0)
1224             {
1225                 time = MPI_Wtime();
1226             }
1227 #endif
1228             /*prepare for A
1229                llToAll
1230                1. (most outer) axes (x) is split into P[s] parts of size N[s]
1231                for sending*/
1232             if (pM[s] > 0)
1233             {
1234                 tend = ((thread + 1) * pM[s] * pK[s] / plan->nthreads);
1235                 tstart /= C[s];
1236                 splitaxes(lout2, lout, N[s], M[s], K[s], pM[s], P[s], C[s], iNout[s], oNout[s],
1237                           tstart % pM[s], tstart / pM[s], tend % pM[s], tend / pM[s]);
1238             }
1239 #pragma omp barrier /*barrier required before AllToAll (all input has to be their) - before timing to make timing more acurate*/
1240 #ifdef NOGMX
1241             if (times != NULL && thread == 0)
1242             {
1243                 time_local += MPI_Wtime() - time;
1244             }
1245 #endif
1246
1247             /* ---------- END SPLIT , START TRANSPOSE------------ */
1248
1249             if (thread == 0)
1250             {
1251 #ifdef NOGMX
1252                 if (times != 0)
1253                 {
1254                     time = MPI_Wtime();
1255                 }
1256 #else
1257                 wallcycle_start(times, ewcPME_FFTCOMM);
1258 #endif
1259 #ifdef FFT5D_MPI_TRANSPOSE
1260                 FFTW(execute)(mpip[s]);
1261 #else
1262 #    if GMX_MPI
1263                 if ((s == 0 && !(plan->flags & FFT5D_ORDER_YZ))
1264                     || (s == 1 && (plan->flags & FFT5D_ORDER_YZ)))
1265                 {
1266                     MPI_Alltoall(reinterpret_cast<real*>(lout2),
1267                                  N[s] * pM[s] * K[s] * sizeof(t_complex) / sizeof(real),
1268                                  GMX_MPI_REAL, reinterpret_cast<real*>(lout3),
1269                                  N[s] * pM[s] * K[s] * sizeof(t_complex) / sizeof(real),
1270                                  GMX_MPI_REAL, cart[s]);
1271                 }
1272                 else
1273                 {
1274                     MPI_Alltoall(reinterpret_cast<real*>(lout2),
1275                                  N[s] * M[s] * pK[s] * sizeof(t_complex) / sizeof(real),
1276                                  GMX_MPI_REAL, reinterpret_cast<real*>(lout3),
1277                                  N[s] * M[s] * pK[s] * sizeof(t_complex) / sizeof(real),
1278                                  GMX_MPI_REAL, cart[s]);
1279                 }
1280 #    else
1281                 GMX_RELEASE_ASSERT(false, "Invalid call to fft5d_execute");
1282 #    endif /*GMX_MPI*/
1283 #endif     /*FFT5D_MPI_TRANSPOSE*/
1284 #ifdef NOGMX
1285                 if (times != 0)
1286                 {
1287                     time_mpi[s] = MPI_Wtime() - time;
1288                 }
1289 #else
1290                 wallcycle_stop(times, ewcPME_FFTCOMM);
1291 #endif
1292             }       /*master*/
1293         }           /* bPrallelDim */
1294 #pragma omp barrier /*both needed for parallel and non-parallel dimension (either have to wait on data from AlltoAll or from last FFT*/
1295
1296         /* ---------- END SPLIT + TRANSPOSE------------ */
1297
1298         /* ---------- START JOIN ------------ */
1299 #ifdef NOGMX
1300         if (times != NULL && thread == 0)
1301         {
1302             time = MPI_Wtime();
1303         }
1304 #endif
1305
1306         if (bParallelDim)
1307         {
1308             joinin = lout3;
1309         }
1310         else
1311         {
1312             joinin = fftout;
1313         }
1314         /*bring back in matrix form
1315            thus make  new 1. axes contiguos
1316            also local transpose 1 and 2/3
1317            runs on thread used for following FFT (thus needing a barrier before but not afterwards)
1318          */
1319         if ((s == 0 && !(plan->flags & FFT5D_ORDER_YZ)) || (s == 1 && (plan->flags & FFT5D_ORDER_YZ)))
1320         {
1321             if (pM[s] > 0)
1322             {
1323                 tstart = (thread * pM[s] * pN[s] / plan->nthreads);
1324                 tend   = ((thread + 1) * pM[s] * pN[s] / plan->nthreads);
1325                 joinAxesTrans13(lin, joinin, N[s], pM[s], K[s], pM[s], P[s], C[s + 1], iNin[s + 1],
1326                                 oNin[s + 1], tstart % pM[s], tstart / pM[s], tend % pM[s], tend / pM[s]);
1327             }
1328         }
1329         else
1330         {
1331             if (pN[s] > 0)
1332             {
1333                 tstart = (thread * pK[s] * pN[s] / plan->nthreads);
1334                 tend   = ((thread + 1) * pK[s] * pN[s] / plan->nthreads);
1335                 joinAxesTrans12(lin, joinin, N[s], M[s], pK[s], pN[s], P[s], C[s + 1], iNin[s + 1],
1336                                 oNin[s + 1], tstart % pN[s], tstart / pN[s], tend % pN[s], tend / pN[s]);
1337             }
1338         }
1339
1340 #ifdef NOGMX
1341         if (times != NULL && thread == 0)
1342         {
1343             time_local += MPI_Wtime() - time;
1344         }
1345 #endif
1346         if ((plan->flags & FFT5D_DEBUG) && thread == 0)
1347         {
1348             print_localdata(lin, "%d %d: tranposed %d\n", s + 1, plan);
1349         }
1350         /* ---------- END JOIN ------------ */
1351
1352         /*if (debug) print_localdata(lin, "%d %d: transposed x-z\n", N1, M0, K, ZYX, coor);*/
1353     } /* for(s=0;s<2;s++) */
1354 #ifdef NOGMX
1355     if (times != NULL && thread == 0)
1356     {
1357         time = MPI_Wtime();
1358     }
1359 #endif
1360
1361     if (plan->flags & FFT5D_INPLACE)
1362     {
1363         lout = lin; /*in place currently not supported*/
1364     }
1365     /*  ----------- FFT ----------- */
1366     tstart = (thread * pM[s] * pK[s] / plan->nthreads) * C[s];
1367     if ((plan->flags & FFT5D_REALCOMPLEX) && (plan->flags & FFT5D_BACKWARD))
1368     {
1369         gmx_fft_many_1d_real(p1d[s][thread],
1370                              (plan->flags & FFT5D_BACKWARD) ? GMX_FFT_COMPLEX_TO_REAL : GMX_FFT_REAL_TO_COMPLEX,
1371                              lin + tstart, lout + tstart);
1372     }
1373     else
1374     {
1375         gmx_fft_many_1d(p1d[s][thread], (plan->flags & FFT5D_BACKWARD) ? GMX_FFT_BACKWARD : GMX_FFT_FORWARD,
1376                         lin + tstart, lout + tstart);
1377     }
1378     /* ------------ END FFT ---------*/
1379
1380 #ifdef NOGMX
1381     if (times != NULL && thread == 0)
1382     {
1383         time_fft += MPI_Wtime() - time;
1384
1385         times->fft += time_fft;
1386         times->local += time_local;
1387         times->mpi2 += time_mpi[1];
1388         times->mpi1 += time_mpi[0];
1389     }
1390 #endif
1391
1392     if ((plan->flags & FFT5D_DEBUG) && thread == 0)
1393     {
1394         print_localdata(lout, "%d %d: FFT %d\n", s, plan);
1395     }
1396 }
1397
1398 void fft5d_destroy(fft5d_plan plan)
1399 {
1400     int s, t;
1401
1402     for (s = 0; s < 3; s++)
1403     {
1404         if (plan->p1d[s])
1405         {
1406             for (t = 0; t < plan->nthreads; t++)
1407             {
1408                 gmx_many_fft_destroy(plan->p1d[s][t]);
1409             }
1410             free(plan->p1d[s]);
1411         }
1412         if (plan->iNin[s])
1413         {
1414             free(plan->iNin[s]);
1415             plan->iNin[s] = nullptr;
1416         }
1417         if (plan->oNin[s])
1418         {
1419             free(plan->oNin[s]);
1420             plan->oNin[s] = nullptr;
1421         }
1422         if (plan->iNout[s])
1423         {
1424             free(plan->iNout[s]);
1425             plan->iNout[s] = nullptr;
1426         }
1427         if (plan->oNout[s])
1428         {
1429             free(plan->oNout[s]);
1430             plan->oNout[s] = nullptr;
1431         }
1432     }
1433 #if GMX_FFT_FFTW3
1434     FFTW_LOCK
1435 #    ifdef FFT5D_MPI_TRANSPOS
1436     for (s = 0; s < 2; s++)
1437     {
1438         FFTW(destroy_plan)(plan->mpip[s]);
1439     }
1440 #    endif /* FFT5D_MPI_TRANSPOS */
1441     if (plan->p3d)
1442     {
1443         FFTW(destroy_plan)(plan->p3d);
1444     }
1445     FFTW_UNLOCK
1446 #endif /* GMX_FFT_FFTW3 */
1447
1448     if (!(plan->flags & FFT5D_NOMALLOC))
1449     {
1450         // only needed for PME GPU mixed mode
1451         if (plan->pinningPolicy == gmx::PinningPolicy::PinnedIfSupported && isHostMemoryPinned(plan->lin))
1452         {
1453             gmx::unpinBuffer(plan->lin);
1454         }
1455         sfree_aligned(plan->lin);
1456         sfree_aligned(plan->lout);
1457         if (plan->nthreads > 1)
1458         {
1459             sfree_aligned(plan->lout2);
1460             sfree_aligned(plan->lout3);
1461         }
1462     }
1463
1464 #ifdef FFT5D_THREADS
1465 #    ifdef FFT5D_FFTW_THREADS
1466     /*FFTW(cleanup_threads)();*/
1467 #    endif
1468 #endif
1469
1470     free(plan);
1471 }
1472
1473 /*Is this better than direct access of plan? enough data?
1474    here 0,1 reference divided by which processor grid dimension (not FFT step!)*/
1475 void fft5d_local_size(fft5d_plan plan, int* N1, int* M0, int* K0, int* K1, int** coor)
1476 {
1477     *N1 = plan->N[0];
1478     *M0 = plan->M[0];
1479     *K1 = plan->K[0];
1480     *K0 = plan->N[1];
1481
1482     *coor = plan->coor;
1483 }
1484
1485
1486 /*same as fft5d_plan_3d but with cartesian coordinator and automatic splitting
1487    of processor dimensions*/
1488 fft5d_plan fft5d_plan_3d_cart(int         NG,
1489                               int         MG,
1490                               int         KG,
1491                               MPI_Comm    comm,
1492                               int         P0,
1493                               int         flags,
1494                               t_complex** rlin,
1495                               t_complex** rlout,
1496                               t_complex** rlout2,
1497                               t_complex** rlout3,
1498                               int         nthreads)
1499 {
1500     MPI_Comm cart[2] = { MPI_COMM_NULL, MPI_COMM_NULL };
1501 #if GMX_MPI
1502     int      size = 1, prank = 0;
1503     int      P[2];
1504     int      coor[2];
1505     int      wrap[] = { 0, 0 };
1506     MPI_Comm gcart;
1507     int      rdim1[] = { 0, 1 }, rdim2[] = { 1, 0 };
1508
1509     MPI_Comm_size(comm, &size);
1510     MPI_Comm_rank(comm, &prank);
1511
1512     if (P0 == 0)
1513     {
1514         P0 = lfactor(size);
1515     }
1516     if (size % P0 != 0)
1517     {
1518         if (prank == 0)
1519         {
1520             printf("FFT5D: WARNING: Number of ranks %d not evenly divisible by %d\n", size, P0);
1521         }
1522         P0 = lfactor(size);
1523     }
1524
1525     P[0] = P0;
1526     P[1] = size / P0; /*number of processors in the two dimensions*/
1527
1528     /*Difference between x-y-z regarding 2d decomposition is whether they are
1529        distributed along axis 1, 2 or both*/
1530
1531     MPI_Cart_create(comm, 2, P, wrap, 1, &gcart); /*parameter 4: value 1: reorder*/
1532     MPI_Cart_get(gcart, 2, P, wrap, coor);
1533     MPI_Cart_sub(gcart, rdim1, &cart[0]);
1534     MPI_Cart_sub(gcart, rdim2, &cart[1]);
1535 #else
1536     (void)P0;
1537     (void)comm;
1538 #endif
1539     return fft5d_plan_3d(NG, MG, KG, cart, flags, rlin, rlout, rlout2, rlout3, nthreads);
1540 }
1541
1542
1543 /*prints in original coordinate system of data (as the input to FFT)*/
1544 void fft5d_compare_data(const t_complex* lin, const t_complex* in, fft5d_plan plan, int bothLocal, int normalize)
1545 {
1546     int  xs[3], xl[3], xc[3], NG[3];
1547     int  x, y, z, l;
1548     int* coor = plan->coor;
1549     int  ll   = 2; /*compare ll values per element (has to be 2 for complex)*/
1550     if ((plan->flags & FFT5D_REALCOMPLEX) && (plan->flags & FFT5D_BACKWARD))
1551     {
1552         ll = 1;
1553     }
1554
1555     compute_offsets(plan, xs, xl, xc, NG, 2);
1556     if (plan->flags & FFT5D_DEBUG)
1557     {
1558         printf("Compare2\n");
1559     }
1560     for (z = 0; z < xl[2]; z++)
1561     {
1562         for (y = 0; y < xl[1]; y++)
1563         {
1564             if (plan->flags & FFT5D_DEBUG)
1565             {
1566                 printf("%d %d: ", coor[0], coor[1]);
1567             }
1568             for (x = 0; x < xl[0]; x++)
1569             {
1570                 for (l = 0; l < ll; l++) /*loop over real/complex parts*/
1571                 {
1572                     real a, b;
1573                     a = reinterpret_cast<const real*>(lin)[(z * xs[2] + y * xs[1]) * 2 + x * xs[0] * ll + l];
1574                     if (normalize)
1575                     {
1576                         a /= plan->rC[0] * plan->rC[1] * plan->rC[2];
1577                     }
1578                     if (!bothLocal)
1579                     {
1580                         b = reinterpret_cast<const real*>(
1581                                 in)[((z + xc[2]) * NG[0] * NG[1] + (y + xc[1]) * NG[0]) * 2 + (x + xc[0]) * ll + l];
1582                     }
1583                     else
1584                     {
1585                         b = reinterpret_cast<const real*>(
1586                                 in)[(z * xs[2] + y * xs[1]) * 2 + x * xs[0] * ll + l];
1587                     }
1588                     if (plan->flags & FFT5D_DEBUG)
1589                     {
1590                         printf("%f %f, ", a, b);
1591                     }
1592                     else
1593                     {
1594                         if (std::fabs(a - b) > 2 * NG[0] * NG[1] * NG[2] * GMX_REAL_EPS)
1595                         {
1596                             printf("result incorrect on %d,%d at %d,%d,%d: FFT5D:%f reference:%f\n",
1597                                    coor[0], coor[1], x, y, z, a, b);
1598                         }
1599                         /*                        assert(fabs(a-b)<2*NG[0]*NG[1]*NG[2]*GMX_REAL_EPS);*/
1600                     }
1601                 }
1602                 if (plan->flags & FFT5D_DEBUG)
1603                 {
1604                     printf(",");
1605                 }
1606             }
1607             if (plan->flags & FFT5D_DEBUG)
1608             {
1609                 printf("\n");
1610             }
1611         }
1612     }
1613 }