12c1e010f800de5bdcac40c154cf98fb1849079c
[alexxy/gromacs.git] / src / gromacs / fft / fft5d.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2009,2010,2012,2013,2014,2015,2016,2017,2018, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 #include "gmxpre.h"
36
37 #include "fft5d.h"
38
39 #include "config.h"
40
41 #include <assert.h>
42 #include <float.h>
43 #include <stdio.h>
44 #include <stdlib.h>
45 #include <string.h>
46
47 #include <cmath>
48
49 #include <algorithm>
50
51 #include "gromacs/gpu_utils/gpu_utils.h"
52 #include "gromacs/gpu_utils/hostallocator.h"
53 #include "gromacs/gpu_utils/pinning.h"
54 #include "gromacs/utility/alignedallocator.h"
55 #include "gromacs/utility/exceptions.h"
56 #include "gromacs/utility/fatalerror.h"
57 #include "gromacs/utility/gmxmpi.h"
58 #include "gromacs/utility/smalloc.h"
59
60 #ifdef NOGMX
61 #define GMX_PARALLEL_ENV_INITIALIZED 1
62 #else
63 #if GMX_MPI
64 #define GMX_PARALLEL_ENV_INITIALIZED 1
65 #else
66 #define GMX_PARALLEL_ENV_INITIALIZED 0
67 #endif
68 #endif
69
70 #if GMX_OPENMP
71 /* TODO: Do we still need this? Are we still planning ot use fftw + OpenMP? */
72 #define FFT5D_THREADS
73 /* requires fftw compiled with openmp */
74 /* #define FFT5D_FFTW_THREADS (now set by cmake) */
75 #endif
76
77 #ifndef __FLT_EPSILON__
78 #define __FLT_EPSILON__ FLT_EPSILON
79 #define __DBL_EPSILON__ DBL_EPSILON
80 #endif
81
82 #ifdef NOGMX
83 FILE* debug = 0;
84 #endif
85
86 #if GMX_FFT_FFTW3
87
88 #include "gromacs/utility/exceptions.h"
89 #include "gromacs/utility/mutex.h"
90 /* none of the fftw3 calls, except execute(), are thread-safe, so
91    we need to serialize them with this mutex. */
92 static gmx::Mutex big_fftw_mutex;
93 #define FFTW_LOCK try { big_fftw_mutex.lock(); } GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
94 #define FFTW_UNLOCK try { big_fftw_mutex.unlock(); } GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
95 #endif /* GMX_FFT_FFTW3 */
96
97 #if GMX_MPI
98 /* largest factor smaller than sqrt */
99 static int lfactor(int z)
100 {
101     int i = static_cast<int>(sqrt(static_cast<double>(z)));
102     while (z%i != 0)
103     {
104         i--;
105     }
106     return i;
107 }
108 #endif
109
110 #if !GMX_MPI
111 #if HAVE_GETTIMEOFDAY
112 #include <sys/time.h>
113 double MPI_Wtime()
114 {
115     struct timeval tv;
116     gettimeofday(&tv, 0);
117     return tv.tv_sec+tv.tv_usec*1e-6;
118 }
119 #else
120 double MPI_Wtime()
121 {
122     return 0.0;
123 }
124 #endif
125 #endif
126
127 static int vmax(const int* a, int s)
128 {
129     int i, max = 0;
130     for (i = 0; i < s; i++)
131     {
132         if (a[i] > max)
133         {
134             max = a[i];
135         }
136     }
137     return max;
138 }
139
140
141 /* NxMxK the size of the data
142  * comm communicator to use for fft5d
143  * P0 number of processor in 1st axes (can be null for automatic)
144  * lin is allocated by fft5d because size of array is only known after planning phase
145  * rlout2 is only used as intermediate buffer - only returned after allocation to reuse for back transform - should not be used by caller
146  */
147 fft5d_plan fft5d_plan_3d(int NG, int MG, int KG, MPI_Comm comm[2], int flags, t_complex** rlin, t_complex** rlout, t_complex** rlout2, t_complex** rlout3, int nthreads, gmx::PinningPolicy realGridAllocationPinningPolicy)
148 {
149
150     int        P[2], bMaster, prank[2], i, t;
151     int        rNG, rMG, rKG;
152     int       *N0 = nullptr, *N1 = nullptr, *M0 = nullptr, *M1 = nullptr, *K0 = nullptr, *K1 = nullptr, *oN0 = nullptr, *oN1 = nullptr, *oM0 = nullptr, *oM1 = nullptr, *oK0 = nullptr, *oK1 = nullptr;
153     int        N[3], M[3], K[3], pN[3], pM[3], pK[3], oM[3], oK[3], *iNin[3] = {nullptr}, *oNin[3] = {nullptr}, *iNout[3] = {nullptr}, *oNout[3] = {nullptr};
154     int        C[3], rC[3], nP[2];
155     int        lsize;
156     t_complex *lin = nullptr, *lout = nullptr, *lout2 = nullptr, *lout3 = nullptr;
157     fft5d_plan plan;
158     int        s;
159
160     /* comm, prank and P are in the order of the decomposition (plan->cart is in the order of transposes) */
161 #if GMX_MPI
162     if (GMX_PARALLEL_ENV_INITIALIZED && comm[0] != MPI_COMM_NULL)
163     {
164         MPI_Comm_size(comm[0], &P[0]);
165         MPI_Comm_rank(comm[0], &prank[0]);
166     }
167     else
168 #endif
169     {
170         P[0]     = 1;
171         prank[0] = 0;
172     }
173 #if GMX_MPI
174     if (GMX_PARALLEL_ENV_INITIALIZED && comm[1] != MPI_COMM_NULL)
175     {
176         MPI_Comm_size(comm[1], &P[1]);
177         MPI_Comm_rank(comm[1], &prank[1]);
178     }
179     else
180 #endif
181     {
182         P[1]     = 1;
183         prank[1] = 0;
184     }
185
186     bMaster = (prank[0] == 0 && prank[1] == 0);
187
188
189     if (debug)
190     {
191         fprintf(debug, "FFT5D: Using %dx%d rank grid, rank %d,%d\n",
192                 P[0], P[1], prank[0], prank[1]);
193     }
194
195     if (bMaster)
196     {
197         if (debug)
198         {
199             fprintf(debug, "FFT5D: N: %d, M: %d, K: %d, P: %dx%d, real2complex: %d, backward: %d, order yz: %d, debug %d\n",
200                     NG, MG, KG, P[0], P[1], int((flags&FFT5D_REALCOMPLEX) > 0), int((flags&FFT5D_BACKWARD) > 0), int((flags&FFT5D_ORDER_YZ) > 0), int((flags&FFT5D_DEBUG) > 0));
201         }
202         /* The check below is not correct, one prime factor 11 or 13 is ok.
203            if (fft5d_fmax(fft5d_fmax(lpfactor(NG),lpfactor(MG)),lpfactor(KG))>7) {
204             printf("WARNING: FFT very slow with prime factors larger 7\n");
205             printf("Change FFT size or in case you cannot change it look at\n");
206             printf("http://www.fftw.org/fftw3_doc/Generating-your-own-code.html\n");
207            }
208          */
209     }
210
211     if (NG == 0 || MG == 0 || KG == 0)
212     {
213         if (bMaster)
214         {
215             printf("FFT5D: FATAL: Datasize cannot be zero in any dimension\n");
216         }
217         return nullptr;
218     }
219
220     rNG = NG; rMG = MG; rKG = KG;
221
222     if (flags&FFT5D_REALCOMPLEX)
223     {
224         if (!(flags&FFT5D_BACKWARD))
225         {
226             NG = NG/2+1;
227         }
228         else
229         {
230             if (!(flags&FFT5D_ORDER_YZ))
231             {
232                 MG = MG/2+1;
233             }
234             else
235             {
236                 KG = KG/2+1;
237             }
238         }
239     }
240
241
242     /*for transpose we need to know the size for each processor not only our own size*/
243
244     N0  = static_cast<int*>(malloc(P[0]*sizeof(int))); N1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
245     M0  = static_cast<int*>(malloc(P[0]*sizeof(int))); M1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
246     K0  = static_cast<int*>(malloc(P[0]*sizeof(int))); K1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
247     oN0 = static_cast<int*>(malloc(P[0]*sizeof(int))); oN1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
248     oM0 = static_cast<int*>(malloc(P[0]*sizeof(int))); oM1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
249     oK0 = static_cast<int*>(malloc(P[0]*sizeof(int))); oK1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
250
251     for (i = 0; i < P[0]; i++)
252     {
253         #define EVENDIST
254         #ifndef EVENDIST
255         oN0[i] = i*ceil((double)NG/P[0]);
256         oM0[i] = i*ceil((double)MG/P[0]);
257         oK0[i] = i*ceil((double)KG/P[0]);
258         #else
259         oN0[i] = (NG*i)/P[0];
260         oM0[i] = (MG*i)/P[0];
261         oK0[i] = (KG*i)/P[0];
262         #endif
263     }
264     for (i = 0; i < P[1]; i++)
265     {
266         #ifndef EVENDIST
267         oN1[i] = i*ceil((double)NG/P[1]);
268         oM1[i] = i*ceil((double)MG/P[1]);
269         oK1[i] = i*ceil((double)KG/P[1]);
270         #else
271         oN1[i] = (NG*i)/P[1];
272         oM1[i] = (MG*i)/P[1];
273         oK1[i] = (KG*i)/P[1];
274         #endif
275     }
276     for (i = 0; i < P[0]-1; i++)
277     {
278         N0[i] = oN0[i+1]-oN0[i];
279         M0[i] = oM0[i+1]-oM0[i];
280         K0[i] = oK0[i+1]-oK0[i];
281     }
282     N0[P[0]-1] = NG-oN0[P[0]-1];
283     M0[P[0]-1] = MG-oM0[P[0]-1];
284     K0[P[0]-1] = KG-oK0[P[0]-1];
285     for (i = 0; i < P[1]-1; i++)
286     {
287         N1[i] = oN1[i+1]-oN1[i];
288         M1[i] = oM1[i+1]-oM1[i];
289         K1[i] = oK1[i+1]-oK1[i];
290     }
291     N1[P[1]-1] = NG-oN1[P[1]-1];
292     M1[P[1]-1] = MG-oM1[P[1]-1];
293     K1[P[1]-1] = KG-oK1[P[1]-1];
294
295     /* for step 1-3 the local N,M,K sizes of the transposed system
296        C: contiguous dimension, and nP: number of processor in subcommunicator
297        for that step */
298
299
300     pM[0] = M0[prank[0]];
301     oM[0] = oM0[prank[0]];
302     pK[0] = K1[prank[1]];
303     oK[0] = oK1[prank[1]];
304     C[0]  = NG;
305     rC[0] = rNG;
306     if (!(flags&FFT5D_ORDER_YZ))
307     {
308         N[0]     = vmax(N1, P[1]);
309         M[0]     = M0[prank[0]];
310         K[0]     = vmax(K1, P[1]);
311         pN[0]    = N1[prank[1]];
312         iNout[0] = N1;
313         oNout[0] = oN1;
314         nP[0]    = P[1];
315         C[1]     = KG;
316         rC[1]    = rKG;
317         N[1]     = vmax(K0, P[0]);
318         pN[1]    = K0[prank[0]];
319         iNin[1]  = K1;
320         oNin[1]  = oK1;
321         iNout[1] = K0;
322         oNout[1] = oK0;
323         M[1]     = vmax(M0, P[0]);
324         pM[1]    = M0[prank[0]];
325         oM[1]    = oM0[prank[0]];
326         K[1]     = N1[prank[1]];
327         pK[1]    = N1[prank[1]];
328         oK[1]    = oN1[prank[1]];
329         nP[1]    = P[0];
330         C[2]     = MG;
331         rC[2]    = rMG;
332         iNin[2]  = M0;
333         oNin[2]  = oM0;
334         M[2]     = vmax(K0, P[0]);
335         pM[2]    = K0[prank[0]];
336         oM[2]    = oK0[prank[0]];
337         K[2]     = vmax(N1, P[1]);
338         pK[2]    = N1[prank[1]];
339         oK[2]    = oN1[prank[1]];
340         free(N0); free(oN0); /*these are not used for this order*/
341         free(M1); free(oM1); /*the rest is freed in destroy*/
342     }
343     else
344     {
345         N[0]     = vmax(N0, P[0]);
346         M[0]     = vmax(M0, P[0]);
347         K[0]     = K1[prank[1]];
348         pN[0]    = N0[prank[0]];
349         iNout[0] = N0;
350         oNout[0] = oN0;
351         nP[0]    = P[0];
352         C[1]     = MG;
353         rC[1]    = rMG;
354         N[1]     = vmax(M1, P[1]);
355         pN[1]    = M1[prank[1]];
356         iNin[1]  = M0;
357         oNin[1]  = oM0;
358         iNout[1] = M1;
359         oNout[1] = oM1;
360         M[1]     = N0[prank[0]];
361         pM[1]    = N0[prank[0]];
362         oM[1]    = oN0[prank[0]];
363         K[1]     = vmax(K1, P[1]);
364         pK[1]    = K1[prank[1]];
365         oK[1]    = oK1[prank[1]];
366         nP[1]    = P[1];
367         C[2]     = KG;
368         rC[2]    = rKG;
369         iNin[2]  = K1;
370         oNin[2]  = oK1;
371         M[2]     = vmax(N0, P[0]);
372         pM[2]    = N0[prank[0]];
373         oM[2]    = oN0[prank[0]];
374         K[2]     = vmax(M1, P[1]);
375         pK[2]    = M1[prank[1]];
376         oK[2]    = oM1[prank[1]];
377         free(N1); free(oN1); /*these are not used for this order*/
378         free(K0); free(oK0); /*the rest is freed in destroy*/
379     }
380     N[2] = pN[2] = -1;       /*not used*/
381
382     /*
383        Difference between x-y-z regarding 2d decomposition is whether they are
384        distributed along axis 1, 2 or both
385      */
386
387     /* int lsize = fmax(N[0]*M[0]*K[0]*nP[0],N[1]*M[1]*K[1]*nP[1]); */
388     lsize = std::max(N[0]*M[0]*K[0]*nP[0], std::max(N[1]*M[1]*K[1]*nP[1], C[2]*M[2]*K[2]));
389     /* int lsize = fmax(C[0]*M[0]*K[0],fmax(C[1]*M[1]*K[1],C[2]*M[2]*K[2])); */
390     if (!(flags&FFT5D_NOMALLOC))
391     {
392         // only needed for PME GPU mixed mode
393         if (realGridAllocationPinningPolicy == gmx::PinningPolicy::CanBePinned)
394         {
395             const std::size_t numBytes = lsize * sizeof(t_complex);
396             lin = static_cast<t_complex *>(gmx::PageAlignedAllocationPolicy::malloc(numBytes));
397             gmx::pinBuffer(lin, numBytes);
398         }
399         else
400         {
401             snew_aligned(lin, lsize, 32);
402         }
403         snew_aligned(lout, lsize, 32);
404         if (nthreads > 1)
405         {
406             /* We need extra transpose buffers to avoid OpenMP barriers */
407             snew_aligned(lout2, lsize, 32);
408             snew_aligned(lout3, lsize, 32);
409         }
410         else
411         {
412             /* We can reuse the buffers to avoid cache misses */
413             lout2 = lin;
414             lout3 = lout;
415         }
416     }
417     else
418     {
419         lin  = *rlin;
420         lout = *rlout;
421         if (nthreads > 1)
422         {
423             lout2 = *rlout2;
424             lout3 = *rlout3;
425         }
426         else
427         {
428             lout2 = lin;
429             lout3 = lout;
430         }
431     }
432
433     plan = static_cast<fft5d_plan>(calloc(1, sizeof(struct fft5d_plan_t)));
434
435
436     if (debug)
437     {
438         fprintf(debug, "Running on %d threads\n", nthreads);
439     }
440
441 #if GMX_FFT_FFTW3
442     /* Don't add more stuff here! We have already had at least one bug because we are reimplementing
443      * the low-level FFT interface instead of using the Gromacs FFT module. If we need more
444      * generic functionality it is far better to extend the interface so we can use it for
445      * all FFT libraries instead of writing FFTW-specific code here.
446      */
447
448     /*if not FFTW - then we don't do a 3d plan but instead use only 1D plans */
449     /* It is possible to use the 3d plan with OMP threads - but in that case it is not allowed to be called from
450      * within a parallel region. For now deactivated. If it should be supported it has to made sure that
451      * that the execute of the 3d plan is in a master/serial block (since it contains it own parallel region)
452      * and that the 3d plan is faster than the 1d plan.
453      */
454     if ((!(flags&FFT5D_INPLACE)) && (!(P[0] > 1 || P[1] > 1)) && nthreads == 1) /*don't do 3d plan in parallel or if in_place requested */
455     {
456         int fftwflags = FFTW_DESTROY_INPUT;
457         FFTW(iodim) dims[3];
458         int inNG = NG, outMG = MG, outKG = KG;
459
460         FFTW_LOCK;
461
462         fftwflags |= (flags & FFT5D_NOMEASURE) ? FFTW_ESTIMATE : FFTW_MEASURE;
463
464         if (flags&FFT5D_REALCOMPLEX)
465         {
466             if (!(flags&FFT5D_BACKWARD))        /*input pointer is not complex*/
467             {
468                 inNG *= 2;
469             }
470             else                                /*output pointer is not complex*/
471             {
472                 if (!(flags&FFT5D_ORDER_YZ))
473                 {
474                     outMG *= 2;
475                 }
476                 else
477                 {
478                     outKG *= 2;
479                 }
480             }
481         }
482
483         if (!(flags&FFT5D_BACKWARD))
484         {
485             dims[0].n  = KG;
486             dims[1].n  = MG;
487             dims[2].n  = rNG;
488
489             dims[0].is = inNG*MG;         /*N M K*/
490             dims[1].is = inNG;
491             dims[2].is = 1;
492             if (!(flags&FFT5D_ORDER_YZ))
493             {
494                 dims[0].os = MG;           /*M K N*/
495                 dims[1].os = 1;
496                 dims[2].os = MG*KG;
497             }
498             else
499             {
500                 dims[0].os = 1;           /*K N M*/
501                 dims[1].os = KG*NG;
502                 dims[2].os = KG;
503             }
504         }
505         else
506         {
507             if (!(flags&FFT5D_ORDER_YZ))
508             {
509                 dims[0].n  = NG;
510                 dims[1].n  = KG;
511                 dims[2].n  = rMG;
512
513                 dims[0].is = 1;
514                 dims[1].is = NG*MG;
515                 dims[2].is = NG;
516
517                 dims[0].os = outMG*KG;
518                 dims[1].os = outMG;
519                 dims[2].os = 1;
520             }
521             else
522             {
523                 dims[0].n  = MG;
524                 dims[1].n  = NG;
525                 dims[2].n  = rKG;
526
527                 dims[0].is = NG;
528                 dims[1].is = 1;
529                 dims[2].is = NG*MG;
530
531                 dims[0].os = outKG*NG;
532                 dims[1].os = outKG;
533                 dims[2].os = 1;
534             }
535         }
536 #ifdef FFT5D_THREADS
537 #ifdef FFT5D_FFTW_THREADS
538         FFTW(plan_with_nthreads)(nthreads);
539 #endif
540 #endif
541         if ((flags&FFT5D_REALCOMPLEX) && !(flags&FFT5D_BACKWARD))
542         {
543             plan->p3d = FFTW(plan_guru_dft_r2c)(/*rank*/ 3, dims,
544                                                          /*howmany*/ 0, /*howmany_dims*/ nullptr,
545                                                          reinterpret_cast<real*>(lin), reinterpret_cast<FFTW(complex) *>(lout),
546                                                          /*flags*/ fftwflags);
547         }
548         else if ((flags&FFT5D_REALCOMPLEX) && (flags&FFT5D_BACKWARD))
549         {
550             plan->p3d = FFTW(plan_guru_dft_c2r)(/*rank*/ 3, dims,
551                                                          /*howmany*/ 0, /*howmany_dims*/ nullptr,
552                                                          reinterpret_cast<FFTW(complex) *>(lin), reinterpret_cast<real*>(lout),
553                                                          /*flags*/ fftwflags);
554         }
555         else
556         {
557             plan->p3d = FFTW(plan_guru_dft)(/*rank*/ 3, dims,
558                                                      /*howmany*/ 0, /*howmany_dims*/ nullptr,
559                                                      reinterpret_cast<FFTW(complex) *>(lin), reinterpret_cast<FFTW(complex) *>(lout),
560                                                      /*sign*/ (flags&FFT5D_BACKWARD) ? 1 : -1, /*flags*/ fftwflags);
561         }
562 #ifdef FFT5D_THREADS
563 #ifdef FFT5D_FFTW_THREADS
564         FFTW(plan_with_nthreads)(1);
565 #endif
566 #endif
567         FFTW_UNLOCK;
568     }
569     if (!plan->p3d) /* for decomposition and if 3d plan did not work */
570     {
571 #endif              /* GMX_FFT_FFTW3 */
572     for (s = 0; s < 3; s++)
573     {
574         if (debug)
575         {
576             fprintf(debug, "FFT5D: Plan s %d rC %d M %d pK %d C %d lsize %d\n",
577                     s, rC[s], M[s], pK[s], C[s], lsize);
578         }
579         plan->p1d[s] = static_cast<gmx_fft_t*>(malloc(sizeof(gmx_fft_t)*nthreads));
580
581         /* Make sure that the init routines are only called by one thread at a time and in order
582            (later is only important to not confuse valgrind)
583          */
584 #pragma omp parallel for num_threads(nthreads) schedule(static) ordered
585         for (t = 0; t < nthreads; t++)
586         {
587 #pragma omp ordered
588             {
589                 try
590                 {
591                     int tsize = ((t+1)*pM[s]*pK[s]/nthreads)-(t*pM[s]*pK[s]/nthreads);
592
593                     if ((flags&FFT5D_REALCOMPLEX) && ((!(flags&FFT5D_BACKWARD) && s == 0) || ((flags&FFT5D_BACKWARD) && s == 2)))
594                     {
595                         gmx_fft_init_many_1d_real( &plan->p1d[s][t], rC[s], tsize, (flags&FFT5D_NOMEASURE) ? GMX_FFT_FLAG_CONSERVATIVE : 0 );
596                     }
597                     else
598                     {
599                         gmx_fft_init_many_1d     ( &plan->p1d[s][t],  C[s], tsize, (flags&FFT5D_NOMEASURE) ? GMX_FFT_FLAG_CONSERVATIVE : 0 );
600                     }
601                 }
602                 GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
603             }
604         }
605     }
606
607 #if GMX_FFT_FFTW3
608 }
609 #endif
610     if ((flags&FFT5D_ORDER_YZ))   /*plan->cart is in the order of transposes */
611     {
612         plan->cart[0] = comm[0]; plan->cart[1] = comm[1];
613     }
614     else
615     {
616         plan->cart[1] = comm[0]; plan->cart[0] = comm[1];
617     }
618 #ifdef FFT5D_MPI_TRANSPOSE
619     FFTW_LOCK;
620     for (s = 0; s < 2; s++)
621     {
622         if ((s == 0 && !(flags&FFT5D_ORDER_YZ)) || (s == 1 && (flags&FFT5D_ORDER_YZ)))
623         {
624             plan->mpip[s] = FFTW(mpi_plan_many_transpose)(nP[s], nP[s], N[s]*K[s]*pM[s]*2, 1, 1, (real*)lout2, (real*)lout3, plan->cart[s], FFTW_PATIENT);
625         }
626         else
627         {
628             plan->mpip[s] = FFTW(mpi_plan_many_transpose)(nP[s], nP[s], N[s]*pK[s]*M[s]*2, 1, 1, (real*)lout2, (real*)lout3, plan->cart[s], FFTW_PATIENT);
629         }
630     }
631     FFTW_UNLOCK;
632 #endif
633
634
635     plan->lin   = lin;
636     plan->lout  = lout;
637     plan->lout2 = lout2;
638     plan->lout3 = lout3;
639
640     plan->NG = NG; plan->MG = MG; plan->KG = KG;
641
642     for (s = 0; s < 3; s++)
643     {
644         plan->N[s]    = N[s]; plan->M[s] = M[s]; plan->K[s] = K[s]; plan->pN[s] = pN[s]; plan->pM[s] = pM[s]; plan->pK[s] = pK[s];
645         plan->oM[s]   = oM[s]; plan->oK[s] = oK[s];
646         plan->C[s]    = C[s]; plan->rC[s] = rC[s];
647         plan->iNin[s] = iNin[s]; plan->oNin[s] = oNin[s]; plan->iNout[s] = iNout[s]; plan->oNout[s] = oNout[s];
648     }
649     for (s = 0; s < 2; s++)
650     {
651         plan->P[s] = nP[s]; plan->coor[s] = prank[s];
652     }
653
654 /*    plan->fftorder=fftorder;
655     plan->direction=direction;
656     plan->realcomplex=realcomplex;
657  */
658     plan->flags         = flags;
659     plan->nthreads      = nthreads;
660     plan->pinningPolicy = realGridAllocationPinningPolicy;
661     *rlin               = lin;
662     *rlout              = lout;
663     *rlout2             = lout2;
664     *rlout3             = lout3;
665     return plan;
666 }
667
668
669 enum order {
670     XYZ,
671     XZY,
672     YXZ,
673     YZX,
674     ZXY,
675     ZYX
676 };
677
678
679
680 /*here x,y,z and N,M,K is in rotated coordinate system!!
681    x (and N) is mayor (consecutive) dimension, y (M) middle and z (K) major
682    maxN,maxM,maxK is max size of local data
683    pN, pM, pK is local size specific to current processor (only different to max if not divisible)
684    NG, MG, KG is size of global data*/
685 static void splitaxes(t_complex* lout, const t_complex* lin,
686                       int maxN, int maxM, int maxK, int pM,
687                       int P, int NG, const int *N, const int* oN, int starty, int startz, int endy, int endz)
688 {
689     int x, y, z, i;
690     int in_i, out_i, in_z, out_z, in_y, out_y;
691     int s_y, e_y;
692
693     for (z = startz; z < endz+1; z++) /*3. z l*/
694     {
695         if (z == startz)
696         {
697             s_y = starty;
698         }
699         else
700         {
701             s_y = 0;
702         }
703         if (z == endz)
704         {
705             e_y = endy;
706         }
707         else
708         {
709             e_y = pM;
710         }
711         out_z  = z*maxN*maxM;
712         in_z   = z*NG*pM;
713
714         for (i = 0; i < P; i++) /*index cube along long axis*/
715         {
716             out_i  = out_z  + i*maxN*maxM*maxK;
717             in_i   = in_z + oN[i];
718             for (y = s_y; y < e_y; y++)   /*2. y k*/
719             {
720                 out_y  = out_i  + y*maxN;
721                 in_y   = in_i + y*NG;
722                 for (x = 0; x < N[i]; x++)       /*1. x j*/
723                 {
724                     lout[out_y+x] = lin[in_y+x]; /*in=z*NG*pM+oN[i]+y*NG+x*/
725                     /*after split important that each processor chunk i has size maxN*maxM*maxK and thus being the same size*/
726                     /*before split data contiguos - thus if different processor get different amount oN is different*/
727                 }
728             }
729         }
730     }
731 }
732
733 /*make axis contiguous again (after AllToAll) and also do local transpose*/
734 /*transpose mayor and major dimension
735    variables see above
736    the major, middle, minor order is only correct for x,y,z (N,M,K) for the input
737    N,M,K local dimensions
738    KG global size*/
739 static void joinAxesTrans13(t_complex* lout, const t_complex* lin,
740                             int maxN, int maxM, int maxK, int pM,
741                             int P, int KG, const int* K, const int* oK, int starty, int startx, int endy, int endx)
742 {
743     int i, x, y, z;
744     int out_i, in_i, out_x, in_x, out_z, in_z;
745     int s_y, e_y;
746
747     for (x = startx; x < endx+1; x++) /*1.j*/
748     {
749         if (x == startx)
750         {
751             s_y = starty;
752         }
753         else
754         {
755             s_y = 0;
756         }
757         if (x == endx)
758         {
759             e_y = endy;
760         }
761         else
762         {
763             e_y = pM;
764         }
765
766         out_x  = x*KG*pM;
767         in_x   = x;
768
769         for (i = 0; i < P; i++) /*index cube along long axis*/
770         {
771             out_i  = out_x  + oK[i];
772             in_i   = in_x + i*maxM*maxN*maxK;
773             for (z = 0; z < K[i]; z++) /*3.l*/
774             {
775                 out_z  = out_i  + z;
776                 in_z   = in_i + z*maxM*maxN;
777                 for (y = s_y; y < e_y; y++)              /*2.k*/
778                 {
779                     lout[out_z+y*KG] = lin[in_z+y*maxN]; /*out=x*KG*pM+oK[i]+z+y*KG*/
780                 }
781             }
782         }
783     }
784 }
785
786 /*make axis contiguous again (after AllToAll) and also do local transpose
787    tranpose mayor and middle dimension
788    variables see above
789    the minor, middle, major order is only correct for x,y,z (N,M,K) for the input
790    N,M,K local size
791    MG, global size*/
792 static void joinAxesTrans12(t_complex* lout, const t_complex* lin, int maxN, int maxM, int maxK, int pN,
793                             int P, int MG, const int* M, const int* oM, int startx, int startz, int endx, int endz)
794 {
795     int i, z, y, x;
796     int out_i, in_i, out_z, in_z, out_x, in_x;
797     int s_x, e_x;
798
799     for (z = startz; z < endz+1; z++)
800     {
801         if (z == startz)
802         {
803             s_x = startx;
804         }
805         else
806         {
807             s_x = 0;
808         }
809         if (z == endz)
810         {
811             e_x = endx;
812         }
813         else
814         {
815             e_x = pN;
816         }
817         out_z  = z*MG*pN;
818         in_z   = z*maxM*maxN;
819
820         for (i = 0; i < P; i++) /*index cube along long axis*/
821         {
822             out_i  = out_z  + oM[i];
823             in_i   = in_z + i*maxM*maxN*maxK;
824             for (x = s_x; x < e_x; x++)
825             {
826                 out_x  = out_i  + x*MG;
827                 in_x   = in_i + x;
828                 for (y = 0; y < M[i]; y++)
829                 {
830                     lout[out_x+y] = lin[in_x+y*maxN]; /*out=z*MG*pN+oM[i]+x*MG+y*/
831                 }
832             }
833         }
834     }
835 }
836
837
838 static void rotate_offsets(int x[])
839 {
840     int t = x[0];
841 /*    x[0]=x[2];
842     x[2]=x[1];
843     x[1]=t;*/
844     x[0] = x[1];
845     x[1] = x[2];
846     x[2] = t;
847 }
848
849 /*compute the offset to compare or print transposed local data in original input coordinates
850    xs matrix dimension size, xl dimension length, xc decomposition offset
851    s: step in computation = number of transposes*/
852 static void compute_offsets(fft5d_plan plan, int xs[], int xl[], int xc[], int NG[], int s)
853 {
854 /*    int direction = plan->direction;
855     int fftorder = plan->fftorder;*/
856
857     int  o = 0;
858     int  pos[3], i;
859     int *pM = plan->pM, *pK = plan->pK, *oM = plan->oM, *oK = plan->oK,
860     *C      = plan->C, *rC = plan->rC;
861
862     NG[0] = plan->NG; NG[1] = plan->MG; NG[2] = plan->KG;
863
864     if (!(plan->flags&FFT5D_ORDER_YZ))
865     {
866         switch (s)
867         {
868             case 0: o = XYZ; break;
869             case 1: o = ZYX; break;
870             case 2: o = YZX; break;
871             default: assert(0);
872         }
873     }
874     else
875     {
876         switch (s)
877         {
878             case 0: o = XYZ; break;
879             case 1: o = YXZ; break;
880             case 2: o = ZXY; break;
881             default: assert(0);
882         }
883     }
884
885     switch (o)
886     {
887         case XYZ: pos[0] = 1; pos[1] = 2; pos[2] = 3; break;
888         case XZY: pos[0] = 1; pos[1] = 3; pos[2] = 2; break;
889         case YXZ: pos[0] = 2; pos[1] = 1; pos[2] = 3; break;
890         case YZX: pos[0] = 3; pos[1] = 1; pos[2] = 2; break;
891         case ZXY: pos[0] = 2; pos[1] = 3; pos[2] = 1; break;
892         case ZYX: pos[0] = 3; pos[1] = 2; pos[2] = 1; break;
893     }
894     /*if (debug) printf("pos: %d %d %d\n",pos[0],pos[1],pos[2]);*/
895
896     /*xs, xl give dimension size and data length in local transposed coordinate system
897        for 0(/1/2): x(/y/z) in original coordinate system*/
898     for (i = 0; i < 3; i++)
899     {
900         switch (pos[i])
901         {
902             case 1: xs[i] = 1;         xc[i] = 0;     xl[i] = C[s]; break;
903             case 2: xs[i] = C[s];      xc[i] = oM[s]; xl[i] = pM[s]; break;
904             case 3: xs[i] = C[s]*pM[s]; xc[i] = oK[s]; xl[i] = pK[s]; break;
905         }
906     }
907     /*input order is different for test program to match FFTW order
908        (important for complex to real)*/
909     if (plan->flags&FFT5D_BACKWARD)
910     {
911         rotate_offsets(xs);
912         rotate_offsets(xl);
913         rotate_offsets(xc);
914         rotate_offsets(NG);
915         if (plan->flags&FFT5D_ORDER_YZ)
916         {
917             rotate_offsets(xs);
918             rotate_offsets(xl);
919             rotate_offsets(xc);
920             rotate_offsets(NG);
921         }
922     }
923     if ((plan->flags&FFT5D_REALCOMPLEX) && ((!(plan->flags&FFT5D_BACKWARD) && s == 0) || ((plan->flags&FFT5D_BACKWARD) && s == 2)))
924     {
925         xl[0] = rC[s];
926     }
927 }
928
929 static void print_localdata(const t_complex* lin, const char* txt, int s, fft5d_plan plan)
930 {
931     int  x, y, z, l;
932     int *coor = plan->coor;
933     int  xs[3], xl[3], xc[3], NG[3];
934     int  ll = (plan->flags&FFT5D_REALCOMPLEX) ? 1 : 2;
935     compute_offsets(plan, xs, xl, xc, NG, s);
936     fprintf(debug, txt, coor[0], coor[1]);
937     /*printf("xs: %d %d %d, xl: %d %d %d\n",xs[0],xs[1],xs[2],xl[0],xl[1],xl[2]);*/
938     for (z = 0; z < xl[2]; z++)
939     {
940         for (y = 0; y < xl[1]; y++)
941         {
942             fprintf(debug, "%d %d: ", coor[0], coor[1]);
943             for (x = 0; x < xl[0]; x++)
944             {
945                 for (l = 0; l < ll; l++)
946                 {
947                     fprintf(debug, "%f ", ((real*)lin)[(z*xs[2]+y*xs[1])*2+(x*xs[0])*ll+l]);
948                 }
949                 fprintf(debug, ",");
950             }
951             fprintf(debug, "\n");
952         }
953     }
954 }
955
956 void fft5d_execute(fft5d_plan plan, int thread, fft5d_time times)
957 {
958     t_complex  *lin   = plan->lin;
959     t_complex  *lout  = plan->lout;
960     t_complex  *lout2 = plan->lout2;
961     t_complex  *lout3 = plan->lout3;
962     t_complex  *fftout, *joinin;
963
964     gmx_fft_t **p1d = plan->p1d;
965 #ifdef FFT5D_MPI_TRANSPOSE
966     FFTW(plan) *mpip = plan->mpip;
967 #endif
968 #if GMX_MPI
969     MPI_Comm *cart = plan->cart;
970 #endif
971 #ifdef NOGMX
972     double time_fft = 0, time_local = 0, time_mpi[2] = {0}, time = 0;
973 #endif
974     int   *N = plan->N, *M = plan->M, *K = plan->K, *pN = plan->pN, *pM = plan->pM, *pK = plan->pK,
975     *C       = plan->C, *P = plan->P, **iNin = plan->iNin, **oNin = plan->oNin, **iNout = plan->iNout, **oNout = plan->oNout;
976     int    s = 0, tstart, tend, bParallelDim;
977
978
979 #if GMX_FFT_FFTW3
980     if (plan->p3d)
981     {
982         if (thread == 0)
983         {
984 #ifdef NOGMX
985             if (times != 0)
986             {
987                 time = MPI_Wtime();
988             }
989 #endif
990             FFTW(execute)(plan->p3d);
991 #ifdef NOGMX
992             if (times != 0)
993             {
994                 times->fft += MPI_Wtime()-time;
995             }
996 #endif
997         }
998         return;
999     }
1000 #endif
1001
1002     s = 0;
1003
1004     /*lin: x,y,z*/
1005     if ((plan->flags&FFT5D_DEBUG) && thread == 0)
1006     {
1007         print_localdata(lin, "%d %d: copy in lin\n", s, plan);
1008     }
1009
1010     for (s = 0; s < 2; s++)  /*loop over first two FFT steps (corner rotations)*/
1011
1012     {
1013 #if GMX_MPI
1014         if (GMX_PARALLEL_ENV_INITIALIZED && cart[s] != MPI_COMM_NULL && P[s] > 1)
1015         {
1016             bParallelDim = 1;
1017         }
1018         else
1019 #endif
1020         {
1021             bParallelDim = 0;
1022         }
1023
1024         /* ---------- START FFT ------------ */
1025 #ifdef NOGMX
1026         if (times != 0 && thread == 0)
1027         {
1028             time = MPI_Wtime();
1029         }
1030 #endif
1031
1032         if (bParallelDim || plan->nthreads == 1)
1033         {
1034             fftout = lout;
1035         }
1036         else
1037         {
1038             if (s == 0)
1039             {
1040                 fftout = lout3;
1041             }
1042             else
1043             {
1044                 fftout = lout2;
1045             }
1046         }
1047
1048         tstart = (thread*pM[s]*pK[s]/plan->nthreads)*C[s];
1049         if ((plan->flags&FFT5D_REALCOMPLEX) && !(plan->flags&FFT5D_BACKWARD) && s == 0)
1050         {
1051             gmx_fft_many_1d_real(p1d[s][thread], (plan->flags&FFT5D_BACKWARD) ? GMX_FFT_COMPLEX_TO_REAL : GMX_FFT_REAL_TO_COMPLEX, lin+tstart, fftout+tstart);
1052         }
1053         else
1054         {
1055             gmx_fft_many_1d(     p1d[s][thread], (plan->flags&FFT5D_BACKWARD) ? GMX_FFT_BACKWARD : GMX_FFT_FORWARD,               lin+tstart, fftout+tstart);
1056
1057         }
1058
1059 #ifdef NOGMX
1060         if (times != NULL && thread == 0)
1061         {
1062             time_fft += MPI_Wtime()-time;
1063         }
1064 #endif
1065         if ((plan->flags&FFT5D_DEBUG) && thread == 0)
1066         {
1067             print_localdata(lout, "%d %d: FFT %d\n", s, plan);
1068         }
1069         /* ---------- END FFT ------------ */
1070
1071         /* ---------- START SPLIT + TRANSPOSE------------ (if parallel in in this dimension)*/
1072         if (bParallelDim)
1073         {
1074 #ifdef NOGMX
1075             if (times != NULL && thread == 0)
1076             {
1077                 time = MPI_Wtime();
1078             }
1079 #endif
1080             /*prepare for A
1081                llToAll
1082                1. (most outer) axes (x) is split into P[s] parts of size N[s]
1083                for sending*/
1084             if (pM[s] > 0)
1085             {
1086                 tend    = ((thread+1)*pM[s]*pK[s]/plan->nthreads);
1087                 tstart /= C[s];
1088                 splitaxes(lout2, lout, N[s], M[s], K[s], pM[s], P[s], C[s], iNout[s], oNout[s], tstart%pM[s], tstart/pM[s], tend%pM[s], tend/pM[s]);
1089             }
1090 #pragma omp barrier /*barrier required before AllToAll (all input has to be their) - before timing to make timing more acurate*/
1091 #ifdef NOGMX
1092             if (times != NULL && thread == 0)
1093             {
1094                 time_local += MPI_Wtime()-time;
1095             }
1096 #endif
1097
1098             /* ---------- END SPLIT , START TRANSPOSE------------ */
1099
1100             if (thread == 0)
1101             {
1102 #ifdef NOGMX
1103                 if (times != 0)
1104                 {
1105                     time = MPI_Wtime();
1106                 }
1107 #else
1108                 wallcycle_start(times, ewcPME_FFTCOMM);
1109 #endif
1110 #ifdef FFT5D_MPI_TRANSPOSE
1111                 FFTW(execute)(mpip[s]);
1112 #else
1113 #if GMX_MPI
1114                 if ((s == 0 && !(plan->flags&FFT5D_ORDER_YZ)) || (s == 1 && (plan->flags&FFT5D_ORDER_YZ)))
1115                 {
1116                     MPI_Alltoall(reinterpret_cast<real *>(lout2), N[s]*pM[s]*K[s]*sizeof(t_complex)/sizeof(real), GMX_MPI_REAL, reinterpret_cast<real *>(lout3), N[s]*pM[s]*K[s]*sizeof(t_complex)/sizeof(real), GMX_MPI_REAL, cart[s]);
1117                 }
1118                 else
1119                 {
1120                     MPI_Alltoall(reinterpret_cast<real *>(lout2), N[s]*M[s]*pK[s]*sizeof(t_complex)/sizeof(real), GMX_MPI_REAL, reinterpret_cast<real *>(lout3), N[s]*M[s]*pK[s]*sizeof(t_complex)/sizeof(real), GMX_MPI_REAL, cart[s]);
1121                 }
1122 #else
1123                 gmx_incons("fft5d MPI call without MPI configuration");
1124 #endif /*GMX_MPI*/
1125 #endif /*FFT5D_MPI_TRANSPOSE*/
1126 #ifdef NOGMX
1127                 if (times != 0)
1128                 {
1129                     time_mpi[s] = MPI_Wtime()-time;
1130                 }
1131 #else
1132                 wallcycle_stop(times, ewcPME_FFTCOMM);
1133 #endif
1134             } /*master*/
1135         }     /* bPrallelDim */
1136 #pragma omp barrier  /*both needed for parallel and non-parallel dimension (either have to wait on data from AlltoAll or from last FFT*/
1137
1138         /* ---------- END SPLIT + TRANSPOSE------------ */
1139
1140         /* ---------- START JOIN ------------ */
1141 #ifdef NOGMX
1142         if (times != NULL && thread == 0)
1143         {
1144             time = MPI_Wtime();
1145         }
1146 #endif
1147
1148         if (bParallelDim)
1149         {
1150             joinin = lout3;
1151         }
1152         else
1153         {
1154             joinin = fftout;
1155         }
1156         /*bring back in matrix form
1157            thus make  new 1. axes contiguos
1158            also local transpose 1 and 2/3
1159            runs on thread used for following FFT (thus needing a barrier before but not afterwards)
1160          */
1161         if ((s == 0 && !(plan->flags&FFT5D_ORDER_YZ)) || (s == 1 && (plan->flags&FFT5D_ORDER_YZ)))
1162         {
1163             if (pM[s] > 0)
1164             {
1165                 tstart = ( thread   *pM[s]*pN[s]/plan->nthreads);
1166                 tend   = ((thread+1)*pM[s]*pN[s]/plan->nthreads);
1167                 joinAxesTrans13(lin, joinin, N[s], pM[s], K[s], pM[s], P[s], C[s+1], iNin[s+1], oNin[s+1], tstart%pM[s], tstart/pM[s], tend%pM[s], tend/pM[s]);
1168             }
1169         }
1170         else
1171         {
1172             if (pN[s] > 0)
1173             {
1174                 tstart = ( thread   *pK[s]*pN[s]/plan->nthreads);
1175                 tend   = ((thread+1)*pK[s]*pN[s]/plan->nthreads);
1176                 joinAxesTrans12(lin, joinin, N[s], M[s], pK[s], pN[s], P[s], C[s+1], iNin[s+1], oNin[s+1], tstart%pN[s], tstart/pN[s], tend%pN[s], tend/pN[s]);
1177             }
1178         }
1179
1180 #ifdef NOGMX
1181         if (times != NULL && thread == 0)
1182         {
1183             time_local += MPI_Wtime()-time;
1184         }
1185 #endif
1186         if ((plan->flags&FFT5D_DEBUG) && thread == 0)
1187         {
1188             print_localdata(lin, "%d %d: tranposed %d\n", s+1, plan);
1189         }
1190         /* ---------- END JOIN ------------ */
1191
1192         /*if (debug) print_localdata(lin, "%d %d: transposed x-z\n", N1, M0, K, ZYX, coor);*/
1193     }  /* for(s=0;s<2;s++) */
1194 #ifdef NOGMX
1195     if (times != NULL && thread == 0)
1196     {
1197         time = MPI_Wtime();
1198     }
1199 #endif
1200
1201     if (plan->flags&FFT5D_INPLACE)
1202     {
1203         lout = lin;                          /*in place currently not supported*/
1204
1205     }
1206     /*  ----------- FFT ----------- */
1207     tstart = (thread*pM[s]*pK[s]/plan->nthreads)*C[s];
1208     if ((plan->flags&FFT5D_REALCOMPLEX) && (plan->flags&FFT5D_BACKWARD))
1209     {
1210         gmx_fft_many_1d_real(p1d[s][thread], (plan->flags&FFT5D_BACKWARD) ? GMX_FFT_COMPLEX_TO_REAL : GMX_FFT_REAL_TO_COMPLEX, lin+tstart, lout+tstart);
1211     }
1212     else
1213     {
1214         gmx_fft_many_1d(     p1d[s][thread], (plan->flags&FFT5D_BACKWARD) ? GMX_FFT_BACKWARD : GMX_FFT_FORWARD,               lin+tstart, lout+tstart);
1215     }
1216     /* ------------ END FFT ---------*/
1217
1218 #ifdef NOGMX
1219     if (times != NULL && thread == 0)
1220     {
1221         time_fft += MPI_Wtime()-time;
1222
1223         times->fft   += time_fft;
1224         times->local += time_local;
1225         times->mpi2  += time_mpi[1];
1226         times->mpi1  += time_mpi[0];
1227     }
1228 #endif
1229
1230     if ((plan->flags&FFT5D_DEBUG) && thread == 0)
1231     {
1232         print_localdata(lout, "%d %d: FFT %d\n", s, plan);
1233     }
1234 }
1235
1236 void fft5d_destroy(fft5d_plan plan)
1237 {
1238     int s, t;
1239
1240     for (s = 0; s < 3; s++)
1241     {
1242         if (plan->p1d[s])
1243         {
1244             for (t = 0; t < plan->nthreads; t++)
1245             {
1246                 gmx_many_fft_destroy(plan->p1d[s][t]);
1247             }
1248             free(plan->p1d[s]);
1249         }
1250         if (plan->iNin[s])
1251         {
1252             free(plan->iNin[s]);
1253             plan->iNin[s] = nullptr;
1254         }
1255         if (plan->oNin[s])
1256         {
1257             free(plan->oNin[s]);
1258             plan->oNin[s] = nullptr;
1259         }
1260         if (plan->iNout[s])
1261         {
1262             free(plan->iNout[s]);
1263             plan->iNout[s] = nullptr;
1264         }
1265         if (plan->oNout[s])
1266         {
1267             free(plan->oNout[s]);
1268             plan->oNout[s] = nullptr;
1269         }
1270     }
1271 #if GMX_FFT_FFTW3
1272     FFTW_LOCK;
1273 #ifdef FFT5D_MPI_TRANSPOS
1274     for (s = 0; s < 2; s++)
1275     {
1276         FFTW(destroy_plan)(plan->mpip[s]);
1277     }
1278 #endif /* FFT5D_MPI_TRANSPOS */
1279     if (plan->p3d)
1280     {
1281         FFTW(destroy_plan)(plan->p3d);
1282     }
1283     FFTW_UNLOCK;
1284 #endif /* GMX_FFT_FFTW3 */
1285
1286     if (!(plan->flags&FFT5D_NOMALLOC))
1287     {
1288         // only needed for PME GPU mixed mode
1289         if (plan->pinningPolicy == gmx::PinningPolicy::CanBePinned &&
1290             isHostMemoryPinned(plan->lin))
1291         {
1292             gmx::unpinBuffer(plan->lin);
1293         }
1294         sfree_aligned(plan->lin);
1295         sfree_aligned(plan->lout);
1296         if (plan->nthreads > 1)
1297         {
1298             sfree_aligned(plan->lout2);
1299             sfree_aligned(plan->lout3);
1300         }
1301     }
1302
1303 #ifdef FFT5D_THREADS
1304 #ifdef FFT5D_FFTW_THREADS
1305     /*FFTW(cleanup_threads)();*/
1306 #endif
1307 #endif
1308
1309     free(plan);
1310 }
1311
1312 /*Is this better than direct access of plan? enough data?
1313    here 0,1 reference divided by which processor grid dimension (not FFT step!)*/
1314 void fft5d_local_size(fft5d_plan plan, int* N1, int* M0, int* K0, int* K1, int** coor)
1315 {
1316     *N1 = plan->N[0];
1317     *M0 = plan->M[0];
1318     *K1 = plan->K[0];
1319     *K0 = plan->N[1];
1320
1321     *coor = plan->coor;
1322 }
1323
1324
1325 /*same as fft5d_plan_3d but with cartesian coordinator and automatic splitting
1326    of processor dimensions*/
1327 fft5d_plan fft5d_plan_3d_cart(int NG, int MG, int KG, MPI_Comm comm, int P0, int flags, t_complex** rlin, t_complex** rlout, t_complex** rlout2, t_complex** rlout3, int nthreads)
1328 {
1329     MPI_Comm cart[2] = {MPI_COMM_NULL, MPI_COMM_NULL};
1330 #if GMX_MPI
1331     int      size = 1, prank = 0;
1332     int      P[2];
1333     int      coor[2];
1334     int      wrap[] = {0, 0};
1335     MPI_Comm gcart;
1336     int      rdim1[] = {0, 1}, rdim2[] = {1, 0};
1337
1338     MPI_Comm_size(comm, &size);
1339     MPI_Comm_rank(comm, &prank);
1340
1341     if (P0 == 0)
1342     {
1343         P0 = lfactor(size);
1344     }
1345     if (size%P0 != 0)
1346     {
1347         if (prank == 0)
1348         {
1349             printf("FFT5D: WARNING: Number of ranks %d not evenly divisible by %d\n", size, P0);
1350         }
1351         P0 = lfactor(size);
1352     }
1353
1354     P[0] = P0; P[1] = size/P0; /*number of processors in the two dimensions*/
1355
1356     /*Difference between x-y-z regarding 2d decomposition is whether they are
1357        distributed along axis 1, 2 or both*/
1358
1359     MPI_Cart_create(comm, 2, P, wrap, 1, &gcart); /*parameter 4: value 1: reorder*/
1360     MPI_Cart_get(gcart, 2, P, wrap, coor);
1361     MPI_Cart_sub(gcart, rdim1, &cart[0]);
1362     MPI_Cart_sub(gcart, rdim2, &cart[1]);
1363 #else
1364     (void)P0;
1365     (void)comm;
1366 #endif
1367     return fft5d_plan_3d(NG, MG, KG, cart, flags, rlin, rlout, rlout2, rlout3, nthreads);
1368 }
1369
1370
1371
1372 /*prints in original coordinate system of data (as the input to FFT)*/
1373 void fft5d_compare_data(const t_complex* lin, const t_complex* in, fft5d_plan plan, int bothLocal, int normalize)
1374 {
1375     int  xs[3], xl[3], xc[3], NG[3];
1376     int  x, y, z, l;
1377     int *coor = plan->coor;
1378     int  ll   = 2; /*compare ll values per element (has to be 2 for complex)*/
1379     if ((plan->flags&FFT5D_REALCOMPLEX) && (plan->flags&FFT5D_BACKWARD))
1380     {
1381         ll = 1;
1382     }
1383
1384     compute_offsets(plan, xs, xl, xc, NG, 2);
1385     if (plan->flags&FFT5D_DEBUG)
1386     {
1387         printf("Compare2\n");
1388     }
1389     for (z = 0; z < xl[2]; z++)
1390     {
1391         for (y = 0; y < xl[1]; y++)
1392         {
1393             if (plan->flags&FFT5D_DEBUG)
1394             {
1395                 printf("%d %d: ", coor[0], coor[1]);
1396             }
1397             for (x = 0; x < xl[0]; x++)
1398             {
1399                 for (l = 0; l < ll; l++)   /*loop over real/complex parts*/
1400                 {
1401                     real a, b;
1402                     a = ((real*)lin)[(z*xs[2]+y*xs[1])*2+x*xs[0]*ll+l];
1403                     if (normalize)
1404                     {
1405                         a /= plan->rC[0]*plan->rC[1]*plan->rC[2];
1406                     }
1407                     if (!bothLocal)
1408                     {
1409                         b = ((real*)in)[((z+xc[2])*NG[0]*NG[1]+(y+xc[1])*NG[0])*2+(x+xc[0])*ll+l];
1410                     }
1411                     else
1412                     {
1413                         b = ((real*)in)[(z*xs[2]+y*xs[1])*2+x*xs[0]*ll+l];
1414                     }
1415                     if (plan->flags&FFT5D_DEBUG)
1416                     {
1417                         printf("%f %f, ", a, b);
1418                     }
1419                     else
1420                     {
1421                         if (std::fabs(a-b) > 2*NG[0]*NG[1]*NG[2]*GMX_REAL_EPS)
1422                         {
1423                             printf("result incorrect on %d,%d at %d,%d,%d: FFT5D:%f reference:%f\n", coor[0], coor[1], x, y, z, a, b);
1424                         }
1425 /*                        assert(fabs(a-b)<2*NG[0]*NG[1]*NG[2]*GMX_REAL_EPS);*/
1426                     }
1427                 }
1428                 if (plan->flags&FFT5D_DEBUG)
1429                 {
1430                     printf(",");
1431                 }
1432             }
1433             if (plan->flags&FFT5D_DEBUG)
1434             {
1435                 printf("\n");
1436             }
1437         }
1438     }
1439
1440 }