Rework -Weverything
[alexxy/gromacs.git] / src / gromacs / fft / fft_mkl.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2003 David van der Spoel, Erik Lindahl, University of Groningen.
5  * Copyright (c) 2013,2014,2015,2016,2019,2020,2021, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36 #include "gmxpre.h"
37
38 #include <errno.h>
39 #include <stdlib.h>
40
41 #include <mkl_dfti.h>
42 #include <mkl_service.h>
43
44 #include "gromacs/fft/fft.h"
45 #include "gromacs/utility/fatalerror.h"
46
47
48 /* For MKL version (<10.0), we should define MKL_LONG. */
49 #ifndef MKL_LONG
50 #    define MKL_LONG long int
51 #endif
52
53
54 #if GMX_DOUBLE
55 #    define GMX_DFTI_PREC DFTI_DOUBLE
56 #else
57 #    define GMX_DFTI_PREC DFTI_SINGLE
58 #endif
59
60 /*! \internal
61  * \brief
62  * Contents of the Intel MKL FFT fft datatype.
63  *
64  * Note that this is one of several possible implementations of gmx_fft_t.
65  *
66  *  The MKL _API_ supports 1D,2D, and 3D transforms, including real-to-complex.
67  *  Unfortunately the actual library implementation does not support 3D real
68  *  transforms as of version 7.2, and versions before 7.0 don't support 2D real
69  *  either. In addition, the multi-dimensional storage format for real data
70  *  is not compatible with our padding.
71  *
72  *  To work around this we roll our own 2D and 3D real-to-complex transforms,
73  *  using separate X/Y/Z handles defined to perform (ny*nz), (nx*nz), and
74  *  (nx*ny) transforms at once when necessary. To perform strided multiple
75  *  transforms out-of-place (i.e., without padding in the last dimension)
76  *  on the fly we also need to separate the forward and backward
77  *  handles for real-to-complex/complex-to-real data permutation.
78  *
79  *  This makes it necessary to define 3 handles for in-place FFTs, and 4 for
80  *  the out-of-place transforms. Still, whenever possible we try to use
81  *  a single 3D-transform handle instead.
82  *
83  *  So, the handles are enumerated as follows:
84  *
85  *  1D FFT (real too):    Index 0 is the handle for the entire FFT
86  *  2D complex FFT:       Index 0 is the handle for the entire FFT
87  *  3D complex FFT:       Index 0 is the handle for the entire FFT
88  *  2D, inplace real FFT: 0=FFTx, 1=FFTy handle
89  *  2D, ooplace real FFT: 0=FFTx, 1=real-to-complex FFTy, 2=complex-to-real FFTy
90  *  3D, inplace real FFT: 0=FFTx, 1=FFTy, 2=FFTz handle
91  *  3D, ooplace real FFT: 0=FFTx, 1=FFTy, 2=r2c FFTz, 3=c2r FFTz
92  *
93  *  Intel people reading this: Learn from FFTW what a good interface looks like :-)
94  */
95 #ifdef DOXYGEN
96 struct gmx_fft_mkl
97 #else
98 struct gmx_fft
99 #endif
100 {
101     int              ndim;       /**< Number of dimensions in FFT  */
102     int              nx;         /**< Length of X transform        */
103     int              ny;         /**< Length of Y transform        */
104     int              nz;         /**< Length of Z transform        */
105     int              real_fft;   /**< 1 if real FFT, otherwise 0   */
106     DFTI_DESCRIPTOR* inplace[3]; /**< in-place FFT                 */
107     DFTI_DESCRIPTOR* ooplace[4]; /**< out-of-place FFT             */
108     t_complex*       work;       /**< Enable out-of-place c2r FFT  */
109 };
110
111
112 int gmx_fft_init_1d(gmx_fft_t* pfft, int nxInt, gmx_fft_flag gmx_unused flags)
113 {
114     gmx_fft_t fft;
115     int       d;
116     int       status;
117
118     if (pfft == nullptr)
119     {
120         gmx_fatal(FARGS, "Invalid opaque FFT datatype pointer.");
121         return EINVAL;
122     }
123     *pfft = nullptr;
124
125     if ((fft = reinterpret_cast<gmx_fft_t>(malloc(sizeof(struct gmx_fft)))) == nullptr)
126     {
127         return ENOMEM;
128     }
129
130     /* Mark all handles invalid */
131     for (d = 0; d < 3; d++)
132     {
133         fft->inplace[d] = fft->ooplace[d] = nullptr;
134     }
135     fft->ooplace[3] = nullptr;
136
137
138     MKL_LONG nx = nxInt;
139     status      = DftiCreateDescriptor(&fft->inplace[0], GMX_DFTI_PREC, DFTI_COMPLEX, 1, nx);
140
141     if (status == 0)
142     {
143         status = DftiSetValue(fft->inplace[0], DFTI_PLACEMENT, DFTI_INPLACE);
144     }
145
146     if (status == 0)
147     {
148         status = DftiCommitDescriptor(fft->inplace[0]);
149     }
150
151
152     if (status == 0)
153     {
154         status = DftiCreateDescriptor(&fft->ooplace[0], GMX_DFTI_PREC, DFTI_COMPLEX, 1, nx);
155     }
156
157     if (status == 0)
158     {
159         DftiSetValue(fft->ooplace[0], DFTI_PLACEMENT, DFTI_NOT_INPLACE);
160     }
161
162     if (status == 0)
163     {
164         DftiCommitDescriptor(fft->ooplace[0]);
165     }
166
167
168     if (status != 0)
169     {
170         gmx_fatal(FARGS, "Error initializing Intel MKL FFT; status=%d", status);
171     }
172
173     fft->ndim     = 1;
174     fft->nx       = nx;
175     fft->real_fft = 0;
176     fft->work     = nullptr;
177
178     *pfft = fft;
179     return 0;
180 }
181
182
183 int gmx_fft_init_1d_real(gmx_fft_t* pfft, int nxInt, gmx_fft_flag gmx_unused flags)
184 {
185     gmx_fft_t fft;
186     int       d;
187     int       status;
188
189     if (pfft == nullptr)
190     {
191         gmx_fatal(FARGS, "Invalid opaque FFT datatype pointer.");
192         return EINVAL;
193     }
194     *pfft = nullptr;
195
196     if ((fft = reinterpret_cast<gmx_fft_t>(malloc(sizeof(struct gmx_fft)))) == nullptr)
197     {
198         return ENOMEM;
199     }
200
201     /* Mark all handles invalid */
202     for (d = 0; d < 3; d++)
203     {
204         fft->inplace[d] = fft->ooplace[d] = nullptr;
205     }
206     fft->ooplace[3] = nullptr;
207
208     MKL_LONG nx = nxInt;
209     status = DftiCreateDescriptor(&fft->inplace[0], GMX_DFTI_PREC, DFTI_REAL, 1, nx);
210
211     if (status == 0)
212     {
213         status = DftiSetValue(fft->inplace[0], DFTI_PLACEMENT, DFTI_INPLACE);
214     }
215
216     if (status == 0)
217     {
218         status = DftiCommitDescriptor(fft->inplace[0]);
219     }
220
221
222     if (status == 0)
223     {
224         status = DftiCreateDescriptor(&fft->ooplace[0], GMX_DFTI_PREC, DFTI_REAL, 1, nx);
225     }
226
227     if (status == 0)
228     {
229         status = DftiSetValue(fft->ooplace[0], DFTI_PLACEMENT, DFTI_NOT_INPLACE);
230     }
231
232     if (status == 0)
233     {
234         status = DftiCommitDescriptor(fft->ooplace[0]);
235     }
236
237
238     if (status == DFTI_UNIMPLEMENTED)
239     {
240         gmx_fatal(FARGS, "The linked Intel MKL version (<6.0?) cannot do real FFTs.");
241     }
242
243
244     if (status != 0)
245     {
246         gmx_fatal(FARGS, "Error initializing Intel MKL FFT; status=%d", status);
247     }
248
249     fft->ndim     = 1;
250     fft->nx       = nx;
251     fft->real_fft = 1;
252     fft->work     = nullptr;
253
254     *pfft = fft;
255     return 0;
256 }
257
258
259 int gmx_fft_init_2d_real(gmx_fft_t* pfft, int nxInt, int nyInt, gmx_fft_flag gmx_unused flags)
260 {
261     gmx_fft_t fft;
262     int       d;
263     int       status;
264     MKL_LONG  stride[2];
265     MKL_LONG  nyc;
266
267     if (pfft == nullptr)
268     {
269         gmx_fatal(FARGS, "Invalid opaque FFT datatype pointer.");
270         return EINVAL;
271     }
272     *pfft = nullptr;
273
274     if ((fft = reinterpret_cast<gmx_fft_t>(malloc(sizeof(struct gmx_fft)))) == nullptr)
275     {
276         return ENOMEM;
277     }
278
279     nyc = (nyInt / 2 + 1);
280
281     /* Mark all handles invalid */
282     for (d = 0; d < 3; d++)
283     {
284         fft->inplace[d] = fft->ooplace[d] = nullptr;
285     }
286     fft->ooplace[3] = nullptr;
287
288     /* Roll our own 2D real transform using multiple transforms in MKL,
289      * since the current MKL versions does not support our storage format,
290      * and all but the most recent don't even have 2D real FFTs.
291      */
292
293     /* In-place X FFT */
294     MKL_LONG nx = nxInt;
295     status      = DftiCreateDescriptor(&fft->inplace[0], GMX_DFTI_PREC, DFTI_COMPLEX, 1, nx);
296
297     if (status == 0)
298     {
299         stride[0] = 0;
300         stride[1] = nyc;
301
302         status = (DftiSetValue(fft->inplace[0], DFTI_PLACEMENT, DFTI_INPLACE)
303                   || DftiSetValue(fft->inplace[0], DFTI_NUMBER_OF_TRANSFORMS, nyc)
304                   || DftiSetValue(fft->inplace[0], DFTI_INPUT_DISTANCE, 1)
305                   || DftiSetValue(fft->inplace[0], DFTI_INPUT_STRIDES, stride)
306                   || DftiSetValue(fft->inplace[0], DFTI_OUTPUT_DISTANCE, 1)
307                   || DftiSetValue(fft->inplace[0], DFTI_OUTPUT_STRIDES, stride));
308     }
309
310     if (status == 0)
311     {
312         status = DftiCommitDescriptor(fft->inplace[0]);
313     }
314
315     /* Out-of-place X FFT */
316     if (status == 0)
317     {
318         status = DftiCreateDescriptor(&(fft->ooplace[0]), GMX_DFTI_PREC, DFTI_COMPLEX, 1, nx);
319     }
320
321     if (status == 0)
322     {
323         stride[0] = 0;
324         stride[1] = nyc;
325
326         status = (DftiSetValue(fft->ooplace[0], DFTI_PLACEMENT, DFTI_NOT_INPLACE)
327                   || DftiSetValue(fft->ooplace[0], DFTI_NUMBER_OF_TRANSFORMS, nyc)
328                   || DftiSetValue(fft->ooplace[0], DFTI_INPUT_DISTANCE, 1)
329                   || DftiSetValue(fft->ooplace[0], DFTI_INPUT_STRIDES, stride)
330                   || DftiSetValue(fft->ooplace[0], DFTI_OUTPUT_DISTANCE, 1)
331                   || DftiSetValue(fft->ooplace[0], DFTI_OUTPUT_STRIDES, stride));
332     }
333
334     if (status == 0)
335     {
336         status = DftiCommitDescriptor(fft->ooplace[0]);
337     }
338
339
340     /* In-place Y FFT  */
341     MKL_LONG ny = nyInt;
342     if (status == 0)
343     {
344         status = DftiCreateDescriptor(&fft->inplace[1], GMX_DFTI_PREC, DFTI_REAL, 1, ny);
345     }
346
347     if (status == 0)
348     {
349         stride[0] = 0;
350         stride[1] = 1;
351
352         status = (DftiSetValue(fft->inplace[1], DFTI_PLACEMENT, DFTI_INPLACE)
353                   || DftiSetValue(fft->inplace[1], DFTI_NUMBER_OF_TRANSFORMS, nx)
354                   || DftiSetValue(fft->inplace[1], DFTI_INPUT_DISTANCE, 2 * nyc)
355                   || DftiSetValue(fft->inplace[1], DFTI_INPUT_STRIDES, stride)
356                   || DftiSetValue(fft->inplace[1], DFTI_OUTPUT_DISTANCE, 2 * nyc)
357                   || DftiSetValue(fft->inplace[1], DFTI_OUTPUT_STRIDES, stride)
358                   || DftiCommitDescriptor(fft->inplace[1]));
359     }
360
361
362     /* Out-of-place real-to-complex (affects output distance) Y FFT */
363     if (status == 0)
364     {
365         status = DftiCreateDescriptor(&fft->ooplace[1], GMX_DFTI_PREC, DFTI_REAL, 1, ny);
366     }
367
368     if (status == 0)
369     {
370         stride[0] = 0;
371         stride[1] = 1;
372
373         status = (DftiSetValue(fft->ooplace[1], DFTI_PLACEMENT, DFTI_NOT_INPLACE)
374                   || DftiSetValue(fft->ooplace[1], DFTI_NUMBER_OF_TRANSFORMS, nx)
375                   || DftiSetValue(fft->ooplace[1], DFTI_INPUT_DISTANCE, ny)
376                   || DftiSetValue(fft->ooplace[1], DFTI_INPUT_STRIDES, stride)
377                   || DftiSetValue(fft->ooplace[1], DFTI_OUTPUT_DISTANCE, 2 * nyc)
378                   || DftiSetValue(fft->ooplace[1], DFTI_OUTPUT_STRIDES, stride)
379                   || DftiCommitDescriptor(fft->ooplace[1]));
380     }
381
382
383     /* Out-of-place complex-to-real (affects output distance) Y FFT */
384     if (status == 0)
385     {
386         status = DftiCreateDescriptor(&fft->ooplace[2], GMX_DFTI_PREC, DFTI_REAL, 1, ny);
387     }
388
389     if (status == 0)
390     {
391         stride[0] = 0;
392         stride[1] = 1;
393
394         status = (DftiSetValue(fft->ooplace[2], DFTI_PLACEMENT, DFTI_NOT_INPLACE)
395                   || DftiSetValue(fft->ooplace[2], DFTI_NUMBER_OF_TRANSFORMS, nx)
396                   || DftiSetValue(fft->ooplace[2], DFTI_INPUT_DISTANCE, 2 * nyc)
397                   || DftiSetValue(fft->ooplace[2], DFTI_INPUT_STRIDES, stride)
398                   || DftiSetValue(fft->ooplace[2], DFTI_OUTPUT_DISTANCE, ny)
399                   || DftiSetValue(fft->ooplace[2], DFTI_OUTPUT_STRIDES, stride)
400                   || DftiCommitDescriptor(fft->ooplace[2]));
401     }
402
403
404     if (status == 0)
405     {
406         void* memory = malloc(sizeof(t_complex) * (nx * (ny / 2 + 1)));
407         if (nullptr == memory)
408         {
409             status = ENOMEM;
410         }
411         fft->work = static_cast<t_complex*>(memory);
412     }
413
414     if (status != 0)
415     {
416         gmx_fatal(FARGS, "Error initializing Intel MKL FFT; status=%d", status);
417     }
418
419     fft->ndim     = 2;
420     fft->nx       = nx;
421     fft->ny       = ny;
422     fft->real_fft = 1;
423
424     *pfft = fft;
425     return 0;
426 }
427
428 int gmx_fft_1d(gmx_fft_t fft, enum gmx_fft_direction dir, void* in_data, void* out_data)
429 {
430     int inplace = (in_data == out_data);
431     int status  = 0;
432
433     if ((fft->real_fft == 1) || (fft->ndim != 1) || ((dir != GMX_FFT_FORWARD) && (dir != GMX_FFT_BACKWARD)))
434     {
435         gmx_fatal(FARGS, "FFT plan mismatch - bad plan or direction.");
436         return EINVAL;
437     }
438
439     if (dir == GMX_FFT_FORWARD)
440     {
441         if (inplace)
442         {
443             status = DftiComputeForward(fft->inplace[0], in_data);
444         }
445         else
446         {
447             status = DftiComputeForward(fft->ooplace[0], in_data, out_data);
448         }
449     }
450     else
451     {
452         if (inplace)
453         {
454             status = DftiComputeBackward(fft->inplace[0], in_data);
455         }
456         else
457         {
458             status = DftiComputeBackward(fft->ooplace[0], in_data, out_data);
459         }
460     }
461
462     if (status != 0)
463     {
464         gmx_fatal(FARGS, "Error executing Intel MKL FFT.");
465     }
466
467     return status;
468 }
469
470
471 int gmx_fft_1d_real(gmx_fft_t fft, enum gmx_fft_direction dir, void* in_data, void* out_data)
472 {
473     int inplace = (in_data == out_data);
474     int status  = 0;
475
476     if ((fft->real_fft != 1) || (fft->ndim != 1)
477         || ((dir != GMX_FFT_REAL_TO_COMPLEX) && (dir != GMX_FFT_COMPLEX_TO_REAL)))
478     {
479         gmx_fatal(FARGS, "FFT plan mismatch - bad plan or direction.");
480         return EINVAL;
481     }
482
483     if (dir == GMX_FFT_REAL_TO_COMPLEX)
484     {
485         if (inplace)
486         {
487             status = DftiComputeForward(fft->inplace[0], in_data);
488         }
489         else
490         {
491             status = DftiComputeForward(fft->ooplace[0], in_data, out_data);
492         }
493     }
494     else
495     {
496         if (inplace)
497         {
498             status = DftiComputeBackward(fft->inplace[0], in_data);
499         }
500         else
501         {
502             status = DftiComputeBackward(fft->ooplace[0], in_data, out_data);
503         }
504     }
505
506     if (status != 0)
507     {
508         gmx_fatal(FARGS, "Error executing Intel MKL FFT.");
509     }
510
511     return status;
512 }
513
514
515 int gmx_fft_2d_real(gmx_fft_t fft, enum gmx_fft_direction dir, void* in_data, void* out_data)
516 {
517     int inplace = (in_data == out_data);
518     int status  = 0;
519
520     if ((fft->real_fft != 1) || (fft->ndim != 2)
521         || ((dir != GMX_FFT_REAL_TO_COMPLEX) && (dir != GMX_FFT_COMPLEX_TO_REAL)))
522     {
523         gmx_fatal(FARGS, "FFT plan mismatch - bad plan or direction.");
524     }
525
526     if (dir == GMX_FFT_REAL_TO_COMPLEX)
527     {
528         if (inplace)
529         {
530             /* real-to-complex in Y dimension, in-place */
531             status = DftiComputeForward(fft->inplace[1], in_data);
532
533             /* complex-to-complex in X dimension, in-place */
534             if (status == 0)
535             {
536                 status = DftiComputeForward(fft->inplace[0], in_data);
537             }
538         }
539         else
540         {
541             /* real-to-complex in Y dimension, in_data to out_data */
542             status = DftiComputeForward(fft->ooplace[1], in_data, out_data);
543
544             /* complex-to-complex in X dimension, in-place to out_data */
545             if (status == 0)
546             {
547                 status = DftiComputeForward(fft->inplace[0], out_data);
548             }
549         }
550     }
551     else
552     {
553         /* prior implementation was incorrect. See fft.cpp unit test */
554         gmx_incons("Complex -> Real is not supported by MKL.");
555     }
556
557     if (status != 0)
558     {
559         gmx_fatal(FARGS, "Error executing Intel MKL FFT.");
560     }
561
562     return status;
563 }
564
565 void gmx_fft_destroy(gmx_fft_t fft)
566 {
567     int d;
568
569     if (fft != nullptr)
570     {
571         for (d = 0; d < 3; d++)
572         {
573             if (fft->inplace[d] != nullptr)
574             {
575                 DftiFreeDescriptor(&fft->inplace[d]);
576             }
577             if (fft->ooplace[d] != nullptr)
578             {
579                 DftiFreeDescriptor(&fft->ooplace[d]);
580             }
581         }
582         if (fft->ooplace[3] != nullptr)
583         {
584             DftiFreeDescriptor(&fft->ooplace[3]);
585         }
586         if (fft->work != nullptr)
587         {
588             free(fft->work);
589         }
590         free(fft);
591     }
592 }
593
594 void gmx_fft_cleanup()
595 {
596     mkl_free_buffers();
597 }