a64a7574c563af359288cabea32cf1276de8bbd2
[alexxy/gromacs.git] / src / gromacs / fft / fft_mkl.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2003 David van der Spoel, Erik Lindahl, University of Groningen.
5  * Copyright (c) 2013,2014,2015,2016,2019,2020, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36 #include "gmxpre.h"
37
38 #include <errno.h>
39 #include <stdlib.h>
40
41 #include <mkl_dfti.h>
42 #include <mkl_service.h>
43
44 #include "gromacs/fft/fft.h"
45 #include "gromacs/utility/fatalerror.h"
46
47
48 /* For MKL version (<10.0), we should define MKL_LONG. */
49 #ifndef MKL_LONG
50 #    define MKL_LONG long int
51 #endif
52
53
54 #if GMX_DOUBLE
55 #    define GMX_DFTI_PREC DFTI_DOUBLE
56 #else
57 #    define GMX_DFTI_PREC DFTI_SINGLE
58 #endif
59
60 /*! \internal
61  * \brief
62  * Contents of the Intel MKL FFT fft datatype.
63  *
64  * Note that this is one of several possible implementations of gmx_fft_t.
65  *
66  *  The MKL _API_ supports 1D,2D, and 3D transforms, including real-to-complex.
67  *  Unfortunately the actual library implementation does not support 3D real
68  *  transforms as of version 7.2, and versions before 7.0 don't support 2D real
69  *  either. In addition, the multi-dimensional storage format for real data
70  *  is not compatible with our padding.
71  *
72  *  To work around this we roll our own 2D and 3D real-to-complex transforms,
73  *  using separate X/Y/Z handles defined to perform (ny*nz), (nx*nz), and
74  *  (nx*ny) transforms at once when necessary. To perform strided multiple
75  *  transforms out-of-place (i.e., without padding in the last dimension)
76  *  on the fly we also need to separate the forward and backward
77  *  handles for real-to-complex/complex-to-real data permutation.
78  *
79  *  This makes it necessary to define 3 handles for in-place FFTs, and 4 for
80  *  the out-of-place transforms. Still, whenever possible we try to use
81  *  a single 3D-transform handle instead.
82  *
83  *  So, the handles are enumerated as follows:
84  *
85  *  1D FFT (real too):    Index 0 is the handle for the entire FFT
86  *  2D complex FFT:       Index 0 is the handle for the entire FFT
87  *  3D complex FFT:       Index 0 is the handle for the entire FFT
88  *  2D, inplace real FFT: 0=FFTx, 1=FFTy handle
89  *  2D, ooplace real FFT: 0=FFTx, 1=real-to-complex FFTy, 2=complex-to-real FFTy
90  *  3D, inplace real FFT: 0=FFTx, 1=FFTy, 2=FFTz handle
91  *  3D, ooplace real FFT: 0=FFTx, 1=FFTy, 2=r2c FFTz, 3=c2r FFTz
92  *
93  *  Intel people reading this: Learn from FFTW what a good interface looks like :-)
94  */
95 #ifdef DOXYGEN
96 struct gmx_fft_mkl
97 #else
98 struct gmx_fft
99 #endif
100 {
101     int              ndim;       /**< Number of dimensions in FFT  */
102     int              nx;         /**< Length of X transform        */
103     int              ny;         /**< Length of Y transform        */
104     int              nz;         /**< Length of Z transform        */
105     int              real_fft;   /**< 1 if real FFT, otherwise 0   */
106     DFTI_DESCRIPTOR* inplace[3]; /**< in-place FFT                 */
107     DFTI_DESCRIPTOR* ooplace[4]; /**< out-of-place FFT             */
108     t_complex*       work;       /**< Enable out-of-place c2r FFT  */
109 };
110
111
112 int gmx_fft_init_1d(gmx_fft_t* pfft, int nx, gmx_fft_flag gmx_unused flags)
113 {
114     gmx_fft_t fft;
115     int       d;
116     int       status;
117
118     if (pfft == nullptr)
119     {
120         gmx_fatal(FARGS, "Invalid opaque FFT datatype pointer.");
121         return EINVAL;
122     }
123     *pfft = nullptr;
124
125     if ((fft = (gmx_fft_t)malloc(sizeof(struct gmx_fft))) == nullptr)
126     {
127         return ENOMEM;
128     }
129
130     /* Mark all handles invalid */
131     for (d = 0; d < 3; d++)
132     {
133         fft->inplace[d] = fft->ooplace[d] = nullptr;
134     }
135     fft->ooplace[3] = nullptr;
136
137
138     status = DftiCreateDescriptor(&fft->inplace[0], GMX_DFTI_PREC, DFTI_COMPLEX, 1, (MKL_LONG)nx);
139
140     if (status == 0)
141     {
142         status = DftiSetValue(fft->inplace[0], DFTI_PLACEMENT, DFTI_INPLACE);
143     }
144
145     if (status == 0)
146     {
147         status = DftiCommitDescriptor(fft->inplace[0]);
148     }
149
150
151     if (status == 0)
152     {
153         status = DftiCreateDescriptor(&fft->ooplace[0], GMX_DFTI_PREC, DFTI_COMPLEX, 1, (MKL_LONG)nx);
154     }
155
156     if (status == 0)
157     {
158         DftiSetValue(fft->ooplace[0], DFTI_PLACEMENT, DFTI_NOT_INPLACE);
159     }
160
161     if (status == 0)
162     {
163         DftiCommitDescriptor(fft->ooplace[0]);
164     }
165
166
167     if (status != 0)
168     {
169         gmx_fatal(FARGS, "Error initializing Intel MKL FFT; status=%d", status);
170     }
171
172     fft->ndim     = 1;
173     fft->nx       = nx;
174     fft->real_fft = 0;
175     fft->work     = nullptr;
176
177     *pfft = fft;
178     return 0;
179 }
180
181
182 int gmx_fft_init_1d_real(gmx_fft_t* pfft, int nx, gmx_fft_flag gmx_unused flags)
183 {
184     gmx_fft_t fft;
185     int       d;
186     int       status;
187
188     if (pfft == nullptr)
189     {
190         gmx_fatal(FARGS, "Invalid opaque FFT datatype pointer.");
191         return EINVAL;
192     }
193     *pfft = nullptr;
194
195     if ((fft = (gmx_fft_t)malloc(sizeof(struct gmx_fft))) == nullptr)
196     {
197         return ENOMEM;
198     }
199
200     /* Mark all handles invalid */
201     for (d = 0; d < 3; d++)
202     {
203         fft->inplace[d] = fft->ooplace[d] = nullptr;
204     }
205     fft->ooplace[3] = nullptr;
206
207     status = DftiCreateDescriptor(&fft->inplace[0], GMX_DFTI_PREC, DFTI_REAL, 1, (MKL_LONG)nx);
208
209     if (status == 0)
210     {
211         status = DftiSetValue(fft->inplace[0], DFTI_PLACEMENT, DFTI_INPLACE);
212     }
213
214     if (status == 0)
215     {
216         status = DftiCommitDescriptor(fft->inplace[0]);
217     }
218
219
220     if (status == 0)
221     {
222         status = DftiCreateDescriptor(&fft->ooplace[0], GMX_DFTI_PREC, DFTI_REAL, 1, (MKL_LONG)nx);
223     }
224
225     if (status == 0)
226     {
227         status = DftiSetValue(fft->ooplace[0], DFTI_PLACEMENT, DFTI_NOT_INPLACE);
228     }
229
230     if (status == 0)
231     {
232         status = DftiCommitDescriptor(fft->ooplace[0]);
233     }
234
235
236     if (status == DFTI_UNIMPLEMENTED)
237     {
238         gmx_fatal(FARGS, "The linked Intel MKL version (<6.0?) cannot do real FFTs.");
239     }
240
241
242     if (status != 0)
243     {
244         gmx_fatal(FARGS, "Error initializing Intel MKL FFT; status=%d", status);
245     }
246
247     fft->ndim     = 1;
248     fft->nx       = nx;
249     fft->real_fft = 1;
250     fft->work     = nullptr;
251
252     *pfft = fft;
253     return 0;
254 }
255
256
257 int gmx_fft_init_2d_real(gmx_fft_t* pfft, int nx, int ny, gmx_fft_flag gmx_unused flags)
258 {
259     gmx_fft_t fft;
260     int       d;
261     int       status;
262     MKL_LONG  stride[2];
263     MKL_LONG  nyc;
264
265     if (pfft == nullptr)
266     {
267         gmx_fatal(FARGS, "Invalid opaque FFT datatype pointer.");
268         return EINVAL;
269     }
270     *pfft = nullptr;
271
272     if ((fft = (gmx_fft_t)malloc(sizeof(struct gmx_fft))) == nullptr)
273     {
274         return ENOMEM;
275     }
276
277     nyc = (ny / 2 + 1);
278
279     /* Mark all handles invalid */
280     for (d = 0; d < 3; d++)
281     {
282         fft->inplace[d] = fft->ooplace[d] = nullptr;
283     }
284     fft->ooplace[3] = nullptr;
285
286     /* Roll our own 2D real transform using multiple transforms in MKL,
287      * since the current MKL versions does not support our storage format,
288      * and all but the most recent don't even have 2D real FFTs.
289      */
290
291     /* In-place X FFT */
292     status = DftiCreateDescriptor(&fft->inplace[0], GMX_DFTI_PREC, DFTI_COMPLEX, 1, (MKL_LONG)nx);
293
294     if (status == 0)
295     {
296         stride[0] = 0;
297         stride[1] = nyc;
298
299         status = (DftiSetValue(fft->inplace[0], DFTI_PLACEMENT, DFTI_INPLACE)
300                   || DftiSetValue(fft->inplace[0], DFTI_NUMBER_OF_TRANSFORMS, nyc)
301                   || DftiSetValue(fft->inplace[0], DFTI_INPUT_DISTANCE, 1)
302                   || DftiSetValue(fft->inplace[0], DFTI_INPUT_STRIDES, stride)
303                   || DftiSetValue(fft->inplace[0], DFTI_OUTPUT_DISTANCE, 1)
304                   || DftiSetValue(fft->inplace[0], DFTI_OUTPUT_STRIDES, stride));
305     }
306
307     if (status == 0)
308     {
309         status = DftiCommitDescriptor(fft->inplace[0]);
310     }
311
312     /* Out-of-place X FFT */
313     if (status == 0)
314     {
315         status = DftiCreateDescriptor(&(fft->ooplace[0]), GMX_DFTI_PREC, DFTI_COMPLEX, 1, (MKL_LONG)nx);
316     }
317
318     if (status == 0)
319     {
320         stride[0] = 0;
321         stride[1] = nyc;
322
323         status = (DftiSetValue(fft->ooplace[0], DFTI_PLACEMENT, DFTI_NOT_INPLACE)
324                   || DftiSetValue(fft->ooplace[0], DFTI_NUMBER_OF_TRANSFORMS, nyc)
325                   || DftiSetValue(fft->ooplace[0], DFTI_INPUT_DISTANCE, 1)
326                   || DftiSetValue(fft->ooplace[0], DFTI_INPUT_STRIDES, stride)
327                   || DftiSetValue(fft->ooplace[0], DFTI_OUTPUT_DISTANCE, 1)
328                   || DftiSetValue(fft->ooplace[0], DFTI_OUTPUT_STRIDES, stride));
329     }
330
331     if (status == 0)
332     {
333         status = DftiCommitDescriptor(fft->ooplace[0]);
334     }
335
336
337     /* In-place Y FFT  */
338     if (status == 0)
339     {
340         status = DftiCreateDescriptor(&fft->inplace[1], GMX_DFTI_PREC, DFTI_REAL, 1, (MKL_LONG)ny);
341     }
342
343     if (status == 0)
344     {
345         stride[0] = 0;
346         stride[1] = 1;
347
348         status = (DftiSetValue(fft->inplace[1], DFTI_PLACEMENT, DFTI_INPLACE)
349                   || DftiSetValue(fft->inplace[1], DFTI_NUMBER_OF_TRANSFORMS, (MKL_LONG)nx)
350                   || DftiSetValue(fft->inplace[1], DFTI_INPUT_DISTANCE, 2 * nyc)
351                   || DftiSetValue(fft->inplace[1], DFTI_INPUT_STRIDES, stride)
352                   || DftiSetValue(fft->inplace[1], DFTI_OUTPUT_DISTANCE, 2 * nyc)
353                   || DftiSetValue(fft->inplace[1], DFTI_OUTPUT_STRIDES, stride)
354                   || DftiCommitDescriptor(fft->inplace[1]));
355     }
356
357
358     /* Out-of-place real-to-complex (affects output distance) Y FFT */
359     if (status == 0)
360     {
361         status = DftiCreateDescriptor(&fft->ooplace[1], GMX_DFTI_PREC, DFTI_REAL, 1, (MKL_LONG)ny);
362     }
363
364     if (status == 0)
365     {
366         stride[0] = 0;
367         stride[1] = 1;
368
369         status = (DftiSetValue(fft->ooplace[1], DFTI_PLACEMENT, DFTI_NOT_INPLACE)
370                   || DftiSetValue(fft->ooplace[1], DFTI_NUMBER_OF_TRANSFORMS, (MKL_LONG)nx)
371                   || DftiSetValue(fft->ooplace[1], DFTI_INPUT_DISTANCE, (MKL_LONG)ny)
372                   || DftiSetValue(fft->ooplace[1], DFTI_INPUT_STRIDES, stride)
373                   || DftiSetValue(fft->ooplace[1], DFTI_OUTPUT_DISTANCE, 2 * nyc)
374                   || DftiSetValue(fft->ooplace[1], DFTI_OUTPUT_STRIDES, stride)
375                   || DftiCommitDescriptor(fft->ooplace[1]));
376     }
377
378
379     /* Out-of-place complex-to-real (affects output distance) Y FFT */
380     if (status == 0)
381     {
382         status = DftiCreateDescriptor(&fft->ooplace[2], GMX_DFTI_PREC, DFTI_REAL, 1, (MKL_LONG)ny);
383     }
384
385     if (status == 0)
386     {
387         stride[0] = 0;
388         stride[1] = 1;
389
390         status = (DftiSetValue(fft->ooplace[2], DFTI_PLACEMENT, DFTI_NOT_INPLACE)
391                   || DftiSetValue(fft->ooplace[2], DFTI_NUMBER_OF_TRANSFORMS, (MKL_LONG)nx)
392                   || DftiSetValue(fft->ooplace[2], DFTI_INPUT_DISTANCE, 2 * nyc)
393                   || DftiSetValue(fft->ooplace[2], DFTI_INPUT_STRIDES, stride)
394                   || DftiSetValue(fft->ooplace[2], DFTI_OUTPUT_DISTANCE, (MKL_LONG)ny)
395                   || DftiSetValue(fft->ooplace[2], DFTI_OUTPUT_STRIDES, stride)
396                   || DftiCommitDescriptor(fft->ooplace[2]));
397     }
398
399
400     if (status == 0)
401     {
402         void* memory = malloc(sizeof(t_complex) * (nx * (ny / 2 + 1)));
403         if (nullptr == memory)
404         {
405             status = ENOMEM;
406         }
407         fft->work = static_cast<t_complex*>(memory);
408     }
409
410     if (status != 0)
411     {
412         gmx_fatal(FARGS, "Error initializing Intel MKL FFT; status=%d", status);
413     }
414
415     fft->ndim     = 2;
416     fft->nx       = nx;
417     fft->ny       = ny;
418     fft->real_fft = 1;
419
420     *pfft = fft;
421     return 0;
422 }
423
424 int gmx_fft_1d(gmx_fft_t fft, enum gmx_fft_direction dir, void* in_data, void* out_data)
425 {
426     int inplace = (in_data == out_data);
427     int status  = 0;
428
429     if ((fft->real_fft == 1) || (fft->ndim != 1) || ((dir != GMX_FFT_FORWARD) && (dir != GMX_FFT_BACKWARD)))
430     {
431         gmx_fatal(FARGS, "FFT plan mismatch - bad plan or direction.");
432         return EINVAL;
433     }
434
435     if (dir == GMX_FFT_FORWARD)
436     {
437         if (inplace)
438         {
439             status = DftiComputeForward(fft->inplace[0], in_data);
440         }
441         else
442         {
443             status = DftiComputeForward(fft->ooplace[0], in_data, out_data);
444         }
445     }
446     else
447     {
448         if (inplace)
449         {
450             status = DftiComputeBackward(fft->inplace[0], in_data);
451         }
452         else
453         {
454             status = DftiComputeBackward(fft->ooplace[0], in_data, out_data);
455         }
456     }
457
458     if (status != 0)
459     {
460         gmx_fatal(FARGS, "Error executing Intel MKL FFT.");
461     }
462
463     return status;
464 }
465
466
467 int gmx_fft_1d_real(gmx_fft_t fft, enum gmx_fft_direction dir, void* in_data, void* out_data)
468 {
469     int inplace = (in_data == out_data);
470     int status  = 0;
471
472     if ((fft->real_fft != 1) || (fft->ndim != 1)
473         || ((dir != GMX_FFT_REAL_TO_COMPLEX) && (dir != GMX_FFT_COMPLEX_TO_REAL)))
474     {
475         gmx_fatal(FARGS, "FFT plan mismatch - bad plan or direction.");
476         return EINVAL;
477     }
478
479     if (dir == GMX_FFT_REAL_TO_COMPLEX)
480     {
481         if (inplace)
482         {
483             status = DftiComputeForward(fft->inplace[0], in_data);
484         }
485         else
486         {
487             status = DftiComputeForward(fft->ooplace[0], in_data, out_data);
488         }
489     }
490     else
491     {
492         if (inplace)
493         {
494             status = DftiComputeBackward(fft->inplace[0], in_data);
495         }
496         else
497         {
498             status = DftiComputeBackward(fft->ooplace[0], in_data, out_data);
499         }
500     }
501
502     if (status != 0)
503     {
504         gmx_fatal(FARGS, "Error executing Intel MKL FFT.");
505     }
506
507     return status;
508 }
509
510
511 int gmx_fft_2d_real(gmx_fft_t fft, enum gmx_fft_direction dir, void* in_data, void* out_data)
512 {
513     int inplace = (in_data == out_data);
514     int status  = 0;
515
516     if ((fft->real_fft != 1) || (fft->ndim != 2)
517         || ((dir != GMX_FFT_REAL_TO_COMPLEX) && (dir != GMX_FFT_COMPLEX_TO_REAL)))
518     {
519         gmx_fatal(FARGS, "FFT plan mismatch - bad plan or direction.");
520     }
521
522     if (dir == GMX_FFT_REAL_TO_COMPLEX)
523     {
524         if (inplace)
525         {
526             /* real-to-complex in Y dimension, in-place */
527             status = DftiComputeForward(fft->inplace[1], in_data);
528
529             /* complex-to-complex in X dimension, in-place */
530             if (status == 0)
531             {
532                 status = DftiComputeForward(fft->inplace[0], in_data);
533             }
534         }
535         else
536         {
537             /* real-to-complex in Y dimension, in_data to out_data */
538             status = DftiComputeForward(fft->ooplace[1], in_data, out_data);
539
540             /* complex-to-complex in X dimension, in-place to out_data */
541             if (status == 0)
542             {
543                 status = DftiComputeForward(fft->inplace[0], out_data);
544             }
545         }
546     }
547     else
548     {
549         /* prior implementation was incorrect. See fft.cpp unit test */
550         gmx_incons("Complex -> Real is not supported by MKL.");
551     }
552
553     if (status != 0)
554     {
555         gmx_fatal(FARGS, "Error executing Intel MKL FFT.");
556     }
557
558     return status;
559 }
560
561 void gmx_fft_destroy(gmx_fft_t fft)
562 {
563     int d;
564
565     if (fft != nullptr)
566     {
567         for (d = 0; d < 3; d++)
568         {
569             if (fft->inplace[d] != nullptr)
570             {
571                 DftiFreeDescriptor(&fft->inplace[d]);
572             }
573             if (fft->ooplace[d] != nullptr)
574             {
575                 DftiFreeDescriptor(&fft->ooplace[d]);
576             }
577         }
578         if (fft->ooplace[3] != nullptr)
579         {
580             DftiFreeDescriptor(&fft->ooplace[3]);
581         }
582         if (fft->work != nullptr)
583         {
584             free(fft->work);
585         }
586         free(fft);
587     }
588 }
589
590 void gmx_fft_cleanup()
591 {
592     mkl_free_buffers();
593 }