Change CanBePinned to PinnedIfSupported
[alexxy/gromacs.git] / src / gromacs / fft / fft5d.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2009,2010,2012,2013,2014,2015,2016,2017,2018, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 #include "gmxpre.h"
36
37 #include "fft5d.h"
38
39 #include "config.h"
40
41 #include <cassert>
42 #include <cfloat>
43 #include <cmath>
44 #include <cstdio>
45 #include <cstdlib>
46 #include <cstring>
47
48 #include <algorithm>
49
50 #include "gromacs/gpu_utils/gpu_utils.h"
51 #include "gromacs/gpu_utils/hostallocator.h"
52 #include "gromacs/gpu_utils/pinning.h"
53 #include "gromacs/utility/alignedallocator.h"
54 #include "gromacs/utility/exceptions.h"
55 #include "gromacs/utility/fatalerror.h"
56 #include "gromacs/utility/gmxmpi.h"
57 #include "gromacs/utility/smalloc.h"
58
59 #ifdef NOGMX
60 #define GMX_PARALLEL_ENV_INITIALIZED 1
61 #else
62 #if GMX_MPI
63 #define GMX_PARALLEL_ENV_INITIALIZED 1
64 #else
65 #define GMX_PARALLEL_ENV_INITIALIZED 0
66 #endif
67 #endif
68
69 #if GMX_OPENMP
70 /* TODO: Do we still need this? Are we still planning ot use fftw + OpenMP? */
71 #define FFT5D_THREADS
72 /* requires fftw compiled with openmp */
73 /* #define FFT5D_FFTW_THREADS (now set by cmake) */
74 #endif
75
76 #ifndef __FLT_EPSILON__
77 #define __FLT_EPSILON__ FLT_EPSILON
78 #define __DBL_EPSILON__ DBL_EPSILON
79 #endif
80
81 #ifdef NOGMX
82 FILE* debug = 0;
83 #endif
84
85 #if GMX_FFT_FFTW3
86
87 #include "gromacs/utility/exceptions.h"
88 #include "gromacs/utility/mutex.h"
89 /* none of the fftw3 calls, except execute(), are thread-safe, so
90    we need to serialize them with this mutex. */
91 static gmx::Mutex big_fftw_mutex;
92 #define FFTW_LOCK try { big_fftw_mutex.lock(); } GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
93 #define FFTW_UNLOCK try { big_fftw_mutex.unlock(); } GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
94 #endif /* GMX_FFT_FFTW3 */
95
96 #if GMX_MPI
97 /* largest factor smaller than sqrt */
98 static int lfactor(int z)
99 {
100     int i = static_cast<int>(sqrt(static_cast<double>(z)));
101     while (z%i != 0)
102     {
103         i--;
104     }
105     return i;
106 }
107 #endif
108
109 #if !GMX_MPI
110 #if HAVE_GETTIMEOFDAY
111 #include <sys/time.h>
112 double MPI_Wtime()
113 {
114     struct timeval tv;
115     gettimeofday(&tv, 0);
116     return tv.tv_sec+tv.tv_usec*1e-6;
117 }
118 #else
119 double MPI_Wtime()
120 {
121     return 0.0;
122 }
123 #endif
124 #endif
125
126 static int vmax(const int* a, int s)
127 {
128     int i, max = 0;
129     for (i = 0; i < s; i++)
130     {
131         if (a[i] > max)
132         {
133             max = a[i];
134         }
135     }
136     return max;
137 }
138
139
140 /* NxMxK the size of the data
141  * comm communicator to use for fft5d
142  * P0 number of processor in 1st axes (can be null for automatic)
143  * lin is allocated by fft5d because size of array is only known after planning phase
144  * rlout2 is only used as intermediate buffer - only returned after allocation to reuse for back transform - should not be used by caller
145  */
146 fft5d_plan fft5d_plan_3d(int NG, int MG, int KG, MPI_Comm comm[2], int flags, t_complex** rlin, t_complex** rlout, t_complex** rlout2, t_complex** rlout3, int nthreads, gmx::PinningPolicy realGridAllocationPinningPolicy)
147 {
148
149     int        P[2], prank[2], i, t;
150     bool       bMaster;
151     int        rNG, rMG, rKG;
152     int       *N0 = nullptr, *N1 = nullptr, *M0 = nullptr, *M1 = nullptr, *K0 = nullptr, *K1 = nullptr, *oN0 = nullptr, *oN1 = nullptr, *oM0 = nullptr, *oM1 = nullptr, *oK0 = nullptr, *oK1 = nullptr;
153     int        N[3], M[3], K[3], pN[3], pM[3], pK[3], oM[3], oK[3], *iNin[3] = {nullptr}, *oNin[3] = {nullptr}, *iNout[3] = {nullptr}, *oNout[3] = {nullptr};
154     int        C[3], rC[3], nP[2];
155     int        lsize;
156     t_complex *lin = nullptr, *lout = nullptr, *lout2 = nullptr, *lout3 = nullptr;
157     fft5d_plan plan;
158     int        s;
159
160     /* comm, prank and P are in the order of the decomposition (plan->cart is in the order of transposes) */
161 #if GMX_MPI
162     if (GMX_PARALLEL_ENV_INITIALIZED && comm[0] != MPI_COMM_NULL)
163     {
164         MPI_Comm_size(comm[0], &P[0]);
165         MPI_Comm_rank(comm[0], &prank[0]);
166     }
167     else
168 #endif
169     {
170         P[0]     = 1;
171         prank[0] = 0;
172     }
173 #if GMX_MPI
174     if (GMX_PARALLEL_ENV_INITIALIZED && comm[1] != MPI_COMM_NULL)
175     {
176         MPI_Comm_size(comm[1], &P[1]);
177         MPI_Comm_rank(comm[1], &prank[1]);
178     }
179     else
180 #endif
181     {
182         P[1]     = 1;
183         prank[1] = 0;
184     }
185
186     bMaster = prank[0] == 0 && prank[1] == 0;
187
188
189     if (debug)
190     {
191         fprintf(debug, "FFT5D: Using %dx%d rank grid, rank %d,%d\n",
192                 P[0], P[1], prank[0], prank[1]);
193     }
194
195     if (bMaster)
196     {
197         if (debug)
198         {
199             fprintf(debug, "FFT5D: N: %d, M: %d, K: %d, P: %dx%d, real2complex: %d, backward: %d, order yz: %d, debug %d\n",
200                     NG, MG, KG, P[0], P[1], int((flags&FFT5D_REALCOMPLEX) > 0), int((flags&FFT5D_BACKWARD) > 0), int((flags&FFT5D_ORDER_YZ) > 0), int((flags&FFT5D_DEBUG) > 0));
201         }
202         /* The check below is not correct, one prime factor 11 or 13 is ok.
203            if (fft5d_fmax(fft5d_fmax(lpfactor(NG),lpfactor(MG)),lpfactor(KG))>7) {
204             printf("WARNING: FFT very slow with prime factors larger 7\n");
205             printf("Change FFT size or in case you cannot change it look at\n");
206             printf("http://www.fftw.org/fftw3_doc/Generating-your-own-code.html\n");
207            }
208          */
209     }
210
211     if (NG == 0 || MG == 0 || KG == 0)
212     {
213         if (bMaster)
214         {
215             printf("FFT5D: FATAL: Datasize cannot be zero in any dimension\n");
216         }
217         return nullptr;
218     }
219
220     rNG = NG; rMG = MG; rKG = KG;
221
222     if (flags&FFT5D_REALCOMPLEX)
223     {
224         if (!(flags&FFT5D_BACKWARD))
225         {
226             NG = NG/2+1;
227         }
228         else
229         {
230             if (!(flags&FFT5D_ORDER_YZ))
231             {
232                 MG = MG/2+1;
233             }
234             else
235             {
236                 KG = KG/2+1;
237             }
238         }
239     }
240
241
242     /*for transpose we need to know the size for each processor not only our own size*/
243
244     N0  = static_cast<int*>(malloc(P[0]*sizeof(int))); N1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
245     M0  = static_cast<int*>(malloc(P[0]*sizeof(int))); M1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
246     K0  = static_cast<int*>(malloc(P[0]*sizeof(int))); K1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
247     oN0 = static_cast<int*>(malloc(P[0]*sizeof(int))); oN1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
248     oM0 = static_cast<int*>(malloc(P[0]*sizeof(int))); oM1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
249     oK0 = static_cast<int*>(malloc(P[0]*sizeof(int))); oK1 = static_cast<int*>(malloc(P[1]*sizeof(int)));
250
251     for (i = 0; i < P[0]; i++)
252     {
253         #define EVENDIST
254         #ifndef EVENDIST
255         oN0[i] = i*ceil((double)NG/P[0]);
256         oM0[i] = i*ceil((double)MG/P[0]);
257         oK0[i] = i*ceil((double)KG/P[0]);
258         #else
259         oN0[i] = (NG*i)/P[0];
260         oM0[i] = (MG*i)/P[0];
261         oK0[i] = (KG*i)/P[0];
262         #endif
263     }
264     for (i = 0; i < P[1]; i++)
265     {
266         #ifndef EVENDIST
267         oN1[i] = i*ceil((double)NG/P[1]);
268         oM1[i] = i*ceil((double)MG/P[1]);
269         oK1[i] = i*ceil((double)KG/P[1]);
270         #else
271         oN1[i] = (NG*i)/P[1];
272         oM1[i] = (MG*i)/P[1];
273         oK1[i] = (KG*i)/P[1];
274         #endif
275     }
276     for (i = 0; i < P[0]-1; i++)
277     {
278         N0[i] = oN0[i+1]-oN0[i];
279         M0[i] = oM0[i+1]-oM0[i];
280         K0[i] = oK0[i+1]-oK0[i];
281     }
282     N0[P[0]-1] = NG-oN0[P[0]-1];
283     M0[P[0]-1] = MG-oM0[P[0]-1];
284     K0[P[0]-1] = KG-oK0[P[0]-1];
285     for (i = 0; i < P[1]-1; i++)
286     {
287         N1[i] = oN1[i+1]-oN1[i];
288         M1[i] = oM1[i+1]-oM1[i];
289         K1[i] = oK1[i+1]-oK1[i];
290     }
291     N1[P[1]-1] = NG-oN1[P[1]-1];
292     M1[P[1]-1] = MG-oM1[P[1]-1];
293     K1[P[1]-1] = KG-oK1[P[1]-1];
294
295     /* for step 1-3 the local N,M,K sizes of the transposed system
296        C: contiguous dimension, and nP: number of processor in subcommunicator
297        for that step */
298
299
300     pM[0] = M0[prank[0]];
301     oM[0] = oM0[prank[0]];
302     pK[0] = K1[prank[1]];
303     oK[0] = oK1[prank[1]];
304     C[0]  = NG;
305     rC[0] = rNG;
306     if (!(flags&FFT5D_ORDER_YZ))
307     {
308         N[0]     = vmax(N1, P[1]);
309         M[0]     = M0[prank[0]];
310         K[0]     = vmax(K1, P[1]);
311         pN[0]    = N1[prank[1]];
312         iNout[0] = N1;
313         oNout[0] = oN1;
314         nP[0]    = P[1];
315         C[1]     = KG;
316         rC[1]    = rKG;
317         N[1]     = vmax(K0, P[0]);
318         pN[1]    = K0[prank[0]];
319         iNin[1]  = K1;
320         oNin[1]  = oK1;
321         iNout[1] = K0;
322         oNout[1] = oK0;
323         M[1]     = vmax(M0, P[0]);
324         pM[1]    = M0[prank[0]];
325         oM[1]    = oM0[prank[0]];
326         K[1]     = N1[prank[1]];
327         pK[1]    = N1[prank[1]];
328         oK[1]    = oN1[prank[1]];
329         nP[1]    = P[0];
330         C[2]     = MG;
331         rC[2]    = rMG;
332         iNin[2]  = M0;
333         oNin[2]  = oM0;
334         M[2]     = vmax(K0, P[0]);
335         pM[2]    = K0[prank[0]];
336         oM[2]    = oK0[prank[0]];
337         K[2]     = vmax(N1, P[1]);
338         pK[2]    = N1[prank[1]];
339         oK[2]    = oN1[prank[1]];
340         free(N0); free(oN0); /*these are not used for this order*/
341         free(M1); free(oM1); /*the rest is freed in destroy*/
342     }
343     else
344     {
345         N[0]     = vmax(N0, P[0]);
346         M[0]     = vmax(M0, P[0]);
347         K[0]     = K1[prank[1]];
348         pN[0]    = N0[prank[0]];
349         iNout[0] = N0;
350         oNout[0] = oN0;
351         nP[0]    = P[0];
352         C[1]     = MG;
353         rC[1]    = rMG;
354         N[1]     = vmax(M1, P[1]);
355         pN[1]    = M1[prank[1]];
356         iNin[1]  = M0;
357         oNin[1]  = oM0;
358         iNout[1] = M1;
359         oNout[1] = oM1;
360         M[1]     = N0[prank[0]];
361         pM[1]    = N0[prank[0]];
362         oM[1]    = oN0[prank[0]];
363         K[1]     = vmax(K1, P[1]);
364         pK[1]    = K1[prank[1]];
365         oK[1]    = oK1[prank[1]];
366         nP[1]    = P[1];
367         C[2]     = KG;
368         rC[2]    = rKG;
369         iNin[2]  = K1;
370         oNin[2]  = oK1;
371         M[2]     = vmax(N0, P[0]);
372         pM[2]    = N0[prank[0]];
373         oM[2]    = oN0[prank[0]];
374         K[2]     = vmax(M1, P[1]);
375         pK[2]    = M1[prank[1]];
376         oK[2]    = oM1[prank[1]];
377         free(N1); free(oN1); /*these are not used for this order*/
378         free(K0); free(oK0); /*the rest is freed in destroy*/
379     }
380     N[2] = pN[2] = -1;       /*not used*/
381
382     /*
383        Difference between x-y-z regarding 2d decomposition is whether they are
384        distributed along axis 1, 2 or both
385      */
386
387     /* int lsize = fmax(N[0]*M[0]*K[0]*nP[0],N[1]*M[1]*K[1]*nP[1]); */
388     lsize = std::max(N[0]*M[0]*K[0]*nP[0], std::max(N[1]*M[1]*K[1]*nP[1], C[2]*M[2]*K[2]));
389     /* int lsize = fmax(C[0]*M[0]*K[0],fmax(C[1]*M[1]*K[1],C[2]*M[2]*K[2])); */
390     if (!(flags&FFT5D_NOMALLOC))
391     {
392         // only needed for PME GPU mixed mode
393         if (realGridAllocationPinningPolicy == gmx::PinningPolicy::PinnedIfSupported &&
394             GMX_GPU == GMX_GPU_CUDA)
395         {
396             const std::size_t numBytes = lsize * sizeof(t_complex);
397             lin = static_cast<t_complex *>(gmx::PageAlignedAllocationPolicy::malloc(numBytes));
398             gmx::pinBuffer(lin, numBytes);
399         }
400         else
401         {
402             snew_aligned(lin, lsize, 32);
403         }
404         snew_aligned(lout, lsize, 32);
405         if (nthreads > 1)
406         {
407             /* We need extra transpose buffers to avoid OpenMP barriers */
408             snew_aligned(lout2, lsize, 32);
409             snew_aligned(lout3, lsize, 32);
410         }
411         else
412         {
413             /* We can reuse the buffers to avoid cache misses */
414             lout2 = lin;
415             lout3 = lout;
416         }
417     }
418     else
419     {
420         lin  = *rlin;
421         lout = *rlout;
422         if (nthreads > 1)
423         {
424             lout2 = *rlout2;
425             lout3 = *rlout3;
426         }
427         else
428         {
429             lout2 = lin;
430             lout3 = lout;
431         }
432     }
433
434     plan = static_cast<fft5d_plan>(calloc(1, sizeof(struct fft5d_plan_t)));
435
436
437     if (debug)
438     {
439         fprintf(debug, "Running on %d threads\n", nthreads);
440     }
441
442 #if GMX_FFT_FFTW3
443     /* Don't add more stuff here! We have already had at least one bug because we are reimplementing
444      * the low-level FFT interface instead of using the Gromacs FFT module. If we need more
445      * generic functionality it is far better to extend the interface so we can use it for
446      * all FFT libraries instead of writing FFTW-specific code here.
447      */
448
449     /*if not FFTW - then we don't do a 3d plan but instead use only 1D plans */
450     /* It is possible to use the 3d plan with OMP threads - but in that case it is not allowed to be called from
451      * within a parallel region. For now deactivated. If it should be supported it has to made sure that
452      * that the execute of the 3d plan is in a master/serial block (since it contains it own parallel region)
453      * and that the 3d plan is faster than the 1d plan.
454      */
455     if ((!(flags&FFT5D_INPLACE)) && (!(P[0] > 1 || P[1] > 1)) && nthreads == 1) /*don't do 3d plan in parallel or if in_place requested */
456     {
457         int fftwflags = FFTW_DESTROY_INPUT;
458         FFTW(iodim) dims[3];
459         int inNG = NG, outMG = MG, outKG = KG;
460
461         FFTW_LOCK;
462
463         fftwflags |= (flags & FFT5D_NOMEASURE) ? FFTW_ESTIMATE : FFTW_MEASURE;
464
465         if (flags&FFT5D_REALCOMPLEX)
466         {
467             if (!(flags&FFT5D_BACKWARD))        /*input pointer is not complex*/
468             {
469                 inNG *= 2;
470             }
471             else                                /*output pointer is not complex*/
472             {
473                 if (!(flags&FFT5D_ORDER_YZ))
474                 {
475                     outMG *= 2;
476                 }
477                 else
478                 {
479                     outKG *= 2;
480                 }
481             }
482         }
483
484         if (!(flags&FFT5D_BACKWARD))
485         {
486             dims[0].n  = KG;
487             dims[1].n  = MG;
488             dims[2].n  = rNG;
489
490             dims[0].is = inNG*MG;         /*N M K*/
491             dims[1].is = inNG;
492             dims[2].is = 1;
493             if (!(flags&FFT5D_ORDER_YZ))
494             {
495                 dims[0].os = MG;           /*M K N*/
496                 dims[1].os = 1;
497                 dims[2].os = MG*KG;
498             }
499             else
500             {
501                 dims[0].os = 1;           /*K N M*/
502                 dims[1].os = KG*NG;
503                 dims[2].os = KG;
504             }
505         }
506         else
507         {
508             if (!(flags&FFT5D_ORDER_YZ))
509             {
510                 dims[0].n  = NG;
511                 dims[1].n  = KG;
512                 dims[2].n  = rMG;
513
514                 dims[0].is = 1;
515                 dims[1].is = NG*MG;
516                 dims[2].is = NG;
517
518                 dims[0].os = outMG*KG;
519                 dims[1].os = outMG;
520                 dims[2].os = 1;
521             }
522             else
523             {
524                 dims[0].n  = MG;
525                 dims[1].n  = NG;
526                 dims[2].n  = rKG;
527
528                 dims[0].is = NG;
529                 dims[1].is = 1;
530                 dims[2].is = NG*MG;
531
532                 dims[0].os = outKG*NG;
533                 dims[1].os = outKG;
534                 dims[2].os = 1;
535             }
536         }
537 #ifdef FFT5D_THREADS
538 #ifdef FFT5D_FFTW_THREADS
539         FFTW(plan_with_nthreads)(nthreads);
540 #endif
541 #endif
542         if ((flags&FFT5D_REALCOMPLEX) && !(flags&FFT5D_BACKWARD))
543         {
544             plan->p3d = FFTW(plan_guru_dft_r2c)(/*rank*/ 3, dims,
545                                                          /*howmany*/ 0, /*howmany_dims*/ nullptr,
546                                                          reinterpret_cast<real*>(lin), reinterpret_cast<FFTW(complex) *>(lout),
547                                                          /*flags*/ fftwflags);
548         }
549         else if ((flags&FFT5D_REALCOMPLEX) && (flags&FFT5D_BACKWARD))
550         {
551             plan->p3d = FFTW(plan_guru_dft_c2r)(/*rank*/ 3, dims,
552                                                          /*howmany*/ 0, /*howmany_dims*/ nullptr,
553                                                          reinterpret_cast<FFTW(complex) *>(lin), reinterpret_cast<real*>(lout),
554                                                          /*flags*/ fftwflags);
555         }
556         else
557         {
558             plan->p3d = FFTW(plan_guru_dft)(/*rank*/ 3, dims,
559                                                      /*howmany*/ 0, /*howmany_dims*/ nullptr,
560                                                      reinterpret_cast<FFTW(complex) *>(lin), reinterpret_cast<FFTW(complex) *>(lout),
561                                                      /*sign*/ (flags&FFT5D_BACKWARD) ? 1 : -1, /*flags*/ fftwflags);
562         }
563 #ifdef FFT5D_THREADS
564 #ifdef FFT5D_FFTW_THREADS
565         FFTW(plan_with_nthreads)(1);
566 #endif
567 #endif
568         FFTW_UNLOCK;
569     }
570     if (!plan->p3d) /* for decomposition and if 3d plan did not work */
571     {
572 #endif              /* GMX_FFT_FFTW3 */
573     for (s = 0; s < 3; s++)
574     {
575         if (debug)
576         {
577             fprintf(debug, "FFT5D: Plan s %d rC %d M %d pK %d C %d lsize %d\n",
578                     s, rC[s], M[s], pK[s], C[s], lsize);
579         }
580         plan->p1d[s] = static_cast<gmx_fft_t*>(malloc(sizeof(gmx_fft_t)*nthreads));
581
582         /* Make sure that the init routines are only called by one thread at a time and in order
583            (later is only important to not confuse valgrind)
584          */
585 #pragma omp parallel for num_threads(nthreads) schedule(static) ordered
586         for (t = 0; t < nthreads; t++)
587         {
588 #pragma omp ordered
589             {
590                 try
591                 {
592                     int tsize = ((t+1)*pM[s]*pK[s]/nthreads)-(t*pM[s]*pK[s]/nthreads);
593
594                     if ((flags&FFT5D_REALCOMPLEX) && ((!(flags&FFT5D_BACKWARD) && s == 0) || ((flags&FFT5D_BACKWARD) && s == 2)))
595                     {
596                         gmx_fft_init_many_1d_real( &plan->p1d[s][t], rC[s], tsize, (flags&FFT5D_NOMEASURE) ? GMX_FFT_FLAG_CONSERVATIVE : 0 );
597                     }
598                     else
599                     {
600                         gmx_fft_init_many_1d     ( &plan->p1d[s][t],  C[s], tsize, (flags&FFT5D_NOMEASURE) ? GMX_FFT_FLAG_CONSERVATIVE : 0 );
601                     }
602                 }
603                 GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
604             }
605         }
606     }
607
608 #if GMX_FFT_FFTW3
609 }
610 #endif
611     if ((flags&FFT5D_ORDER_YZ))   /*plan->cart is in the order of transposes */
612     {
613         plan->cart[0] = comm[0]; plan->cart[1] = comm[1];
614     }
615     else
616     {
617         plan->cart[1] = comm[0]; plan->cart[0] = comm[1];
618     }
619 #ifdef FFT5D_MPI_TRANSPOSE
620     FFTW_LOCK;
621     for (s = 0; s < 2; s++)
622     {
623         if ((s == 0 && !(flags&FFT5D_ORDER_YZ)) || (s == 1 && (flags&FFT5D_ORDER_YZ)))
624         {
625             plan->mpip[s] = FFTW(mpi_plan_many_transpose)(nP[s], nP[s], N[s]*K[s]*pM[s]*2, 1, 1, (real*)lout2, (real*)lout3, plan->cart[s], FFTW_PATIENT);
626         }
627         else
628         {
629             plan->mpip[s] = FFTW(mpi_plan_many_transpose)(nP[s], nP[s], N[s]*pK[s]*M[s]*2, 1, 1, (real*)lout2, (real*)lout3, plan->cart[s], FFTW_PATIENT);
630         }
631     }
632     FFTW_UNLOCK;
633 #endif
634
635
636     plan->lin   = lin;
637     plan->lout  = lout;
638     plan->lout2 = lout2;
639     plan->lout3 = lout3;
640
641     plan->NG = NG; plan->MG = MG; plan->KG = KG;
642
643     for (s = 0; s < 3; s++)
644     {
645         plan->N[s]    = N[s]; plan->M[s] = M[s]; plan->K[s] = K[s]; plan->pN[s] = pN[s]; plan->pM[s] = pM[s]; plan->pK[s] = pK[s];
646         plan->oM[s]   = oM[s]; plan->oK[s] = oK[s];
647         plan->C[s]    = C[s]; plan->rC[s] = rC[s];
648         plan->iNin[s] = iNin[s]; plan->oNin[s] = oNin[s]; plan->iNout[s] = iNout[s]; plan->oNout[s] = oNout[s];
649     }
650     for (s = 0; s < 2; s++)
651     {
652         plan->P[s] = nP[s]; plan->coor[s] = prank[s];
653     }
654
655 /*    plan->fftorder=fftorder;
656     plan->direction=direction;
657     plan->realcomplex=realcomplex;
658  */
659     plan->flags         = flags;
660     plan->nthreads      = nthreads;
661     plan->pinningPolicy = realGridAllocationPinningPolicy;
662     *rlin               = lin;
663     *rlout              = lout;
664     *rlout2             = lout2;
665     *rlout3             = lout3;
666     return plan;
667 }
668
669
670 enum order {
671     XYZ,
672     XZY,
673     YXZ,
674     YZX,
675     ZXY,
676     ZYX
677 };
678
679
680
681 /*here x,y,z and N,M,K is in rotated coordinate system!!
682    x (and N) is mayor (consecutive) dimension, y (M) middle and z (K) major
683    maxN,maxM,maxK is max size of local data
684    pN, pM, pK is local size specific to current processor (only different to max if not divisible)
685    NG, MG, KG is size of global data*/
686 static void splitaxes(t_complex* lout, const t_complex* lin,
687                       int maxN, int maxM, int maxK, int pM,
688                       int P, int NG, const int *N, const int* oN, int starty, int startz, int endy, int endz)
689 {
690     int x, y, z, i;
691     int in_i, out_i, in_z, out_z, in_y, out_y;
692     int s_y, e_y;
693
694     for (z = startz; z < endz+1; z++) /*3. z l*/
695     {
696         if (z == startz)
697         {
698             s_y = starty;
699         }
700         else
701         {
702             s_y = 0;
703         }
704         if (z == endz)
705         {
706             e_y = endy;
707         }
708         else
709         {
710             e_y = pM;
711         }
712         out_z  = z*maxN*maxM;
713         in_z   = z*NG*pM;
714
715         for (i = 0; i < P; i++) /*index cube along long axis*/
716         {
717             out_i  = out_z  + i*maxN*maxM*maxK;
718             in_i   = in_z + oN[i];
719             for (y = s_y; y < e_y; y++)   /*2. y k*/
720             {
721                 out_y  = out_i  + y*maxN;
722                 in_y   = in_i + y*NG;
723                 for (x = 0; x < N[i]; x++)       /*1. x j*/
724                 {
725                     lout[out_y+x] = lin[in_y+x]; /*in=z*NG*pM+oN[i]+y*NG+x*/
726                     /*after split important that each processor chunk i has size maxN*maxM*maxK and thus being the same size*/
727                     /*before split data contiguos - thus if different processor get different amount oN is different*/
728                 }
729             }
730         }
731     }
732 }
733
734 /*make axis contiguous again (after AllToAll) and also do local transpose*/
735 /*transpose mayor and major dimension
736    variables see above
737    the major, middle, minor order is only correct for x,y,z (N,M,K) for the input
738    N,M,K local dimensions
739    KG global size*/
740 static void joinAxesTrans13(t_complex* lout, const t_complex* lin,
741                             int maxN, int maxM, int maxK, int pM,
742                             int P, int KG, const int* K, const int* oK, int starty, int startx, int endy, int endx)
743 {
744     int i, x, y, z;
745     int out_i, in_i, out_x, in_x, out_z, in_z;
746     int s_y, e_y;
747
748     for (x = startx; x < endx+1; x++) /*1.j*/
749     {
750         if (x == startx)
751         {
752             s_y = starty;
753         }
754         else
755         {
756             s_y = 0;
757         }
758         if (x == endx)
759         {
760             e_y = endy;
761         }
762         else
763         {
764             e_y = pM;
765         }
766
767         out_x  = x*KG*pM;
768         in_x   = x;
769
770         for (i = 0; i < P; i++) /*index cube along long axis*/
771         {
772             out_i  = out_x  + oK[i];
773             in_i   = in_x + i*maxM*maxN*maxK;
774             for (z = 0; z < K[i]; z++) /*3.l*/
775             {
776                 out_z  = out_i  + z;
777                 in_z   = in_i + z*maxM*maxN;
778                 for (y = s_y; y < e_y; y++)              /*2.k*/
779                 {
780                     lout[out_z+y*KG] = lin[in_z+y*maxN]; /*out=x*KG*pM+oK[i]+z+y*KG*/
781                 }
782             }
783         }
784     }
785 }
786
787 /*make axis contiguous again (after AllToAll) and also do local transpose
788    tranpose mayor and middle dimension
789    variables see above
790    the minor, middle, major order is only correct for x,y,z (N,M,K) for the input
791    N,M,K local size
792    MG, global size*/
793 static void joinAxesTrans12(t_complex* lout, const t_complex* lin, int maxN, int maxM, int maxK, int pN,
794                             int P, int MG, const int* M, const int* oM, int startx, int startz, int endx, int endz)
795 {
796     int i, z, y, x;
797     int out_i, in_i, out_z, in_z, out_x, in_x;
798     int s_x, e_x;
799
800     for (z = startz; z < endz+1; z++)
801     {
802         if (z == startz)
803         {
804             s_x = startx;
805         }
806         else
807         {
808             s_x = 0;
809         }
810         if (z == endz)
811         {
812             e_x = endx;
813         }
814         else
815         {
816             e_x = pN;
817         }
818         out_z  = z*MG*pN;
819         in_z   = z*maxM*maxN;
820
821         for (i = 0; i < P; i++) /*index cube along long axis*/
822         {
823             out_i  = out_z  + oM[i];
824             in_i   = in_z + i*maxM*maxN*maxK;
825             for (x = s_x; x < e_x; x++)
826             {
827                 out_x  = out_i  + x*MG;
828                 in_x   = in_i + x;
829                 for (y = 0; y < M[i]; y++)
830                 {
831                     lout[out_x+y] = lin[in_x+y*maxN]; /*out=z*MG*pN+oM[i]+x*MG+y*/
832                 }
833             }
834         }
835     }
836 }
837
838
839 static void rotate_offsets(int x[])
840 {
841     int t = x[0];
842 /*    x[0]=x[2];
843     x[2]=x[1];
844     x[1]=t;*/
845     x[0] = x[1];
846     x[1] = x[2];
847     x[2] = t;
848 }
849
850 /*compute the offset to compare or print transposed local data in original input coordinates
851    xs matrix dimension size, xl dimension length, xc decomposition offset
852    s: step in computation = number of transposes*/
853 static void compute_offsets(fft5d_plan plan, int xs[], int xl[], int xc[], int NG[], int s)
854 {
855 /*    int direction = plan->direction;
856     int fftorder = plan->fftorder;*/
857
858     int  o = 0;
859     int  pos[3], i;
860     int *pM = plan->pM, *pK = plan->pK, *oM = plan->oM, *oK = plan->oK,
861     *C      = plan->C, *rC = plan->rC;
862
863     NG[0] = plan->NG; NG[1] = plan->MG; NG[2] = plan->KG;
864
865     if (!(plan->flags&FFT5D_ORDER_YZ))
866     {
867         switch (s)
868         {
869             case 0: o = XYZ; break;
870             case 1: o = ZYX; break;
871             case 2: o = YZX; break;
872             default: assert(0);
873         }
874     }
875     else
876     {
877         switch (s)
878         {
879             case 0: o = XYZ; break;
880             case 1: o = YXZ; break;
881             case 2: o = ZXY; break;
882             default: assert(0);
883         }
884     }
885
886     switch (o)
887     {
888         case XYZ: pos[0] = 1; pos[1] = 2; pos[2] = 3; break;
889         case XZY: pos[0] = 1; pos[1] = 3; pos[2] = 2; break;
890         case YXZ: pos[0] = 2; pos[1] = 1; pos[2] = 3; break;
891         case YZX: pos[0] = 3; pos[1] = 1; pos[2] = 2; break;
892         case ZXY: pos[0] = 2; pos[1] = 3; pos[2] = 1; break;
893         case ZYX: pos[0] = 3; pos[1] = 2; pos[2] = 1; break;
894     }
895     /*if (debug) printf("pos: %d %d %d\n",pos[0],pos[1],pos[2]);*/
896
897     /*xs, xl give dimension size and data length in local transposed coordinate system
898        for 0(/1/2): x(/y/z) in original coordinate system*/
899     for (i = 0; i < 3; i++)
900     {
901         switch (pos[i])
902         {
903             case 1: xs[i] = 1;         xc[i] = 0;     xl[i] = C[s]; break;
904             case 2: xs[i] = C[s];      xc[i] = oM[s]; xl[i] = pM[s]; break;
905             case 3: xs[i] = C[s]*pM[s]; xc[i] = oK[s]; xl[i] = pK[s]; break;
906         }
907     }
908     /*input order is different for test program to match FFTW order
909        (important for complex to real)*/
910     if (plan->flags&FFT5D_BACKWARD)
911     {
912         rotate_offsets(xs);
913         rotate_offsets(xl);
914         rotate_offsets(xc);
915         rotate_offsets(NG);
916         if (plan->flags&FFT5D_ORDER_YZ)
917         {
918             rotate_offsets(xs);
919             rotate_offsets(xl);
920             rotate_offsets(xc);
921             rotate_offsets(NG);
922         }
923     }
924     if ((plan->flags&FFT5D_REALCOMPLEX) && ((!(plan->flags&FFT5D_BACKWARD) && s == 0) || ((plan->flags&FFT5D_BACKWARD) && s == 2)))
925     {
926         xl[0] = rC[s];
927     }
928 }
929
930 static void print_localdata(const t_complex* lin, const char* txt, int s, fft5d_plan plan)
931 {
932     int  x, y, z, l;
933     int *coor = plan->coor;
934     int  xs[3], xl[3], xc[3], NG[3];
935     int  ll = (plan->flags&FFT5D_REALCOMPLEX) ? 1 : 2;
936     compute_offsets(plan, xs, xl, xc, NG, s);
937     fprintf(debug, txt, coor[0], coor[1]);
938     /*printf("xs: %d %d %d, xl: %d %d %d\n",xs[0],xs[1],xs[2],xl[0],xl[1],xl[2]);*/
939     for (z = 0; z < xl[2]; z++)
940     {
941         for (y = 0; y < xl[1]; y++)
942         {
943             fprintf(debug, "%d %d: ", coor[0], coor[1]);
944             for (x = 0; x < xl[0]; x++)
945             {
946                 for (l = 0; l < ll; l++)
947                 {
948                     fprintf(debug, "%f ", reinterpret_cast<const real*>(lin)[(z*xs[2]+y*xs[1])*2+(x*xs[0])*ll+l]);
949                 }
950                 fprintf(debug, ",");
951             }
952             fprintf(debug, "\n");
953         }
954     }
955 }
956
957 void fft5d_execute(fft5d_plan plan, int thread, fft5d_time times)
958 {
959     t_complex  *lin   = plan->lin;
960     t_complex  *lout  = plan->lout;
961     t_complex  *lout2 = plan->lout2;
962     t_complex  *lout3 = plan->lout3;
963     t_complex  *fftout, *joinin;
964
965     gmx_fft_t **p1d = plan->p1d;
966 #ifdef FFT5D_MPI_TRANSPOSE
967     FFTW(plan) *mpip = plan->mpip;
968 #endif
969 #if GMX_MPI
970     MPI_Comm *cart = plan->cart;
971 #endif
972 #ifdef NOGMX
973     double time_fft = 0, time_local = 0, time_mpi[2] = {0}, time = 0;
974 #endif
975     int   *N = plan->N, *M = plan->M, *K = plan->K, *pN = plan->pN, *pM = plan->pM, *pK = plan->pK,
976     *C       = plan->C, *P = plan->P, **iNin = plan->iNin, **oNin = plan->oNin, **iNout = plan->iNout, **oNout = plan->oNout;
977     int    s = 0, tstart, tend, bParallelDim;
978
979
980 #if GMX_FFT_FFTW3
981     if (plan->p3d)
982     {
983         if (thread == 0)
984         {
985 #ifdef NOGMX
986             if (times != 0)
987             {
988                 time = MPI_Wtime();
989             }
990 #endif
991             FFTW(execute)(plan->p3d);
992 #ifdef NOGMX
993             if (times != 0)
994             {
995                 times->fft += MPI_Wtime()-time;
996             }
997 #endif
998         }
999         return;
1000     }
1001 #endif
1002
1003     s = 0;
1004
1005     /*lin: x,y,z*/
1006     if ((plan->flags&FFT5D_DEBUG) && thread == 0)
1007     {
1008         print_localdata(lin, "%d %d: copy in lin\n", s, plan);
1009     }
1010
1011     for (s = 0; s < 2; s++)  /*loop over first two FFT steps (corner rotations)*/
1012
1013     {
1014 #if GMX_MPI
1015         if (GMX_PARALLEL_ENV_INITIALIZED && cart[s] != MPI_COMM_NULL && P[s] > 1)
1016         {
1017             bParallelDim = 1;
1018         }
1019         else
1020 #endif
1021         {
1022             bParallelDim = 0;
1023         }
1024
1025         /* ---------- START FFT ------------ */
1026 #ifdef NOGMX
1027         if (times != 0 && thread == 0)
1028         {
1029             time = MPI_Wtime();
1030         }
1031 #endif
1032
1033         if (bParallelDim || plan->nthreads == 1)
1034         {
1035             fftout = lout;
1036         }
1037         else
1038         {
1039             if (s == 0)
1040             {
1041                 fftout = lout3;
1042             }
1043             else
1044             {
1045                 fftout = lout2;
1046             }
1047         }
1048
1049         tstart = (thread*pM[s]*pK[s]/plan->nthreads)*C[s];
1050         if ((plan->flags&FFT5D_REALCOMPLEX) && !(plan->flags&FFT5D_BACKWARD) && s == 0)
1051         {
1052             gmx_fft_many_1d_real(p1d[s][thread], (plan->flags&FFT5D_BACKWARD) ? GMX_FFT_COMPLEX_TO_REAL : GMX_FFT_REAL_TO_COMPLEX, lin+tstart, fftout+tstart);
1053         }
1054         else
1055         {
1056             gmx_fft_many_1d(     p1d[s][thread], (plan->flags&FFT5D_BACKWARD) ? GMX_FFT_BACKWARD : GMX_FFT_FORWARD,               lin+tstart, fftout+tstart);
1057
1058         }
1059
1060 #ifdef NOGMX
1061         if (times != NULL && thread == 0)
1062         {
1063             time_fft += MPI_Wtime()-time;
1064         }
1065 #endif
1066         if ((plan->flags&FFT5D_DEBUG) && thread == 0)
1067         {
1068             print_localdata(lout, "%d %d: FFT %d\n", s, plan);
1069         }
1070         /* ---------- END FFT ------------ */
1071
1072         /* ---------- START SPLIT + TRANSPOSE------------ (if parallel in in this dimension)*/
1073         if (bParallelDim)
1074         {
1075 #ifdef NOGMX
1076             if (times != NULL && thread == 0)
1077             {
1078                 time = MPI_Wtime();
1079             }
1080 #endif
1081             /*prepare for A
1082                llToAll
1083                1. (most outer) axes (x) is split into P[s] parts of size N[s]
1084                for sending*/
1085             if (pM[s] > 0)
1086             {
1087                 tend    = ((thread+1)*pM[s]*pK[s]/plan->nthreads);
1088                 tstart /= C[s];
1089                 splitaxes(lout2, lout, N[s], M[s], K[s], pM[s], P[s], C[s], iNout[s], oNout[s], tstart%pM[s], tstart/pM[s], tend%pM[s], tend/pM[s]);
1090             }
1091 #pragma omp barrier /*barrier required before AllToAll (all input has to be their) - before timing to make timing more acurate*/
1092 #ifdef NOGMX
1093             if (times != NULL && thread == 0)
1094             {
1095                 time_local += MPI_Wtime()-time;
1096             }
1097 #endif
1098
1099             /* ---------- END SPLIT , START TRANSPOSE------------ */
1100
1101             if (thread == 0)
1102             {
1103 #ifdef NOGMX
1104                 if (times != 0)
1105                 {
1106                     time = MPI_Wtime();
1107                 }
1108 #else
1109                 wallcycle_start(times, ewcPME_FFTCOMM);
1110 #endif
1111 #ifdef FFT5D_MPI_TRANSPOSE
1112                 FFTW(execute)(mpip[s]);
1113 #else
1114 #if GMX_MPI
1115                 if ((s == 0 && !(plan->flags&FFT5D_ORDER_YZ)) || (s == 1 && (plan->flags&FFT5D_ORDER_YZ)))
1116                 {
1117                     MPI_Alltoall(reinterpret_cast<real *>(lout2), N[s]*pM[s]*K[s]*sizeof(t_complex)/sizeof(real), GMX_MPI_REAL, reinterpret_cast<real *>(lout3), N[s]*pM[s]*K[s]*sizeof(t_complex)/sizeof(real), GMX_MPI_REAL, cart[s]);
1118                 }
1119                 else
1120                 {
1121                     MPI_Alltoall(reinterpret_cast<real *>(lout2), N[s]*M[s]*pK[s]*sizeof(t_complex)/sizeof(real), GMX_MPI_REAL, reinterpret_cast<real *>(lout3), N[s]*M[s]*pK[s]*sizeof(t_complex)/sizeof(real), GMX_MPI_REAL, cart[s]);
1122                 }
1123 #else
1124                 gmx_incons("fft5d MPI call without MPI configuration");
1125 #endif /*GMX_MPI*/
1126 #endif /*FFT5D_MPI_TRANSPOSE*/
1127 #ifdef NOGMX
1128                 if (times != 0)
1129                 {
1130                     time_mpi[s] = MPI_Wtime()-time;
1131                 }
1132 #else
1133                 wallcycle_stop(times, ewcPME_FFTCOMM);
1134 #endif
1135             } /*master*/
1136         }     /* bPrallelDim */
1137 #pragma omp barrier  /*both needed for parallel and non-parallel dimension (either have to wait on data from AlltoAll or from last FFT*/
1138
1139         /* ---------- END SPLIT + TRANSPOSE------------ */
1140
1141         /* ---------- START JOIN ------------ */
1142 #ifdef NOGMX
1143         if (times != NULL && thread == 0)
1144         {
1145             time = MPI_Wtime();
1146         }
1147 #endif
1148
1149         if (bParallelDim)
1150         {
1151             joinin = lout3;
1152         }
1153         else
1154         {
1155             joinin = fftout;
1156         }
1157         /*bring back in matrix form
1158            thus make  new 1. axes contiguos
1159            also local transpose 1 and 2/3
1160            runs on thread used for following FFT (thus needing a barrier before but not afterwards)
1161          */
1162         if ((s == 0 && !(plan->flags&FFT5D_ORDER_YZ)) || (s == 1 && (plan->flags&FFT5D_ORDER_YZ)))
1163         {
1164             if (pM[s] > 0)
1165             {
1166                 tstart = ( thread   *pM[s]*pN[s]/plan->nthreads);
1167                 tend   = ((thread+1)*pM[s]*pN[s]/plan->nthreads);
1168                 joinAxesTrans13(lin, joinin, N[s], pM[s], K[s], pM[s], P[s], C[s+1], iNin[s+1], oNin[s+1], tstart%pM[s], tstart/pM[s], tend%pM[s], tend/pM[s]);
1169             }
1170         }
1171         else
1172         {
1173             if (pN[s] > 0)
1174             {
1175                 tstart = ( thread   *pK[s]*pN[s]/plan->nthreads);
1176                 tend   = ((thread+1)*pK[s]*pN[s]/plan->nthreads);
1177                 joinAxesTrans12(lin, joinin, N[s], M[s], pK[s], pN[s], P[s], C[s+1], iNin[s+1], oNin[s+1], tstart%pN[s], tstart/pN[s], tend%pN[s], tend/pN[s]);
1178             }
1179         }
1180
1181 #ifdef NOGMX
1182         if (times != NULL && thread == 0)
1183         {
1184             time_local += MPI_Wtime()-time;
1185         }
1186 #endif
1187         if ((plan->flags&FFT5D_DEBUG) && thread == 0)
1188         {
1189             print_localdata(lin, "%d %d: tranposed %d\n", s+1, plan);
1190         }
1191         /* ---------- END JOIN ------------ */
1192
1193         /*if (debug) print_localdata(lin, "%d %d: transposed x-z\n", N1, M0, K, ZYX, coor);*/
1194     }  /* for(s=0;s<2;s++) */
1195 #ifdef NOGMX
1196     if (times != NULL && thread == 0)
1197     {
1198         time = MPI_Wtime();
1199     }
1200 #endif
1201
1202     if (plan->flags&FFT5D_INPLACE)
1203     {
1204         lout = lin;                          /*in place currently not supported*/
1205
1206     }
1207     /*  ----------- FFT ----------- */
1208     tstart = (thread*pM[s]*pK[s]/plan->nthreads)*C[s];
1209     if ((plan->flags&FFT5D_REALCOMPLEX) && (plan->flags&FFT5D_BACKWARD))
1210     {
1211         gmx_fft_many_1d_real(p1d[s][thread], (plan->flags&FFT5D_BACKWARD) ? GMX_FFT_COMPLEX_TO_REAL : GMX_FFT_REAL_TO_COMPLEX, lin+tstart, lout+tstart);
1212     }
1213     else
1214     {
1215         gmx_fft_many_1d(     p1d[s][thread], (plan->flags&FFT5D_BACKWARD) ? GMX_FFT_BACKWARD : GMX_FFT_FORWARD,               lin+tstart, lout+tstart);
1216     }
1217     /* ------------ END FFT ---------*/
1218
1219 #ifdef NOGMX
1220     if (times != NULL && thread == 0)
1221     {
1222         time_fft += MPI_Wtime()-time;
1223
1224         times->fft   += time_fft;
1225         times->local += time_local;
1226         times->mpi2  += time_mpi[1];
1227         times->mpi1  += time_mpi[0];
1228     }
1229 #endif
1230
1231     if ((plan->flags&FFT5D_DEBUG) && thread == 0)
1232     {
1233         print_localdata(lout, "%d %d: FFT %d\n", s, plan);
1234     }
1235 }
1236
1237 void fft5d_destroy(fft5d_plan plan)
1238 {
1239     int s, t;
1240
1241     for (s = 0; s < 3; s++)
1242     {
1243         if (plan->p1d[s])
1244         {
1245             for (t = 0; t < plan->nthreads; t++)
1246             {
1247                 gmx_many_fft_destroy(plan->p1d[s][t]);
1248             }
1249             free(plan->p1d[s]);
1250         }
1251         if (plan->iNin[s])
1252         {
1253             free(plan->iNin[s]);
1254             plan->iNin[s] = nullptr;
1255         }
1256         if (plan->oNin[s])
1257         {
1258             free(plan->oNin[s]);
1259             plan->oNin[s] = nullptr;
1260         }
1261         if (plan->iNout[s])
1262         {
1263             free(plan->iNout[s]);
1264             plan->iNout[s] = nullptr;
1265         }
1266         if (plan->oNout[s])
1267         {
1268             free(plan->oNout[s]);
1269             plan->oNout[s] = nullptr;
1270         }
1271     }
1272 #if GMX_FFT_FFTW3
1273     FFTW_LOCK;
1274 #ifdef FFT5D_MPI_TRANSPOS
1275     for (s = 0; s < 2; s++)
1276     {
1277         FFTW(destroy_plan)(plan->mpip[s]);
1278     }
1279 #endif /* FFT5D_MPI_TRANSPOS */
1280     if (plan->p3d)
1281     {
1282         FFTW(destroy_plan)(plan->p3d);
1283     }
1284     FFTW_UNLOCK;
1285 #endif /* GMX_FFT_FFTW3 */
1286
1287     if (!(plan->flags&FFT5D_NOMALLOC))
1288     {
1289         // only needed for PME GPU mixed mode
1290         if (plan->pinningPolicy == gmx::PinningPolicy::PinnedIfSupported &&
1291             isHostMemoryPinned(plan->lin))
1292         {
1293             gmx::unpinBuffer(plan->lin);
1294         }
1295         sfree_aligned(plan->lin);
1296         sfree_aligned(plan->lout);
1297         if (plan->nthreads > 1)
1298         {
1299             sfree_aligned(plan->lout2);
1300             sfree_aligned(plan->lout3);
1301         }
1302     }
1303
1304 #ifdef FFT5D_THREADS
1305 #ifdef FFT5D_FFTW_THREADS
1306     /*FFTW(cleanup_threads)();*/
1307 #endif
1308 #endif
1309
1310     free(plan);
1311 }
1312
1313 /*Is this better than direct access of plan? enough data?
1314    here 0,1 reference divided by which processor grid dimension (not FFT step!)*/
1315 void fft5d_local_size(fft5d_plan plan, int* N1, int* M0, int* K0, int* K1, int** coor)
1316 {
1317     *N1 = plan->N[0];
1318     *M0 = plan->M[0];
1319     *K1 = plan->K[0];
1320     *K0 = plan->N[1];
1321
1322     *coor = plan->coor;
1323 }
1324
1325
1326 /*same as fft5d_plan_3d but with cartesian coordinator and automatic splitting
1327    of processor dimensions*/
1328 fft5d_plan fft5d_plan_3d_cart(int NG, int MG, int KG, MPI_Comm comm, int P0, int flags, t_complex** rlin, t_complex** rlout, t_complex** rlout2, t_complex** rlout3, int nthreads)
1329 {
1330     MPI_Comm cart[2] = {MPI_COMM_NULL, MPI_COMM_NULL};
1331 #if GMX_MPI
1332     int      size = 1, prank = 0;
1333     int      P[2];
1334     int      coor[2];
1335     int      wrap[] = {0, 0};
1336     MPI_Comm gcart;
1337     int      rdim1[] = {0, 1}, rdim2[] = {1, 0};
1338
1339     MPI_Comm_size(comm, &size);
1340     MPI_Comm_rank(comm, &prank);
1341
1342     if (P0 == 0)
1343     {
1344         P0 = lfactor(size);
1345     }
1346     if (size%P0 != 0)
1347     {
1348         if (prank == 0)
1349         {
1350             printf("FFT5D: WARNING: Number of ranks %d not evenly divisible by %d\n", size, P0);
1351         }
1352         P0 = lfactor(size);
1353     }
1354
1355     P[0] = P0; P[1] = size/P0; /*number of processors in the two dimensions*/
1356
1357     /*Difference between x-y-z regarding 2d decomposition is whether they are
1358        distributed along axis 1, 2 or both*/
1359
1360     MPI_Cart_create(comm, 2, P, wrap, 1, &gcart); /*parameter 4: value 1: reorder*/
1361     MPI_Cart_get(gcart, 2, P, wrap, coor);
1362     MPI_Cart_sub(gcart, rdim1, &cart[0]);
1363     MPI_Cart_sub(gcart, rdim2, &cart[1]);
1364 #else
1365     (void)P0;
1366     (void)comm;
1367 #endif
1368     return fft5d_plan_3d(NG, MG, KG, cart, flags, rlin, rlout, rlout2, rlout3, nthreads);
1369 }
1370
1371
1372
1373 /*prints in original coordinate system of data (as the input to FFT)*/
1374 void fft5d_compare_data(const t_complex* lin, const t_complex* in, fft5d_plan plan, int bothLocal, int normalize)
1375 {
1376     int  xs[3], xl[3], xc[3], NG[3];
1377     int  x, y, z, l;
1378     int *coor = plan->coor;
1379     int  ll   = 2; /*compare ll values per element (has to be 2 for complex)*/
1380     if ((plan->flags&FFT5D_REALCOMPLEX) && (plan->flags&FFT5D_BACKWARD))
1381     {
1382         ll = 1;
1383     }
1384
1385     compute_offsets(plan, xs, xl, xc, NG, 2);
1386     if (plan->flags&FFT5D_DEBUG)
1387     {
1388         printf("Compare2\n");
1389     }
1390     for (z = 0; z < xl[2]; z++)
1391     {
1392         for (y = 0; y < xl[1]; y++)
1393         {
1394             if (plan->flags&FFT5D_DEBUG)
1395             {
1396                 printf("%d %d: ", coor[0], coor[1]);
1397             }
1398             for (x = 0; x < xl[0]; x++)
1399             {
1400                 for (l = 0; l < ll; l++)   /*loop over real/complex parts*/
1401                 {
1402                     real a, b;
1403                     a = reinterpret_cast<const real*>(lin)[(z*xs[2]+y*xs[1])*2+x*xs[0]*ll+l];
1404                     if (normalize)
1405                     {
1406                         a /= plan->rC[0]*plan->rC[1]*plan->rC[2];
1407                     }
1408                     if (!bothLocal)
1409                     {
1410                         b = reinterpret_cast<const real*>(in)[((z+xc[2])*NG[0]*NG[1]+(y+xc[1])*NG[0])*2+(x+xc[0])*ll+l];
1411                     }
1412                     else
1413                     {
1414                         b = reinterpret_cast<const real*>(in)[(z*xs[2]+y*xs[1])*2+x*xs[0]*ll+l];
1415                     }
1416                     if (plan->flags&FFT5D_DEBUG)
1417                     {
1418                         printf("%f %f, ", a, b);
1419                     }
1420                     else
1421                     {
1422                         if (std::fabs(a-b) > 2*NG[0]*NG[1]*NG[2]*GMX_REAL_EPS)
1423                         {
1424                             printf("result incorrect on %d,%d at %d,%d,%d: FFT5D:%f reference:%f\n", coor[0], coor[1], x, y, z, a, b);
1425                         }
1426 /*                        assert(fabs(a-b)<2*NG[0]*NG[1]*NG[2]*GMX_REAL_EPS);*/
1427                     }
1428                 }
1429                 if (plan->flags&FFT5D_DEBUG)
1430                 {
1431                     printf(",");
1432                 }
1433             }
1434             if (plan->flags&FFT5D_DEBUG)
1435             {
1436                 printf("\n");
1437             }
1438         }
1439     }
1440
1441 }