MNN/tools/audio/source/audio.cpp

559 lines
21 KiB
C++

//
// audio.cpp
// MNN
//
// Created by MNN on 2024/11/15.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "audio/audio.hpp"
#include <MNN/expr/MathOp.hpp>
#include <MNN/expr/NeuralNetWorkOp.hpp>
#include "MNN_generated.h"
#include <cmath>
#include <algorithm>
#include <complex>
#include <fstream>
#include <iostream>
#include <limits>
#ifndef M_PI
#define M_PI 3.141592654
#endif
#ifdef _MSC_VER
#define NOMINMAX
#include <intrin.h>
#include <windows.h>
#endif
namespace MNN {
namespace AUDIO {
#ifdef _MSC_VER
inline uint32_t mnn_clz( uint32_t value ) {
DWORD leading_zero = 0;
if (_BitScanReverse(&leading_zero, value)) {
return 31 - leading_zero;
}else {
// Same remarks as above
return 32;
}
}
#else
inline uint32_t mnn_clz( uint32_t value ) {
return __builtin_clz(value);
}
#endif
struct WaveHeader {
void SeekToDataChunk(std::istream &is) {
// a t a d
while (is && subchunk2_id != 0x61746164) {
is.seekg(subchunk2_size, std::istream::cur);
is.read(reinterpret_cast<char *>(&subchunk2_id), sizeof(int32_t));
is.read(reinterpret_cast<char *>(&subchunk2_size), sizeof(int32_t));
}
}
int32_t chunk_id = 0x46464952; // "RIFF"
int32_t chunk_size;
int32_t format = 0x45564157; // "WAVE"
int32_t subchunk1_id = 0x20746d66; // "fmt "
int32_t subchunk1_size = 16; // PCM
int16_t audio_format = 1; // PCM = 1
int16_t num_channels = 1; // Mono
int32_t sample_rate;
int32_t byte_rate;
int16_t block_align;
int16_t bits_per_sample = 16;
int32_t subchunk2_id = 0x61746164; // "data"
int32_t subchunk2_size;
};
std::pair<VARP, int> load(const std::string &filename, int sr, int frame_offset, int num_frames) {
std::ifstream is(filename, std::ifstream::binary);
auto ret = std::make_pair<VARP, int>(nullptr, 0);
if (!is) {
MNN_ERROR("Failed to open file: %s\n", filename.c_str());
return ret;
}
WaveHeader header{};
is.read(reinterpret_cast<char *>(&header.chunk_id), sizeof(header.chunk_id));
if (header.chunk_id != 0x46464952) { // "RIFF"
MNN_ERROR("Expected chunk_id RIFF. Given: 0x%08x\n", header.chunk_id);
return ret;
}
is.read(reinterpret_cast<char *>(&header.chunk_size), sizeof(header.chunk_size));
is.read(reinterpret_cast<char *>(&header.format), sizeof(header.format));
if (header.format != 0x45564157) { // "WAVE"
MNN_ERROR("Expected format WAVE. Given: 0x%08x\n", header.format);
return ret;
}
is.read(reinterpret_cast<char *>(&header.subchunk1_id), sizeof(header.subchunk1_id));
is.read(reinterpret_cast<char *>(&header.subchunk1_size), sizeof(header.subchunk1_size));
if (header.subchunk1_id == 0x4b4e554a) { // "JUNK"
is.seekg(header.subchunk1_size, std::istream::cur);
is.read(reinterpret_cast<char *>(&header.subchunk1_id), sizeof(header.subchunk1_id));
is.read(reinterpret_cast<char *>(&header.subchunk1_size), sizeof(header.subchunk1_size));
}
if (header.subchunk1_id != 0x20746d66) { // "fmt "
MNN_ERROR("Expected subchunk1_id 'fmt '. Given: 0x%08x\n", header.subchunk1_id);
return ret;
}
if (header.subchunk1_size != 16 && header.subchunk1_size != 18) {
MNN_ERROR("Expected subchunk1_size 16 or 18. Given: %d\n", header.subchunk1_size);
return ret;
}
is.read(reinterpret_cast<char *>(&header.audio_format), sizeof(header.audio_format));
if (header.audio_format != 1 && header.audio_format != 3) {
MNN_ERROR("Unsupported audio_format: %d. Only PCM(1) and IEEE Float(3) supported.\n", header.audio_format);
return ret;
}
is.read(reinterpret_cast<char *>(&header.num_channels), sizeof(header.num_channels));
if (header.num_channels != 1) {
MNN_ERROR("Warning: %d channels found. Only the first channel will be used.\n", header.num_channels);
}
is.read(reinterpret_cast<char *>(&header.sample_rate), sizeof(header.sample_rate));
is.read(reinterpret_cast<char *>(&header.byte_rate), sizeof(header.byte_rate));
is.read(reinterpret_cast<char *>(&header.block_align), sizeof(header.block_align));
is.read(reinterpret_cast<char *>(&header.bits_per_sample), sizeof(header.bits_per_sample));
if (header.byte_rate != (header.sample_rate * header.num_channels * header.bits_per_sample / 8)) {
MNN_ERROR("Incorrect byte rate: %d. Expected: %d\n", header.byte_rate,
header.sample_rate * header.num_channels * header.bits_per_sample / 8);
return ret;
}
if (header.block_align != (header.num_channels * header.bits_per_sample / 8)) {
MNN_ERROR("Incorrect block align: %d. Expected: %d\n", header.block_align,
header.num_channels * header.bits_per_sample / 8);
return ret;
}
if (header.bits_per_sample != 8 && header.bits_per_sample != 16 && header.bits_per_sample != 32) {
MNN_ERROR("Unsupported bits_per_sample: %d. Only 8, 16, or 32 bits per sample supported.\n",
header.bits_per_sample);
return ret;
}
if (header.subchunk1_size == 18) {
int16_t extra_size;
is.read(reinterpret_cast<char *>(&extra_size), sizeof(int16_t));
if (extra_size != 0) {
MNN_ERROR("Unexpected extra size: %d. Expected 0.\n", extra_size);
return ret;
}
}
is.read(reinterpret_cast<char *>(&header.subchunk2_id), sizeof(header.subchunk2_id));
is.read(reinterpret_cast<char *>(&header.subchunk2_size), sizeof(header.subchunk2_size));
header.SeekToDataChunk(is);
if (!is) {
MNN_ERROR("Could not locate data chunk.\n");
return ret;
}
int total_frames = header.subchunk2_size / header.block_align;
if (frame_offset < 0 || frame_offset >= total_frames) {
MNN_ERROR("Frame offset out of range.\n");
return ret;
}
if (num_frames <= 0 || frame_offset + num_frames > total_frames) {
num_frames = total_frames - frame_offset;
}
is.seekg(frame_offset * header.block_align, std::istream::cur);
ret.first = _Input({num_frames}, NHWC);
ret.second = header.sample_rate;
auto ans_ptr = ret.first->writeMap<float>();
if (header.bits_per_sample == 16 && header.audio_format == 1) {
std::vector<int16_t> samples(num_frames * header.num_channels);
is.read(reinterpret_cast<char *>(samples.data()), num_frames * header.block_align);
if (!is) {
MNN_ERROR("Failed to read audio data.\n");
return ret;
}
for (int i = 0; i < num_frames; ++i) {
ans_ptr[i] = samples[i * header.num_channels] / 32768.f;
}
} else if (header.bits_per_sample == 8 && header.audio_format == 1) {
std::vector<uint8_t> samples(num_frames * header.num_channels);
is.read(reinterpret_cast<char *>(samples.data()), num_frames * header.block_align);
if (!is) {
MNN_ERROR("Failed to read audio data.\n");
return ret;
}
for (int i = 0; i < num_frames; ++i) {
ans_ptr[i] = static_cast<float>(samples[i * header.num_channels]) / 128.f - 1.f;
}
} else if (header.bits_per_sample == 32 && header.audio_format == 1) {
std::vector<int32_t> samples(num_frames * header.num_channels);
is.read(reinterpret_cast<char *>(samples.data()), num_frames * header.block_align);
if (!is) {
MNN_ERROR("Failed to read audio data.\n");
return ret;
}
for (int i = 0; i < num_frames; ++i) {
ans_ptr[i] = static_cast<float>(samples[i * header.num_channels]) / static_cast<float>(INT32_MAX);
}
} else if (header.bits_per_sample == 32 && header.audio_format == 3) {
std::vector<float> samples(num_frames * header.num_channels);
is.read(reinterpret_cast<char *>(samples.data()), num_frames * header.block_align);
if (!is) {
MNN_ERROR("Failed to read audio data.\n");
return ret;
}
for (int i = 0; i < num_frames; ++i) {
ans_ptr[i] = samples[i * header.num_channels];
}
} else {
MNN_ERROR("Unsupported bits per sample: %d or audio format: %d.\n", header.bits_per_sample,
header.audio_format);
return ret;
}
if (sr > 0 && sr != ret.second) {
// resample
float resample_ratio = static_cast<float>(sr) / header.sample_rate;
int resample_num_frames = static_cast<int>(num_frames * resample_ratio);
auto resampled_data = _Input({resample_num_frames}, NHWC);
auto src = ret.first->readMap<float>();
auto dst = resampled_data->writeMap<float>();
for (int i = 0; i < resample_num_frames; ++i) {
float interp_index = i / resample_ratio;
int low_index = static_cast<int>(interp_index);
int high_index = std::min(low_index + 1, num_frames - 1);
float frac = interp_index - low_index;
dst[i] = (1 - frac) * src[low_index] + frac * src[high_index];
}
ret.first = resampled_data;
ret.second = sr;
}
return ret;
}
bool save(const std::string &filename, VARP audio, int sample_rate) {
std::ofstream os(filename, std::ios::binary);
if (!os) {
MNN_ERROR("Failed to open file for writing: %s\n", filename.c_str());
return false;
}
auto audio_size = audio->getInfo()->size;
auto audio_ptr = audio->readMap<float>();
WaveHeader header;
header.num_channels = 1;
header.sample_rate = sample_rate;
header.byte_rate = sample_rate * header.num_channels * (header.bits_per_sample / 8);
header.block_align = header.num_channels * (header.bits_per_sample / 8);
header.subchunk2_size = audio_size * (header.bits_per_sample / 8);
header.chunk_size = 36 + header.subchunk2_size;
os.write(reinterpret_cast<const char *>(&header), sizeof(WaveHeader));
// Convert float samples to int16 and write to file
for (int i = 0; i < audio_size; i++) {
float sample = audio_ptr[i];
int16_t int_sample = static_cast<int16_t>(std::max(-1.0f, std::min(1.0f, sample)) * 32767);
os.write(reinterpret_cast<const char *>(&int_sample), sizeof(int16_t));
}
if (!os) {
MNN_ERROR("Failed to write audio data to file.\n");
return false;
}
os.close();
return true;
}
template <typename T>
static inline VARP _var(std::vector<T> vec, const std::vector<int> &dims) {
return _Const(vec.data(), dims, NHWC, halide_type_of<T>());
}
unsigned int next_power_of_2(unsigned int x) {
if (x == 0)
return 1;
if ((x & (x - 1)) == 0)
return x;
return 1U << (32 - mnn_clz(x));
}
VARP hamming_window(int n_fft, bool periodic, float alpha, float beta) {
auto window = _Input({n_fft}, NHWC);
auto window_ptr = window->writeMap<float>();
int N = periodic ? n_fft : n_fft - 1;
for (int n = 0; n < n_fft; ++n) {
window_ptr[n] = alpha - beta * std::cos(2.0 * M_PI * n / N);
}
return window;
}
VARP hann_window(int n_fft, bool periodic) {
auto window = _Input({n_fft}, NHWC);
auto window_ptr = window->writeMap<float>();
int N = periodic ? n_fft : n_fft - 1;
for (int n = 0; n < n_fft; ++n) {
window_ptr[n] = 0.5 * (1 - std::cos(2 * M_PI * n / N));
}
return window;
}
float hz_to_mel(float freq, bool htk) {
if (htk) {
return 2595 * std::log10(1 + freq / 700);
} else {
constexpr float f_min = 0.0, f_sp = 200.0 / 3.0, min_log_hz = 1000.0;
constexpr float logstep = 0.06875177742094912;
constexpr float min_log_mel = (min_log_hz - f_min) / f_sp;
float mels = (freq - f_min) / f_sp;
if (freq >= min_log_hz) {
mels = min_log_mel + std::log(freq / min_log_hz) / logstep;
}
return mels;
}
}
float mel_to_hz(float mel, bool htk) {
if (htk) {
return 700 * (std::pow(10, mel / 2595.0) - 1);
} else {
constexpr float f_min = 0.0f, f_sp = 200.0f / 3, min_log_hz = 1000.0f;
constexpr float logstep = 0.06875177742094912;
constexpr float min_log_mel = (min_log_hz - f_min) / f_sp;
float freq = f_min + f_sp * mel;
if (mel >= min_log_mel) {
freq = min_log_hz * std::exp(logstep * (mel - min_log_mel));
}
return freq;
}
}
VARP melscale_fbanks(const MelscaleParams *params) {
int n_mels = 128, n_fft = 400, sample_rate = 16000;
bool htk = true, norm = false;
float f_min = 0.0, f_max = 0.0;
if (params != nullptr) {
n_mels = params->n_mels;
n_fft = params->n_fft;
sample_rate = params->sample_rate;
htk = params->htk;
norm = params->norm;
f_min = params->f_min;
f_max = params->f_max;
}
int n_freqs = n_fft / 2 + 1;
float nyquist = 0.5 * sample_rate;
std::vector<float> all_freqs(n_freqs);
for (int i = 0; i < n_freqs; ++i) {
all_freqs[i] = i * nyquist / (n_freqs - 1);
}
f_max = f_max <= 0.0 ? nyquist : f_max;
float m_min = hz_to_mel(f_min, htk);
float m_max = hz_to_mel(f_max, htk);
float m_delta = (m_max - m_min) / (n_mels + 1);
auto bins = _Input({n_mels, n_freqs}, NHWC);
auto bins_ptr = bins->writeMap<float>();
for (int n = 0; n < n_mels; ++n) {
float left = mel_to_hz(m_min + m_delta * (n + 0), htk);
float curr = mel_to_hz(m_min + m_delta * (n + 1), htk);
float right = mel_to_hz(m_min + m_delta * (n + 2), htk);
float enorm = (htk && norm) ? 1.0 : 2.0 / (right - left);
for (int k = 0; k < n_freqs; ++k) {
float val = 0.f, f_k = all_freqs[k];
if (f_k >= left && f_k <= curr) {
val = (f_k - left) / (curr - left);
} else if (f_k > curr && f_k <= right) {
val = (right - f_k) / (right - curr);
}
bins_ptr[n * n_freqs + k] = val * enorm;
}
}
return bins;
}
VARP spectrogram(VARP waveform, const SpectrogramParams *params) {
int pad_left = 0, pad_right = 0, pad_mode = REFLECT;
int n_fft = 400, hop_length = 0, win_length = 0, window_type = HANNING;
bool center = false, normalized = false;
float power = 2.0;
if (params) {
pad_left = params->pad_left;
pad_right = params->pad_right;
center = params->center;
pad_mode = params->pad_mode;
n_fft = params->n_fft;
hop_length = params->hop_length;
win_length = params->win_length;
window_type = params->window_type;
normalized = params->normalized;
power = params->power;
}
if (pad_left > 1 || pad_right > 1) {
waveform = MNN::Express::_Pad(waveform, _var<int>({pad_left, pad_right}, {2}), MNN::Express::CONSTANT);
}
if (center) {
waveform = MNN::Express::_Pad(waveform, _var<int>({n_fft / 2, n_fft / 2}, {2}), static_cast<MNN::Express::PadValueMode>(pad_mode));
}
waveform = _Reshape(waveform, {1, -1, 1});
hop_length = hop_length ? hop_length : n_fft / 2;
win_length = win_length ? win_length : n_fft;
VARP window;
switch (window_type) {
case HANNING:
window = hann_window(win_length);
break;
case HAMMING:
window = hamming_window(win_length);
break;
default:
window = hann_window(win_length);
break;
}
std::unique_ptr<OpT> op(new OpT);
op->type = OpType_Stft;
op->main.type = OpParameter_StftParam;
auto param = new StftParamT;
param->abs = true;
op->main.value = param;
EXPRP stftexpr = Expr::create(std::move(op), {waveform, _Scalar<int>(hop_length), window});
int nstfts = ((waveform->getInfo()->dim[1] - n_fft) / hop_length) + 1;
int dft_unique_bins = n_fft / 2 + 1;
auto specgram = MNN::Express::Variable::create(stftexpr);
specgram = _Square(specgram);
auto startsDims = std::vector<int>{0, 0, 0, 0};
auto starts1Dims = std::vector<int>{0, 0, 0, 1};
auto sizeDims = std::vector<int>{1, nstfts, dft_unique_bins, 1};
auto startVar = _Const(startsDims.data(), {4}, NCHW, halide_type_of<int>());
auto start1Var = _Const(starts1Dims.data(), {4}, NCHW, halide_type_of<int>());
auto sizeVar = _Const(sizeDims.data(), {4}, NCHW, halide_type_of<int>());
auto specgramReal = _Slice(specgram, startVar, sizeVar);
auto specgramVirt = _Slice(specgram, start1Var, sizeVar);
specgram = specgramReal + specgramVirt;
specgram = _Reshape(specgram, {nstfts, dft_unique_bins});
if (normalized) {
float window_norm = 1.0f / _ReduceSum(_Square(window))->readMap<float>()[0];
specgram = specgram * _Scalar<float>(window_norm);
}
if (power == 1.0f) {
specgram = _Sqrt(specgram);
} else if (power != 2.0f) {
specgram = _Pow(specgram, _Scalar<float>(power / 2.0f));
}
return specgram;
}
VARP mel_spectrogram(VARP waveform, const MelscaleParams *mel_params, const SpectrogramParams *spec_params) {
auto banks = melscale_fbanks(mel_params);
auto specgram = spectrogram(waveform, spec_params);
auto mel_specgram = _MatMul(specgram, banks, false, true);
return mel_specgram;
}
VARP fbank(VARP waveform, int sampling_rate, int n_mels, int n_fft, int hop_length, float dither, float preemphasis) {
int wav_len = waveform->getInfo()->size;
int frame_num = (wav_len - n_fft) / hop_length + 1;
if (frame_num <= 0 || wav_len < n_fft) {
return nullptr; // frame_num is zero
}
// get_strided: sizes: [m, n_fft], strides: [windows_shift, 1]
int m = 1 + (wav_len - n_fft) / hop_length;
std::vector<int> strided_region = {
0, // src offset
wav_len,
hop_length,
1, // src strides
0, // dst offset
m * n_fft,
n_fft,
1, // dst strides
1,
m,
n_fft // dst sizes
};
auto strided_wav = _Raster({waveform}, strided_region, {m, n_fft});
auto wav_dim = strided_wav->getInfo()->dim;
// add_dither
if (dither > 0.f) {
auto rand_dither = _RandomUnifom(_var<int>(wav_dim, {static_cast<int>(wav_dim.size())}),
halide_type_of<float>(), -dither, dither);
strided_wav = strided_wav + rand_dither;
}
// subtract each row/frame by its mean
{
auto row_means = _ReduceMean(strided_wav, {-1}, true);
strided_wav = strided_wav - row_means;
}
if (preemphasis != 0.f) {
std::vector<int> offset_region = {
// region 0
0, // src offset
m * n_fft, n_fft, 1, // src strides
0, // dst offset
m * n_fft, n_fft, 1, // dst strides
1, m, 1, // dst sizes
// region 1
0, // src offset
m * n_fft, n_fft, 1, // src strides
1, // dst offset
m * n_fft, n_fft, 1, // dst strides
1, m, n_fft - 1 // dst sizes
};
auto offset_strided_wav = _Raster({strided_wav, strided_wav}, offset_region, {m, n_fft});
strided_wav = strided_wav - _Scalar<float>(preemphasis) * offset_strided_wav;
}
int padded_n_fft = next_power_of_2(n_fft);
MelscaleParams mel_params;
mel_params.n_mels = n_mels;
mel_params.n_fft = padded_n_fft;
mel_params.sample_rate = sampling_rate;
mel_params.f_min = 20.0;
SpectrogramParams spec_params;
spec_params.n_fft = padded_n_fft;
spec_params.hop_length = n_fft;
auto mel_energies = mel_spectrogram(strided_wav, &mel_params, &spec_params);
mel_energies = _Log(mel_energies);
return mel_energies;
}
VARP whisper_fbank(VARP waveform, int sample_rate, int n_mels, int n_fft, int hop_length, int chunk_len) {
int n_samples = chunk_len * sample_rate;
int pad_right = n_samples - waveform->getInfo()->size;
pad_right = pad_right > 0 ? pad_right : 0;
MelscaleParams mel_params;
mel_params.n_mels = n_mels;
mel_params.n_fft = n_fft;
mel_params.sample_rate = sample_rate;
mel_params.htk = false;
mel_params.norm = true;
SpectrogramParams spec_params;
spec_params.pad_right = pad_right;
spec_params.n_fft = n_fft;
spec_params.hop_length = hop_length;
spec_params.center = true;
auto mel_specgram = mel_spectrogram(waveform, &mel_params, &spec_params);
mel_specgram =
_Slice(mel_specgram, _var<int>({0, 0}, {2}), _var<int>({mel_specgram->getInfo()->dim[0] - 1, -1}, {2}));
auto log_specgram = _Log(mel_specgram) / _Log(_Scalar<float>(10.0));
log_specgram = _Maximum(log_specgram, _ReduceMax(log_specgram) - _Scalar<float>(8.0));
log_specgram = (log_specgram + _Scalar<float>(4.0)) / _Scalar<float>(4.0);
// NHWC -> NCHW
log_specgram = _Unsqueeze(log_specgram, {0, 1});
log_specgram = _Convert(log_specgram, NCHW);
log_specgram = _Squeeze(log_specgram, {2});
return log_specgram;
}
} // namespace AUDIO
} // namespace MNN