mirror of https://github.com/alibaba/MNN.git
509 lines
21 KiB
C++
509 lines
21 KiB
C++
//
|
|
// CPUStft.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2024/11/26.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
/**
|
|
Ref from onnxruntime
|
|
*/
|
|
|
|
#ifndef M_PI
|
|
#define M_PI 3.141592654
|
|
#endif
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include <complex>
|
|
#include "backend/cpu/CPUStft.hpp"
|
|
#include "backend/cpu/CPUBackend.hpp"
|
|
#include "core/Concurrency.h"
|
|
#include "core/TensorUtils.hpp"
|
|
#include "core/Macro.h"
|
|
#include "compute/CommonOptFunction.h"
|
|
|
|
namespace MNN {
|
|
|
|
#define ___RETURN_IF_ERROR(x) {auto code = (x); if (NO_ERROR != code) {return code;}}
|
|
#define ___RETURN_IF(x, y) {if (x) {return NOT_SUPPORT;}}
|
|
|
|
static bool is_real_valued_signal(const Tensor* shape) {
|
|
return shape->dimensions() == 2 || shape->length(shape->dimensions() -1) == 1;
|
|
}
|
|
|
|
static bool is_complex_valued_signal(const Tensor* shape) {
|
|
return shape->dimensions() > 2 && shape->length(shape->dimensions() -1) == 2;
|
|
}
|
|
|
|
static bool is_power_of_2(size_t size) {
|
|
size_t n_bits = 0;
|
|
while (size != 0) {
|
|
n_bits += size & 1;
|
|
size = size >> 1;
|
|
}
|
|
return n_bits == 1;
|
|
}
|
|
|
|
static const unsigned char BitReverseTable256[] = {
|
|
0x00, 0x80, 0x40, 0xC0, 0x20, 0xA0, 0x60, 0xE0, 0x10, 0x90, 0x50, 0xD0, 0x30, 0xB0, 0x70, 0xF0, 0x08, 0x88, 0x48,
|
|
0xC8, 0x28, 0xA8, 0x68, 0xE8, 0x18, 0x98, 0x58, 0xD8, 0x38, 0xB8, 0x78, 0xF8, 0x04, 0x84, 0x44, 0xC4, 0x24, 0xA4,
|
|
0x64, 0xE4, 0x14, 0x94, 0x54, 0xD4, 0x34, 0xB4, 0x74, 0xF4, 0x0C, 0x8C, 0x4C, 0xCC, 0x2C, 0xAC, 0x6C, 0xEC, 0x1C,
|
|
0x9C, 0x5C, 0xDC, 0x3C, 0xBC, 0x7C, 0xFC, 0x02, 0x82, 0x42, 0xC2, 0x22, 0xA2, 0x62, 0xE2, 0x12, 0x92, 0x52, 0xD2,
|
|
0x32, 0xB2, 0x72, 0xF2, 0x0A, 0x8A, 0x4A, 0xCA, 0x2A, 0xAA, 0x6A, 0xEA, 0x1A, 0x9A, 0x5A, 0xDA, 0x3A, 0xBA, 0x7A,
|
|
0xFA, 0x06, 0x86, 0x46, 0xC6, 0x26, 0xA6, 0x66, 0xE6, 0x16, 0x96, 0x56, 0xD6, 0x36, 0xB6, 0x76, 0xF6, 0x0E, 0x8E,
|
|
0x4E, 0xCE, 0x2E, 0xAE, 0x6E, 0xEE, 0x1E, 0x9E, 0x5E, 0xDE, 0x3E, 0xBE, 0x7E, 0xFE, 0x01, 0x81, 0x41, 0xC1, 0x21,
|
|
0xA1, 0x61, 0xE1, 0x11, 0x91, 0x51, 0xD1, 0x31, 0xB1, 0x71, 0xF1, 0x09, 0x89, 0x49, 0xC9, 0x29, 0xA9, 0x69, 0xE9,
|
|
0x19, 0x99, 0x59, 0xD9, 0x39, 0xB9, 0x79, 0xF9, 0x05, 0x85, 0x45, 0xC5, 0x25, 0xA5, 0x65, 0xE5, 0x15, 0x95, 0x55,
|
|
0xD5, 0x35, 0xB5, 0x75, 0xF5, 0x0D, 0x8D, 0x4D, 0xCD, 0x2D, 0xAD, 0x6D, 0xED, 0x1D, 0x9D, 0x5D, 0xDD, 0x3D, 0xBD,
|
|
0x7D, 0xFD, 0x03, 0x83, 0x43, 0xC3, 0x23, 0xA3, 0x63, 0xE3, 0x13, 0x93, 0x53, 0xD3, 0x33, 0xB3, 0x73, 0xF3, 0x0B,
|
|
0x8B, 0x4B, 0xCB, 0x2B, 0xAB, 0x6B, 0xEB, 0x1B, 0x9B, 0x5B, 0xDB, 0x3B, 0xBB, 0x7B, 0xFB, 0x07, 0x87, 0x47, 0xC7,
|
|
0x27, 0xA7, 0x67, 0xE7, 0x17, 0x97, 0x57, 0xD7, 0x37, 0xB7, 0x77, 0xF7, 0x0F, 0x8F, 0x4F, 0xCF, 0x2F, 0xAF, 0x6F,
|
|
0xEF, 0x1F, 0x9F, 0x5F, 0xDF, 0x3F, 0xBF, 0x7F, 0xFF};
|
|
|
|
template <typename T>
|
|
static inline T bit_reverse(T num, unsigned significant_bits) {
|
|
if (significant_bits > 32) {
|
|
MNN_ERROR("Unsupported bit size.");
|
|
}
|
|
uint32_t num_32 = static_cast<uint32_t>(num);
|
|
uint32_t rev = (BitReverseTable256[num_32 & 0xff] << 24) | (BitReverseTable256[(num_32 >> 8) & 0xff] << 16) |
|
|
(BitReverseTable256[(num_32 >> 16) & 0xff] << 8) | (BitReverseTable256[(num_32 >> 24) & 0xff]);
|
|
return static_cast<T>(((uint64_t)rev) >> (32 - significant_bits));
|
|
}
|
|
|
|
template <typename T>
|
|
static T compute_angular_velocity(size_t number_of_samples, bool inverse) {
|
|
// Calculate fundamental angular velocity
|
|
static const T pi = static_cast<T>(M_PI);
|
|
static const T tau = 2 * pi;
|
|
T inverse_switch = inverse ? 1.f : -1.f;
|
|
T angular_velocity = inverse_switch * tau / number_of_samples;
|
|
return angular_velocity;
|
|
}
|
|
|
|
template <typename T>
|
|
static std::complex<T> compute_exponential(size_t index, const T angular_velocity) {
|
|
const T angle = static_cast<T>(index) * angular_velocity;
|
|
return std::complex<T>(cos(angle), sin(angle));
|
|
}
|
|
|
|
template <typename T, typename U>
|
|
static ErrorCode fft_radix2(Backend* backend, const Tensor* X, Tensor* Y, size_t X_offset, size_t X_stride,
|
|
size_t Y_offset, size_t Y_stride, int64_t axis, size_t dft_length, const Tensor* window,
|
|
bool is_onesided, bool inverse, std::vector<std::complex<T>>& V,
|
|
std::vector<std::complex<T>>& temp_output) {
|
|
// Get shape and significant bits
|
|
const auto X_shape = X->shape();
|
|
size_t number_of_samples = static_cast<size_t>(X_shape[axis]);
|
|
unsigned significant_bits = static_cast<unsigned>(log2(dft_length));
|
|
|
|
// Get data
|
|
auto* X_data = const_cast<U*>(reinterpret_cast<const U*>(X->host<void>())) + X_offset;
|
|
// Get window
|
|
float* window_data = nullptr;
|
|
if (window) {
|
|
window_data = const_cast<float*>(reinterpret_cast<const float*>(window->host<void>()));
|
|
}
|
|
|
|
size_t Y_data_stride = 1;
|
|
std::complex<T>* Y_data;
|
|
if (is_onesided) {
|
|
if (temp_output.size() != dft_length) {
|
|
temp_output.resize(dft_length);
|
|
}
|
|
Y_data = temp_output.data();
|
|
} else {
|
|
Y_data = reinterpret_cast<std::complex<T>*>(Y->host<void>()) + Y_offset;
|
|
Y_data_stride = Y_stride;
|
|
}
|
|
|
|
auto angular_velocity = compute_angular_velocity<T>(dft_length, inverse);
|
|
|
|
// Create vandermonde matrix V ordered with the bit-reversed permutation
|
|
if (V.size() != dft_length) {
|
|
V.resize(dft_length);
|
|
for (size_t i = 0; i < dft_length; i++) {
|
|
size_t bit_reversed_index = bit_reverse(i, significant_bits);
|
|
V[bit_reversed_index] = compute_exponential(i, angular_velocity);
|
|
}
|
|
}
|
|
|
|
for (size_t i = 0; i < dft_length; i++) {
|
|
size_t bit_reversed_index = bit_reverse(i, significant_bits);
|
|
auto x = (bit_reversed_index < number_of_samples) ? *(X_data + bit_reversed_index * X_stride) : 0;
|
|
auto window_element = window_data ? *(window_data + bit_reversed_index) : 1;
|
|
*(Y_data + i * Y_data_stride) = std::complex<T>(1, 0) * x * window_element;
|
|
}
|
|
|
|
// Run fft_radix2
|
|
unsigned current_significant_bits = 0;
|
|
for (size_t i = 2; i <= dft_length; i <<= 1) {
|
|
size_t midpoint = i >> 1;
|
|
current_significant_bits++;
|
|
|
|
for (size_t k = 0; k < midpoint; k++) {
|
|
auto first_idx = bit_reverse(k, current_significant_bits);
|
|
auto second_idx = bit_reverse(midpoint + k, current_significant_bits);
|
|
for (size_t j = 0; j < dft_length; j += i) {
|
|
auto even_index = k + j;
|
|
auto odd_index = k + j + midpoint;
|
|
std::complex<T>* even = (Y_data + even_index * Y_data_stride);
|
|
std::complex<T>* odd = (Y_data + odd_index * Y_data_stride);
|
|
std::complex<T> first = *even + (V[first_idx] * *odd);
|
|
std::complex<T> second = *even + (V[second_idx] * *odd);
|
|
*even = first;
|
|
*odd = second;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Scale the output if inverse
|
|
if (inverse) {
|
|
for (size_t i = 0; i < dft_length; i++) {
|
|
std::complex<T>& val = *(Y_data + i * Y_data_stride);
|
|
val /= static_cast<T>(dft_length);
|
|
}
|
|
}
|
|
|
|
if (is_onesided) {
|
|
const size_t output_size = (dft_length >> 1) + 1;
|
|
auto destination = reinterpret_cast<std::complex<T>*>(Y->host<void>()) + Y_offset;
|
|
for (size_t i = 0; i < output_size; i++) {
|
|
*(destination + Y_stride * i) = *(Y_data + i * Y_data_stride);
|
|
}
|
|
}
|
|
|
|
return NO_ERROR;
|
|
}
|
|
|
|
template <typename T>
|
|
T next_power_of_2(T in) {
|
|
in--;
|
|
T out = 1;
|
|
while (out <= in) {
|
|
out <<= 1;
|
|
}
|
|
return out;
|
|
}
|
|
|
|
template <typename T, typename U>
|
|
static ErrorCode dft_bluestein_z_chirp(Backend* bn, const Tensor* X, Tensor* Y, std::shared_ptr<Tensor>& b_fft_p, std::shared_ptr<Tensor>& chirp_p, size_t X_offset, size_t X_stride, size_t Y_offset, size_t Y_stride,
|
|
int64_t axis, size_t dft_length, const Tensor* window, bool inverse, std::vector<std::complex<T>>& V,
|
|
std::vector<std::complex<T>>& temp_output) {
|
|
static const T pi = static_cast<T>(M_PI);
|
|
|
|
size_t N = static_cast<size_t>(dft_length);
|
|
size_t M = next_power_of_2(2 * N - 1);
|
|
auto dft_input_shape = std::vector<int>({1, (int)M, 2});
|
|
T scale = inverse ? 1.f / N : 1.f;
|
|
T direction = inverse ? 1.f : -1.f;
|
|
|
|
bool should_recreate_b_fft = b_fft_p->elementSize() != M * 2;
|
|
bool should_recreate_chirp = chirp_p->elementSize() != M * 2;
|
|
bool should_recreate = should_recreate_b_fft || should_recreate_chirp;
|
|
if (should_recreate) {
|
|
std::shared_ptr<Tensor> b_p(Tensor::create(dft_input_shape, X->getType()));
|
|
auto& b = *b_p;
|
|
b_fft_p.reset(Tensor::create(dft_input_shape, Y->getType()));
|
|
auto& b_fft = *b_fft_p;
|
|
chirp_p.reset(Tensor::create(dft_input_shape, X->getType()));
|
|
auto& chirp = *chirp_p;
|
|
|
|
std::complex<T>* b_data = reinterpret_cast<std::complex<T>*>(b.host<void>());
|
|
std::complex<T>* b_fft_data = reinterpret_cast<std::complex<T>*>(b_fft.host<void>());
|
|
std::complex<T>* chirp_data = reinterpret_cast<std::complex<T>*>(chirp.host<void>());
|
|
memset(reinterpret_cast<void*>(b_data), 0, b.usize());
|
|
memset(reinterpret_cast<void*>(b_fft_data), 0, b_fft.usize());
|
|
memset(reinterpret_cast<void*>(chirp_data), 0, chirp.usize());
|
|
|
|
for (size_t n = 0; n < N; n++) {
|
|
std::complex<T>& chirp_n = *(chirp_data + n);
|
|
// chirp
|
|
auto exponent = direction * pi * n * n / N;
|
|
chirp_n = std::complex<T>(cos(exponent), sin(exponent));
|
|
|
|
// b
|
|
std::complex<T>& b_n = *(b_data + n);
|
|
b_n = std::conj(chirp_n);
|
|
}
|
|
|
|
for (size_t n = M - N + 1; n < M; n++) {
|
|
std::complex<T>& b_n = *(b_data + n);
|
|
std::complex<T>& b_m_minus_n = *(b_data + M - n);
|
|
b_n = b_m_minus_n;
|
|
}
|
|
|
|
// Forward FFT radix2 for the "b" signal
|
|
// This will be cached and reused!
|
|
auto code = ((fft_radix2<T, std::complex<T>>(bn, &b, &b_fft, 0, 1, 0, 1, 1, M, nullptr,
|
|
false, false, V, temp_output)));
|
|
if (NO_ERROR != code) {
|
|
FUNC_PRINT(1);
|
|
return code;
|
|
}
|
|
}
|
|
|
|
// Get data
|
|
auto* X_data = const_cast<U*>(reinterpret_cast<const U*>(X->host<void>())) + X_offset;
|
|
auto* Y_data = reinterpret_cast<std::complex<T>*>(Y->host<void>()) + Y_offset;
|
|
float* window_data = nullptr;
|
|
if (window) {
|
|
window_data = const_cast<float*>(reinterpret_cast<const float*>(window->host<void>()));
|
|
}
|
|
std::shared_ptr<Tensor> a_p(Tensor::create(dft_input_shape, X->getType()));
|
|
auto& a = *a_p;
|
|
std::shared_ptr<Tensor> a_fft_p(Tensor::create(dft_input_shape, Y->getType()));
|
|
auto& a_fft = *a_fft_p;
|
|
std::complex<T>* a_data = reinterpret_cast<std::complex<T>*>(a.host<void>());
|
|
std::complex<T>* a_fft_data = reinterpret_cast<std::complex<T>*>(a_fft.host<void>());
|
|
std::complex<T>* b_fft_data = reinterpret_cast<std::complex<T>*>(b_fft_p->host<void>());
|
|
std::complex<T>* chirp_data = reinterpret_cast<std::complex<T>*>(chirp_p->host<void>());
|
|
memset(reinterpret_cast<void*>(a_data), 0, a.usize());
|
|
|
|
const auto& X_shape = X->shape();
|
|
size_t number_of_samples = static_cast<size_t>(X_shape[axis]);
|
|
|
|
// Prepare "a" signal
|
|
for (size_t n = 0; n < number_of_samples; n++) {
|
|
std::complex<T>& a_n = *(a_data + n);
|
|
std::complex<T>& chirp_n = *(chirp_data + n);
|
|
auto window_n = window_data ? *(window_data + n) : 1;
|
|
a_n = *(X_data + n * X_stride); // input
|
|
a_n *= window_n;
|
|
a_n *= chirp_n;
|
|
}
|
|
|
|
// Forward FFT radix2 for the "a" signal
|
|
{
|
|
auto code = ((fft_radix2<T, std::complex<T>>(bn, &a, &a_fft, 0, 1, 0, 1, 1, M, nullptr,
|
|
false, false, V, temp_output)));
|
|
if (NO_ERROR != code) {
|
|
return code;
|
|
}
|
|
}
|
|
|
|
for (size_t i = 0; i < M; i++) {
|
|
std::complex<T>& a_i = *(a_fft_data + i);
|
|
std::complex<T>& b_i = *(b_fft_data + i);
|
|
a_i *= b_i;
|
|
}
|
|
|
|
// Inverse FFT radix2 for the "a" signal
|
|
{
|
|
auto code = ((fft_radix2<T, std::complex<T>>(bn, &a_fft, &a, 0, 1, 0, 1, 1, M, nullptr,
|
|
false, true, V, temp_output)));
|
|
if (NO_ERROR != code) {
|
|
return code;
|
|
}
|
|
}
|
|
const auto& Y_shape = Y->shape();
|
|
size_t dft_output_size = static_cast<size_t>(Y_shape[(axis)]);
|
|
|
|
for (size_t i = 0; i < dft_output_size; i++) {
|
|
std::complex<T>& chirp_i = *(chirp_data + i);
|
|
std::complex<T>& out = *(Y_data + i * Y_stride);
|
|
std::complex<T>& c_i = *(a_data + i);
|
|
if (i > 0) {
|
|
// The inverse fft is computed using the same cached vandermonde matrix (V) created by the
|
|
// forward fft. This reversal causes the output to be reversed as well.
|
|
// Therefore we undo the reversal when writing the output back out.
|
|
c_i = *(a_data + M - i);
|
|
}
|
|
out = c_i * chirp_i * scale;
|
|
}
|
|
return NO_ERROR;
|
|
}
|
|
|
|
template <typename T, typename U>
|
|
static ErrorCode discrete_fourier_transform(Backend* ctx, const Tensor* X, Tensor* Y, std::shared_ptr<Tensor>& b_fft, std::shared_ptr<Tensor>& chirp,
|
|
int64_t axis, int64_t dft_length, const Tensor* window, bool is_onesided, bool inverse,
|
|
std::vector<std::complex<T>>& V,
|
|
std::vector<std::complex<T>>& temp_output) {
|
|
// Get shape
|
|
const auto& X_shape = X->shape();
|
|
const auto& Y_shape = Y->shape();
|
|
|
|
auto batch_and_signal_rank = X->dimensions();
|
|
auto total_dfts = static_cast<size_t>(X->elementSize() / X->length(axis));
|
|
|
|
auto is_input_real = X->dimensions() == 2 || X->length(X->dimensions() - 1) == 1;
|
|
auto complex_input_factor = is_input_real ? 1 : 2;
|
|
if (X->dimensions() > 2) {
|
|
total_dfts /= (X->length(X->dimensions() - 1));
|
|
batch_and_signal_rank -= 1;
|
|
}
|
|
|
|
// Calculate x/y offsets/strides
|
|
for (size_t i = 0; i < total_dfts; i++) {
|
|
size_t X_offset = 0;
|
|
size_t X_stride = X->stride(axis) / complex_input_factor;
|
|
size_t cumulative_packed_stride = total_dfts;
|
|
size_t temp = i;
|
|
for (size_t r = 0; r < batch_and_signal_rank; r++) {
|
|
if (r == static_cast<size_t>(axis)) {
|
|
continue;
|
|
}
|
|
cumulative_packed_stride /= (X_shape[r]);
|
|
auto index = temp / cumulative_packed_stride;
|
|
temp -= (index * cumulative_packed_stride);
|
|
X_offset += index * X->stride(r) / complex_input_factor;
|
|
}
|
|
|
|
size_t Y_offset = 0;
|
|
size_t Y_stride = Y->stride(axis) / 2;
|
|
cumulative_packed_stride = total_dfts;
|
|
temp = i;
|
|
for (size_t r = 0; r < batch_and_signal_rank; r++) {
|
|
if (r == static_cast<size_t>(axis)) {
|
|
continue;
|
|
}
|
|
cumulative_packed_stride /= (X_shape[r]);
|
|
auto index = temp / cumulative_packed_stride;
|
|
temp -= (index * cumulative_packed_stride);
|
|
Y_offset += index * (size_t)(Y->stride(r) / 2);
|
|
}
|
|
|
|
if (is_power_of_2((dft_length))) {
|
|
___RETURN_IF_ERROR((fft_radix2<T, U>(ctx, X, Y, X_offset, X_stride, Y_offset, Y_stride, axis, (dft_length), window,
|
|
is_onesided, inverse, V, temp_output)));
|
|
} else {
|
|
___RETURN_IF_ERROR(
|
|
(dft_bluestein_z_chirp<T, U>(ctx, X, Y, b_fft, chirp, X_offset, X_stride, Y_offset, Y_stride, axis, (dft_length), window, inverse, V, temp_output)));
|
|
}
|
|
}
|
|
return NO_ERROR;
|
|
}
|
|
|
|
static ErrorCode discrete_fourier_transform(Backend* ctx, int64_t axis, bool is_onesided, bool inverse, Tensor* X, Tensor* dft_length, Tensor* Y) {
|
|
// Get input shape
|
|
const auto is_real_valued = is_real_valued_signal(X);
|
|
const auto is_complex_valued = is_complex_valued_signal(X);
|
|
if (axis < 0) {
|
|
axis = axis + X->dimensions();
|
|
}
|
|
|
|
int64_t number_of_samples = static_cast<int64_t>(X->length(axis));
|
|
if (dft_length) {
|
|
const auto& dft_length_shape = dft_length->shape();
|
|
number_of_samples = dft_length->host<int>()[0];
|
|
}
|
|
|
|
// Get the DFT output size. Onesided will return only the unique values!
|
|
// note: x >> 1 === std::floor(x / 2.f)
|
|
auto dft_output_size = is_onesided ? ((number_of_samples >> 1) + 1) : number_of_samples;
|
|
|
|
std::shared_ptr<Tensor> b_fft(new Tensor), chirp(new Tensor);
|
|
std::vector<std::complex<float>> V;
|
|
std::vector<std::complex<float>> temp_output;
|
|
if (is_real_valued) {
|
|
___RETURN_IF_ERROR((discrete_fourier_transform<float, float>(ctx, X, Y, b_fft, chirp, axis, number_of_samples, nullptr,
|
|
is_onesided, inverse, V, temp_output)));
|
|
} else if (is_complex_valued) {
|
|
___RETURN_IF_ERROR((discrete_fourier_transform<float, std::complex<float>>(
|
|
ctx, X, Y, b_fft, chirp, axis, number_of_samples, nullptr, is_onesided, inverse, V, temp_output)));
|
|
}
|
|
|
|
return NO_ERROR;
|
|
}
|
|
|
|
template <typename T, typename U>
|
|
static ErrorCode short_time_fourier_transform(Backend* ctx, Tensor* signal, Tensor* Y, int frame_step, Tensor* window, bool is_onesided, bool /*inverse*/) {
|
|
// Attr("onesided"): default = 1
|
|
// Input(0, "signal") type = T1
|
|
// Input(1, "frame_length") type = T2
|
|
// Input(2, "window") type = T1, optional
|
|
// Input(3, "frame_step") type = T2
|
|
// Output(0, "output") type = T1
|
|
|
|
// Get input signal shape
|
|
const auto& signal_shape = signal->shape();
|
|
const auto batch_size = signal_shape[0];
|
|
const auto signal_size = signal_shape[1];
|
|
const auto signal_components = signal_shape.size() == 2 ? 1 : signal_shape[2];
|
|
// Get the frame length
|
|
int frame_length = window->length(0);
|
|
// Get window length
|
|
// Calculate the window size with preference to the window input.
|
|
const auto window_size = frame_length;
|
|
MNN_ASSERT(window_size <= signal_size);
|
|
// Calculate the number of dfts to run
|
|
const auto n_dfts =
|
|
static_cast<int64_t>(std::floor((signal_size - window_size) / static_cast<float>(frame_step))) + 1;
|
|
|
|
// Calculate the output spectra length (onesided will return only the unique values)
|
|
// note: x >> 1 === std::floor(x / 2.f)
|
|
const auto dft_output_size = is_onesided ? (window_size >> 1) + 1 : window_size;
|
|
|
|
auto Y_data = reinterpret_cast<float*>(Y->host<void>());
|
|
|
|
// Get/create the signal mutable data
|
|
auto* signal_data = const_cast<float*>(reinterpret_cast<const float*>(signal->host<void>()));
|
|
|
|
// Define tensor shapes for each dft run
|
|
const int output_components = 2;
|
|
auto dft_input_shape = std::vector<int>{1, window_size, signal_components};
|
|
auto dft_output_shape = std::vector<int>{1, dft_output_size, output_components};
|
|
|
|
std::shared_ptr<Tensor> b_fft(new Tensor), chirp(new Tensor);
|
|
std::vector<std::complex<T>> V;
|
|
std::vector<std::complex<T>> temp_output;
|
|
// Tensors do not own the backing memory, so no worries on destruction
|
|
std::shared_ptr<Tensor> input(Tensor::createDevice(dft_input_shape, signal->getType()));
|
|
std::shared_ptr<Tensor> output(Tensor::createDevice(dft_output_shape, Y->getType()));
|
|
|
|
// Run each dft of each batch as if it was a real-valued batch size 1 dft operation
|
|
for (int64_t batch_idx = 0; batch_idx < batch_size; batch_idx++) {
|
|
for (int64_t i = 0; i < n_dfts; i++) {
|
|
auto input_frame_begin =
|
|
signal_data + (batch_idx * signal_size * signal_components) + (i * frame_step * signal_components);
|
|
|
|
auto output_frame_begin = Y_data + (batch_idx * n_dfts * dft_output_size * output_components) + (i * dft_output_size * output_components);
|
|
input->buffer().host = (uint8_t*)input_frame_begin;
|
|
output->buffer().host = (uint8_t*)output_frame_begin;
|
|
// Run individual dft
|
|
___RETURN_IF_ERROR((discrete_fourier_transform<T, U>(ctx, input.get(), output.get(), b_fft, chirp, 1, window_size, window, is_onesided, false, V, temp_output)));
|
|
}
|
|
}
|
|
|
|
return NO_ERROR;
|
|
}
|
|
|
|
|
|
CPUStft::CPUStft(Backend* backend, bool abs)
|
|
: Execution(backend), mAbs(abs) {
|
|
// nothing to do
|
|
}
|
|
|
|
ErrorCode CPUStft::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
|
|
|
return NO_ERROR;
|
|
}
|
|
|
|
ErrorCode CPUStft::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
|
auto signal = inputs[0];
|
|
const auto is_real_valued = is_real_valued_signal(signal);
|
|
const auto is_complex_valued = is_complex_valued_signal(signal);
|
|
int frameStep = inputs[1]->host<int>()[0];
|
|
if (is_real_valued) {
|
|
___RETURN_IF_ERROR((short_time_fourier_transform<float, float>(backend(), inputs[0], outputs[0], frameStep, inputs[2], mAbs, false)));
|
|
} else if (is_complex_valued) {
|
|
___RETURN_IF_ERROR((short_time_fourier_transform<float, std::complex<float>>(backend(), inputs[0], outputs[0], frameStep, inputs[2], mAbs, false)));
|
|
} else {
|
|
MNN_ASSERT(false);
|
|
}
|
|
return NO_ERROR;
|
|
}
|
|
|
|
class CPUStftCreator : public CPUBackend::Creator {
|
|
public:
|
|
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
|
const MNN::Op* op, Backend* backend) const {
|
|
auto stft = op->main_as_StftParam();
|
|
return new CPUStft(backend, stft->abs());
|
|
}
|
|
};
|
|
|
|
REGISTER_CPU_OP_CREATOR(CPUStftCreator, OpType_Stft);
|
|
} // namespace MNN
|