[FFT] split FFT API into different layers (#2793)

* fft refactoring

* separate header for STL overloads

* ditto

---------

Co-authored-by: pf <pf@me>
This commit is contained in:
pfeatherstone 2023-05-06 16:43:10 +01:00 committed by GitHub
parent 48fdace271
commit ff331517b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 216 additions and 171 deletions

View File

@ -277,6 +277,7 @@ if (NOT TARGET dlib)
tokenizer/tokenizer_kernel_1.cpp
unicode/unicode.cpp
test_for_odr_violations.cpp
fft/fft.cpp
)
set(dlib_needed_public_libraries)

49
dlib/fft/fft.cpp Normal file
View File

@ -0,0 +1,49 @@
#include "fft.h"
#ifdef DLIB_USE_MKL_FFT
#include "mkl_fft.h"
#else
#include "kiss_fft.h"
#endif
namespace dlib
{
template<typename T>
void fft(const fft_size& dims, const std::complex<T>* in, std::complex<T>* out, bool is_inverse)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fft(dims, in, out, is_inverse);
#else
kiss_fft(dims, in, out, is_inverse);
#endif
}
template<typename T>
void fftr(const fft_size& dims, const T* in, std::complex<T>* out)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fftr(dims, in, out);
#else
kiss_fftr(dims, in, out);
#endif
}
template<typename T>
void ifftr(const fft_size& dims, const std::complex<T>* in, T* out)
{
#ifdef DLIB_USE_MKL_FFT
mkl_ifftr(dims, in, out);
#else
kiss_ifftr(dims, in, out);
#endif
}
template void fft<float>(const fft_size& dims, const std::complex<float>* in, std::complex<float>* out, bool is_inverse);
template void fft<double>(const fft_size& dims, const std::complex<double>* in, std::complex<double>* out, bool is_inverse);
template void fftr<float>(const fft_size& dims, const float* in, std::complex<float>* out);
template void fftr<double>(const fft_size& dims, const double* in, std::complex<double>* out);
template void ifftr<float>(const fft_size& dims, const std::complex<float>* in, float* out);
template void ifftr<double>(const fft_size& dims, const std::complex<double>* in, double* out);
}

102
dlib/fft/fft.h Normal file
View File

@ -0,0 +1,102 @@
// Copyright (C) 2023 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_FFT_DETAILS_Hh_
#define DLIB_FFT_DETAILS_Hh_
#include <complex>
#include "fft_size.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
constexpr bool is_power_of_two (
const unsigned long n
)
/*!
ensures
- returns true if value contains a power of two and false otherwise. As a
special case, we also consider 0 to be a power of two.
!*/
{
return n == 0 ? true : (n & (n - 1)) == 0;
}
// ----------------------------------------------------------------------------------------
constexpr long fftr_nc_size(
long nc
)
/*!
ensures
- returns the output dimension of a 1D real FFT
!*/
{
return nc == 0 ? 0 : nc/2+1;
}
// ----------------------------------------------------------------------------------------
constexpr long ifftr_nc_size(
long nc
)
/*!
ensures
- returns the output dimension of an inverse 1D real FFT
!*/
{
return nc == 0 ? 0 : 2*(nc-1);
}
// ----------------------------------------------------------------------------------------
template<typename T>
void fft(const fft_size& dims, const std::complex<T>* in, std::complex<T>* out, bool is_inverse);
/*!
requires
- T must be either float or double
- dims represents the dimensions of both `in` and `out`
- dims.num_dims() > 0
ensures
- performs an FFT on `in` and stores the result in `out`.
- if `is_inverse` is true, a backward FFT is performed,
otherwise a forward FFT is performed.
!*/
// ----------------------------------------------------------------------------------------
template<typename T>
void fftr(const fft_size& dims, const T* in, std::complex<T>* out);
/*!
requires
- T must be either float or double
- dims represent the dimensions of `in`
- `in` has dimensions {dims[0], dims[1], ..., dims[-2], dims[-1]}
- `out` has dimensions {dims[0], dims[1], ..., dims[-2], dims[-1]/2+1}
- dims.num_dims() > 0
- dims.back() must be even
ensures
- performs a real FFT on `in` and stores the result in `out`.
!*/
// ----------------------------------------------------------------------------------------
template<typename T>
void ifftr(const fft_size& dims, const std::complex<T>* in, T* out);
/*!
requires
- T must be either float or double
- dims represent the dimensions of `out`
- `in` has dimensions {dims[0], dims[1], ..., dims[-2], dims[-1]/2+1}
- `out` has dimensions {dims[0], dims[1], ..., dims[-2], dims[-1]}
- dims.num_dims() > 0
- dims.back() must be even
ensures
- performs an inverse real FFT on `in` and stores the result in `out`.
!*/
// ----------------------------------------------------------------------------------------
}
#endif //DLIB_FFT_DETAILS_Hh_

54
dlib/fft/fft_stl.h Normal file
View File

@ -0,0 +1,54 @@
// Copyright (C) 2023 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_FFT_STL_Hh_
#define DLIB_FFT_STL_Hh_
#include <vector>
#include "fft.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template < typename T, typename Alloc >
void fft_inplace (std::vector<std::complex<T>, Alloc>& data)
/*!
requires
- data contains elements of type std::complex<> that itself contains double, float, or long double.
ensures
- This function is identical to fft() except that it does the FFT in-place.
That is, after this function executes we will have:
- #data == fft(data)
!*/
{
static_assert(std::is_floating_point<T>::value, "only support floating point types");
if (data.size() != 0)
fft({(long)data.size()}, &data[0], &data[0], false);
}
// ----------------------------------------------------------------------------------------
template < typename T, typename Alloc >
void ifft_inplace (std::vector<std::complex<T>, Alloc>& data)
/*!
requires
- data contains elements of type std::complex<> that itself contains double, float, or long double.
ensures
- This function is identical to ifft() except that it does the inverse FFT
in-place. That is, after this function executes we will have:
- #data == ifft(data)*data.size()
- Note that the output needs to be divided by data.size() to complete the
inverse transformation.
!*/
{
static_assert(std::is_floating_point<T>::value, "only support floating point types");
if (data.size() != 0)
fft({(long)data.size()}, &data[0], &data[0], true);
}
// ----------------------------------------------------------------------------------------
}
#endif //DLIB_FFT_STL_Hh_

View File

@ -8,35 +8,11 @@
#include "../hash.h"
#include "../algs.h"
#include "../math.h"
#ifdef DLIB_USE_MKL_FFT
#include "mkl_fft.h"
#else
#include "kiss_fft.h"
#endif
#include "../fft/fft.h"
#include "../fft/fft_stl.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
constexpr bool is_power_of_two (const unsigned long n)
{
return n == 0 ? true : (n & (n - 1)) == 0;
}
// ----------------------------------------------------------------------------------------
constexpr long fftr_nc_size(long nc)
{
return nc == 0 ? 0 : nc/2+1;
}
// ----------------------------------------------------------------------------------------
constexpr long ifftr_nc_size(long nc)
{
return nc == 0 ? 0 : 2*(nc-1);
}
// ----------------------------------------------------------------------------------------
@ -47,13 +23,7 @@ namespace dlib
static_assert(std::is_floating_point<T>::value, "only support floating point types");
matrix<std::complex<T>,0,1> out(in.size());
if (in.size() != 0)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fft({(long)in.size()}, &in[0], &out(0,0), false);
#else
kiss_fft({(long)in.size()}, &in[0], &out(0,0), false);
#endif
}
fft({(long)in.size()}, &in[0], &out(0,0), false);
return out;
}
@ -66,13 +36,7 @@ namespace dlib
static_assert(std::is_floating_point<T>::value, "only support floating point types");
matrix<std::complex<T>,NR,NC,MM,L> out(in.nr(), in.nc());
if (in.size() != 0)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fft({in.nr(),in.nc()}, &in(0,0), &out(0,0), false);
#else
kiss_fft({in.nr(),in.nc()}, &in(0,0), &out(0,0), false);
#endif
}
fft({in.nr(),in.nc()}, &in(0,0), &out(0,0), false);
return out;
}
@ -97,11 +61,7 @@ namespace dlib
matrix<std::complex<T>,0,1> out(in.size());
if (in.size() != 0)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fft({(long)in.size()}, &in[0], &out(0,0), true);
#else
kiss_fft({(long)in.size()}, &in[0], &out(0,0), true);
#endif
fft({(long)in.size()}, &in[0], &out(0,0), true);
out /= out.size();
}
return out;
@ -117,11 +77,7 @@ namespace dlib
matrix<std::complex<T>,NR,NC,MM,L> out(in.nr(), in.nc());
if (in.size() != 0)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fft({in.nr(),in.nc()}, &in(0,0), &out(0,0), true);
#else
kiss_fft({in.nr(),in.nc()}, &in(0,0), &out(0,0), true);
#endif
fft({in.nr(),in.nc()}, &in(0,0), &out(0,0), true);
out /= out.size();
}
return out;
@ -148,13 +104,7 @@ namespace dlib
DLIB_ASSERT(in.nc() % 2 == 0, "last dimension " << in.nc() << " needs to be even otherwise ifftr(fftr(data)) won't have matching dimensions");
matrix<std::complex<T>,NR,fftr_nc_size(NC),MM,L> out(in.nr(), fftr_nc_size(in.nc()));
if (in.size() != 0)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fftr({in.nr(),in.nc()}, &in(0,0), &out(0,0));
#else
kiss_fftr({in.nr(),in.nc()}, &in(0,0), &out(0,0));
#endif
}
fftr({in.nr(),in.nc()}, &in(0,0), &out(0,0));
return out;
}
@ -179,11 +129,7 @@ namespace dlib
matrix<T,NR,ifftr_nc_size(NC),MM,L> out(in.nr(), ifftr_nc_size(in.nc()));
if (in.size() != 0)
{
#ifdef DLIB_USE_MKL_FFT
mkl_ifftr({out.nr(),out.nc()}, &in(0,0), &out(0,0));
#else
kiss_ifftr({out.nr(),out.nc()}, &in(0,0), &out(0,0));
#endif
ifftr({out.nr(),out.nc()}, &in(0,0), &out(0,0));
out /= out.size();
}
return out;
@ -200,22 +146,6 @@ namespace dlib
return ifftr(in);
}
// ----------------------------------------------------------------------------------------
template < typename T, typename Alloc >
void fft_inplace (std::vector<std::complex<T>, Alloc>& data)
{
static_assert(std::is_floating_point<T>::value, "only support floating point types");
if (data.size() != 0)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fft({(long)data.size()}, &data[0], &data[0], false);
#else
kiss_fft({(long)data.size()}, &data[0], &data[0], false);
#endif
}
}
// ----------------------------------------------------------------------------------------
template < typename T, long NR, long NC, typename MM, typename L >
@ -223,29 +153,7 @@ namespace dlib
{
static_assert(std::is_floating_point<T>::value, "only support floating point types");
if (data.size() != 0)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fft({data.nr(),data.nc()}, &data(0,0), &data(0,0), false);
#else
kiss_fft({data.nr(),data.nc()}, &data(0,0), &data(0,0), false);
#endif
}
}
// ----------------------------------------------------------------------------------------
template < typename T, typename Alloc >
void ifft_inplace (std::vector<std::complex<T>, Alloc>& data)
{
static_assert(std::is_floating_point<T>::value, "only support floating point types");
if (data.size() != 0)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fft({(long)data.size()}, &data[0], &data[0], true);
#else
kiss_fft({(long)data.size()}, &data[0], &data[0], true);
#endif
}
fft({data.nr(),data.nc()}, &data(0,0), &data(0,0), false);
}
// ----------------------------------------------------------------------------------------
@ -255,13 +163,7 @@ namespace dlib
{
static_assert(std::is_floating_point<T>::value, "only support floating point types");
if (data.size() != 0)
{
#ifdef DLIB_USE_MKL_FFT
mkl_fft({data.nr(),data.nc()}, &data(0,0), &data(0,0), true);
#else
kiss_fft({data.nr(),data.nc()}, &data(0,0), &data(0,0), true);
#endif
}
fft({data.nr(),data.nc()}, &data(0,0), &data(0,0), true);
}
// ----------------------------------------------------------------------------------------

View File

@ -9,37 +9,6 @@
namespace dlib
{
// ----------------------------------------------------------------------------------------
constexpr bool is_power_of_two (
const unsigned long value
);
/*!
ensures
- returns true if value contains a power of two and false otherwise. As a
special case, we also consider 0 to be a power of two.
!*/
// ----------------------------------------------------------------------------------------
constexpr long fftr_nc_size(
long nc
);
/*!
ensures
- returns the output dimension of a 1D real FFT
!*/
// ----------------------------------------------------------------------------------------
constexpr long ifftr_nc_size(
long nc
);
/*!
ensures
- returns the output dimension of an inverse 1D real FFT
!*/
// ----------------------------------------------------------------------------------------
template <typename EXP>
@ -176,21 +145,6 @@ namespace dlib
- #data == fft(data)
!*/
// ----------------------------------------------------------------------------------------
template < typename T, typename Alloc >
void fft_inplace (
std::vector<std::complex<T>, Alloc>& data
)
/*!
requires
- data contains elements of type std::complex<> that itself contains double, float, or long double.
ensures
- This function is identical to fft() except that it does the FFT in-place.
That is, after this function executes we will have:
- #data == fft(data)
!*/
// ----------------------------------------------------------------------------------------
template <
@ -214,23 +168,6 @@ namespace dlib
inverse transformation.
!*/
// ----------------------------------------------------------------------------------------
template < typename T, typename Alloc >
void ifft_inplace (
std::vector<std::complex<T>, Alloc>& data
);
/*!
requires
- data contains elements of type std::complex<> that itself contains double, float, or long double.
ensures
- This function is identical to ifft() except that it does the inverse FFT
in-place. That is, after this function executes we will have:
- #data == ifft(data)*data.size()
- Note that the output needs to be divided by data.size() to complete the
inverse transformation.
!*/
// ----------------------------------------------------------------------------------------
// These return function objects with signature double(size_t i, size_t wlen)