3e0cb7857777ffbd7af8105123c4e509fd31fd59
[alexxy/gromacs.git] / src / gromacs / fft / fft5d.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2009,2010,2012,2013,2014,2015,2016,2017,2018, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 #include "gmxpre.h"
36
37 #include "fft5d.h"
38
39 #include "config.h"
40
41 #include <cassert>
42 #include <cfloat>
43 #include <cmath>
44 #include <cstdio>
45 #include <cstdlib>
46 #include <cstring>
47
48 #include <algorithm>
49
50 #include "gromacs/gpu_utils/gpu_utils.h"
51 #include "gromacs/gpu_utils/hostallocator.h"
52 #include "gromacs/gpu_utils/pinning.h"
53 #include "gromacs/utility/alignedallocator.h"
54 #include "gromacs/utility/exceptions.h"
55 #include "gromacs/utility/fatalerror.h"
56 #include "gromacs/utility/gmxmpi.h"
57 #include "gromacs/utility/smalloc.h"
58
59 #ifdef NOGMX
60 #define GMX_PARALLEL_ENV_INITIALIZED 1
61 #else
62 #if GMX_MPI
63 #define GMX_PARALLEL_ENV_INITIALIZED 1
64 #else
65 #define GMX_PARALLEL_ENV_INITIALIZED 0
66 #endif
67 #endif
68
69 #if GMX_OPENMP
70 /* TODO: Do we still need this? Are we still planning ot use fftw + OpenMP? */
71 #define FFT5D_THREADS
72 /* requires fftw compiled with openmp */
73 /* #define FFT5D_FFTW_THREADS (now set by cmake) */
74 #endif
75
76 #ifndef __FLT_EPSILON__
77 #define __FLT_EPSILON__ FLT_EPSILON
78 #define __DBL_EPSILON__ DBL_EPSILON
79 #endif
80
81 #ifdef NOGMX
82 FILE* debug = 0;
83 #endif
84
85 #if GMX_FFT_FFTW3
86
87 #include "gromacs/utility/exceptions.h"
88 #include "gromacs/utility/mutex.h"
89 /* none of the fftw3 calls, except execute(), are thread-safe, so
90    we need to serialize them with this mutex. */
91 static gmx::Mutex big_fftw_mutex;
92 #define FFTW_LOCK try { big_fftw_mutex.lock(); } GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
93 #define FFTW_UNLOCK try { big_fftw_mutex.unlock(); } GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
94 #endif /* GMX_FFT_FFTW3 */
95
96 #if GMX_MPI
97 /* largest factor smaller than sqrt */
98 static int lfactor(int z)
99 {
100     int i = static_cast<int>(sqrt(static_cast<double>(z)));
101     while (z%i != 0)
102     {
103         i--;
104     }
105     return i;
106 }
107 #endif
108
109 #if !GMX_MPI
110 #if HAVE_GETTIMEOFDAY
111 #include <sys/time.h>
112 double MPI_Wtime()
113 {
114     struct timeval tv;
115     gettimeofday(&tv, 0);
116     return tv.tv_sec+tv.tv_usec*1e-6;
117 }
118 #else
119 double MPI_Wtime()
120 {
121     return 0.0;
122 }
123 #endif
124 #endif
125
126 static int vmax(const int* a, int s)
127 {
128     int i, max = 0;
129     for (i = 0; i < s; i++)
130     {
131         if (a[i] > max)
132         {
133             max = a[i];
134         }
135     }
136     return max;
137 }
138
139
140 /* NxMxK the size of the data
141  * comm communicator to use for fft5d
142  * P0 number of processor in 1st axes (can be null for automatic)
143  * lin is allocated by fft5d because size of array is only known after planning phase
144  * rlout2 is only used as intermediate buffer - only returned after allocation to reuse for back transform - should not be used by caller
145  */
146 fft5d_plan fft5d_plan_3d(int NG, int MG, int KG, MPI_Comm comm[2], int flags, t_complex** rlin, t_complex** rlout, t_complex** rlout2, t_complex** rlout3, int nthreads, gmx::PinningPolicy realGridAllocationPinningPolicy)
147 {
148
149     int        P[2], bMaster, prank[2], i, t;
150     int        rNG, rMG, rKG;
151     int       *N0 = nullptr, *N1 = nullptr, *M0 = nullptr, *M1 = nullptr, *K0 = nullptr, *K1 = nullptr, *oN0 = nullptr, *oN1 = nullptr, *oM0 = nullptr, *oM1 = nullptr, *oK0 = nullptr, *oK1 = nullptr;
152     int        N[3], M[3], K[3], pN[3], pM[3], pK[3], oM[3], oK[3], *iNin[3] = {nullptr}, *oNin[3] = {nullptr}, *iNout[3] = {nullptr}, *oNout[3] = {nullptr};
153     int        C[3], rC[3], nP[2];
154     int        lsize;
155     t_complex *lin = nullptr, *lout = nullptr, *lout2 = nullptr, *lout3 = nullptr;
156     fft5d_plan plan;
157     int        s;
158
159     /* comm, prank and P are in the order of the decomposition (plan->cart is in the order of transposes) */
160 #if GMX_MPI
161     if (GMX_PARALLEL_ENV_INITIALIZED && comm[0] != MPI_COMM_NULL)
162     {
163         MPI_Comm_size(comm[0], &P[0]);
164         MPI_Comm_rank(comm[0], &prank[0]);
165     }
166     else
167 #endif
168     {
169         P[0]     = 1;
170         prank[0] = 0;
171     }
172 #if GMX_MPI
173     if (GMX_PARALLEL_ENV_INITIALIZED && comm[1] != MPI_COMM_NULL)
174     {
175         MPI_Comm_size(comm[1], &P[1]);
176         MPI_Comm_rank(comm[1], &prank[1]);
177     }
178     else
179 #endif
180     {
181         P[1]     = 1;
182         prank[1] = 0;
183     }
184
185     bMaster = (prank[0] == 0 && prank[1] == 0);
186
187
188     if (debug)
189     {
190         fprintf(debug, "FFT5D: Using %dx%d rank grid, rank %d,%d\n",
191                 P[0], P[1], prank[0], prank[1]);
192     }
193
194     if (bMaster)
195     {
196         if (debug)
197         {
198             fprintf(debug, "FFT5D: N: %d, M: %d, K: %d, P: %dx%d, real2complex: %d, backward: %d, order yz: %d, debug %d\n",
199                     NG, MG, KG, P[0], P[1], int((flags&FFT5D_REALCOMPLEX) > 0), int((flags&FFT5D_BACKWARD) > 0), int((flags&FFT5D_ORDER_YZ) > 0), int((flags&FFT5D_DEBUG) > 0));
200         }
201         /* The check below is not correct, one prime factor 11 or 13 is ok.
202            if (fft5d_fmax(fft5d_fmax(lpfactor(NG),lpfactor(MG)),lpfactor(KG))>7) {
203             printf("WARNING: FFT very slow with prime factors larger 7\n");
204             printf("Change FFT size or in case you cannot change it look at\n");
205             printf("http://www.fftw.org/fftw3_doc/Generating-your-own-code.html\n");
206            }
207          */
208     }
209
210     if (NG == 0 || MG == 0 || KG == 0)
211     {
212         if (bMaster)
213         {
214             printf("FFT5D: FATAL: Datasize cannot be zero in any dimension\n");
215         }
216         return nullptr;
217     }
218
219     rNG = NG; rMG = MG; rKG = KG;
220
221     if (flags&FFT5D_REALCOMPLEX)
222     {
223         if (!(flags&FFT5D_BACKWARD))
224         {
225             NG = NG/2+1;
226         }
227         else
228         {
229             if (!(flags&FFT5D_ORDER_YZ))
230             {
231                 MG = MG/2+1;
232             }
233             else
234             {
235                 KG = KG/2+1;
236             }
237         }
238     }
239
240
241     /*for transpose we need to know the size for each processor not only our own size*/
242
243     N0  = static_cast<int*>(malloc(P[0]*sizeof(int))); N1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
244     M0  = static_cast<int*>(malloc(P[0]*sizeof(int))); M1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
245     K0  = static_cast<int*>(malloc(P[0]*sizeof(int))); K1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
246     oN0 = static_cast<int*>(malloc(P[0]*sizeof(int))); oN1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
247     oM0 = static_cast<int*>(malloc(P[0]*sizeof(int))); oM1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
248     oK0 = static_cast<int*>(malloc(P[0]*sizeof(int))); oK1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
249
250     for (i = 0; i < P[0]; i++)
251     {
252         #define EVENDIST
253         #ifndef EVENDIST
254         oN0[i] = i*ceil((double)NG/P[0]);
255         oM0[i] = i*ceil((double)MG/P[0]);
256         oK0[i] = i*ceil((double)KG/P[0]);
257         #else
258         oN0[i] = (NG*i)/P[0];
259         oM0[i] = (MG*i)/P[0];
260         oK0[i] = (KG*i)/P[0];
261         #endif
262     }
263     for (i = 0; i < P[1]; i++)
264     {
265         #ifndef EVENDIST
266         oN1[i] = i*ceil((double)NG/P[1]);
267         oM1[i] = i*ceil((double)MG/P[1]);
268         oK1[i] = i*ceil((double)KG/P[1]);
269         #else
270         oN1[i] = (NG*i)/P[1];
271         oM1[i] = (MG*i)/P[1];
272         oK1[i] = (KG*i)/P[1];
273         #endif
274     }
275     for (i = 0; i < P[0]-1; i++)
276     {
277         N0[i] = oN0[i+1]-oN0[i];
278         M0[i] = oM0[i+1]-oM0[i];
279         K0[i] = oK0[i+1]-oK0[i];
280     }
281     N0[P[0]-1] = NG-oN0[P[0]-1];
282     M0[P[0]-1] = MG-oM0[P[0]-1];
283     K0[P[0]-1] = KG-oK0[P[0]-1];
284     for (i = 0; i < P[1]-1; i++)
285     {
286         N1[i] = oN1[i+1]-oN1[i];
287         M1[i] = oM1[i+1]-oM1[i];
288         K1[i] = oK1[i+1]-oK1[i];
289     }
290     N1[P[1]-1] = NG-oN1[P[1]-1];
291     M1[P[1]-1] = MG-oM1[P[1]-1];
292     K1[P[1]-1] = KG-oK1[P[1]-1];
293
294     /* for step 1-3 the local N,M,K sizes of the transposed system
295        C: contiguous dimension, and nP: number of processor in subcommunicator
296        for that step */
297
298
299     pM[0] = M0[prank[0]];
300     oM[0] = oM0[prank[0]];
301     pK[0] = K1[prank[1]];
302     oK[0] = oK1[prank[1]];
303     C[0]  = NG;
304     rC[0] = rNG;
305     if (!(flags&FFT5D_ORDER_YZ))
306     {
307         N[0]     = vmax(N1, P[1]);
308         M[0]     = M0[prank[0]];
309         K[0]     = vmax(K1, P[1]);
310         pN[0]    = N1[prank[1]];
311         iNout[0] = N1;
312         oNout[0] = oN1;
313         nP[0]    = P[1];
314         C[1]     = KG;
315         rC[1]    = rKG;
316         N[1]     = vmax(K0, P[0]);
317         pN[1]    = K0[prank[0]];
318         iNin[1]  = K1;
319         oNin[1]  = oK1;
320         iNout[1] = K0;
321         oNout[1] = oK0;
322         M[1]     = vmax(M0, P[0]);
323         pM[1]    = M0[prank[0]];
324         oM[1]    = oM0[prank[0]];
325         K[1]     = N1[prank[1]];
326         pK[1]    = N1[prank[1]];
327         oK[1]    = oN1[prank[1]];
328         nP[1]    = P[0];
329         C[2]     = MG;
330         rC[2]    = rMG;
331         iNin[2]  = M0;
332         oNin[2]  = oM0;
333         M[2]     = vmax(K0, P[0]);
334         pM[2]    = K0[prank[0]];
335         oM[2]    = oK0[prank[0]];
336         K[2]     = vmax(N1, P[1]);
337         pK[2]    = N1[prank[1]];
338         oK[2]    = oN1[prank[1]];
339         free(N0); free(oN0); /*these are not used for this order*/
340         free(M1); free(oM1); /*the rest is freed in destroy*/
341     }
342     else
343     {
344         N[0]     = vmax(N0, P[0]);
345         M[0]     = vmax(M0, P[0]);
346         K[0]     = K1[prank[1]];
347         pN[0]    = N0[prank[0]];
348         iNout[0] = N0;
349         oNout[0] = oN0;
350         nP[0]    = P[0];
351         C[1]     = MG;
352         rC[1]    = rMG;
353         N[1]     = vmax(M1, P[1]);
354         pN[1]    = M1[prank[1]];
355         iNin[1]  = M0;
356         oNin[1]  = oM0;
357         iNout[1] = M1;
358         oNout[1] = oM1;
359         M[1]     = N0[prank[0]];
360         pM[1]    = N0[prank[0]];
361         oM[1]    = oN0[prank[0]];
362         K[1]     = vmax(K1, P[1]);
363         pK[1]    = K1[prank[1]];
364         oK[1]    = oK1[prank[1]];
365         nP[1]    = P[1];
366         C[2]     = KG;
367         rC[2]    = rKG;
368         iNin[2]  = K1;
369         oNin[2]  = oK1;
370         M[2]     = vmax(N0, P[0]);
371         pM[2]    = N0[prank[0]];
372         oM[2]    = oN0[prank[0]];
373         K[2]     = vmax(M1, P[1]);
374         pK[2]    = M1[prank[1]];
375         oK[2]    = oM1[prank[1]];
376         free(N1); free(oN1); /*these are not used for this order*/
377         free(K0); free(oK0); /*the rest is freed in destroy*/
378     }
379     N[2] = pN[2] = -1;       /*not used*/
380
381     /*
382        Difference between x-y-z regarding 2d decomposition is whether they are
383        distributed along axis 1, 2 or both
384      */
385
386     /* int lsize = fmax(N[0]*M[0]*K[0]*nP[0],N[1]*M[1]*K[1]*nP[1]); */
387     lsize = std::max(N[0]*M[0]*K[0]*nP[0], std::max(N[1]*M[1]*K[1]*nP[1], C[2]*M[2]*K[2]));
388     /* int lsize = fmax(C[0]*M[0]*K[0],fmax(C[1]*M[1]*K[1],C[2]*M[2]*K[2])); */
389     if (!(flags&FFT5D_NOMALLOC))
390     {
391         // only needed for PME GPU mixed mode
392         if (realGridAllocationPinningPolicy == gmx::PinningPolicy::CanBePinned)
393         {
394             const std::size_t numBytes = lsize * sizeof(t_complex);
395             lin = static_cast<t_complex *>(gmx::PageAlignedAllocationPolicy::malloc(numBytes));
396             gmx::pinBuffer(lin, numBytes);
397         }
398         else
399         {
400             snew_aligned(lin, lsize, 32);
401         }
402         snew_aligned(lout, lsize, 32);
403         if (nthreads > 1)
404         {
405             /* We need extra transpose buffers to avoid OpenMP barriers */
406             snew_aligned(lout2, lsize, 32);
407             snew_aligned(lout3, lsize, 32);
408         }
409         else
410         {
411             /* We can reuse the buffers to avoid cache misses */
412             lout2 = lin;
413             lout3 = lout;
414         }
415     }
416     else
417     {
418         lin  = *rlin;
419         lout = *rlout;
420         if (nthreads > 1)
421         {
422             lout2 = *rlout2;
423             lout3 = *rlout3;
424         }
425         else
426         {
427             lout2 = lin;
428             lout3 = lout;
429         }
430     }
431
432     plan = static_cast<fft5d_plan>(calloc(1, sizeof(struct fft5d_plan_t)));
433
434
435     if (debug)
436     {
437         fprintf(debug, "Running on %d threads\n", nthreads);
438     }
439
440 #if GMX_FFT_FFTW3
441     /* Don't add more stuff here! We have already had at least one bug because we are reimplementing
442      * the low-level FFT interface instead of using the Gromacs FFT module. If we need more
443      * generic functionality it is far better to extend the interface so we can use it for
444      * all FFT libraries instead of writing FFTW-specific code here.
445      */
446
447     /*if not FFTW - then we don't do a 3d plan but instead use only 1D plans */
448     /* It is possible to use the 3d plan with OMP threads - but in that case it is not allowed to be called from
449      * within a parallel region. For now deactivated. If it should be supported it has to made sure that
450      * that the execute of the 3d plan is in a master/serial block (since it contains it own parallel region)
451      * and that the 3d plan is faster than the 1d plan.
452      */
453     if ((!(flags&FFT5D_INPLACE)) && (!(P[0] > 1 || P[1] > 1)) && nthreads == 1) /*don't do 3d plan in parallel or if in_place requested */
454     {
455         int fftwflags = FFTW_DESTROY_INPUT;
456         FFTW(iodim) dims[3];
457         int inNG = NG, outMG = MG, outKG = KG;
458
459         FFTW_LOCK;
460
461         fftwflags |= (flags & FFT5D_NOMEASURE) ? FFTW_ESTIMATE : FFTW_MEASURE;
462
463         if (flags&FFT5D_REALCOMPLEX)
464         {
465             if (!(flags&FFT5D_BACKWARD))        /*input pointer is not complex*/
466             {
467                 inNG *= 2;
468             }
469             else                                /*output pointer is not complex*/
470             {
471                 if (!(flags&FFT5D_ORDER_YZ))
472                 {
473                     outMG *= 2;
474                 }
475                 else
476                 {
477                     outKG *= 2;
478                 }
479             }
480         }
481
482         if (!(flags&FFT5D_BACKWARD))
483         {
484             dims[0].n  = KG;
485             dims[1].n  = MG;
486             dims[2].n  = rNG;
487
488             dims[0].is = inNG*MG;         /*N M K*/
489             dims[1].is = inNG;
490             dims[2].is = 1;
491             if (!(flags&FFT5D_ORDER_YZ))
492             {
493                 dims[0].os = MG;           /*M K N*/
494                 dims[1].os = 1;
495                 dims[2].os = MG*KG;
496             }
497             else
498             {
499                 dims[0].os = 1;           /*K N M*/
500                 dims[1].os = KG*NG;
501                 dims[2].os = KG;
502             }
503         }
504         else
505         {
506             if (!(flags&FFT5D_ORDER_YZ))
507             {
508                 dims[0].n  = NG;
509                 dims[1].n  = KG;
510                 dims[2].n  = rMG;
511
512                 dims[0].is = 1;
513                 dims[1].is = NG*MG;
514                 dims[2].is = NG;
515
516                 dims[0].os = outMG*KG;
517                 dims[1].os = outMG;
518                 dims[2].os = 1;
519             }
520             else
521             {
522                 dims[0].n  = MG;
523                 dims[1].n  = NG;
524                 dims[2].n  = rKG;
525
526                 dims[0].is = NG;
527                 dims[1].is = 1;
528                 dims[2].is = NG*MG;
529
530                 dims[0].os = outKG*NG;
531                 dims[1].os = outKG;
532                 dims[2].os = 1;
533             }
534         }
535 #ifdef FFT5D_THREADS
536 #ifdef FFT5D_FFTW_THREADS
537         FFTW(plan_with_nthreads)(nthreads);
538 #endif
539 #endif
540         if ((flags&FFT5D_REALCOMPLEX) && !(flags&FFT5D_BACKWARD))
541         {
542             plan->p3d = FFTW(plan_guru_dft_r2c)(/*rank*/ 3, dims,
543                                                          /*howmany*/ 0, /*howmany_dims*/ nullptr,
544                                                          reinterpret_cast<real*>(lin), reinterpret_cast<FFTW(complex) *>(lout),
545                                                          /*flags*/ fftwflags);
546         }
547         else if ((flags&FFT5D_REALCOMPLEX) && (flags&FFT5D_BACKWARD))
548         {
549             plan->p3d = FFTW(plan_guru_dft_c2r)(/*rank*/ 3, dims,
550                                                          /*howmany*/ 0, /*howmany_dims*/ nullptr,
551                                                          reinterpret_cast<FFTW(complex) *>(lin), reinterpret_cast<real*>(lout),
552                                                          /*flags*/ fftwflags);
553         }
554         else
555         {
556             plan->p3d = FFTW(plan_guru_dft)(/*rank*/ 3, dims,
557                                                      /*howmany*/ 0, /*howmany_dims*/ nullptr,
558                                                      reinterpret_cast<FFTW(complex) *>(lin), reinterpret_cast<FFTW(complex) *>(lout),
559                                                      /*sign*/ (flags&FFT5D_BACKWARD) ? 1 : -1, /*flags*/ fftwflags);
560         }
561 #ifdef FFT5D_THREADS
562 #ifdef FFT5D_FFTW_THREADS
563         FFTW(plan_with_nthreads)(1);
564 #endif
565 #endif
566         FFTW_UNLOCK;
567     }
568     if (!plan->p3d) /* for decomposition and if 3d plan did not work */
569     {
570 #endif              /* GMX_FFT_FFTW3 */
571     for (s = 0; s < 3; s++)
572     {
573         if (debug)
574         {
575             fprintf(debug, "FFT5D: Plan s %d rC %d M %d pK %d C %d lsize %d\n",
576                     s, rC[s], M[s], pK[s], C[s], lsize);
577         }
578         plan->p1d[s] = static_cast<gmx_fft_t*>(malloc(sizeof(gmx_fft_t)*nthreads));
579
580         /* Make sure that the init routines are only called by one thread at a time and in order
581            (later is only important to not confuse valgrind)
582          */
583 #pragma omp parallel for num_threads(nthreads) schedule(static) ordered
584         for (t = 0; t < nthreads; t++)
585         {
586 #pragma omp ordered
587             {
588                 try
589                 {
590                     int tsize = ((t+1)*pM[s]*pK[s]/nthreads)-(t*pM[s]*pK[s]/nthreads);
591
592                     if ((flags&FFT5D_REALCOMPLEX) && ((!(flags&FFT5D_BACKWARD) && s == 0) || ((flags&FFT5D_BACKWARD) && s == 2)))
593                     {
594                         gmx_fft_init_many_1d_real( &plan->p1d[s][t], rC[s], tsize, (flags&FFT5D_NOMEASURE) ? GMX_FFT_FLAG_CONSERVATIVE : 0 );
595                     }
596                     else
597                     {
598                         gmx_fft_init_many_1d     ( &plan->p1d[s][t],  C[s], tsize, (flags&FFT5D_NOMEASURE) ? GMX_FFT_FLAG_CONSERVATIVE : 0 );
599                     }
600                 }
601                 GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
602             }
603         }
604     }
605
606 #if GMX_FFT_FFTW3
607 }
608 #endif
609     if ((flags&FFT5D_ORDER_YZ))   /*plan->cart is in the order of transposes */
610     {
611         plan->cart[0] = comm[0]; plan->cart[1] = comm[1];
612     }
613     else
614     {
615         plan->cart[1] = comm[0]; plan->cart[0] = comm[1];
616     }
617 #ifdef FFT5D_MPI_TRANSPOSE
618     FFTW_LOCK;
619     for (s = 0; s < 2; s++)
620     {
621         if ((s == 0 && !(flags&FFT5D_ORDER_YZ)) || (s == 1 && (flags&FFT5D_ORDER_YZ)))
622         {
623             plan->mpip[s] = FFTW(mpi_plan_many_transpose)(nP[s], nP[s], N[s]*K[s]*pM[s]*2, 1, 1, (real*)lout2, (real*)lout3, plan->cart[s], FFTW_PATIENT);
624         }
625         else
626         {
627             plan->mpip[s] = FFTW(mpi_plan_many_transpose)(nP[s], nP[s], N[s]*pK[s]*M[s]*2, 1, 1, (real*)lout2, (real*)lout3, plan->cart[s], FFTW_PATIENT);
628         }
629     }
630     FFTW_UNLOCK;
631 #endif
632
633
634     plan->lin   = lin;
635     plan->lout  = lout;
636     plan->lout2 = lout2;
637     plan->lout3 = lout3;
638
639     plan->NG = NG; plan->MG = MG; plan->KG = KG;
640
641     for (s = 0; s < 3; s++)
642     {
643         plan->N[s]    = N[s]; plan->M[s] = M[s]; plan->K[s] = K[s]; plan->pN[s] = pN[s]; plan->pM[s] = pM[s]; plan->pK[s] = pK[s];
644         plan->oM[s]   = oM[s]; plan->oK[s] = oK[s];
645         plan->C[s]    = C[s]; plan->rC[s] = rC[s];
646         plan->iNin[s] = iNin[s]; plan->oNin[s] = oNin[s]; plan->iNout[s] = iNout[s]; plan->oNout[s] = oNout[s];
647     }
648     for (s = 0; s < 2; s++)
649     {
650         plan->P[s] = nP[s]; plan->coor[s] = prank[s];
651     }
652
653 /*    plan->fftorder=fftorder;
654     plan->direction=direction;
655     plan->realcomplex=realcomplex;
656  */
657     plan->flags         = flags;
658     plan->nthreads      = nthreads;
659     plan->pinningPolicy = realGridAllocationPinningPolicy;
660     *rlin               = lin;
661     *rlout              = lout;
662     *rlout2             = lout2;
663     *rlout3             = lout3;
664     return plan;
665 }
666
667
668 enum order {
669     XYZ,
670     XZY,
671     YXZ,
672     YZX,
673     ZXY,
674     ZYX
675 };
676
677
678
679 /*here x,y,z and N,M,K is in rotated coordinate system!!
680    x (and N) is mayor (consecutive) dimension, y (M) middle and z (K) major
681    maxN,maxM,maxK is max size of local data
682    pN, pM, pK is local size specific to current processor (only different to max if not divisible)
683    NG, MG, KG is size of global data*/
684 static void splitaxes(t_complex* lout, const t_complex* lin,
685                       int maxN, int maxM, int maxK, int pM,
686                       int P, int NG, const int *N, const int* oN, int starty, int startz, int endy, int endz)
687 {
688     int x, y, z, i;
689     int in_i, out_i, in_z, out_z, in_y, out_y;
690     int s_y, e_y;
691
692     for (z = startz; z < endz+1; z++) /*3. z l*/
693     {
694         if (z == startz)
695         {
696             s_y = starty;
697         }
698         else
699         {
700             s_y = 0;
701         }
702         if (z == endz)
703         {
704             e_y = endy;
705         }
706         else
707         {
708             e_y = pM;
709         }
710         out_z  = z*maxN*maxM;
711         in_z   = z*NG*pM;
712
713         for (i = 0; i < P; i++) /*index cube along long axis*/
714         {
715             out_i  = out_z  + i*maxN*maxM*maxK;
716             in_i   = in_z + oN[i];
717             for (y = s_y; y < e_y; y++)   /*2. y k*/
718             {
719                 out_y  = out_i  + y*maxN;
720                 in_y   = in_i + y*NG;
721                 for (x = 0; x < N[i]; x++)       /*1. x j*/
722                 {
723                     lout[out_y+x] = lin[in_y+x]; /*in=z*NG*pM+oN[i]+y*NG+x*/
724                     /*after split important that each processor chunk i has size maxN*maxM*maxK and thus being the same size*/
725                     /*before split data contiguos - thus if different processor get different amount oN is different*/
726                 }
727             }
728         }
729     }
730 }
731
732 /*make axis contiguous again (after AllToAll) and also do local transpose*/
733 /*transpose mayor and major dimension
734    variables see above
735    the major, middle, minor order is only correct for x,y,z (N,M,K) for the input
736    N,M,K local dimensions
737    KG global size*/
738 static void joinAxesTrans13(t_complex* lout, const t_complex* lin,
739                             int maxN, int maxM, int maxK, int pM,
740                             int P, int KG, const int* K, const int* oK, int starty, int startx, int endy, int endx)
741 {
742     int i, x, y, z;
743     int out_i, in_i, out_x, in_x, out_z, in_z;
744     int s_y, e_y;
745
746     for (x = startx; x < endx+1; x++) /*1.j*/
747     {
748         if (x == startx)
749         {
750             s_y = starty;
751         }
752         else
753         {
754             s_y = 0;
755         }
756         if (x == endx)
757         {
758             e_y = endy;
759         }
760         else
761         {
762             e_y = pM;
763         }
764
765         out_x  = x*KG*pM;
766         in_x   = x;
767
768         for (i = 0; i < P; i++) /*index cube along long axis*/
769         {
770             out_i  = out_x  + oK[i];
771             in_i   = in_x + i*maxM*maxN*maxK;
772             for (z = 0; z < K[i]; z++) /*3.l*/
773             {
774                 out_z  = out_i  + z;
775                 in_z   = in_i + z*maxM*maxN;
776                 for (y = s_y; y < e_y; y++)              /*2.k*/
777                 {
778                     lout[out_z+y*KG] = lin[in_z+y*maxN]; /*out=x*KG*pM+oK[i]+z+y*KG*/
779                 }
780             }
781         }
782     }
783 }
784
785 /*make axis contiguous again (after AllToAll) and also do local transpose
786    tranpose mayor and middle dimension
787    variables see above
788    the minor, middle, major order is only correct for x,y,z (N,M,K) for the input
789    N,M,K local size
790    MG, global size*/
791 static void joinAxesTrans12(t_complex* lout, const t_complex* lin, int maxN, int maxM, int maxK, int pN,
792                             int P, int MG, const int* M, const int* oM, int startx, int startz, int endx, int endz)
793 {
794     int i, z, y, x;
795     int out_i, in_i, out_z, in_z, out_x, in_x;
796     int s_x, e_x;
797
798     for (z = startz; z < endz+1; z++)
799     {
800         if (z == startz)
801         {
802             s_x = startx;
803         }
804         else
805         {
806             s_x = 0;
807         }
808         if (z == endz)
809         {
810             e_x = endx;
811         }
812         else
813         {
814             e_x = pN;
815         }
816         out_z  = z*MG*pN;
817         in_z   = z*maxM*maxN;
818
819         for (i = 0; i < P; i++) /*index cube along long axis*/
820         {
821             out_i  = out_z  + oM[i];
822             in_i   = in_z + i*maxM*maxN*maxK;
823             for (x = s_x; x < e_x; x++)
824             {
825                 out_x  = out_i  + x*MG;
826                 in_x   = in_i + x;
827                 for (y = 0; y < M[i]; y++)
828                 {
829                     lout[out_x+y] = lin[in_x+y*maxN]; /*out=z*MG*pN+oM[i]+x*MG+y*/
830                 }
831             }
832         }
833     }
834 }
835
836
837 static void rotate_offsets(int x[])
838 {
839     int t = x[0];
840 /*    x[0]=x[2];
841     x[2]=x[1];
842     x[1]=t;*/
843     x[0] = x[1];
844     x[1] = x[2];
845     x[2] = t;
846 }
847
848 /*compute the offset to compare or print transposed local data in original input coordinates
849    xs matrix dimension size, xl dimension length, xc decomposition offset
850    s: step in computation = number of transposes*/
851 static void compute_offsets(fft5d_plan plan, int xs[], int xl[], int xc[], int NG[], int s)
852 {
853 /*    int direction = plan->direction;
854     int fftorder = plan->fftorder;*/
855
856     int  o = 0;
857     int  pos[3], i;
858     int *pM = plan->pM, *pK = plan->pK, *oM = plan->oM, *oK = plan->oK,
859     *C      = plan->C, *rC = plan->rC;
860
861     NG[0] = plan->NG; NG[1] = plan->MG; NG[2] = plan->KG;
862
863     if (!(plan->flags&FFT5D_ORDER_YZ))
864     {
865         switch (s)
866         {
867             case 0: o = XYZ; break;
868             case 1: o = ZYX; break;
869             case 2: o = YZX; break;
870             default: assert(0);
871         }
872     }
873     else
874     {
875         switch (s)
876         {
877             case 0: o = XYZ; break;
878             case 1: o = YXZ; break;
879             case 2: o = ZXY; break;
880             default: assert(0);
881         }
882     }
883
884     switch (o)
885     {
886         case XYZ: pos[0] = 1; pos[1] = 2; pos[2] = 3; break;
887         case XZY: pos[0] = 1; pos[1] = 3; pos[2] = 2; break;
888         case YXZ: pos[0] = 2; pos[1] = 1; pos[2] = 3; break;
889         case YZX: pos[0] = 3; pos[1] = 1; pos[2] = 2; break;
890         case ZXY: pos[0] = 2; pos[1] = 3; pos[2] = 1; break;
891         case ZYX: pos[0] = 3; pos[1] = 2; pos[2] = 1; break;
892     }
893     /*if (debug) printf("pos: %d %d %d\n",pos[0],pos[1],pos[2]);*/
894
895     /*xs, xl give dimension size and data length in local transposed coordinate system
896        for 0(/1/2): x(/y/z) in original coordinate system*/
897     for (i = 0; i < 3; i++)
898     {
899         switch (pos[i])
900         {
901             case 1: xs[i] = 1;         xc[i] = 0;     xl[i] = C[s]; break;
902             case 2: xs[i] = C[s];      xc[i] = oM[s]; xl[i] = pM[s]; break;
903             case 3: xs[i] = C[s]*pM[s]; xc[i] = oK[s]; xl[i] = pK[s]; break;
904         }
905     }
906     /*input order is different for test program to match FFTW order
907        (important for complex to real)*/
908     if (plan->flags&FFT5D_BACKWARD)
909     {
910         rotate_offsets(xs);
911         rotate_offsets(xl);
912         rotate_offsets(xc);
913         rotate_offsets(NG);
914         if (plan->flags&FFT5D_ORDER_YZ)
915         {
916             rotate_offsets(xs);
917             rotate_offsets(xl);
918             rotate_offsets(xc);
919             rotate_offsets(NG);
920         }
921     }
922     if ((plan->flags&FFT5D_REALCOMPLEX) && ((!(plan->flags&FFT5D_BACKWARD) && s == 0) || ((plan->flags&FFT5D_BACKWARD) && s == 2)))
923     {
924         xl[0] = rC[s];
925     }
926 }
927
928 static void print_localdata(const t_complex* lin, const char* txt, int s, fft5d_plan plan)
929 {
930     int  x, y, z, l;
931     int *coor = plan->coor;
932     int  xs[3], xl[3], xc[3], NG[3];
933     int  ll = (plan->flags&FFT5D_REALCOMPLEX) ? 1 : 2;
934     compute_offsets(plan, xs, xl, xc, NG, s);
935     fprintf(debug, txt, coor[0], coor[1]);
936     /*printf("xs: %d %d %d, xl: %d %d %d\n",xs[0],xs[1],xs[2],xl[0],xl[1],xl[2]);*/
937     for (z = 0; z < xl[2]; z++)
938     {
939         for (y = 0; y < xl[1]; y++)
940         {
941             fprintf(debug, "%d %d: ", coor[0], coor[1]);
942             for (x = 0; x < xl[0]; x++)
943             {
944                 for (l = 0; l < ll; l++)
945                 {
946                     fprintf(debug, "%f ", reinterpret_cast<const real*>(lin)[(z*xs[2]+y*xs[1])*2+(x*xs[0])*ll+l]);
947                 }
948                 fprintf(debug, ",");
949             }
950             fprintf(debug, "\n");
951         }
952     }
953 }
954
955 void fft5d_execute(fft5d_plan plan, int thread, fft5d_time times)
956 {
957     t_complex  *lin   = plan->lin;
958     t_complex  *lout  = plan->lout;
959     t_complex  *lout2 = plan->lout2;
960     t_complex  *lout3 = plan->lout3;
961     t_complex  *fftout, *joinin;
962
963     gmx_fft_t **p1d = plan->p1d;
964 #ifdef FFT5D_MPI_TRANSPOSE
965     FFTW(plan) *mpip = plan->mpip;
966 #endif
967 #if GMX_MPI
968     MPI_Comm *cart = plan->cart;
969 #endif
970 #ifdef NOGMX
971     double time_fft = 0, time_local = 0, time_mpi[2] = {0}, time = 0;
972 #endif
973     int   *N = plan->N, *M = plan->M, *K = plan->K, *pN = plan->pN, *pM = plan->pM, *pK = plan->pK,
974     *C       = plan->C, *P = plan->P, **iNin = plan->iNin, **oNin = plan->oNin, **iNout = plan->iNout, **oNout = plan->oNout;
975     int    s = 0, tstart, tend, bParallelDim;
976
977
978 #if GMX_FFT_FFTW3
979     if (plan->p3d)
980     {
981         if (thread == 0)
982         {
983 #ifdef NOGMX
984             if (times != 0)
985             {
986                 time = MPI_Wtime();
987             }
988 #endif
989             FFTW(execute)(plan->p3d);
990 #ifdef NOGMX
991             if (times != 0)
992             {
993                 times->fft += MPI_Wtime()-time;
994             }
995 #endif
996         }
997         return;
998     }
999 #endif
1000
1001     s = 0;
1002
1003     /*lin: x,y,z*/
1004     if ((plan->flags&FFT5D_DEBUG) && thread == 0)
1005     {
1006         print_localdata(lin, "%d %d: copy in lin\n", s, plan);
1007     }
1008
1009     for (s = 0; s < 2; s++)  /*loop over first two FFT steps (corner rotations)*/
1010
1011     {
1012 #if GMX_MPI
1013         if (GMX_PARALLEL_ENV_INITIALIZED && cart[s] != MPI_COMM_NULL && P[s] > 1)
1014         {
1015             bParallelDim = 1;
1016         }
1017         else
1018 #endif
1019         {
1020             bParallelDim = 0;
1021         }
1022
1023         /* ---------- START FFT ------------ */
1024 #ifdef NOGMX
1025         if (times != 0 && thread == 0)
1026         {
1027             time = MPI_Wtime();
1028         }
1029 #endif
1030
1031         if (bParallelDim || plan->nthreads == 1)
1032         {
1033             fftout = lout;
1034         }
1035         else
1036         {
1037             if (s == 0)
1038             {
1039                 fftout = lout3;
1040             }
1041             else
1042             {
1043                 fftout = lout2;
1044             }
1045         }
1046
1047         tstart = (thread*pM[s]*pK[s]/plan->nthreads)*C[s];
1048         if ((plan->flags&FFT5D_REALCOMPLEX) && !(plan->flags&FFT5D_BACKWARD) && s == 0)
1049         {
1050             gmx_fft_many_1d_real(p1d[s][thread], (plan->flags&FFT5D_BACKWARD) ? GMX_FFT_COMPLEX_TO_REAL : GMX_FFT_REAL_TO_COMPLEX, lin+tstart, fftout+tstart);
1051         }
1052         else
1053         {
1054             gmx_fft_many_1d(     p1d[s][thread], (plan->flags&FFT5D_BACKWARD) ? GMX_FFT_BACKWARD : GMX_FFT_FORWARD,               lin+tstart, fftout+tstart);
1055
1056         }
1057
1058 #ifdef NOGMX
1059         if (times != NULL && thread == 0)
1060         {
1061             time_fft += MPI_Wtime()-time;
1062         }
1063 #endif
1064         if ((plan->flags&FFT5D_DEBUG) && thread == 0)
1065         {
1066             print_localdata(lout, "%d %d: FFT %d\n", s, plan);
1067         }
1068         /* ---------- END FFT ------------ */
1069
1070         /* ---------- START SPLIT + TRANSPOSE------------ (if parallel in in this dimension)*/
1071         if (bParallelDim)
1072         {
1073 #ifdef NOGMX
1074             if (times != NULL && thread == 0)
1075             {
1076                 time = MPI_Wtime();
1077             }
1078 #endif
1079             /*prepare for A
1080                llToAll
1081                1. (most outer) axes (x) is split into P[s] parts of size N[s]
1082                for sending*/
1083             if (pM[s] > 0)
1084             {
1085                 tend    = ((thread+1)*pM[s]*pK[s]/plan->nthreads);
1086                 tstart /= C[s];
1087                 splitaxes(lout2, lout, N[s], M[s], K[s], pM[s], P[s], C[s], iNout[s], oNout[s], tstart%pM[s], tstart/pM[s], tend%pM[s], tend/pM[s]);
1088             }
1089 #pragma omp barrier /*barrier required before AllToAll (all input has to be their) - before timing to make timing more acurate*/
1090 #ifdef NOGMX
1091             if (times != NULL && thread == 0)
1092             {
1093                 time_local += MPI_Wtime()-time;
1094             }
1095 #endif
1096
1097             /* ---------- END SPLIT , START TRANSPOSE------------ */
1098
1099             if (thread == 0)
1100             {
1101 #ifdef NOGMX
1102                 if (times != 0)
1103                 {
1104                     time = MPI_Wtime();
1105                 }
1106 #else
1107                 wallcycle_start(times, ewcPME_FFTCOMM);
1108 #endif
1109 #ifdef FFT5D_MPI_TRANSPOSE
1110                 FFTW(execute)(mpip[s]);
1111 #else
1112 #if GMX_MPI
1113                 if ((s == 0 && !(plan->flags&FFT5D_ORDER_YZ)) || (s == 1 && (plan->flags&FFT5D_ORDER_YZ)))
1114                 {
1115                     MPI_Alltoall(reinterpret_cast<real *>(lout2), N[s]*pM[s]*K[s]*sizeof(t_complex)/sizeof(real), GMX_MPI_REAL, reinterpret_cast<real *>(lout3), N[s]*pM[s]*K[s]*sizeof(t_complex)/sizeof(real), GMX_MPI_REAL, cart[s]);
1116                 }
1117                 else
1118                 {
1119                     MPI_Alltoall(reinterpret_cast<real *>(lout2), N[s]*M[s]*pK[s]*sizeof(t_complex)/sizeof(real), GMX_MPI_REAL, reinterpret_cast<real *>(lout3), N[s]*M[s]*pK[s]*sizeof(t_complex)/sizeof(real), GMX_MPI_REAL, cart[s]);
1120                 }
1121 #else
1122                 gmx_incons("fft5d MPI call without MPI configuration");
1123 #endif /*GMX_MPI*/
1124 #endif /*FFT5D_MPI_TRANSPOSE*/
1125 #ifdef NOGMX
1126                 if (times != 0)
1127                 {
1128                     time_mpi[s] = MPI_Wtime()-time;
1129                 }
1130 #else
1131                 wallcycle_stop(times, ewcPME_FFTCOMM);
1132 #endif
1133             } /*master*/
1134         }     /* bPrallelDim */
1135 #pragma omp barrier  /*both needed for parallel and non-parallel dimension (either have to wait on data from AlltoAll or from last FFT*/
1136
1137         /* ---------- END SPLIT + TRANSPOSE------------ */
1138
1139         /* ---------- START JOIN ------------ */
1140 #ifdef NOGMX
1141         if (times != NULL && thread == 0)
1142         {
1143             time = MPI_Wtime();
1144         }
1145 #endif
1146
1147         if (bParallelDim)
1148         {
1149             joinin = lout3;
1150         }
1151         else
1152         {
1153             joinin = fftout;
1154         }
1155         /*bring back in matrix form
1156            thus make  new 1. axes contiguos
1157            also local transpose 1 and 2/3
1158            runs on thread used for following FFT (thus needing a barrier before but not afterwards)
1159          */
1160         if ((s == 0 && !(plan->flags&FFT5D_ORDER_YZ)) || (s == 1 && (plan->flags&FFT5D_ORDER_YZ)))
1161         {
1162             if (pM[s] > 0)
1163             {
1164                 tstart = ( thread   *pM[s]*pN[s]/plan->nthreads);
1165                 tend   = ((thread+1)*pM[s]*pN[s]/plan->nthreads);
1166                 joinAxesTrans13(lin, joinin, N[s], pM[s], K[s], pM[s], P[s], C[s+1], iNin[s+1], oNin[s+1], tstart%pM[s], tstart/pM[s], tend%pM[s], tend/pM[s]);
1167             }
1168         }
1169         else
1170         {
1171             if (pN[s] > 0)
1172             {
1173                 tstart = ( thread   *pK[s]*pN[s]/plan->nthreads);
1174                 tend   = ((thread+1)*pK[s]*pN[s]/plan->nthreads);
1175                 joinAxesTrans12(lin, joinin, N[s], M[s], pK[s], pN[s], P[s], C[s+1], iNin[s+1], oNin[s+1], tstart%pN[s], tstart/pN[s], tend%pN[s], tend/pN[s]);
1176             }
1177         }
1178
1179 #ifdef NOGMX
1180         if (times != NULL && thread == 0)
1181         {
1182             time_local += MPI_Wtime()-time;
1183         }
1184 #endif
1185         if ((plan->flags&FFT5D_DEBUG) && thread == 0)
1186         {
1187             print_localdata(lin, "%d %d: tranposed %d\n", s+1, plan);
1188         }
1189         /* ---------- END JOIN ------------ */
1190
1191         /*if (debug) print_localdata(lin, "%d %d: transposed x-z\n", N1, M0, K, ZYX, coor);*/
1192     }  /* for(s=0;s<2;s++) */
1193 #ifdef NOGMX
1194     if (times != NULL && thread == 0)
1195     {
1196         time = MPI_Wtime();
1197     }
1198 #endif
1199
1200     if (plan->flags&FFT5D_INPLACE)
1201     {
1202         lout = lin;                          /*in place currently not supported*/
1203
1204     }
1205     /*  ----------- FFT ----------- */
1206     tstart = (thread*pM[s]*pK[s]/plan->nthreads)*C[s];
1207     if ((plan->flags&FFT5D_REALCOMPLEX) && (plan->flags&FFT5D_BACKWARD))
1208     {
1209         gmx_fft_many_1d_real(p1d[s][thread], (plan->flags&FFT5D_BACKWARD) ? GMX_FFT_COMPLEX_TO_REAL : GMX_FFT_REAL_TO_COMPLEX, lin+tstart, lout+tstart);
1210     }
1211     else
1212     {
1213         gmx_fft_many_1d(     p1d[s][thread], (plan->flags&FFT5D_BACKWARD) ? GMX_FFT_BACKWARD : GMX_FFT_FORWARD,               lin+tstart, lout+tstart);
1214     }
1215     /* ------------ END FFT ---------*/
1216
1217 #ifdef NOGMX
1218     if (times != NULL && thread == 0)
1219     {
1220         time_fft += MPI_Wtime()-time;
1221
1222         times->fft   += time_fft;
1223         times->local += time_local;
1224         times->mpi2  += time_mpi[1];
1225         times->mpi1  += time_mpi[0];
1226     }
1227 #endif
1228
1229     if ((plan->flags&FFT5D_DEBUG) && thread == 0)
1230     {
1231         print_localdata(lout, "%d %d: FFT %d\n", s, plan);
1232     }
1233 }
1234
1235 void fft5d_destroy(fft5d_plan plan)
1236 {
1237     int s, t;
1238
1239     for (s = 0; s < 3; s++)
1240     {
1241         if (plan->p1d[s])
1242         {
1243             for (t = 0; t < plan->nthreads; t++)
1244             {
1245                 gmx_many_fft_destroy(plan->p1d[s][t]);
1246             }
1247             free(plan->p1d[s]);
1248         }
1249         if (plan->iNin[s])
1250         {
1251             free(plan->iNin[s]);
1252             plan->iNin[s] = nullptr;
1253         }
1254         if (plan->oNin[s])
1255         {
1256             free(plan->oNin[s]);
1257             plan->oNin[s] = nullptr;
1258         }
1259         if (plan->iNout[s])
1260         {
1261             free(plan->iNout[s]);
1262             plan->iNout[s] = nullptr;
1263         }
1264         if (plan->oNout[s])
1265         {
1266             free(plan->oNout[s]);
1267             plan->oNout[s] = nullptr;
1268         }
1269     }
1270 #if GMX_FFT_FFTW3
1271     FFTW_LOCK;
1272 #ifdef FFT5D_MPI_TRANSPOS
1273     for (s = 0; s < 2; s++)
1274     {
1275         FFTW(destroy_plan)(plan->mpip[s]);
1276     }
1277 #endif /* FFT5D_MPI_TRANSPOS */
1278     if (plan->p3d)
1279     {
1280         FFTW(destroy_plan)(plan->p3d);
1281     }
1282     FFTW_UNLOCK;
1283 #endif /* GMX_FFT_FFTW3 */
1284
1285     if (!(plan->flags&FFT5D_NOMALLOC))
1286     {
1287         // only needed for PME GPU mixed mode
1288         if (plan->pinningPolicy == gmx::PinningPolicy::CanBePinned &&
1289             isHostMemoryPinned(plan->lin))
1290         {
1291             gmx::unpinBuffer(plan->lin);
1292         }
1293         sfree_aligned(plan->lin);
1294         sfree_aligned(plan->lout);
1295         if (plan->nthreads > 1)
1296         {
1297             sfree_aligned(plan->lout2);
1298             sfree_aligned(plan->lout3);
1299         }
1300     }
1301
1302 #ifdef FFT5D_THREADS
1303 #ifdef FFT5D_FFTW_THREADS
1304     /*FFTW(cleanup_threads)();*/
1305 #endif
1306 #endif
1307
1308     free(plan);
1309 }
1310
1311 /*Is this better than direct access of plan? enough data?
1312    here 0,1 reference divided by which processor grid dimension (not FFT step!)*/
1313 void fft5d_local_size(fft5d_plan plan, int* N1, int* M0, int* K0, int* K1, int** coor)
1314 {
1315     *N1 = plan->N[0];
1316     *M0 = plan->M[0];
1317     *K1 = plan->K[0];
1318     *K0 = plan->N[1];
1319
1320     *coor = plan->coor;
1321 }
1322
1323
1324 /*same as fft5d_plan_3d but with cartesian coordinator and automatic splitting
1325    of processor dimensions*/
1326 fft5d_plan fft5d_plan_3d_cart(int NG, int MG, int KG, MPI_Comm comm, int P0, int flags, t_complex** rlin, t_complex** rlout, t_complex** rlout2, t_complex** rlout3, int nthreads)
1327 {
1328     MPI_Comm cart[2] = {MPI_COMM_NULL, MPI_COMM_NULL};
1329 #if GMX_MPI
1330     int      size = 1, prank = 0;
1331     int      P[2];
1332     int      coor[2];
1333     int      wrap[] = {0, 0};
1334     MPI_Comm gcart;
1335     int      rdim1[] = {0, 1}, rdim2[] = {1, 0};
1336
1337     MPI_Comm_size(comm, &size);
1338     MPI_Comm_rank(comm, &prank);
1339
1340     if (P0 == 0)
1341     {
1342         P0 = lfactor(size);
1343     }
1344     if (size%P0 != 0)
1345     {
1346         if (prank == 0)
1347         {
1348             printf("FFT5D: WARNING: Number of ranks %d not evenly divisible by %d\n", size, P0);
1349         }
1350         P0 = lfactor(size);
1351     }
1352
1353     P[0] = P0; P[1] = size/P0; /*number of processors in the two dimensions*/
1354
1355     /*Difference between x-y-z regarding 2d decomposition is whether they are
1356        distributed along axis 1, 2 or both*/
1357
1358     MPI_Cart_create(comm, 2, P, wrap, 1, &gcart); /*parameter 4: value 1: reorder*/
1359     MPI_Cart_get(gcart, 2, P, wrap, coor);
1360     MPI_Cart_sub(gcart, rdim1, &cart[0]);
1361     MPI_Cart_sub(gcart, rdim2, &cart[1]);
1362 #else
1363     (void)P0;
1364     (void)comm;
1365 #endif
1366     return fft5d_plan_3d(NG, MG, KG, cart, flags, rlin, rlout, rlout2, rlout3, nthreads);
1367 }
1368
1369
1370
1371 /*prints in original coordinate system of data (as the input to FFT)*/
1372 void fft5d_compare_data(const t_complex* lin, const t_complex* in, fft5d_plan plan, int bothLocal, int normalize)
1373 {
1374     int  xs[3], xl[3], xc[3], NG[3];
1375     int  x, y, z, l;
1376     int *coor = plan->coor;
1377     int  ll   = 2; /*compare ll values per element (has to be 2 for complex)*/
1378     if ((plan->flags&FFT5D_REALCOMPLEX) && (plan->flags&FFT5D_BACKWARD))
1379     {
1380         ll = 1;
1381     }
1382
1383     compute_offsets(plan, xs, xl, xc, NG, 2);
1384     if (plan->flags&FFT5D_DEBUG)
1385     {
1386         printf("Compare2\n");
1387     }
1388     for (z = 0; z < xl[2]; z++)
1389     {
1390         for (y = 0; y < xl[1]; y++)
1391         {
1392             if (plan->flags&FFT5D_DEBUG)
1393             {
1394                 printf("%d %d: ", coor[0], coor[1]);
1395             }
1396             for (x = 0; x < xl[0]; x++)
1397             {
1398                 for (l = 0; l < ll; l++)   /*loop over real/complex parts*/
1399                 {
1400                     real a, b;
1401                     a = reinterpret_cast<const real*>(lin)[(z*xs[2]+y*xs[1])*2+x*xs[0]*ll+l];
1402                     if (normalize)
1403                     {
1404                         a /= plan->rC[0]*plan->rC[1]*plan->rC[2];
1405                     }
1406                     if (!bothLocal)
1407                     {
1408                         b = reinterpret_cast<const real*>(in)[((z+xc[2])*NG[0]*NG[1]+(y+xc[1])*NG[0])*2+(x+xc[0])*ll+l];
1409                     }
1410                     else
1411                     {
1412                         b = reinterpret_cast<const real*>(in)[(z*xs[2]+y*xs[1])*2+x*xs[0]*ll+l];
1413                     }
1414                     if (plan->flags&FFT5D_DEBUG)
1415                     {
1416                         printf("%f %f, ", a, b);
1417                     }
1418                     else
1419                     {
1420                         if (std::fabs(a-b) > 2*NG[0]*NG[1]*NG[2]*GMX_REAL_EPS)
1421                         {
1422                             printf("result incorrect on %d,%d at %d,%d,%d: FFT5D:%f reference:%f\n", coor[0], coor[1], x, y, z, a, b);
1423                         }
1424 /*                        assert(fabs(a-b)<2*NG[0]*NG[1]*NG[2]*GMX_REAL_EPS);*/
1425                     }
1426                 }
1427                 if (plan->flags&FFT5D_DEBUG)
1428                 {
1429                     printf(",");
1430                 }
1431             }
1432             if (plan->flags&FFT5D_DEBUG)
1433             {
1434                 printf("\n");
1435             }
1436         }
1437     }
1438
1439 }