diff --git a/apps/frameworks/sherpa-mnn/.gitignore b/apps/frameworks/sherpa-mnn/.gitignore new file mode 100644 index 00000000..2b547606 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/.gitignore @@ -0,0 +1,5 @@ +SourcePackages +build-* +*.xcworkspace +!build-*.sh +*.lock \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/CHANGELOG.md b/apps/frameworks/sherpa-mnn/CHANGELOG.md new file mode 100644 index 00000000..326925f2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/CHANGELOG.md @@ -0,0 +1,475 @@ +## 1.10.46 + +# Fix kokoro lexicon. (#1886) +# speaker-identification-with-vad-non-streaming-asr.py Lack of support for sense_voice. (#1884) +# Fix generating Chinese lexicon for Kokoro TTS 1.0 (#1888) +# Reduce vad-whisper-c-api example code. (#1891) +# JNI Exception Handling (#1452) +# Fix #1901: UnicodeEncodeError running export_bpe_vocab.py (#1902) +# Fix publishing pre-built windows libraries (#1905) +# Fixing Whisper Model Token Normalization (#1904) +# feat: add mic example for better compatibility (#1909) +# Add onnxruntime 1.18.1 for Linux aarch64 GPU (#1914) +# Add C++ API for streaming zipformer ASR on RK NPU (#1908) +# change [1<<28] to [1<<10], to fix build issues on GOARCH=386 that [1<<28] too large (#1916) +# Flutter Config toJson/fromJson (#1893) +# Fix publishing linux pre-built artifacts (#1919) +# go.mod set to use go 1.17, and use unsafe.Slice to optimize the code (#1920) +# fix: AddPunct panic for Go(#1921) +# Fix publishing macos pre-built artifacts (#1922) +# Minor fixes for rknn (#1925) +# Build wheels for rknn linux aarch64 (#1928) + +## 1.10.45 + +* [update] fixed bug: create golang instance succeed while the c struct create failed (#1860) +* fixed typo in RTF calculations (#1861) +* Export FireRedASR to sherpa-onnx. (#1865) +* Add C++ and Python API for FireRedASR AED models (#1867) +* Add Kotlin and Java API for FireRedAsr AED model (#1870) +* Add C API for FireRedAsr AED model. (#1871) +* Add CXX API for FireRedAsr (#1872) +* Add JavaScript API (node-addon) for FireRedAsr (#1873) +* Add JavaScript API (WebAssembly) for FireRedAsr model. (#1874) +* Add C# API for FireRedAsr Model (#1875) +* Add C# API for FireRedAsr Model (#1875) +* Add Swift API for FireRedAsr AED Model (#1876) +* Add Dart API for FireRedAsr AED Model (#1877) +* Add Go API for FireRedAsr AED Model (#1879) +* Add Pascal API for FireRedAsr AED Model (#1880) + +## 1.10.44 + +* Export MatchaTTS fa-en model to sherpa-onnx (#1832) +* Add C++ support for MatchaTTS models not from icefall. (#1834) +* OfflineRecognizer supports create stream with hotwords (#1833) +* Add PengChengStarling models to sherpa-onnx (#1835) +* Support specifying voice in espeak-ng for kokoro tts models. (#1836) +* Fix: made print sherpa_onnx_loge when it is in debug mode (#1838) +* Add Go API for audio tagging (#1840) +* Fix CI (#1841) +* Update readme to contain links for pre-built Apps (#1853) +* Modify the model used (#1855) +* Flutter OnlinePunctuation (#1854) +* Fix spliting text by languages for kokoro tts. (#1849) + +## 1.10.43 + +* Add MFC example for Kokoro TTS 1.0 (#1815) +* Update sherpa-onnx-tts.js VitsModelConfig.model can be none (#1817) +* Fix passing gb2312 encoded strings to tts on Windows (#1819) +* Support scaling the duration of a pause in TTS. (#1820) +* Fix building wheels for linux aarch64. (#1821) +* Fix CI for Linux aarch64. (#1822) + +## 1.10.42 + +* Fix publishing wheels (#1746) +* Update README to include https://github.com/xinhecuican/QSmartAssistant (#1755) +* Add Kokoro TTS to MFC examples (#1760) +* Refactor node-addon C++ code. (#1768) +* Add keyword spotter C API for HarmonyOS (#1769) +* Add ArkTS API for Keyword spotting. (#1775) +* Add Flutter example for Kokoro TTS (#1776) +* Initialize the audio session for iOS ASR example (#1786) +* Fix: Prepend 0 to tokenization to prevent word skipping for Kokoro. (#1787) +* Export Kokoro 1.0 to sherpa-onnx (#1788) +* Add C++ and Python API for Kokoro 1.0 multilingual TTS model (#1795) +* Add Java and Koltin API for Kokoro TTS 1.0 (#1798) +* Add Android demo for Kokoro TTS 1.0 (#1799) +* Add C API for Kokoro TTS 1.0 (#1801) +* Add CXX API for Kokoro TTS 1.0 (#1802) +* Add Swift API for Kokoro TTS 1.0 (#1803) +* Add Go API for Kokoro TTS 1.0 (#1804) +* Add C# API for Kokoro TTS 1.0 (#1805) +* Add Dart API for Kokoro TTS 1.0 (#1806) +* Add Pascal API for Kokoro TTS 1.0 (#1807) +* Add JavaScript API (node-addon) for Kokoro TTS 1.0 (#1808) +* Add JavaScript API (WebAssembly) for Kokoro TTS 1.0 (#1809) +* Add Flutter example for Kokoro TTS 1.0 (#1810) +* Add iOS demo for Kokoro TTS 1.0 (#1812) +* Add HarmonyOS demo for Kokoro TTS 1.0 (#1813) + +## 1.10.41 + +* Fix UI for Android TTS Engine. (#1735) +* Add iOS TTS example for MatchaTTS (#1736) +* Add iOS example for Kokoro TTS (#1737) +* Fix dither binding in Pybind11 to ensure independence from high_freq in FeatureExtractorConfig (#1739) +* Fix keyword spotting. (#1689) +* Update readme to include https://github.com/hfyydd/sherpa-onnx-server (#1741) +* Reduce vad-moonshine-c-api example code. (#1742) +* Support Kokoro TTS for HarmonyOS. (#1743) + +## 1.10.40 + +* Fix building wheels (#1703) +* Export kokoro to sherpa-onnx (#1713) +* Add C++ and Python API for Kokoro TTS models. (#1715) +* Add C API for Kokoro TTS models (#1717) +* Fix style issues (#1718) +* Add C# API for Kokoro TTS models (#1720) +* Add Swift API for Kokoro TTS models (#1721) +* Add Go API for Kokoro TTS models (#1722) +* Add Dart API for Kokoro TTS models (#1723) +* Add Pascal API for Kokoro TTS models (#1724) +* Add JavaScript API (node-addon) for Kokoro TTS models (#1725) +* Add JavaScript (WebAssembly) API for Kokoro TTS models. (#1726) +* Add Koltin and Java API for Kokoro TTS models (#1728) +* Update README.md for KWS to not use git lfs. (#1729) + + + + +## 1.10.39 + +* Fix building without TTS (#1691) +* Add README for android libs. (#1693) +* Fix: export-onnx.py(expected all tensors to be on the same device) (#1699) +* Fix passing strings from C# to C. (#1701) + +## 1.10.38 + +* Fix initializing TTS in Python. (#1664) +* Remove spaces after punctuations for TTS (#1666) +* Add constructor fromPtr() for all flutter class with factory ctor. (#1667) +* Add Kotlin API for Matcha-TTS models. (#1668) +* Support Matcha-TTS models using espeak-ng (#1672) +* Add Java API for Matcha-TTS models. (#1673) +* Avoid adding tail padding for VAD in generate-subtitles.py (#1674) +* Add C API for MatchaTTS models (#1675) +* Add CXX API for MatchaTTS models (#1676) +* Add JavaScript API (node-addon-api) for MatchaTTS models. (#1677) +* Add HarmonyOS examples for MatchaTTS. (#1678) +* Upgraded to .NET 8 and made code style a little more internally consistent. (#1680) +* Update workflows to use .NET 8.0 also. (#1681) +* Add C# and JavaScript (wasm) API for MatchaTTS models (#1682) +* Add Android demo for MatchaTTS models. (#1683) +* Add Swift API for MatchaTTS models. (#1684) +* Add Go API for MatchaTTS models (#1685) +* Add Pascal API for MatchaTTS models. (#1686) +* Add Dart API for MatchaTTS models (#1687) + +## 1.10.37 + +* Add new tts models for Latvia and Persian+English (#1644) +* Add a byte-level BPE Chinese+English non-streaming zipformer model (#1645) +* Support removing invalid utf-8 sequences. (#1648) +* Add TeleSpeech CTC to non_streaming_server.py (#1649) +* Fix building macOS libs (#1656) +* Add Go API for Keyword spotting (#1662) +* Add Swift online punctuation (#1661) +* Add C++ runtime for Matcha-TTS (#1627) + +## 1.10.36 + +* Update AAR version in Android Java demo (#1618) +* Support linking onnxruntime statically for Android (#1619) +* Update readme to include Open-LLM-VTuber (#1622) +* Rename maxNumStences to maxNumSentences (#1625) +* Support using onnxruntime 1.16.0 with CUDA 11.4 on Jetson Orin NX (Linux arm64 GPU). (#1630) +* Update readme to include jetson orin nx and nano b01 (#1631) +* feat: add checksum action (#1632) +* Support decoding with byte-level BPE (bbpe) models. (#1633) +* feat: enable c api for android ci (#1635) +* Update README.md (#1640) +* SherpaOnnxVadAsr: Offload runSecondPass to background thread for improved real-time audio processing (#1638) +* Fix GitHub actions. (#1642) + + +## 1.10.35 + +* Add missing changes about speaker identfication demo for HarmonyOS (#1612) +* Provide sherpa-onnx.aar for Android (#1615) +* Use aar in Android Java demo. (#1616) + +## 1.10.34 + +* Fix building node-addon package (#1598) +* Update doc links for HarmonyOS (#1601) +* Add on-device real-time ASR demo for HarmonyOS (#1606) +* Add speaker identification APIs for HarmonyOS (#1607) +* Add speaker identification demo for HarmonyOS (#1608) +* Add speaker diarization API for HarmonyOS. (#1609) +* Add speaker diarization demo for HarmonyOS (#1610) + +## 1.10.33 + +* Add non-streaming ASR support for HarmonyOS. (#1564) +* Add streaming ASR support for HarmonyOS. (#1565) +* Fix building for Android (#1568) +* Publish `sherpa_onnx.har` for HarmonyOS (#1572) +* Add VAD+ASR demo for HarmonyOS (#1573) +* Fix publishing har packages for HarmonyOS (#1576) +* Add CI to build HAPs for HarmonyOS (#1578) +* Add microphone demo about VAD+ASR for HarmonyOS (#1581) +* Fix getting microphone permission for HarmonyOS VAD+ASR example (#1582) +* Add HarmonyOS support for text-to-speech. (#1584) +* Fix: support both old and new websockets request headers format (#1588) +* Add on-device tex-to-speech (TTS) demo for HarmonyOS (#1590) + +## 1.10.32 + +* Support cross-compiling for HarmonyOS (#1553) +* HarmonyOS support for VAD. (#1561) +* Fix publishing flutter iOS app to appstore (#1563). + +## 1.10.31 + +* Publish pre-built wheels for Python 3.13 (#1485) +* Publish pre-built macos xcframework (#1490) +* Fix reading tokens.txt on Windows. (#1497) +* Add two-pass ASR Android APKs for Moonshine models. (#1499) +* Support building GPU-capable sherpa-onnx on Linux aarch64. (#1500) +* Publish pre-built wheels with CUDA support for Linux aarch64. (#1507) +* Export the English TTS model from MeloTTS (#1509) +* Add Lazarus example for Moonshine models. (#1532) +* Add isolate_tts demo (#1529) +* Add WebAssembly example for VAD + Moonshine models. (#1535) +* Add Android APK for streaming Paraformer ASR (#1538) +* Support static build for windows arm64. (#1539) +* Use xcframework for Flutter iOS plugin to support iOS simulators. + +## 1.10.30 + +* Fix building node-addon for Windows x86. (#1469) +* Begin to support https://github.com/usefulsensors/moonshine (#1470) +* Publish pre-built JNI libs for Linux aarch64 (#1472) +* Add C++ runtime and Python APIs for Moonshine models (#1473) +* Add Kotlin and Java API for Moonshine models (#1474) +* Add C and C++ API for Moonshine models (#1476) +* Add Swift API for Moonshine models. (#1477) +* Add Go API examples for adding punctuations to text. (#1478) +* Add Go API for Moonshine models (#1479) +* Add JavaScript API for Moonshine models (#1480) +* Add Dart API for Moonshine models. (#1481) +* Add Pascal API for Moonshine models (#1482) +* Add C# API for Moonshine models. (#1483) + +## 1.10.29 + +* Add Go API for offline punctuation models (#1434) +* Support https://huggingface.co/Revai/reverb-diarization-v1 (#1437) +* Add more models for speaker diarization (#1440) +* Add Java API example for hotwords. (#1442) +* Add java android demo (#1454) +* Add C++ API for streaming ASR. (#1455) +* Add C++ API for non-streaming ASR (#1456) +* Handle NaN embeddings in speaker diarization. (#1461) +* Add speaker identification with VAD and non-streaming ASR using ALSA (#1463) +* Support GigaAM CTC models for Russian ASR (#1464) +* Add GigaAM NeMo transducer model for Russian ASR (#1467) + +## 1.10.28 + +* Fix swift example for generating subtitles. (#1362) +* Allow more online models to load tokens file from the memory (#1352) +* Fix CI errors introduced by supporting loading keywords from buffers (#1366) +* Fix running MeloTTS models on GPU. (#1379) +* Support Parakeet models from NeMo (#1381) +* Export Pyannote speaker segmentation models to onnx (#1382) +* Support Agglomerative clustering. (#1384) +* Add Python API for clustering (#1385) +* support whisper turbo (#1390) +* context_state is not set correctly when previous context is passed after reset (#1393) +* Speaker diarization example with onnxruntime Python API (#1395) +* C++ API for speaker diarization (#1396) +* Python API for speaker diarization. (#1400) +* C API for speaker diarization (#1402) +* docs(nodejs-addon-examples): add guide for pnpm user (#1401) +* Go API for speaker diarization (#1403) +* Swift API for speaker diarization (#1404) +* Update readme to include more external projects using sherpa-onnx (#1405) +* C# API for speaker diarization (#1407) +* JavaScript API (node-addon) for speaker diarization (#1408) +* WebAssembly exmaple for speaker diarization (#1411) +* Handle audio files less than 10s long for speaker diarization. (#1412) +* JavaScript API with WebAssembly for speaker diarization (#1414) +* Kotlin API for speaker diarization (#1415) +* Java API for speaker diarization (#1416) +* Dart API for speaker diarization (#1418) +* Pascal API for speaker diarization (#1420) +* Android JNI support for speaker diarization (#1421) +* Android demo for speaker diarization (#1423) + +## 1.10.27 + +* Add non-streaming ONNX models for Russian ASR (#1358) +* Fix building Flutter TTS examples for Linux (#1356) +* Support passing utf-8 strings from JavaScript to C++. (#1355) +* Fix sherpa_onnx.go to support returning empty recognition results (#1353) + +## 1.10.26 + +* Add links to projects using sherpa-onnx. (#1345) +* Support lang/emotion/event results from SenseVoice in Swift API. (#1346) +* Support specifying max speech duration for VAD. (#1348) +* Add APIs about max speech duration in VAD for various programming languages (#1349) + +## 1.10.25 + +* Allow tokens and hotwords to be loaded from buffered string driectly (#1339) +* Fix computing features for CED audio tagging models. (#1341) +* Preserve previous result as context for next segment (#1335) +* Add Python binding for online punctuation models (#1312) +* Fix vad.Flush(). (#1329) +* Fix wasm app for streaming paraformer (#1328) +* Build websocket related binaries for embedded systems. (#1327) +* Fixed the C api calls and created the TTS project file (#1324) +* Re-implement LM rescore for online transducer (#1231) + +## 1.10.24 + +* Add VAD and keyword spotting for the Node package with WebAssembly (#1286) +* Fix releasing npm package and fix building Android VAD+ASR example (#1288) +* add Tokens []string, Timestamps []float32, Lang string, Emotion string, Event string (#1277) +* add vad+sense voice example for C API (#1291) +* ADD VAD+ASR example for dart with CircularBuffer. (#1293) +* Fix VAD+ASR example for Dart API. (#1294) +* Avoid SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches freeing null. (#1296) +* Fix releasing wasm app for vad+asr (#1300) +* remove extra files from linux/macos/windows jni libs (#1301) +* two-pass Android APK for SenseVoice (#1302) +* Downgrade flutter sdk versions. (#1305) +* Reduce onnxruntime log output. (#1306) +* Provide prebuilt .jar files for different java versions. (#1307) + + +## 1.10.23 + +* flutter: add lang, emotion, event to OfflineRecognizerResult (#1268) +* Use a separate thread to initialize models for lazarus examples. (#1270) +* Object pascal examples for recording and playing audio with portaudio. (#1271) +* Text to speech API for Object Pascal. (#1273) +* update kotlin api for better release native object and add user-friendly apis. (#1275) +* Update wave-reader.cc to support 8/16/32-bit waves (#1278) +* Add WebAssembly for VAD (#1281) +* WebAssembly example for VAD + Non-streaming ASR (#1284) + +## 1.10.22 + +* Add Pascal API for reading wave files (#1243) +* Pascal API for streaming ASR (#1246) +* Pascal API for non-streaming ASR (#1247) +* Pascal API for VAD (#1249) +* Add more C API examples (#1255) +* Add emotion, event of SenseVoice. (#1257) +* Support reading multi-channel wave files with 8/16/32-bit encoded samples (#1258) +* Enable IPO only for Release build. (#1261) +* Add Lazarus example for generating subtitles using Silero VAD with non-streaming ASR (#1251) +* Fix looking up OOVs in lexicon.txt for MeloTTS models. (#1266) + + +## 1.10.21 + +* Fix ffmpeg c api example (#1185) +* Fix splitting sentences for MeloTTS (#1186) +* Non-streaming WebSocket client for Java. (#1190) +* Fix copying asset files for flutter examples. (#1191) +* Add Chinese+English tts example for flutter (#1192) +* Add speaker identification and verification exmaple for Dart API (#1194) +* Fix reading non-standard wav files. (#1199) +* Add ReazonSpeech Japanese pre-trained model (#1203) +* Describe how to add new words for MeloTTS models (#1209) +* Remove libonnxruntime_providers_cuda.so as a dependency. (#1210) +* Fix setting SenseVoice language. (#1214) +* Support passing TTS callback in Swift API (#1218) +* Add MeloTTS example for ios (#1223) +* Add online punctuation and casing prediction model for English language (#1224) +* Fix python two pass ASR examples (#1230) +* Add blank penalty for various language bindings + +## 1.10.20 + +* Add Dart API for audio tagging +* Add Dart API for adding punctuations to text + +## 1.10.19 + +* Prefix all C API functions with SherpaOnnx + +## 1.10.18 + +* Fix the case when recognition results contain the symbol `"`. It caused + issues when converting results to a json string. + +## 1.10.17 + +* Support SenseVoice CTC models. +* Add Dart API for keyword spotter. + +## 1.10.16 + +* Support zh-en TTS model from MeloTTS. + +## 1.10.15 + +* Downgrade onnxruntime from v1.18.1 to v1.17.1 + +## 1.10.14 + +* Support whisper large v3 +* Update onnxruntime from v1.18.0 to v1.18.1 +* Fix invalid utf8 sequence from Whisper for Dart API. + +## 1.10.13 + +* Update onnxruntime from 1.17.1 to 1.18.0 +* Add C# API for Keyword spotting + +## 1.10.12 + +* Add Flush to VAD so that the last speech segment can be detected. See also + https://github.com/k2-fsa/sherpa-onnx/discussions/1077#discussioncomment-9979740 + +## 1.10.11 + +* Support the iOS platform for Flutter. + +## 1.10.10 + +* Build sherpa-onnx into a single shared library. + +## 1.10.9 + +* Fix released packages. piper-phonemize was not included in v1.10.8. + +## 1.10.8 + +* Fix released packages. There should be a lib directory. + +## 1.10.7 + +* Support Android for Flutter. + +## 1.10.2 + +* Fix passing C# string to C++ + +## 1.10.1 + +* Enable to stop TTS generation + +## 1.10.0 + +* Add inverse text normalization + +## 1.9.30 + +* Add TTS + +## 1.9.29 + +* Publish with CI + +## 0.0.3 + +* Fix path separator on Windows. + +## 0.0.2 + +* Support specifying lib path. + +## 0.0.1 + +* Initial release. diff --git a/apps/frameworks/sherpa-mnn/CMakeLists.txt b/apps/frameworks/sherpa-mnn/CMakeLists.txt new file mode 100644 index 00000000..a5596f3a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/CMakeLists.txt @@ -0,0 +1,452 @@ +cmake_minimum_required(VERSION 3.13 FATAL_ERROR) + +set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14" CACHE STRING "Minimum OS X deployment version. Used only for macOS") + +set(CMAKE_POLICY_DEFAULT_CMP0063 NEW) +set(CMAKE_POLICY_DEFAULT_CMP0069 NEW) + +project(sherpa-mnn) + +message(STATUS "MNN's dir: ${MNN_LIB_DIR}") +include_directories(${MNN_LIB_DIR}/include) +link_directories(${MNN_LIB_DIR}/lib) + +# Remember to update +# ./CHANGELOG.md +# ./new-release.sh +set(SHERPA_MNN_VERSION "1.10.46") + +# Disable warning about +# +# "The DOWNLOAD_EXTRACT_TIMESTAMP option was not given and policy CMP0135 is +# not set. +if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") + cmake_policy(SET CMP0135 NEW) +endif() + +option(SHERPA_MNN_ENABLE_PYTHON "Whether to build Python" OFF) +option(SHERPA_MNN_ENABLE_TESTS "Whether to build tests" OFF) +option(SHERPA_MNN_ENABLE_CHECK "Whether to build with assert" OFF) +option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF) +option(SHERPA_MNN_ENABLE_PORTAUDIO "Whether to build with portaudio" ON) +option(SHERPA_MNN_ENABLE_JNI "Whether to build JNI internface" OFF) +option(SHERPA_MNN_ENABLE_C_API "Whether to build C API" ON) +option(SHERPA_MNN_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON) +option(SHERPA_MNN_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF) +option(SHERPA_MNN_ENABLE_DIRECTML "Enable ONNX Runtime DirectML support" OFF) +option(SHERPA_MNN_ENABLE_WASM "Whether to enable WASM" OFF) +option(SHERPA_MNN_ENABLE_WASM_SPEAKER_DIARIZATION "Whether to enable WASM for speaker diarization" OFF) +option(SHERPA_MNN_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF) +option(SHERPA_MNN_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF) +option(SHERPA_MNN_ENABLE_WASM_KWS "Whether to enable WASM for KWS" OFF) +option(SHERPA_MNN_ENABLE_WASM_VAD "Whether to enable WASM for VAD" OFF) +option(SHERPA_MNN_ENABLE_WASM_VAD_ASR "Whether to enable WASM for VAD+ASR" OFF) +option(SHERPA_MNN_ENABLE_WASM_NODEJS "Whether to enable WASM for NodeJS" OFF) +option(SHERPA_MNN_ENABLE_BINARY "Whether to build binaries" ON) +option(SHERPA_MNN_ENABLE_TTS "Whether to build TTS related code" ON) +option(SHERPA_MNN_ENABLE_SPEAKER_DIARIZATION "Whether to build speaker diarization related code" ON) +option(SHERPA_MNN_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is OFF on Linux" ON) +option(SHERPA_MNN_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE "True to use pre-installed onnxruntime if available" ON) +option(SHERPA_MNN_ENABLE_SANITIZER "Whether to enable ubsan and asan" OFF) +option(SHERPA_MNN_BUILD_C_API_EXAMPLES "Whether to enable C API examples" ON) +option(SHERPA_MNN_ENABLE_RKNN "Whether to build for RKNN NPU " OFF) + +set(SHERPA_MNN_LINUX_ARM64_GPU_ONNXRUNTIME_VERSION "1.11.0" CACHE STRING "Used only for Linux ARM64 GPU. If you use Jetson nano b01, then please set it to 1.11.0. If you use Jetson Orin NX, then set it to 1.16.0.If you use NVIDIA Jetson Orin Nano Engineering Reference Developer Kit +Super - Jetpack 6.2 [L4T 36.4.3], then set it to 1.18.1") + + +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin") + +if(NOT WIN32) + set(CMAKE_SKIP_BUILD_RPATH FALSE) + set(BUILD_RPATH_USE_ORIGIN TRUE) + set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) +endif() + +if(NOT APPLE) + set(SHERPA_MNN_RPATH_ORIGIN "$ORIGIN") +else() + set(SHERPA_MNN_RPATH_ORIGIN "@loader_path") +endif() + +if(NOT WIN32) + set(CMAKE_INSTALL_RPATH ${SHERPA_MNN_RPATH_ORIGIN}) + set(CMAKE_BUILD_RPATH ${SHERPA_MNN_RPATH_ORIGIN}) +endif() + +if(NOT CMAKE_BUILD_TYPE) + message(STATUS "No CMAKE_BUILD_TYPE given, default to Release") + set(CMAKE_BUILD_TYPE Release) +endif() + +if(DEFINED ANDROID_ABI AND NOT SHERPA_MNN_ENABLE_JNI AND NOT SHERPA_MNN_ENABLE_C_API) + message(STATUS "Set SHERPA_MNN_ENABLE_JNI to ON for Android") + set(SHERPA_MNN_ENABLE_JNI ON CACHE BOOL "" FORCE) +endif() + +if(SHERPA_MNN_ENABLE_PYTHON AND NOT BUILD_SHARED_LIBS) + message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_MNN_ENABLE_PYTHON is ON") + set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) +endif() + +if(SHERPA_MNN_ENABLE_GPU) + message(WARNING "\ +Compiling for NVIDIA GPU is enabled. Please make sure cudatoolkit +is installed on your system. Otherwise, you will get errors at runtime. +Hint: You don't need sudo permission to install CUDA toolkit. Please refer to + https://k2-fsa.github.io/k2/installation/cuda-cudnn.html +to install CUDA toolkit if you have not installed it.") + if(NOT BUILD_SHARED_LIBS) + message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_MNN_ENABLE_GPU is ON") + set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) + endif() +endif() + +if(SHERPA_MNN_ENABLE_DIRECTML) + message(WARNING "\ +Compiling with DirectML enabled. Please make sure Windows 10 SDK +is installed on your system. Otherwise, you will get errors at runtime. +Please refer to + https://onnxruntime.ai/docs/execution-providers/DirectML-ExecutionProvider.html#requirements +to install Windows 10 SDK if you have not installed it.") + if(NOT BUILD_SHARED_LIBS) + message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_MNN_ENABLE_DIRECTML is ON") + set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) + endif() +endif() + +# see https://cmake.org/cmake/help/latest/prop_tgt/MSVC_RUNTIME_LIBRARY.html +# https://stackoverflow.com/questions/14172856/compile-with-mt-instead-of-md-using-cmake +if(MSVC) + add_compile_options( + $<$:/MT> #---------| + $<$:/MTd> #---|-- Statically link the runtime libraries + $<$:/MT> #--| + $<$:/MT> + $<$:/MT> + ) +endif() + +if(CMAKE_SYSTEM_NAME STREQUAL OHOS) + set(CMAKE_CXX_FLAGS "-Wno-unused-command-line-argument ${CMAKE_CXX_FLAGS}") + set(CMAKE_C_FLAGS "-Wno-unused-command-line-argument ${CMAKE_C_FLAGS}") +endif() + +message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") +message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}") +message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}") +message(STATUS "SHERPA_MNN_ENABLE_PYTHON ${SHERPA_MNN_ENABLE_PYTHON}") +message(STATUS "SHERPA_MNN_ENABLE_TESTS ${SHERPA_MNN_ENABLE_TESTS}") +message(STATUS "SHERPA_MNN_ENABLE_CHECK ${SHERPA_MNN_ENABLE_CHECK}") +message(STATUS "SHERPA_MNN_ENABLE_PORTAUDIO ${SHERPA_MNN_ENABLE_PORTAUDIO}") +message(STATUS "SHERPA_MNN_ENABLE_JNI ${SHERPA_MNN_ENABLE_JNI}") +message(STATUS "SHERPA_MNN_ENABLE_C_API ${SHERPA_MNN_ENABLE_C_API}") +message(STATUS "SHERPA_MNN_ENABLE_WEBSOCKET ${SHERPA_MNN_ENABLE_WEBSOCKET}") +message(STATUS "SHERPA_MNN_ENABLE_GPU ${SHERPA_MNN_ENABLE_GPU}") +message(STATUS "SHERPA_MNN_ENABLE_WASM ${SHERPA_MNN_ENABLE_WASM}") +message(STATUS "SHERPA_MNN_ENABLE_WASM_SPEAKER_DIARIZATION ${SHERPA_MNN_ENABLE_WASM_SPEAKER_DIARIZATION}") +message(STATUS "SHERPA_MNN_ENABLE_WASM_TTS ${SHERPA_MNN_ENABLE_WASM_TTS}") +message(STATUS "SHERPA_MNN_ENABLE_WASM_ASR ${SHERPA_MNN_ENABLE_WASM_ASR}") +message(STATUS "SHERPA_MNN_ENABLE_WASM_KWS ${SHERPA_MNN_ENABLE_WASM_KWS}") +message(STATUS "SHERPA_MNN_ENABLE_WASM_VAD ${SHERPA_MNN_ENABLE_WASM_VAD}") +message(STATUS "SHERPA_MNN_ENABLE_WASM_VAD_ASR ${SHERPA_MNN_ENABLE_WASM_VAD_ASR}") +message(STATUS "SHERPA_MNN_ENABLE_WASM_NODEJS ${SHERPA_MNN_ENABLE_WASM_NODEJS}") +message(STATUS "SHERPA_MNN_ENABLE_BINARY ${SHERPA_MNN_ENABLE_BINARY}") +message(STATUS "SHERPA_MNN_ENABLE_TTS ${SHERPA_MNN_ENABLE_TTS}") +message(STATUS "SHERPA_MNN_ENABLE_SPEAKER_DIARIZATION ${SHERPA_MNN_ENABLE_SPEAKER_DIARIZATION}") +message(STATUS "SHERPA_MNN_LINK_LIBSTDCPP_STATICALLY ${SHERPA_MNN_LINK_LIBSTDCPP_STATICALLY}") +message(STATUS "SHERPA_MNN_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE ${SHERPA_MNN_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE}") +message(STATUS "SHERPA_MNN_ENABLE_SANITIZER: ${SHERPA_MNN_ENABLE_SANITIZER}") +message(STATUS "SHERPA_MNN_BUILD_C_API_EXAMPLES: ${SHERPA_MNN_BUILD_C_API_EXAMPLES}") +message(STATUS "SHERPA_MNN_ENABLE_RKNN: ${SHERPA_MNN_ENABLE_RKNN}") + +if(BUILD_SHARED_LIBS OR SHERPA_MNN_ENABLE_JNI) + set(CMAKE_CXX_VISIBILITY_PRESET hidden) + set(CMAKE_VISIBILITY_INLINES_HIDDEN 1) + set(CMAKE_POSITION_INDEPENDENT_CODE ON) +endif() + +if(BUILD_SHARED_LIBS AND NOT CMAKE_SYSTEM_NAME STREQUAL iOS AND CMAKE_BUILD_TYPE STREQUAL Release) + # Don't use LTO for iOS since it causes the following error + # error: unable to find any architecture information in the binary + # at '/Users/fangjun/open-source/sherpa-onnx/build-ios/build/os64/sherpa-onnx.a': + # Unknown header: 0xb17c0de + # See also https://forums.developer.apple.com/forums/thread/714324 + + include(CheckIPOSupported) + check_ipo_supported(RESULT ipo) + if(ipo) + message(STATUS "IPO is enabled") + set(CMAKE_INTERPROCEDURAL_OPTIMIZATION ON) + else() + message(STATUS "IPO is not available") + endif() +endif() + +if(SHERPA_MNN_ENABLE_TTS) + message(STATUS "TTS is enabled") + add_definitions(-DSHERPA_MNN_ENABLE_TTS=1) +else() + message(WARNING "TTS is disabled") + add_definitions(-DSHERPA_MNN_ENABLE_TTS=0) +endif() + +if(SHERPA_MNN_ENABLE_SPEAKER_DIARIZATION) + message(STATUS "speaker diarization is enabled") + add_definitions(-DSHERPA_MNN_ENABLE_SPEAKER_DIARIZATION=1) +else() + message(WARNING "speaker diarization is disabled") + add_definitions(-DSHERPA_MNN_ENABLE_SPEAKER_DIARIZATION=0) +endif() + +if(SHERPA_MNN_ENABLE_DIRECTML) + message(STATUS "DirectML is enabled") + add_definitions(-DSHERPA_MNN_ENABLE_DIRECTML=1) +else() + message(STATUS "DirectML is disabled") + add_definitions(-DSHERPA_MNN_ENABLE_DIRECTML=0) +endif() + +if(SHERPA_MNN_ENABLE_WASM_SPEAKER_DIARIZATION) + if(NOT SHERPA_MNN_ENABLE_SPEAKER_DIARIZATION) + message(FATAL_ERROR "Please set SHERPA_MNN_ENABLE_SPEAKER_DIARIZATION to ON if you want to build WASM for speaker diarization") + endif() + + if(NOT SHERPA_MNN_ENABLE_WASM) + message(FATAL_ERROR "Please set SHERPA_MNN_ENABLE_WASM to ON if you enable WASM for speaker diarization") + endif() +endif() + +if(SHERPA_MNN_ENABLE_WASM_TTS) + if(NOT SHERPA_MNN_ENABLE_TTS) + message(FATAL_ERROR "Please set SHERPA_MNN_ENABLE_TTS to ON if you want to build WASM for TTS") + endif() + + if(NOT SHERPA_MNN_ENABLE_WASM) + message(FATAL_ERROR "Please set SHERPA_MNN_ENABLE_WASM to ON if you enable WASM for TTS") + endif() +endif() + +if(SHERPA_MNN_ENABLE_WASM_ASR) + if(NOT SHERPA_MNN_ENABLE_WASM) + message(FATAL_ERROR "Please set SHERPA_MNN_ENABLE_WASM to ON if you enable WASM for ASR") + endif() +endif() + +if(SHERPA_MNN_ENABLE_WASM_NODEJS) + if(NOT SHERPA_MNN_ENABLE_WASM) + message(FATAL_ERROR "Please set SHERPA_MNN_ENABLE_WASM to ON if you enable WASM for NodeJS") + endif() + add_definitions(-DSHERPA_MNN_ENABLE_WASM_KWS=1) +endif() + +if(SHERPA_MNN_ENABLE_WASM) + add_definitions(-DSHERPA_MNN_ENABLE_WASM=1) +endif() + +if(SHERPA_MNN_ENABLE_WASM_KWS) + if(NOT SHERPA_MNN_ENABLE_WASM) + message(FATAL_ERROR "Please set SHERPA_MNN_ENABLE_WASM to ON if you enable WASM for KWS") + endif() + add_definitions(-DSHERPA_MNN_ENABLE_WASM_KWS=1) +endif() + +if(SHERPA_MNN_ENABLE_WASM_VAD) + if(NOT SHERPA_MNN_ENABLE_WASM) + message(FATAL_ERROR "Please set SHERPA_MNN_ENABLE_WASM to ON if you enable WASM for VAD") + endif() +endif() + +if(SHERPA_MNN_ENABLE_WASM_VAD_ASR) + if(NOT SHERPA_MNN_ENABLE_WASM) + message(FATAL_ERROR "Please set SHERPA_MNN_ENABLE_WASM to ON if you enable WASM for VAD+ASR") + endif() +endif() + +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ version to be used.") +endif() +set(CMAKE_CXX_EXTENSIONS OFF) +message(STATUS "C++ Standard version: ${CMAKE_CXX_STANDARD}") + +include(CheckIncludeFileCXX) + +if(SHERPA_MNN_ENABLE_RKNN) + add_definitions(-DSHERPA_MNN_ENABLE_RKNN=1) +endif() + +if(UNIX AND NOT APPLE AND NOT SHERPA_MNN_ENABLE_WASM AND NOT CMAKE_SYSTEM_NAME STREQUAL Android AND NOT CMAKE_SYSTEM_NAME STREQUAL OHOS) + check_include_file_cxx(alsa/asoundlib.h SHERPA_MNN_HAS_ALSA) + if(SHERPA_MNN_HAS_ALSA) + message(STATUS "With Alsa") + add_definitions(-DSHERPA_MNN_ENABLE_ALSA=1) + else() + message(WARNING "\ +Could not find alsa/asoundlib.h ! +We won't build sherpa-onnx-alsa +To fix that, please do: + (1) sudo apt-get install alsa-utils libasound2-dev + (2) rm -rf build + (3) re-try + ") + endif() +endif() + +check_include_file_cxx(cxxabi.h SHERPA_MNN_HAVE_CXXABI_H) +check_include_file_cxx(execinfo.h SHERPA_MNN_HAVE_EXECINFO_H) + +if(WIN32) + add_definitions(-DNOMINMAX) # Otherwise, std::max() and std::min() won't work +endif() + +if(WIN32 AND MSVC) + # disable various warnings for MSVC + # 4244: 'return': conversion from 'unsigned __int64' to 'int', possible loss of data + # 4267: 'initializing': conversion from 'size_t' to 'int', possible loss of data + # 4305: 'argument': truncation from 'double' to 'const float' + # 4334: '<<': result of 32-bit shift implicitly converted to 64 bits + # 4800: 'int': forcing value to bool 'true' or 'false' + # 4996: 'fopen': This function or variable may be unsafe + set(disabled_warnings + /wd4244 + /wd4267 + /wd4305 + /wd4334 + /wd4800 + /wd4996 + ) + message(STATUS "Disabled warnings: ${disabled_warnings}") + foreach(w IN LISTS disabled_warnings) + string(APPEND CMAKE_CXX_FLAGS " ${w} ") + endforeach() + + add_compile_options("$<$:/utf-8>") + add_compile_options("$<$:/utf-8>") +endif() + +list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/Modules) +list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) + +if(SHERPA_MNN_ENABLE_WASM) + # Enable it for debugging in case there is something wrong. + # string(APPEND CMAKE_CXX_FLAGS " -g4 -s ASSERTIONS=2 -s SAFE_HEAP=1 -s STACK_OVERFLOW_CHECK=1 ") +endif() + +if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux) + if(SHERPA_MNN_LINK_LIBSTDCPP_STATICALLY) + message(STATUS "Link libstdc++ statically") + set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -static-libstdc++ -static-libgcc ") + else() + message(STATUS "Link libstdc++ dynamically") + endif() +endif() + +include(kaldi-native-fbank) +include(kaldi-decoder) +include(simple-sentencepiece) +set(ONNXRUNTIME_DIR ${onnxruntime_SOURCE_DIR}) +message(STATUS "ONNXRUNTIME_DIR: ${ONNXRUNTIME_DIR}") + +if(SHERPA_MNN_ENABLE_PORTAUDIO AND SHERPA_MNN_ENABLE_BINARY) + # portaudio is used only in building demo binaries and the sherpa-onnx-core + # library does not depend on it. + include(portaudio) +endif() + +if(SHERPA_MNN_ENABLE_PYTHON) + include(pybind11) +endif() + +if(SHERPA_MNN_ENABLE_TESTS) + enable_testing() + include(googletest) +endif() + +if(SHERPA_MNN_ENABLE_WEBSOCKET) + include(websocketpp) + include(asio) +endif() + +if(SHERPA_MNN_ENABLE_TTS) + include(espeak-ng-for-piper) + set(ESPEAK_NG_DIR ${espeak_ng_SOURCE_DIR}) + message(STATUS "ESPEAK_NG_DIR: ${ESPEAK_NG_DIR}") + include(piper-phonemize) + include(cppjieba) # For Chinese TTS. It is a header-only C++ library +endif() + +if(SHERPA_MNN_ENABLE_SPEAKER_DIARIZATION) + include(hclust-cpp) +endif() + +# if(NOT MSVC AND CMAKE_BUILD_TYPE STREQUAL Debug AND (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")) +if(SHERPA_MNN_ENABLE_SANITIZER) + message(WARNING "enable ubsan and asan") + set(CMAKE_REQUIRED_LIBRARIES -lubsan -lasan) + include(CheckCCompilerFlag) + + set(flags -fsanitize=undefined ) + string(APPEND flags " -fno-sanitize-recover=undefined ") + string(APPEND flags " -fsanitize=integer ") + string(APPEND flags " -fsanitize=nullability ") + string(APPEND flags " -fsanitize=implicit-conversion ") + string(APPEND flags " -fsanitize=bounds ") + string(APPEND flags " -fsanitize=address ") + + if(OFF) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${flags} -Wall -Wextra") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${flags} -Wall -Wextra") + else() + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${flags}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${flags}") + endif() + + set(CMAKE_EXECUTBLE_LINKER_FLAGS "${CMAKE_EXECUTBLE_LINKER_FLAGS} ${flags}") + + add_compile_options(-fno-omit-frame-pointer) +endif() + +add_subdirectory(sherpa-mnn) + +if(SHERPA_MNN_ENABLE_C_API AND SHERPA_MNN_ENABLE_BINARY AND SHERPA_MNN_BUILD_C_API_EXAMPLES) + set(SHERPA_MNN_PKG_WITH_CARGS "-lcargs") + add_subdirectory(c-api-examples) + add_subdirectory(cxx-api-examples) +endif() + +if(SHERPA_MNN_ENABLE_WASM) + add_subdirectory(wasm) +endif() + +message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") + +if(NOT BUILD_SHARED_LIBS) + if(APPLE) + set(SHERPA_MNN_PKG_CONFIG_EXTRA_LIBS "-lc++ -framework Foundation") + endif() + + if(UNIX AND NOT APPLE) + set(SHERPA_MNN_PKG_CONFIG_EXTRA_LIBS "-lstdc++ -lm -pthread -ldl") + endif() +endif() + +if(NOT BUILD_SHARED_LIBS) +# See https://people.freedesktop.org/~dbn/pkg-config-guide.html + if(SHERPA_MNN_ENABLE_TTS) + configure_file(cmake/sherpa-onnx-static.pc.in ${PROJECT_BINARY_DIR}/sherpa-onnx.pc @ONLY) + else() + configure_file(cmake/sherpa-onnx-static-no-tts.pc.in ${PROJECT_BINARY_DIR}/sherpa-onnx.pc @ONLY) + endif() +else() + configure_file(cmake/sherpa-onnx-shared.pc.in ${PROJECT_BINARY_DIR}/sherpa-onnx.pc @ONLY) +endif() + +install( + FILES + ${PROJECT_BINARY_DIR}/sherpa-onnx.pc + DESTINATION + ./ +) +message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") diff --git a/apps/frameworks/sherpa-mnn/CPPLINT.cfg b/apps/frameworks/sherpa-mnn/CPPLINT.cfg new file mode 100644 index 00000000..714091c3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/CPPLINT.cfg @@ -0,0 +1 @@ +filter=-./mfc-examples diff --git a/apps/frameworks/sherpa-mnn/LICENSE b/apps/frameworks/sherpa-mnn/LICENSE new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/apps/frameworks/sherpa-mnn/MANIFEST.in b/apps/frameworks/sherpa-mnn/MANIFEST.in new file mode 100644 index 00000000..4372d196 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/MANIFEST.in @@ -0,0 +1,12 @@ +include LICENSE +include README.md +include CMakeLists.txt +recursive-include c-api-examples *.* +recursive-include sherpa-onnx *.* +recursive-include cmake *.* +prune */__pycache__ +prune android +prune sherpa-onnx/java-api +prune ios-swift +prune ios-swiftui + diff --git a/apps/frameworks/sherpa-mnn/NOTICE b/apps/frameworks/sherpa-mnn/NOTICE new file mode 100644 index 00000000..6d9f2e7a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/NOTICE @@ -0,0 +1,19 @@ +# NOTICE +## Project Info + +- ** Name **:sherpa-mnn +- **License**: Apache 2.0 + +## Dependencies + +- [MNN](https://github.com/alibaba/MNN/) + +## Modifications +This project is derived from sherpa-onnx (https://github.com/k2-fsa/sherpa-onnx) +Key changes include: + +- Use MNN instead of onnxruntime to do deeplearning model inference +- Rename sherpa-onnx to sherpa-mnn + +## Copyright +Copyright (c) 2022-2023 Xiaomi Corporation. All rights reserved. Copyright (c) 2025, MNN Team. diff --git a/apps/frameworks/sherpa-mnn/README.md b/apps/frameworks/sherpa-mnn/README.md new file mode 100644 index 00000000..5dffb534 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/README.md @@ -0,0 +1,93 @@ +# sherpa-mnn + +本工程基于 sherpa-onnx 改造而得,将 onnxruntime 的调用全部替换为 MNN + +## MNN 环境和模型准备 + +### MNN 编译 + +下载 MNN : https://github.com/alibaba/MNN/ + +在编译 MNN 时额外加上 `-DMNN_SEP_BUILD=OFF` 和 `-DCMAKE_INSTALL_PREFIX=.` ,: + +``` +mkdir build +cd build +cmake .. -DMNN_LOW_MEMORY=ON -DMNN_SEP_BUILD=OFF -DCMAKE_INSTALL_PREFIX=. -DMNN_BUILD_CONVERTER=ON +make -j4 +make install +``` + +### 模型转换 +在 编译好 MNNConvert 的目录下(上文的build目录),按如下命令逐个把下载好的 onnx FP32 模型转换成 mnn ,建议转换时量化一下,可以降低模型大小,并在MNN库开启`MNN_LOW_MEMORY`编译的情况下降低运行内存并提升运行性能,不要直接转换 int8 的 onnx 模型。 +``` +mkdir sherpa-mnn-models +./MNNConvert -f ONNX --modelFile sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx --MNNModel sherpa-mnn-models/encode.mnn --weightQuantBits=8 --weightQuantBlock=64 +./MNNConvert -f ONNX --modelFile sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx --MNNModel sherpa-mnn-models/decode.mnn --weightQuantBits=8 --weightQuantBlock=64 +./MNNConvert -f ONNX --modelFile sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx --MNNModel sherpa-mnn-models/joiner.mnn --weightQuantBits=8 --weightQuantBlock=64 +``` + + +## 本地编译和运行测试 + +### 编译 +回到 sherpa-mnn 根目录 +执行如下操作, `MNN_LIB_DIR`后面的内容按自己的编译目录修改 + +``` +mkdir build +cmake .. -DMNN_LIB_DIR=/Users/xtjiang/alicnn/AliNNPrivate/build +make -j16 +``` + +### 测试 +回到 sherpa-mnn 根目录,以sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 这个模型为例 + +``` +./build/bin/sherpa-mnn --tokens=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt --encoder=./sherpa-mnn-models/encode.mnn --decoder=./sherpa-mnn-models/decode.mnn --joiner=./sherpa-mnn-models/joiner.mnn ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav +``` + +正常的话会打印如下信息 +``` +Number of threads: 1, Elapsed seconds: 0.27, Audio duration (s): 5.1, Real time factor (RTF) = 0.27/5.1 = 0.053 +这是第一种第二种叫与 ALWAYS ALWAYS什么意思 +{ "text": "这是第一种第二种叫与 ALWAYS ALWAYS什么意思", "tokens": ["这", "是", "第", "一", "种", "第", "二", "种", "叫", "与", " ALWAYS", " ALWAYS", "什", "么", "意", "思"], "timestamps": [0.96, 1.04, 1.28, 1.40, 1.48, 1.72, 1.84, 2.04, 2.44, 3.64, 3.84, 4.36, 4.72, 4.76, 4.92, 5.04], "ys_probs": [-0.884769, -0.858386, -1.106216, -0.626572, -1.101773, -0.359962, -0.745972, -0.267809, -0.826859, -1.076653, -0.683002, -0.869667, -0.593140, -0.469688, -0.256882, -0.442532], "lm_probs": [], "context_scores": [], "segment": 0, "words": [], "start_time": 0.00, "is_final": false} +``` + +## 编译 Android +### MNN Android 编译 +进入 MNN 目录后操作 +``` +cd project/android +mkdir build_64 +../build_64.sh -DMNN_LOW_MEMORY=ON -DMNN_SEP_BUILD_OFF -DCMAKE_INSTALL_PREFIX=. +make install +``` + +### sherpa-mnn Android 编译 +修改 build-android-arm64-v8a.sh 脚本 +把 `MNN_LIB_DIR`后面的内容修改为上面的编译目录 + +然后执行 build-android-arm64-v8a.sh + +如果编译出来的 so 体积较大,可以用 android ndk 工具 strip 一下 + + +## 编译 iOS +修改 build-ios.sh 脚本 +把 `MNN_LIB_DIR`后面的内容修改为 MNN 根目录(保证能找到 MNN 头文件即可) + +运行 build-ios.sh 脚本 + +``` +export MNN_LIB_DIR=/path/to/MNN +sh build-ios.sh +``` + +编译出 build-ios/sherpa-mnn.xcframework + +## 编译 MacOs framework +类似 iOS 编译过程,修改 build-swift-macos.sh +把 `MNN_LIB_DIR`后面的内容修改为 MNN 根目录(保证能找到 MNN 头文件即可) +运行 build-swift-macos.sh +编译出 build-swift-macos/sherpa-mnn.xcframework/ diff --git a/apps/frameworks/sherpa-mnn/README_ONNX.md b/apps/frameworks/sherpa-mnn/README_ONNX.md new file mode 100644 index 00000000..1f08959c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/README_ONNX.md @@ -0,0 +1,446 @@ +### Supported functions + +|Speech recognition| Speech synthesis | +|------------------|------------------| +| ✔️ | ✔️ | + +|Speaker identification| Speaker diarization | Speaker verification | +|----------------------|-------------------- |------------------------| +| ✔️ | ✔️ | ✔️ | + +| Spoken Language identification | Audio tagging | Voice activity detection | +|--------------------------------|---------------|--------------------------| +| ✔️ | ✔️ | ✔️ | + +| Keyword spotting | Add punctuation | Speech enhancement | +|------------------|-----------------|--------------------| +| ✔️ | ✔️ | ✔️ | + +### Supported platforms + +|Architecture| Android | iOS | Windows | macOS | linux | HarmonyOS | +|------------|---------|---------|------------|-------|-------|-----------| +| x64 | ✔️ | | ✔️ | ✔️ | ✔️ | ✔️ | +| x86 | ✔️ | | ✔️ | | | | +| arm64 | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | +| arm32 | ✔️ | | | | ✔️ | ✔️ | +| riscv64 | | | | | ✔️ | | + +### Supported programming languages + +| 1. C++ | 2. C | 3. Python | 4. JavaScript | +|--------|-------|-----------|---------------| +| ✔️ | ✔️ | ✔️ | ✔️ | + +|5. Java | 6. C# | 7. Kotlin | 8. Swift | +|--------|-------|-----------|----------| +| ✔️ | ✔️ | ✔️ | ✔️ | + +| 9. Go | 10. Dart | 11. Rust | 12. Pascal | +|-------|----------|----------|------------| +| ✔️ | ✔️ | ✔️ | ✔️ | + +For Rust support, please see [sherpa-rs][sherpa-rs] + +It also supports WebAssembly. + +## Introduction + +This repository supports running the following functions **locally** + + - Speech-to-text (i.e., ASR); both streaming and non-streaming are supported + - Text-to-speech (i.e., TTS) + - Speaker diarization + - Speaker identification + - Speaker verification + - Spoken language identification + - Audio tagging + - VAD (e.g., [silero-vad][silero-vad]) + - Keyword spotting + +on the following platforms and operating systems: + + - x86, ``x86_64``, 32-bit ARM, 64-bit ARM (arm64, aarch64), RISC-V (riscv64) + - Linux, macOS, Windows, openKylin + - Android, WearOS + - iOS + - HarmonyOS + - NodeJS + - WebAssembly + - [NVIDIA Jetson Orin NX][NVIDIA Jetson Orin NX] (Support running on both CPU and GPU) + - [NVIDIA Jetson Nano B01][NVIDIA Jetson Nano B01] (Support running on both CPU and GPU) + - [Raspberry Pi][Raspberry Pi] + - [RV1126][RV1126] + - [LicheePi4A][LicheePi4A] + - [VisionFive 2][VisionFive 2] + - [旭日X3派][旭日X3派] + - [爱芯派][爱芯派] + - etc + +with the following APIs + + - C++, C, Python, Go, ``C#`` + - Java, Kotlin, JavaScript + - Swift, Rust + - Dart, Object Pascal + +### Links for Huggingface Spaces + +
+You can visit the following Huggingface spaces to try sherpa-onnx without +installing anything. All you need is a browser. + +| Description | URL | +|-------------------------------------------------------|-----------------------------------------| +| Speaker diarization | [Click me][hf-space-speaker-diarization]| +| Speech recognition | [Click me][hf-space-asr] | +| Speech recognition with [Whisper][Whisper] | [Click me][hf-space-asr-whisper] | +| Speech synthesis | [Click me][hf-space-tts] | +| Generate subtitles | [Click me][hf-space-subtitle] | +| Audio tagging | [Click me][hf-space-audio-tagging] | +| Spoken language identification with [Whisper][Whisper]| [Click me][hf-space-slid-whisper] | + +We also have spaces built using WebAssembly. They are listed below: + +| Description | Huggingface space| ModelScope space| +|------------------------------------------------------------------------------------------|------------------|-----------------| +|Voice activity detection with [silero-vad][silero-vad] | [Click me][wasm-hf-vad]|[地址][wasm-ms-vad]| +|Real-time speech recognition (Chinese + English) with Zipformer | [Click me][wasm-hf-streaming-asr-zh-en-zipformer]|[地址][wasm-hf-streaming-asr-zh-en-zipformer]| +|Real-time speech recognition (Chinese + English) with Paraformer |[Click me][wasm-hf-streaming-asr-zh-en-paraformer]| [地址][wasm-ms-streaming-asr-zh-en-paraformer]| +|Real-time speech recognition (Chinese + English + Cantonese) with [Paraformer-large][Paraformer-large]|[Click me][wasm-hf-streaming-asr-zh-en-yue-paraformer]| [地址][wasm-ms-streaming-asr-zh-en-yue-paraformer]| +|Real-time speech recognition (English) |[Click me][wasm-hf-streaming-asr-en-zipformer] |[地址][wasm-ms-streaming-asr-en-zipformer]| +|VAD + speech recognition (Chinese + English + Korean + Japanese + Cantonese) with [SenseVoice][SenseVoice]|[Click me][wasm-hf-vad-asr-zh-en-ko-ja-yue-sense-voice]| [地址][wasm-ms-vad-asr-zh-en-ko-ja-yue-sense-voice]| +|VAD + speech recognition (English) with [Whisper][Whisper] tiny.en|[Click me][wasm-hf-vad-asr-en-whisper-tiny-en]| [地址][wasm-ms-vad-asr-en-whisper-tiny-en]| +|VAD + speech recognition (English) with [Moonshine tiny][Moonshine tiny]|[Click me][wasm-hf-vad-asr-en-moonshine-tiny-en]| [地址][wasm-ms-vad-asr-en-moonshine-tiny-en]| +|VAD + speech recognition (English) with Zipformer trained with [GigaSpeech][GigaSpeech] |[Click me][wasm-hf-vad-asr-en-zipformer-gigaspeech]| [地址][wasm-ms-vad-asr-en-zipformer-gigaspeech]| +|VAD + speech recognition (Chinese) with Zipformer trained with [WenetSpeech][WenetSpeech] |[Click me][wasm-hf-vad-asr-zh-zipformer-wenetspeech]| [地址][wasm-ms-vad-asr-zh-zipformer-wenetspeech]| +|VAD + speech recognition (Japanese) with Zipformer trained with [ReazonSpeech][ReazonSpeech]|[Click me][wasm-hf-vad-asr-ja-zipformer-reazonspeech]| [地址][wasm-ms-vad-asr-ja-zipformer-reazonspeech]| +|VAD + speech recognition (Thai) with Zipformer trained with [GigaSpeech2][GigaSpeech2] |[Click me][wasm-hf-vad-asr-th-zipformer-gigaspeech2]| [地址][wasm-ms-vad-asr-th-zipformer-gigaspeech2]| +|VAD + speech recognition (Chinese 多种方言) with a [TeleSpeech-ASR][TeleSpeech-ASR] CTC model|[Click me][wasm-hf-vad-asr-zh-telespeech]| [地址][wasm-ms-vad-asr-zh-telespeech]| +|VAD + speech recognition (English + Chinese, 及多种中文方言) with Paraformer-large |[Click me][wasm-hf-vad-asr-zh-en-paraformer-large]| [地址][wasm-ms-vad-asr-zh-en-paraformer-large]| +|VAD + speech recognition (English + Chinese, 及多种中文方言) with Paraformer-small |[Click me][wasm-hf-vad-asr-zh-en-paraformer-small]| [地址][wasm-ms-vad-asr-zh-en-paraformer-small]| +|Speech synthesis (English) |[Click me][wasm-hf-tts-piper-en]| [地址][wasm-ms-tts-piper-en]| +|Speech synthesis (German) |[Click me][wasm-hf-tts-piper-de]| [地址][wasm-ms-tts-piper-de]| +|Speaker diarization |[Click me][wasm-hf-speaker-diarization]|[地址][wasm-ms-speaker-diarization]| + +
+ +### Links for pre-built Android APKs + +
+ +You can find pre-built Android APKs for this repository in the following table + +| Description | URL | 中国用户 | +|----------------------------------------|------------------------------------|-----------------------------------| +| Speaker diarization | [Address][apk-speaker-diarization] | [点此][apk-speaker-diarization-cn]| +| Streaming speech recognition | [Address][apk-streaming-asr] | [点此][apk-streaming-asr-cn] | +| Text-to-speech | [Address][apk-tts] | [点此][apk-tts-cn] | +| Voice activity detection (VAD) | [Address][apk-vad] | [点此][apk-vad-cn] | +| VAD + non-streaming speech recognition | [Address][apk-vad-asr] | [点此][apk-vad-asr-cn] | +| Two-pass speech recognition | [Address][apk-2pass] | [点此][apk-2pass-cn] | +| Audio tagging | [Address][apk-at] | [点此][apk-at-cn] | +| Audio tagging (WearOS) | [Address][apk-at-wearos] | [点此][apk-at-wearos-cn] | +| Speaker identification | [Address][apk-sid] | [点此][apk-sid-cn] | +| Spoken language identification | [Address][apk-slid] | [点此][apk-slid-cn] | +| Keyword spotting | [Address][apk-kws] | [点此][apk-kws-cn] | + +
+ +### Links for pre-built Flutter APPs + +
+ +#### Real-time speech recognition + +| Description | URL | 中国用户 | +|--------------------------------|-------------------------------------|-------------------------------------| +| Streaming speech recognition | [Address][apk-flutter-streaming-asr]| [点此][apk-flutter-streaming-asr-cn]| + +#### Text-to-speech + +| Description | URL | 中国用户 | +|------------------------------------------|------------------------------------|------------------------------------| +| Android (arm64-v8a, armeabi-v7a, x86_64) | [Address][flutter-tts-android] | [点此][flutter-tts-android-cn] | +| Linux (x64) | [Address][flutter-tts-linux] | [点此][flutter-tts-linux-cn] | +| macOS (x64) | [Address][flutter-tts-macos-x64] | [点此][flutter-tts-macos-arm64-cn] | +| macOS (arm64) | [Address][flutter-tts-macos-arm64] | [点此][flutter-tts-macos-x64-cn] | +| Windows (x64) | [Address][flutter-tts-win-x64] | [点此][flutter-tts-win-x64-cn] | + +> Note: You need to build from source for iOS. + +
+ +### Links for pre-built Lazarus APPs + +
+ +#### Generating subtitles + +| Description | URL | 中国用户 | +|--------------------------------|----------------------------|----------------------------| +| Generate subtitles (生成字幕) | [Address][lazarus-subtitle]| [点此][lazarus-subtitle-cn]| + +
+ +### Links for pre-trained models + +
+ +| Description | URL | +|---------------------------------------------|---------------------------------------------------------------------------------------| +| Speech recognition (speech to text, ASR) | [Address][asr-models] | +| Text-to-speech (TTS) | [Address][tts-models] | +| VAD | [Address][vad-models] | +| Keyword spotting | [Address][kws-models] | +| Audio tagging | [Address][at-models] | +| Speaker identification (Speaker ID) | [Address][sid-models] | +| Spoken language identification (Language ID)| See multi-lingual [Whisper][Whisper] ASR models from [Speech recognition][asr-models]| +| Punctuation | [Address][punct-models] | +| Speaker segmentation | [Address][speaker-segmentation-models] | +| Speech enhancement | [Address][speech-enhancement-models] | + +
+ +#### Some pre-trained ASR models (Streaming) + +
+ +Please see + + - + - + - + +for more models. The following table lists only **SOME** of them. + + +|Name | Supported Languages| Description| +|-----|-----|----| +|[sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20][sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20]| Chinese, English| See [also](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english)| +|[sherpa-onnx-streaming-zipformer-small-bilingual-zh-en-2023-02-16][sherpa-onnx-streaming-zipformer-small-bilingual-zh-en-2023-02-16]| Chinese, English| See [also](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-small-bilingual-zh-en-2023-02-16-bilingual-chinese-english)| +|[sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23][sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23]|Chinese| Suitable for Cortex A7 CPU. See [also](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-zh-14m-2023-02-23)| +|[sherpa-onnx-streaming-zipformer-en-20M-2023-02-17][sherpa-onnx-streaming-zipformer-en-20M-2023-02-17]|English|Suitable for Cortex A7 CPU. See [also](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-en-20m-2023-02-17)| +|[sherpa-onnx-streaming-zipformer-korean-2024-06-16][sherpa-onnx-streaming-zipformer-korean-2024-06-16]|Korean| See [also](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-korean-2024-06-16-korean)| +|[sherpa-onnx-streaming-zipformer-fr-2023-04-14][sherpa-onnx-streaming-zipformer-fr-2023-04-14]|French| See [also](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#shaojieli-sherpa-onnx-streaming-zipformer-fr-2023-04-14-french)| + +
+ + +#### Some pre-trained ASR models (Non-Streaming) + +
+ +Please see + + - + - + - + - + - + +for more models. The following table lists only **SOME** of them. + +|Name | Supported Languages| Description| +|-----|-----|----| +|[Whisper tiny.en](https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2)|English| See [also](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html)| +|[Moonshine tiny][Moonshine tiny]|English|See [also](https://github.com/usefulsensors/moonshine)| +|[sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17][sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17]|Chinese, Cantonese, English, Korean, Japanese| 支持多种中文方言. See [also](https://k2-fsa.github.io/sherpa/onnx/sense-voice/index.html)| +|[sherpa-onnx-paraformer-zh-2024-03-09][sherpa-onnx-paraformer-zh-2024-03-09]|Chinese, English| 也支持多种中文方言. See [also](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2024-03-09-chinese-english)| +|[sherpa-onnx-zipformer-ja-reazonspeech-2024-08-01][sherpa-onnx-zipformer-ja-reazonspeech-2024-08-01]|Japanese|See [also](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#sherpa-onnx-zipformer-ja-reazonspeech-2024-08-01-japanese)| +|[sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24][sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24]|Russian|See [also](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/nemo-transducer-models.html#sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24-russian)| +|[sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24][sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24]|Russian| See [also](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/russian.html#sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24)| +|[sherpa-onnx-zipformer-ru-2024-09-18][sherpa-onnx-zipformer-ru-2024-09-18]|Russian|See [also](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#sherpa-onnx-zipformer-ru-2024-09-18-russian)| +|[sherpa-onnx-zipformer-korean-2024-06-24][sherpa-onnx-zipformer-korean-2024-06-24]|Korean|See [also](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#sherpa-onnx-zipformer-korean-2024-06-24-korean)| +|[sherpa-onnx-zipformer-thai-2024-06-20][sherpa-onnx-zipformer-thai-2024-06-20]|Thai| See [also](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#sherpa-onnx-zipformer-thai-2024-06-20-thai)| +|[sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04][sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04]|Chinese| 支持多种方言. See [also](https://k2-fsa.github.io/sherpa/onnx/pretrained_models/telespeech/models.html#sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04)| + +
+ +### Useful links + +- Documentation: https://k2-fsa.github.io/sherpa/onnx/ +- Bilibili 演示视频: https://search.bilibili.com/all?keyword=%E6%96%B0%E4%B8%80%E4%BB%A3Kaldi + +### How to reach us + +Please see +https://k2-fsa.github.io/sherpa/social-groups.html +for 新一代 Kaldi **微信交流群** and **QQ 交流群**. + +## Projects using sherpa-onnx + +### [Open-LLM-VTuber](https://github.com/t41372/Open-LLM-VTuber) + +Talk to any LLM with hands-free voice interaction, voice interruption, and Live2D taking +face running locally across platforms + +See also + +### [voiceapi](https://github.com/ruzhila/voiceapi) + +
+ Streaming ASR and TTS based on FastAPI + + +It shows how to use the ASR and TTS Python APIs with FastAPI. +
+ +### [腾讯会议摸鱼工具 TMSpeech](https://github.com/jxlpzqc/TMSpeech) + +Uses streaming ASR in C# with graphical user interface. + +Video demo in Chinese: [【开源】Windows实时字幕软件(网课/开会必备)](https://www.bilibili.com/video/BV1rX4y1p7Nx) + +### [lol互动助手](https://github.com/l1veIn/lol-wom-electron) + +It uses the JavaScript API of sherpa-onnx along with [Electron](https://electronjs.org/) + +Video demo in Chinese: [爆了!炫神教你开打字挂!真正影响胜率的英雄联盟工具!英雄联盟的最后一块拼图!和游戏中的每个人无障碍沟通!](https://www.bilibili.com/video/BV142tje9E74) + +### [Sherpa-ONNX 语音识别服务器](https://github.com/hfyydd/sherpa-onnx-server) + +A server based on nodejs providing Restful API for speech recognition. + +### [QSmartAssistant](https://github.com/xinhecuican/QSmartAssistant) + +一个模块化,全过程可离线,低占用率的对话机器人/智能音箱 + +It uses QT. Both [ASR](https://github.com/xinhecuican/QSmartAssistant/blob/master/doc/%E5%AE%89%E8%A3%85.md#asr) +and [TTS](https://github.com/xinhecuican/QSmartAssistant/blob/master/doc/%E5%AE%89%E8%A3%85.md#tts) +are used. + + +### [Flutter-EasySpeechRecognition](https://github.com/Jason-chen-coder/Flutter-EasySpeechRecognition) + +It extends [./flutter-examples/streaming_asr](./flutter-examples/streaming_asr) by +downloading models inside the app to reduce the size of the app. + +### [sherpa-onnx-unity](https://github.com/xue-fei/sherpa-onnx-unity) + +sherpa-onnx in Unity. See also [#1695](https://github.com/k2-fsa/sherpa-onnx/issues/1695), +[#1892](https://github.com/k2-fsa/sherpa-onnx/issues/1892), and [#1859](https://github.com/k2-fsa/sherpa-onnx/issues/1859) + +[sherpa-rs]: https://github.com/thewh1teagle/sherpa-rs +[silero-vad]: https://github.com/snakers4/silero-vad +[Raspberry Pi]: https://www.raspberrypi.com/ +[RV1126]: https://www.rock-chips.com/uploads/pdf/2022.8.26/191/RV1126%20Brief%20Datasheet.pdf +[LicheePi4A]: https://sipeed.com/licheepi4a +[VisionFive 2]: https://www.starfivetech.com/en/site/boards +[旭日X3派]: https://developer.horizon.ai/api/v1/fileData/documents_pi/index.html +[爱芯派]: https://wiki.sipeed.com/hardware/zh/maixIII/ax-pi/axpi.html +[hf-space-speaker-diarization]: https://huggingface.co/spaces/k2-fsa/speaker-diarization +[hf-space-asr]: https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition +[Whisper]: https://github.com/openai/whisper +[hf-space-asr-whisper]: https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition-with-whisper +[hf-space-tts]: https://huggingface.co/spaces/k2-fsa/text-to-speech +[hf-space-subtitle]: https://huggingface.co/spaces/k2-fsa/generate-subtitles-for-videos +[hf-space-audio-tagging]: https://huggingface.co/spaces/k2-fsa/audio-tagging +[hf-space-slid-whisper]: https://huggingface.co/spaces/k2-fsa/spoken-language-identification +[wasm-hf-vad]: https://huggingface.co/spaces/k2-fsa/web-assembly-vad-sherpa-onnx +[wasm-ms-vad]: https://modelscope.cn/studios/csukuangfj/web-assembly-vad-sherpa-onnx +[wasm-hf-streaming-asr-zh-en-zipformer]: https://huggingface.co/spaces/k2-fsa/web-assembly-asr-sherpa-onnx-zh-en +[wasm-ms-streaming-asr-zh-en-zipformer]: https://modelscope.cn/studios/k2-fsa/web-assembly-asr-sherpa-onnx-zh-en +[wasm-hf-streaming-asr-zh-en-paraformer]: https://huggingface.co/spaces/k2-fsa/web-assembly-asr-sherpa-onnx-zh-en-paraformer +[wasm-ms-streaming-asr-zh-en-paraformer]: https://modelscope.cn/studios/k2-fsa/web-assembly-asr-sherpa-onnx-zh-en-paraformer +[Paraformer-large]: https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary +[wasm-hf-streaming-asr-zh-en-yue-paraformer]: https://huggingface.co/spaces/k2-fsa/web-assembly-asr-sherpa-onnx-zh-cantonese-en-paraformer +[wasm-ms-streaming-asr-zh-en-yue-paraformer]: https://modelscope.cn/studios/k2-fsa/web-assembly-asr-sherpa-onnx-zh-cantonese-en-paraformer +[wasm-hf-streaming-asr-en-zipformer]: https://huggingface.co/spaces/k2-fsa/web-assembly-asr-sherpa-onnx-en +[wasm-ms-streaming-asr-en-zipformer]: https://modelscope.cn/studios/k2-fsa/web-assembly-asr-sherpa-onnx-en +[SenseVoice]: https://github.com/FunAudioLLM/SenseVoice +[wasm-hf-vad-asr-zh-en-ko-ja-yue-sense-voice]: https://huggingface.co/spaces/k2-fsa/web-assembly-vad-asr-sherpa-onnx-zh-en-ja-ko-cantonese-sense-voice +[wasm-ms-vad-asr-zh-en-ko-ja-yue-sense-voice]: https://www.modelscope.cn/studios/csukuangfj/web-assembly-vad-asr-sherpa-onnx-zh-en-jp-ko-cantonese-sense-voice +[wasm-hf-vad-asr-en-whisper-tiny-en]: https://huggingface.co/spaces/k2-fsa/web-assembly-vad-asr-sherpa-onnx-en-whisper-tiny +[wasm-ms-vad-asr-en-whisper-tiny-en]: https://www.modelscope.cn/studios/csukuangfj/web-assembly-vad-asr-sherpa-onnx-en-whisper-tiny +[wasm-hf-vad-asr-en-moonshine-tiny-en]: https://huggingface.co/spaces/k2-fsa/web-assembly-vad-asr-sherpa-onnx-en-moonshine-tiny +[wasm-ms-vad-asr-en-moonshine-tiny-en]: https://www.modelscope.cn/studios/csukuangfj/web-assembly-vad-asr-sherpa-onnx-en-moonshine-tiny +[wasm-hf-vad-asr-en-zipformer-gigaspeech]: https://huggingface.co/spaces/k2-fsa/web-assembly-vad-asr-sherpa-onnx-en-zipformer-gigaspeech +[wasm-ms-vad-asr-en-zipformer-gigaspeech]: https://www.modelscope.cn/studios/k2-fsa/web-assembly-vad-asr-sherpa-onnx-en-zipformer-gigaspeech +[wasm-hf-vad-asr-zh-zipformer-wenetspeech]: https://huggingface.co/spaces/k2-fsa/web-assembly-vad-asr-sherpa-onnx-zh-zipformer-wenetspeech +[wasm-ms-vad-asr-zh-zipformer-wenetspeech]: https://www.modelscope.cn/studios/k2-fsa/web-assembly-vad-asr-sherpa-onnx-zh-zipformer-wenetspeech +[ReazonSpeech]: https://research.reazon.jp/_static/reazonspeech_nlp2023.pdf +[wasm-hf-vad-asr-ja-zipformer-reazonspeech]: https://huggingface.co/spaces/k2-fsa/web-assembly-vad-asr-sherpa-onnx-ja-zipformer +[wasm-ms-vad-asr-ja-zipformer-reazonspeech]: https://www.modelscope.cn/studios/csukuangfj/web-assembly-vad-asr-sherpa-onnx-ja-zipformer +[GigaSpeech2]: https://github.com/SpeechColab/GigaSpeech2 +[wasm-hf-vad-asr-th-zipformer-gigaspeech2]: https://huggingface.co/spaces/k2-fsa/web-assembly-vad-asr-sherpa-onnx-th-zipformer +[wasm-ms-vad-asr-th-zipformer-gigaspeech2]: https://www.modelscope.cn/studios/csukuangfj/web-assembly-vad-asr-sherpa-onnx-th-zipformer +[TeleSpeech-ASR]: https://github.com/Tele-AI/TeleSpeech-ASR +[wasm-hf-vad-asr-zh-telespeech]: https://huggingface.co/spaces/k2-fsa/web-assembly-vad-asr-sherpa-onnx-zh-telespeech +[wasm-ms-vad-asr-zh-telespeech]: https://www.modelscope.cn/studios/k2-fsa/web-assembly-vad-asr-sherpa-onnx-zh-telespeech +[wasm-hf-vad-asr-zh-en-paraformer-large]: https://huggingface.co/spaces/k2-fsa/web-assembly-vad-asr-sherpa-onnx-zh-en-paraformer +[wasm-ms-vad-asr-zh-en-paraformer-large]: https://www.modelscope.cn/studios/k2-fsa/web-assembly-vad-asr-sherpa-onnx-zh-en-paraformer +[wasm-hf-vad-asr-zh-en-paraformer-small]: https://huggingface.co/spaces/k2-fsa/web-assembly-vad-asr-sherpa-onnx-zh-en-paraformer-small +[wasm-ms-vad-asr-zh-en-paraformer-small]: https://www.modelscope.cn/studios/k2-fsa/web-assembly-vad-asr-sherpa-onnx-zh-en-paraformer-small +[wasm-hf-tts-piper-en]: https://huggingface.co/spaces/k2-fsa/web-assembly-tts-sherpa-onnx-en +[wasm-ms-tts-piper-en]: https://modelscope.cn/studios/k2-fsa/web-assembly-tts-sherpa-onnx-en +[wasm-hf-tts-piper-de]: https://huggingface.co/spaces/k2-fsa/web-assembly-tts-sherpa-onnx-de +[wasm-ms-tts-piper-de]: https://modelscope.cn/studios/k2-fsa/web-assembly-tts-sherpa-onnx-de +[wasm-hf-speaker-diarization]: https://huggingface.co/spaces/k2-fsa/web-assembly-speaker-diarization-sherpa-onnx +[wasm-ms-speaker-diarization]: https://www.modelscope.cn/studios/csukuangfj/web-assembly-speaker-diarization-sherpa-onnx +[apk-speaker-diarization]: https://k2-fsa.github.io/sherpa/onnx/speaker-diarization/apk.html +[apk-speaker-diarization-cn]: https://k2-fsa.github.io/sherpa/onnx/speaker-diarization/apk-cn.html +[apk-streaming-asr]: https://k2-fsa.github.io/sherpa/onnx/android/apk.html +[apk-streaming-asr-cn]: https://k2-fsa.github.io/sherpa/onnx/android/apk-cn.html +[apk-tts]: https://k2-fsa.github.io/sherpa/onnx/tts/apk-engine.html +[apk-tts-cn]: https://k2-fsa.github.io/sherpa/onnx/tts/apk-engine-cn.html +[apk-vad]: https://k2-fsa.github.io/sherpa/onnx/vad/apk.html +[apk-vad-cn]: https://k2-fsa.github.io/sherpa/onnx/vad/apk-cn.html +[apk-vad-asr]: https://k2-fsa.github.io/sherpa/onnx/vad/apk-asr.html +[apk-vad-asr-cn]: https://k2-fsa.github.io/sherpa/onnx/vad/apk-asr-cn.html +[apk-2pass]: https://k2-fsa.github.io/sherpa/onnx/android/apk-2pass.html +[apk-2pass-cn]: https://k2-fsa.github.io/sherpa/onnx/android/apk-2pass-cn.html +[apk-at]: https://k2-fsa.github.io/sherpa/onnx/audio-tagging/apk.html +[apk-at-cn]: https://k2-fsa.github.io/sherpa/onnx/audio-tagging/apk-cn.html +[apk-at-wearos]: https://k2-fsa.github.io/sherpa/onnx/audio-tagging/apk-wearos.html +[apk-at-wearos-cn]: https://k2-fsa.github.io/sherpa/onnx/audio-tagging/apk-wearos-cn.html +[apk-sid]: https://k2-fsa.github.io/sherpa/onnx/speaker-identification/apk.html +[apk-sid-cn]: https://k2-fsa.github.io/sherpa/onnx/speaker-identification/apk-cn.html +[apk-slid]: https://k2-fsa.github.io/sherpa/onnx/spoken-language-identification/apk.html +[apk-slid-cn]: https://k2-fsa.github.io/sherpa/onnx/spoken-language-identification/apk-cn.html +[apk-kws]: https://k2-fsa.github.io/sherpa/onnx/kws/apk.html +[apk-kws-cn]: https://k2-fsa.github.io/sherpa/onnx/kws/apk-cn.html +[apk-flutter-streaming-asr]: https://k2-fsa.github.io/sherpa/onnx/flutter/asr/app.html +[apk-flutter-streaming-asr-cn]: https://k2-fsa.github.io/sherpa/onnx/flutter/asr/app-cn.html +[flutter-tts-android]: https://k2-fsa.github.io/sherpa/onnx/flutter/tts-android.html +[flutter-tts-android-cn]: https://k2-fsa.github.io/sherpa/onnx/flutter/tts-android-cn.html +[flutter-tts-linux]: https://k2-fsa.github.io/sherpa/onnx/flutter/tts-linux.html +[flutter-tts-linux-cn]: https://k2-fsa.github.io/sherpa/onnx/flutter/tts-linux-cn.html +[flutter-tts-macos-x64]: https://k2-fsa.github.io/sherpa/onnx/flutter/tts-macos-x64.html +[flutter-tts-macos-arm64-cn]: https://k2-fsa.github.io/sherpa/onnx/flutter/tts-macos-x64-cn.html +[flutter-tts-macos-arm64]: https://k2-fsa.github.io/sherpa/onnx/flutter/tts-macos-arm64.html +[flutter-tts-macos-x64-cn]: https://k2-fsa.github.io/sherpa/onnx/flutter/tts-macos-arm64-cn.html +[flutter-tts-win-x64]: https://k2-fsa.github.io/sherpa/onnx/flutter/tts-win.html +[flutter-tts-win-x64-cn]: https://k2-fsa.github.io/sherpa/onnx/flutter/tts-win-cn.html +[lazarus-subtitle]: https://k2-fsa.github.io/sherpa/onnx/lazarus/download-generated-subtitles.html +[lazarus-subtitle-cn]: https://k2-fsa.github.io/sherpa/onnx/lazarus/download-generated-subtitles-cn.html +[asr-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models +[tts-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models +[vad-models]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx +[kws-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/kws-models +[at-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models +[sid-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +[slid-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +[punct-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models +[speaker-segmentation-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models +[GigaSpeech]: https://github.com/SpeechColab/GigaSpeech +[WenetSpeech]: https://github.com/wenet-e2e/WenetSpeech +[sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 +[sherpa-onnx-streaming-zipformer-small-bilingual-zh-en-2023-02-16]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-small-bilingual-zh-en-2023-02-16.tar.bz2 +[sherpa-onnx-streaming-zipformer-korean-2024-06-16]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-korean-2024-06-16.tar.bz2 +[sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23.tar.bz2 +[sherpa-onnx-streaming-zipformer-en-20M-2023-02-17]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 +[sherpa-onnx-zipformer-ja-reazonspeech-2024-08-01]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-ja-reazonspeech-2024-08-01.tar.bz2 +[sherpa-onnx-zipformer-ru-2024-09-18]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-ru-2024-09-18.tar.bz2 +[sherpa-onnx-zipformer-korean-2024-06-24]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-korean-2024-06-24.tar.bz2 +[sherpa-onnx-zipformer-thai-2024-06-20]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-thai-2024-06-20.tar.bz2 +[sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24.tar.bz2 +[sherpa-onnx-paraformer-zh-2024-03-09]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2024-03-09.tar.bz2 +[sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24.tar.bz2 +[sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2 +[sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +[sherpa-onnx-streaming-zipformer-fr-2023-04-14]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-fr-2023-04-14.tar.bz2 +[Moonshine tiny]: https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +[NVIDIA Jetson Orin NX]: https://developer.download.nvidia.com/assets/embedded/secure/jetson/orin_nx/docs/Jetson_Orin_NX_DS-10712-001_v0.5.pdf?RCPGu9Q6OVAOv7a7vgtwc9-BLScXRIWq6cSLuditMALECJ_dOj27DgnqAPGVnT2VpiNpQan9SyFy-9zRykR58CokzbXwjSA7Gj819e91AXPrWkGZR3oS1VLxiDEpJa_Y0lr7UT-N4GnXtb8NlUkP4GkCkkF_FQivGPrAucCUywL481GH_WpP_p7ziHU1Wg==&t=eyJscyI6ImdzZW8iLCJsc2QiOiJodHRwczovL3d3dy5nb29nbGUuY29tLmhrLyJ9 +[NVIDIA Jetson Nano B01]: https://www.seeedstudio.com/blog/2020/01/16/new-revision-of-jetson-nano-dev-kit-now-supports-new-jetson-nano-module/ +[speech-enhancement-models]: https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models diff --git a/apps/frameworks/sherpa-mnn/build-android-arm64-v8a.sh b/apps/frameworks/sherpa-mnn/build-android-arm64-v8a.sh new file mode 100755 index 00000000..db56a39e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/build-android-arm64-v8a.sh @@ -0,0 +1,158 @@ +#!/usr/bin/env bash +set -ex + +# If BUILD_SHARED_LIBS is ON, we use libonnxruntime.so +# If BUILD_SHARED_LIBS is OFF, we use libonnxruntime.a +# +# In any case, we will have libsherpa-onnx-jni.so +# +# If BUILD_SHARED_LIBS is OFF, then libonnxruntime.a is linked into libsherpa-onnx-jni.so +# and you only need to copy libsherpa-onnx-jni.so to your Android projects. +# +# If BUILD_SHARED_LIBS is ON, then you need to copy both libsherpa-onnx-jni.so +# and libonnxruntime.so to your Android projects +# +BUILD_SHARED_LIBS=ON + +if [ $BUILD_SHARED_LIBS == ON ]; then + dir=$PWD/build-android-arm64-v8a +else + dir=$PWD/build-android-arm64-v8a-static +fi + +mkdir -p $dir +cd $dir + +# Note from https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-android +# (optional) remove the hardcoded debug flag in Android NDK android-ndk +# issue: https://github.com/android/ndk/issues/243 +# +# open $ANDROID_NDK/build/cmake/android.toolchain.cmake for ndk < r23 +# or $ANDROID_NDK/build/cmake/android-legacy.toolchain.cmake for ndk >= r23 +# +# delete "-g" line +# +# list(APPEND ANDROID_COMPILER_FLAGS +# -g +# -DANDROID + +if [ -z $ANDROID_NDK ]; then + ANDROID_NDK=/star-fj/fangjun/software/android-sdk/ndk/22.1.7171670 + if [ $BUILD_SHARED_LIBS == OFF ]; then + ANDROID_NDK=/star-fj/fangjun/software/android-sdk/ndk/27.0.11718014 + fi + # or use + # ANDROID_NDK=/star-fj/fangjun/software/android-ndk + # + # Inside the $ANDROID_NDK directory, you can find a binary ndk-build + # and some other files like the file "build/cmake/android.toolchain.cmake" + + if [ ! -d $ANDROID_NDK ]; then + # For macOS, I have installed Android Studio, select the menu + # Tools -> SDK manager -> Android SDK + # and set "Android SDK location" to /Users/fangjun/software/my-android + ANDROID_NDK=/Users/fangjun/software/my-android/ndk/22.1.7171670 + + if [ $BUILD_SHARED_LIBS == OFF ]; then + ANDROID_NDK=/Users/fangjun/software/my-android/ndk/27.0.11718014 + fi + fi +fi + +if [ ! -d $ANDROID_NDK ]; then + echo Please set the environment variable ANDROID_NDK before you run this script + exit 1 +fi + +echo "ANDROID_NDK: $ANDROID_NDK" +sleep 1 + +if [ -z $SHERPA_MNN_ENABLE_TTS ]; then + SHERPA_MNN_ENABLE_TTS=ON +fi + +if [ -z $SHERPA_MNN_ENABLE_SPEAKER_DIARIZATION ]; then + SHERPA_MNN_ENABLE_SPEAKER_DIARIZATION=ON +fi + +if [ -z $SHERPA_MNN_ENABLE_BINARY ]; then + SHERPA_MNN_ENABLE_BINARY=OFF +fi + +if [ -z $SHERPA_MNN_ENABLE_C_API ]; then + SHERPA_MNN_ENABLE_C_API=OFF +fi + +if [ -z $SHERPA_MNN_ENABLE_JNI ]; then + SHERPA_MNN_ENABLE_JNI=ON +fi + +cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \ + -DSHERPA_MNN_ENABLE_TTS=$SHERPA_MNN_ENABLE_TTS \ + -DSHERPA_MNN_ENABLE_SPEAKER_DIARIZATION=$SHERPA_MNN_ENABLE_SPEAKER_DIARIZATION \ + -DSHERPA_MNN_ENABLE_BINARY=$SHERPA_MNN_ENABLE_BINARY \ + -DBUILD_PIPER_PHONMIZE_EXE=OFF \ + -DBUILD_PIPER_PHONMIZE_TESTS=OFF \ + -DBUILD_ESPEAK_NG_EXE=OFF \ + -DBUILD_ESPEAK_NG_TESTS=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DMNN_LIB_DIR=/Users/xtjiang/alicnn/AliNNPrivate/project/android/build_64 \ + -DBUILD_SHARED_LIBS=$BUILD_SHARED_LIBS \ + -DSHERPA_MNN_ENABLE_PYTHON=OFF \ + -DSHERPA_MNN_ENABLE_TESTS=OFF \ + -DSHERPA_MNN_ENABLE_CHECK=OFF \ + -DSHERPA_MNN_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_MNN_ENABLE_JNI=$SHERPA_MNN_ENABLE_JNI \ + -DSHERPA_MNN_LINK_LIBSTDCPP_STATICALLY=OFF \ + -DSHERPA_MNN_ENABLE_C_API=$SHERPA_MNN_ENABLE_C_API \ + -DCMAKE_INSTALL_PREFIX=./install \ + -DANDROID_ABI="arm64-v8a" \ + -DANDROID_PLATFORM=android-21 .. + + # By default, it links to libc++_static.a + # -DANDROID_STL=c++_shared \ + +# Please use -DANDROID_PLATFORM=android-27 if you want to use Android NNAPI + +# make VERBOSE=1 -j4 +make -j4 +make install/strip +rm -rf install/share +rm -rf install/lib/pkgconfig +rm -rf install/lib/lib*.a +if [ -f install/lib/libsherpa-onnx-c-api.so ]; then + cat >install/lib/README.md < +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + SherpaMnnOfflinePunctuationConfig config; + memset(&config, 0, sizeof(config)); + + // clang-format off + config.model.ct_transformer = "./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx"; + // clang-format on + config.model.num_threads = 1; + config.model.debug = 1; + config.model.provider = "cpu"; + + const SherpaMnnOfflinePunctuation *punct = + SherpaMnnCreateOfflinePunctuation(&config); + if (!punct) { + fprintf(stderr, + "Failed to create OfflinePunctuation. Please check your config"); + return -1; + } + + const char *texts[] = { + "这是一个测试你好吗How are you我很好thank you are you ok谢谢你", + "我们都是木头人不会说话不会动", + ("The African blogosphere is rapidly expanding bringing more voices " + "online in the form of commentaries opinions analyses rants and poetry"), + }; + + int32_t n = sizeof(texts) / sizeof(const char *); + fprintf(stderr, "n: %d\n", n); + + fprintf(stderr, "--------------------\n"); + for (int32_t i = 0; i != n; ++i) { + const char *text_with_punct = + SherpaOfflinePunctuationAddPunct(punct, texts[i]); + + fprintf(stderr, "Input text: %s\n", texts[i]); + fprintf(stderr, "Output text: %s\n", text_with_punct); + SherpaOfflinePunctuationFreeText(text_with_punct); + fprintf(stderr, "--------------------\n"); + } + + SherpaMnnDestroyOfflinePunctuation(punct); + + return 0; +}; diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/CMakeLists.txt b/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/CMakeLists.txt new file mode 100644 index 00000000..925d4625 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/CMakeLists.txt @@ -0,0 +1,9 @@ + +add_executable(c-api-alsa c-api-alsa.cc alsa.cc) +target_link_libraries(c-api-alsa sherpa-onnx-c-api cargs) + +if(DEFINED ENV{SHERPA_MNN_ALSA_LIB_DIR}) + target_link_libraries(c-api-alsa -L$ENV{SHERPA_MNN_ALSA_LIB_DIR} -lasound) +else() + target_link_libraries(c-api-alsa asound) +endif() diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/CPPLINT.cfg b/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/CPPLINT.cfg new file mode 100644 index 00000000..f1b97ab7 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/CPPLINT.cfg @@ -0,0 +1 @@ +exclude_files=alsa.cc|alsa.h diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/README.md b/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/README.md new file mode 100644 index 00000000..50e24235 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/README.md @@ -0,0 +1,12 @@ +# Introduction + +This folder contains examples for real-time speech recognition from a microphone +using sherpa-onnx C API. + +**Note**: You can call C API from C++ files. + + +## ./c-api-alsa.cc + +This file uses alsa to read a microphone. It runs only on Linux. This file +does not support macOS or Windows. diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/alsa.cc b/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/alsa.cc new file mode 100644 index 00000000..7acd97ce --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/alsa.cc @@ -0,0 +1 @@ +../../sherpa-onnx/csrc/alsa.cc \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/alsa.h b/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/alsa.h new file mode 100644 index 00000000..cde29958 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/alsa.h @@ -0,0 +1 @@ +../../sherpa-onnx/csrc/alsa.h \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/c-api-alsa.cc b/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/c-api-alsa.cc new file mode 100644 index 00000000..88df8f34 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/asr-microphone-example/c-api-alsa.cc @@ -0,0 +1,259 @@ +// c-api-examples/asr-microphone-example/c-api-alsa.cc +// Copyright (c) 2022-2024 Xiaomi Corporation + +#include +#include +#include +#include + +#include +#include // std::tolower +#include +#include + +#include "c-api-examples/asr-microphone-example/alsa.h" + +// NOTE: You don't need to use cargs.h in your own project. +// We use it in this file to parse commandline arguments +#include "cargs.h" // NOLINT +#include "sherpa-mnn/c-api/c-api.h" + +static struct cag_option options[] = { + {/*.identifier =*/'h', + /*.access_letters =*/"h", + /*.access_name =*/"help", + /*.value_name =*/"help", + /*.description =*/"Show help"}, + {/*.identifier =*/'t', + /*.access_letters =*/NULL, + /*.access_name =*/"tokens", + /*.value_name =*/"tokens", + /*.description =*/"Tokens file"}, + {/*.identifier =*/'e', + /*.access_letters =*/NULL, + /*.access_name =*/"encoder", + /*.value_name =*/"encoder", + /*.description =*/"Encoder ONNX file"}, + {/*.identifier =*/'d', + /*.access_letters =*/NULL, + /*.access_name =*/"decoder", + /*.value_name =*/"decoder", + /*.description =*/"Decoder ONNX file"}, + {/*.identifier =*/'j', + /*.access_letters =*/NULL, + /*.access_name =*/"joiner", + /*.value_name =*/"joiner", + /*.description =*/"Joiner ONNX file"}, + {/*.identifier =*/'n', + /*.access_letters =*/NULL, + /*.access_name =*/"num-threads", + /*.value_name =*/"num-threads", + /*.description =*/"Number of threads"}, + {/*.identifier =*/'p', + /*.access_letters =*/NULL, + /*.access_name =*/"provider", + /*.value_name =*/"provider", + /*.description =*/"Provider: cpu (default), cuda, coreml"}, + {/*.identifier =*/'m', + /*.access_letters =*/NULL, + /*.access_name =*/"decoding-method", + /*.value_name =*/"decoding-method", + /*.description =*/ + "Decoding method: greedy_search (default), modified_beam_search"}, + {/*.identifier =*/'f', + /*.access_letters =*/NULL, + /*.access_name =*/"hotwords-file", + /*.value_name =*/"hotwords-file", + /*.description =*/ + "The file containing hotwords, one words/phrases per line, and for each " + "phrase the bpe/cjkchar are separated by a space. For example: ▁HE LL O " + "▁WORLD, 你 好 世 界"}, + {/*.identifier =*/'s', + /*.access_letters =*/NULL, + /*.access_name =*/"hotwords-score", + /*.value_name =*/"hotwords-score", + /*.description =*/ + "The bonus score for each token in hotwords. Used only when " + "decoding_method is modified_beam_search"}, +}; + +const char *kUsage = + R"( +Usage: + ./bin/c-api-alsa \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/decoder.onnx \ + device_name + +The device name specifies which microphone to use in case there are several +on your system. You can use + + arecord -l + +to find all available microphones on your computer. For instance, if it outputs + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and device 0 on that card, please use: + + plughw:3,0 + +as the device_name. +)"; + +bool stop = false; + +static void Handler(int sig) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); +} + +int32_t main(int32_t argc, char *argv[]) { + if (argc < 6) { + fprintf(stderr, "%s\n", kUsage); + exit(0); + } + + signal(SIGINT, Handler); + + SherpaMnnOnlineRecognizerConfig config; + memset(&config, 0, sizeof(config)); + + config.model_config.debug = 0; + config.model_config.num_threads = 1; + config.model_config.provider = "cpu"; + + config.decoding_method = "greedy_search"; + + config.max_active_paths = 4; + + config.feat_config.sample_rate = 16000; + config.feat_config.feature_dim = 80; + + config.enable_endpoint = 1; + config.rule1_min_trailing_silence = 2.4; + config.rule2_min_trailing_silence = 1.2; + config.rule3_min_utterance_length = 300; + + cag_option_context context; + char identifier; + const char *value; + + cag_option_prepare(&context, options, CAG_ARRAY_SIZE(options), argc, argv); + + while (cag_option_fetch(&context)) { + identifier = cag_option_get(&context); + value = cag_option_get_value(&context); + switch (identifier) { + case 't': + config.model_config.tokens = value; + break; + case 'e': + config.model_config.transducer.encoder = value; + break; + case 'd': + config.model_config.transducer.decoder = value; + break; + case 'j': + config.model_config.transducer.joiner = value; + break; + case 'n': + config.model_config.num_threads = atoi(value); + break; + case 'p': + config.model_config.provider = value; + break; + case 'm': + config.decoding_method = value; + break; + case 'f': + config.hotwords_file = value; + break; + case 's': + config.hotwords_score = atof(value); + break; + case 'h': { + fprintf(stderr, "%s\n", kUsage); + exit(0); + break; + } + default: + // do nothing as config already has valid default values + break; + } + } + + const SherpaMnnOnlineRecognizer *recognizer = + SherpaMnnCreateOnlineRecognizer(&config); + const SherpaMnnOnlineStream *stream = + SherpaMnnCreateOnlineStream(recognizer); + + const SherpaMnnDisplay *display = SherpaMnnCreateDisplay(50); + int32_t segment_id = 0; + + const char *device_name = argv[context.index]; + sherpa_mnn::Alsa alsa(device_name); + fprintf(stderr, "Use recording device: %s\n", device_name); + fprintf(stderr, + "Please \033[32m\033[1mspeak\033[0m! Press \033[31m\033[1mCtrl + " + "C\033[0m to exit\n"); + + int32_t expected_sample_rate = 16000; + + if (alsa.GetExpectedSampleRate() != expected_sample_rate) { + fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(), + expected_sample_rate); + exit(-1); + } + + int32_t chunk = 0.1 * alsa.GetActualSampleRate(); + + std::string last_text; + + int32_t segment_index = 0; + + while (!stop) { + const std::vector &samples = alsa.Read(chunk); + SherpaMnnOnlineStreamAcceptWaveform(stream, expected_sample_rate, + samples.data(), samples.size()); + while (SherpaMnnIsOnlineStreamReady(recognizer, stream)) { + SherpaMnnDecodeOnlineStream(recognizer, stream); + } + + const SherpaMnnOnlineRecognizerResult *r = + SherpaMnnGetOnlineStreamResult(recognizer, stream); + + std::string text = r->text; + SherpaMnnDestroyOnlineRecognizerResult(r); + + if (!text.empty() && last_text != text) { + last_text = text; + + std::transform(text.begin(), text.end(), text.begin(), + [](auto c) { return std::tolower(c); }); + + SherpaMnnPrint(display, segment_index, text.c_str()); + fflush(stderr); + } + + if (SherpaMnnOnlineStreamIsEndpoint(recognizer, stream)) { + if (!text.empty()) { + ++segment_index; + } + SherpaMnnOnlineStreamReset(recognizer, stream); + } + } + + // free allocated resources + SherpaMnnDestroyDisplay(display); + SherpaMnnDestroyOnlineStream(stream); + SherpaMnnDestroyOnlineRecognizer(recognizer); + fprintf(stderr, "\n"); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/audio-tagging-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/audio-tagging-c-api.c new file mode 100644 index 00000000..a037a61e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/audio-tagging-c-api.c @@ -0,0 +1,79 @@ +// c-api-examples/audio-tagging-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// We assume you have pre-downloaded the model files for testing +// from https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models +// +// An example is given below: +// +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 +// tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 +// rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + SherpaMnnAudioTaggingConfig config; + memset(&config, 0, sizeof(config)); + + config.model.zipformer.model = + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.int8.onnx"; + config.model.num_threads = 1; + config.model.debug = 1; + config.model.provider = "cpu"; + // clang-format off + config.labels = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv"; + // clang-format on + + const SherpaMnnAudioTagging *tagger = SherpaMnnCreateAudioTagging(&config); + if (!tagger) { + fprintf(stderr, "Failed to create audio tagger. Please check your config"); + return -1; + } + + // You can find more test waves from + // https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 + const char *wav_filename = + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav"; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + const SherpaMnnOfflineStream *stream = + SherpaMnnAudioTaggingCreateOfflineStream(tagger); + + SherpaMnnAcceptWaveformOffline(stream, wave->sample_rate, wave->samples, + wave->num_samples); + + int32_t top_k = 5; + const SherpaMnnAudioEvent *const *results = + SherpaMnnAudioTaggingCompute(tagger, stream, top_k); + + fprintf(stderr, "--------------------------------------------------\n"); + fprintf(stderr, "Index\t\tProbability\t\tEvent name\n"); + fprintf(stderr, "--------------------------------------------------\n"); + for (int32_t i = 0; i != top_k; ++i) { + fprintf(stderr, "%d\t\t%.3f\t\t\t%s\n", i, results[i]->prob, + results[i]->name); + } + fprintf(stderr, "--------------------------------------------------\n"); + + SherpaMnnAudioTaggingFreeResults(results); + SherpaMnnDestroyOfflineStream(stream); + SherpaMnnFreeWave(wave); + SherpaMnnDestroyAudioTagging(tagger); + + return 0; +}; diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/decode-file-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/decode-file-c-api.c new file mode 100644 index 00000000..39257a77 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/decode-file-c-api.c @@ -0,0 +1,244 @@ +// c-api-examples/decode-file-c-api.c +// +// Copyright (c) 2023 Xiaomi Corporation + +// This file shows how to use sherpa-onnx C API +// to decode a file. + +#include +#include +#include + +#include "cargs.h" +#include "sherpa-mnn/c-api/c-api.h" + +static struct cag_option options[] = { + {.identifier = 'h', + .access_letters = "h", + .access_name = "help", + .description = "Show help"}, + {.identifier = 't', + .access_letters = NULL, + .access_name = "tokens", + .value_name = "tokens", + .description = "Tokens file"}, + {.identifier = 'e', + .access_letters = NULL, + .access_name = "encoder", + .value_name = "encoder", + .description = "Encoder ONNX file"}, + {.identifier = 'd', + .access_letters = NULL, + .access_name = "decoder", + .value_name = "decoder", + .description = "Decoder ONNX file"}, + {.identifier = 'j', + .access_letters = NULL, + .access_name = "joiner", + .value_name = "joiner", + .description = "Joiner ONNX file"}, + {.identifier = 'n', + .access_letters = NULL, + .access_name = "num-threads", + .value_name = "num-threads", + .description = "Number of threads"}, + {.identifier = 'p', + .access_letters = NULL, + .access_name = "provider", + .value_name = "provider", + .description = "Provider: cpu (default), cuda, coreml"}, + {.identifier = 'm', + .access_letters = NULL, + .access_name = "decoding-method", + .value_name = "decoding-method", + .description = + "Decoding method: greedy_search (default), modified_beam_search"}, + {.identifier = 'f', + .access_letters = NULL, + .access_name = "hotwords-file", + .value_name = "hotwords-file", + .description = "The file containing hotwords, one words/phrases per line, " + "and for each phrase the bpe/cjkchar are separated by a " + "space. For example: ▁HE LL O ▁WORLD, 你 好 世 界"}, + {.identifier = 's', + .access_letters = NULL, + .access_name = "hotwords-score", + .value_name = "hotwords-score", + .description = "The bonus score for each token in hotwords. Used only " + "when decoding_method is modified_beam_search"}, +}; + +const char *kUsage = + "\n" + "Usage:\n " + " ./bin/decode-file-c-api \\\n" + " --tokens=/path/to/tokens.txt \\\n" + " --encoder=/path/to/encoder.onnx \\\n" + " --decoder=/path/to/decoder.onnx \\\n" + " --joiner=/path/to/joiner.onnx \\\n" + " --provider=cpu \\\n" + " /path/to/foo.wav\n" + "\n\n" + "Default num_threads is 1.\n" + "Valid decoding_method: greedy_search (default), modified_beam_search\n\n" + "Valid provider: cpu (default), cuda, coreml\n\n" + "Please refer to \n" + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/" + "index.html\n" + "for a list of pre-trained models to download.\n" + "\n" + "Note that this file supports only streaming transducer models.\n"; + +int32_t main(int32_t argc, char *argv[]) { + if (argc < 6) { + fprintf(stderr, "%s\n", kUsage); + exit(0); + } + + SherpaMnnOnlineRecognizerConfig config; + memset(&config, 0, sizeof(config)); + + config.model_config.debug = 0; + config.model_config.num_threads = 1; + config.model_config.provider = "cpu"; + + config.decoding_method = "greedy_search"; + + config.max_active_paths = 4; + + config.feat_config.sample_rate = 16000; + config.feat_config.feature_dim = 80; + + config.enable_endpoint = 1; + config.rule1_min_trailing_silence = 2.4; + config.rule2_min_trailing_silence = 1.2; + config.rule3_min_utterance_length = 300; + + cag_option_context context; + char identifier; + const char *value; + + cag_option_prepare(&context, options, CAG_ARRAY_SIZE(options), argc, argv); + + while (cag_option_fetch(&context)) { + identifier = cag_option_get(&context); + value = cag_option_get_value(&context); + switch (identifier) { + case 't': + config.model_config.tokens = value; + break; + case 'e': + config.model_config.transducer.encoder = value; + break; + case 'd': + config.model_config.transducer.decoder = value; + break; + case 'j': + config.model_config.transducer.joiner = value; + break; + case 'n': + config.model_config.num_threads = atoi(value); + break; + case 'p': + config.model_config.provider = value; + break; + case 'm': + config.decoding_method = value; + break; + case 'f': + config.hotwords_file = value; + break; + case 's': + config.hotwords_score = atof(value); + break; + case 'h': { + fprintf(stderr, "%s\n", kUsage); + exit(0); + break; + } + default: + // do nothing as config already has valid default values + break; + } + } + + const SherpaMnnOnlineRecognizer *recognizer = + SherpaMnnCreateOnlineRecognizer(&config); + const SherpaMnnOnlineStream *stream = + SherpaMnnCreateOnlineStream(recognizer); + + const SherpaMnnDisplay *display = SherpaMnnCreateDisplay(50); + int32_t segment_id = 0; + + const char *wav_filename = argv[context.index]; + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + // simulate streaming + +#define N 3200 // 0.2 s. Sample rate is fixed to 16 kHz + + fprintf(stderr, "sample rate: %d, num samples: %d, duration: %.2f s\n", + wave->sample_rate, wave->num_samples, + (float)wave->num_samples / wave->sample_rate); + + int32_t k = 0; + while (k < wave->num_samples) { + int32_t start = k; + int32_t end = + (start + N > wave->num_samples) ? wave->num_samples : (start + N); + k += N; + + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, + wave->samples + start, end - start); + while (SherpaMnnIsOnlineStreamReady(recognizer, stream)) { + SherpaMnnDecodeOnlineStream(recognizer, stream); + } + + const SherpaMnnOnlineRecognizerResult *r = + SherpaMnnGetOnlineStreamResult(recognizer, stream); + + if (strlen(r->text)) { + SherpaMnnPrint(display, segment_id, r->text); + } + + if (SherpaMnnOnlineStreamIsEndpoint(recognizer, stream)) { + if (strlen(r->text)) { + ++segment_id; + } + SherpaMnnOnlineStreamReset(recognizer, stream); + } + + SherpaMnnDestroyOnlineRecognizerResult(r); + } + + // add some tail padding + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, + 4800); + + SherpaMnnFreeWave(wave); + + SherpaMnnOnlineStreamInputFinished(stream); + while (SherpaMnnIsOnlineStreamReady(recognizer, stream)) { + SherpaMnnDecodeOnlineStream(recognizer, stream); + } + + const SherpaMnnOnlineRecognizerResult *r = + SherpaMnnGetOnlineStreamResult(recognizer, stream); + + if (strlen(r->text)) { + SherpaMnnPrint(display, segment_id, r->text); + } + + SherpaMnnDestroyOnlineRecognizerResult(r); + + SherpaMnnDestroyDisplay(display); + SherpaMnnDestroyOnlineStream(stream); + SherpaMnnDestroyOnlineRecognizer(recognizer); + fprintf(stderr, "\n"); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/fire-red-asr-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/fire-red-asr-c-api.c new file mode 100644 index 00000000..90c6027d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/fire-red-asr-c-api.c @@ -0,0 +1,84 @@ +// c-api-examples/fire-red-asr-c-api.c +// +// Copyright (c) 2025 Xiaomi Corporation + +// We assume you have pre-downloaded the FireRedAsr model +// from https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models +// An example is given below: +// +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 +// tar xvf sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 +// rm sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + const char *wav_filename = + "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/0.wav"; + const char *encoder_filename = + "sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/encoder.int8.onnx"; + const char *decoder_filename = + "sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/decoder.int8.onnx"; + const char *tokens_filename = + "sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/tokens.txt"; + const char *provider = "cpu"; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + // Offline model config + SherpaMnnOfflineModelConfig offline_model_config; + memset(&offline_model_config, 0, sizeof(offline_model_config)); + offline_model_config.debug = 1; + offline_model_config.num_threads = 1; + offline_model_config.provider = provider; + offline_model_config.tokens = tokens_filename; + offline_model_config.fire_red_asr.encoder = encoder_filename; + offline_model_config.fire_red_asr.decoder = decoder_filename; + + // Recognizer config + SherpaMnnOfflineRecognizerConfig recognizer_config; + memset(&recognizer_config, 0, sizeof(recognizer_config)); + recognizer_config.decoding_method = "greedy_search"; + recognizer_config.model_config = offline_model_config; + + const SherpaMnnOfflineRecognizer *recognizer = + SherpaMnnCreateOfflineRecognizer(&recognizer_config); + + if (recognizer == NULL) { + fprintf(stderr, "Please check your config!\n"); + + SherpaMnnFreeWave(wave); + + return -1; + } + + const SherpaMnnOfflineStream *stream = + SherpaMnnCreateOfflineStream(recognizer); + + SherpaMnnAcceptWaveformOffline(stream, wave->sample_rate, wave->samples, + wave->num_samples); + SherpaMnnDecodeOfflineStream(recognizer, stream); + const SherpaMnnOfflineRecognizerResult *result = + SherpaMnnGetOfflineStreamResult(stream); + + fprintf(stderr, "Decoded text: %s\n", result->text); + + SherpaMnnDestroyOfflineRecognizerResult(result); + SherpaMnnDestroyOfflineStream(stream); + SherpaMnnDestroyOfflineRecognizer(recognizer); + SherpaMnnFreeWave(wave); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c new file mode 100644 index 00000000..73d85cbc --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c @@ -0,0 +1,196 @@ +// c-api-examples/keywords-spotter-buffered-tokens-keywords-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Luo Xiao + +// +// This file demonstrates how to use keywords spotter with sherpa-onnx's C +// API and with tokens and keywords loaded from buffered strings instead of from +// external files API. +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +static size_t ReadFile(const char *filename, const char **buffer_out) { + FILE *file = fopen(filename, "r"); + if (file == NULL) { + fprintf(stderr, "Failed to open %s\n", filename); + return -1; + } + fseek(file, 0L, SEEK_END); + long size = ftell(file); + rewind(file); + *buffer_out = malloc(size); + if (*buffer_out == NULL) { + fclose(file); + fprintf(stderr, "Memory error\n"); + return -1; + } + size_t read_bytes = fread((void *)*buffer_out, 1, size, file); + if (read_bytes != size) { + printf("Errors occured in reading the file %s\n", filename); + free((void *)*buffer_out); + *buffer_out = NULL; + fclose(file); + return -1; + } + fclose(file); + return read_bytes; +} + +int32_t main() { + const char *wav_filename = + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/test_wavs/" + "6.wav"; + const char *encoder_filename = + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"; + const char *decoder_filename = + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "decoder-epoch-12-avg-2-chunk-16-left-64.onnx"; + const char *joiner_filename = + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"; + const char *provider = "cpu"; + const char *tokens_filename = + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/tokens.txt"; + const char *keywords_filename = + "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/test_wavs/" + "test_keywords.txt"; + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + // reading tokens and keywords to buffers + const char *tokens_buf; + size_t token_buf_size = ReadFile(tokens_filename, &tokens_buf); + if (token_buf_size < 1) { + fprintf(stderr, "Please check your tokens.txt!\n"); + free((void *)tokens_buf); + return -1; + } + const char *keywords_buf; + size_t keywords_buf_size = ReadFile(keywords_filename, &keywords_buf); + if (keywords_buf_size < 1) { + fprintf(stderr, "Please check your keywords.txt!\n"); + free((void *)keywords_buf); + return -1; + } + + // Zipformer config + SherpaMnnOnlineTransducerModelConfig zipformer_config; + memset(&zipformer_config, 0, sizeof(zipformer_config)); + zipformer_config.encoder = encoder_filename; + zipformer_config.decoder = decoder_filename; + zipformer_config.joiner = joiner_filename; + + // Online model config + SherpaMnnOnlineModelConfig online_model_config; + memset(&online_model_config, 0, sizeof(online_model_config)); + online_model_config.debug = 1; + online_model_config.num_threads = 1; + online_model_config.provider = provider; + online_model_config.tokens_buf = tokens_buf; + online_model_config.tokens_buf_size = token_buf_size; + online_model_config.transducer = zipformer_config; + + // Keywords-spotter config + SherpaMnnKeywordSpotterConfig keywords_spotter_config; + memset(&keywords_spotter_config, 0, sizeof(keywords_spotter_config)); + keywords_spotter_config.max_active_paths = 4; + keywords_spotter_config.keywords_threshold = 0.1; + keywords_spotter_config.keywords_score = 3.0; + keywords_spotter_config.model_config = online_model_config; + keywords_spotter_config.keywords_buf = keywords_buf; + keywords_spotter_config.keywords_buf_size = keywords_buf_size; + + const SherpaMnnKeywordSpotter *keywords_spotter = + SherpaMnnCreateKeywordSpotter(&keywords_spotter_config); + + free((void *)tokens_buf); + tokens_buf = NULL; + free((void *)keywords_buf); + keywords_buf = NULL; + + if (keywords_spotter == NULL) { + fprintf(stderr, "Please check your config!\n"); + SherpaMnnFreeWave(wave); + return -1; + } + + const SherpaMnnOnlineStream *stream = + SherpaMnnCreateKeywordStream(keywords_spotter); + + const SherpaMnnDisplay *display = SherpaMnnCreateDisplay(50); + int32_t segment_id = 0; + +// simulate streaming. You can choose an arbitrary N +#define N 3200 + + fprintf(stderr, "sample rate: %d, num samples: %d, duration: %.2f s\n", + wave->sample_rate, wave->num_samples, + (float)wave->num_samples / wave->sample_rate); + + int32_t k = 0; + while (k < wave->num_samples) { + int32_t start = k; + int32_t end = + (start + N > wave->num_samples) ? wave->num_samples : (start + N); + k += N; + + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, + wave->samples + start, end - start); + while (SherpaMnnIsKeywordStreamReady(keywords_spotter, stream)) { + SherpaMnnDecodeKeywordStream(keywords_spotter, stream); + } + + const SherpaMnnKeywordResult *r = + SherpaMnnGetKeywordResult(keywords_spotter, stream); + + if (strlen(r->keyword)) { + SherpaMnnPrint(display, segment_id, r->keyword); + } + + SherpaMnnDestroyKeywordResult(r); + } + + // add some tail padding + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, + 4800); + + SherpaMnnFreeWave(wave); + + SherpaMnnOnlineStreamInputFinished(stream); + while (SherpaMnnIsKeywordStreamReady(keywords_spotter, stream)) { + SherpaMnnDecodeKeywordStream(keywords_spotter, stream); + } + + const SherpaMnnKeywordResult *r = + SherpaMnnGetKeywordResult(keywords_spotter, stream); + + if (strlen(r->keyword)) { + SherpaMnnPrint(display, segment_id, r->keyword); + } + + SherpaMnnDestroyKeywordResult(r); + + SherpaMnnDestroyDisplay(display); + SherpaMnnDestroyOnlineStream(stream); + SherpaMnnDestroyKeywordSpotter(keywords_spotter); + fprintf(stderr, "\n"); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/kokoro-tts-en-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/kokoro-tts-en-c-api.c new file mode 100644 index 00000000..8b01820b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/kokoro-tts-en-c-api.c @@ -0,0 +1,84 @@ +// c-api-examples/kokoro-tts-en-c-api.c +// +// Copyright (c) 2025 Xiaomi Corporation + +// This file shows how to use sherpa-onnx C API +// for English TTS with Kokoro. +// +// clang-format off +/* +Usage + + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/kokoro-en-v0_19.tar.bz2 +tar xf kokoro-en-v0_19.tar.bz2 +rm kokoro-en-v0_19.tar.bz2 + +./kokoro-tts-en-c-api + + */ +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +static int32_t ProgressCallback(const float *samples, int32_t num_samples, + float progress) { + fprintf(stderr, "Progress: %.3f%%\n", progress * 100); + // return 1 to continue generating + // return 0 to stop generating + return 1; +} + +int32_t main(int32_t argc, char *argv[]) { + SherpaMnnOfflineTtsConfig config; + memset(&config, 0, sizeof(config)); + config.model.kokoro.model = "./kokoro-en-v0_19/model.onnx"; + config.model.kokoro.voices = "./kokoro-en-v0_19/voices.bin"; + config.model.kokoro.tokens = "./kokoro-en-v0_19/tokens.txt"; + config.model.kokoro.data_dir = "./kokoro-en-v0_19/espeak-ng-data"; + + config.model.num_threads = 2; + + // If you don't want to see debug messages, please set it to 0 + config.model.debug = 1; + + const char *filename = "./generated-kokoro-en.wav"; + const char *text = + "Today as always, men fall into two groups: slaves and free men. Whoever " + "does not have two-thirds of his day for himself, is a slave, whatever " + "he may be: a statesman, a businessman, an official, or a scholar. " + "Friends fell out often because life was changing so fast. The easiest " + "thing in the world was to lose touch with someone."; + + const SherpaMnnOfflineTts *tts = SherpaMnnCreateOfflineTts(&config); + // mapping of sid to voice name + // 0->af, 1->af_bella, 2->af_nicole, 3->af_sarah, 4->af_sky, 5->am_adam + // 6->am_michael, 7->bf_emma, 8->bf_isabella, 9->bm_george, 10->bm_lewis + int32_t sid = 0; + float speed = 1.0; // larger -> faster in speech speed + +#if 0 + // If you don't want to use a callback, then please enable this branch + const SherpaMnnGeneratedAudio *audio = + SherpaMnnOfflineTtsGenerate(tts, text, sid, speed); +#else + const SherpaMnnGeneratedAudio *audio = + SherpaMnnOfflineTtsGenerateWithProgressCallback(tts, text, sid, speed, + ProgressCallback); +#endif + + SherpaMnnWriteWave(audio->samples, audio->n, audio->sample_rate, filename); + + SherpaMnnDestroyOfflineTtsGeneratedAudio(audio); + SherpaMnnDestroyOfflineTts(tts); + + fprintf(stderr, "Input text is: %s\n", text); + fprintf(stderr, "Speaker ID is is: %d\n", sid); + fprintf(stderr, "Saved to: %s\n", filename); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/kokoro-tts-zh-en-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/kokoro-tts-zh-en-c-api.c new file mode 100644 index 00000000..c60f6283 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/kokoro-tts-zh-en-c-api.c @@ -0,0 +1,82 @@ +// c-api-examples/kokoro-tts-zh-en-c-api.c +// +// Copyright (c) 2025 Xiaomi Corporation + +// This file shows how to use sherpa-onnx C API +// for English + Chinese TTS with Kokoro. +// +// clang-format off +/* +Usage + + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/kokoro-multi-lang-v1_0.tar.bz2 +tar xf kokoro-multi-lang-v1_0.tar.bz2 +rm kokoro-multi-lang-v1_0.tar.bz2 + +./kokoro-tts-zh-en-c-api + + */ +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +static int32_t ProgressCallback(const float *samples, int32_t num_samples, + float progress) { + fprintf(stderr, "Progress: %.3f%%\n", progress * 100); + // return 1 to continue generating + // return 0 to stop generating + return 1; +} + +int32_t main(int32_t argc, char *argv[]) { + SherpaMnnOfflineTtsConfig config; + memset(&config, 0, sizeof(config)); + config.model.kokoro.model = "./kokoro-multi-lang-v1_0/model.onnx"; + config.model.kokoro.voices = "./kokoro-multi-lang-v1_0/voices.bin"; + config.model.kokoro.tokens = "./kokoro-multi-lang-v1_0/tokens.txt"; + config.model.kokoro.data_dir = "./kokoro-multi-lang-v1_0/espeak-ng-data"; + config.model.kokoro.dict_dir = "./kokoro-multi-lang-v1_0/dict"; + config.model.kokoro.lexicon = + "./kokoro-multi-lang-v1_0/lexicon-us-en.txt,./kokoro-multi-lang-v1_0/" + "lexicon-zh.txt"; + + config.model.num_threads = 2; + + // If you don't want to see debug messages, please set it to 0 + config.model.debug = 1; + + const char *filename = "./generated-kokoro-zh-en.wav"; + const char *text = + "中英文语音合成测试。This is generated by next generation Kaldi using " + "Kokoro without Misaki. 你觉得中英文说的如何呢?"; + + const SherpaMnnOfflineTts *tts = SherpaMnnCreateOfflineTts(&config); + int32_t sid = 0; // there are 53 speakers + float speed = 1.0; // larger -> faster in speech speed + +#if 0 + // If you don't want to use a callback, then please enable this branch + const SherpaMnnGeneratedAudio *audio = + SherpaMnnOfflineTtsGenerate(tts, text, sid, speed); +#else + const SherpaMnnGeneratedAudio *audio = + SherpaMnnOfflineTtsGenerateWithProgressCallback(tts, text, sid, speed, + ProgressCallback); +#endif + + SherpaMnnWriteWave(audio->samples, audio->n, audio->sample_rate, filename); + + SherpaMnnDestroyOfflineTtsGeneratedAudio(audio); + SherpaMnnDestroyOfflineTts(tts); + + fprintf(stderr, "Input text is: %s\n", text); + fprintf(stderr, "Speaker ID is is: %d\n", sid); + fprintf(stderr, "Saved to: %s\n", filename); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/kws-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/kws-c-api.c new file mode 100644 index 00000000..a5a0d5f4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/kws-c-api.c @@ -0,0 +1,152 @@ +// c-api-examples/kws-c-api.c +// +// Copyright (c) 2025 Xiaomi Corporation +// +// This file demonstrates how to use keywords spotter with sherpa-onnx's C +// clang-format off +// +// Usage +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// +// ./kws-c-api +// +// clang-format on +#include +#include // exit +#include // memset + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + SherpaMnnKeywordSpotterConfig config; + + memset(&config, 0, sizeof(config)); + config.model_config.transducer.encoder = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"; + + config.model_config.transducer.decoder = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "decoder-epoch-12-avg-2-chunk-16-left-64.onnx"; + + config.model_config.transducer.joiner = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"; + + config.model_config.tokens = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "tokens.txt"; + + config.model_config.provider = "cpu"; + config.model_config.num_threads = 1; + config.model_config.debug = 1; + + config.keywords_file = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "test_wavs/test_keywords.txt"; + + const SherpaMnnKeywordSpotter *kws = SherpaMnnCreateKeywordSpotter(&config); + if (!kws) { + fprintf(stderr, "Please check your config"); + exit(-1); + } + + fprintf(stderr, + "--Test pre-defined keywords from test_wavs/test_keywords.txt--\n"); + + const char *wav_filename = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "test_wavs/3.wav"; + + float tail_paddings[8000] = {0}; // 0.5 seconds + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + exit(-1); + } + + const SherpaMnnOnlineStream *stream = SherpaMnnCreateKeywordStream(kws); + if (!stream) { + fprintf(stderr, "Failed to create stream\n"); + exit(-1); + } + + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples, + wave->num_samples); + + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, + sizeof(tail_paddings) / sizeof(float)); + SherpaMnnOnlineStreamInputFinished(stream); + while (SherpaMnnIsKeywordStreamReady(kws, stream)) { + SherpaMnnDecodeKeywordStream(kws, stream); + const SherpaMnnKeywordResult *r = SherpaMnnGetKeywordResult(kws, stream); + if (r && r->json && strlen(r->keyword)) { + fprintf(stderr, "Detected keyword: %s\n", r->json); + + // Remember to reset the keyword stream right after a keyword is detected + SherpaMnnResetKeywordStream(kws, stream); + } + SherpaMnnDestroyKeywordResult(r); + } + SherpaMnnDestroyOnlineStream(stream); + + // -------------------------------------------------------------------------- + + fprintf(stderr, "--Use pre-defined keywords + add a new keyword--\n"); + + stream = SherpaMnnCreateKeywordStreamWithKeywords(kws, "y ǎn y uán @演员"); + + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples, + wave->num_samples); + + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, + sizeof(tail_paddings) / sizeof(float)); + SherpaMnnOnlineStreamInputFinished(stream); + while (SherpaMnnIsKeywordStreamReady(kws, stream)) { + SherpaMnnDecodeKeywordStream(kws, stream); + const SherpaMnnKeywordResult *r = SherpaMnnGetKeywordResult(kws, stream); + if (r && r->json && strlen(r->keyword)) { + fprintf(stderr, "Detected keyword: %s\n", r->json); + + // Remember to reset the keyword stream + SherpaMnnResetKeywordStream(kws, stream); + } + SherpaMnnDestroyKeywordResult(r); + } + SherpaMnnDestroyOnlineStream(stream); + + // -------------------------------------------------------------------------- + + fprintf(stderr, "--Use pre-defined keywords + add two new keywords--\n"); + + stream = SherpaMnnCreateKeywordStreamWithKeywords( + kws, "y ǎn y uán @演员/zh ī m íng @知名"); + + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples, + wave->num_samples); + + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, + sizeof(tail_paddings) / sizeof(float)); + SherpaMnnOnlineStreamInputFinished(stream); + while (SherpaMnnIsKeywordStreamReady(kws, stream)) { + SherpaMnnDecodeKeywordStream(kws, stream); + const SherpaMnnKeywordResult *r = SherpaMnnGetKeywordResult(kws, stream); + if (r && r->json && strlen(r->keyword)) { + fprintf(stderr, "Detected keyword: %s\n", r->json); + + // Remember to reset the keyword stream + SherpaMnnResetKeywordStream(kws, stream); + } + SherpaMnnDestroyKeywordResult(r); + } + SherpaMnnDestroyOnlineStream(stream); + + SherpaMnnFreeWave(wave); + SherpaMnnDestroyKeywordSpotter(kws); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/matcha-tts-en-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/matcha-tts-en-c-api.c new file mode 100644 index 00000000..7415646f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/matcha-tts-en-c-api.c @@ -0,0 +1,87 @@ +// c-api-examples/matcha-tts-en-c-api.c +// +// Copyright (c) 2025 Xiaomi Corporation + +// This file shows how to use sherpa-onnx C API +// for English TTS with MatchaTTS. +// +// clang-format off +/* +Usage + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-en_US-ljspeech.tar.bz2 +tar xvf matcha-icefall-en_US-ljspeech.tar.bz2 +rm matcha-icefall-en_US-ljspeech.tar.bz2 + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx + +./matcha-tts-en-c-api + + */ +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +static int32_t ProgressCallback(const float *samples, int32_t num_samples, + float progress) { + fprintf(stderr, "Progress: %.3f%%\n", progress * 100); + // return 1 to continue generating + // return 0 to stop generating + return 1; +} + +int32_t main(int32_t argc, char *argv[]) { + SherpaMnnOfflineTtsConfig config; + memset(&config, 0, sizeof(config)); + config.model.matcha.acoustic_model = + "./matcha-icefall-en_US-ljspeech/model-steps-3.onnx"; + + config.model.matcha.vocoder = "./hifigan_v2.onnx"; + + config.model.matcha.tokens = "./matcha-icefall-en_US-ljspeech/tokens.txt"; + + config.model.matcha.data_dir = + "./matcha-icefall-en_US-ljspeech/espeak-ng-data"; + + config.model.num_threads = 1; + + // If you don't want to see debug messages, please set it to 0 + config.model.debug = 1; + + const char *filename = "./generated-matcha-en.wav"; + const char *text = + "Today as always, men fall into two groups: slaves and free men. Whoever " + "does not have two-thirds of his day for himself, is a slave, whatever " + "he may be: a statesman, a businessman, an official, or a scholar. " + "Friends fell out often because life was changing so fast. The easiest " + "thing in the world was to lose touch with someone."; + + const SherpaMnnOfflineTts *tts = SherpaMnnCreateOfflineTts(&config); + int32_t sid = 0; + float speed = 1.0; // larger -> faster in speech speed + +#if 0 + // If you don't want to use a callback, then please enable this branch + const SherpaMnnGeneratedAudio *audio = + SherpaMnnOfflineTtsGenerate(tts, text, sid, speed); +#else + const SherpaMnnGeneratedAudio *audio = + SherpaMnnOfflineTtsGenerateWithProgressCallback(tts, text, sid, speed, + ProgressCallback); +#endif + + SherpaMnnWriteWave(audio->samples, audio->n, audio->sample_rate, filename); + + SherpaMnnDestroyOfflineTtsGeneratedAudio(audio); + SherpaMnnDestroyOfflineTts(tts); + + fprintf(stderr, "Input text is: %s\n", text); + fprintf(stderr, "Speaker ID is is: %d\n", sid); + fprintf(stderr, "Saved to: %s\n", filename); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/matcha-tts-zh-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/matcha-tts-zh-c-api.c new file mode 100644 index 00000000..0e0932b6 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/matcha-tts-zh-c-api.c @@ -0,0 +1,87 @@ +// c-api-examples/matcha-tts-zh-c-api.c +// +// Copyright (c) 2025 Xiaomi Corporation + +// This file shows how to use sherpa-onnx C API +// for Chinese TTS with MatchaTTS. +// +// clang-format off +/* +Usage + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2 +tar xvf matcha-icefall-zh-baker.tar.bz2 +rm matcha-icefall-zh-baker.tar.bz2 + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx + +./matcha-tts-zh-c-api + + */ +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +static int32_t ProgressCallback(const float *samples, int32_t num_samples, + float progress) { + fprintf(stderr, "Progress: %.3f%%\n", progress * 100); + // return 1 to continue generating + // return 0 to stop generating + return 1; +} + +int32_t main(int32_t argc, char *argv[]) { + SherpaMnnOfflineTtsConfig config; + memset(&config, 0, sizeof(config)); + config.model.matcha.acoustic_model = + "./matcha-icefall-zh-baker/model-steps-3.onnx"; + config.model.matcha.vocoder = "./hifigan_v2.onnx"; + config.model.matcha.lexicon = "./matcha-icefall-zh-baker/lexicon.txt"; + config.model.matcha.tokens = "./matcha-icefall-zh-baker/tokens.txt"; + config.model.matcha.dict_dir = "./matcha-icefall-zh-baker/dict"; + config.model.num_threads = 1; + + // If you don't want to see debug messages, please set it to 0 + config.model.debug = 1; + + // clang-format off + config.rule_fsts = "./matcha-icefall-zh-baker/phone.fst,./matcha-icefall-zh-baker/date.fst,./matcha-icefall-zh-baker/number.fst"; + // clang-format on + + const char *filename = "./generated-matcha-zh.wav"; + const char *text = + "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如" + "涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感" + "受着生命的奇迹与温柔." + "某某银行的副行长和一些行政领导表示,他们去过长江和长白山; " + "经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。"; + + const SherpaMnnOfflineTts *tts = SherpaMnnCreateOfflineTts(&config); + int32_t sid = 0; + float speed = 1.0; // larger -> faster in speech speed + +#if 0 + // If you don't want to use a callback, then please enable this branch + const SherpaMnnGeneratedAudio *audio = + SherpaMnnOfflineTtsGenerate(tts, text, sid, speed); +#else + const SherpaMnnGeneratedAudio *audio = + SherpaMnnOfflineTtsGenerateWithProgressCallback(tts, text, sid, speed, + ProgressCallback); +#endif + + SherpaMnnWriteWave(audio->samples, audio->n, audio->sample_rate, filename); + + SherpaMnnDestroyOfflineTtsGeneratedAudio(audio); + SherpaMnnDestroyOfflineTts(tts); + + fprintf(stderr, "Input text is: %s\n", text); + fprintf(stderr, "Speaker ID is is: %d\n", sid); + fprintf(stderr, "Saved to: %s\n", filename); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/moonshine-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/moonshine-c-api.c new file mode 100644 index 00000000..896ff212 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/moonshine-c-api.c @@ -0,0 +1,83 @@ +// c-api-examples/moonshine-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use Moonshine tiny with sherpa-onnx's C API. +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +// tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +// rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + const char *wav_filename = + "./sherpa-onnx-moonshine-tiny-en-int8/test_wavs/0.wav"; + const char *preprocessor = + "./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx"; + const char *encoder = "./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx"; + const char *uncached_decoder = + "./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx"; + const char *cached_decoder = + "./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx"; + const char *tokens = "./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt"; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + // Offline model config + SherpaMnnOfflineModelConfig offline_model_config; + memset(&offline_model_config, 0, sizeof(offline_model_config)); + offline_model_config.debug = 1; + offline_model_config.num_threads = 1; + offline_model_config.provider = "cpu"; + offline_model_config.tokens = tokens; + offline_model_config.moonshine.preprocessor = preprocessor; + offline_model_config.moonshine.encoder = encoder; + offline_model_config.moonshine.uncached_decoder = uncached_decoder; + offline_model_config.moonshine.cached_decoder = cached_decoder; + + // Recognizer config + SherpaMnnOfflineRecognizerConfig recognizer_config; + memset(&recognizer_config, 0, sizeof(recognizer_config)); + recognizer_config.decoding_method = "greedy_search"; + recognizer_config.model_config = offline_model_config; + + const SherpaMnnOfflineRecognizer *recognizer = + SherpaMnnCreateOfflineRecognizer(&recognizer_config); + + if (recognizer == NULL) { + fprintf(stderr, "Please check your config!\n"); + SherpaMnnFreeWave(wave); + return -1; + } + + const SherpaMnnOfflineStream *stream = + SherpaMnnCreateOfflineStream(recognizer); + + SherpaMnnAcceptWaveformOffline(stream, wave->sample_rate, wave->samples, + wave->num_samples); + SherpaMnnDecodeOfflineStream(recognizer, stream); + const SherpaMnnOfflineRecognizerResult *result = + SherpaMnnGetOfflineStreamResult(stream); + + fprintf(stderr, "Decoded text: %s\n", result->text); + + SherpaMnnDestroyOfflineRecognizerResult(result); + SherpaMnnDestroyOfflineStream(stream); + SherpaMnnDestroyOfflineRecognizer(recognizer); + SherpaMnnFreeWave(wave); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/offline-speaker-diarization-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/offline-speaker-diarization-c-api.c new file mode 100644 index 00000000..fec2da8f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/offline-speaker-diarization-c-api.c @@ -0,0 +1,131 @@ +// c-api-examples/offline-sepaker-diarization-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to implement speaker diarization with +// sherpa-onnx's C API. + +// clang-format off +/* +Usage: + +Step 1: Download a speaker segmentation model + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models +for a list of available models. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + +Step 2: Download a speaker embedding extractor model + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +for a list of available models. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx + +Step 3. Download test wave files + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models +for a list of available test wave files. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav + +Step 4. Run it + + */ +// clang-format on + +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +static int32_t ProgressCallback(int32_t num_processed_chunks, + int32_t num_total_chunks, void *arg) { + float progress = 100.0 * num_processed_chunks / num_total_chunks; + fprintf(stderr, "progress %.2f%%\n", progress); + + // the return value is currently ignored + return 0; +} + +int main() { + // Please see the comments at the start of this file for how to download + // the .onnx file and .wav files below + const char *segmentation_model = + "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"; + + const char *embedding_extractor_model = + "./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"; + + const char *wav_filename = "./0-four-speakers-zh.wav"; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + SherpaMnnOfflineSpeakerDiarizationConfig config; + memset(&config, 0, sizeof(config)); + + config.segmentation.pyannote.model = segmentation_model; + config.embedding.model = embedding_extractor_model; + + // the test wave ./0-four-speakers-zh.wav has 4 speakers, so + // we set num_clusters to 4 + // + config.clustering.num_clusters = 4; + // If you don't know the number of speakers in the test wave file, please + // use + // config.clustering.threshold = 0.5; // You need to tune this threshold + + const SherpaMnnOfflineSpeakerDiarization *sd = + SherpaMnnCreateOfflineSpeakerDiarization(&config); + + if (!sd) { + fprintf(stderr, "Failed to initialize offline speaker diarization\n"); + return -1; + } + + if (SherpaMnnOfflineSpeakerDiarizationGetSampleRate(sd) != + wave->sample_rate) { + fprintf( + stderr, + "Expected sample rate: %d. Actual sample rate from the wave file: %d\n", + SherpaMnnOfflineSpeakerDiarizationGetSampleRate(sd), + wave->sample_rate); + goto failed; + } + + const SherpaMnnOfflineSpeakerDiarizationResult *result = + SherpaMnnOfflineSpeakerDiarizationProcessWithCallback( + sd, wave->samples, wave->num_samples, ProgressCallback, NULL); + if (!result) { + fprintf(stderr, "Failed to do speaker diarization"); + goto failed; + } + + int32_t num_segments = + SherpaMnnOfflineSpeakerDiarizationResultGetNumSegments(result); + + const SherpaMnnOfflineSpeakerDiarizationSegment *segments = + SherpaMnnOfflineSpeakerDiarizationResultSortByStartTime(result); + + for (int32_t i = 0; i != num_segments; ++i) { + fprintf(stderr, "%.3f -- %.3f speaker_%02d\n", segments[i].start, + segments[i].end, segments[i].speaker); + } + +failed: + + SherpaMnnOfflineSpeakerDiarizationDestroySegment(segments); + SherpaMnnOfflineSpeakerDiarizationDestroyResult(result); + SherpaMnnDestroyOfflineSpeakerDiarization(sd); + SherpaMnnFreeWave(wave); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/offline-tts-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/offline-tts-c-api.c new file mode 100644 index 00000000..417a80c4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/offline-tts-c-api.c @@ -0,0 +1,249 @@ +// c-api-examples/offline-tts-c-api.c +// +// Copyright (c) 2023 Xiaomi Corporation + +// This file shows how to use sherpa-onnx C API +// to convert text to speech using an offline model. + +#include +#include +#include + +#include "cargs.h" +#include "sherpa-mnn/c-api/c-api.h" + +static struct cag_option options[] = { + {.identifier = 'h', + .access_letters = "h", + .access_name = "help", + .description = "Show help"}, + {.access_name = "vits-model", + .value_name = "/path/to/xxx.onnx", + .identifier = '0', + .description = "Path to VITS model"}, + {.access_name = "vits-lexicon", + .value_name = "/path/to/lexicon.txt", + .identifier = '1', + .description = "Path to lexicon.txt for VITS models"}, + {.access_name = "vits-tokens", + .value_name = "/path/to/tokens.txt", + .identifier = '2', + .description = "Path to tokens.txt for VITS models"}, + {.access_name = "vits-noise-scale", + .value_name = "0.667", + .identifier = '3', + .description = "noise_scale for VITS models"}, + {.access_name = "vits-noise-scale-w", + .value_name = "0.8", + .identifier = '4', + .description = "noise_scale_w for VITS models"}, + {.access_name = "vits-length-scale", + .value_name = "1.0", + .identifier = '5', + .description = + "length_scale for VITS models. Default to 1. You can tune it " + "to change the speech speed. small -> faster; large -> slower. "}, + {.access_name = "num-threads", + .value_name = "1", + .identifier = '6', + .description = "Number of threads"}, + {.access_name = "provider", + .value_name = "cpu", + .identifier = '7', + .description = "Provider: cpu (default), cuda, coreml"}, + {.access_name = "debug", + .value_name = "0", + .identifier = '8', + .description = "1 to show debug messages while loading the model"}, + {.access_name = "sid", + .value_name = "0", + .identifier = '9', + .description = "Speaker ID. Default to 0. Note it is not used for " + "single-speaker models."}, + {.access_name = "output-filename", + .value_name = "./generated.wav", + .identifier = 'a', + .description = + "Filename to save the generated audio. Default to ./generated.wav"}, + + {.access_name = "tts-rule-fsts", + .value_name = "/path/to/rule.fst", + .identifier = 'b', + .description = "It not empty, it contains a list of rule FST filenames." + "Multiple filenames are separated by a comma and they are " + "applied from left to right. An example value: " + "rule1.fst,rule2,fst,rule3.fst"}, + + {.access_name = "max-num-sentences", + .value_name = "2", + .identifier = 'c', + .description = "Maximum number of sentences that we process at a time. " + "This is to avoid OOM for very long input text. " + "If you set it to -1, then we process all sentences in a " + "single batch."}, + + {.access_name = "vits-data-dir", + .value_name = "/path/to/espeak-ng-data", + .identifier = 'd', + .description = + "Path to espeak-ng-data. If it is given, --vits-lexicon is ignored"}, + +}; + +static void ShowUsage() { + const char *kUsageMessage = + "Offline text-to-speech with sherpa-onnx C API" + "\n" + "./offline-tts-c-api \\\n" + " --vits-model=/path/to/model.onnx \\\n" + " --vits-lexicon=/path/to/lexicon.txt \\\n" + " --vits-tokens=/path/to/tokens.txt \\\n" + " --sid=0 \\\n" + " --output-filename=./generated.wav \\\n" + " 'some text within single quotes on linux/macos or use double quotes on " + "windows'\n" + "\n" + "It will generate a file ./generated.wav as specified by " + "--output-filename.\n" + "\n" + "You can download a test model from\n" + "https://huggingface.co/csukuangfj/vits-ljs\n" + "\n" + "For instance, you can use:\n" + "wget " + "https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx\n" + "wget " + "https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt\n" + "wget " + "https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt\n" + "\n" + "./offline-tts-c-api \\\n" + " --vits-model=./vits-ljs.onnx \\\n" + " --vits-lexicon=./lexicon.txt \\\n" + " --vits-tokens=./tokens.txt \\\n" + " --sid=0 \\\n" + " --output-filename=./generated.wav \\\n" + " 'liliana, the most beautiful and lovely assistant of our team!'\n" + "\n" + "Please see\n" + "https://k2-fsa.github.io/sherpa/onnx/tts/index.html\n" + "or details.\n\n"; + + fprintf(stderr, "%s", kUsageMessage); + cag_option_print(options, CAG_ARRAY_SIZE(options), stderr); + exit(0); +} + +int32_t main(int32_t argc, char *argv[]) { + cag_option_context context; + char identifier; + const char *value; + + cag_option_prepare(&context, options, CAG_ARRAY_SIZE(options), argc, argv); + + SherpaMnnOfflineTtsConfig config; + memset(&config, 0, sizeof(config)); + + int32_t sid = 0; + const char *filename = strdup("./generated.wav"); + const char *text; + + while (cag_option_fetch(&context)) { + identifier = cag_option_get(&context); + value = cag_option_get_value(&context); + switch (identifier) { + case '0': + config.model.vits.model = value; + break; + case '1': + config.model.vits.lexicon = value; + break; + case '2': + config.model.vits.tokens = value; + break; + case '3': + config.model.vits.noise_scale = atof(value); + break; + case '4': + config.model.vits.noise_scale_w = atof(value); + break; + case '5': + config.model.vits.length_scale = atof(value); + break; + case '6': + config.model.num_threads = atoi(value); + break; + case '7': + config.model.provider = value; + break; + case '8': + config.model.debug = atoi(value); + break; + case '9': + sid = atoi(value); + break; + case 'a': + free((void *)filename); + filename = strdup(value); + break; + case 'b': + config.rule_fsts = value; + break; + case 'c': + config.max_num_sentences = atoi(value); + break; + case 'd': + config.model.vits.data_dir = value; + break; + case '?': + fprintf(stderr, "Unknown option\n"); + // fall through + case 'h': + // fall through + default: + ShowUsage(); + } + } + fprintf(stderr, "here\n"); + + if (!config.model.vits.model) { + fprintf(stderr, "Please provide --vits-model\n"); + ShowUsage(); + } + + if (!config.model.vits.tokens) { + fprintf(stderr, "Please provide --vits-tokens\n"); + ShowUsage(); + } + + if (!config.model.vits.data_dir && !config.model.vits.lexicon) { + fprintf(stderr, "Please provide --vits-data-dir or --vits-lexicon\n"); + ShowUsage(); + } + + // the last arg is the text + text = argv[argc - 1]; + if (text[0] == '-') { + fprintf(stderr, "\n***Please input your text!***\n\n"); + fprintf(stderr, "\n---------------Usage---------------\n\n"); + ShowUsage(); + } + + const SherpaMnnOfflineTts *tts = SherpaMnnCreateOfflineTts(&config); + + const SherpaMnnGeneratedAudio *audio = + SherpaMnnOfflineTtsGenerate(tts, text, sid, 1.0); + + SherpaMnnWriteWave(audio->samples, audio->n, audio->sample_rate, filename); + + SherpaMnnDestroyOfflineTtsGeneratedAudio(audio); + SherpaMnnDestroyOfflineTts(tts); + + fprintf(stderr, "Input text is: %s\n", text); + fprintf(stderr, "Speaker ID is is: %d\n", sid); + fprintf(stderr, "Saved to: %s\n", filename); + + free((void *)filename); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/paraformer-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/paraformer-c-api.c new file mode 100644 index 00000000..4249fac2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/paraformer-c-api.c @@ -0,0 +1,83 @@ +// c-api-examples/paraformer-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use non-streaming Paraformer with sherpa-onnx's +// C API. +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-small-2024-03-09.tar.bz2 +// tar xvf sherpa-onnx-paraformer-zh-small-2024-03-09.tar.bz2 +// rm sherpa-onnx-paraformer-zh-small-2024-03-09.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + const char *wav_filename = + "sherpa-onnx-paraformer-zh-small-2024-03-09/test_wavs/0.wav"; + const char *model_filename = + "sherpa-onnx-paraformer-zh-small-2024-03-09/model.int8.onnx"; + const char *tokens_filename = + "sherpa-onnx-paraformer-zh-small-2024-03-09/tokens.txt"; + const char *provider = "cpu"; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + // Paraformer config + SherpaMnnOfflineParaformerModelConfig paraformer_config; + memset(¶former_config, 0, sizeof(paraformer_config)); + paraformer_config.model = model_filename; + + // Offline model config + SherpaMnnOfflineModelConfig offline_model_config; + memset(&offline_model_config, 0, sizeof(offline_model_config)); + offline_model_config.debug = 1; + offline_model_config.num_threads = 1; + offline_model_config.provider = provider; + offline_model_config.tokens = tokens_filename; + offline_model_config.paraformer = paraformer_config; + + // Recognizer config + SherpaMnnOfflineRecognizerConfig recognizer_config; + memset(&recognizer_config, 0, sizeof(recognizer_config)); + recognizer_config.decoding_method = "greedy_search"; + recognizer_config.model_config = offline_model_config; + + const SherpaMnnOfflineRecognizer *recognizer = + SherpaMnnCreateOfflineRecognizer(&recognizer_config); + + if (recognizer == NULL) { + fprintf(stderr, "Please check your config!\n"); + SherpaMnnFreeWave(wave); + return -1; + } + + const SherpaMnnOfflineStream *stream = + SherpaMnnCreateOfflineStream(recognizer); + + SherpaMnnAcceptWaveformOffline(stream, wave->sample_rate, wave->samples, + wave->num_samples); + SherpaMnnDecodeOfflineStream(recognizer, stream); + const SherpaMnnOfflineRecognizerResult *result = + SherpaMnnGetOfflineStreamResult(stream); + + fprintf(stderr, "Decoded text: %s\n", result->text); + + SherpaMnnDestroyOfflineRecognizerResult(result); + SherpaMnnDestroyOfflineStream(stream); + SherpaMnnDestroyOfflineRecognizer(recognizer); + SherpaMnnFreeWave(wave); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/run.sh b/apps/frameworks/sherpa-mnn/c-api-examples/run.sh new file mode 100755 index 00000000..02054c2e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/run.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash + +set -ex + +if [ ! -d ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 ]; then + echo "Please download the pre-trained model for testing." + echo "You can refer to" + echo "" + echo "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english" + echo "for help" + exit 1 +fi + +if [[ ! -f ../build/lib/libsherpa-onnx-c-api.a && ! -f ../build/lib/libsherpa-onnx-c-api.dylib && ! -f ../build/lib/libsherpa-onnx-c-api.so ]]; then + echo "Please build sherpa-onnx first. You can use" + echo "" + echo " cd /path/to/sherpa-onnx" + echo " mkdir build" + echo " cd build" + echo " cmake .." + echo " make -j4" + exit 1 +fi + +if [ ! -f ./decode-file-c-api ]; then + make +fi + +./decode-file-c-api \ + --tokens=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ + --encoder=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ + --decoder=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ + --joiner=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ + ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav + +# Run with hotwords + +echo "礼 拜 二" > hotwords.txt + +./decode-file-c-api \ + --tokens=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ + --encoder=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ + --decoder=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ + --joiner=./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ + --hotwords-file=hotwords.txt \ + --hotwords-score=1.5 \ + --decoding-method=modified_beam_search \ + ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/sense-voice-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/sense-voice-c-api.c new file mode 100644 index 00000000..16580fa2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/sense-voice-c-api.c @@ -0,0 +1,85 @@ +// c-api-examples/sense-voice-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use SenseVoice with sherpa-onnx's C API. +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +// tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +// rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + const char *wav_filename = + "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/en.wav"; + const char *model_filename = + "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx"; + const char *tokens_filename = + "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt"; + const char *language = "auto"; + const char *provider = "cpu"; + int32_t use_inverse_text_normalization = 1; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + SherpaMnnOfflineSenseVoiceModelConfig sense_voice_config; + memset(&sense_voice_config, 0, sizeof(sense_voice_config)); + sense_voice_config.model = model_filename; + sense_voice_config.language = language; + sense_voice_config.use_itn = use_inverse_text_normalization; + + // Offline model config + SherpaMnnOfflineModelConfig offline_model_config; + memset(&offline_model_config, 0, sizeof(offline_model_config)); + offline_model_config.debug = 1; + offline_model_config.num_threads = 1; + offline_model_config.provider = provider; + offline_model_config.tokens = tokens_filename; + offline_model_config.sense_voice = sense_voice_config; + + // Recognizer config + SherpaMnnOfflineRecognizerConfig recognizer_config; + memset(&recognizer_config, 0, sizeof(recognizer_config)); + recognizer_config.decoding_method = "greedy_search"; + recognizer_config.model_config = offline_model_config; + + const SherpaMnnOfflineRecognizer *recognizer = + SherpaMnnCreateOfflineRecognizer(&recognizer_config); + + if (recognizer == NULL) { + fprintf(stderr, "Please check your config!\n"); + SherpaMnnFreeWave(wave); + return -1; + } + + const SherpaMnnOfflineStream *stream = + SherpaMnnCreateOfflineStream(recognizer); + + SherpaMnnAcceptWaveformOffline(stream, wave->sample_rate, wave->samples, + wave->num_samples); + SherpaMnnDecodeOfflineStream(recognizer, stream); + const SherpaMnnOfflineRecognizerResult *result = + SherpaMnnGetOfflineStreamResult(stream); + + fprintf(stderr, "Decoded text: %s\n", result->text); + + SherpaMnnDestroyOfflineRecognizerResult(result); + SherpaMnnDestroyOfflineStream(stream); + SherpaMnnDestroyOfflineRecognizer(recognizer); + SherpaMnnFreeWave(wave); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/speaker-identification-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/speaker-identification-c-api.c new file mode 100644 index 00000000..173ca7e9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/speaker-identification-c-api.c @@ -0,0 +1,257 @@ +// c-api-examples/speaker-identification-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// We assume you have pre-downloaded the speaker embedding extractor model +// from +// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +// +// An example command to download +// "3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx" +// is given below: +// +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx +// +// clang-format on +// +// Also, please download the test wave files from +// +// https://github.com/csukuangfj/sr-data + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +static const float *ComputeEmbedding( + const SherpaMnnSpeakerEmbeddingExtractor *ex, const char *wav_filename) { + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + exit(-1); + } + + const SherpaMnnOnlineStream *stream = + SherpaMnnSpeakerEmbeddingExtractorCreateStream(ex); + + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, wave->samples, + wave->num_samples); + SherpaMnnOnlineStreamInputFinished(stream); + + if (!SherpaMnnSpeakerEmbeddingExtractorIsReady(ex, stream)) { + fprintf(stderr, "The input wave file %s is too short!\n", wav_filename); + exit(-1); + } + + // we will free `v` outside of this function + const float *v = + SherpaMnnSpeakerEmbeddingExtractorComputeEmbedding(ex, stream); + + SherpaMnnDestroyOnlineStream(stream); + SherpaMnnFreeWave(wave); + + // Remeber to free v to avoid memory leak + return v; +} + +int32_t main() { + SherpaMnnSpeakerEmbeddingExtractorConfig config; + + memset(&config, 0, sizeof(config)); + + // please download the model from + // https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models + config.model = "./3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx"; + + config.num_threads = 1; + config.debug = 0; + config.provider = "cpu"; + + const SherpaMnnSpeakerEmbeddingExtractor *ex = + SherpaMnnCreateSpeakerEmbeddingExtractor(&config); + if (!ex) { + fprintf(stderr, "Failed to create speaker embedding extractor"); + return -1; + } + + int32_t dim = SherpaMnnSpeakerEmbeddingExtractorDim(ex); + + const SherpaMnnSpeakerEmbeddingManager *manager = + SherpaMnnCreateSpeakerEmbeddingManager(dim); + + // Please download the test data from + // https://github.com/csukuangfj/sr-data + const char *spk1_1 = "./sr-data/enroll/fangjun-sr-1.wav"; + const char *spk1_2 = "./sr-data/enroll/fangjun-sr-2.wav"; + const char *spk1_3 = "./sr-data/enroll/fangjun-sr-3.wav"; + + const char *spk2_1 = "./sr-data/enroll/leijun-sr-1.wav"; + const char *spk2_2 = "./sr-data/enroll/leijun-sr-2.wav"; + + const float *spk1_vec[4] = {NULL}; + spk1_vec[0] = ComputeEmbedding(ex, spk1_1); + spk1_vec[1] = ComputeEmbedding(ex, spk1_2); + spk1_vec[2] = ComputeEmbedding(ex, spk1_3); + + const float *spk2_vec[3] = {NULL}; + spk2_vec[0] = ComputeEmbedding(ex, spk2_1); + spk2_vec[1] = ComputeEmbedding(ex, spk2_2); + + if (!SherpaMnnSpeakerEmbeddingManagerAddList(manager, "fangjun", spk1_vec)) { + fprintf(stderr, "Failed to register fangjun\n"); + exit(-1); + } + + if (!SherpaMnnSpeakerEmbeddingManagerContains(manager, "fangjun")) { + fprintf(stderr, "Failed to find fangjun\n"); + exit(-1); + } + + if (!SherpaMnnSpeakerEmbeddingManagerAddList(manager, "leijun", spk2_vec)) { + fprintf(stderr, "Failed to register leijun\n"); + exit(-1); + } + + if (!SherpaMnnSpeakerEmbeddingManagerContains(manager, "leijun")) { + fprintf(stderr, "Failed to find leijun\n"); + exit(-1); + } + + if (SherpaMnnSpeakerEmbeddingManagerNumSpeakers(manager) != 2) { + fprintf(stderr, "There should be two speakers: fangjun and leijun\n"); + exit(-1); + } + + const char *const *all_speakers = + SherpaMnnSpeakerEmbeddingManagerGetAllSpeakers(manager); + const char *const *p = all_speakers; + fprintf(stderr, "list of registered speakers\n-----\n"); + while (p[0]) { + fprintf(stderr, "speaker: %s\n", p[0]); + ++p; + } + fprintf(stderr, "----\n"); + + SherpaMnnSpeakerEmbeddingManagerFreeAllSpeakers(all_speakers); + + const char *test1 = "./sr-data/test/fangjun-test-sr-1.wav"; + const char *test2 = "./sr-data/test/leijun-test-sr-1.wav"; + const char *test3 = "./sr-data/test/liudehua-test-sr-1.wav"; + + const float *v1 = ComputeEmbedding(ex, test1); + const float *v2 = ComputeEmbedding(ex, test2); + const float *v3 = ComputeEmbedding(ex, test3); + + float threshold = 0.6; + + const char *name1 = + SherpaMnnSpeakerEmbeddingManagerSearch(manager, v1, threshold); + if (name1) { + fprintf(stderr, "%s: Found %s\n", test1, name1); + SherpaMnnSpeakerEmbeddingManagerFreeSearch(name1); + } else { + fprintf(stderr, "%s: Not found\n", test1); + } + + const char *name2 = + SherpaMnnSpeakerEmbeddingManagerSearch(manager, v2, threshold); + if (name2) { + fprintf(stderr, "%s: Found %s\n", test2, name2); + SherpaMnnSpeakerEmbeddingManagerFreeSearch(name2); + } else { + fprintf(stderr, "%s: Not found\n", test2); + } + + const char *name3 = + SherpaMnnSpeakerEmbeddingManagerSearch(manager, v3, threshold); + if (name3) { + fprintf(stderr, "%s: Found %s\n", test3, name3); + SherpaMnnSpeakerEmbeddingManagerFreeSearch(name3); + } else { + fprintf(stderr, "%s: Not found\n", test3); + } + + int32_t ok = SherpaMnnSpeakerEmbeddingManagerVerify(manager, "fangjun", v1, + threshold); + if (ok) { + fprintf(stderr, "%s matches fangjun\n", test1); + } else { + fprintf(stderr, "%s does NOT match fangjun\n", test1); + } + + ok = SherpaMnnSpeakerEmbeddingManagerVerify(manager, "fangjun", v2, + threshold); + if (ok) { + fprintf(stderr, "%s matches fangjun\n", test2); + } else { + fprintf(stderr, "%s does NOT match fangjun\n", test2); + } + + fprintf(stderr, "Removing fangjun\n"); + if (!SherpaMnnSpeakerEmbeddingManagerRemove(manager, "fangjun")) { + fprintf(stderr, "Failed to remove fangjun\n"); + exit(-1); + } + + if (SherpaMnnSpeakerEmbeddingManagerNumSpeakers(manager) != 1) { + fprintf(stderr, "There should be only 1 speaker left\n"); + exit(-1); + } + + name1 = SherpaMnnSpeakerEmbeddingManagerSearch(manager, v1, threshold); + if (name1) { + fprintf(stderr, "%s: Found %s\n", test1, name1); + SherpaMnnSpeakerEmbeddingManagerFreeSearch(name1); + } else { + fprintf(stderr, "%s: Not found\n", test1); + } + + fprintf(stderr, "Removing leijun\n"); + if (!SherpaMnnSpeakerEmbeddingManagerRemove(manager, "leijun")) { + fprintf(stderr, "Failed to remove leijun\n"); + exit(-1); + } + + if (SherpaMnnSpeakerEmbeddingManagerNumSpeakers(manager) != 0) { + fprintf(stderr, "There should be only 1 speaker left\n"); + exit(-1); + } + + name2 = SherpaMnnSpeakerEmbeddingManagerSearch(manager, v2, threshold); + if (name2) { + fprintf(stderr, "%s: Found %s\n", test2, name2); + SherpaMnnSpeakerEmbeddingManagerFreeSearch(name2); + } else { + fprintf(stderr, "%s: Not found\n", test2); + } + + all_speakers = SherpaMnnSpeakerEmbeddingManagerGetAllSpeakers(manager); + + p = all_speakers; + fprintf(stderr, "list of registered speakers\n-----\n"); + while (p[0]) { + fprintf(stderr, "speaker: %s\n", p[0]); + ++p; + } + fprintf(stderr, "----\n"); + + SherpaMnnSpeakerEmbeddingManagerFreeAllSpeakers(all_speakers); + SherpaMnnSpeakerEmbeddingExtractorDestroyEmbedding(v1); + SherpaMnnSpeakerEmbeddingExtractorDestroyEmbedding(v2); + SherpaMnnSpeakerEmbeddingExtractorDestroyEmbedding(v3); + + SherpaMnnSpeakerEmbeddingExtractorDestroyEmbedding(spk1_vec[0]); + SherpaMnnSpeakerEmbeddingExtractorDestroyEmbedding(spk1_vec[1]); + SherpaMnnSpeakerEmbeddingExtractorDestroyEmbedding(spk1_vec[2]); + + SherpaMnnSpeakerEmbeddingExtractorDestroyEmbedding(spk2_vec[0]); + SherpaMnnSpeakerEmbeddingExtractorDestroyEmbedding(spk2_vec[1]); + + SherpaMnnDestroySpeakerEmbeddingManager(manager); + SherpaMnnDestroySpeakerEmbeddingExtractor(ex); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/speech-enhancement-gtcrn-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/speech-enhancement-gtcrn-c-api.c new file mode 100644 index 00000000..16a51fa5 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/speech-enhancement-gtcrn-c-api.c @@ -0,0 +1,55 @@ +// c-api-examples/speech-enhancement-gtcrn-c-api.c +// +// Copyright (c) 2025 Xiaomi Corporation +// +// We assume you have pre-downloaded model +// from +// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models +// +// +// An example command to download +// clang-format off +/* +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/inp_16k.wav +*/ +// clang-format on +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + SherpaMnnOfflineSpeechDenoiserConfig config; + const char *wav_filename = "./inp_16k.wav"; + const char *out_wave_filename = "./enhanced_16k.wav"; + + memset(&config, 0, sizeof(config)); + config.model.gtcrn.model = "./gtcrn_simple.onnx"; + + const SherpaMnnOfflineSpeechDenoiser *sd = + SherpaMnnCreateOfflineSpeechDenoiser(&config); + if (!sd) { + fprintf(stderr, "Please check your config"); + return -1; + } + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + SherpaMnnDestroyOfflineSpeechDenoiser(sd); + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + const SherpaMnnDenoisedAudio *denoised = SherpaMnnOfflineSpeechDenoiserRun( + sd, wave->samples, wave->num_samples, wave->sample_rate); + + SherpaMnnWriteWave(denoised->samples, denoised->n, denoised->sample_rate, + out_wave_filename); + + SherpaMnnDestroyDenoisedAudio(denoised); + SherpaMnnFreeWave(wave); + SherpaMnnDestroyOfflineSpeechDenoiser(sd); + + fprintf(stdout, "Saved to %s\n", out_wave_filename); +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/spoken-language-identification-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/spoken-language-identification-c-api.c new file mode 100644 index 00000000..c6541782 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/spoken-language-identification-c-api.c @@ -0,0 +1,68 @@ +// c-api-examples/spoken-language-identification-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// We assume you have pre-downloaded the whisper multi-lingual models +// from https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models +// An example command to download the "tiny" whisper model is given below: +// +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2 +// tar xvf sherpa-onnx-whisper-tiny.tar.bz2 +// rm sherpa-onnx-whisper-tiny.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + SherpaMnnSpokenLanguageIdentificationConfig config; + + memset(&config, 0, sizeof(config)); + + config.whisper.encoder = "./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx"; + config.whisper.decoder = "./sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx"; + config.num_threads = 1; + config.debug = 1; + config.provider = "cpu"; + + const SherpaMnnSpokenLanguageIdentification *slid = + SherpaMnnCreateSpokenLanguageIdentification(&config); + if (!slid) { + fprintf(stderr, "Failed to create spoken language identifier"); + return -1; + } + + // You can find more test waves from + // https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs + const char *wav_filename = "./sherpa-onnx-whisper-tiny/test_wavs/0.wav"; + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + SherpaMnnOfflineStream *stream = + SherpaMnnSpokenLanguageIdentificationCreateOfflineStream(slid); + + SherpaMnnAcceptWaveformOffline(stream, wave->sample_rate, wave->samples, + wave->num_samples); + + const SherpaMnnSpokenLanguageIdentificationResult *result = + SherpaMnnSpokenLanguageIdentificationCompute(slid, stream); + + fprintf(stderr, "wav_filename: %s\n", wav_filename); + fprintf(stderr, "Detected language: %s\n", result->lang); + + SherpaMnnDestroySpokenLanguageIdentificationResult(result); + SherpaMnnDestroyOfflineStream(stream); + SherpaMnnFreeWave(wave); + SherpaMnnDestroySpokenLanguageIdentification(slid); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/streaming-ctc-buffered-tokens-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/streaming-ctc-buffered-tokens-c-api.c new file mode 100644 index 00000000..5a78dc85 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/streaming-ctc-buffered-tokens-c-api.c @@ -0,0 +1,180 @@ +// c-api-examples/streaming-ctc-buffered-tokens-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Luo Xiao + +// +// This file demonstrates how to use streaming Zipformer2 Ctc with sherpa-onnx's +// C API and with tokens loaded from buffered strings instead of +// from external files API. +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +// tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +// rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +static size_t ReadFile(const char *filename, const char **buffer_out) { + FILE *file = fopen(filename, "r"); + if (file == NULL) { + fprintf(stderr, "Failed to open %s\n", filename); + return -1; + } + fseek(file, 0L, SEEK_END); + long size = ftell(file); + rewind(file); + *buffer_out = malloc(size); + if (*buffer_out == NULL) { + fclose(file); + fprintf(stderr, "Memory error\n"); + return -1; + } + size_t read_bytes = fread((void *)*buffer_out, 1, size, file); + if (read_bytes != size) { + printf("Errors occured in reading the file %s\n", filename); + free((void *)*buffer_out); + *buffer_out = NULL; + fclose(file); + return -1; + } + fclose(file); + return read_bytes; +} + +int32_t main() { + const char *wav_filename = + "sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/" + "DEV_T0000000000.wav"; + const char *model_filename = + "sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/" + "ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx"; + const char *tokens_filename = + "sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt"; + const char *provider = "cpu"; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + // reading tokens to buffers + const char *tokens_buf; + size_t token_buf_size = ReadFile(tokens_filename, &tokens_buf); + if (token_buf_size < 1) { + fprintf(stderr, "Please check your tokens.txt!\n"); + free((void *)tokens_buf); + return -1; + } + + // Zipformer2Ctc config + SherpaMnnOnlineZipformer2CtcModelConfig zipformer2_ctc_config; + memset(&zipformer2_ctc_config, 0, sizeof(zipformer2_ctc_config)); + zipformer2_ctc_config.model = model_filename; + + // Online model config + SherpaMnnOnlineModelConfig online_model_config; + memset(&online_model_config, 0, sizeof(online_model_config)); + online_model_config.debug = 1; + online_model_config.num_threads = 1; + online_model_config.provider = provider; + online_model_config.tokens_buf = tokens_buf; + online_model_config.tokens_buf_size = token_buf_size; + online_model_config.zipformer2_ctc = zipformer2_ctc_config; + + // Recognizer config + SherpaMnnOnlineRecognizerConfig recognizer_config; + memset(&recognizer_config, 0, sizeof(recognizer_config)); + recognizer_config.decoding_method = "greedy_search"; + recognizer_config.model_config = online_model_config; + + const SherpaMnnOnlineRecognizer *recognizer = + SherpaMnnCreateOnlineRecognizer(&recognizer_config); + + free((void *)tokens_buf); + tokens_buf = NULL; + + if (recognizer == NULL) { + fprintf(stderr, "Please check your config!\n"); + SherpaMnnFreeWave(wave); + return -1; + } + + const SherpaMnnOnlineStream *stream = + SherpaMnnCreateOnlineStream(recognizer); + + const SherpaMnnDisplay *display = SherpaMnnCreateDisplay(50); + int32_t segment_id = 0; + +// simulate streaming. You can choose an arbitrary N +#define N 3200 + + fprintf(stderr, "sample rate: %d, num samples: %d, duration: %.2f s\n", + wave->sample_rate, wave->num_samples, + (float)wave->num_samples / wave->sample_rate); + + int32_t k = 0; + while (k < wave->num_samples) { + int32_t start = k; + int32_t end = + (start + N > wave->num_samples) ? wave->num_samples : (start + N); + k += N; + + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, + wave->samples + start, end - start); + while (SherpaMnnIsOnlineStreamReady(recognizer, stream)) { + SherpaMnnDecodeOnlineStream(recognizer, stream); + } + + const SherpaMnnOnlineRecognizerResult *r = + SherpaMnnGetOnlineStreamResult(recognizer, stream); + + if (strlen(r->text)) { + SherpaMnnPrint(display, segment_id, r->text); + } + + if (SherpaMnnOnlineStreamIsEndpoint(recognizer, stream)) { + if (strlen(r->text)) { + ++segment_id; + } + SherpaMnnOnlineStreamReset(recognizer, stream); + } + + SherpaMnnDestroyOnlineRecognizerResult(r); + } + + // add some tail padding + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, + 4800); + + SherpaMnnFreeWave(wave); + + SherpaMnnOnlineStreamInputFinished(stream); + while (SherpaMnnIsOnlineStreamReady(recognizer, stream)) { + SherpaMnnDecodeOnlineStream(recognizer, stream); + } + + const SherpaMnnOnlineRecognizerResult *r = + SherpaMnnGetOnlineStreamResult(recognizer, stream); + + if (strlen(r->text)) { + SherpaMnnPrint(display, segment_id, r->text); + } + + SherpaMnnDestroyOnlineRecognizerResult(r); + + SherpaMnnDestroyDisplay(display); + SherpaMnnDestroyOnlineStream(stream); + SherpaMnnDestroyOnlineRecognizer(recognizer); + fprintf(stderr, "\n"); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/streaming-hlg-decode-file-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/streaming-hlg-decode-file-c-api.c new file mode 100644 index 00000000..11d227b7 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/streaming-hlg-decode-file-c-api.c @@ -0,0 +1,130 @@ +// c-api-examples/streaming-hlg-decode-file-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation +/* +We use the following model as an example + +// clang-format off + +Download the model from +https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + +tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 +rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + +build/bin/streaming-hlg-decode-file-c-api + +(The above model is from https://github.com/k2-fsa/icefall/pull/1557) +*/ +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + // clang-format off + // + // Please download the model from + // https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + const char *model = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx"; + const char *tokens = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt"; + const char *graph = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst"; + const char *wav_filename = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/8k.wav"; + // clang-format on + + SherpaMnnOnlineRecognizerConfig config; + + memset(&config, 0, sizeof(config)); + config.feat_config.sample_rate = 16000; + config.feat_config.feature_dim = 80; + config.model_config.zipformer2_ctc.model = model; + config.model_config.tokens = tokens; + config.model_config.num_threads = 1; + config.model_config.provider = "cpu"; + config.model_config.debug = 0; + config.ctc_fst_decoder_config.graph = graph; + const SherpaMnnOnlineRecognizer *recognizer = + SherpaMnnCreateOnlineRecognizer(&config); + if (!recognizer) { + fprintf(stderr, "Failed to create recognizer"); + exit(-1); + } + + const SherpaMnnOnlineStream *stream = + SherpaMnnCreateOnlineStream(recognizer); + + const SherpaMnnDisplay *display = SherpaMnnCreateDisplay(50); + int32_t segment_id = 0; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + exit(-1); + } + +// simulate streaming. You can choose an arbitrary N +#define N 3200 + + fprintf(stderr, "sample rate: %d, num samples: %d, duration: %.2f s\n", + wave->sample_rate, wave->num_samples, + (float)wave->num_samples / wave->sample_rate); + + int32_t k = 0; + while (k < wave->num_samples) { + int32_t start = k; + int32_t end = + (start + N > wave->num_samples) ? wave->num_samples : (start + N); + k += N; + + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, + wave->samples + start, end - start); + while (SherpaMnnIsOnlineStreamReady(recognizer, stream)) { + SherpaMnnDecodeOnlineStream(recognizer, stream); + } + + const SherpaMnnOnlineRecognizerResult *r = + SherpaMnnGetOnlineStreamResult(recognizer, stream); + + if (strlen(r->text)) { + SherpaMnnPrint(display, segment_id, r->text); + } + + if (SherpaMnnOnlineStreamIsEndpoint(recognizer, stream)) { + if (strlen(r->text)) { + ++segment_id; + } + SherpaMnnOnlineStreamReset(recognizer, stream); + } + + SherpaMnnDestroyOnlineRecognizerResult(r); + } + + // add some tail padding + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, + 4800); + + SherpaMnnFreeWave(wave); + + SherpaMnnOnlineStreamInputFinished(stream); + while (SherpaMnnIsOnlineStreamReady(recognizer, stream)) { + SherpaMnnDecodeOnlineStream(recognizer, stream); + } + + const SherpaMnnOnlineRecognizerResult *r = + SherpaMnnGetOnlineStreamResult(recognizer, stream); + + if (strlen(r->text)) { + SherpaMnnPrint(display, segment_id, r->text); + } + + SherpaMnnDestroyOnlineRecognizerResult(r); + + SherpaMnnDestroyDisplay(display); + SherpaMnnDestroyOnlineStream(stream); + SherpaMnnDestroyOnlineRecognizer(recognizer); + fprintf(stderr, "\n"); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/streaming-paraformer-buffered-tokens-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/streaming-paraformer-buffered-tokens-c-api.c new file mode 100644 index 00000000..79bffb41 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/streaming-paraformer-buffered-tokens-c-api.c @@ -0,0 +1,181 @@ +// c-api-examples/streaming-paraformer-buffered-tokens-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Luo Xiao + +// +// This file demonstrates how to use streaming Paraformer with sherpa-onnx's C +// API and with tokens loaded from buffered strings instead of from +// external files API. +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 +// tar xvf sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 +// rm sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +static size_t ReadFile(const char *filename, const char **buffer_out) { + FILE *file = fopen(filename, "r"); + if (file == NULL) { + fprintf(stderr, "Failed to open %s\n", filename); + return -1; + } + fseek(file, 0L, SEEK_END); + long size = ftell(file); + rewind(file); + *buffer_out = malloc(size); + if (*buffer_out == NULL) { + fclose(file); + fprintf(stderr, "Memory error\n"); + return -1; + } + size_t read_bytes = fread((void *)*buffer_out, 1, size, file); + if (read_bytes != size) { + printf("Errors occured in reading the file %s\n", filename); + free((void *)*buffer_out); + *buffer_out = NULL; + fclose(file); + return -1; + } + fclose(file); + return read_bytes; +} + +int32_t main() { + const char *wav_filename = + "sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav"; + const char *encoder_filename = + "sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx"; + const char *decoder_filename = + "sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx"; + const char *tokens_filename = + "sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt"; + const char *provider = "cpu"; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + // reading tokens to buffers + const char *tokens_buf; + size_t token_buf_size = ReadFile(tokens_filename, &tokens_buf); + if (token_buf_size < 1) { + fprintf(stderr, "Please check your tokens.txt!\n"); + free((void *)tokens_buf); + return -1; + } + + // Paraformer config + SherpaMnnOnlineParaformerModelConfig paraformer_config; + memset(¶former_config, 0, sizeof(paraformer_config)); + paraformer_config.encoder = encoder_filename; + paraformer_config.decoder = decoder_filename; + + // Online model config + SherpaMnnOnlineModelConfig online_model_config; + memset(&online_model_config, 0, sizeof(online_model_config)); + online_model_config.debug = 1; + online_model_config.num_threads = 1; + online_model_config.provider = provider; + online_model_config.tokens_buf = tokens_buf; + online_model_config.tokens_buf_size = token_buf_size; + online_model_config.paraformer = paraformer_config; + + // Recognizer config + SherpaMnnOnlineRecognizerConfig recognizer_config; + memset(&recognizer_config, 0, sizeof(recognizer_config)); + recognizer_config.decoding_method = "greedy_search"; + recognizer_config.model_config = online_model_config; + + const SherpaMnnOnlineRecognizer *recognizer = + SherpaMnnCreateOnlineRecognizer(&recognizer_config); + + free((void *)tokens_buf); + tokens_buf = NULL; + + if (recognizer == NULL) { + fprintf(stderr, "Please check your config!\n"); + SherpaMnnFreeWave(wave); + return -1; + } + + const SherpaMnnOnlineStream *stream = + SherpaMnnCreateOnlineStream(recognizer); + + const SherpaMnnDisplay *display = SherpaMnnCreateDisplay(50); + int32_t segment_id = 0; + +// simulate streaming. You can choose an arbitrary N +#define N 3200 + + fprintf(stderr, "sample rate: %d, num samples: %d, duration: %.2f s\n", + wave->sample_rate, wave->num_samples, + (float)wave->num_samples / wave->sample_rate); + + int32_t k = 0; + while (k < wave->num_samples) { + int32_t start = k; + int32_t end = + (start + N > wave->num_samples) ? wave->num_samples : (start + N); + k += N; + + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, + wave->samples + start, end - start); + while (SherpaMnnIsOnlineStreamReady(recognizer, stream)) { + SherpaMnnDecodeOnlineStream(recognizer, stream); + } + + const SherpaMnnOnlineRecognizerResult *r = + SherpaMnnGetOnlineStreamResult(recognizer, stream); + + if (strlen(r->text)) { + SherpaMnnPrint(display, segment_id, r->text); + } + + if (SherpaMnnOnlineStreamIsEndpoint(recognizer, stream)) { + if (strlen(r->text)) { + ++segment_id; + } + SherpaMnnOnlineStreamReset(recognizer, stream); + } + + SherpaMnnDestroyOnlineRecognizerResult(r); + } + + // add some tail padding + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, + 4800); + + SherpaMnnFreeWave(wave); + + SherpaMnnOnlineStreamInputFinished(stream); + while (SherpaMnnIsOnlineStreamReady(recognizer, stream)) { + SherpaMnnDecodeOnlineStream(recognizer, stream); + } + + const SherpaMnnOnlineRecognizerResult *r = + SherpaMnnGetOnlineStreamResult(recognizer, stream); + + if (strlen(r->text)) { + SherpaMnnPrint(display, segment_id, r->text); + } + + SherpaMnnDestroyOnlineRecognizerResult(r); + + SherpaMnnDestroyDisplay(display); + SherpaMnnDestroyOnlineStream(stream); + SherpaMnnDestroyOnlineRecognizer(recognizer); + fprintf(stderr, "\n"); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/streaming-paraformer-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/streaming-paraformer-c-api.c new file mode 100644 index 00000000..97c1d8ff --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/streaming-paraformer-c-api.c @@ -0,0 +1,139 @@ +// c-api-examples/streaming-paraformer-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use streaming Paraformer with sherpa-onnx's C +// API. +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 +// tar xvf sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 +// rm sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + const char *wav_filename = + "sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav"; + const char *encoder_filename = + "sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx"; + const char *decoder_filename = + "sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx"; + const char *tokens_filename = + "sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt"; + const char *provider = "cpu"; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + // Paraformer config + SherpaMnnOnlineParaformerModelConfig paraformer_config; + memset(¶former_config, 0, sizeof(paraformer_config)); + paraformer_config.encoder = encoder_filename; + paraformer_config.decoder = decoder_filename; + + // Online model config + SherpaMnnOnlineModelConfig online_model_config; + memset(&online_model_config, 0, sizeof(online_model_config)); + online_model_config.debug = 1; + online_model_config.num_threads = 1; + online_model_config.provider = provider; + online_model_config.tokens = tokens_filename; + online_model_config.paraformer = paraformer_config; + + // Recognizer config + SherpaMnnOnlineRecognizerConfig recognizer_config; + memset(&recognizer_config, 0, sizeof(recognizer_config)); + recognizer_config.decoding_method = "greedy_search"; + recognizer_config.model_config = online_model_config; + + const SherpaMnnOnlineRecognizer *recognizer = + SherpaMnnCreateOnlineRecognizer(&recognizer_config); + + if (recognizer == NULL) { + fprintf(stderr, "Please check your config!\n"); + SherpaMnnFreeWave(wave); + return -1; + } + + const SherpaMnnOnlineStream *stream = + SherpaMnnCreateOnlineStream(recognizer); + + const SherpaMnnDisplay *display = SherpaMnnCreateDisplay(50); + int32_t segment_id = 0; + +// simulate streaming. You can choose an arbitrary N +#define N 3200 + + fprintf(stderr, "sample rate: %d, num samples: %d, duration: %.2f s\n", + wave->sample_rate, wave->num_samples, + (float)wave->num_samples / wave->sample_rate); + + int32_t k = 0; + while (k < wave->num_samples) { + int32_t start = k; + int32_t end = + (start + N > wave->num_samples) ? wave->num_samples : (start + N); + k += N; + + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, + wave->samples + start, end - start); + while (SherpaMnnIsOnlineStreamReady(recognizer, stream)) { + SherpaMnnDecodeOnlineStream(recognizer, stream); + } + + const SherpaMnnOnlineRecognizerResult *r = + SherpaMnnGetOnlineStreamResult(recognizer, stream); + + if (strlen(r->text)) { + SherpaMnnPrint(display, segment_id, r->text); + } + + if (SherpaMnnOnlineStreamIsEndpoint(recognizer, stream)) { + if (strlen(r->text)) { + ++segment_id; + } + SherpaMnnOnlineStreamReset(recognizer, stream); + } + + SherpaMnnDestroyOnlineRecognizerResult(r); + } + + // add some tail padding + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, + 4800); + + SherpaMnnFreeWave(wave); + + SherpaMnnOnlineStreamInputFinished(stream); + while (SherpaMnnIsOnlineStreamReady(recognizer, stream)) { + SherpaMnnDecodeOnlineStream(recognizer, stream); + } + + const SherpaMnnOnlineRecognizerResult *r = + SherpaMnnGetOnlineStreamResult(recognizer, stream); + + if (strlen(r->text)) { + SherpaMnnPrint(display, segment_id, r->text); + } + + SherpaMnnDestroyOnlineRecognizerResult(r); + + SherpaMnnDestroyDisplay(display); + SherpaMnnDestroyOnlineStream(stream); + SherpaMnnDestroyOnlineRecognizer(recognizer); + fprintf(stderr, "\n"); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c new file mode 100644 index 00000000..b59011e1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c @@ -0,0 +1,203 @@ +// c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Luo Xiao + +// +// This file demonstrates how to use streaming Zipformer with sherpa-onnx's C +// API and with tokens and hotwords loaded from buffered strings instead of from +// external files API. +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 +// tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 +// rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +static size_t ReadFile(const char *filename, const char **buffer_out) { + FILE *file = fopen(filename, "r"); + if (file == NULL) { + fprintf(stderr, "Failed to open %s\n", filename); + return -1; + } + fseek(file, 0L, SEEK_END); + long size = ftell(file); + rewind(file); + *buffer_out = malloc(size); + if (*buffer_out == NULL) { + fclose(file); + fprintf(stderr, "Memory error\n"); + return -1; + } + size_t read_bytes = fread((void *)*buffer_out, 1, size, file); + if (read_bytes != size) { + printf("Errors occured in reading the file %s\n", filename); + free((void *)*buffer_out); + *buffer_out = NULL; + fclose(file); + return -1; + } + fclose(file); + return read_bytes; +} + +int32_t main() { + const char *wav_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs/0.wav"; + const char *encoder_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/" + "encoder-epoch-99-avg-1.onnx"; + const char *decoder_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/" + "decoder-epoch-99-avg-1.onnx"; + const char *joiner_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/" + "joiner-epoch-99-avg-1.onnx"; + const char *provider = "cpu"; + const char *modeling_unit = "bpe"; + const char *tokens_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/tokens.txt"; + const char *hotwords_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/hotwords.txt"; + const char *bpe_vocab = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/" + "bpe.vocab"; + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + // reading tokens and hotwords to buffers + const char *tokens_buf; + size_t token_buf_size = ReadFile(tokens_filename, &tokens_buf); + if (token_buf_size < 1) { + fprintf(stderr, "Please check your tokens.txt!\n"); + free((void *)tokens_buf); + return -1; + } + const char *hotwords_buf; + size_t hotwords_buf_size = ReadFile(hotwords_filename, &hotwords_buf); + if (hotwords_buf_size < 1) { + fprintf(stderr, "Please check your hotwords.txt!\n"); + free((void *)hotwords_buf); + return -1; + } + + // Zipformer config + SherpaMnnOnlineTransducerModelConfig zipformer_config; + memset(&zipformer_config, 0, sizeof(zipformer_config)); + zipformer_config.encoder = encoder_filename; + zipformer_config.decoder = decoder_filename; + zipformer_config.joiner = joiner_filename; + + // Online model config + SherpaMnnOnlineModelConfig online_model_config; + memset(&online_model_config, 0, sizeof(online_model_config)); + online_model_config.debug = 1; + online_model_config.num_threads = 1; + online_model_config.provider = provider; + online_model_config.tokens_buf = tokens_buf; + online_model_config.tokens_buf_size = token_buf_size; + online_model_config.transducer = zipformer_config; + + // Recognizer config + SherpaMnnOnlineRecognizerConfig recognizer_config; + memset(&recognizer_config, 0, sizeof(recognizer_config)); + recognizer_config.decoding_method = "modified_beam_search"; + recognizer_config.model_config = online_model_config; + recognizer_config.hotwords_buf = hotwords_buf; + recognizer_config.hotwords_buf_size = hotwords_buf_size; + + const SherpaMnnOnlineRecognizer *recognizer = + SherpaMnnCreateOnlineRecognizer(&recognizer_config); + + free((void *)tokens_buf); + tokens_buf = NULL; + free((void *)hotwords_buf); + hotwords_buf = NULL; + + if (recognizer == NULL) { + fprintf(stderr, "Please check your config!\n"); + SherpaMnnFreeWave(wave); + return -1; + } + + const SherpaMnnOnlineStream *stream = + SherpaMnnCreateOnlineStream(recognizer); + + const SherpaMnnDisplay *display = SherpaMnnCreateDisplay(50); + int32_t segment_id = 0; + +// simulate streaming. You can choose an arbitrary N +#define N 3200 + + fprintf(stderr, "sample rate: %d, num samples: %d, duration: %.2f s\n", + wave->sample_rate, wave->num_samples, + (float)wave->num_samples / wave->sample_rate); + + int32_t k = 0; + while (k < wave->num_samples) { + int32_t start = k; + int32_t end = + (start + N > wave->num_samples) ? wave->num_samples : (start + N); + k += N; + + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, + wave->samples + start, end - start); + while (SherpaMnnIsOnlineStreamReady(recognizer, stream)) { + SherpaMnnDecodeOnlineStream(recognizer, stream); + } + + const SherpaMnnOnlineRecognizerResult *r = + SherpaMnnGetOnlineStreamResult(recognizer, stream); + + if (strlen(r->text)) { + SherpaMnnPrint(display, segment_id, r->text); + } + + if (SherpaMnnOnlineStreamIsEndpoint(recognizer, stream)) { + if (strlen(r->text)) { + ++segment_id; + } + SherpaMnnOnlineStreamReset(recognizer, stream); + } + + SherpaMnnDestroyOnlineRecognizerResult(r); + } + + // add some tail padding + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, + 4800); + + SherpaMnnFreeWave(wave); + + SherpaMnnOnlineStreamInputFinished(stream); + while (SherpaMnnIsOnlineStreamReady(recognizer, stream)) { + SherpaMnnDecodeOnlineStream(recognizer, stream); + } + + const SherpaMnnOnlineRecognizerResult *r = + SherpaMnnGetOnlineStreamResult(recognizer, stream); + + if (strlen(r->text)) { + SherpaMnnPrint(display, segment_id, r->text); + } + + SherpaMnnDestroyOnlineRecognizerResult(r); + + SherpaMnnDestroyDisplay(display); + SherpaMnnDestroyOnlineStream(stream); + SherpaMnnDestroyOnlineRecognizer(recognizer); + fprintf(stderr, "\n"); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/streaming-zipformer-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/streaming-zipformer-c-api.c new file mode 100644 index 00000000..5c453499 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/streaming-zipformer-c-api.c @@ -0,0 +1,145 @@ +// c-api-examples/streaming-zipformer-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use streaming Zipformer with sherpa-onnx's C +// API. +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 +// tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 +// rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + const char *wav_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs/0.wav"; + const char *encoder_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/" + "encoder-epoch-99-avg-1.onnx"; + const char *decoder_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/" + "decoder-epoch-99-avg-1.onnx"; + const char *joiner_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/" + "joiner-epoch-99-avg-1.onnx"; + const char *tokens_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/tokens.txt"; + const char *provider = "cpu"; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + // Zipformer config + SherpaMnnOnlineTransducerModelConfig zipformer_config; + memset(&zipformer_config, 0, sizeof(zipformer_config)); + zipformer_config.encoder = encoder_filename; + zipformer_config.decoder = decoder_filename; + zipformer_config.joiner = joiner_filename; + + // Online model config + SherpaMnnOnlineModelConfig online_model_config; + memset(&online_model_config, 0, sizeof(online_model_config)); + online_model_config.debug = 1; + online_model_config.num_threads = 1; + online_model_config.provider = provider; + online_model_config.tokens = tokens_filename; + online_model_config.transducer = zipformer_config; + + // Recognizer config + SherpaMnnOnlineRecognizerConfig recognizer_config; + memset(&recognizer_config, 0, sizeof(recognizer_config)); + recognizer_config.decoding_method = "greedy_search"; + recognizer_config.model_config = online_model_config; + + const SherpaMnnOnlineRecognizer *recognizer = + SherpaMnnCreateOnlineRecognizer(&recognizer_config); + + if (recognizer == NULL) { + fprintf(stderr, "Please check your config!\n"); + SherpaMnnFreeWave(wave); + return -1; + } + + const SherpaMnnOnlineStream *stream = + SherpaMnnCreateOnlineStream(recognizer); + + const SherpaMnnDisplay *display = SherpaMnnCreateDisplay(50); + int32_t segment_id = 0; + +// simulate streaming. You can choose an arbitrary N +#define N 3200 + + fprintf(stderr, "sample rate: %d, num samples: %d, duration: %.2f s\n", + wave->sample_rate, wave->num_samples, + (float)wave->num_samples / wave->sample_rate); + + int32_t k = 0; + while (k < wave->num_samples) { + int32_t start = k; + int32_t end = + (start + N > wave->num_samples) ? wave->num_samples : (start + N); + k += N; + + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, + wave->samples + start, end - start); + while (SherpaMnnIsOnlineStreamReady(recognizer, stream)) { + SherpaMnnDecodeOnlineStream(recognizer, stream); + } + + const SherpaMnnOnlineRecognizerResult *r = + SherpaMnnGetOnlineStreamResult(recognizer, stream); + + if (strlen(r->text)) { + SherpaMnnPrint(display, segment_id, r->text); + } + + if (SherpaMnnOnlineStreamIsEndpoint(recognizer, stream)) { + if (strlen(r->text)) { + ++segment_id; + } + SherpaMnnOnlineStreamReset(recognizer, stream); + } + + SherpaMnnDestroyOnlineRecognizerResult(r); + } + + // add some tail padding + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate + SherpaMnnOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, + 4800); + + SherpaMnnFreeWave(wave); + + SherpaMnnOnlineStreamInputFinished(stream); + while (SherpaMnnIsOnlineStreamReady(recognizer, stream)) { + SherpaMnnDecodeOnlineStream(recognizer, stream); + } + + const SherpaMnnOnlineRecognizerResult *r = + SherpaMnnGetOnlineStreamResult(recognizer, stream); + + if (strlen(r->text)) { + SherpaMnnPrint(display, segment_id, r->text); + } + + SherpaMnnDestroyOnlineRecognizerResult(r); + + SherpaMnnDestroyDisplay(display); + SherpaMnnDestroyOnlineStream(stream); + SherpaMnnDestroyOnlineRecognizer(recognizer); + fprintf(stderr, "\n"); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/telespeech-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/telespeech-c-api.c new file mode 100644 index 00000000..0257e551 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/telespeech-c-api.c @@ -0,0 +1,78 @@ +// c-api-examples/telespeech-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use TeleSpeech-ASR CTC model with sherpa-onnx's +// C API. +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2 +// tar xvf sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2 +// rm sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + const char *wav_filename = + "sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/test_wavs/3-sichuan.wav"; + const char *model_filename = + "sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/model.int8.onnx"; + const char *tokens_filename = + "sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/tokens.txt"; + const char *provider = "cpu"; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + // Offline model config + SherpaMnnOfflineModelConfig offline_model_config; + memset(&offline_model_config, 0, sizeof(offline_model_config)); + offline_model_config.debug = 1; + offline_model_config.num_threads = 1; + offline_model_config.provider = provider; + offline_model_config.tokens = tokens_filename; + offline_model_config.telespeech_ctc = model_filename; + + // Recognizer config + SherpaMnnOfflineRecognizerConfig recognizer_config; + memset(&recognizer_config, 0, sizeof(recognizer_config)); + recognizer_config.decoding_method = "greedy_search"; + recognizer_config.model_config = offline_model_config; + + const SherpaMnnOfflineRecognizer *recognizer = + SherpaMnnCreateOfflineRecognizer(&recognizer_config); + + if (recognizer == NULL) { + fprintf(stderr, "Please check your config!\n"); + SherpaMnnFreeWave(wave); + return -1; + } + + const SherpaMnnOfflineStream *stream = + SherpaMnnCreateOfflineStream(recognizer); + + SherpaMnnAcceptWaveformOffline(stream, wave->sample_rate, wave->samples, + wave->num_samples); + SherpaMnnDecodeOfflineStream(recognizer, stream); + const SherpaMnnOfflineRecognizerResult *result = + SherpaMnnGetOfflineStreamResult(stream); + + fprintf(stderr, "Decoded text: %s\n", result->text); + + SherpaMnnDestroyOfflineRecognizerResult(result); + SherpaMnnDestroyOfflineStream(stream); + SherpaMnnDestroyOfflineRecognizer(recognizer); + SherpaMnnFreeWave(wave); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/vad-moonshine-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/vad-moonshine-c-api.c new file mode 100644 index 00000000..c4223223 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/vad-moonshine-c-api.c @@ -0,0 +1,146 @@ +// c-api-examples/vad-moonshine-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use VAD + Moonshine with sherpa-onnx's C API. +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/Obama.wav +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +// tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +// rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + const char *wav_filename = "./Obama.wav"; + const char *vad_filename = "./silero_vad.onnx"; + + const char *preprocessor = + "./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx"; + const char *encoder = "./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx"; + const char *uncached_decoder = + "./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx"; + const char *cached_decoder = + "./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx"; + const char *tokens = "./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt"; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + if (wave->sample_rate != 16000) { + fprintf(stderr, "Expect the sample rate to be 16000. Given: %d\n", + wave->sample_rate); + SherpaMnnFreeWave(wave); + return -1; + } + + // Offline model config + SherpaMnnOfflineModelConfig offline_model_config; + memset(&offline_model_config, 0, sizeof(offline_model_config)); + offline_model_config.debug = 0; + offline_model_config.num_threads = 1; + offline_model_config.provider = "cpu"; + offline_model_config.tokens = tokens; + offline_model_config.moonshine.preprocessor = preprocessor; + offline_model_config.moonshine.encoder = encoder; + offline_model_config.moonshine.uncached_decoder = uncached_decoder; + offline_model_config.moonshine.cached_decoder = cached_decoder; + + // Recognizer config + SherpaMnnOfflineRecognizerConfig recognizer_config; + memset(&recognizer_config, 0, sizeof(recognizer_config)); + recognizer_config.decoding_method = "greedy_search"; + recognizer_config.model_config = offline_model_config; + + const SherpaMnnOfflineRecognizer *recognizer = + SherpaMnnCreateOfflineRecognizer(&recognizer_config); + + if (recognizer == NULL) { + fprintf(stderr, "Please check your recognizer config!\n"); + SherpaMnnFreeWave(wave); + return -1; + } + + SherpaMnnVadModelConfig vadConfig; + memset(&vadConfig, 0, sizeof(vadConfig)); + vadConfig.silero_vad.model = vad_filename; + vadConfig.silero_vad.threshold = 0.5; + vadConfig.silero_vad.min_silence_duration = 0.5; + vadConfig.silero_vad.min_speech_duration = 0.5; + vadConfig.silero_vad.max_speech_duration = 10; + vadConfig.silero_vad.window_size = 512; + vadConfig.sample_rate = 16000; + vadConfig.num_threads = 1; + vadConfig.debug = 1; + + SherpaMnnVoiceActivityDetector *vad = + SherpaMnnCreateVoiceActivityDetector(&vadConfig, 30); + + if (vad == NULL) { + fprintf(stderr, "Please check your recognizer config!\n"); + SherpaMnnFreeWave(wave); + SherpaMnnDestroyOfflineRecognizer(recognizer); + return -1; + } + + int32_t window_size = vadConfig.silero_vad.window_size; + int32_t i = 0; + int is_eof = 0; + + while (!is_eof) { + if (i + window_size < wave->num_samples) { + SherpaMnnVoiceActivityDetectorAcceptWaveform(vad, wave->samples + i, + window_size); + } else { + SherpaMnnVoiceActivityDetectorFlush(vad); + is_eof = 1; + } + while (!SherpaMnnVoiceActivityDetectorEmpty(vad)) { + const SherpaMnnSpeechSegment *segment = + SherpaMnnVoiceActivityDetectorFront(vad); + + const SherpaMnnOfflineStream *stream = + SherpaMnnCreateOfflineStream(recognizer); + + SherpaMnnAcceptWaveformOffline(stream, wave->sample_rate, + segment->samples, segment->n); + + SherpaMnnDecodeOfflineStream(recognizer, stream); + + const SherpaMnnOfflineRecognizerResult *result = + SherpaMnnGetOfflineStreamResult(stream); + + float start = segment->start / 16000.0f; + float duration = segment->n / 16000.0f; + float stop = start + duration; + + fprintf(stderr, "%.3f -- %.3f: %s\n", start, stop, result->text); + + SherpaMnnDestroyOfflineRecognizerResult(result); + SherpaMnnDestroyOfflineStream(stream); + + SherpaMnnDestroySpeechSegment(segment); + SherpaMnnVoiceActivityDetectorPop(vad); + } + i += window_size; + } + + SherpaMnnDestroyOfflineRecognizer(recognizer); + SherpaMnnDestroyVoiceActivityDetector(vad); + SherpaMnnFreeWave(wave); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/vad-sense-voice-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/vad-sense-voice-c-api.c new file mode 100644 index 00000000..118a082f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/vad-sense-voice-c-api.c @@ -0,0 +1,148 @@ +// c-api-examples/vad-sense-voice-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use VAD + SenseVoice with sherpa-onnx's C API. +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/lei-jun-test.wav +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +// tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +// rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + const char *wav_filename = "./lei-jun-test.wav"; + const char *vad_filename = "./silero_vad.onnx"; + const char *model_filename = + "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx"; + const char *tokens_filename = + "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt"; + const char *language = "auto"; + const char *provider = "cpu"; + int32_t use_inverse_text_normalization = 1; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + if (wave->sample_rate != 16000) { + fprintf(stderr, "Expect the sample rate to be 16000. Given: %d\n", + wave->sample_rate); + SherpaMnnFreeWave(wave); + return -1; + } + + SherpaMnnOfflineSenseVoiceModelConfig sense_voice_config; + memset(&sense_voice_config, 0, sizeof(sense_voice_config)); + sense_voice_config.model = model_filename; + sense_voice_config.language = language; + sense_voice_config.use_itn = use_inverse_text_normalization; + + // Offline model config + SherpaMnnOfflineModelConfig offline_model_config; + memset(&offline_model_config, 0, sizeof(offline_model_config)); + offline_model_config.debug = 0; + offline_model_config.num_threads = 1; + offline_model_config.provider = provider; + offline_model_config.tokens = tokens_filename; + offline_model_config.sense_voice = sense_voice_config; + + // Recognizer config + SherpaMnnOfflineRecognizerConfig recognizer_config; + memset(&recognizer_config, 0, sizeof(recognizer_config)); + recognizer_config.decoding_method = "greedy_search"; + recognizer_config.model_config = offline_model_config; + + const SherpaMnnOfflineRecognizer *recognizer = + SherpaMnnCreateOfflineRecognizer(&recognizer_config); + + if (recognizer == NULL) { + fprintf(stderr, "Please check your recognizer config!\n"); + SherpaMnnFreeWave(wave); + return -1; + } + + SherpaMnnVadModelConfig vadConfig; + memset(&vadConfig, 0, sizeof(vadConfig)); + vadConfig.silero_vad.model = vad_filename; + vadConfig.silero_vad.threshold = 0.5; + vadConfig.silero_vad.min_silence_duration = 0.5; + vadConfig.silero_vad.min_speech_duration = 0.5; + vadConfig.silero_vad.max_speech_duration = 5; + vadConfig.silero_vad.window_size = 512; + vadConfig.sample_rate = 16000; + vadConfig.num_threads = 1; + vadConfig.debug = 1; + + SherpaMnnVoiceActivityDetector *vad = + SherpaMnnCreateVoiceActivityDetector(&vadConfig, 30); + + if (vad == NULL) { + fprintf(stderr, "Please check your recognizer config!\n"); + SherpaMnnFreeWave(wave); + SherpaMnnDestroyOfflineRecognizer(recognizer); + return -1; + } + + int32_t window_size = vadConfig.silero_vad.window_size; + int32_t i = 0; + int is_eof = 0; + + while (!is_eof) { + if (i + window_size < wave->num_samples) { + SherpaMnnVoiceActivityDetectorAcceptWaveform(vad, wave->samples + i, + window_size); + } else { + SherpaMnnVoiceActivityDetectorFlush(vad); + is_eof = 1; + } + + while (!SherpaMnnVoiceActivityDetectorEmpty(vad)) { + const SherpaMnnSpeechSegment *segment = + SherpaMnnVoiceActivityDetectorFront(vad); + + const SherpaMnnOfflineStream *stream = + SherpaMnnCreateOfflineStream(recognizer); + + SherpaMnnAcceptWaveformOffline(stream, wave->sample_rate, + segment->samples, segment->n); + + SherpaMnnDecodeOfflineStream(recognizer, stream); + + const SherpaMnnOfflineRecognizerResult *result = + SherpaMnnGetOfflineStreamResult(stream); + + float start = segment->start / 16000.0f; + float duration = segment->n / 16000.0f; + float stop = start + duration; + + fprintf(stderr, "%.3f -- %.3f: %s\n", start, stop, result->text); + + SherpaMnnDestroyOfflineRecognizerResult(result); + SherpaMnnDestroyOfflineStream(stream); + + SherpaMnnDestroySpeechSegment(segment); + SherpaMnnVoiceActivityDetectorPop(vad); + } + i += window_size; + } + + SherpaMnnDestroyOfflineRecognizer(recognizer); + SherpaMnnDestroyVoiceActivityDetector(vad); + SherpaMnnFreeWave(wave); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/vad-whisper-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/vad-whisper-c-api.c new file mode 100644 index 00000000..ed6477b7 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/vad-whisper-c-api.c @@ -0,0 +1,145 @@ +// c-api-examples/vad-whisper-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use VAD + Whisper tiny.en with +// sherpa-onnx's C API. +// +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/Obama.wav +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 +// tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2 +// rm sherpa-onnx-whisper-tiny.en.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + const char *wav_filename = "./Obama.wav"; + const char *vad_filename = "./silero_vad.onnx"; + + const char *encoder = "sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"; + const char *decoder = "sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"; + const char *tokens = "sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + if (wave->sample_rate != 16000) { + fprintf(stderr, "Expect the sample rate to be 16000. Given: %d\n", + wave->sample_rate); + SherpaMnnFreeWave(wave); + return -1; + } + + // Offline model config + SherpaMnnOfflineModelConfig offline_model_config; + memset(&offline_model_config, 0, sizeof(offline_model_config)); + offline_model_config.debug = 0; + offline_model_config.num_threads = 1; + offline_model_config.provider = "cpu"; + offline_model_config.tokens = tokens; + offline_model_config.whisper.encoder = encoder; + offline_model_config.whisper.decoder = decoder; + offline_model_config.whisper.language = "en"; + offline_model_config.whisper.tail_paddings = 0; + offline_model_config.whisper.task = "transcribe"; + + // Recognizer config + SherpaMnnOfflineRecognizerConfig recognizer_config; + memset(&recognizer_config, 0, sizeof(recognizer_config)); + recognizer_config.decoding_method = "greedy_search"; + recognizer_config.model_config = offline_model_config; + + const SherpaMnnOfflineRecognizer *recognizer = + SherpaMnnCreateOfflineRecognizer(&recognizer_config); + + if (recognizer == NULL) { + fprintf(stderr, "Please check your recognizer config!\n"); + SherpaMnnFreeWave(wave); + return -1; + } + + SherpaMnnVadModelConfig vadConfig; + memset(&vadConfig, 0, sizeof(vadConfig)); + vadConfig.silero_vad.model = vad_filename; + vadConfig.silero_vad.threshold = 0.5; + vadConfig.silero_vad.min_silence_duration = 0.5; + vadConfig.silero_vad.min_speech_duration = 0.5; + vadConfig.silero_vad.max_speech_duration = 10; + vadConfig.silero_vad.window_size = 512; + vadConfig.sample_rate = 16000; + vadConfig.num_threads = 1; + vadConfig.debug = 1; + + SherpaMnnVoiceActivityDetector *vad = + SherpaMnnCreateVoiceActivityDetector(&vadConfig, 30); + + if (vad == NULL) { + fprintf(stderr, "Please check your recognizer config!\n"); + SherpaMnnFreeWave(wave); + SherpaMnnDestroyOfflineRecognizer(recognizer); + return -1; + } + + int32_t window_size = vadConfig.silero_vad.window_size; + int32_t i = 0; + int is_eof = 0; + + while (!is_eof) { + if (i + window_size < wave->num_samples) { + SherpaMnnVoiceActivityDetectorAcceptWaveform(vad, wave->samples + i, + window_size); + } + else { + SherpaMnnVoiceActivityDetectorFlush(vad); + is_eof = 1; + } + while (!SherpaMnnVoiceActivityDetectorEmpty(vad)) { + const SherpaMnnSpeechSegment *segment = + SherpaMnnVoiceActivityDetectorFront(vad); + + const SherpaMnnOfflineStream *stream = + SherpaMnnCreateOfflineStream(recognizer); + + SherpaMnnAcceptWaveformOffline(stream, wave->sample_rate, + segment->samples, segment->n); + + SherpaMnnDecodeOfflineStream(recognizer, stream); + + const SherpaMnnOfflineRecognizerResult *result = + SherpaMnnGetOfflineStreamResult(stream); + + float start = segment->start / 16000.0f; + float duration = segment->n / 16000.0f; + float stop = start + duration; + + fprintf(stderr, "%.3f -- %.3f: %s\n", start, stop, result->text); + + SherpaMnnDestroyOfflineRecognizerResult(result); + SherpaMnnDestroyOfflineStream(stream); + + SherpaMnnDestroySpeechSegment(segment); + SherpaMnnVoiceActivityDetectorPop(vad); + } + i += window_size; + } + + SherpaMnnDestroyOfflineRecognizer(recognizer); + SherpaMnnDestroyVoiceActivityDetector(vad); + SherpaMnnFreeWave(wave); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/whisper-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/whisper-c-api.c new file mode 100644 index 00000000..f4f1a71e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/whisper-c-api.c @@ -0,0 +1,89 @@ +// c-api-examples/whisper-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// We assume you have pre-downloaded the whisper multi-lingual models +// from https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models +// An example command to download the "tiny" whisper model is given below: +// +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2 +// tar xvf sherpa-onnx-whisper-tiny.tar.bz2 +// rm sherpa-onnx-whisper-tiny.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + const char *wav_filename = "./sherpa-onnx-whisper-tiny/test_wavs/0.wav"; + const char *encoder_filename = "sherpa-onnx-whisper-tiny/tiny-encoder.onnx"; + const char *decoder_filename = "sherpa-onnx-whisper-tiny/tiny-decoder.onnx"; + const char *tokens_filename = "sherpa-onnx-whisper-tiny/tiny-tokens.txt"; + const char *language = "en"; + const char *provider = "cpu"; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + // Whisper config + SherpaMnnOfflineWhisperModelConfig whisper_config; + memset(&whisper_config, 0, sizeof(whisper_config)); + whisper_config.decoder = decoder_filename; + whisper_config.encoder = encoder_filename; + whisper_config.language = language; + whisper_config.tail_paddings = 0; + whisper_config.task = "transcribe"; + + // Offline model config + SherpaMnnOfflineModelConfig offline_model_config; + memset(&offline_model_config, 0, sizeof(offline_model_config)); + offline_model_config.debug = 1; + offline_model_config.num_threads = 1; + offline_model_config.provider = provider; + offline_model_config.tokens = tokens_filename; + offline_model_config.whisper = whisper_config; + + // Recognizer config + SherpaMnnOfflineRecognizerConfig recognizer_config; + memset(&recognizer_config, 0, sizeof(recognizer_config)); + recognizer_config.decoding_method = "greedy_search"; + recognizer_config.model_config = offline_model_config; + + const SherpaMnnOfflineRecognizer *recognizer = + SherpaMnnCreateOfflineRecognizer(&recognizer_config); + + if (recognizer == NULL) { + fprintf(stderr, "Please check your config!\n"); + + SherpaMnnFreeWave(wave); + + return -1; + } + + const SherpaMnnOfflineStream *stream = + SherpaMnnCreateOfflineStream(recognizer); + + SherpaMnnAcceptWaveformOffline(stream, wave->sample_rate, wave->samples, + wave->num_samples); + SherpaMnnDecodeOfflineStream(recognizer, stream); + const SherpaMnnOfflineRecognizerResult *result = + SherpaMnnGetOfflineStreamResult(stream); + + fprintf(stderr, "Decoded text: %s\n", result->text); + + SherpaMnnDestroyOfflineRecognizerResult(result); + SherpaMnnDestroyOfflineStream(stream); + SherpaMnnDestroyOfflineRecognizer(recognizer); + SherpaMnnFreeWave(wave); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/c-api-examples/zipformer-c-api.c b/apps/frameworks/sherpa-mnn/c-api-examples/zipformer-c-api.c new file mode 100644 index 00000000..c6ba2721 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/c-api-examples/zipformer-c-api.c @@ -0,0 +1,89 @@ +// c-api-examples/zipformer-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use non-streaming Zipformer with sherpa-onnx's +// C API. +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-small-en-2023-06-26.tar.bz2 +// tar xvf sherpa-onnx-zipformer-small-en-2023-06-26.tar.bz2 +// rm sherpa-onnx-zipformer-small-en-2023-06-26.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +int32_t main() { + const char *wav_filename = + "sherpa-onnx-zipformer-small-en-2023-06-26/test_wavs/0.wav"; + const char *encoder_filename = + "sherpa-onnx-zipformer-small-en-2023-06-26/encoder-epoch-99-avg-1.onnx"; + const char *decoder_filename = + "sherpa-onnx-zipformer-small-en-2023-06-26/decoder-epoch-99-avg-1.onnx"; + const char *joiner_filename = + "sherpa-onnx-zipformer-small-en-2023-06-26/joiner-epoch-99-avg-1.onnx"; + const char *tokens_filename = + "sherpa-onnx-zipformer-small-en-2023-06-26/tokens.txt"; + const char *provider = "cpu"; + + const SherpaMnnWave *wave = SherpaMnnReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + // Zipformer config + SherpaMnnOfflineTransducerModelConfig zipformer_config; + memset(&zipformer_config, 0, sizeof(zipformer_config)); + zipformer_config.encoder = encoder_filename; + zipformer_config.decoder = decoder_filename; + zipformer_config.joiner = joiner_filename; + + // Offline model config + SherpaMnnOfflineModelConfig offline_model_config; + memset(&offline_model_config, 0, sizeof(offline_model_config)); + offline_model_config.debug = 1; + offline_model_config.num_threads = 1; + offline_model_config.provider = provider; + offline_model_config.tokens = tokens_filename; + offline_model_config.transducer = zipformer_config; + + // Recognizer config + SherpaMnnOfflineRecognizerConfig recognizer_config; + memset(&recognizer_config, 0, sizeof(recognizer_config)); + recognizer_config.decoding_method = "greedy_search"; + recognizer_config.model_config = offline_model_config; + + const SherpaMnnOfflineRecognizer *recognizer = + SherpaMnnCreateOfflineRecognizer(&recognizer_config); + + if (recognizer == NULL) { + fprintf(stderr, "Please check your config!\n"); + SherpaMnnFreeWave(wave); + return -1; + } + + const SherpaMnnOfflineStream *stream = + SherpaMnnCreateOfflineStream(recognizer); + + SherpaMnnAcceptWaveformOffline(stream, wave->sample_rate, wave->samples, + wave->num_samples); + SherpaMnnDecodeOfflineStream(recognizer, stream); + const SherpaMnnOfflineRecognizerResult *result = + SherpaMnnGetOfflineStreamResult(stream); + + fprintf(stderr, "Decoded text: %s\n", result->text); + + SherpaMnnDestroyOfflineRecognizerResult(result); + SherpaMnnDestroyOfflineStream(stream); + SherpaMnnDestroyOfflineRecognizer(recognizer); + SherpaMnnFreeWave(wave); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/cmake/.gitignore b/apps/frameworks/sherpa-mnn/cmake/.gitignore new file mode 100644 index 00000000..4d7a2315 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/.gitignore @@ -0,0 +1 @@ +!*.cmake diff --git a/apps/frameworks/sherpa-mnn/cmake/__init__.py b/apps/frameworks/sherpa-mnn/cmake/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/apps/frameworks/sherpa-mnn/cmake/asio.cmake b/apps/frameworks/sherpa-mnn/cmake/asio.cmake new file mode 100644 index 00000000..9e3ce8d2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/asio.cmake @@ -0,0 +1,45 @@ +function(download_asio) + include(FetchContent) + + set(asio_URL "https://github.com/chriskohlhoff/asio/archive/refs/tags/asio-1-24-0.tar.gz") + set(asio_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/asio-asio-1-24-0.tar.gz") + set(asio_HASH "SHA256=cbcaaba0f66722787b1a7c33afe1befb3a012b5af3ad7da7ff0f6b8c9b7a8a5b") + + # If you don't have access to the Internet, + # please pre-download asio + set(possible_file_locations + $ENV{HOME}/Downloads/asio-asio-1-24-0.tar.gz + ${CMAKE_SOURCE_DIR}/asio-asio-1-24-0.tar.gz + ${CMAKE_BINARY_DIR}/asio-asio-1-24-0.tar.gz + /tmp/asio-asio-1-24-0.tar.gz + /star-fj/fangjun/download/github/asio-asio-1-24-0.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(asio_URL "${f}") + file(TO_CMAKE_PATH "${asio_URL}" asio_URL) + message(STATUS "Found local downloaded asio: ${asio_URL}") + set(asio_URL2) + break() + endif() + endforeach() + + FetchContent_Declare(asio + URL + ${asio_URL} + ${asio_URL2} + URL_HASH ${asio_HASH} + ) + + FetchContent_GetProperties(asio) + if(NOT asio_POPULATED) + message(STATUS "Downloading asio ${asio_URL}") + FetchContent_Populate(asio) + endif() + message(STATUS "asio is downloaded to ${asio_SOURCE_DIR}") + # add_subdirectory(${asio_SOURCE_DIR} ${asio_BINARY_DIR} EXCLUDE_FROM_ALL) + include_directories(${asio_SOURCE_DIR}/asio/include) +endfunction() + +download_asio() diff --git a/apps/frameworks/sherpa-mnn/cmake/cargs.cmake b/apps/frameworks/sherpa-mnn/cmake/cargs.cmake new file mode 100644 index 00000000..d7c60550 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/cargs.cmake @@ -0,0 +1,50 @@ +function(download_cargs) + include(FetchContent) + + set(cargs_URL "https://github.com/likle/cargs/archive/refs/tags/v1.0.3.tar.gz") + set(cargs_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/cargs-1.0.3.tar.gz") + set(cargs_HASH "SHA256=ddba25bd35e9c6c75bc706c126001b8ce8e084d40ef37050e6aa6963e836eb8b") + + # If you don't have access to the Internet, + # please pre-download cargs + set(possible_file_locations + $ENV{HOME}/Downloads/cargs-1.0.3.tar.gz + ${CMAKE_SOURCE_DIR}/cargs-1.0.3.tar.gz + ${CMAKE_BINARY_DIR}/cargs-1.0.3.tar.gz + /tmp/cargs-1.0.3.tar.gz + /star-fj/fangjun/download/github/cargs-1.0.3.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(cargs_URL "${f}") + file(TO_CMAKE_PATH "${cargs_URL}" cargs_URL) + message(STATUS "Found local downloaded cargs: ${cargs_URL}") + set(cargs_URL2) + break() + endif() + endforeach() + + FetchContent_Declare(cargs + URL + ${cargs_URL} + ${cargs_URL2} + URL_HASH + ${cargs_HASH} + ) + + FetchContent_GetProperties(cargs) + if(NOT cargs_POPULATED) + message(STATUS "Downloading cargs ${cargs_URL}") + FetchContent_Populate(cargs) + endif() + message(STATUS "cargs is downloaded to ${cargs_SOURCE_DIR}") + add_subdirectory(${cargs_SOURCE_DIR} ${cargs_BINARY_DIR} EXCLUDE_FROM_ALL) + + install(TARGETS cargs DESTINATION lib) + install(FILES ${cargs_SOURCE_DIR}/include/cargs.h + DESTINATION include + ) +endfunction() + +download_cargs() diff --git a/apps/frameworks/sherpa-mnn/cmake/cmake_extension.py b/apps/frameworks/sherpa-mnn/cmake/cmake_extension.py new file mode 100644 index 00000000..4b9028ba --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/cmake_extension.py @@ -0,0 +1,227 @@ +# cmake/cmake_extension.py +# Copyright (c) 2023 Xiaomi Corporation +# +# flake8: noqa + +import os +import platform +import shutil +import sys +from pathlib import Path + +import setuptools +from setuptools.command.build_ext import build_ext + + +def is_for_pypi(): + ans = os.environ.get("SHERPA_ONNX_IS_FOR_PYPI", None) + return ans is not None + + +def is_macos(): + return platform.system() == "Darwin" + + +def is_windows(): + return platform.system() == "Windows" + + +def is_linux(): + return platform.system() == "Linux" + + +def is_arm64(): + return platform.machine() in ["arm64", "aarch64"] + + +def is_x86(): + return platform.machine() in ["i386", "i686", "x86_64"] + + +def enable_alsa(): + build_alsa = os.environ.get("SHERPA_ONNX_ENABLE_ALSA", None) + return build_alsa and is_linux() and (is_arm64() or is_x86()) + + +def get_binaries(): + binaries = [ + "sherpa-onnx", + "sherpa-onnx-keyword-spotter", + "sherpa-onnx-microphone", + "sherpa-onnx-microphone-offline", + "sherpa-onnx-microphone-offline-audio-tagging", + "sherpa-onnx-microphone-offline-speaker-identification", + "sherpa-onnx-offline", + "sherpa-onnx-offline-audio-tagging", + "sherpa-onnx-offline-language-identification", + "sherpa-onnx-offline-punctuation", + "sherpa-onnx-offline-speaker-diarization", + "sherpa-onnx-offline-tts", + "sherpa-onnx-offline-tts-play", + "sherpa-onnx-offline-websocket-server", + "sherpa-onnx-online-punctuation", + "sherpa-onnx-online-websocket-client", + "sherpa-onnx-online-websocket-server", + "sherpa-onnx-vad-microphone", + "sherpa-onnx-vad-microphone-offline-asr", + "sherpa-onnx-vad-with-offline-asr", + ] + + if enable_alsa(): + binaries += [ + "sherpa-onnx-alsa", + "sherpa-onnx-alsa-offline", + "sherpa-onnx-alsa-offline-speaker-identification", + "sherpa-onnx-offline-tts-play-alsa", + "sherpa-onnx-vad-alsa", + "sherpa-onnx-alsa-offline-audio-tagging", + ] + + if is_windows(): + binaries += [ + "onnxruntime.dll", + "sherpa-onnx-c-api.dll", + "sherpa-onnx-cxx-api.dll", + ] + + return binaries + + +try: + from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + + class bdist_wheel(_bdist_wheel): + def finalize_options(self): + _bdist_wheel.finalize_options(self) + # In this case, the generated wheel has a name in the form + # sherpa-xxx-pyxx-none-any.whl + if is_for_pypi() and not is_macos(): + self.root_is_pure = True + else: + # The generated wheel has a name ending with + # -linux_x86_64.whl + self.root_is_pure = False + +except ImportError: + bdist_wheel = None + + +def cmake_extension(name, *args, **kwargs) -> setuptools.Extension: + kwargs["language"] = "c++" + sources = [] + return setuptools.Extension(name, sources, *args, **kwargs) + + +class BuildExtension(build_ext): + def build_extension(self, ext: setuptools.extension.Extension): + # build/temp.linux-x86_64-3.8 + os.makedirs(self.build_temp, exist_ok=True) + + # build/lib.linux-x86_64-3.8 + os.makedirs(self.build_lib, exist_ok=True) + + out_bin_dir = Path(self.build_lib).parent / "sherpa_onnx" / "bin" + install_dir = Path(self.build_lib).resolve() / "sherpa_onnx" + + sherpa_onnx_dir = Path(__file__).parent.parent.resolve() + + cmake_args = os.environ.get("SHERPA_ONNX_CMAKE_ARGS", "") + make_args = os.environ.get("SHERPA_ONNX_MAKE_ARGS", "") + system_make_args = os.environ.get("MAKEFLAGS", "") + + if cmake_args == "": + cmake_args = "-DCMAKE_BUILD_TYPE=Release" + + extra_cmake_args = f" -DCMAKE_INSTALL_PREFIX={install_dir} " + extra_cmake_args += " -DBUILD_SHARED_LIBS=ON " + extra_cmake_args += " -DBUILD_PIPER_PHONMIZE_EXE=OFF " + extra_cmake_args += " -DBUILD_PIPER_PHONMIZE_TESTS=OFF " + extra_cmake_args += " -DBUILD_ESPEAK_NG_EXE=OFF " + extra_cmake_args += " -DBUILD_ESPEAK_NG_TESTS=OFF " + extra_cmake_args += " -DSHERPA_ONNX_ENABLE_C_API=ON " + + extra_cmake_args += " -DSHERPA_ONNX_BUILD_C_API_EXAMPLES=OFF " + extra_cmake_args += " -DSHERPA_ONNX_ENABLE_CHECK=OFF " + extra_cmake_args += " -DSHERPA_ONNX_ENABLE_PYTHON=ON " + extra_cmake_args += " -DSHERPA_ONNX_ENABLE_PORTAUDIO=ON " + extra_cmake_args += " -DSHERPA_ONNX_ENABLE_WEBSOCKET=ON " + + if "PYTHON_EXECUTABLE" not in cmake_args: + print(f"Setting PYTHON_EXECUTABLE to {sys.executable}") + cmake_args += f" -DPYTHON_EXECUTABLE={sys.executable}" + + cmake_args += extra_cmake_args + + if is_windows(): + build_cmd = f""" + cmake {cmake_args} -B {self.build_temp} -S {sherpa_onnx_dir} + cmake --build {self.build_temp} --target install --config Release -- -m:2 + """ + print(f"build command is:\n{build_cmd}") + ret = os.system( + f"cmake {cmake_args} -B {self.build_temp} -S {sherpa_onnx_dir}" + ) + if ret != 0: + raise Exception("Failed to configure sherpa") + + ret = os.system( + f"cmake --build {self.build_temp} --target install --config Release -- -m:2" # noqa + ) + if ret != 0: + raise Exception("Failed to build and install sherpa") + else: + if make_args == "" and system_make_args == "": + print("for fast compilation, run:") + print('export SHERPA_ONNX_MAKE_ARGS="-j"; python setup.py install') + print('Setting make_args to "-j4"') + make_args = "-j4" + + if "-G Ninja" in cmake_args: + build_cmd = f""" + cd {self.build_temp} + cmake {cmake_args} {sherpa_onnx_dir} + ninja {make_args} install + """ + else: + build_cmd = f""" + cd {self.build_temp} + + cmake {cmake_args} {sherpa_onnx_dir} + + make {make_args} install/strip + """ + print(f"build command is:\n{build_cmd}") + + ret = os.system(build_cmd) + if ret != 0: + raise Exception( + "\nBuild sherpa-onnx failed. Please check the error message.\n" + "You can ask for help by creating an issue on GitHub.\n" + "\nClick:\n\thttps://github.com/k2-fsa/sherpa-onnx/issues/new\n" # noqa + ) + + suffix = ".exe" if is_windows() else "" + # Remember to also change setup.py + + binaries = get_binaries() + + for f in binaries: + suffix = "" if ".dll" in f else suffix + src_file = install_dir / "bin" / (f + suffix) + if not src_file.is_file(): + src_file = install_dir / "lib" / (f + suffix) + if not src_file.is_file(): + src_file = install_dir / ".." / (f + suffix) + + print(f"Copying {src_file} to {out_bin_dir}/") + shutil.copy(f"{src_file}", f"{out_bin_dir}/") + + shutil.rmtree(f"{install_dir}/bin") + shutil.rmtree(f"{install_dir}/share") + shutil.rmtree(f"{install_dir}/lib/pkgconfig") + + if is_macos(): + os.remove(f"{install_dir}/lib/libonnxruntime.dylib") + + if is_windows(): + shutil.rmtree(f"{install_dir}/lib") diff --git a/apps/frameworks/sherpa-mnn/cmake/cppjieba.cmake b/apps/frameworks/sherpa-mnn/cmake/cppjieba.cmake new file mode 100644 index 00000000..167da338 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/cppjieba.cmake @@ -0,0 +1,45 @@ +function(download_cppjieba) + include(FetchContent) + + set(cppjieba_URL "https://github.com/csukuangfj/cppjieba/archive/refs/tags/sherpa-onnx-2024-04-19.tar.gz") + set(cppjieba_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/cppjieba-sherpa-onnx-2024-04-19.tar.gz") + set(cppjieba_HASH "SHA256=03e5264687f0efaef05487a07d49c3f4c0f743347bfbf825df4b30cc75ac5288") + + # If you don't have access to the Internet, + # please pre-download cppjieba + set(possible_file_locations + $ENV{HOME}/Downloads/cppjieba-sherpa-onnx-2024-04-19.tar.gz + ${CMAKE_SOURCE_DIR}/cppjieba-sherpa-onnx-2024-04-19.tar.gz + ${CMAKE_BINARY_DIR}/cppjieba-sherpa-onnx-2024-04-19.tar.gz + /tmp/cppjieba-sherpa-onnx-2024-04-19.tar.gz + /star-fj/fangjun/download/github/cppjieba-sherpa-onnx-2024-04-19.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(cppjieba_URL "${f}") + file(TO_CMAKE_PATH "${cppjieba_URL}" cppjieba_URL) + message(STATUS "Found local downloaded cppjieba: ${cppjieba_URL}") + set(cppjieba_URL2) + break() + endif() + endforeach() + + FetchContent_Declare(cppjieba + URL + ${cppjieba_URL} + ${cppjieba_URL2} + URL_HASH + ${cppjieba_HASH} + ) + + FetchContent_GetProperties(cppjieba) + if(NOT cppjieba_POPULATED) + message(STATUS "Downloading cppjieba ${cppjieba_URL}") + FetchContent_Populate(cppjieba) + endif() + message(STATUS "cppjieba is downloaded to ${cppjieba_SOURCE_DIR}") + add_subdirectory(${cppjieba_SOURCE_DIR} ${cppjieba_BINARY_DIR} EXCLUDE_FROM_ALL) +endfunction() + +download_cppjieba() diff --git a/apps/frameworks/sherpa-mnn/cmake/eigen.cmake b/apps/frameworks/sherpa-mnn/cmake/eigen.cmake new file mode 100644 index 00000000..9aef9abc --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/eigen.cmake @@ -0,0 +1,48 @@ +function(download_eigen) + include(FetchContent) + + set(eigen_URL "https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.gz") + set(eigen_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/eigen-3.4.0.tar.gz") + set(eigen_HASH "SHA256=8586084f71f9bde545ee7fa6d00288b264a2b7ac3607b974e54d13e7162c1c72") + + # If you don't have access to the Internet, + # please pre-download eigen + set(possible_file_locations + $ENV{HOME}/Downloads/eigen-3.4.0.tar.gz + ${CMAKE_SOURCE_DIR}/eigen-3.4.0.tar.gz + ${CMAKE_BINARY_DIR}/eigen-3.4.0.tar.gz + /tmp/eigen-3.4.0.tar.gz + /star-fj/fangjun/download/github/eigen-3.4.0.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(eigen_URL "${f}") + file(TO_CMAKE_PATH "${eigen_URL}" eigen_URL) + message(STATUS "Found local downloaded eigen: ${eigen_URL}") + set(eigen_URL2) + break() + endif() + endforeach() + + set(BUILD_TESTING OFF CACHE BOOL "" FORCE) + set(EIGEN_BUILD_DOC OFF CACHE BOOL "" FORCE) + + FetchContent_Declare(eigen + URL ${eigen_URL} + URL_HASH ${eigen_HASH} + ) + + FetchContent_GetProperties(eigen) + if(NOT eigen_POPULATED) + message(STATUS "Downloading eigen from ${eigen_URL}") + FetchContent_Populate(eigen) + endif() + message(STATUS "eigen is downloaded to ${eigen_SOURCE_DIR}") + message(STATUS "eigen's binary dir is ${eigen_BINARY_DIR}") + + add_subdirectory(${eigen_SOURCE_DIR} ${eigen_BINARY_DIR} EXCLUDE_FROM_ALL) +endfunction() + +download_eigen() + diff --git a/apps/frameworks/sherpa-mnn/cmake/espeak-ng-for-piper.cmake b/apps/frameworks/sherpa-mnn/cmake/espeak-ng-for-piper.cmake new file mode 100644 index 00000000..0ef82530 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/espeak-ng-for-piper.cmake @@ -0,0 +1,134 @@ +function(download_espeak_ng_for_piper) + include(FetchContent) + + set(espeak_ng_URL "https://github.com/csukuangfj/espeak-ng/archive/f6fed6c58b5e0998b8e68c6610125e2d07d595a7.zip") + set(espeak_ng_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/espeak-ng-f6fed6c58b5e0998b8e68c6610125e2d07d595a7.zip") + set(espeak_ng_HASH "SHA256=70cbf4050e7a014aae19140b05e57249da4720f56128459fbe3a93beaf971ae6") + + set(BUILD_ESPEAK_NG_TESTS OFF CACHE BOOL "" FORCE) + set(USE_ASYNC OFF CACHE BOOL "" FORCE) + set(USE_MBROLA OFF CACHE BOOL "" FORCE) + set(USE_LIBSONIC OFF CACHE BOOL "" FORCE) + set(USE_LIBPCAUDIO OFF CACHE BOOL "" FORCE) + set(USE_KLATT OFF CACHE BOOL "" FORCE) + set(USE_SPEECHPLAYER OFF CACHE BOOL "" FORCE) + set(EXTRA_cmn ON CACHE BOOL "" FORCE) + set(EXTRA_ru ON CACHE BOOL "" FORCE) + if (NOT SHERPA_ONNX_ENABLE_EPSEAK_NG_EXE) + set(BUILD_ESPEAK_NG_EXE OFF CACHE BOOL "" FORCE) + endif() + + # If you don't have access to the Internet, + # please pre-download kaldi-decoder + set(possible_file_locations + $ENV{HOME}/Downloads/espeak-ng-f6fed6c58b5e0998b8e68c6610125e2d07d595a7.zip + ${CMAKE_SOURCE_DIR}/espeak-ng-f6fed6c58b5e0998b8e68c6610125e2d07d595a7.zip + ${CMAKE_BINARY_DIR}/espeak-ng-f6fed6c58b5e0998b8e68c6610125e2d07d595a7.zip + /tmp/espeak-ng-f6fed6c58b5e0998b8e68c6610125e2d07d595a7.zip + /star-fj/fangjun/download/github/espeak-ng-f6fed6c58b5e0998b8e68c6610125e2d07d595a7.zip + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(espeak_ng_URL "${f}") + file(TO_CMAKE_PATH "${espeak_ng_URL}" espeak_ng_URL) + message(STATUS "Found local downloaded espeak-ng: ${espeak_ng_URL}") + set(espeak_ng_URL2 ) + break() + endif() + endforeach() + + FetchContent_Declare(espeak_ng + URL + ${espeak_ng_URL} + ${espeak_ng_URL2} + URL_HASH ${espeak_ng_HASH} + ) + + FetchContent_GetProperties(espeak_ng) + if(NOT espeak_ng_POPULATED) + message(STATUS "Downloading espeak-ng from ${espeak_ng_URL}") + FetchContent_Populate(espeak_ng) + endif() + message(STATUS "espeak-ng is downloaded to ${espeak_ng_SOURCE_DIR}") + message(STATUS "espeak-ng binary dir is ${espeak_ng_BINARY_DIR}") + + if(BUILD_SHARED_LIBS) + set(_build_shared_libs_bak ${BUILD_SHARED_LIBS}) + set(BUILD_SHARED_LIBS OFF) + endif() + + add_subdirectory(${espeak_ng_SOURCE_DIR} ${espeak_ng_BINARY_DIR}) + + if(_build_shared_libs_bak) + set_target_properties(espeak-ng + PROPERTIES + POSITION_INDEPENDENT_CODE ON + C_VISIBILITY_PRESET hidden + CXX_VISIBILITY_PRESET hidden + ) + set(BUILD_SHARED_LIBS ON) + endif() + + set(espeak_ng_SOURCE_DIR ${espeak_ng_SOURCE_DIR} PARENT_SCOPE) + + if(WIN32 AND MSVC) + target_compile_options(ucd PUBLIC + /wd4309 + ) + + target_compile_options(espeak-ng PUBLIC + /wd4005 + /wd4018 + /wd4067 + /wd4068 + /wd4090 + /wd4101 + /wd4244 + /wd4267 + /wd4996 + ) + + if(TARGET espeak-ng-bin) + target_compile_options(espeak-ng-bin PRIVATE + /wd4244 + /wd4024 + /wd4047 + /wd4067 + /wd4267 + /wd4996 + ) + endif() + endif() + + if(UNIX AND NOT APPLE) + target_compile_options(espeak-ng PRIVATE + -Wno-unused-result + -Wno-format-overflow + -Wno-format-truncation + -Wno-uninitialized + -Wno-format + ) + + if(TARGET espeak-ng-bin) + target_compile_options(espeak-ng-bin PRIVATE + -Wno-unused-result + ) + endif() + endif() + + target_include_directories(espeak-ng + INTERFACE + ${espeak_ng_SOURCE_DIR}/src/include + ${espeak_ng_SOURCE_DIR}/src/ucd-tools/src/include + ) + + if(NOT BUILD_SHARED_LIBS) + install(TARGETS + espeak-ng + ucd + DESTINATION lib) + endif() +endfunction() + +download_espeak_ng_for_piper() diff --git a/apps/frameworks/sherpa-mnn/cmake/googletest.cmake b/apps/frameworks/sherpa-mnn/cmake/googletest.cmake new file mode 100644 index 00000000..a9bfd443 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/googletest.cmake @@ -0,0 +1,76 @@ +function(download_googltest) + include(FetchContent) + + set(googletest_URL "https://github.com/google/googletest/archive/refs/tags/v1.13.0.tar.gz") + set(googletest_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/googletest-1.13.0.tar.gz") + set(googletest_HASH "SHA256=ad7fdba11ea011c1d925b3289cf4af2c66a352e18d4c7264392fead75e919363") + + # If you don't have access to the Internet, + # please pre-download googletest + set(possible_file_locations + $ENV{HOME}/Downloads/googletest-1.13.0.tar.gz + ${CMAKE_SOURCE_DIR}/googletest-1.13.0.tar.gz + ${CMAKE_BINARY_DIR}/googletest-1.13.0.tar.gz + /tmp/googletest-1.13.0.tar.gz + /star-fj/fangjun/download/github/googletest-1.13.0.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(googletest_URL "${f}") + file(TO_CMAKE_PATH "${googletest_URL}" googletest_URL) + message(STATUS "Found local downloaded googletest: ${googletest_URL}") + set(googletest_URL2) + break() + endif() + endforeach() + + set(BUILD_GMOCK ON CACHE BOOL "" FORCE) + set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) + set(gtest_disable_pthreads ON CACHE BOOL "" FORCE) + set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + + FetchContent_Declare(googletest + URL + ${googletest_URL} + ${googletest_URL2} + URL_HASH ${googletest_HASH} + ) + + FetchContent_GetProperties(googletest) + if(NOT googletest_POPULATED) + message(STATUS "Downloading googletest from ${googletest_URL}") + FetchContent_Populate(googletest) + endif() + message(STATUS "googletest is downloaded to ${googletest_SOURCE_DIR}") + message(STATUS "googletest's binary dir is ${googletest_BINARY_DIR}") + + if(APPLE) + set(CMAKE_MACOSX_RPATH ON) # to solve the following warning on macOS + endif() + #[==[ + -- Generating done + Policy CMP0042 is not set: MACOSX_RPATH is enabled by default. Run "cmake + --help-policy CMP0042" for policy details. Use the cmake_policy command to + set the policy and suppress this warning. + + MACOSX_RPATH is not specified for the following targets: + + gmock + gmock_main + gtest + gtest_main + + This warning is for project developers. Use -Wno-dev to suppress it. + ]==] + + add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL) + + target_include_directories(gtest + INTERFACE + ${googletest_SOURCE_DIR}/googletest/include + ${googletest_SOURCE_DIR}/googlemock/include + ) +endfunction() + +download_googltest() diff --git a/apps/frameworks/sherpa-mnn/cmake/hclust-cpp.cmake b/apps/frameworks/sherpa-mnn/cmake/hclust-cpp.cmake new file mode 100644 index 00000000..c84ccafc --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/hclust-cpp.cmake @@ -0,0 +1,47 @@ +function(download_hclust_cpp) + include(FetchContent) + + # The latest commit as of 2024.09.29 + set(hclust_cpp_URL "https://github.com/csukuangfj/hclust-cpp/archive/refs/tags/2024-09-29.tar.gz") + set(hclust_cpp_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/hclust-cpp-2024-09-29.tar.gz") + set(hclust_cpp_HASH "SHA256=abab51448a3cb54272aae07522970306e0b2cc6479d59d7b19e7aee4d6cedd33") + + # If you don't have access to the Internet, + # please pre-download hclust-cpp + set(possible_file_locations + $ENV{HOME}/Downloads/hclust-cpp-2024-09-29.tar.gz + ${CMAKE_SOURCE_DIR}/hclust-cpp-2024-09-29.tar.gz + ${CMAKE_BINARY_DIR}/hclust-cpp-2024-09-29.tar.gz + /tmp/hclust-cpp-2024-09-29.tar.gz + /star-fj/fangjun/download/github/hclust-cpp-2024-09-29.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(hclust_cpp_URL "${f}") + file(TO_CMAKE_PATH "${hclust_cpp_URL}" hclust_cpp_URL) + message(STATUS "Found local downloaded hclust_cpp: ${hclust_cpp_URL}") + set(hclust_cpp_URL2) + break() + endif() + endforeach() + + FetchContent_Declare(hclust_cpp + URL + ${hclust_cpp_URL} + ${hclust_cpp_URL2} + URL_HASH ${hclust_cpp_HASH} + ) + + FetchContent_GetProperties(hclust_cpp) + if(NOT hclust_cpp_POPULATED) + message(STATUS "Downloading hclust_cpp from ${hclust_cpp_URL}") + FetchContent_Populate(hclust_cpp) + endif() + + message(STATUS "hclust_cpp is downloaded to ${hclust_cpp_SOURCE_DIR}") + message(STATUS "hclust_cpp's binary dir is ${hclust_cpp_BINARY_DIR}") + include_directories(${hclust_cpp_SOURCE_DIR}) +endfunction() + +download_hclust_cpp() diff --git a/apps/frameworks/sherpa-mnn/cmake/kaldi-decoder.cmake b/apps/frameworks/sherpa-mnn/cmake/kaldi-decoder.cmake new file mode 100644 index 00000000..91202342 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/kaldi-decoder.cmake @@ -0,0 +1,89 @@ +function(download_kaldi_decoder) + include(FetchContent) + + set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.6.tar.gz") + set(kaldi_decoder_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-decoder-0.2.6.tar.gz") + set(kaldi_decoder_HASH "SHA256=b13c78b37495cafc6ef3f8a7b661b349c55a51abbd7f7f42f389408dcf86a463") + + set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE) + set(KALDI_DECODER_ENABLE_TESTS OFF CACHE BOOL "" FORCE) + set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE) + + # If you don't have access to the Internet, + # please pre-download kaldi-decoder + set(possible_file_locations + $ENV{HOME}/Downloads/kaldi-decoder-0.2.6.tar.gz + ${CMAKE_SOURCE_DIR}/kaldi-decoder-0.2.6.tar.gz + ${CMAKE_BINARY_DIR}/kaldi-decoder-0.2.6.tar.gz + /tmp/kaldi-decoder-0.2.6.tar.gz + /star-fj/fangjun/download/github/kaldi-decoder-0.2.6.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(kaldi_decoder_URL "${f}") + file(TO_CMAKE_PATH "${kaldi_decoder_URL}" kaldi_decoder_URL) + message(STATUS "Found local downloaded kaldi-decoder: ${kaldi_decoder_URL}") + set(kaldi_decoder_URL2 ) + break() + endif() + endforeach() + + FetchContent_Declare(kaldi_decoder + URL + ${kaldi_decoder_URL} + ${kaldi_decoder_URL2} + URL_HASH ${kaldi_decoder_HASH} + ) + + FetchContent_GetProperties(kaldi_decoder) + if(NOT kaldi_decoder_POPULATED) + message(STATUS "Downloading kaldi-decoder from ${kaldi_decoder_URL}") + FetchContent_Populate(kaldi_decoder) + endif() + message(STATUS "kaldi-decoder is downloaded to ${kaldi_decoder_SOURCE_DIR}") + message(STATUS "kaldi-decoder's binary dir is ${kaldi_decoder_BINARY_DIR}") + + include_directories(${kaldi_decoder_SOURCE_DIR}) + + if(BUILD_SHARED_LIBS) + set(_build_shared_libs_bak ${BUILD_SHARED_LIBS}) + set(BUILD_SHARED_LIBS OFF) + endif() + + add_subdirectory(${kaldi_decoder_SOURCE_DIR} ${kaldi_decoder_BINARY_DIR} EXCLUDE_FROM_ALL) + + if(_build_shared_libs_bak) + set_target_properties( + kaldi-decoder-core + PROPERTIES + POSITION_INDEPENDENT_CODE ON + C_VISIBILITY_PRESET hidden + CXX_VISIBILITY_PRESET hidden + ) + set(BUILD_SHARED_LIBS ON) + endif() + + if(WIN32 AND MSVC) + target_compile_options(kaldi-decoder-core PUBLIC + /wd4018 + /wd4291 + ) + endif() + + target_include_directories(kaldi-decoder-core + INTERFACE + ${kaldi-decoder_SOURCE_DIR}/ + ) + if(NOT BUILD_SHARED_LIBS) + install(TARGETS + kaldi-decoder-core + kaldifst_core + fst + fstfar + DESTINATION lib) + endif() +endfunction() + +download_kaldi_decoder() + diff --git a/apps/frameworks/sherpa-mnn/cmake/kaldi-native-fbank.cmake b/apps/frameworks/sherpa-mnn/cmake/kaldi-native-fbank.cmake new file mode 100644 index 00000000..f7aba1b5 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/kaldi-native-fbank.cmake @@ -0,0 +1,74 @@ +function(download_kaldi_native_fbank) + include(FetchContent) + + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.21.1.tar.gz") + set(kaldi_native_fbank_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.21.1.tar.gz") + set(kaldi_native_fbank_HASH "SHA256=37c1aa230b00fe062791d800d8fc50aa3de215918d3dce6440699e67275d859e") + + set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) + set(KALDI_NATIVE_FBANK_ENABLE_CHECK OFF CACHE BOOL "" FORCE) + + # If you don't have access to the Internet, + # please pre-download kaldi-native-fbank + set(possible_file_locations + $ENV{HOME}/Downloads/kaldi-native-fbank-1.21.1.tar.gz + ${CMAKE_SOURCE_DIR}/kaldi-native-fbank-1.21.1.tar.gz + ${CMAKE_BINARY_DIR}/kaldi-native-fbank-1.21.1.tar.gz + /tmp/kaldi-native-fbank-1.21.1.tar.gz + /star-fj/fangjun/download/github/kaldi-native-fbank-1.21.1.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(kaldi_native_fbank_URL "${f}") + file(TO_CMAKE_PATH "${kaldi_native_fbank_URL}" kaldi_native_fbank_URL) + message(STATUS "Found local downloaded kaldi-native-fbank: ${kaldi_native_fbank_URL}") + set(kaldi_native_fbank_URL2 ) + break() + endif() + endforeach() + + FetchContent_Declare(kaldi_native_fbank + URL + ${kaldi_native_fbank_URL} + ${kaldi_native_fbank_URL2} + URL_HASH ${kaldi_native_fbank_HASH} + ) + + FetchContent_GetProperties(kaldi_native_fbank) + if(NOT kaldi_native_fbank_POPULATED) + message(STATUS "Downloading kaldi-native-fbank from ${kaldi_native_fbank_URL}") + FetchContent_Populate(kaldi_native_fbank) + endif() + message(STATUS "kaldi-native-fbank is downloaded to ${kaldi_native_fbank_SOURCE_DIR}") + message(STATUS "kaldi-native-fbank's binary dir is ${kaldi_native_fbank_BINARY_DIR}") + + if(BUILD_SHARED_LIBS) + set(_build_shared_libs_bak ${BUILD_SHARED_LIBS}) + set(BUILD_SHARED_LIBS OFF) + endif() + + add_subdirectory(${kaldi_native_fbank_SOURCE_DIR} ${kaldi_native_fbank_BINARY_DIR} EXCLUDE_FROM_ALL) + + if(_build_shared_libs_bak) + set_target_properties(kaldi-native-fbank-core + PROPERTIES + POSITION_INDEPENDENT_CODE ON + C_VISIBILITY_PRESET hidden + CXX_VISIBILITY_PRESET hidden + ) + set(BUILD_SHARED_LIBS ON) + endif() + + target_include_directories(kaldi-native-fbank-core + INTERFACE + ${kaldi_native_fbank_SOURCE_DIR}/ + ) + + if(NOT BUILD_SHARED_LIBS) + install(TARGETS kaldi-native-fbank-core DESTINATION lib) + endif() +endfunction() + +download_kaldi_native_fbank() diff --git a/apps/frameworks/sherpa-mnn/cmake/kaldifst.cmake b/apps/frameworks/sherpa-mnn/cmake/kaldifst.cmake new file mode 100644 index 00000000..e0c11baf --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/kaldifst.cmake @@ -0,0 +1,72 @@ +function(download_kaldifst) + include(FetchContent) + + set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.11.tar.gz") + set(kaldifst_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldifst-1.7.11.tar.gz") + set(kaldifst_HASH "SHA256=b43b3332faa2961edc730e47995a58cd4e22ead21905d55b0c4a41375b4a525f") + + # If you don't have access to the Internet, + # please pre-download kaldifst + set(possible_file_locations + $ENV{HOME}/Downloads/kaldifst-1.7.11.tar.gz + ${CMAKE_SOURCE_DIR}/kaldifst-1.7.11.tar.gz + ${CMAKE_BINARY_DIR}/kaldifst-1.7.11.tar.gz + /tmp/kaldifst-1.7.11.tar.gz + /star-fj/fangjun/download/github/kaldifst-1.7.11.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(kaldifst_URL "${f}") + file(TO_CMAKE_PATH "${kaldifst_URL}" kaldifst_URL) + message(STATUS "Found local downloaded kaldifst: ${kaldifst_URL}") + set(kaldifst_URL2) + break() + endif() + endforeach() + + set(KALDIFST_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE) + + FetchContent_Declare(kaldifst + URL ${kaldifst_URL} + URL_HASH ${kaldifst_HASH} + ) + + FetchContent_GetProperties(kaldifst) + if(NOT kaldifst_POPULATED) + message(STATUS "Downloading kaldifst from ${kaldifst_URL}") + FetchContent_Populate(kaldifst) + endif() + message(STATUS "kaldifst is downloaded to ${kaldifst_SOURCE_DIR}") + message(STATUS "kaldifst's binary dir is ${kaldifst_BINARY_DIR}") + + list(APPEND CMAKE_MODULE_PATH ${kaldifst_SOURCE_DIR}/cmake) + + if(BUILD_SHARED_LIBS) + set(_build_shared_libs_bak ${BUILD_SHARED_LIBS}) + set(BUILD_SHARED_LIBS OFF) + endif() + + add_subdirectory(${kaldifst_SOURCE_DIR} ${kaldifst_BINARY_DIR} EXCLUDE_FROM_ALL) + + if(_build_shared_libs_bak) + set_target_properties(kaldifst_core + PROPERTIES + POSITION_INDEPENDENT_CODE ON + C_VISIBILITY_PRESET hidden + CXX_VISIBILITY_PRESET hidden + ) + set(BUILD_SHARED_LIBS ON) + endif() + + target_include_directories(kaldifst_core + PUBLIC + ${kaldifst_SOURCE_DIR}/ + ) + + set_target_properties(kaldifst_core PROPERTIES OUTPUT_NAME "sherpa-mnn-kaldifst-core") + # installed in ./kaldi-decoder.cmake +endfunction() + +download_kaldifst() diff --git a/apps/frameworks/sherpa-mnn/cmake/openfst.cmake b/apps/frameworks/sherpa-mnn/cmake/openfst.cmake new file mode 100644 index 00000000..59211d34 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/openfst.cmake @@ -0,0 +1,109 @@ +# Copyright (c) 2020 Xiaomi Corporation (author: Fangjun Kuang) + +function(download_openfst) + include(FetchContent) + + set(openfst_URL "https://github.com/csukuangfj/openfst/archive/refs/tags/sherpa-onnx-2024-06-19.tar.gz") + set(openfst_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/openfst-sherpa-onnx-2024-06-19.tar.gz") + set(openfst_HASH "SHA256=5c98e82cc509c5618502dde4860b8ea04d843850ed57e6d6b590b644b268853d") + + # If you don't have access to the Internet, + # please pre-download it + set(possible_file_locations + $ENV{HOME}/Downloads/openfst-sherpa-onnx-2024-06-19.tar.gz + ${CMAKE_SOURCE_DIR}/openfst-sherpa-onnx-2024-06-19.tar.gz + ${CMAKE_BINARY_DIR}/openfst-sherpa-onnx-2024-06-19.tar.gz + /tmp/openfst-sherpa-onnx-2024-06-19.tar.gz + /star-fj/fangjun/download/github/openfst-sherpa-onnx-2024-06-19.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(openfst_URL "${f}") + file(TO_CMAKE_PATH "${openfst_URL}" openfst_URL) + set(openfst_URL2) + break() + endif() + endforeach() + + set(HAVE_BIN OFF CACHE BOOL "" FORCE) + set(HAVE_SCRIPT OFF CACHE BOOL "" FORCE) + set(HAVE_COMPACT OFF CACHE BOOL "" FORCE) + set(HAVE_COMPRESS OFF CACHE BOOL "" FORCE) + set(HAVE_CONST OFF CACHE BOOL "" FORCE) + set(HAVE_FAR ON CACHE BOOL "" FORCE) + set(HAVE_GRM OFF CACHE BOOL "" FORCE) + set(HAVE_PDT OFF CACHE BOOL "" FORCE) + set(HAVE_MPDT OFF CACHE BOOL "" FORCE) + set(HAVE_LINEAR OFF CACHE BOOL "" FORCE) + set(HAVE_LOOKAHEAD OFF CACHE BOOL "" FORCE) + set(HAVE_NGRAM OFF CACHE BOOL "" FORCE) + set(HAVE_PYTHON OFF CACHE BOOL "" FORCE) + set(HAVE_SPECIAL OFF CACHE BOOL "" FORCE) + + if(NOT WIN32) + FetchContent_Declare(openfst + URL + ${openfst_URL} + ${openfst_URL2} + URL_HASH ${openfst_HASH} + PATCH_COMMAND + sed -i.bak s/enable_testing\(\)//g "src/CMakeLists.txt" && + sed -i.bak s/add_subdirectory\(test\)//g "src/CMakeLists.txt" && + sed -i.bak /message/d "src/script/CMakeLists.txt" + # sed -i.bak s/add_subdirectory\(script\)//g "src/CMakeLists.txt" && + # sed -i.bak s/add_subdirectory\(extensions\)//g "src/CMakeLists.txt" + ) + else() + FetchContent_Declare(openfst + URL ${openfst_URL} + URL_HASH ${openfst_HASH} + ) + endif() + + FetchContent_GetProperties(openfst) + if(NOT openfst_POPULATED) + message(STATUS "Downloading openfst from ${openfst_URL}") + FetchContent_Populate(openfst) + endif() + message(STATUS "openfst is downloaded to ${openfst_SOURCE_DIR}") + + if(_build_shared_libs_bak) + set(_build_shared_libs_bak ${BUILD_SHARED_LIBS}) + set(BUILD_SHARED_LIBS OFF) + endif() + + add_subdirectory(${openfst_SOURCE_DIR} ${openfst_BINARY_DIR} EXCLUDE_FROM_ALL) + + if(_build_shared_libs_bak) + set_target_properties(fst fstfar + PROPERTIES + POSITION_INDEPENDENT_CODE ON + C_VISIBILITY_PRESET hidden + CXX_VISIBILITY_PRESET hidden + ) + set(BUILD_SHARED_LIBS ON) + endif() + + set(openfst_SOURCE_DIR ${openfst_SOURCE_DIR} PARENT_SCOPE) + + set_target_properties(fst PROPERTIES OUTPUT_NAME "sherpa-mnn-fst") + set_target_properties(fstfar PROPERTIES OUTPUT_NAME "sherpa-mnn-fstfar") + + if(LINUX) + target_compile_options(fst PUBLIC -Wno-missing-template-keyword) + endif() + + target_include_directories(fst + PUBLIC + ${openfst_SOURCE_DIR}/src/include + ) + + target_include_directories(fstfar + PUBLIC + ${openfst_SOURCE_DIR}/src/include + ) + # installed in ./kaldi-decoder.cmake +endfunction() + +download_openfst() diff --git a/apps/frameworks/sherpa-mnn/cmake/piper-phonemize.cmake b/apps/frameworks/sherpa-mnn/cmake/piper-phonemize.cmake new file mode 100644 index 00000000..0e11fd17 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/piper-phonemize.cmake @@ -0,0 +1,78 @@ +function(download_piper_phonemize) + include(FetchContent) + + set(piper_phonemize_URL "https://github.com/csukuangfj/piper-phonemize/archive/78a788e0b719013401572d70fef372e77bff8e43.zip") + set(piper_phonemize_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/piper-phonemize-78a788e0b719013401572d70fef372e77bff8e43.zip") + set(piper_phonemize_HASH "SHA256=89641a46489a4898754643ce57bda9c9b54b4ca46485fdc02bf0dc84b866645d") + + # If you don't have access to the Internet, + # please pre-download kaldi-decoder + set(possible_file_locations + $ENV{HOME}/Downloads/piper-phonemize-78a788e0b719013401572d70fef372e77bff8e43.zip + ${CMAKE_SOURCE_DIR}/piper-phonemize-78a788e0b719013401572d70fef372e77bff8e43.zip + ${CMAKE_BINARY_DIR}/piper-phonemize-78a788e0b719013401572d70fef372e77bff8e43.zip + /tmp/piper-phonemize-78a788e0b719013401572d70fef372e77bff8e43.zip + /star-fj/fangjun/download/github/piper-phonemize-78a788e0b719013401572d70fef372e77bff8e43.zip + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(piper_phonemize_URL "${f}") + file(TO_CMAKE_PATH "${piper_phonemize_URL}" piper_phonemize_URL) + message(STATUS "Found local downloaded espeak-ng: ${piper_phonemize_URL}") + set(piper_phonemize_URL2 ) + break() + endif() + endforeach() + + FetchContent_Declare(piper_phonemize + URL + ${piper_phonemize_URL} + ${piper_phonemize_URL2} + URL_HASH ${piper_phonemize_HASH} + ) + + FetchContent_GetProperties(piper_phonemize) + if(NOT piper_phonemize_POPULATED) + message(STATUS "Downloading piper-phonemize from ${piper_phonemize_URL}") + FetchContent_Populate(piper_phonemize) + endif() + message(STATUS "piper-phonemize is downloaded to ${piper_phonemize_SOURCE_DIR}") + message(STATUS "piper-phonemize binary dir is ${piper_phonemize_BINARY_DIR}") + + if(BUILD_SHARED_LIBS) + set(_build_shared_libs_bak ${BUILD_SHARED_LIBS}) + set(BUILD_SHARED_LIBS OFF) + endif() + + add_subdirectory(${piper_phonemize_SOURCE_DIR} ${piper_phonemize_BINARY_DIR} EXCLUDE_FROM_ALL) + + if(_build_shared_libs_bak) + set_target_properties(piper_phonemize + PROPERTIES + POSITION_INDEPENDENT_CODE ON + C_VISIBILITY_PRESET hidden + CXX_VISIBILITY_PRESET hidden + ) + set(BUILD_SHARED_LIBS ON) + endif() + + if(WIN32 AND MSVC) + target_compile_options(piper_phonemize PUBLIC + /wd4309 + ) + endif() + + target_include_directories(piper_phonemize + INTERFACE + ${piper_phonemize_SOURCE_DIR}/src/include + ) + + if(NOT BUILD_SHARED_LIBS) + install(TARGETS + piper_phonemize + DESTINATION lib) + endif() +endfunction() + +download_piper_phonemize() diff --git a/apps/frameworks/sherpa-mnn/cmake/portaudio.cmake b/apps/frameworks/sherpa-mnn/cmake/portaudio.cmake new file mode 100644 index 00000000..d8af8d49 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/portaudio.cmake @@ -0,0 +1,71 @@ +function(download_portaudio) + include(FetchContent) + + set(portaudio_URL "http://files.portaudio.com/archives/pa_stable_v190700_20210406.tgz") + set(portaudio_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/pa_stable_v190700_20210406.tgz") + set(portaudio_HASH "SHA256=47efbf42c77c19a05d22e627d42873e991ec0c1357219c0d74ce6a2948cb2def") + + # If you don't have access to the Internet, please download it to your + # local drive and modify the following line according to your needs. + set(possible_file_locations + $ENV{HOME}/Downloads/pa_stable_v190700_20210406.tgz + $ENV{HOME}/asr/pa_stable_v190700_20210406.tgz + ${CMAKE_SOURCE_DIR}/pa_stable_v190700_20210406.tgz + ${CMAKE_BINARY_DIR}/pa_stable_v190700_20210406.tgz + /tmp/pa_stable_v190700_20210406.tgz + /star-fj/fangjun/download/github/pa_stable_v190700_20210406.tgz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(portaudio_URL "${f}") + file(TO_CMAKE_PATH "${portaudio_URL}" portaudio_URL) + message(STATUS "Found local downloaded portaudio: ${portaudio_URL}") + set(portaudio_URL2) + break() + endif() + endforeach() + + # Always use static build + set(PA_BUILD_SHARED OFF CACHE BOOL "" FORCE) + set(PA_BUILD_STATIC ON CACHE BOOL "" FORCE) + + FetchContent_Declare(portaudio + URL + ${portaudio_URL} + ${portaudio_URL2} + URL_HASH ${portaudio_HASH} + ) + + FetchContent_GetProperties(portaudio) + if(NOT portaudio_POPULATED) + message(STATUS "Downloading portaudio from ${portaudio_URL}") + FetchContent_Populate(portaudio) + endif() + message(STATUS "portaudio is downloaded to ${portaudio_SOURCE_DIR}") + message(STATUS "portaudio's binary dir is ${portaudio_BINARY_DIR}") + + if(APPLE) + set(CMAKE_MACOSX_RPATH ON) # to solve the following warning on macOS + endif() + + add_subdirectory(${portaudio_SOURCE_DIR} ${portaudio_BINARY_DIR} EXCLUDE_FROM_ALL) + + set_target_properties(portaudio_static PROPERTIES OUTPUT_NAME "sherpa-onnx-portaudio_static") + if(NOT WIN32) + target_compile_options(portaudio_static PRIVATE "-Wno-deprecated-declarations") + endif() + + if(NOT BUILD_SHARED_LIBS AND SHERPA_ONNX_ENABLE_BINARY) + install(TARGETS + portaudio_static + DESTINATION lib) + endif() + +endfunction() + +download_portaudio() + +# Note +# See http://portaudio.com/docs/v19-doxydocs/tutorial_start.html +# for how to use portaudio diff --git a/apps/frameworks/sherpa-mnn/cmake/pybind11.cmake b/apps/frameworks/sherpa-mnn/cmake/pybind11.cmake new file mode 100644 index 00000000..bc06a3d1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/pybind11.cmake @@ -0,0 +1,44 @@ +function(download_pybind11) + include(FetchContent) + + set(pybind11_URL "https://github.com/pybind/pybind11/archive/refs/tags/v2.12.0.tar.gz") + set(pybind11_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/pybind11-2.12.0.tar.gz") + set(pybind11_HASH "SHA256=bf8f242abd1abcd375d516a7067490fb71abd79519a282d22b6e4d19282185a7") + + # If you don't have access to the Internet, + # please pre-download pybind11 + set(possible_file_locations + $ENV{HOME}/Downloads/pybind11-2.12.0.tar.gz + ${CMAKE_SOURCE_DIR}/pybind11-2.12.0.tar.gz + ${CMAKE_BINARY_DIR}/pybind11-2.12.0.tar.gz + /tmp/pybind11-2.12.0.tar.gz + /star-fj/fangjun/download/github/pybind11-2.12.0.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(pybind11_URL "${f}") + file(TO_CMAKE_PATH "${pybind11_URL}" pybind11_URL) + message(STATUS "Found local downloaded pybind11: ${pybind11_URL}") + set(pybind11_URL2) + break() + endif() + endforeach() + + FetchContent_Declare(pybind11 + URL + ${pybind11_URL} + ${pybind11_URL2} + URL_HASH ${pybind11_HASH} + ) + + FetchContent_GetProperties(pybind11) + if(NOT pybind11_POPULATED) + message(STATUS "Downloading pybind11 from ${pybind11_URL}") + FetchContent_Populate(pybind11) + endif() + message(STATUS "pybind11 is downloaded to ${pybind11_SOURCE_DIR}") + add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR} EXCLUDE_FROM_ALL) +endfunction() + +download_pybind11() diff --git a/apps/frameworks/sherpa-mnn/cmake/sherpa-onnx-shared.pc.in b/apps/frameworks/sherpa-mnn/cmake/sherpa-onnx-shared.pc.in new file mode 100644 index 00000000..e0b0eea1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/sherpa-onnx-shared.pc.in @@ -0,0 +1,25 @@ +# Note: If you use Python, then the prefix might not be correct. +# +# You need to either manually modify this file to change the prefix to the location +# where this sherpa-onnx.pc file actually resides +# or +# you can use +# +# pkg-config --define-variable=prefix=/path/to/the/dir/containing/this/file --cflags sherpa-onnx + +prefix="@CMAKE_INSTALL_PREFIX@" +exec_prefix="${prefix}" +includedir="${prefix}/include" +libdir="${exec_prefix}/lib" + +Name: sherpa-onnx +Description: pkg-config for sherpa-onnx +URL: https://github.com/k2-fsa/sherpa-onnx + +Version: @SHERPA_ONNX_VERSION@ +Cflags: -I"${includedir}" + +# Note: -lcargs is required only for the following file +# https://github.com/k2-fsa/sherpa-onnx/blob/master/c-api-examples/decode-file-c-api.c +# We add it here so that users don't need to specify -lcargs when compiling decode-file-c-api.c +Libs: -L"${libdir}" -lsherpa-onnx-c-api -lonnxruntime -Wl,-rpath,${libdir} @SHERPA_ONNX_PKG_WITH_CARGS@ @SHERPA_ONNX_PKG_CONFIG_EXTRA_LIBS@ diff --git a/apps/frameworks/sherpa-mnn/cmake/sherpa-onnx-static-no-tts.pc.in b/apps/frameworks/sherpa-mnn/cmake/sherpa-onnx-static-no-tts.pc.in new file mode 100644 index 00000000..f100cd95 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/sherpa-onnx-static-no-tts.pc.in @@ -0,0 +1,25 @@ +# Note: If you use Python, then the prefix might not be correct. +# +# You need to either manually modify this file to change the prefix to the location +# where this sherpa-onnx.pc file actually resides +# or +# you can use +# +# pkg-config --define-variable=prefix=/path/to/the/dir/containing/this/file --cflags sherpa-onnx + +prefix="@CMAKE_INSTALL_PREFIX@" +exec_prefix="${prefix}" +includedir="${prefix}/include" +libdir="${exec_prefix}/lib" + +Name: sherpa-onnx +Description: pkg-config for sherpa-onnx with TTS support +URL: https://github.com/k2-fsa/sherpa-onnx + +Version: @SHERPA_ONNX_VERSION@ +Cflags: -I"${includedir}" + +# Note: -lcargs is required only for the following file +# https://github.com/k2-fsa/sherpa-onnx/blob/master/c-api-examples/decode-file-c-api.c +# We add it here so that users don't need to specify -lcargs when compiling decode-file-c-api.c +Libs: -L"${libdir}" -lsherpa-onnx-c-api -lsherpa-onnx-core -lkaldi-decoder-core -lsherpa-onnx-kaldifst-core -lsherpa-onnx-fst -lkaldi-native-fbank-core -lonnxruntime -lssentencepiece_core -Wl,-rpath,${libdir} @SHERPA_ONNX_PKG_WITH_CARGS@ @SHERPA_ONNX_PKG_CONFIG_EXTRA_LIBS@ diff --git a/apps/frameworks/sherpa-mnn/cmake/sherpa-onnx-static.pc.in b/apps/frameworks/sherpa-mnn/cmake/sherpa-onnx-static.pc.in new file mode 100644 index 00000000..1f788b00 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/sherpa-onnx-static.pc.in @@ -0,0 +1,25 @@ +# Note: If you use Python, then the prefix might not be correct. +# +# You need to either manually modify this file to change the prefix to the location +# where this sherpa-onnx.pc file actually resides +# or +# you can use +# +# pkg-config --define-variable=prefix=/path/to/the/dir/containing/this/file --cflags sherpa-onnx + +prefix="@CMAKE_INSTALL_PREFIX@" +exec_prefix="${prefix}" +includedir="${prefix}/include" +libdir="${exec_prefix}/lib" + +Name: sherpa-onnx +Description: pkg-config for sherpa-onnx +URL: https://github.com/k2-fsa/sherpa-onnx + +Version: @SHERPA_ONNX_VERSION@ +Cflags: -I"${includedir}" + +# Note: -lcargs is required only for the following file +# https://github.com/k2-fsa/sherpa-onnx/blob/master/c-api-examples/decode-file-c-api.c +# We add it here so that users don't need to specify -lcargs when compiling decode-file-c-api.c +Libs: -L"${libdir}" -lsherpa-onnx-c-api -lsherpa-onnx-core -lkaldi-decoder-core -lsherpa-onnx-kaldifst-core -lsherpa-onnx-fstfar -lsherpa-onnx-fst -lkaldi-native-fbank-core -lpiper_phonemize -lespeak-ng -lucd -lonnxruntime -lssentencepiece_core -Wl,-rpath,${libdir} @SHERPA_ONNX_PKG_WITH_CARGS@ @SHERPA_ONNX_PKG_CONFIG_EXTRA_LIBS@ diff --git a/apps/frameworks/sherpa-mnn/cmake/simple-sentencepiece.cmake b/apps/frameworks/sherpa-mnn/cmake/simple-sentencepiece.cmake new file mode 100644 index 00000000..4b6750d0 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/simple-sentencepiece.cmake @@ -0,0 +1,73 @@ +function(download_simple_sentencepiece) + include(FetchContent) + + set(simple-sentencepiece_URL "https://github.com/pkufool/simple-sentencepiece/archive/refs/tags/v0.7.tar.gz") + set(simple-sentencepiece_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/simple-sentencepiece-0.7.tar.gz") + set(simple-sentencepiece_HASH "SHA256=1748a822060a35baa9f6609f84efc8eb54dc0e74b9ece3d82367b7119fdc75af") + + # If you don't have access to the Internet, + # please pre-download simple-sentencepiece + set(possible_file_locations + $ENV{HOME}/Downloads/simple-sentencepiece-0.7.tar.gz + ${CMAKE_SOURCE_DIR}/simple-sentencepiece-0.7.tar.gz + ${CMAKE_BINARY_DIR}/simple-sentencepiece-0.7.tar.gz + /tmp/simple-sentencepiece-0.7.tar.gz + /star-fj/fangjun/download/github/simple-sentencepiece-0.7.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(simple-sentencepiece_URL "${f}") + file(TO_CMAKE_PATH "${simple-sentencepiece_URL}" simple-sentencepiece_URL) + message(STATUS "Found local downloaded simple-sentencepiece: ${simple-sentencepiece_URL}") + set(simple-sentencepiece_URL2) + break() + endif() + endforeach() + + set(SBPE_ENABLE_TESTS OFF CACHE BOOL "" FORCE) + set(SBPE_BUILD_PYTHON OFF CACHE BOOL "" FORCE) + + FetchContent_Declare(simple-sentencepiece + URL + ${simple-sentencepiece_URL} + ${simple-sentencepiece_URL2} + URL_HASH + ${simple-sentencepiece_HASH} + ) + + FetchContent_GetProperties(simple-sentencepiece) + if(NOT simple-sentencepiece_POPULATED) + message(STATUS "Downloading simple-sentencepiece ${simple-sentencepiece_URL}") + FetchContent_Populate(simple-sentencepiece) + endif() + message(STATUS "simple-sentencepiece is downloaded to ${simple-sentencepiece_SOURCE_DIR}") + + if(BUILD_SHARED_LIBS) + set(_build_shared_libs_bak ${BUILD_SHARED_LIBS}) + set(BUILD_SHARED_LIBS OFF) + endif() + + add_subdirectory(${simple-sentencepiece_SOURCE_DIR} ${simple-sentencepiece_BINARY_DIR} EXCLUDE_FROM_ALL) + + if(_build_shared_libs_bak) + set_target_properties(ssentencepiece_core + PROPERTIES + POSITION_INDEPENDENT_CODE ON + C_VISIBILITY_PRESET hidden + CXX_VISIBILITY_PRESET hidden + ) + set(BUILD_SHARED_LIBS ON) + endif() + + target_include_directories(ssentencepiece_core + PUBLIC + ${simple-sentencepiece_SOURCE_DIR}/ + ) + + if(NOT BUILD_SHARED_LIBS) + install(TARGETS ssentencepiece_core DESTINATION lib) + endif() +endfunction() + +download_simple_sentencepiece() diff --git a/apps/frameworks/sherpa-mnn/cmake/websocketpp.cmake b/apps/frameworks/sherpa-mnn/cmake/websocketpp.cmake new file mode 100644 index 00000000..79b0585b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cmake/websocketpp.cmake @@ -0,0 +1,46 @@ +function(download_websocketpp) + include(FetchContent) + + # The latest commit on the develop branch os as 2022-10-22 + set(websocketpp_URL "https://github.com/zaphoyd/websocketpp/archive/b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip") + set(websocketpp_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip") + set(websocketpp_HASH "SHA256=1385135ede8191a7fbef9ec8099e3c5a673d48df0c143958216cd1690567f583") + + # If you don't have access to the Internet, + # please pre-download websocketpp + set(possible_file_locations + $ENV{HOME}/Downloads/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip + ${CMAKE_SOURCE_DIR}/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip + ${CMAKE_BINARY_DIR}/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip + /tmp/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip + /star-fj/fangjun/download/github/websocketpp-b9aeec6eaf3d5610503439b4fae3581d9aff08e8.zip + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(websocketpp_URL "${f}") + file(TO_CMAKE_PATH "${websocketpp_URL}" websocketpp_URL) + message(STATUS "Found local downloaded websocketpp: ${websocketpp_URL}") + set(websocketpp_URL2) + break() + endif() + endforeach() + + FetchContent_Declare(websocketpp + URL + ${websocketpp_URL} + ${websocketpp_URL2} + URL_HASH ${websocketpp_HASH} + ) + + FetchContent_GetProperties(websocketpp) + if(NOT websocketpp_POPULATED) + message(STATUS "Downloading websocketpp from ${websocketpp_URL}") + FetchContent_Populate(websocketpp) + endif() + message(STATUS "websocketpp is downloaded to ${websocketpp_SOURCE_DIR}") + # add_subdirectory(${websocketpp_SOURCE_DIR} ${websocketpp_BINARY_DIR} EXCLUDE_FROM_ALL) + include_directories(${websocketpp_SOURCE_DIR}) +endfunction() + +download_websocketpp() diff --git a/apps/frameworks/sherpa-mnn/cxx-api-examples/CMakeLists.txt b/apps/frameworks/sherpa-mnn/cxx-api-examples/CMakeLists.txt new file mode 100644 index 00000000..447934eb --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cxx-api-examples/CMakeLists.txt @@ -0,0 +1,39 @@ +include_directories(${CMAKE_SOURCE_DIR}) + +add_executable(streaming-zipformer-cxx-api ./streaming-zipformer-cxx-api.cc) +target_link_libraries(streaming-zipformer-cxx-api sherpa-mnn-cxx-api) + +add_executable(speech-enhancement-gtcrn-cxx-api ./speech-enhancement-gtcrn-cxx-api.cc) +target_link_libraries(speech-enhancement-gtcrn-cxx-api sherpa-mnn-cxx-api) + +add_executable(kws-cxx-api ./kws-cxx-api.cc) +target_link_libraries(kws-cxx-api sherpa-mnn-cxx-api) + +add_executable(streaming-zipformer-rtf-cxx-api ./streaming-zipformer-rtf-cxx-api.cc) +target_link_libraries(streaming-zipformer-rtf-cxx-api sherpa-mnn-cxx-api) + +add_executable(whisper-cxx-api ./whisper-cxx-api.cc) +target_link_libraries(whisper-cxx-api sherpa-mnn-cxx-api) + +add_executable(fire-red-asr-cxx-api ./fire-red-asr-cxx-api.cc) +target_link_libraries(fire-red-asr-cxx-api sherpa-mnn-cxx-api) + +add_executable(moonshine-cxx-api ./moonshine-cxx-api.cc) +target_link_libraries(moonshine-cxx-api sherpa-mnn-cxx-api) + +add_executable(sense-voice-cxx-api ./sense-voice-cxx-api.cc) +target_link_libraries(sense-voice-cxx-api sherpa-mnn-cxx-api) + +if(SHERPA_MNN_ENABLE_TTS) + add_executable(matcha-tts-zh-cxx-api ./matcha-tts-zh-cxx-api.cc) + target_link_libraries(matcha-tts-zh-cxx-api sherpa-mnn-cxx-api) + + add_executable(matcha-tts-en-cxx-api ./matcha-tts-en-cxx-api.cc) + target_link_libraries(matcha-tts-en-cxx-api sherpa-mnn-cxx-api) + + add_executable(kokoro-tts-en-cxx-api ./kokoro-tts-en-cxx-api.cc) + target_link_libraries(kokoro-tts-en-cxx-api sherpa-mnn-cxx-api) + + add_executable(kokoro-tts-zh-en-cxx-api ./kokoro-tts-zh-en-cxx-api.cc) + target_link_libraries(kokoro-tts-zh-en-cxx-api sherpa-mnn-cxx-api) +endif() diff --git a/apps/frameworks/sherpa-mnn/cxx-api-examples/fire-red-asr-cxx-api.cc b/apps/frameworks/sherpa-mnn/cxx-api-examples/fire-red-asr-cxx-api.cc new file mode 100644 index 00000000..b93bdd3f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cxx-api-examples/fire-red-asr-cxx-api.cc @@ -0,0 +1,77 @@ +// cxx-api-examples/fire-red-asr-cxx-api.cc +// Copyright (c) 2025 Xiaomi Corporation + +// +// This file demonstrates how to use FireRedAsr AED with sherpa-onnx's C++ API. +// +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 +// tar xvf sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 +// rm sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 +// +// clang-format on + +#include // NOLINT +#include +#include + +#include "sherpa-mnn/c-api/cxx-api.h" + +int32_t main() { + using namespace sherpa_mnn::cxx; // NOLINT + OfflineRecognizerConfig config; + + config.model_config.fire_red_asr.encoder = + "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/encoder.int8.onnx"; + config.model_config.fire_red_asr.decoder = + "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/decoder.int8.onnx"; + config.model_config.tokens = + "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/tokens.txt"; + + config.model_config.num_threads = 1; + + std::cout << "Loading model\n"; + OfflineRecognizer recongizer = OfflineRecognizer::Create(config); + if (!recongizer.Get()) { + std::cerr << "Please check your config\n"; + return -1; + } + std::cout << "Loading model done\n"; + + std::string wave_filename = + "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/0.wav"; + Wave wave = ReadWave(wave_filename); + if (wave.samples.empty()) { + std::cerr << "Failed to read: '" << wave_filename << "'\n"; + return -1; + } + + std::cout << "Start recognition\n"; + const auto begin = std::chrono::steady_clock::now(); + + OfflineStream stream = recongizer.CreateStream(); + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), + wave.samples.size()); + + recongizer.Decode(&stream); + + OfflineRecognizerResult result = recongizer.GetResult(&stream); + + const auto end = std::chrono::steady_clock::now(); + const float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + float duration = wave.samples.size() / static_cast(wave.sample_rate); + float rtf = elapsed_seconds / duration; + + std::cout << "text: " << result.text << "\n"; + printf("Number of threads: %d\n", config.model_config.num_threads); + printf("Duration: %.3fs\n", duration); + printf("Elapsed seconds: %.3fs\n", elapsed_seconds); + printf("(Real time factor) RTF = %.3f / %.3f = %.3f\n", elapsed_seconds, + duration, rtf); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/cxx-api-examples/kokoro-tts-en-cxx-api.cc b/apps/frameworks/sherpa-mnn/cxx-api-examples/kokoro-tts-en-cxx-api.cc new file mode 100644 index 00000000..5364f84f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cxx-api-examples/kokoro-tts-en-cxx-api.cc @@ -0,0 +1,73 @@ +// cxx-api-examples/kokoro-tts-en-cxx-api.c +// +// Copyright (c) 2025 Xiaomi Corporation + +// This file shows how to use sherpa-onnx CXX API +// for English TTS with Kokoro. +// +// clang-format off +/* +Usage + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/kokoro-en-v0_19.tar.bz2 +tar xf kokoro-en-v0_19.tar.bz2 +rm kokoro-en-v0_19.tar.bz2 + +./kokoro-tts-en-cxx-api + + */ +// clang-format on + +#include + +#include "sherpa-mnn/c-api/cxx-api.h" + +static int32_t ProgressCallback(const float *samples, int32_t num_samples, + float progress, void *arg) { + fprintf(stderr, "Progress: %.3f%%\n", progress * 100); + // return 1 to continue generating + // return 0 to stop generating + return 1; +} + +int32_t main(int32_t argc, char *argv[]) { + using namespace sherpa_mnn::cxx; // NOLINT + OfflineTtsConfig config; + + config.model.kokoro.model = "./kokoro-en-v0_19/model.onnx"; + config.model.kokoro.voices = "./kokoro-en-v0_19/voices.bin"; + config.model.kokoro.tokens = "./kokoro-en-v0_19/tokens.txt"; + config.model.kokoro.data_dir = "./kokoro-en-v0_19/espeak-ng-data"; + + config.model.num_threads = 2; + + // If you don't want to see debug messages, please set it to 0 + config.model.debug = 1; + + std::string filename = "./generated-kokoro-en-cxx.wav"; + std::string text = + "Today as always, men fall into two groups: slaves and free men. Whoever " + "does not have two-thirds of his day for himself, is a slave, whatever " + "he may be: a statesman, a businessman, an official, or a scholar. " + "Friends fell out often because life was changing so fast. The easiest " + "thing in the world was to lose touch with someone."; + + auto tts = OfflineTts::Create(config); + int32_t sid = 0; + float speed = 1.0; // larger -> faster in speech speed + +#if 0 + // If you don't want to use a callback, then please enable this branch + GeneratedAudio audio = tts.Generate(text, sid, speed); +#else + GeneratedAudio audio = tts.Generate(text, sid, speed, ProgressCallback); +#endif + + WriteWave(filename, {audio.samples, audio.sample_rate}); + + fprintf(stderr, "Input text is: %s\n", text.c_str()); + fprintf(stderr, "Speaker ID is is: %d\n", sid); + fprintf(stderr, "Saved to: %s\n", filename.c_str()); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/cxx-api-examples/kokoro-tts-zh-en-cxx-api.cc b/apps/frameworks/sherpa-mnn/cxx-api-examples/kokoro-tts-zh-en-cxx-api.cc new file mode 100644 index 00000000..cb0c8775 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cxx-api-examples/kokoro-tts-zh-en-cxx-api.cc @@ -0,0 +1,74 @@ +// cxx-api-examples/kokoro-tts-zh-en-cxx-api.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +// This file shows how to use sherpa-onnx CXX API +// for Chinese + English TTS with Kokoro. +// +// clang-format off +/* +Usage + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/kokoro-multi-lang-v1_0.tar.bz2 +tar xf kokoro-multi-lang-v1_0.tar.bz2 +rm kokoro-multi-lang-v1_0.tar.bz2 + +./kokoro-tts-zh-en-cxx-api + + */ +// clang-format on + +#include + +#include "sherpa-mnn/c-api/cxx-api.h" + +static int32_t ProgressCallback(const float *samples, int32_t num_samples, + float progress, void *arg) { + fprintf(stderr, "Progress: %.3f%%\n", progress * 100); + // return 1 to continue generating + // return 0 to stop generating + return 1; +} + +int32_t main(int32_t argc, char *argv[]) { + using namespace sherpa_mnn::cxx; // NOLINT + OfflineTtsConfig config; + + config.model.kokoro.model = "./kokoro-multi-lang-v1_0/model.onnx"; + config.model.kokoro.voices = "./kokoro-multi-lang-v1_0/voices.bin"; + config.model.kokoro.tokens = "./kokoro-multi-lang-v1_0/tokens.txt"; + config.model.kokoro.data_dir = "./kokoro-multi-lang-v1_0/espeak-ng-data"; + config.model.kokoro.dict_dir = "./kokoro-multi-lang-v1_0/dict"; + config.model.kokoro.lexicon = + "./kokoro-multi-lang-v1_0/lexicon-us-en.txt,./kokoro-multi-lang-v1_0/" + "lexicon-zh.txt"; + + config.model.num_threads = 2; + + // If you don't want to see debug messages, please set it to 0 + config.model.debug = 1; + + std::string filename = "./generated-kokoro-zh-en-cxx.wav"; + std::string text = + "中英文语音合成测试。This is generated by next generation Kaldi using " + "Kokoro without Misaki. 你觉得中英文说的如何呢?"; + + auto tts = OfflineTts::Create(config); + int32_t sid = 50; + float speed = 1.0; // larger -> faster in speech speed + +#if 0 + // If you don't want to use a callback, then please enable this branch + GeneratedAudio audio = tts.Generate(text, sid, speed); +#else + GeneratedAudio audio = tts.Generate(text, sid, speed, ProgressCallback); +#endif + + WriteWave(filename, {audio.samples, audio.sample_rate}); + + fprintf(stderr, "Input text is: %s\n", text.c_str()); + fprintf(stderr, "Speaker ID is is: %d\n", sid); + fprintf(stderr, "Saved to: %s\n", filename.c_str()); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/cxx-api-examples/kws-cxx-api.cc b/apps/frameworks/sherpa-mnn/cxx-api-examples/kws-cxx-api.cc new file mode 100644 index 00000000..2ef41536 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cxx-api-examples/kws-cxx-api.cc @@ -0,0 +1,143 @@ +// cxx-api-examples/kws-cxx-api.cc +// +// Copyright (c) 2025 Xiaomi Corporation +// +// This file demonstrates how to use keywords spotter with sherpa-onnx's C +// clang-format off +// +// Usage +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile.tar.bz2 +// +// ./kws-cxx-api +// +// clang-format on +#include +#include + +#include "sherpa-mnn/c-api/cxx-api.h" + +int32_t main() { + using namespace sherpa_mnn::cxx; // NOLINT + + KeywordSpotterConfig config; + config.model_config.transducer.encoder = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx"; + + config.model_config.transducer.decoder = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "decoder-epoch-12-avg-2-chunk-16-left-64.onnx"; + + config.model_config.transducer.joiner = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx"; + + config.model_config.tokens = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "tokens.txt"; + + config.model_config.provider = "cpu"; + config.model_config.num_threads = 1; + config.model_config.debug = 1; + + config.keywords_file = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "test_wavs/test_keywords.txt"; + + KeywordSpotter kws = KeywordSpotter::Create(config); + if (!kws.Get()) { + std::cerr << "Please check your config\n"; + return -1; + } + + std::cout + << "--Test pre-defined keywords from test_wavs/test_keywords.txt--\n"; + + std::string wave_filename = + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01-mobile/" + "test_wavs/3.wav"; + + std::array tail_paddings = {0}; // 0.5 seconds + + Wave wave = ReadWave(wave_filename); + if (wave.samples.empty()) { + std::cerr << "Failed to read: '" << wave_filename << "'\n"; + return -1; + } + + OnlineStream stream = kws.CreateStream(); + if (!stream.Get()) { + std::cerr << "Failed to create stream\n"; + return -1; + } + + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), + wave.samples.size()); + + stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(), + tail_paddings.size()); + stream.InputFinished(); + + while (kws.IsReady(&stream)) { + kws.Decode(&stream); + auto r = kws.GetResult(&stream); + if (!r.keyword.empty()) { + std::cout << "Detected keyword: " << r.json << "\n"; + + // Remember to reset the keyword stream right after a keyword is detected + kws.Reset(&stream); + } + } + + // -------------------------------------------------------------------------- + + std::cout << "--Use pre-defined keywords + add a new keyword--\n"; + + stream = kws.CreateStream("y ǎn y uán @演员"); + + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), + wave.samples.size()); + + stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(), + tail_paddings.size()); + stream.InputFinished(); + + while (kws.IsReady(&stream)) { + kws.Decode(&stream); + auto r = kws.GetResult(&stream); + if (!r.keyword.empty()) { + std::cout << "Detected keyword: " << r.json << "\n"; + + // Remember to reset the keyword stream right after a keyword is detected + kws.Reset(&stream); + } + } + + // -------------------------------------------------------------------------- + + std::cout << "--Use pre-defined keywords + add two new keywords--\n"; + + stream = kws.CreateStream("y ǎn y uán @演员/zh ī m íng @知名"); + + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), + wave.samples.size()); + + stream.AcceptWaveform(wave.sample_rate, tail_paddings.data(), + tail_paddings.size()); + stream.InputFinished(); + + while (kws.IsReady(&stream)) { + kws.Decode(&stream); + auto r = kws.GetResult(&stream); + if (!r.keyword.empty()) { + std::cout << "Detected keyword: " << r.json << "\n"; + + // Remember to reset the keyword stream right after a keyword is detected + kws.Reset(&stream); + } + } + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/cxx-api-examples/matcha-tts-en-cxx-api.cc b/apps/frameworks/sherpa-mnn/cxx-api-examples/matcha-tts-en-cxx-api.cc new file mode 100644 index 00000000..833bd9fd --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cxx-api-examples/matcha-tts-en-cxx-api.cc @@ -0,0 +1,80 @@ +// cxx-api-examples/matcha-tts-en-cxx-api.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +// This file shows how to use sherpa-onnx CXX API +// for Chinese TTS with MatchaTTS. +// +// clang-format off +/* +Usage + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-en_US-ljspeech.tar.bz2 +tar xvf matcha-icefall-en_US-ljspeech.tar.bz2 +rm matcha-icefall-en_US-ljspeech.tar.bz2 + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx + +./matcha-tts-en-cxx-api + + */ +// clang-format on + +#include + +#include "sherpa-mnn/c-api/cxx-api.h" + +static int32_t ProgressCallback(const float *samples, int32_t num_samples, + float progress, void *arg) { + fprintf(stderr, "Progress: %.3f%%\n", progress * 100); + // return 1 to continue generating + // return 0 to stop generating + return 1; +} + +int32_t main(int32_t argc, char *argv[]) { + using namespace sherpa_mnn::cxx; // NOLINT + OfflineTtsConfig config; + + config.model.matcha.acoustic_model = + "./matcha-icefall-en_US-ljspeech/model-steps-3.onnx"; + + config.model.matcha.vocoder = "./hifigan_v2.onnx"; + + config.model.matcha.tokens = "./matcha-icefall-en_US-ljspeech/tokens.txt"; + + config.model.matcha.data_dir = + "./matcha-icefall-en_US-ljspeech/espeak-ng-data"; + + config.model.num_threads = 1; + + // If you don't want to see debug messages, please set it to 0 + config.model.debug = 1; + + std::string filename = "./generated-matcha-en-cxx.wav"; + std::string text = + "Today as always, men fall into two groups: slaves and free men. Whoever " + "does not have two-thirds of his day for himself, is a slave, whatever " + "he may be: a statesman, a businessman, an official, or a scholar. " + "Friends fell out often because life was changing so fast. The easiest " + "thing in the world was to lose touch with someone."; + + auto tts = OfflineTts::Create(config); + int32_t sid = 0; + float speed = 1.0; // larger -> faster in speech speed + +#if 0 + // If you don't want to use a callback, then please enable this branch + GeneratedAudio audio = tts.Generate(text, sid, speed); +#else + GeneratedAudio audio = tts.Generate(text, sid, speed, ProgressCallback); +#endif + + WriteWave(filename, {audio.samples, audio.sample_rate}); + + fprintf(stderr, "Input text is: %s\n", text.c_str()); + fprintf(stderr, "Speaker ID is is: %d\n", sid); + fprintf(stderr, "Saved to: %s\n", filename.c_str()); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/cxx-api-examples/matcha-tts-zh-cxx-api.cc b/apps/frameworks/sherpa-mnn/cxx-api-examples/matcha-tts-zh-cxx-api.cc new file mode 100644 index 00000000..93eb72e9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cxx-api-examples/matcha-tts-zh-cxx-api.cc @@ -0,0 +1,79 @@ +// cxx-api-examples/matcha-tts-zh-cxx-api.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +// This file shows how to use sherpa-onnx CXX API +// for Chinese TTS with MatchaTTS. +// +// clang-format off +/* +Usage + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2 +tar xvf matcha-icefall-zh-baker.tar.bz2 +rm matcha-icefall-zh-baker.tar.bz2 + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx + +./matcha-tts-zh-cxx-api + + */ +// clang-format on + +#include + +#include "sherpa-mnn/c-api/cxx-api.h" + +static int32_t ProgressCallback(const float *samples, int32_t num_samples, + float progress, void *arg) { + fprintf(stderr, "Progress: %.3f%%\n", progress * 100); + // return 1 to continue generating + // return 0 to stop generating + return 1; +} + +int32_t main(int32_t argc, char *argv[]) { + using namespace sherpa_mnn::cxx; // NOLINT + OfflineTtsConfig config; + config.model.matcha.acoustic_model = + "./matcha-icefall-zh-baker/model-steps-3.onnx"; + config.model.matcha.vocoder = "./hifigan_v2.onnx"; + config.model.matcha.lexicon = "./matcha-icefall-zh-baker/lexicon.txt"; + config.model.matcha.tokens = "./matcha-icefall-zh-baker/tokens.txt"; + config.model.matcha.dict_dir = "./matcha-icefall-zh-baker/dict"; + config.model.num_threads = 1; + + // If you don't want to see debug messages, please set it to 0 + config.model.debug = 1; + + // clang-format off + config.rule_fsts = "./matcha-icefall-zh-baker/phone.fst,./matcha-icefall-zh-baker/date.fst,./matcha-icefall-zh-baker/number.fst"; // NOLINT + // clang-format on + + std::string filename = "./generated-matcha-zh-cxx.wav"; + std::string text = + "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如" + "涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感" + "受着生命的奇迹与温柔." + "某某银行的副行长和一些行政领导表示,他们去过长江和长白山; " + "经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。"; + + auto tts = OfflineTts::Create(config); + int32_t sid = 0; + float speed = 1.0; // larger -> faster in speech speed + +#if 0 + // If you don't want to use a callback, then please enable this branch + GeneratedAudio audio = tts.Generate(text, sid, speed); +#else + GeneratedAudio audio = tts.Generate(text, sid, speed, ProgressCallback); +#endif + + WriteWave(filename, {audio.samples, audio.sample_rate}); + + fprintf(stderr, "Input text is: %s\n", text.c_str()); + fprintf(stderr, "Speaker ID is is: %d\n", sid); + fprintf(stderr, "Saved to: %s\n", filename.c_str()); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/cxx-api-examples/moonshine-cxx-api.cc b/apps/frameworks/sherpa-mnn/cxx-api-examples/moonshine-cxx-api.cc new file mode 100644 index 00000000..6cbc4192 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cxx-api-examples/moonshine-cxx-api.cc @@ -0,0 +1,81 @@ +// cxx-api-examples/moonshine-cxx-api.cc +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use Moonshine with sherpa-onnx's C++ API. +// +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +// tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +// rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +// +// clang-format on + +#include // NOLINT +#include +#include + +#include "sherpa-mnn/c-api/cxx-api.h" + +int32_t main() { + using namespace sherpa_mnn::cxx; // NOLINT + OfflineRecognizerConfig config; + + config.model_config.moonshine.preprocessor = + "./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx"; + config.model_config.moonshine.encoder = + "./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx"; + config.model_config.moonshine.uncached_decoder = + "./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx"; + config.model_config.moonshine.cached_decoder = + "./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx"; + config.model_config.tokens = + "./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt"; + + config.model_config.num_threads = 1; + + std::cout << "Loading model\n"; + OfflineRecognizer recongizer = OfflineRecognizer::Create(config); + if (!recongizer.Get()) { + std::cerr << "Please check your config\n"; + return -1; + } + std::cout << "Loading model done\n"; + + std::string wave_filename = + "./sherpa-onnx-moonshine-tiny-en-int8/test_wavs/0.wav"; + Wave wave = ReadWave(wave_filename); + if (wave.samples.empty()) { + std::cerr << "Failed to read: '" << wave_filename << "'\n"; + return -1; + } + + std::cout << "Start recognition\n"; + const auto begin = std::chrono::steady_clock::now(); + + OfflineStream stream = recongizer.CreateStream(); + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), + wave.samples.size()); + + recongizer.Decode(&stream); + + OfflineRecognizerResult result = recongizer.GetResult(&stream); + + const auto end = std::chrono::steady_clock::now(); + const float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + float duration = wave.samples.size() / static_cast(wave.sample_rate); + float rtf = elapsed_seconds / duration; + + std::cout << "text: " << result.text << "\n"; + printf("Number of threads: %d\n", config.model_config.num_threads); + printf("Duration: %.3fs\n", duration); + printf("Elapsed seconds: %.3fs\n", elapsed_seconds); + printf("(Real time factor) RTF = %.3f / %.3f = %.3f\n", elapsed_seconds, + duration, rtf); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/cxx-api-examples/sense-voice-cxx-api.cc b/apps/frameworks/sherpa-mnn/cxx-api-examples/sense-voice-cxx-api.cc new file mode 100644 index 00000000..e6fcffa7 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cxx-api-examples/sense-voice-cxx-api.cc @@ -0,0 +1,78 @@ +// cxx-api-examples/sense-voice-cxx-api.cc +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use sense voice with sherpa-onnx's C++ API. +// +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +// tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +// rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +// +// clang-format on + +#include // NOLINT +#include +#include + +#include "sherpa-mnn/c-api/cxx-api.h" + +int32_t main() { + using namespace sherpa_mnn::cxx; // NOLINT + OfflineRecognizerConfig config; + + config.model_config.sense_voice.model = + "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx"; + config.model_config.sense_voice.use_itn = true; + config.model_config.sense_voice.language = "auto"; + config.model_config.tokens = + "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt"; + + config.model_config.num_threads = 1; + + std::cout << "Loading model\n"; + OfflineRecognizer recongizer = OfflineRecognizer::Create(config); + if (!recongizer.Get()) { + std::cerr << "Please check your config\n"; + return -1; + } + std::cout << "Loading model done\n"; + + std::string wave_filename = + "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/en.wav"; + + Wave wave = ReadWave(wave_filename); + if (wave.samples.empty()) { + std::cerr << "Failed to read: '" << wave_filename << "'\n"; + return -1; + } + + std::cout << "Start recognition\n"; + const auto begin = std::chrono::steady_clock::now(); + + OfflineStream stream = recongizer.CreateStream(); + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), + wave.samples.size()); + + recongizer.Decode(&stream); + + OfflineRecognizerResult result = recongizer.GetResult(&stream); + + const auto end = std::chrono::steady_clock::now(); + const float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + float duration = wave.samples.size() / static_cast(wave.sample_rate); + float rtf = elapsed_seconds / duration; + + std::cout << "text: " << result.text << "\n"; + printf("Number of threads: %d\n", config.model_config.num_threads); + printf("Duration: %.3fs\n", duration); + printf("Elapsed seconds: %.3fs\n", elapsed_seconds); + printf("(Real time factor) RTF = %.3f / %.3f = %.3f\n", elapsed_seconds, + duration, rtf); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/cxx-api-examples/speech-enhancement-gtcrn-cxx-api.cc b/apps/frameworks/sherpa-mnn/cxx-api-examples/speech-enhancement-gtcrn-cxx-api.cc new file mode 100644 index 00000000..9e96c373 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cxx-api-examples/speech-enhancement-gtcrn-cxx-api.cc @@ -0,0 +1,65 @@ +// cxx-api-examples/speech-enhancement-gtcrn-cxx-api.cc +// +// Copyright (c) 2025 Xiaomi Corporation +// +// We assume you have pre-downloaded model +// from +// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models +// +// +// An example command to download +// clang-format off +/* +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/inp_16k.wav +*/ +// clang-format on +#include // NOLINT +#include +#include + +#include "sherpa-mnn/c-api/cxx-api.h" + +int32_t main() { + using namespace sherpa_mnn::cxx; // NOLINT + + OfflineSpeechDenoiserConfig config; + std::string wav_filename = "./inp_16k.wav"; + std::string out_wave_filename = "./enhanced_16k.wav"; + + config.model.gtcrn.model = "./gtcrn_simple.onnx"; + + auto sd = OfflineSpeechDenoiser::Create(config); + if (!sd.Get()) { + std::cerr << "Please check your config\n"; + return -1; + } + + Wave wave = ReadWave(wav_filename); + if (wave.samples.empty()) { + std::cerr << "Failed to read: '" << wav_filename << "'\n"; + return -1; + } + + std::cout << "Started\n"; + const auto begin = std::chrono::steady_clock::now(); + auto denoised = + sd.Run(wave.samples.data(), wave.samples.size(), wave.sample_rate); + const auto end = std::chrono::steady_clock::now(); + std::cout << "Done\n"; + + WriteWave(out_wave_filename, {denoised.samples, denoised.sample_rate}); + + const float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + float duration = wave.samples.size() / static_cast(wave.sample_rate); + float rtf = elapsed_seconds / duration; + + std::cout << "Saved to " << out_wave_filename << "\n"; + printf("Duration: %.3fs\n", duration); + printf("Elapsed seconds: %.3fs\n", elapsed_seconds); + printf("(Real time factor) RTF = %.3f / %.3f = %.3f\n", elapsed_seconds, + duration, rtf); +} diff --git a/apps/frameworks/sherpa-mnn/cxx-api-examples/streaming-zipformer-cxx-api.cc b/apps/frameworks/sherpa-mnn/cxx-api-examples/streaming-zipformer-cxx-api.cc new file mode 100644 index 00000000..1d8b3067 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cxx-api-examples/streaming-zipformer-cxx-api.cc @@ -0,0 +1,93 @@ +// cxx-api-examples/streaming-zipformer-cxx-api.cc +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use streaming Zipformer +// with sherpa-onnx's C++ API. +// +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 +// tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 +// rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 +// +// clang-format on + +#include // NOLINT +#include +#include + +#include "sherpa-mnn/c-api/cxx-api.h" + +int32_t main() { + using namespace sherpa_mnn::cxx; // NOLINT + OnlineRecognizerConfig config; + + // please see + // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english + config.model_config.transducer.encoder = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/" + "encoder-epoch-99-avg-1.int8.onnx"; + + // Note: We recommend not using int8.onnx for the decoder. + config.model_config.transducer.decoder = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/" + "decoder-epoch-99-avg-1.onnx"; + + config.model_config.transducer.joiner = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/" + "joiner-epoch-99-avg-1.int8.onnx"; + + config.model_config.tokens = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt"; + + config.model_config.num_threads = 1; + + std::cout << "Loading model\n"; + OnlineRecognizer recongizer = OnlineRecognizer::Create(config); + if (!recongizer.Get()) { + std::cerr << "Please check your config\n"; + return -1; + } + std::cout << "Loading model done\n"; + + std::string wave_filename = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/" + "0.wav"; + Wave wave = ReadWave(wave_filename); + if (wave.samples.empty()) { + std::cerr << "Failed to read: '" << wave_filename << "'\n"; + return -1; + } + + std::cout << "Start recognition\n"; + const auto begin = std::chrono::steady_clock::now(); + + OnlineStream stream = recongizer.CreateStream(); + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), + wave.samples.size()); + stream.InputFinished(); + + while (recongizer.IsReady(&stream)) { + recongizer.Decode(&stream); + } + + OnlineRecognizerResult result = recongizer.GetResult(&stream); + + const auto end = std::chrono::steady_clock::now(); + const float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + float duration = wave.samples.size() / static_cast(wave.sample_rate); + float rtf = elapsed_seconds / duration; + + std::cout << "text: " << result.text << "\n"; + printf("Number of threads: %d\n", config.model_config.num_threads); + printf("Duration: %.3fs\n", duration); + printf("Elapsed seconds: %.3fs\n", elapsed_seconds); + printf("(Real time factor) RTF = %.3f / %.3f = %.3f\n", elapsed_seconds, + duration, rtf); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/cxx-api-examples/streaming-zipformer-rtf-cxx-api.cc b/apps/frameworks/sherpa-mnn/cxx-api-examples/streaming-zipformer-rtf-cxx-api.cc new file mode 100644 index 00000000..a9e05567 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cxx-api-examples/streaming-zipformer-rtf-cxx-api.cc @@ -0,0 +1,132 @@ +// cxx-api-examples/streaming-zipformer-rtf-cxx-api.cc +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use streaming Zipformer +// with sherpa-onnx's C++ API. +// +// clang-format off +// +// cd /path/sherpa-onnx/ +// mkdir build +// cd build +// cmake .. +// make +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 +// tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 +// rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 +// +// # 1. Test on CPU, run once +// +// ./bin/streaming-zipformer-rtf-cxx-api +// +// # 2. Test on CPU, run 10 times +// +// ./bin/streaming-zipformer-rtf-cxx-api 10 +// +// # 3. Test on GPU, run 10 times +// +// ./bin/streaming-zipformer-rtf-cxx-api 10 cuda +// +// clang-format on + +#include // NOLINT +#include +#include + +#include "sherpa-mnn/c-api/cxx-api.h" + +int32_t main(int argc, char *argv[]) { + int32_t num_runs = 1; + if (argc >= 2) { + num_runs = atoi(argv[1]); + if (num_runs < 0) { + num_runs = 1; + } + } + + bool use_gpu = (argc == 3); + + using namespace sherpa_mnn::cxx; // NOLINT + OnlineRecognizerConfig config; + + // please see + // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english + config.model_config.transducer.encoder = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/" + "encoder-epoch-99-avg-1.int8.onnx"; + + // Note: We recommend not using int8.onnx for the decoder. + config.model_config.transducer.decoder = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/" + "decoder-epoch-99-avg-1.onnx"; + + config.model_config.transducer.joiner = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/" + "joiner-epoch-99-avg-1.int8.onnx"; + + config.model_config.tokens = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt"; + + config.model_config.num_threads = 1; + config.model_config.provider = use_gpu ? "cuda" : "cpu"; + + std::cout << "Loading model\n"; + OnlineRecognizer recongizer = OnlineRecognizer::Create(config); + if (!recongizer.Get()) { + std::cerr << "Please check your config\n"; + return -1; + } + std::cout << "Loading model done\n"; + + std::string wave_filename = + "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/" + "0.wav"; + Wave wave = ReadWave(wave_filename); + if (wave.samples.empty()) { + std::cerr << "Failed to read: '" << wave_filename << "'\n"; + return -1; + } + + std::cout << "Start recognition\n"; + float total_elapsed_seconds = 0; + OnlineRecognizerResult result; + for (int32_t i = 0; i < num_runs; ++i) { + const auto begin = std::chrono::steady_clock::now(); + + OnlineStream stream = recongizer.CreateStream(); + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), + wave.samples.size()); + stream.InputFinished(); + + while (recongizer.IsReady(&stream)) { + recongizer.Decode(&stream); + } + + result = recongizer.GetResult(&stream); + + auto end = std::chrono::steady_clock::now(); + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + printf("Run %d/%d, elapsed seconds: %.3f\n", i, num_runs, elapsed_seconds); + total_elapsed_seconds += elapsed_seconds; + } + float average_elapsed_secodns = total_elapsed_seconds / num_runs; + float duration = wave.samples.size() / static_cast(wave.sample_rate); + float rtf = total_elapsed_seconds / num_runs / duration; + + std::cout << "text: " << result.text << "\n"; + printf("Number of threads: %d\n", config.model_config.num_threads); + printf("Duration: %.3fs\n", duration); + printf("Total Elapsed seconds: %.3fs\n", total_elapsed_seconds); + printf("Num runs: %d\n", num_runs); + printf("Elapsed seconds per run: %.3f/%d=%.3f\n", total_elapsed_seconds, + num_runs, average_elapsed_secodns); + printf("(Real time factor) RTF = %.3f / %.3f = %.3f\n", + average_elapsed_secodns, duration, rtf); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/cxx-api-examples/whisper-cxx-api.cc b/apps/frameworks/sherpa-mnn/cxx-api-examples/whisper-cxx-api.cc new file mode 100644 index 00000000..0e1c05c3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/cxx-api-examples/whisper-cxx-api.cc @@ -0,0 +1,76 @@ +// cxx-api-examples/whisper-cxx-api.cc +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to use whisper with sherpa-onnx's C++ API. +// +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 +// tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2 +// rm sherpa-onnx-whisper-tiny.en.tar.bz2 +// +// clang-format on + +#include // NOLINT +#include +#include + +#include "sherpa-mnn/c-api/cxx-api.h" + +int32_t main() { + using namespace sherpa_mnn::cxx; // NOLINT + OfflineRecognizerConfig config; + + config.model_config.whisper.encoder = + "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx"; + config.model_config.whisper.decoder = + "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx"; + config.model_config.tokens = + "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt"; + + config.model_config.num_threads = 1; + + std::cout << "Loading model\n"; + OfflineRecognizer recongizer = OfflineRecognizer::Create(config); + if (!recongizer.Get()) { + std::cerr << "Please check your config\n"; + return -1; + } + std::cout << "Loading model done\n"; + + std::string wave_filename = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav"; + Wave wave = ReadWave(wave_filename); + if (wave.samples.empty()) { + std::cerr << "Failed to read: '" << wave_filename << "'\n"; + return -1; + } + + std::cout << "Start recognition\n"; + const auto begin = std::chrono::steady_clock::now(); + + OfflineStream stream = recongizer.CreateStream(); + stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), + wave.samples.size()); + + recongizer.Decode(&stream); + + OfflineRecognizerResult result = recongizer.GetResult(&stream); + + const auto end = std::chrono::steady_clock::now(); + const float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + float duration = wave.samples.size() / static_cast(wave.sample_rate); + float rtf = elapsed_seconds / duration; + + std::cout << "text: " << result.text << "\n"; + printf("Number of threads: %d\n", config.model_config.num_threads); + printf("Duration: %.3fs\n", duration); + printf("Elapsed seconds: %.3fs\n", elapsed_seconds); + printf("(Real time factor) RTF = %.3f / %.3f = %.3f\n", elapsed_seconds, + duration, rtf); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/.gitignore b/apps/frameworks/sherpa-mnn/kotlin-api-examples/.gitignore new file mode 100644 index 00000000..681205b6 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/.gitignore @@ -0,0 +1,3 @@ +hs_err* +vits-zh-aishell3 +*.jar diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/AudioTagging.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/AudioTagging.kt new file mode 100644 index 00000000..984b8579 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/AudioTagging.kt @@ -0,0 +1 @@ +../sherpa-mnn/kotlin-api/AudioTagging.kt \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/FeatureConfig.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/FeatureConfig.kt new file mode 100644 index 00000000..8a432d4e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/FeatureConfig.kt @@ -0,0 +1 @@ +../sherpa-mnn/kotlin-api/FeatureConfig.kt \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/OfflinePunctuation.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/OfflinePunctuation.kt new file mode 100644 index 00000000..063334e1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/OfflinePunctuation.kt @@ -0,0 +1 @@ +../sherpa-mnn/kotlin-api/OfflinePunctuation.kt \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/OfflineRecognizer.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/OfflineRecognizer.kt new file mode 100644 index 00000000..39be0a2f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/OfflineRecognizer.kt @@ -0,0 +1 @@ +../sherpa-mnn/kotlin-api/OfflineRecognizer.kt \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/OfflineSpeakerDiarization.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/OfflineSpeakerDiarization.kt new file mode 100644 index 00000000..238777fa --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/OfflineSpeakerDiarization.kt @@ -0,0 +1 @@ +../sherpa-mnn/kotlin-api/OfflineSpeakerDiarization.kt \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/OfflineStream.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/OfflineStream.kt new file mode 100644 index 00000000..e82e994d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/OfflineStream.kt @@ -0,0 +1 @@ +../sherpa-mnn/kotlin-api/OfflineStream.kt \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/OnlinePunctuation.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/OnlinePunctuation.kt new file mode 100644 index 00000000..98c0a805 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/OnlinePunctuation.kt @@ -0,0 +1 @@ +../sherpa-mnn/kotlin-api/OnlinePunctuation.kt \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/OnlineRecognizer.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/OnlineRecognizer.kt new file mode 100644 index 00000000..f0b47776 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/OnlineRecognizer.kt @@ -0,0 +1 @@ +../sherpa-mnn/kotlin-api/OnlineRecognizer.kt \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/OnlineStream.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/OnlineStream.kt new file mode 100644 index 00000000..98c760ed --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/OnlineStream.kt @@ -0,0 +1 @@ +../sherpa-mnn/kotlin-api/OnlineStream.kt \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/Speaker.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/Speaker.kt new file mode 100644 index 00000000..de0983f2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/Speaker.kt @@ -0,0 +1 @@ +../sherpa-mnn/kotlin-api/Speaker.kt \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/SpeakerEmbeddingExtractorConfig.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/SpeakerEmbeddingExtractorConfig.kt new file mode 100644 index 00000000..601d0832 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/SpeakerEmbeddingExtractorConfig.kt @@ -0,0 +1 @@ +../sherpa-mnn/kotlin-api/SpeakerEmbeddingExtractorConfig.kt \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/SpokenLanguageIdentification.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/SpokenLanguageIdentification.kt new file mode 100644 index 00000000..4e1bba44 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/SpokenLanguageIdentification.kt @@ -0,0 +1 @@ +../sherpa-mnn/kotlin-api/SpokenLanguageIdentification.kt \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/Tts.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/Tts.kt new file mode 100644 index 00000000..db803c65 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/Tts.kt @@ -0,0 +1 @@ +../sherpa-mnn/kotlin-api/Tts.kt \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/Vad.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/Vad.kt new file mode 100644 index 00000000..016caab0 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/Vad.kt @@ -0,0 +1 @@ +../sherpa-mnn/kotlin-api/Vad.kt \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/WaveReader.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/WaveReader.kt new file mode 100644 index 00000000..b836cb96 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/WaveReader.kt @@ -0,0 +1 @@ +../sherpa-mnn/kotlin-api/WaveReader.kt \ No newline at end of file diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/faked-asset-manager.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/faked-asset-manager.kt new file mode 100644 index 00000000..477c29bc --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/faked-asset-manager.kt @@ -0,0 +1,3 @@ +package android.content.res + +class AssetManager {} diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/faked-log.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/faked-log.kt new file mode 100644 index 00000000..e25d5a31 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/faked-log.kt @@ -0,0 +1,10 @@ +package android.util + +class Log { + companion object { + fun i(tag: String, msg: String) { + println("$tag, $msg") + } + } +} + diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/run.sh b/apps/frameworks/sherpa-mnn/kotlin-api-examples/run.sh new file mode 100755 index 00000000..72fffa90 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/run.sh @@ -0,0 +1,384 @@ +#!/usr/bin/env bash +# +# This scripts shows how to build JNI libs for sherpa-onnx +# Note: This scripts runs only on Linux and macOS, though sherpa-onnx +# supports building JNI libs for Windows. + +set -ex + +if [[ ! -f ../build/lib/libsherpa-onnx-jni.dylib && ! -f ../build/lib/libsherpa-onnx-jni.so ]]; then + mkdir -p ../build + pushd ../build + cmake \ + -DSHERPA_MNN_ENABLE_PYTHON=OFF \ + -DSHERPA_MNN_ENABLE_TESTS=OFF \ + -DSHERPA_MNN_ENABLE_CHECK=OFF \ + -DBUILD_SHARED_LIBS=ON \ + -DBUILD_SHARED_LIBS=ON \ + -DSHERPA_MNN_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_MNN_ENABLE_JNI=ON \ + .. + + make -j4 + ls -lh lib + popd +fi + +export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH + +function testSpeakerEmbeddingExtractor() { + if [ ! -f ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx + fi + + if [ ! -f ./speaker1_a_cn_16k.wav ]; then + curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav + fi + + if [ ! -f ./speaker1_b_cn_16k.wav ]; then + curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav + fi + + if [ ! -f ./speaker2_a_cn_16k.wav ]; then + curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav + fi + + out_filename=test_speaker_id.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_speaker_id.kt \ + OnlineStream.kt \ + Speaker.kt \ + SpeakerEmbeddingExtractorConfig.kt \ + WaveReader.kt \ + faked-asset-manager.kt \ + faked-log.kt + + ls -lh $out_filename + + java -Djava.library.path=../build/lib -jar $out_filename +} + + +function testOnlineAsr() { + if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then + git lfs install + GIT_CLONE_PROTECTION_ACTIVE=false git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 + fi + + if [ ! -f ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2 + tar xvf sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2 + rm sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms.tar.bz2 + fi + + if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + fi + + if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18 ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + fi + + out_filename=test_online_asr.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_online_asr.kt \ + FeatureConfig.kt \ + OnlineRecognizer.kt \ + OnlineStream.kt \ + WaveReader.kt \ + faked-asset-manager.kt \ + faked-log.kt + + ls -lh $out_filename + + java -Djava.library.path=../build/lib -jar $out_filename +} + +function testTts() { + if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 + tar xf vits-piper-en_US-amy-low.tar.bz2 + rm vits-piper-en_US-amy-low.tar.bz2 + fi + + if [ ! -f ./matcha-icefall-zh-baker/model-steps-3.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2 + tar xvf matcha-icefall-zh-baker.tar.bz2 + rm matcha-icefall-zh-baker.tar.bz2 + fi + + if [ ! -f ./hifigan_v2.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx + fi + + if [ ! -f ./kokoro-multi-lang-v1_0/model.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/kokoro-multi-lang-v1_0.tar.bz2 + tar xf kokoro-multi-lang-v1_0.tar.bz2 + rm kokoro-multi-lang-v1_0.tar.bz2 + fi + + if [ ! -f ./kokoro-en-v0_19/model.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/kokoro-en-v0_19.tar.bz2 + tar xf kokoro-en-v0_19.tar.bz2 + rm kokoro-en-v0_19.tar.bz2 + fi + + out_filename=test_tts.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_tts.kt \ + Tts.kt \ + faked-asset-manager.kt \ + faked-log.kt + + ls -lh $out_filename + + java -Djava.library.path=../build/lib -jar $out_filename +} + + +function testAudioTagging() { + if [ ! -d sherpa-onnx-zipformer-audio-tagging-2024-04-09 ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 + tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 + rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 + fi + + out_filename=test_audio_tagging.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_audio_tagging.kt \ + AudioTagging.kt \ + OfflineStream.kt \ + WaveReader.kt \ + faked-asset-manager.kt \ + faked-log.kt + + ls -lh $out_filename + + java -Djava.library.path=../build/lib -jar $out_filename +} + + +function testSpokenLanguageIdentification() { + if [ ! -f ./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2 + tar xvf sherpa-onnx-whisper-tiny.tar.bz2 + rm sherpa-onnx-whisper-tiny.tar.bz2 + fi + + if [ ! -f ./spoken-language-identification-test-wavs/ar-arabic.wav ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/spoken-language-identification-test-wavs.tar.bz2 + tar xvf spoken-language-identification-test-wavs.tar.bz2 + rm spoken-language-identification-test-wavs.tar.bz2 + fi + + out_filename=test_language_id.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_language_id.kt \ + SpokenLanguageIdentification.kt \ + OfflineStream.kt \ + WaveReader.kt \ + faked-asset-manager.kt \ + faked-log.kt + + ls -lh $out_filename + + java -Djava.library.path=../build/lib -jar $out_filename +} + +function testOfflineAsr() { + if [ ! -f ./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/encoder.int8.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 + tar xvf sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 + rm sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 + ls -lh sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16 + fi + + if [ ! -f ./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 + tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 + rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 + fi + + if [ ! -f ./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 + tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 + rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 + fi + + if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 + tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2 + rm sherpa-onnx-whisper-tiny.en.tar.bz2 + fi + + if [ ! -f ./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2 + tar xvf sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2 + rm sherpa-onnx-nemo-ctc-en-citrinet-512.tar.bz2 + fi + + if [ ! -f ./sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 + tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 + rm sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 + fi + + if [ ! -f ./sherpa-onnx-zipformer-multi-zh-hans-2023-9-2/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-multi-zh-hans-2023-9-2.tar.bz2 + tar xvf sherpa-onnx-zipformer-multi-zh-hans-2023-9-2.tar.bz2 + rm sherpa-onnx-zipformer-multi-zh-hans-2023-9-2.tar.bz2 + fi + + out_filename=test_offline_asr.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_offline_asr.kt \ + FeatureConfig.kt \ + OfflineRecognizer.kt \ + OfflineStream.kt \ + WaveReader.kt \ + faked-asset-manager.kt + + ls -lh $out_filename + java -Djava.library.path=../build/lib -jar $out_filename +} + +function testInverseTextNormalizationOfflineAsr() { + if [ ! -f ./sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 + tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 + rm sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 + fi + + if [ ! -f ./itn-zh-number.wav ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav + fi + + if [ ! -f ./itn_zh_number.fst ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst + fi + + out_filename=test_itn_offline_asr.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_itn_offline_asr.kt \ + FeatureConfig.kt \ + OfflineRecognizer.kt \ + OfflineStream.kt \ + WaveReader.kt \ + faked-asset-manager.kt + + ls -lh $out_filename + java -Djava.library.path=../build/lib -jar $out_filename +} + +function testInverseTextNormalizationOnlineAsr() { + if [ ! -f ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 + rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 + fi + + if [ ! -f ./itn-zh-number.wav ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav + fi + + if [ ! -f ./itn_zh_number.fst ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst + fi + + out_filename=test_itn_online_asr.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_itn_online_asr.kt \ + FeatureConfig.kt \ + OnlineRecognizer.kt \ + OnlineStream.kt \ + WaveReader.kt \ + faked-asset-manager.kt + + ls -lh $out_filename + java -Djava.library.path=../build/lib -jar $out_filename +} + +function testOfflinePunctuation() { + if [ ! -f ./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 + tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 + rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 + fi + + out_filename=test_offline_punctuation.jar + kotlinc-jvm -include-runtime -d $out_filename \ + ./test_offline_punctuation.kt \ + ./OfflinePunctuation.kt \ + faked-asset-manager.kt \ + faked-log.kt + + ls -lh $out_filename + + java -Djava.library.path=../build/lib -jar $out_filename +} + +function testOnlinePunctuation() { + if [ ! -f ./sherpa-onnx-online-punct-en-2024-08-06/model.int8.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 + tar xvf sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 + rm sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 + fi + + out_filename=test_online_punctuation.jar + kotlinc-jvm -include-runtime -d $out_filename \ + ./test_online_punctuation.kt \ + ./OnlinePunctuation.kt \ + faked-asset-manager.kt \ + faked-log.kt + + ls -lh $out_filename + + java -Djava.library.path=../build/lib -jar $out_filename +} + +function testOfflineSpeakerDiarization() { + if [ ! -f ./sherpa-onnx-pyannote-segmentation-3-0/model.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + fi + + if [ ! -f ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx + fi + + if [ ! -f ./0-four-speakers-zh.wav ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav + fi + + out_filename=test_offline_speaker_diarization.jar + kotlinc-jvm -include-runtime -d $out_filename \ + test_offline_speaker_diarization.kt \ + OfflineSpeakerDiarization.kt \ + Speaker.kt \ + SpeakerEmbeddingExtractorConfig.kt \ + OnlineStream.kt \ + WaveReader.kt \ + faked-asset-manager.kt \ + faked-log.kt + + ls -lh $out_filename + + java -Djava.library.path=../build/lib -jar $out_filename +} + +testOfflineSpeakerDiarization +testSpeakerEmbeddingExtractor +testOnlineAsr +testTts +testAudioTagging +testSpokenLanguageIdentification +testOfflineAsr +testOfflinePunctuation +testOnlinePunctuation +testInverseTextNormalizationOfflineAsr +testInverseTextNormalizationOnlineAsr diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_audio_tagging.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_audio_tagging.kt new file mode 100644 index 00000000..95d2f79b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_audio_tagging.kt @@ -0,0 +1,49 @@ +package com.k2fsa.sherpa.mnn + +fun main() { + testAudioTagging() +} + +fun testAudioTagging() { + val config = AudioTaggingConfig( + model=AudioTaggingModelConfig( + zipformer=OfflineZipformerAudioTaggingModelConfig( + model="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.int8.onnx", + ), + numThreads=1, + debug=true, + provider="cpu", + ), + labels="./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv", + topK=5, + ) + val tagger = AudioTagging(config=config) + + val testFiles = arrayOf( + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav", + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/2.wav", + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/3.wav", + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/4.wav", + ) + println("----------") + for (waveFilename in testFiles) { + val stream = tagger.createStream() + + val objArray = WaveReader.readWaveFromFile( + filename = waveFilename, + ) + val samples: FloatArray = objArray[0] as FloatArray + val sampleRate: Int = objArray[1] as Int + + stream.acceptWaveform(samples, sampleRate = sampleRate) + val events = tagger.compute(stream) + stream.release() + + println(waveFilename) + println(events) + println("----------") + } + + tagger.release() +} + diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_itn_offline_asr.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_itn_offline_asr.kt new file mode 100644 index 00000000..2d723a82 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_itn_offline_asr.kt @@ -0,0 +1,37 @@ +package com.k2fsa.sherpa.mnn + +fun main() { + test() +} + +fun test() { + val recognizer = createOfflineRecognizer() + val waveFilename = "./itn-zh-number.wav"; + + val objArray = WaveReader.readWaveFromFile( + filename = waveFilename, + ) + val samples: FloatArray = objArray[0] as FloatArray + val sampleRate: Int = objArray[1] as Int + + val stream = recognizer.createStream() + stream.acceptWaveform(samples, sampleRate=sampleRate) + recognizer.decode(stream) + + val result = recognizer.getResult(stream) + println(result) + + stream.release() + recognizer.release() +} + +fun createOfflineRecognizer(): OfflineRecognizer { + val config = OfflineRecognizerConfig( + featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80), + modelConfig = getOfflineModelConfig(0)!!, + ruleFsts = "./itn_zh_number.fst", + ) + + return OfflineRecognizer(config = config) +} + diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_itn_online_asr.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_itn_online_asr.kt new file mode 100644 index 00000000..02544d6c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_itn_online_asr.kt @@ -0,0 +1,41 @@ +package com.k2fsa.sherpa.mnn + +fun main() { + test() +} + +fun test() { + val recognizer = createOnlineRecognizer() + val waveFilename = "./itn-zh-number.wav"; + + val objArray = WaveReader.readWaveFromFile( + filename = waveFilename, + ) + val samples: FloatArray = objArray[0] as FloatArray + val sampleRate: Int = objArray[1] as Int + + val stream = recognizer.createStream() + stream.acceptWaveform(samples, sampleRate=sampleRate) + while (recognizer.isReady(stream)) { + recognizer.decode(stream) + } + + val result = recognizer.getResult(stream).text + println(result) + + stream.release() + recognizer.release() +} + +fun createOnlineRecognizer(): OnlineRecognizer { + val config = OnlineRecognizerConfig( + featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80), + modelConfig = getModelConfig(8)!!, + ) + + config.ruleFsts = "./itn_zh_number.fst" + println(config) + + return OnlineRecognizer(config = config) +} + diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_language_id.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_language_id.kt new file mode 100644 index 00000000..9af97652 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_language_id.kt @@ -0,0 +1,43 @@ +package com.k2fsa.sherpa.mnn + +fun main() { + testSpokenLanguageIdentifcation() +} + +fun testSpokenLanguageIdentifcation() { + val config = SpokenLanguageIdentificationConfig( + whisper = SpokenLanguageIdentificationWhisperConfig( + encoder = "./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx", + decoder = "./sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx", + tailPaddings = 33, + ), + numThreads=1, + debug=true, + provider="cpu", + ) + val slid = SpokenLanguageIdentification(config=config) + + val testFiles = arrayOf( + "./spoken-language-identification-test-wavs/ar-arabic.wav", + "./spoken-language-identification-test-wavs/bg-bulgarian.wav", + "./spoken-language-identification-test-wavs/de-german.wav", + ) + + for (waveFilename in testFiles) { + val objArray = WaveReader.readWaveFromFile( + filename = waveFilename, + ) + val samples: FloatArray = objArray[0] as FloatArray + val sampleRate: Int = objArray[1] as Int + + val stream = slid.createStream() + stream.acceptWaveform(samples, sampleRate = sampleRate) + val lang = slid.compute(stream) + stream.release() + println(waveFilename) + println(lang) + } + + slid.release() +} + diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_offline_asr.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_offline_asr.kt new file mode 100644 index 00000000..572a1268 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_offline_asr.kt @@ -0,0 +1,48 @@ +package com.k2fsa.sherpa.mnn + +fun main() { + val types = arrayOf(0, 2, 5, 6, 15, 21, 24) + for (type in types) { + test(type) + } +} + +fun test(type: Int) { + val recognizer = createOfflineRecognizer(type) + + val waveFilename = when (type) { + 0 -> "./sherpa-onnx-paraformer-zh-2023-09-14/test_wavs/0.wav" + 2 -> "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav" + 5 -> "./sherpa-onnx-zipformer-multi-zh-hans-2023-9-2/test_wavs/1.wav" + 6 -> "./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav" + 15 -> "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/zh.wav" + 21 -> "./sherpa-onnx-moonshine-tiny-en-int8/test_wavs/0.wav" + 24 -> "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/0.wav" + else -> null + } + + val objArray = WaveReader.readWaveFromFile( + filename = waveFilename!!, + ) + val samples: FloatArray = objArray[0] as FloatArray + val sampleRate: Int = objArray[1] as Int + + val stream = recognizer.createStream() + stream.acceptWaveform(samples, sampleRate=sampleRate) + recognizer.decode(stream) + + val result = recognizer.getResult(stream) + println(result) + + stream.release() + recognizer.release() +} + +fun createOfflineRecognizer(type: Int): OfflineRecognizer { + val config = OfflineRecognizerConfig( + featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80), + modelConfig = getOfflineModelConfig(type = type)!!, + ) + + return OfflineRecognizer(config = config) +} diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_offline_punctuation.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_offline_punctuation.kt new file mode 100644 index 00000000..bc402c13 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_offline_punctuation.kt @@ -0,0 +1,31 @@ +package com.k2fsa.sherpa.mnn + +fun main() { + testPunctuation() +} + +fun testPunctuation() { + val config = OfflinePunctuationConfig( + model=OfflinePunctuationModelConfig( + ctTransformer="./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx", + numThreads=1, + debug=true, + provider="cpu", + ) + ) + val punct = OfflinePunctuation(config = config) + val sentences = arrayOf( + "这是一个测试你好吗How are you我很好thank you are you ok谢谢你", + "我们都是木头人不会说话不会动", + "The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry", + ) + println("---") + for (text in sentences) { + val out = punct.addPunctuation(text) + println("Input: $text") + println("Output: $out") + println("---") + } + println(sentences) + +} diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_offline_speaker_diarization.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_offline_speaker_diarization.kt new file mode 100644 index 00000000..44ae1ba6 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_offline_speaker_diarization.kt @@ -0,0 +1,53 @@ +package com.k2fsa.sherpa.mnn + +fun main() { + testOfflineSpeakerDiarization() +} + +fun callback(numProcessedChunks: Int, numTotalChunks: Int, arg: Long): Int { + val progress = numProcessedChunks.toFloat() / numTotalChunks * 100 + val s = "%.2f".format(progress) + println("Progress: ${s}%"); + + return 0 +} + +fun testOfflineSpeakerDiarization() { + var config = OfflineSpeakerDiarizationConfig( + segmentation=OfflineSpeakerSegmentationModelConfig( + pyannote=OfflineSpeakerSegmentationPyannoteModelConfig("./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"), + ), + embedding=SpeakerEmbeddingExtractorConfig( + model="./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx", + ), + + // The test wave file ./0-four-speakers-zh.wav contains four speakers, so + // we use numClusters=4 here. If you don't know the number of speakers + // in the test wave file, please set the threshold like below. + // + // clustering=FastClusteringConfig(threshold=0.5), + // + // WARNING: You need to tune threshold by yourself. + // A larger threshold leads to fewer clusters, i.e., few speakers. + // A smaller threshold leads to more clusters, i.e., more speakers. + // + clustering=FastClusteringConfig(numClusters=4), + ) + + val sd = OfflineSpeakerDiarization(config=config) + + val waveData = WaveReader.readWave( + filename = "./0-four-speakers-zh.wav", + ) + + if (sd.sampleRate() != waveData.sampleRate) { + println("Expected sample rate: ${sd.sampleRate()}, given: ${waveData.sampleRate}") + return + } + + // val segments = sd.process(waveData.samples) // this one is also ok + val segments = sd.processWithCallback(waveData.samples, callback=::callback) + for (segment in segments) { + println("${segment.start} -- ${segment.end} speaker_${segment.speaker}") + } +} diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_online_asr.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_online_asr.kt new file mode 100644 index 00000000..29d56ced --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_online_asr.kt @@ -0,0 +1,114 @@ +package com.k2fsa.sherpa.mnn + +fun main() { + testOnlineAsr("transducer") + testOnlineAsr("zipformer2-ctc") + testOnlineAsr("ctc-hlg") + testOnlineAsr("nemo-ctc") +} + +fun testOnlineAsr(type: String) { + val featConfig = FeatureConfig( + sampleRate = 16000, + featureDim = 80, + ) + + var ctcFstDecoderConfig = OnlineCtcFstDecoderConfig() + val waveFilename: String + val modelConfig: OnlineModelConfig = when (type) { + "transducer" -> { + waveFilename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/0.wav" + // please refer to + // https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html + // to dowload pre-trained models + OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx", + decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", + joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", + ), + tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", + numThreads = 1, + debug = false, + ) + } + "zipformer2-ctc" -> { + waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav" + OnlineModelConfig( + zipformer2Ctc = OnlineZipformer2CtcModelConfig( + model = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx", + ), + tokens = "./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt", + numThreads = 1, + debug = false, + ) + } + "nemo-ctc" -> { + waveFilename = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/test_wavs/0.wav" + OnlineModelConfig( + neMoCtc = OnlineNeMoCtcModelConfig( + model = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/model.onnx", + ), + tokens = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/tokens.txt", + numThreads = 1, + debug = false, + ) + } + "ctc-hlg" -> { + waveFilename = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/1.wav" + ctcFstDecoderConfig.graph = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst" + OnlineModelConfig( + zipformer2Ctc = OnlineZipformer2CtcModelConfig( + model = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx", + ), + tokens = "./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt", + numThreads = 1, + debug = false, + ) + } + else -> throw IllegalArgumentException(type) + } + + val endpointConfig = EndpointConfig() + + val lmConfig = OnlineLMConfig() + + val config = OnlineRecognizerConfig( + modelConfig = modelConfig, + lmConfig = lmConfig, + featConfig = featConfig, + ctcFstDecoderConfig=ctcFstDecoderConfig, + endpointConfig = endpointConfig, + enableEndpoint = true, + decodingMethod = "greedy_search", + maxActivePaths = 4, + ) + + val recognizer = OnlineRecognizer( + config = config, + ) + + val objArray = WaveReader.readWaveFromFile( + filename = waveFilename, + ) + val samples: FloatArray = objArray[0] as FloatArray + val sampleRate: Int = objArray[1] as Int + + val stream = recognizer.createStream() + stream.acceptWaveform(samples, sampleRate = sampleRate) + while (recognizer.isReady(stream)) { + recognizer.decode(stream) + } + + val tailPaddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds + stream.acceptWaveform(tailPaddings, sampleRate = sampleRate) + stream.inputFinished() + while (recognizer.isReady(stream)) { + recognizer.decode(stream) + } + + println("results: ${recognizer.getResult(stream).text}") + + stream.release() + recognizer.release() +} diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_online_punctuation.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_online_punctuation.kt new file mode 100644 index 00000000..2e6272a9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_online_punctuation.kt @@ -0,0 +1,30 @@ +package com.k2fsa.sherpa.mnn + +fun main() { + testPunctuation() +} + +// https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 +fun testPunctuation() { + val config = OnlinePunctuationConfig( + model=OnlinePunctuationModelConfig( + cnnBilstm="./sherpa-onnx-online-punct-en-2024-08-06/model.int8.onnx", + bpeVocab="./sherpa-onnx-online-punct-en-2024-08-06/bpe.vocab", + numThreads=1, + debug=true, + provider="cpu", + ) + ) + val punct = OnlinePunctuation(config = config) + val sentences = arrayOf( + "how are you doing fantastic thank you what is about you", + "The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry", + ) + println("---") + for (text in sentences) { + val out = punct.addPunctuation(text) + println("Input: $text") + println("Output: $out") + println("---") + } +} diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_speaker_id.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_speaker_id.kt new file mode 100644 index 00000000..a7c9f1f8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_speaker_id.kt @@ -0,0 +1,62 @@ +package com.k2fsa.sherpa.mnn + +fun main() { + testSpeakerRecognition() +} + +fun testSpeakerRecognition() { + val config = SpeakerEmbeddingExtractorConfig( + model="./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx", + ) + val extractor = SpeakerEmbeddingExtractor(config = config) + + val embedding1a = computeEmbedding(extractor, "./speaker1_a_cn_16k.wav") + val embedding2a = computeEmbedding(extractor, "./speaker2_a_cn_16k.wav") + val embedding1b = computeEmbedding(extractor, "./speaker1_b_cn_16k.wav") + + var manager = SpeakerEmbeddingManager(extractor.dim()) + var ok = manager.add(name = "speaker1", embedding=embedding1a) + check(ok) + + manager.add(name = "speaker2", embedding=embedding2a) + check(ok) + + var name = manager.search(embedding=embedding1b, threshold=0.5f) + check(name == "speaker1") + + manager.release() + + manager = SpeakerEmbeddingManager(extractor.dim()) + val embeddingList = mutableListOf(embedding1a, embedding1b) + ok = manager.add(name = "s1", embedding=embeddingList.toTypedArray()) + check(ok) + + name = manager.search(embedding=embedding1b, threshold=0.5f) + check(name == "s1") + + name = manager.search(embedding=embedding2a, threshold=0.5f) + check(name.length == 0) + + manager.release() + extractor.release() + println("Speaker ID test done!") +} + +fun computeEmbedding(extractor: SpeakerEmbeddingExtractor, filename: String): FloatArray { + var objArray = WaveReader.readWaveFromFile( + filename = filename, + ) + var samples: FloatArray = objArray[0] as FloatArray + var sampleRate: Int = objArray[1] as Int + + val stream = extractor.createStream() + stream.acceptWaveform(sampleRate = sampleRate, samples=samples) + stream.inputFinished() + check(extractor.isReady(stream)) + + val embedding = extractor.compute(stream) + + stream.release() + + return embedding +} diff --git a/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_tts.kt b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_tts.kt new file mode 100644 index 00000000..b7c0991e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/kotlin-api-examples/test_tts.kt @@ -0,0 +1,141 @@ +package com.k2fsa.sherpa.mnn + +fun main() { + testVits() + testMatcha() + testKokoroEn() + testKokoroZhEn() +} + +fun testKokoroZhEn() { + // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models + var config = OfflineTtsConfig( + model=OfflineTtsModelConfig( + kokoro=OfflineTtsKokoroModelConfig( + model="./kokoro-multi-lang-v1_0/model.onnx", + voices="./kokoro-multi-lang-v1_0/voices.bin", + tokens="./kokoro-multi-lang-v1_0/tokens.txt", + dataDir="./kokoro-multi-lang-v1_0/espeak-ng-data", + dictDir="./kokoro-multi-lang-v1_0/dict", + lexicon="./kokoro-multi-lang-v1_0/lexicon-us-en.txt,./kokoro-multi-lang-v1_0/lexicon-zh.txt", + ), + numThreads=2, + debug=true, + ), + ) + val tts = OfflineTts(config=config) + val audio = tts.generateWithCallback(text="中英文语音合成测试。This is generated by next generation Kaldi using Kokoro without Misaki. 你觉得中英文说的如何呢?", callback=::callback) + audio.save(filename="test-kokoro-zh-en.wav") + tts.release() + println("Saved to test-kokoro-zh-en.wav") +} + +fun testKokoroEn() { + // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models + var config = OfflineTtsConfig( + model=OfflineTtsModelConfig( + kokoro=OfflineTtsKokoroModelConfig( + model="./kokoro-en-v0_19/model.onnx", + voices="./kokoro-en-v0_19/voices.bin", + tokens="./kokoro-en-v0_19/tokens.txt", + dataDir="./kokoro-en-v0_19/espeak-ng-data", + ), + numThreads=2, + debug=true, + ), + ) + val tts = OfflineTts(config=config) + val audio = tts.generateWithCallback(text="How are you doing today?", callback=::callback) + audio.save(filename="test-kokoro-en.wav") + tts.release() + println("Saved to test-kokoro-en.wav") +} + +fun testMatcha() { + // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models + // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2 + var config = OfflineTtsConfig( + model=OfflineTtsModelConfig( + matcha=OfflineTtsMatchaModelConfig( + acousticModel="./matcha-icefall-zh-baker/model-steps-3.onnx", + vocoder="./hifigan_v2.onnx", + tokens="./matcha-icefall-zh-baker/tokens.txt", + lexicon="./matcha-icefall-zh-baker/lexicon.txt", + dictDir="./matcha-icefall-zh-baker/dict", + ), + numThreads=1, + debug=true, + ), + ruleFsts="./matcha-icefall-zh-baker/phone.fst,./matcha-icefall-zh-baker/date.fst,./matcha-icefall-zh-baker/number.fst", + ) + val tts = OfflineTts(config=config) + val audio = tts.generateWithCallback(text="某某银行的副行长和一些行政领导表示,他们去过长江和长白山; 经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。", callback=::callback) + audio.save(filename="test-matcha-zh.wav") + tts.release() + println("Saved to test-matcha-zh.wav") +} + +fun testVits() { + // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models + // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 + var config = OfflineTtsConfig( + model=OfflineTtsModelConfig( + vits=OfflineTtsVitsModelConfig( + model="./vits-piper-en_US-amy-low/en_US-amy-low.onnx", + tokens="./vits-piper-en_US-amy-low/tokens.txt", + dataDir="./vits-piper-en_US-amy-low/espeak-ng-data", + ), + numThreads=1, + debug=true, + ) + ) + val tts = OfflineTts(config=config) + val audio = tts.generateWithCallback(text="“Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.”", callback=::callback) + audio.save(filename="test-en.wav") + tts.release() + println("Saved to test-en.wav") +} + +/* +1. Unzip test_tts.jar +2. +javap ./com.k2fsa.sherpa.mnn/Test_ttsKt\$testTts\$audio\$1.class + +3. It prints: +Compiled from "test_tts.kt" +final class com.k2fsa.sherpa.mnn.Test_ttsKt$testTts$audio$1 extends kotlin.jvm.internal.FunctionReferenceImpl implements kotlin.jvm.functions.Function1 { + public static final com.k2fsa.sherpa.mnn.Test_ttsKt$testTts$audio$1 INSTANCE; + com.k2fsa.sherpa.mnn.Test_ttsKt$testTts$audio$1(); + public final java.lang.Integer invoke(float[]); + public java.lang.Object invoke(java.lang.Object); + static {}; +} + +4. +javap -s ./com.k2fsa.sherpa.mnn/Test_ttsKt\$testTts\$audio\$1.class + +5. It prints +Compiled from "test_tts.kt" +final class com.k2fsa.sherpa.mnn.Test_ttsKt$testTts$audio$1 extends kotlin.jvm.internal.FunctionReferenceImpl implements kotlin.jvm.functions.Function1 { + public static final com.k2fsa.sherpa.mnn.Test_ttsKt$testTts$audio$1 INSTANCE; + descriptor: Lcom.k2fsa.sherpa.mnn/Test_ttsKt$testTts$audio$1; + com.k2fsa.sherpa.mnn.Test_ttsKt$testTts$audio$1(); + descriptor: ()V + + public final java.lang.Integer invoke(float[]); + descriptor: ([F)Ljava/lang/Integer; + + public java.lang.Object invoke(java.lang.Object); + descriptor: (Ljava/lang/Object;)Ljava/lang/Object; + + static {}; + descriptor: ()V +} +*/ +fun callback(samples: FloatArray): Int { + println("callback got called with ${samples.size} samples"); + + // 1 means to continue + // 0 means to stop + return 1 +} diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/README.md b/apps/frameworks/sherpa-mnn/python-api-examples/README.md new file mode 100644 index 00000000..24176bea --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/README.md @@ -0,0 +1,12 @@ +# File description + +- [./http_server.py](./http_server.py) It defines which files to server. + Files are saved in [./web](./web). +- [non_streaming_server.py](./non_streaming_server.py) WebSocket server for + non-streaming models. +- [vad-remove-non-speech-segments.py](./vad-remove-non-speech-segments.py) It uses + [silero-vad](https://github.com/snakers4/silero-vad) to remove non-speech + segments and concatenate all speech segments into a single one. +- [vad-with-non-streaming-asr.py](./vad-with-non-streaming-asr.py) It shows + how to use VAD with a non-streaming ASR model for speech recognition from + a microphone diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/add-punctuation-online.py b/apps/frameworks/sherpa-mnn/python-api-examples/add-punctuation-online.py new file mode 100755 index 00000000..89dc204c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/add-punctuation-online.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + +""" +This script shows how to add punctuations to text using sherpa-onnx Python API. + +Please download the model from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models + +The following is an example + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 +tar xvf sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 +rm sherpa-onnx-online-punct-en-2024-08-06.tar.bz2 +""" + +from pathlib import Path + +import sherpa_mnn + + +def main(): + model = "./sherpa-onnx-online-punct-en-2024-08-06/model.onnx" + bpe = "./sherpa-onnx-online-punct-en-2024-08-06/bpe.vocab" + if not Path(model).is_file(): + raise ValueError(f"{model} does not exist") + if not Path(bpe).is_file(): + raise ValueError(f"{bpe} does not exist") + + model_config = sherpa_mnn.OnlinePunctuationModelConfig( + cnn_bilstm=model, bpe_vocab=bpe + ) + config = sherpa_mnn.OnlinePunctuationConfig(model_config=model_config) + punct = sherpa_mnn.OnlinePunctuation(config) + + texts = [ + "how are you i am fine thank you", + "The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry", + ] + for text in texts: + text_with_punct = punct.add_punctuation_with_case(text) + print("----------") + print(f"input : {text}") + print(f"output: {text_with_punct}") + print("----------") + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/add-punctuation.py b/apps/frameworks/sherpa-mnn/python-api-examples/add-punctuation.py new file mode 100755 index 00000000..96c67734 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/add-punctuation.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +""" +This script shows how to add punctuations to text using sherpa-onnx Python API. + +Please download the model from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models + +The following is an example + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +""" + +from pathlib import Path + +import sherpa_mnn + + +def main(): + model = "./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx" + if not Path(model).is_file(): + raise ValueError(f"{model} does not exist") + config = sherpa_mnn.OfflinePunctuationConfig( + model=sherpa_mnn.OfflinePunctuationModelConfig(ct_transformer=model), + ) + + punct = sherpa_mnn.OfflinePunctuation(config) + + text_list = [ + "这是一个测试你好吗How are you我很好thank you are you ok谢谢你", + "我们都是木头人不会说话不会动", + "The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry", + ] + for text in text_list: + text_with_punct = punct.add_punctuation(text) + print("----------") + print(f"input: {text}") + print(f"output: {text_with_punct}") + + print("----------") + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/audio-tagging-from-a-file-ced.py b/apps/frameworks/sherpa-mnn/python-api-examples/audio-tagging-from-a-file-ced.py new file mode 100755 index 00000000..efc8f7cc --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/audio-tagging-from-a-file-ced.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 + +""" +This script shows how to use audio tagging Python APIs to tag a file. + +Please read the code to download the required model files and test wave file. +""" + +import logging +import time +from pathlib import Path + +import numpy as np +import sherpa_mnn +import soundfile as sf + + +def read_test_wave(): + # Please download the model files and test wave files from + # https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models + test_wave = "./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/test_wavs/6.wav" + + if not Path(test_wave).is_file(): + raise ValueError( + f"Please download {test_wave} from " + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models" + ) + + # See https://python-soundfile.readthedocs.io/en/0.11.0/#soundfile.read + data, sample_rate = sf.read( + test_wave, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + samples = np.ascontiguousarray(data) + + # samples is a 1-d array of dtype float32 + # sample_rate is a scalar + return samples, sample_rate + + +def create_audio_tagger(): + # Please download the model files and test wave files from + # https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models + model_file = "./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/model.int8.onnx" + label_file = ( + "./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/class_labels_indices.csv" + ) + + if not Path(model_file).is_file(): + raise ValueError( + f"Please download {model_file} from " + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models" + ) + + if not Path(label_file).is_file(): + raise ValueError( + f"Please download {label_file} from " + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models" + ) + + config = sherpa_mnn.AudioTaggingConfig( + model=sherpa_mnn.AudioTaggingModelConfig( + ced=model_file, + num_threads=1, + debug=True, + provider="cpu", + ), + labels=label_file, + top_k=5, + ) + if not config.validate(): + raise ValueError(f"Please check the config: {config}") + + print(config) + + return sherpa_mnn.AudioTagging(config) + + +def main(): + logging.info("Create audio tagger") + audio_tagger = create_audio_tagger() + + logging.info("Read test wave") + samples, sample_rate = read_test_wave() + + logging.info("Computing") + + start_time = time.time() + + stream = audio_tagger.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=samples) + result = audio_tagger.compute(stream) + end_time = time.time() + + elapsed_seconds = end_time - start_time + audio_duration = len(samples) / sample_rate + + real_time_factor = elapsed_seconds / audio_duration + logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}") + logging.info(f"Audio duration in seconds: {audio_duration:.3f}") + logging.info( + f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}" + ) + + s = "\n" + for i, e in enumerate(result): + s += f"{i}: {e}\n" + + logging.info(s) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/audio-tagging-from-a-file.py b/apps/frameworks/sherpa-mnn/python-api-examples/audio-tagging-from-a-file.py new file mode 100755 index 00000000..fd57f5da --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/audio-tagging-from-a-file.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 + +""" +This script shows how to use audio tagging Python APIs to tag a file. + +Please read the code to download the required model files and test wave file. +""" + +import logging +import time +from pathlib import Path + +import numpy as np +import sherpa_mnn +import soundfile as sf + + +def read_test_wave(): + # Please download the model files and test wave files from + # https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models + test_wave = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/1.wav" + + if not Path(test_wave).is_file(): + raise ValueError( + f"Please download {test_wave} from " + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models" + ) + + # See https://python-soundfile.readthedocs.io/en/0.11.0/#soundfile.read + data, sample_rate = sf.read( + test_wave, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + samples = np.ascontiguousarray(data) + + # samples is a 1-d array of dtype float32 + # sample_rate is a scalar + return samples, sample_rate + + +def create_audio_tagger(): + # Please download the model files and test wave files from + # https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models + model_file = "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx" + label_file = ( + "./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv" + ) + + if not Path(model_file).is_file(): + raise ValueError( + f"Please download {model_file} from " + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models" + ) + + if not Path(label_file).is_file(): + raise ValueError( + f"Please download {label_file} from " + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models" + ) + + config = sherpa_mnn.AudioTaggingConfig( + model=sherpa_mnn.AudioTaggingModelConfig( + zipformer=sherpa_mnn.OfflineZipformerAudioTaggingModelConfig( + model=model_file, + ), + num_threads=1, + debug=True, + provider="cpu", + ), + labels=label_file, + top_k=5, + ) + if not config.validate(): + raise ValueError(f"Please check the config: {config}") + + print(config) + + return sherpa_mnn.AudioTagging(config) + + +def main(): + logging.info("Create audio tagger") + audio_tagger = create_audio_tagger() + + logging.info("Read test wave") + samples, sample_rate = read_test_wave() + + logging.info("Computing") + + start_time = time.time() + + stream = audio_tagger.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=samples) + result = audio_tagger.compute(stream) + end_time = time.time() + + elapsed_seconds = end_time - start_time + audio_duration = len(samples) / sample_rate + + real_time_factor = elapsed_seconds / audio_duration + logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}") + logging.info(f"Audio duration in seconds: {audio_duration:.3f}") + logging.info( + f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}" + ) + + s = "\n" + for i, e in enumerate(result): + s += f"{i}: {e}\n" + + logging.info(s) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/generate-subtitles.py b/apps/frameworks/sherpa-mnn/python-api-examples/generate-subtitles.py new file mode 100755 index 00000000..20a29dc8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/generate-subtitles.py @@ -0,0 +1,581 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 Xiaomi Corporation + +""" +This file demonstrates how to use sherpa-onnx Python APIs to generate +subtitles. + +Supported file formats are those supported by ffmpeg; for instance, +*.mov, *.mp4, *.wav, etc. + +Note that you need a non-streaming model for this script. + +Please visit +https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx +to download silero_vad.onnx + +For instance, + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx + +(1) For paraformer + + ./python-api-examples/generate-subtitles.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/paraformer.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + /path/to/test.mp4 + +(2) For transducer models from icefall + + ./python-api-examples/generate-subtitles.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + /path/to/test.mp4 + +(3) For Moonshine models + +./python-api-examples/generate-subtitles.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \ + --moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \ + --moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \ + --moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \ + --tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \ + --num-threads=2 \ + /path/to/test.mp4 + +(4) For Whisper models + +./python-api-examples/generate-subtitles.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ + --whisper-task=transcribe \ + --num-threads=2 \ + /path/to/test.mp4 + +(5) For SenseVoice CTC models + +./python-api-examples/generate-subtitles.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --sense-voice=./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.onnx \ + --tokens=./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt \ + --num-threads=2 \ + /path/to/test.mp4 + + +(6) For WeNet CTC models + +./python-api-examples/generate-subtitles.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \ + --num-threads=2 \ + /path/to/test.mp4 + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/index.html +to install sherpa-onnx and to download non-streaming pre-trained models +used in this file. +""" +import argparse +import datetime as dt +import shutil +import subprocess +import sys +from dataclasses import dataclass +from datetime import timedelta +from pathlib import Path + +import numpy as np +import sherpa_mnn + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--silero-vad-model", + type=str, + required=True, + help="Path to silero_vad.onnx", + ) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder", + default="", + type=str, + help="Path to the transducer encoder model", + ) + + parser.add_argument( + "--decoder", + default="", + type=str, + help="Path to the transducer decoder model", + ) + + parser.add_argument( + "--joiner", + default="", + type=str, + help="Path to the transducer joiner model", + ) + + parser.add_argument( + "--paraformer", + default="", + type=str, + help="Path to the model.onnx from Paraformer", + ) + + parser.add_argument( + "--sense-voice", + default="", + type=str, + help="Path to the model.onnx from SenseVoice", + ) + + parser.add_argument( + "--wenet-ctc", + default="", + type=str, + help="Path to the CTC model.onnx from WeNet", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=2, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--whisper-encoder", + default="", + type=str, + help="Path to whisper encoder model", + ) + + parser.add_argument( + "--whisper-decoder", + default="", + type=str, + help="Path to whisper decoder model", + ) + + parser.add_argument( + "--whisper-language", + default="", + type=str, + help="""It specifies the spoken language in the input file. + Example values: en, fr, de, zh, jp. + Available languages for multilingual models can be found at + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + If not specified, we infer the language from the input audio file. + """, + ) + + parser.add_argument( + "--whisper-task", + default="transcribe", + choices=["transcribe", "translate"], + type=str, + help="""For multilingual models, if you specify translate, the output + will be in English. + """, + ) + + parser.add_argument( + "--whisper-tail-paddings", + default=-1, + type=int, + help="""Number of tail padding frames. + We have removed the 30-second constraint from whisper, so you need to + choose the amount of tail padding frames by yourself. + Use -1 to use a default value for tail padding. + """, + ) + + parser.add_argument( + "--moonshine-preprocessor", + default="", + type=str, + help="Path to moonshine preprocessor model", + ) + + parser.add_argument( + "--moonshine-encoder", + default="", + type=str, + help="Path to moonshine encoder model", + ) + + parser.add_argument( + "--moonshine-uncached-decoder", + default="", + type=str, + help="Path to moonshine uncached decoder model", + ) + + parser.add_argument( + "--moonshine-cached-decoder", + default="", + type=str, + help="Path to moonshine cached decoder model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Valid values are greedy_search and modified_beam_search. + modified_beam_search is valid only for transducer models. + """, + ) + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages when loading modes.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="""Sample rate of the feature extractor. Must match the one + expected by the model. Note: The input sound files can have a + different sample rate from this argument.""", + ) + + parser.add_argument( + "--feature-dim", + type=int, + default=80, + help="Feature dimension. Must match the one expected by the model", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to generate subtitles ", + ) + + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def create_recognizer(args) -> sherpa_mnn.OfflineRecognizer: + if args.encoder: + assert len(args.paraformer) == 0, args.paraformer + assert len(args.sense_voice) == 0, args.sense_voice + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + + recognizer = sherpa_mnn.OfflineRecognizer.from_transducer( + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + elif args.paraformer: + assert len(args.sense_voice) == 0, args.sense_voice + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + assert_file_exists(args.paraformer) + + recognizer = sherpa_mnn.OfflineRecognizer.from_paraformer( + paraformer=args.paraformer, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + elif args.sense_voice: + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + assert_file_exists(args.sense_voice) + recognizer = sherpa_mnn.OfflineRecognizer.from_sense_voice( + model=args.sense_voice, + tokens=args.tokens, + num_threads=args.num_threads, + use_itn=True, + debug=args.debug, + ) + elif args.wenet_ctc: + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + assert_file_exists(args.wenet_ctc) + + recognizer = sherpa_mnn.OfflineRecognizer.from_wenet_ctc( + model=args.wenet_ctc, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + elif args.whisper_encoder: + assert_file_exists(args.whisper_encoder) + assert_file_exists(args.whisper_decoder) + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + recognizer = sherpa_mnn.OfflineRecognizer.from_whisper( + encoder=args.whisper_encoder, + decoder=args.whisper_decoder, + tokens=args.tokens, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + debug=args.debug, + language=args.whisper_language, + task=args.whisper_task, + tail_paddings=args.whisper_tail_paddings, + ) + elif args.moonshine_preprocessor: + assert_file_exists(args.moonshine_preprocessor) + assert_file_exists(args.moonshine_encoder) + assert_file_exists(args.moonshine_uncached_decoder) + assert_file_exists(args.moonshine_cached_decoder) + + recognizer = sherpa_mnn.OfflineRecognizer.from_moonshine( + preprocessor=args.moonshine_preprocessor, + encoder=args.moonshine_encoder, + uncached_decoder=args.moonshine_uncached_decoder, + cached_decoder=args.moonshine_cached_decoder, + tokens=args.tokens, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + debug=args.debug, + ) + else: + raise ValueError("Please specify at least one model") + + return recognizer + + +@dataclass +class Segment: + start: float + duration: float + text: str = "" + + @property + def end(self): + return self.start + self.duration + + def __str__(self): + s = f"{timedelta(seconds=self.start)}"[:-3] + s += " --> " + s += f"{timedelta(seconds=self.end)}"[:-3] + s = s.replace(".", ",") + s += "\n" + s += self.text + return s + + +def main(): + args = get_args() + assert_file_exists(args.tokens) + assert_file_exists(args.silero_vad_model) + + assert args.num_threads > 0, args.num_threads + + if not Path(args.sound_file).is_file(): + raise ValueError(f"{args.sound_file} does not exist") + + assert ( + args.sample_rate == 16000 + ), f"Only sample rate 16000 is supported.Given: {args.sample_rate}" + + recognizer = create_recognizer(args) + + ffmpeg_cmd = [ + "ffmpeg", + "-i", + args.sound_file, + "-f", + "s16le", + "-acodec", + "pcm_s16le", + "-ac", + "1", + "-ar", + str(args.sample_rate), + "-", + ] + + process = subprocess.Popen( + ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL + ) + + frames_per_read = int(args.sample_rate * 100) # 100 second + + stream = recognizer.create_stream() + + config = sherpa_mnn.VadModelConfig() + config.silero_vad.model = args.silero_vad_model + config.silero_vad.threshold = 0.5 + config.silero_vad.min_silence_duration = 0.25 # seconds + config.silero_vad.min_speech_duration = 0.25 # seconds + + # If the current segment is larger than this value, then it increases + # the threshold to 0.9 internally. After detecting this segment, + # it resets the threshold to its original value. + config.silero_vad.max_speech_duration = 5 # seconds + config.sample_rate = args.sample_rate + + window_size = config.silero_vad.window_size + + buffer = [] + vad = sherpa_mnn.VoiceActivityDetector(config, buffer_size_in_seconds=100) + + segment_list = [] + + print("Started!") + start_t = dt.datetime.now() + num_processed_samples = 0 + + is_eof = False + # TODO(fangjun): Support multithreads + while not is_eof: + # *2 because int16_t has two bytes + data = process.stdout.read(frames_per_read * 2) + if not data: + vad.flush() + is_eof = True + else: + samples = np.frombuffer(data, dtype=np.int16) + samples = samples.astype(np.float32) / 32768 + + num_processed_samples += samples.shape[0] + + buffer = np.concatenate([buffer, samples]) + while len(buffer) > window_size: + vad.accept_waveform(buffer[:window_size]) + buffer = buffer[window_size:] + + streams = [] + segments = [] + while not vad.empty(): + segment = Segment( + start=vad.front.start / args.sample_rate, + duration=len(vad.front.samples) / args.sample_rate, + ) + segments.append(segment) + + stream = recognizer.create_stream() + stream.accept_waveform(args.sample_rate, vad.front.samples) + + streams.append(stream) + + vad.pop() + + for s in streams: + recognizer.decode_stream(s) + + for seg, stream in zip(segments, streams): + seg.text = stream.result.text + segment_list.append(seg) + + end_t = dt.datetime.now() + elapsed_seconds = (end_t - start_t).total_seconds() + duration = num_processed_samples / 16000 + rtf = elapsed_seconds / duration + + srt_filename = Path(args.sound_file).with_suffix(".srt") + with open(srt_filename, "w", encoding="utf-8") as f: + for i, seg in enumerate(segment_list): + print(i + 1, file=f) + print(seg, file=f) + print("", file=f) + + print(f"Saved to {srt_filename}") + print(f"Audio duration:\t{duration:.3f} s") + print(f"Elapsed:\t{elapsed_seconds:.3f} s") + print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") + print("Done!") + + +if __name__ == "__main__": + if shutil.which("ffmpeg") is None: + sys.exit("Please install ffmpeg first!") + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/http_server.py b/apps/frameworks/sherpa-mnn/python-api-examples/http_server.py new file mode 100644 index 00000000..b67154ff --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/http_server.py @@ -0,0 +1,82 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple + +# Please sort it alphabetically +_static_files = ( + ("/css/bootstrap.min.css", "text/css"), + ("/css/bootstrap.min.css.map", "text/css"), + ("/index.html", "text/html"), + ("/js/bootstrap.min.js", "application/javascript"), + ("/js/bootstrap.min.js.map", "application/javascript"), + ("/js/jquery-3.6.0.min.js", "application/javascript"), + ("/js/offline_record.js", "application/javascript"), + ("/js/offline_record.js", "application/javascript"), + ("/js/popper.min.js", "application/javascript"), + ("/js/popper.min.js.map", "application/javascript"), + ("/js/streaming_record.js", "application/javascript"), + ("/js/upload.js", "application/javascript"), + ("/k2-logo.png", "image/png"), + ("/nav-partial.html", "text/html"), + ("/offline_record.html", "text/html"), + ("/streaming_record.html", "text/html"), + ("/upload.html", "text/html"), +) + +_404_page = r""" + +Speech recognition with next-gen Kaldi +

404 ERROR! Please re-check your URL

+ +""" + + +def read_file(root: str, name: str) -> str: + try: + with open(f"{root}/{name}") as f: + return f.read() + except: # noqa + with open(f"{root}/{name}", "rb") as f: + return f.read() + + +class HttpServer: + """ + A simple HTTP server that hosts only static files + """ + + def __init__(self, doc_root: str): + content = dict() + for f, mime_type in _static_files: + content[f] = (read_file(doc_root, f), mime_type) + self.content = content + + def process_request(self, f: str) -> Tuple[str, str, str]: + """ + Args: + f: + The filename to read. + Returns: + Return a tuple: + - a bool, True if the given file is found. False otherwise. + - a str, the content of the file if found. Otherwise, it + contains the content for the 404 page + - a str, the MIME type of the returned content + """ + if f in self.content: + return True, self.content[f][0], self.content[f][1] + else: + return False, _404_page, "text/html" diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/inverse-text-normalization-offline-asr.py b/apps/frameworks/sherpa-mnn/python-api-examples/inverse-text-normalization-offline-asr.py new file mode 100755 index 00000000..33b2a073 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/inverse-text-normalization-offline-asr.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2024 Xiaomi Corporation + +""" +This script shows how to use inverse text normalization with non-streaming ASR. + +Usage: + +(1) Download the test model + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 +tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 +rm sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 + +(2) Download rule fst + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst + +Please refer to +https://github.com/k2-fsa/colab/blob/master/sherpa-onnx/itn_zh_number.ipynb +for how itn_zh_number.fst is generated. + +(3) Download test wave + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav + +(4) Run this script + +python3 ./python-api-examples/inverse-text-normalization-offline-asr.py +""" +from pathlib import Path + +import sherpa_mnn +import soundfile as sf + + +def create_recognizer(): + model = "./sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx" + tokens = "./sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt" + rule_fsts = "./itn_zh_number.fst" + + if ( + not Path(model).is_file() + or not Path(tokens).is_file() + or not Path(rule_fsts).is_file() + ): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return sherpa_mnn.OfflineRecognizer.from_paraformer( + paraformer=model, + tokens=tokens, + debug=True, + rule_fsts=rule_fsts, + ) + + +def main(): + recognizer = create_recognizer() + wave_filename = "./itn-zh-number.wav" + if not Path(wave_filename).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + recognizer.decode_stream(stream) + print(wave_filename) + print(stream.result) + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/inverse-text-normalization-online-asr.py b/apps/frameworks/sherpa-mnn/python-api-examples/inverse-text-normalization-online-asr.py new file mode 100755 index 00000000..1c89c51d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/inverse-text-normalization-online-asr.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2024 Xiaomi Corporation + +""" +This script shows how to use inverse text normalization with streaming ASR. + +Usage: + +(1) Download the test model + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 + +(2) Download rule fst + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn_zh_number.fst + +Please refer to +https://github.com/k2-fsa/colab/blob/master/sherpa-onnx/itn_zh_number.ipynb +for how itn_zh_number.fst is generated. + +(3) Download test wave + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/itn-zh-number.wav + +(4) Run this script + +python3 ./python-api-examples/inverse-text-normalization-online-asr.py +""" +from pathlib import Path + +import sherpa_mnn +import soundfile as sf + + +def create_recognizer(): + encoder = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx" + decoder = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" + joiner = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx" + tokens = "./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt" + rule_fsts = "./itn_zh_number.fst" + + if ( + not Path(encoder).is_file() + or not Path(decoder).is_file() + or not Path(joiner).is_file() + or not Path(tokens).is_file() + or not Path(rule_fsts).is_file() + ): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return sherpa_mnn.OnlineRecognizer.from_transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + debug=True, + rule_fsts=rule_fsts, + ) + + +def main(): + recognizer = create_recognizer() + wave_filename = "./itn-zh-number.wav" + if not Path(wave_filename).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + + tail_padding = [0] * int(0.3 * sample_rate) + stream.accept_waveform(sample_rate, tail_padding) + + while recognizer.is_ready(stream): + recognizer.decode_stream(stream) + + print(wave_filename) + print(recognizer.get_result_all(stream)) + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/keyword-spotter-from-microphone.py b/apps/frameworks/sherpa-mnn/python-api-examples/keyword-spotter-from-microphone.py new file mode 100755 index 00000000..3565a8f0 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/keyword-spotter-from-microphone.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 + +# Real-time keyword spotting from a microphone with sherpa-onnx Python API +# +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html +# to download pre-trained models + +import argparse +import sys +from pathlib import Path + +from typing import List + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +import sherpa_mnn + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html to download it" + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder", + type=str, + help="Path to the transducer encoder model", + ) + + parser.add_argument( + "--decoder", + type=str, + help="Path to the transducer decoder model", + ) + + parser.add_argument( + "--joiner", + type=str, + help="Path to the transducer joiner model", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + parser.add_argument( + "--max-active-paths", + type=int, + default=4, + help=""" + It specifies number of active paths to keep during decoding. + """, + ) + + parser.add_argument( + "--num-trailing-blanks", + type=int, + default=1, + help="""The number of trailing blanks a keyword should be followed. Setting + to a larger value (e.g. 8) when your keywords has overlapping tokens + between each other. + """, + ) + + parser.add_argument( + "--keywords-file", + type=str, + help=""" + The file containing keywords, one words/phrases per line, and for each + phrase the bpe/cjkchar/pinyin are separated by a space. For example: + + ▁HE LL O ▁WORLD + x iǎo ài t óng x ué + """, + ) + + parser.add_argument( + "--keywords-score", + type=float, + default=1.0, + help=""" + The boosting score of each token for keywords. The larger the easier to + survive beam search. + """, + ) + + parser.add_argument( + "--keywords-threshold", + type=float, + default=0.25, + help=""" + The trigger threshold (i.e. probability) of the keyword. The larger the + harder to trigger. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + sys.exit(0) + + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + assert_file_exists(args.tokens) + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + + assert Path( + args.keywords_file + ).is_file(), ( + f"keywords_file : {args.keywords_file} not exist, please provide a valid path." + ) + + keyword_spotter = sherpa_mnn.KeywordSpotter( + tokens=args.tokens, + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + num_threads=args.num_threads, + max_active_paths=args.max_active_paths, + keywords_file=args.keywords_file, + keywords_score=args.keywords_score, + keywords_threshold=args.keywords_threshold, + num_trailing_blanks=args.num_trailing_blanks, + provider=args.provider, + ) + + print("Started! Please speak") + + idx = 0 + + sample_rate = 16000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + stream = keyword_spotter.create_stream() + with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + stream.accept_waveform(sample_rate, samples) + while keyword_spotter.is_ready(stream): + keyword_spotter.decode_stream(stream) + result = keyword_spotter.get_result(stream) + if result: + print(f"{idx}: {result }") + idx += 1 + # Remember to reset stream right after detecting a keyword + keyword_spotter.reset_stream(stream) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/keyword-spotter.py b/apps/frameworks/sherpa-mnn/python-api-examples/keyword-spotter.py new file mode 100755 index 00000000..93be8758 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/keyword-spotter.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 + +""" +This file demonstrates how to use sherpa-onnx Python API to do keyword spotting +from wave file(s). + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html +to download pre-trained models. +""" +import argparse +import time +import wave +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import sherpa_mnn + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +def create_keyword_spotter(): + kws = sherpa_mnn.KeywordSpotter( + tokens="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt", + encoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx", + decoder="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx", + joiner="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx", + num_threads=2, + keywords_file="./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt", + provider="cpu", + ) + + return kws + + +def main(): + kws = create_keyword_spotter() + + wave_filename = ( + "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav" + ) + + samples, sample_rate = read_wave(wave_filename) + + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) + + print("----------Use pre-defined keywords----------") + s = kws.create_stream() + s.accept_waveform(sample_rate, samples) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + while kws.is_ready(s): + kws.decode_stream(s) + r = kws.get_result(s) + if r != "": + # Remember to call reset right after detected a keyword + kws.reset_stream(s) + + print(f"Detected {r}") + + print("----------Use pre-defined keywords + add a new keyword----------") + + s = kws.create_stream("y ǎn y uán @演员") + s.accept_waveform(sample_rate, samples) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + while kws.is_ready(s): + kws.decode_stream(s) + r = kws.get_result(s) + if r != "": + # Remember to call reset right after detected a keyword + kws.reset_stream(s) + + print(f"Detected {r}") + + print("----------Use pre-defined keywords + add 2 new keywords----------") + + s = kws.create_stream("y ǎn y uán @演员/zh ī m íng @知名") + s.accept_waveform(sample_rate, samples) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + while kws.is_ready(s): + kws.decode_stream(s) + r = kws.get_result(s) + if r != "": + # Remember to call reset right after detected a keyword + kws.reset_stream(s) + + print(f"Detected {r}") + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/non_streaming_server.py b/apps/frameworks/sherpa-mnn/python-api-examples/non_streaming_server.py new file mode 100755 index 00000000..2ca26ef2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/non_streaming_server.py @@ -0,0 +1,1193 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. +""" +A server for non-streaming speech recognition. Non-streaming means you send all +the content of the audio at once for recognition. + +It supports multiple clients sending at the same time. + +Usage: + ./non_streaming_server.py --help + +Please refer to + +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html + +for pre-trained models to download. + +Usage examples: + +(1) Use a non-streaming transducer model + +cd /path/to/sherpa-onnx +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-en-2023-06-26.tar.bz2 +tar xvf sherpa-onnx-zipformer-en-2023-06-26.tar.bz2 +rm sherpa-onnx-zipformer-en-2023-06-26.tar.bz2 + +python3 ./python-api-examples/non_streaming_server.py \ + --encoder ./sherpa-onnx-zipformer-en-2023-06-26/encoder-epoch-99-avg-1.onnx \ + --decoder ./sherpa-onnx-zipformer-en-2023-06-26/decoder-epoch-99-avg-1.onnx \ + --joiner ./sherpa-onnx-zipformer-en-2023-06-26/joiner-epoch-99-avg-1.onnx \ + --tokens ./sherpa-onnx-zipformer-en-2023-06-26/tokens.txt \ + --port 6006 + +(2) Use a non-streaming paraformer + +cd /path/to/sherpa-onnx +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 +tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 +rm sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 + +python3 ./python-api-examples/non_streaming_server.py \ + --paraformer ./sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx \ + --tokens ./sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt + +(3) Use a non-streaming CTC model from NeMo + +cd /path/to/sherpa-onnx +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-conformer-medium.tar.bz2 +tar xvf sherpa-onnx-nemo-ctc-en-conformer-medium.tar.bz2 +rm sherpa-onnx-nemo-ctc-en-conformer-medium.tar.bz2 + +python3 ./python-api-examples/non_streaming_server.py \ + --nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \ + --tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt + +(4) Use a non-streaming CTC model from WeNet + +cd /path/to/sherpa-onnx +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zh-wenet-wenetspeech.tar.bz2 +tar xvf sherpa-onnx-zh-wenet-wenetspeech.tar.bz2 +rm sherpa-onnx-zh-wenet-wenetspeech.tar.bz2 + +python3 ./python-api-examples/non_streaming_server.py \ + --wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ + --tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt + +(5) Use a Moonshine model + +cd /path/to/sherpa-onnx +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 + +python3 ./python-api-examples/non_streaming_server.py \ + --moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \ + --moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \ + --moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \ + --moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \ + --tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt + +(6) Use a Whisper model + +cd /path/to/sherpa-onnx +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 +tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2 +rm sherpa-onnx-whisper-tiny.en.tar.bz2 + +python3 ./python-api-examples/non_streaming_server.py \ + --whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ + --tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt + +(7) Use a tdnn model of the yesno recipe from icefall + +cd /path/to/sherpa-onnx + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-tdnn-yesno.tar.bz2 +tar xvf sherpa-onnx-tdnn-yesno.tar.bz2 +rm sherpa-onnx-tdnn-yesno.tar.bz2 + +python3 ./python-api-examples/non_streaming_server.py \ + --sample-rate=8000 \ + --feat-dim=23 \ + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt + +(8) Use a Non-streaming SenseVoice model + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 + +python3 ./python-api-examples/non_streaming_server.py \ + --sense-voice=./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx \ + --tokens=./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt + +(9) Use a Non-streaming telespeech ctc model + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2 +tar xvf sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2 +rm sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2 + +python3 ./python-api-examples/non_streaming_server.py \ + --telespeech-ctc=./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/model.int8.onnx \ + --tokens=./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/tokens.txt + +---- + +To use a certificate so that you can use https, please use + +python3 ./python-api-examples/non_streaming_server.py \ + --whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \ + --certificate=/path/to/your/cert.pem + +If you don't have a certificate, please run: + + cd ./python-api-examples/web + ./generate-certificate.py + +It will generate 3 files, one of which is the required `cert.pem`. +""" # noqa + +import argparse +import asyncio +import http +import logging +import socket +import ssl +import sys +import warnings +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from pathlib import Path +from typing import Optional, Tuple + +import numpy as np +import sherpa_mnn + +import websockets + +from http_server import HttpServer + + +def setup_logger( + log_filename: str, + log_level: str = "info", + use_console: bool = True, +) -> None: + """Setup log level. + + Args: + log_filename: + The filename to save the log. + log_level: + The log level to use, e.g., "debug", "info", "warning", "error", + "critical" + use_console: + True to also print logs to console. + """ + now = datetime.now() + date_time = now.strftime("%Y-%m-%d-%H-%M-%S") + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + log_filename = f"{log_filename}-{date_time}.txt" + + Path(log_filename).parent.mkdir(parents=True, exist_ok=True) + + level = logging.ERROR + if log_level == "debug": + level = logging.DEBUG + elif log_level == "info": + level = logging.INFO + elif log_level == "warning": + level = logging.WARNING + elif log_level == "critical": + level = logging.CRITICAL + + logging.basicConfig( + filename=log_filename, + format=formatter, + level=level, + filemode="w", + ) + if use_console: + console = logging.StreamHandler() + console.setLevel(level) + console.setFormatter(logging.Formatter(formatter)) + logging.getLogger("").addHandler(console) + + +def add_transducer_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--encoder", + default="", + type=str, + help="Path to the transducer encoder model", + ) + + parser.add_argument( + "--decoder", + default="", + type=str, + help="Path to the transducer decoder model", + ) + + parser.add_argument( + "--joiner", + default="", + type=str, + help="Path to the transducer joiner model", + ) + + +def add_paraformer_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--paraformer", + default="", + type=str, + help="Path to the model.onnx from Paraformer", + ) + + +def add_sense_voice_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--sense-voice", + default="", + type=str, + help="Path to the model.onnx from SenseVoice", + ) + + +def add_nemo_ctc_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--nemo-ctc", + default="", + type=str, + help="Path to the model.onnx from NeMo CTC", + ) + + +def add_telespeech_ctc_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--telespeech-ctc", + default="", + type=str, + help="Path to the model.onnx from TeleSpeech CTC", + ) + + +def add_wenet_ctc_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--wenet-ctc", + default="", + type=str, + help="Path to the model.onnx from WeNet CTC", + ) + + +def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--tdnn-model", + default="", + type=str, + help="Path to the model.onnx for the tdnn model of the yesno recipe", + ) + + +def add_moonshine_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--moonshine-preprocessor", + default="", + type=str, + help="Path to moonshine preprocessor model", + ) + + parser.add_argument( + "--moonshine-encoder", + default="", + type=str, + help="Path to moonshine encoder model", + ) + + parser.add_argument( + "--moonshine-uncached-decoder", + default="", + type=str, + help="Path to moonshine uncached decoder model", + ) + + parser.add_argument( + "--moonshine-cached-decoder", + default="", + type=str, + help="Path to moonshine cached decoder model", + ) + + +def add_whisper_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--whisper-encoder", + default="", + type=str, + help="Path to whisper encoder model", + ) + + parser.add_argument( + "--whisper-decoder", + default="", + type=str, + help="Path to whisper decoder model", + ) + + parser.add_argument( + "--whisper-language", + default="", + type=str, + help="""It specifies the spoken language in the input audio file. + Example values: en, fr, de, zh, jp. + Available languages for multilingual models can be found at + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + If not specified, we infer the language from the input audio file. + """, + ) + + parser.add_argument( + "--whisper-task", + default="transcribe", + choices=["transcribe", "translate"], + type=str, + help="""For multilingual models, if you specify translate, the output + will be in English. + """, + ) + + parser.add_argument( + "--whisper-tail-paddings", + default=-1, + type=int, + help="""Number of tail padding frames. + We have removed the 30-second constraint from whisper, so you need to + choose the amount of tail padding frames by yourself. + Use -1 to use a default value for tail padding. + """, + ) + + +def add_model_args(parser: argparse.ArgumentParser): + add_transducer_model_args(parser) + add_paraformer_model_args(parser) + add_sense_voice_model_args(parser) + add_nemo_ctc_model_args(parser) + add_wenet_ctc_model_args(parser) + add_telespeech_ctc_model_args(parser) + add_tdnn_ctc_model_args(parser) + add_whisper_model_args(parser) + add_moonshine_model_args(parser) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=2, + help="Number of threads to run the neural network model", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + +def add_feature_config_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="Sample rate of the data used to train the model. ", + ) + + parser.add_argument( + "--feat-dim", + type=int, + default=80, + help="Feature dimension of the model", + ) + + +def add_decoding_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Decoding method to use. Current supported methods are: + - greedy_search + - modified_beam_search (for transducer models only) + """, + ) + + add_modified_beam_search_args(parser) + + +def add_modified_beam_search_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--max-active-paths", + type=int, + default=4, + help="""Used only when --decoding-method is modified_beam_search. + It specifies number of active paths to keep during decoding. + """, + ) + + +def add_hotwords_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. For example: + + ▁HE LL O ▁WORLD + 你 好 世 界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + +def add_blank_penalty_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + +def check_args(args): + if not Path(args.tokens).is_file(): + raise ValueError(f"{args.tokens} does not exist") + + if args.decoding_method not in ( + "greedy_search", + "modified_beam_search", + ): + raise ValueError(f"Unsupported decoding method {args.decoding_method}") + + if args.decoding_method == "modified_beam_search": + assert args.num_active_paths > 0, args.num_active_paths + assert Path(args.encoder).is_file(), args.encoder + assert Path(args.decoder).is_file(), args.decoder + assert Path(args.joiner).is_file(), args.joiner + + if args.hotwords_file != "": + assert args.decoding_method == "modified_beam_search", args.decoding_method + assert Path(args.hotwords_file).is_file(), args.hotwords_file + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + add_model_args(parser) + add_feature_config_args(parser) + add_decoding_args(parser) + add_hotwords_args(parser) + add_blank_penalty_args(parser) + + parser.add_argument( + "--port", + type=int, + default=6006, + help="The server will listen on this port", + ) + + parser.add_argument( + "--max-batch-size", + type=int, + default=3, + help="""Max batch size for computation. Note if there are not enough + requests in the queue, it will wait for max_wait_ms time. After that, + even if there are not enough requests, it still sends the + available requests in the queue for computation. + """, + ) + + parser.add_argument( + "--max-wait-ms", + type=float, + default=5, + help="""Max time in millisecond to wait to build batches for inference. + If there are not enough requests in the feature queue to build a batch + of max_batch_size, it waits up to this time before fetching available + requests for computation. + """, + ) + + parser.add_argument( + "--nn-pool-size", + type=int, + default=1, + help="Number of threads for NN computation and decoding.", + ) + + parser.add_argument( + "--max-message-size", + type=int, + default=(1 << 20), + help="""Max message size in bytes. + The max size per message cannot exceed this limit. + """, + ) + + parser.add_argument( + "--max-queue-size", + type=int, + default=32, + help="Max number of messages in the queue for each connection.", + ) + + parser.add_argument( + "--max-active-connections", + type=int, + default=200, + help="""Maximum number of active connections. The server will refuse + to accept new connections once the current number of active connections + equals to this limit. + """, + ) + + parser.add_argument( + "--certificate", + type=str, + help="""Path to the X.509 certificate. You need it only if you want to + use a secure websocket connection, i.e., use wss:// instead of ws://. + You can use ./web/generate-certificate.py + to generate the certificate `cert.pem`. + Note ./web/generate-certificate.py will generate three files but you + only need to pass the generated cert.pem to this option. + """, + ) + + parser.add_argument( + "--doc-root", + type=str, + default="./python-api-examples/web", + help="Path to the web root", + ) + + return parser.parse_args() + + +class NonStreamingServer: + def __init__( + self, + recognizer: sherpa_mnn.OfflineRecognizer, + max_batch_size: int, + max_wait_ms: float, + nn_pool_size: int, + max_message_size: int, + max_queue_size: int, + max_active_connections: int, + doc_root: str, + certificate: Optional[str] = None, + ): + """ + Args: + recognizer: + An instance of the sherpa_mnn.OfflineRecognizer. + max_batch_size: + Max batch size for inference. + max_wait_ms: + Max wait time in milliseconds in order to build a batch of + `max_batch_size`. + nn_pool_size: + Number of threads for the thread pool that is used for NN + computation and decoding. + max_message_size: + Max size in bytes per message. + max_queue_size: + Max number of messages in the queue for each connection. + max_active_connections: + Max number of active connections. Once number of active client + equals to this limit, the server refuses to accept new connections. + doc_root: + Path to the directory where files like index.html for the HTTP + server locate. + certificate: + Optional. If not None, it will use secure websocket. + You can use ./web/generate-certificate.py to generate + it (the default generated filename is `cert.pem`). + """ + self.recognizer = recognizer + + self.certificate = certificate + self.http_server = HttpServer(doc_root) + + self.nn_pool_size = nn_pool_size + self.nn_pool = ThreadPoolExecutor( + max_workers=nn_pool_size, + thread_name_prefix="nn", + ) + + self.stream_queue = asyncio.Queue() + + self.max_wait_ms = max_wait_ms + self.max_batch_size = max_batch_size + self.max_message_size = max_message_size + self.max_queue_size = max_queue_size + self.max_active_connections = max_active_connections + + self.current_active_connections = 0 + self.sample_rate = int(recognizer.config.feat_config.sampling_rate) + + async def process_request( + self, + path: str, + request_headers: websockets.Headers, + ) -> Optional[Tuple[http.HTTPStatus, websockets.Headers, bytes]]: + if "sec-websocket-key" not in ( + request_headers.headers # For new request_headers + if hasattr(request_headers, "headers") + else request_headers # For old request_headers + ): + # This is a normal HTTP request + if path == "/": + path = "/index.html" + if path[-1] == "?": + path = path[:-1] + + if path == "/streaming_record.html": + response = r""" + +Speech recognition with next-gen Kaldi +

Only +/upload.html +and +/offline_record.html +is available for the non-streaming server.

+
+
+Go back to /upload.html +or /offline_record.html + +""" + found = True + mime_type = "text/html" + else: + found, response, mime_type = self.http_server.process_request(path) + if isinstance(response, str): + response = response.encode("utf-8") + + if not found: + status = http.HTTPStatus.NOT_FOUND + else: + status = http.HTTPStatus.OK + header = {"Content-Type": mime_type} + return status, header, response + + if self.current_active_connections < self.max_active_connections: + self.current_active_connections += 1 + return None + + # Refuse new connections + status = http.HTTPStatus.SERVICE_UNAVAILABLE # 503 + header = {"Hint": "The server is overloaded. Please retry later."} + response = b"The server is busy. Please retry later." + + return status, header, response + + async def run(self, port: int): + logging.info("started") + + tasks = [] + for i in range(self.nn_pool_size): + tasks.append(asyncio.create_task(self.stream_consumer_task())) + + if self.certificate: + logging.info(f"Using certificate: {self.certificate}") + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ssl_context.load_cert_chain(self.certificate) + else: + ssl_context = None + logging.info("No certificate provided") + + async with websockets.serve( + self.handle_connection, + host="", + port=port, + max_size=self.max_message_size, + max_queue=self.max_queue_size, + process_request=self.process_request, + ssl=ssl_context, + ): + ip_list = ["localhost"] + if ssl_context: + ip_list += ["0.0.0.0", "127.0.0.1"] + ip_list.append(socket.gethostbyname(socket.gethostname())) + + proto = "http://" if ssl_context is None else "https://" + s = "Please visit one of the following addresses:\n\n" + for p in ip_list: + s += " " + proto + p + f":{port}" "\n" + logging.info(s) + + await asyncio.Future() # run forever + + await asyncio.gather(*tasks) # not reachable + + async def recv_audio_samples( + self, + socket: websockets.WebSocketServerProtocol, + ) -> Tuple[Optional[np.ndarray], Optional[float]]: + """Receive a tensor from the client. + + The message from the client is a **bytes** buffer. + + The first message can be either "Done" meaning the client won't send + anything in the future or it can be a buffer containing 8 bytes. + The first 4 bytes in little endian specifies the sample + rate of the audio samples; the second 4 bytes in little endian specifies + the number of bytes in the audio file, which will be sent by the client + in the subsequent messages. + Since there is a limit in the message size posed by the websocket + protocol, the client may send the audio file in multiple messages if the + audio file is very large. + + The second and remaining messages contain audio samples. + + Please refer to ./offline-websocket-client-decode-files-paralell.py + and ./offline-websocket-client-decode-files-sequential.py + for how the client sends the message. + + Args: + socket: + The socket for communicating with the client. + Returns: + Return a containing: + - 1-D np.float32 array containing the audio samples + - sample rate of the audio samples + or return (None, None) indicating the end of utterance. + """ + header = await socket.recv() + if header == "Done": + return None, None + + assert len(header) >= 8, ( + "The first message should contain at least 8 bytes." + + f"Given {len(header)}" + ) + + sample_rate = int.from_bytes(header[:4], "little", signed=True) + expected_num_bytes = int.from_bytes(header[4:8], "little", signed=True) + + received = [] + num_received_bytes = 0 + if len(header) > 8: + received.append(header[8:]) + num_received_bytes += len(header) - 8 + + if num_received_bytes < expected_num_bytes: + async for message in socket: + received.append(message) + num_received_bytes += len(message) + if num_received_bytes >= expected_num_bytes: + break + + assert num_received_bytes == expected_num_bytes, ( + num_received_bytes, + expected_num_bytes, + ) + + samples = b"".join(received) + array = np.frombuffer(samples, dtype=np.float32) + return array, sample_rate + + async def stream_consumer_task(self): + """This function extracts streams from the queue, batches them up, sends + them to the RNN-T model for computation and decoding. + """ + while True: + if self.stream_queue.empty(): + await asyncio.sleep(self.max_wait_ms / 1000) + continue + + batch = [] + try: + while len(batch) < self.max_batch_size: + item = self.stream_queue.get_nowait() + + batch.append(item) + except asyncio.QueueEmpty: + pass + + stream_list = [b[0] for b in batch] + future_list = [b[1] for b in batch] + + loop = asyncio.get_running_loop() + await loop.run_in_executor( + self.nn_pool, + self.recognizer.decode_streams, + stream_list, + ) + + for f in future_list: + self.stream_queue.task_done() + f.set_result(None) + + async def compute_and_decode( + self, + stream: sherpa_mnn.OfflineStream, + ) -> None: + """Put the stream into the queue and wait it to be processed by the + consumer task. + + Args: + stream: + The stream to be processed. Note: It is changed in-place. + """ + loop = asyncio.get_running_loop() + future = loop.create_future() + await self.stream_queue.put((stream, future)) + await future + + async def handle_connection( + self, + socket: websockets.WebSocketServerProtocol, + ): + """Receive audio samples from the client, process it, and sends + deocoding result back to the client. + + Args: + socket: + The socket for communicating with the client. + """ + try: + await self.handle_connection_impl(socket) + except websockets.exceptions.ConnectionClosedError: + logging.info(f"{socket.remote_address} disconnected") + finally: + # Decrement so that it can accept new connections + self.current_active_connections -= 1 + + logging.info( + f"Disconnected: {socket.remote_address}. " + f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa + ) + + async def handle_connection_impl( + self, + socket: websockets.WebSocketServerProtocol, + ): + """Receive audio samples from the client, process it, and send + decoding results back to the client. + + Args: + socket: + The socket for communicating with the client. + """ + logging.info( + f"Connected: {socket.remote_address}. " + f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa + ) + + while True: + stream = self.recognizer.create_stream() + samples, sample_rate = await self.recv_audio_samples(socket) + if samples is None: + break + # stream.accept_samples() runs in the main thread + + stream.accept_waveform(sample_rate, samples) + + await self.compute_and_decode(stream) + result = stream.result.text + logging.info(f"result: {result}") + + if result: + await socket.send(result) + else: + # If result is an empty string, send something to the client. + # Otherwise, socket.send() is a no-op and the client will + # wait for a reply indefinitely. + await socket.send("") + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def create_recognizer(args) -> sherpa_mnn.OfflineRecognizer: + if args.encoder: + assert len(args.paraformer) == 0, args.paraformer + assert len(args.sense_voice) == 0, args.sense_voice + assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.telespeech_ctc) == 0, args.telespeech_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + + recognizer = sherpa_mnn.OfflineRecognizer.from_transducer( + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + max_active_paths=args.max_active_paths, + hotwords_file=args.hotwords_file, + hotwords_score=args.hotwords_score, + blank_penalty=args.blank_penalty, + provider=args.provider, + ) + elif args.paraformer: + assert len(args.sense_voice) == 0, args.sense_voice + assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.telespeech_ctc) == 0, args.telespeech_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + assert_file_exists(args.paraformer) + + recognizer = sherpa_mnn.OfflineRecognizer.from_paraformer( + paraformer=args.paraformer, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + provider=args.provider, + ) + elif args.sense_voice: + assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.telespeech_ctc) == 0, args.telespeech_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + assert_file_exists(args.sense_voice) + recognizer = sherpa_mnn.OfflineRecognizer.from_sense_voice( + model=args.sense_voice, + tokens=args.tokens, + num_threads=args.num_threads, + use_itn=True, + ) + elif args.nemo_ctc: + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.telespeech_ctc) == 0, args.telespeech_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + assert_file_exists(args.nemo_ctc) + + recognizer = sherpa_mnn.OfflineRecognizer.from_nemo_ctc( + model=args.nemo_ctc, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + provider=args.provider, + ) + elif args.wenet_ctc: + assert len(args.telespeech_ctc) == 0, args.telespeech_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + assert_file_exists(args.wenet_ctc) + + recognizer = sherpa_mnn.OfflineRecognizer.from_wenet_ctc( + model=args.wenet_ctc, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + provider=args.provider, + ) + elif args.telespeech_ctc: + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + assert_file_exists(args.telespeech_ctc) + + recognizer = sherpa_mnn.OfflineRecognizer.from_telespeech_ctc( + model=args.telespeech_ctc, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + provider=args.provider, + ) + elif args.whisper_encoder: + assert len(args.tdnn_model) == 0, args.tdnn_model + assert_file_exists(args.whisper_encoder) + assert_file_exists(args.whisper_decoder) + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + recognizer = sherpa_mnn.OfflineRecognizer.from_whisper( + encoder=args.whisper_encoder, + decoder=args.whisper_decoder, + tokens=args.tokens, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + language=args.whisper_language, + task=args.whisper_task, + tail_paddings=args.whisper_tail_paddings, + provider=args.provider, + ) + elif args.tdnn_model: + assert_file_exists(args.tdnn_model) + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + recognizer = sherpa_mnn.OfflineRecognizer.from_tdnn_ctc( + model=args.tdnn_model, + tokens=args.tokens, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + provider=args.provider, + ) + elif args.moonshine_preprocessor: + assert_file_exists(args.moonshine_preprocessor) + assert_file_exists(args.moonshine_encoder) + assert_file_exists(args.moonshine_uncached_decoder) + assert_file_exists(args.moonshine_cached_decoder) + + recognizer = sherpa_mnn.OfflineRecognizer.from_moonshine( + preprocessor=args.moonshine_preprocessor, + encoder=args.moonshine_encoder, + uncached_decoder=args.moonshine_uncached_decoder, + cached_decoder=args.moonshine_cached_decoder, + tokens=args.tokens, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + ) + else: + raise ValueError("Please specify at least one model") + + return recognizer + + +def main(): + args = get_args() + logging.info(vars(args)) + check_args(args) + + recognizer = create_recognizer(args) + + port = args.port + max_wait_ms = args.max_wait_ms + max_batch_size = args.max_batch_size + nn_pool_size = args.nn_pool_size + max_message_size = args.max_message_size + max_queue_size = args.max_queue_size + max_active_connections = args.max_active_connections + certificate = args.certificate + doc_root = args.doc_root + + if certificate and not Path(certificate).is_file(): + raise ValueError(f"{certificate} does not exist") + + if not Path(doc_root).is_dir(): + raise ValueError(f"Directory {doc_root} does not exist") + + non_streaming_server = NonStreamingServer( + recognizer=recognizer, + max_wait_ms=max_wait_ms, + max_batch_size=max_batch_size, + nn_pool_size=nn_pool_size, + max_message_size=max_message_size, + max_queue_size=max_queue_size, + max_active_connections=max_active_connections, + certificate=certificate, + doc_root=doc_root, + ) + asyncio.run(non_streaming_server.run(port)) + + +if __name__ == "__main__": + log_filename = "log/log-non-streaming-server" + setup_logger(log_filename) + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/offline-decode-files.py b/apps/frameworks/sherpa-mnn/python-api-examples/offline-decode-files.py new file mode 100755 index 00000000..0de8edfe --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/offline-decode-files.py @@ -0,0 +1,494 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 by manyeyes +# Copyright (c) 2023 Xiaomi Corporation + +""" +This file demonstrates how to use sherpa-onnx Python API to transcribe +file(s) with a non-streaming model. + +(1) For paraformer + + ./python-api-examples/offline-decode-files.py \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/paraformer.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + /path/to/0.wav \ + /path/to/1.wav + +(2) For transducer models from icefall + + ./python-api-examples/offline-decode-files.py \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 \ + /path/to/0.wav \ + /path/to/1.wav + +(3) For CTC models from NeMo + +python3 ./python-api-examples/offline-decode-files.py \ + --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \ + --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav + +(4) For Whisper models + +python3 ./python-api-examples/offline-decode-files.py \ + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ + --whisper-task=transcribe \ + --num-threads=1 \ + ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \ + ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ + ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav + +(5) For CTC models from WeNet + +python3 ./python-api-examples/offline-decode-files.py \ + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav + +(6) For tdnn models of the yesno recipe from icefall + +python3 ./python-api-examples/offline-decode-files.py \ + --sample-rate=8000 \ + --feature-dim=23 \ + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/index.html +to install sherpa-onnx and to download non-streaming pre-trained models +used in this file. +""" +import argparse +import time +import wave +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import sherpa_mnn + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, like + HELLO WORLD + 你好世界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + parser.add_argument( + "--modeling-unit", + type=str, + default="", + help=""" + The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe. + Used only when hotwords-file is given. + """, + ) + + parser.add_argument( + "--bpe-vocab", + type=str, + default="", + help=""" + The path to the bpe vocabulary, the bpe vocabulary is generated by + sentencepiece, you can also export the bpe vocabulary through a bpe model + by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given + and modeling-unit is bpe or cjkchar+bpe. + """, + ) + + parser.add_argument( + "--encoder", + default="", + type=str, + help="Path to the encoder model", + ) + + parser.add_argument( + "--decoder", + default="", + type=str, + help="Path to the decoder model", + ) + + parser.add_argument( + "--joiner", + default="", + type=str, + help="Path to the joiner model", + ) + + parser.add_argument( + "--paraformer", + default="", + type=str, + help="Path to the model.onnx from Paraformer", + ) + + parser.add_argument( + "--nemo-ctc", + default="", + type=str, + help="Path to the model.onnx from NeMo CTC", + ) + + parser.add_argument( + "--wenet-ctc", + default="", + type=str, + help="Path to the model.onnx from WeNet CTC", + ) + + parser.add_argument( + "--tdnn-model", + default="", + type=str, + help="Path to the model.onnx for the tdnn model of the yesno recipe", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--whisper-encoder", + default="", + type=str, + help="Path to whisper encoder model", + ) + + parser.add_argument( + "--whisper-decoder", + default="", + type=str, + help="Path to whisper decoder model", + ) + + parser.add_argument( + "--whisper-language", + default="", + type=str, + help="""It specifies the spoken language in the input audio file. + Example values: en, fr, de, zh, jp. + Available languages for multilingual models can be found at + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + If not specified, we infer the language from the input audio file. + """, + ) + + parser.add_argument( + "--whisper-task", + default="transcribe", + choices=["transcribe", "translate"], + type=str, + help="""For multilingual models, if you specify translate, the output + will be in English. + """, + ) + + parser.add_argument( + "--whisper-tail-paddings", + default=-1, + type=int, + help="""Number of tail padding frames. + We have removed the 30-second constraint from whisper, so you need to + choose the amount of tail padding frames by yourself. + Use -1 to use a default value for tail padding. + """, + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="""Sample rate of the feature extractor. Must match the one + expected by the model. Note: The input sound files can have a + different sample rate from this argument.""", + ) + + parser.add_argument( + "--feature-dim", + type=int, + default=80, + help="Feature dimension. Must match the one expected by the model", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to decode. Each file must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +def main(): + args = get_args() + assert_file_exists(args.tokens) + assert args.num_threads > 0, args.num_threads + + if args.encoder: + assert len(args.paraformer) == 0, args.paraformer + assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + + recognizer = sherpa_mnn.OfflineRecognizer.from_transducer( + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + hotwords_file=args.hotwords_file, + hotwords_score=args.hotwords_score, + modeling_unit=args.modeling_unit, + bpe_vocab=args.bpe_vocab, + blank_penalty=args.blank_penalty, + debug=args.debug, + ) + elif args.paraformer: + assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + + assert_file_exists(args.paraformer) + + recognizer = sherpa_mnn.OfflineRecognizer.from_paraformer( + paraformer=args.paraformer, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + elif args.nemo_ctc: + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + + assert_file_exists(args.nemo_ctc) + + recognizer = sherpa_mnn.OfflineRecognizer.from_nemo_ctc( + model=args.nemo_ctc, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + elif args.wenet_ctc: + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + + assert_file_exists(args.wenet_ctc) + + recognizer = sherpa_mnn.OfflineRecognizer.from_wenet_ctc( + model=args.wenet_ctc, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + elif args.whisper_encoder: + assert len(args.tdnn_model) == 0, args.tdnn_model + assert_file_exists(args.whisper_encoder) + assert_file_exists(args.whisper_decoder) + + recognizer = sherpa_mnn.OfflineRecognizer.from_whisper( + encoder=args.whisper_encoder, + decoder=args.whisper_decoder, + tokens=args.tokens, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + debug=args.debug, + language=args.whisper_language, + task=args.whisper_task, + tail_paddings=args.whisper_tail_paddings, + ) + elif args.tdnn_model: + assert_file_exists(args.tdnn_model) + + recognizer = sherpa_mnn.OfflineRecognizer.from_tdnn_ctc( + model=args.tdnn_model, + tokens=args.tokens, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + debug=args.debug, + ) + else: + print("Please specify at least one model") + return + + print("Started!") + start_time = time.time() + + streams = [] + total_duration = 0 + for wave_filename in args.sound_files: + assert_file_exists(wave_filename) + samples, sample_rate = read_wave(wave_filename) + duration = len(samples) / sample_rate + total_duration += duration + s = recognizer.create_stream() + s.accept_waveform(sample_rate, samples) + + streams.append(s) + + recognizer.decode_streams(streams) + results = [s.result.text for s in streams] + end_time = time.time() + print("Done!") + + for wave_filename, result in zip(args.sound_files, results): + print(f"{wave_filename}\n{result}") + print("-" * 10) + + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + print(f"num_threads: {args.num_threads}") + print(f"decoding_method: {args.decoding_method}") + print(f"Wave duration: {total_duration:.3f} s") + print(f"Elapsed time: {elapsed_seconds:.3f} s") + print( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/offline-fire-red-asr-decode-files.py b/apps/frameworks/sherpa-mnn/python-api-examples/offline-fire-red-asr-decode-files.py new file mode 100644 index 00000000..ffe5c028 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/offline-fire-red-asr-decode-files.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a non-streaming FireRedAsr AED model from +https://github.com/FireRedTeam/FireRedASR +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + +For instance, + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 +tar xvf sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 +rm sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16.tar.bz2 +""" + +from pathlib import Path + +import sherpa_mnn +import soundfile as sf + + +def create_recognizer(): + encoder = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/encoder.int8.onnx" + decoder = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/decoder.int8.onnx" + tokens = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/tokens.txt" + test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/0.wav" + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/1.wav" + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/2.wav" + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3.wav" + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/8k.wav" + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/3-sichuan.wav" + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/4-tianjin.wav" + # test_wav = "./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/test_wavs/5-henan.wav" + + if ( + not Path(encoder).is_file() + or not Path(decoder).is_file() + or not Path(test_wav).is_file() + ): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return ( + sherpa_mnn.OfflineRecognizer.from_fire_red_asr( + encoder=encoder, + decoder=decoder, + tokens=tokens, + debug=True, + ), + test_wav, + ) + + +def main(): + recognizer, wave_filename = create_recognizer() + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] + # sample_rate does not need to be 16000 Hz + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + recognizer.decode_stream(stream) + print(wave_filename) + print(stream.result) + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/offline-moonshine-decode-files.py b/apps/frameworks/sherpa-mnn/python-api-examples/offline-moonshine-decode-files.py new file mode 100644 index 00000000..dd707971 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/offline-moonshine-decode-files.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a non-streaming Moonshine model from +https://github.com/usefulsensors/moonshine +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + +For instance, + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2 +""" + +import datetime as dt +from pathlib import Path + +import sherpa_mnn +import soundfile as sf + + +def create_recognizer(): + preprocessor = "./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx" + encoder = "./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx" + uncached_decoder = "./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx" + cached_decoder = "./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx" + + tokens = "./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt" + test_wav = "./sherpa-onnx-moonshine-tiny-en-int8/test_wavs/0.wav" + + if not Path(preprocessor).is_file() or not Path(test_wav).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return ( + sherpa_mnn.OfflineRecognizer.from_moonshine( + preprocessor=preprocessor, + encoder=encoder, + uncached_decoder=uncached_decoder, + cached_decoder=cached_decoder, + tokens=tokens, + debug=True, + ), + test_wav, + ) + + +def main(): + recognizer, wave_filename = create_recognizer() + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] + # sample_rate does not need to be 16000 Hz + + start_t = dt.datetime.now() + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + recognizer.decode_stream(stream) + + end_t = dt.datetime.now() + elapsed_seconds = (end_t - start_t).total_seconds() + duration = audio.shape[-1] / sample_rate + rtf = elapsed_seconds / duration + + print(stream.result) + print(wave_filename) + print("Text:", stream.result.text) + print(f"Audio duration:\t{duration:.3f} s") + print(f"Elapsed:\t{elapsed_seconds:.3f} s") + print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/offline-nemo-ctc-decode-files.py b/apps/frameworks/sherpa-mnn/python-api-examples/offline-nemo-ctc-decode-files.py new file mode 100755 index 00000000..d97b8e31 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/offline-nemo-ctc-decode-files.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a non-streaming CTC model from NeMo +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + + +The example model supports 10 languages and it is converted from +https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc +""" + +from pathlib import Path + +import sherpa_mnn +import soundfile as sf + + +def create_recognizer(): + model = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/model.onnx" + tokens = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/tokens.txt" + + test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/de-german.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/en-english.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/es-spanish.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/fr-french.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/hr-croatian.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/it-italian.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/po-polish.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/ru-russian.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/uk-ukrainian.wav" + + if not Path(model).is_file() or not Path(test_wav).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return ( + sherpa_mnn.OfflineRecognizer.from_nemo_ctc( + model=model, + tokens=tokens, + debug=True, + ), + test_wav, + ) + + +def main(): + recognizer, wave_filename = create_recognizer() + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] + # sample_rate does not need to be 16000 Hz + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + recognizer.decode_stream(stream) + print(wave_filename) + print(stream.result) + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/offline-nemo-transducer-decode-files.py b/apps/frameworks/sherpa-mnn/python-api-examples/offline-nemo-transducer-decode-files.py new file mode 100755 index 00000000..8b639759 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/offline-nemo-transducer-decode-files.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a non-streaming transducer model from NeMo +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + + +The example model supports 10 languages and it is converted from +https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_multilingual_fastconformer_hybrid_large_pc +""" + +from pathlib import Path + +import sherpa_mnn +import soundfile as sf + + +def create_recognizer(): + encoder = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/encoder.onnx" + decoder = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/decoder.onnx" + joiner = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/joiner.onnx" + tokens = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/tokens.txt" + + test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/de-german.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/en-english.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/es-spanish.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/fr-french.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/hr-croatian.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/it-italian.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/po-polish.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/ru-russian.wav" + # test_wav = "./sherpa-onnx-nemo-fast-conformer-transducer-be-de-en-es-fr-hr-it-pl-ru-uk-20k/test_wavs/uk-ukrainian.wav" + + if not Path(encoder).is_file() or not Path(test_wav).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return ( + sherpa_mnn.OfflineRecognizer.from_transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + model_type="nemo_transducer", + debug=True, + ), + test_wav, + ) + + +def main(): + recognizer, wave_filename = create_recognizer() + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] + # sample_rate does not need to be 16000 Hz + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + recognizer.decode_stream(stream) + print(wave_filename) + print(stream.result) + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/offline-sense-voice-ctc-decode-files.py b/apps/frameworks/sherpa-mnn/python-api-examples/offline-sense-voice-ctc-decode-files.py new file mode 100644 index 00000000..e53345d8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/offline-sense-voice-ctc-decode-files.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a non-streaming SenseVoice CTC model from +https://github.com/FunAudioLLM/SenseVoice +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + +For instance, + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +""" + +from pathlib import Path + +import sherpa_mnn +import soundfile as sf + + +def create_recognizer(): + model = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.onnx" + tokens = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt" + test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/zh.wav" + # test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/en.wav" + # test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/ja.wav" + # test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/ko.wav" + # test_wav = "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/test_wavs/yue.wav" + + if not Path(model).is_file() or not Path(test_wav).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return ( + sherpa_mnn.OfflineRecognizer.from_sense_voice( + model=model, + tokens=tokens, + use_itn=True, + debug=True, + ), + test_wav, + ) + + +def main(): + recognizer, wave_filename = create_recognizer() + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] + # sample_rate does not need to be 16000 Hz + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + recognizer.decode_stream(stream) + print(wave_filename) + print(stream.result) + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/offline-speaker-diarization.py b/apps/frameworks/sherpa-mnn/python-api-examples/offline-speaker-diarization.py new file mode 100755 index 00000000..fe11bf21 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/offline-speaker-diarization.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024 Xiaomi Corporation + +""" +This file shows how to use sherpa-onnx Python API for +offline/non-streaming speaker diarization. + +Usage: + +Step 1: Download a speaker segmentation model + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models +for a list of available models. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + +Step 2: Download a speaker embedding extractor model + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +for a list of available models. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx + +Step 3. Download test wave files + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models +for a list of available test wave files. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav + +Step 4. Run it + + python3 ./python-api-examples/offline-speaker-diarization.py + +""" +from pathlib import Path + +import sherpa_mnn +import soundfile as sf + + +def init_speaker_diarization(num_speakers: int = -1, cluster_threshold: float = 0.5): + """ + Args: + num_speakers: + If you know the actual number of speakers in the wave file, then please + specify it. Otherwise, leave it to -1 + cluster_threshold: + If num_speakers is -1, then this threshold is used for clustering. + A smaller cluster_threshold leads to more clusters, i.e., more speakers. + A larger cluster_threshold leads to fewer clusters, i.e., fewer speakers. + """ + segmentation_model = "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx" + embedding_extractor_model = ( + "./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx" + ) + + config = sherpa_mnn.OfflineSpeakerDiarizationConfig( + segmentation=sherpa_mnn.OfflineSpeakerSegmentationModelConfig( + pyannote=sherpa_mnn.OfflineSpeakerSegmentationPyannoteModelConfig( + model=segmentation_model + ), + ), + embedding=sherpa_mnn.SpeakerEmbeddingExtractorConfig( + model=embedding_extractor_model + ), + clustering=sherpa_mnn.FastClusteringConfig( + num_clusters=num_speakers, threshold=cluster_threshold + ), + min_duration_on=0.3, + min_duration_off=0.5, + ) + if not config.validate(): + raise RuntimeError( + "Please check your config and make sure all required files exist" + ) + + return sherpa_mnn.OfflineSpeakerDiarization(config) + + +def progress_callback(num_processed_chunk: int, num_total_chunks: int) -> int: + progress = num_processed_chunk / num_total_chunks * 100 + print(f"Progress: {progress:.3f}%") + return 0 + + +def main(): + wave_filename = "./0-four-speakers-zh.wav" + if not Path(wave_filename).is_file(): + raise RuntimeError(f"{wave_filename} does not exist") + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # Since we know there are 4 speakers in the above test wave file, we use + # num_speakers 4 here + sd = init_speaker_diarization(num_speakers=4) + if sample_rate != sd.sample_rate: + raise RuntimeError( + f"Expected samples rate: {sd.sample_rate}, given: {sample_rate}" + ) + + show_progress = True + + if show_progress: + result = sd.process(audio, callback=progress_callback).sort_by_start_time() + else: + result = sd.process(audio).sort_by_start_time() + + for r in result: + print(f"{r.start:.3f} -- {r.end:.3f} speaker_{r.speaker:02}") + # print(r) # this one is simpler + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/offline-speech-enhancement-gtcrn.py b/apps/frameworks/sherpa-mnn/python-api-examples/offline-speech-enhancement-gtcrn.py new file mode 100755 index 00000000..88c01734 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/offline-speech-enhancement-gtcrn.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use the speech enhancement API. + +Please download files used this script from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models + +Example: + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/gtcrn_simple.onnx + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speech-enhancement-models/speech_with_noise.wav +""" + +import time +from pathlib import Path +from typing import Tuple + +import numpy as np +import sherpa_mnn +import soundfile as sf + + +def create_speech_denoiser(): + model_filename = "./gtcrn_simple.onnx" + if not Path(model_filename).is_file(): + raise ValueError( + "Please first download a model from " + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models" + ) + + config = sherpa_mnn.OfflineSpeechDenoiserConfig( + model=sherpa_mnn.OfflineSpeechDenoiserModelConfig( + gtcrn=sherpa_mnn.OfflineSpeechDenoiserGtcrnModelConfig( + model=model_filename + ), + debug=False, + num_threads=1, + provider="cpu", + ) + ) + if not config.validate(): + print(config) + raise ValueError("Errors in config. Please check previous error logs") + return sherpa_mnn.OfflineSpeechDenoiser(config) + + +def load_audio(filename: str) -> Tuple[np.ndarray, int]: + data, sample_rate = sf.read( + filename, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + samples = np.ascontiguousarray(data) + return samples, sample_rate + + +def main(): + sd = create_speech_denoiser() + test_wave = "./speech_with_noise.wav" + if not Path(test_wave).is_file(): + raise ValueError( + f"{test_wave} does not exist. You can download it from " + "https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models" + ) + + samples, sample_rate = load_audio(test_wave) + + start = time.time() + denoised = sd(samples, sample_rate) + end = time.time() + + elapsed_seconds = end - start + audio_duration = len(samples) / sample_rate + real_time_factor = elapsed_seconds / audio_duration + + sf.write("./enhanced_16k.wav", denoised.samples, denoised.sample_rate) + print("Saved to ./enhanced_16k.wav") + print(f"Elapsed seconds: {elapsed_seconds:.3f}") + print(f"Audio duration in seconds: {audio_duration:.3f}") + print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}") + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/offline-telespeech-ctc-decode-files.py b/apps/frameworks/sherpa-mnn/python-api-examples/offline-telespeech-ctc-decode-files.py new file mode 100755 index 00000000..a415faf5 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/offline-telespeech-ctc-decode-files.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a non-streaming CTC model from +https://github.com/Tele-AI/TeleSpeech-ASR +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + + +""" + +from pathlib import Path + +import sherpa_mnn +import soundfile as sf + + +def create_recognizer(): + model = "./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/model.int8.onnx" + tokens = "./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/tokens.txt" + test_wav = "./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/test_wavs/3-sichuan.wav" + # test_wav = "./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/test_wavs/4-tianjin.wav" + # test_wav = "./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/test_wavs/5-henan.wav" + + if not Path(model).is_file() or not Path(test_wav).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return ( + sherpa_mnn.OfflineRecognizer.from_telespeech_ctc( + model=model, + tokens=tokens, + debug=True, + ), + test_wav, + ) + + +def main(): + recognizer, wave_filename = create_recognizer() + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] + # sample_rate does not need to be 16000 Hz + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + recognizer.decode_stream(stream) + print(wave_filename) + print(stream.result) + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/offline-tts-play.py b/apps/frameworks/sherpa-mnn/python-api-examples/offline-tts-play.py new file mode 100755 index 00000000..2d9307f3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/offline-tts-play.py @@ -0,0 +1,580 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 Xiaomi Corporation + +""" +This file demonstrates how to use sherpa-onnx Python API to generate audio +from text, i.e., text-to-speech. + +Different from ./offline-tts.py, this file plays back the generated audio +while the model is still generating. + +Usage: + +Example (1/7) + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 +tar xf vits-piper-en_US-amy-low.tar.bz2 + +python3 ./python-api-examples/offline-tts-play.py \ + --vits-model=./vits-piper-en_US-amy-low/en_US-amy-low.onnx \ + --vits-tokens=./vits-piper-en_US-amy-low/tokens.txt \ + --vits-data-dir=./vits-piper-en_US-amy-low/espeak-ng-data \ + --output-filename=./generated.wav \ + "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar." + +Example (2/7) + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-zh-aishell3.tar.bz2 +tar xvf vits-zh-aishell3.tar.bz2 + +python3 ./python-api-examples/offline-tts-play.py \ + --vits-model=./vits-icefall-zh-aishell3/model.onnx \ + --vits-lexicon=./vits-icefall-zh-aishell3/lexicon.txt \ + --vits-tokens=./vits-icefall-zh-aishell3/tokens.txt \ + --tts-rule-fsts='./vits-icefall-zh-aishell3/phone.fst,./vits-icefall-zh-aishell3/date.fst,./vits-icefall-zh-aishell3/number.fst' \ + --sid=21 \ + --output-filename=./liubei-21.wav \ + "勿以恶小而为之,勿以善小而不为。惟贤惟德,能服于人。122334" + +Example (3/7) + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/sherpa-onnx-vits-zh-ll.tar.bz2 +tar xvf sherpa-onnx-vits-zh-ll.tar.bz2 +rm sherpa-onnx-vits-zh-ll.tar.bz2 + +python3 ./python-api-examples/offline-tts-play.py \ + --vits-model=./sherpa-onnx-vits-zh-ll/model.onnx \ + --vits-lexicon=./sherpa-onnx-vits-zh-ll/lexicon.txt \ + --vits-tokens=./sherpa-onnx-vits-zh-ll/tokens.txt \ + --tts-rule-fsts=./sherpa-onnx-vits-zh-ll/phone.fst,./sherpa-onnx-vits-zh-ll/date.fst,./sherpa-onnx-vits-zh-ll/number.fst \ + --vits-dict-dir=./sherpa-onnx-vits-zh-ll/dict \ + --sid=2 \ + --output-filename=./test-2.wav \ + "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。2024年5月11号,拨打110或者18920240511。123456块钱。" + +Example (4/7) + +curl -O -SL https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2 +tar xvf matcha-icefall-zh-baker.tar.bz2 +rm matcha-icefall-zh-baker.tar.bz2 + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx + +python3 ./python-api-examples/offline-tts-play.py \ + --matcha-acoustic-model=./matcha-icefall-zh-baker/model-steps-3.onnx \ + --matcha-vocoder=./hifigan_v2.onnx \ + --matcha-lexicon=./matcha-icefall-zh-baker/lexicon.txt \ + --matcha-tokens=./matcha-icefall-zh-baker/tokens.txt \ + --tts-rule-fsts=./matcha-icefall-zh-baker/phone.fst,./matcha-icefall-zh-baker/date.fst,./matcha-icefall-zh-baker/number.fst \ + --matcha-dict-dir=./matcha-icefall-zh-baker/dict \ + --output-filename=./test-matcha.wav \ + "某某银行的副行长和一些行政领导表示,他们去过长江和长白山; 经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。" + +Example (5/7) + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-en_US-ljspeech.tar.bz2 +tar xvf matcha-icefall-en_US-ljspeech.tar.bz2 +rm matcha-icefall-en_US-ljspeech.tar.bz2 + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx + +python3 ./python-api-examples/offline-tts-play.py \ + --matcha-acoustic-model=./matcha-icefall-en_US-ljspeech/model-steps-3.onnx \ + --matcha-vocoder=./hifigan_v2.onnx \ + --matcha-tokens=./matcha-icefall-en_US-ljspeech/tokens.txt \ + --matcha-data-dir=./matcha-icefall-en_US-ljspeech/espeak-ng-data \ + --output-filename=./test-matcha-ljspeech-en.wav \ + --num-threads=2 \ + "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar." + +Example (6/7) + +(This version of kokoro supports only English) + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/kokoro-en-v0_19.tar.bz2 +tar xf kokoro-en-v0_19.tar.bz2 +rm kokoro-en-v0_19.tar.bz2 + +python3 ./python-api-examples/offline-tts.py \ + --debug=1 \ + --kokoro-model=./kokoro-en-v0_19/model.onnx \ + --kokoro-voices=./kokoro-en-v0_19/voices.bin \ + --kokoro-tokens=./kokoro-en-v0_19/tokens.txt \ + --kokoro-data-dir=./kokoro-en-v0_19/espeak-ng-data \ + --num-threads=2 \ + --sid=10 \ + --output-filename="./kokoro-10.wav" \ + "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be a statesman, a businessman, an official, or a scholar." + +Example (7/7) + +(This version of kokoro supports English, Chinese, etc.) + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/kokoro-multi-lang-v1_0.tar.bz2 +tar xf kokoro-multi-lang-v1_0.tar.bz2 +rm kokoro-multi-lang-v1_0.tar.bz2 + +python3 ./python-api-examples/offline-tts-play.py \ + --debug=1 \ + --kokoro-model=./kokoro-multi-lang-v1_0/model.onnx \ + --kokoro-voices=./kokoro-multi-lang-v1_0/voices.bin \ + --kokoro-tokens=./kokoro-multi-lang-v1_0/tokens.txt \ + --kokoro-data-dir=./kokoro-multi-lang-v1_0/espeak-ng-data \ + --kokoro-dict-dir=./kokoro-multi-lang-v1_0/dict \ + --kokoro-lexicon=./kokoro-multi-lang-v1_0/lexicon-us-en.txt,./kokoro-multi-lang-v1_0/lexicon-zh.txt \ + --num-threads=2 \ + --sid=18 \ + --output-filename="./kokoro-18-zh-en.wav" \ + "中英文语音合成测试。This is generated by next generation Kaldi using Kokoro without Misaki. 你觉得中英文说的如何呢?" + +You can find more models at +https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models + +Please see +https://k2-fsa.github.io/sherpa/onnx/tts/index.html +for details. +""" + +import argparse +import logging +import queue +import sys +import threading +import time + +import numpy as np +import sherpa_mnn +import soundfile as sf + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + + +def add_vits_args(parser): + parser.add_argument( + "--vits-model", + type=str, + default="", + help="Path to vits model.onnx", + ) + + parser.add_argument( + "--vits-lexicon", + type=str, + default="", + help="Path to lexicon.txt", + ) + + parser.add_argument( + "--vits-tokens", + type=str, + default="", + help="Path to tokens.txt", + ) + + parser.add_argument( + "--vits-data-dir", + type=str, + default="", + help="""Path to the dict directory of espeak-ng. If it is specified, + --vits-lexicon and --vits-tokens are ignored""", + ) + + parser.add_argument( + "--vits-dict-dir", + type=str, + default="", + help="Path to the dict directory for models using jieba", + ) + + +def add_matcha_args(parser): + parser.add_argument( + "--matcha-acoustic-model", + type=str, + default="", + help="Path to model.onnx for matcha", + ) + + parser.add_argument( + "--matcha-vocoder", + type=str, + default="", + help="Path to vocoder for matcha", + ) + + parser.add_argument( + "--matcha-lexicon", + type=str, + default="", + help="Path to lexicon.txt for matcha", + ) + + parser.add_argument( + "--matcha-tokens", + type=str, + default="", + help="Path to tokens.txt for matcha", + ) + + parser.add_argument( + "--matcha-data-dir", + type=str, + default="", + help="""Path to the dict directory of espeak-ng. If it is specified, + --matcha-lexicon and --matcha-tokens are ignored""", + ) + + parser.add_argument( + "--matcha-dict-dir", + type=str, + default="", + help="Path to the dict directory for models using jieba", + ) + + +def add_kokoro_args(parser): + parser.add_argument( + "--kokoro-model", + type=str, + default="", + help="Path to model.onnx for kokoro", + ) + + parser.add_argument( + "--kokoro-voices", + type=str, + default="", + help="Path to voices.bin for kokoro", + ) + + parser.add_argument( + "--kokoro-tokens", + type=str, + default="", + help="Path to tokens.txt for kokoro", + ) + + parser.add_argument( + "--kokoro-data-dir", + type=str, + default="", + help="Path to the dict directory of espeak-ng.", + ) + + parser.add_argument( + "--kokoro-dict-dir", + type=str, + default="", + help="Path to the dict directory for models using jieba. Needed only by multilingual kokoro", + ) + + parser.add_argument( + "--kokoro-lexicon", + type=str, + default="", + help="Path to lexicon.txt for kokoro. Needed only by multilingual kokoro", + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + add_vits_args(parser) + add_matcha_args(parser) + add_kokoro_args(parser) + + parser.add_argument( + "--tts-rule-fsts", + type=str, + default="", + help="Path to rule.fst", + ) + + parser.add_argument( + "--output-filename", + type=str, + default="./generated.wav", + help="Path to save generated wave", + ) + + parser.add_argument( + "--sid", + type=int, + default=0, + help="""Speaker ID. Used only for multi-speaker models, e.g. + models trained using the VCTK dataset. Not used for single-speaker + models, e.g., models trained using the LJ speech dataset. + """, + ) + + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="valid values: cpu, cuda, coreml", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--speed", + type=float, + default=1.0, + help="Speech speed. Larger->faster; smaller->slower", + ) + + parser.add_argument( + "text", + type=str, + help="The input text to generate audio for", + ) + + return parser.parse_args() + + +# buffer saves audio samples to be played +buffer = queue.Queue() + +# started is set to True once generated_audio_callback is called. +started = False + +# stopped is set to True once all the text has been processed +stopped = False + +# killed is set to True once ctrl + C is pressed +killed = False + +# Note: When started is True, and stopped is True, and buffer is empty, +# we will exit the program since all audio samples have been played. + +sample_rate = None + +event = threading.Event() + +first_message_time = None + + +def generated_audio_callback(samples: np.ndarray, progress: float): + """This function is called whenever max_num_sentences sentences + have been processed. + + Note that it is passed to C++ and is invoked in C++. + + Args: + samples: + A 1-D np.float32 array containing audio samples + """ + global first_message_time + if first_message_time is None: + first_message_time = time.time() + + buffer.put(samples) + global started + + if started is False: + logging.info("Start playing ...") + started = True + + # 1 means to keep generating + # 0 means to stop generating + if killed: + return 0 + + return 1 + + +# see https://python-sounddevice.readthedocs.io/en/0.4.6/api/streams.html#sounddevice.OutputStream +def play_audio_callback( + outdata: np.ndarray, frames: int, time, status: sd.CallbackFlags +): + if killed or (started and buffer.empty() and stopped): + event.set() + + # outdata is of shape (frames, num_channels) + if buffer.empty(): + outdata.fill(0) + return + + n = 0 + while n < frames and not buffer.empty(): + remaining = frames - n + k = buffer.queue[0].shape[0] + + if remaining <= k: + outdata[n:, 0] = buffer.queue[0][:remaining] + buffer.queue[0] = buffer.queue[0][remaining:] + n = frames + if buffer.queue[0].shape[0] == 0: + buffer.get() + + break + + outdata[n : n + k, 0] = buffer.get() + n += k + + if n < frames: + outdata[n:, 0] = 0 + + +# Please see +# https://python-sounddevice.readthedocs.io/en/0.4.6/usage.html#device-selection +# for how to select a device +def play_audio(): + if False: + # This if branch can be safely removed. It is here to show you how to + # change the default output device in case you need that. + devices = sd.query_devices() + print(devices) + + # sd.default.device[1] is the output device, if you want to + # select a different device, say, 3, as the output device, please + # use self.default.device[1] = 3 + + default_output_device_idx = sd.default.device[1] + print( + f'Use default output device: {devices[default_output_device_idx]["name"]}' + ) + + with sd.OutputStream( + channels=1, + callback=play_audio_callback, + dtype="float32", + samplerate=sample_rate, + blocksize=1024, + ): + event.wait() + + logging.info("Exiting ...") + + +def main(): + args = get_args() + print(args) + + tts_config = sherpa_mnn.OfflineTtsConfig( + model=sherpa_mnn.OfflineTtsModelConfig( + vits=sherpa_mnn.OfflineTtsVitsModelConfig( + model=args.vits_model, + lexicon=args.vits_lexicon, + data_dir=args.vits_data_dir, + dict_dir=args.vits_dict_dir, + tokens=args.vits_tokens, + ), + matcha=sherpa_mnn.OfflineTtsMatchaModelConfig( + acoustic_model=args.matcha_acoustic_model, + vocoder=args.matcha_vocoder, + lexicon=args.matcha_lexicon, + tokens=args.matcha_tokens, + data_dir=args.matcha_data_dir, + dict_dir=args.matcha_dict_dir, + ), + kokoro=sherpa_mnn.OfflineTtsKokoroModelConfig( + model=args.kokoro_model, + voices=args.kokoro_voices, + tokens=args.kokoro_tokens, + data_dir=args.kokoro_data_dir, + dict_dir=args.kokoro_dict_dir, + lexicon=args.kokoro_lexicon, + ), + provider=args.provider, + debug=args.debug, + num_threads=args.num_threads, + ), + rule_fsts=args.tts_rule_fsts, + max_num_sentences=1, + ) + + if not tts_config.validate(): + raise ValueError("Please check your config") + + logging.info("Loading model ...") + tts = sherpa_mnn.OfflineTts(tts_config) + logging.info("Loading model done.") + + global sample_rate + sample_rate = tts.sample_rate + + play_back_thread = threading.Thread(target=play_audio) + play_back_thread.start() + + logging.info("Start generating ...") + start_time = time.time() + audio = tts.generate( + args.text, + sid=args.sid, + speed=args.speed, + callback=generated_audio_callback, + ) + end_time = time.time() + logging.info("Finished generating!") + global stopped + stopped = True + + if len(audio.samples) == 0: + print("Error in generating audios. Please read previous error messages.") + global killed + killed = True + play_back_thread.join() + return + + elapsed_seconds = end_time - start_time + audio_duration = len(audio.samples) / audio.sample_rate + real_time_factor = elapsed_seconds / audio_duration + + sf.write( + args.output_filename, + audio.samples, + samplerate=audio.sample_rate, + subtype="PCM_16", + ) + logging.info(f"The text is '{args.text}'") + logging.info( + "Time in seconds to receive the first " + f"message: {first_message_time-start_time:.3f}" + ) + logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}") + logging.info(f"Audio duration in seconds: {audio_duration:.3f}") + logging.info( + f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}" + ) + + logging.info(f"*** Saved to {args.output_filename} ***") + + print("\n >>>>>>>>> You can safely press ctrl + C to stop the play <<<<<<<<<<\n") + + play_back_thread.join() + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") + killed = True + sys.exit(0) diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/offline-tts.py b/apps/frameworks/sherpa-mnn/python-api-examples/offline-tts.py new file mode 100755 index 00000000..1a2471b3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/offline-tts.py @@ -0,0 +1,423 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 Xiaomi Corporation + +""" +This file demonstrates how to use sherpa-onnx Python API to generate audio +from text, i.e., text-to-speech. + + +Different from ./offline-tts-play.py, this file does not play back the +generated audio. + +Usage: + +Example (1/7) + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 +tar xf vits-piper-en_US-amy-low.tar.bz2 + +python3 ./python-api-examples/offline-tts.py \ + --vits-model=./vits-piper-en_US-amy-low/en_US-amy-low.onnx \ + --vits-tokens=./vits-piper-en_US-amy-low/tokens.txt \ + --vits-data-dir=./vits-piper-en_US-amy-low/espeak-ng-data \ + --output-filename=./generated.wav \ + "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar." + +Example (2/7) + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-icefall-zh-aishell3.tar.bz2 +tar xvf vits-icefall-zh-aishell3.tar.bz2 + +python3 ./python-api-examples/offline-tts.py \ + --vits-model=./vits-icefall-zh-aishell3/model.onnx \ + --vits-lexicon=./vits-icefall-zh-aishell3/lexicon.txt \ + --vits-tokens=./vits-icefall-zh-aishell3/tokens.txt \ + --tts-rule-fsts='./vits-icefall-zh-aishell3/phone.fst,./vits-icefall-zh-aishell3/date.fst,./vits-icefall-zh-aishell3/number.fst' \ + --sid=21 \ + --output-filename=./liubei-21.wav \ + "勿以恶小而为之,勿以善小而不为。惟贤惟德,能服于人。122334" + +Example (3/7) + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/sherpa-onnx-vits-zh-ll.tar.bz2 +tar xvf sherpa-onnx-vits-zh-ll.tar.bz2 +rm sherpa-onnx-vits-zh-ll.tar.bz2 + +python3 ./python-api-examples/offline-tts.py \ + --vits-model=./sherpa-onnx-vits-zh-ll/model.onnx \ + --vits-lexicon=./sherpa-onnx-vits-zh-ll/lexicon.txt \ + --vits-tokens=./sherpa-onnx-vits-zh-ll/tokens.txt \ + --tts-rule-fsts=./sherpa-onnx-vits-zh-ll/phone.fst,./sherpa-onnx-vits-zh-ll/date.fst,./sherpa-onnx-vits-zh-ll/number.fst \ + --vits-dict-dir=./sherpa-onnx-vits-zh-ll/dict \ + --sid=2 \ + --output-filename=./test-2.wav \ + "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。2024年5月11号,拨打110或者18920240511。123456块钱。" + +Example (4/7) + +curl -O -SL https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2 +tar xvf matcha-icefall-zh-baker.tar.bz2 +rm matcha-icefall-zh-baker.tar.bz2 + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx + +python3 ./python-api-examples/offline-tts.py \ + --matcha-acoustic-model=./matcha-icefall-zh-baker/model-steps-3.onnx \ + --matcha-vocoder=./hifigan_v2.onnx \ + --matcha-lexicon=./matcha-icefall-zh-baker/lexicon.txt \ + --matcha-tokens=./matcha-icefall-zh-baker/tokens.txt \ + --tts-rule-fsts=./matcha-icefall-zh-baker/phone.fst,./matcha-icefall-zh-baker/date.fst,./matcha-icefall-zh-baker/number.fst \ + --matcha-dict-dir=./matcha-icefall-zh-baker/dict \ + --output-filename=./test-matcha.wav \ + "某某银行的副行长和一些行政领导表示,他们去过长江和长白山; 经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。" + +Example (5/7) + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-en_US-ljspeech.tar.bz2 +tar xvf matcha-icefall-en_US-ljspeech.tar.bz2 +rm matcha-icefall-en_US-ljspeech.tar.bz2 + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx + +python3 ./python-api-examples/offline-tts.py \ + --matcha-acoustic-model=./matcha-icefall-en_US-ljspeech/model-steps-3.onnx \ + --matcha-vocoder=./hifigan_v2.onnx \ + --matcha-tokens=./matcha-icefall-en_US-ljspeech/tokens.txt \ + --matcha-data-dir=./matcha-icefall-en_US-ljspeech/espeak-ng-data \ + --output-filename=./test-matcha-ljspeech-en.wav \ + --num-threads=2 \ + "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar." + +Example (6/7) + +(This version of kokoro supports only English) + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/kokoro-en-v0_19.tar.bz2 +tar xf kokoro-en-v0_19.tar.bz2 +rm kokoro-en-v0_19.tar.bz2 + +python3 ./python-api-examples/offline-tts.py \ + --debug=1 \ + --kokoro-model=./kokoro-en-v0_19/model.onnx \ + --kokoro-voices=./kokoro-en-v0_19/voices.bin \ + --kokoro-tokens=./kokoro-en-v0_19/tokens.txt \ + --kokoro-data-dir=./kokoro-en-v0_19/espeak-ng-data \ + --num-threads=2 \ + --sid=10 \ + --output-filename="./kokoro-10.wav" \ + "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be a statesman, a businessman, an official, or a scholar." + +Example (7/7) + +(This version of kokoro supports English, Chinese, etc.) + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/kokoro-multi-lang-v1_0.tar.bz2 +tar xf kokoro-multi-lang-v1_0.tar.bz2 +rm kokoro-multi-lang-v1_0.tar.bz2 + +python3 ./python-api-examples/offline-tts.py \ + --debug=1 \ + --kokoro-model=./kokoro-multi-lang-v1_0/model.onnx \ + --kokoro-voices=./kokoro-multi-lang-v1_0/voices.bin \ + --kokoro-tokens=./kokoro-multi-lang-v1_0/tokens.txt \ + --kokoro-data-dir=./kokoro-multi-lang-v1_0/espeak-ng-data \ + --kokoro-dict-dir=./kokoro-multi-lang-v1_0/dict \ + --kokoro-lexicon=./kokoro-multi-lang-v1_0/lexicon-us-en.txt,./kokoro-multi-lang-v1_0/lexicon-zh.txt \ + --num-threads=2 \ + --sid=18 \ + --output-filename="./kokoro-18-zh-en.wav" \ + "中英文语音合成测试。This is generated by next generation Kaldi using Kokoro without Misaki. 你觉得中英文说的如何呢?" + +You can find more models at +https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models + +Please see +https://k2-fsa.github.io/sherpa/onnx/tts/index.html +for details. + +""" + +import argparse +import time + +import sherpa_mnn +import soundfile as sf + + +def add_vits_args(parser): + parser.add_argument( + "--vits-model", + type=str, + default="", + help="Path to vits model.onnx", + ) + + parser.add_argument( + "--vits-lexicon", + type=str, + default="", + help="Path to lexicon.txt", + ) + + parser.add_argument( + "--vits-tokens", + type=str, + default="", + help="Path to tokens.txt", + ) + + parser.add_argument( + "--vits-data-dir", + type=str, + default="", + help="""Path to the dict directory of espeak-ng. If it is specified, + --vits-lexicon and --vits-tokens are ignored""", + ) + + parser.add_argument( + "--vits-dict-dir", + type=str, + default="", + help="Path to the dict directory for models using jieba", + ) + + +def add_matcha_args(parser): + parser.add_argument( + "--matcha-acoustic-model", + type=str, + default="", + help="Path to model.onnx for matcha", + ) + + parser.add_argument( + "--matcha-vocoder", + type=str, + default="", + help="Path to vocoder for matcha", + ) + + parser.add_argument( + "--matcha-lexicon", + type=str, + default="", + help="Path to lexicon.txt for matcha", + ) + + parser.add_argument( + "--matcha-tokens", + type=str, + default="", + help="Path to tokens.txt for matcha", + ) + + parser.add_argument( + "--matcha-data-dir", + type=str, + default="", + help="""Path to the dict directory of espeak-ng. If it is specified, + --matcha-lexicon and --matcha-tokens are ignored""", + ) + + parser.add_argument( + "--matcha-dict-dir", + type=str, + default="", + help="Path to the dict directory for models using jieba", + ) + + +def add_kokoro_args(parser): + parser.add_argument( + "--kokoro-model", + type=str, + default="", + help="Path to model.onnx for kokoro", + ) + + parser.add_argument( + "--kokoro-voices", + type=str, + default="", + help="Path to voices.bin for kokoro", + ) + + parser.add_argument( + "--kokoro-tokens", + type=str, + default="", + help="Path to tokens.txt for kokoro", + ) + + parser.add_argument( + "--kokoro-data-dir", + type=str, + default="", + help="Path to the dict directory of espeak-ng.", + ) + + parser.add_argument( + "--kokoro-dict-dir", + type=str, + default="", + help="Path to the dict directory for models using jieba. Needed only by multilingual kokoro", + ) + + parser.add_argument( + "--kokoro-lexicon", + type=str, + default="", + help="Path to lexicon.txt for kokoro. Needed only by multilingual kokoro", + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + add_vits_args(parser) + add_matcha_args(parser) + add_kokoro_args(parser) + + parser.add_argument( + "--tts-rule-fsts", + type=str, + default="", + help="Path to rule.fst", + ) + + parser.add_argument( + "--max-num-sentences", + type=int, + default=1, + help="""Max number of sentences in a batch to avoid OOM if the input + text is very long. Set it to -1 to process all the sentences in a + single batch. A smaller value does not mean it is slower compared + to a larger one on CPU. + """, + ) + + parser.add_argument( + "--output-filename", + type=str, + default="./generated.wav", + help="Path to save generated wave", + ) + + parser.add_argument( + "--sid", + type=int, + default=0, + help="""Speaker ID. Used only for multi-speaker models, e.g. + models trained using the VCTK dataset. Not used for single-speaker + models, e.g., models trained using the LJ speech dataset. + """, + ) + + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="valid values: cpu, cuda, coreml", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--speed", + type=float, + default=1.0, + help="Speech speed. Larger->faster; smaller->slower", + ) + + parser.add_argument( + "text", + type=str, + help="The input text to generate audio for", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + print(args) + + tts_config = sherpa_mnn.OfflineTtsConfig( + model=sherpa_mnn.OfflineTtsModelConfig( + vits=sherpa_mnn.OfflineTtsVitsModelConfig( + model=args.vits_model, + lexicon=args.vits_lexicon, + data_dir=args.vits_data_dir, + dict_dir=args.vits_dict_dir, + tokens=args.vits_tokens, + ), + matcha=sherpa_mnn.OfflineTtsMatchaModelConfig( + acoustic_model=args.matcha_acoustic_model, + vocoder=args.matcha_vocoder, + lexicon=args.matcha_lexicon, + tokens=args.matcha_tokens, + data_dir=args.matcha_data_dir, + dict_dir=args.matcha_dict_dir, + ), + kokoro=sherpa_mnn.OfflineTtsKokoroModelConfig( + model=args.kokoro_model, + voices=args.kokoro_voices, + tokens=args.kokoro_tokens, + data_dir=args.kokoro_data_dir, + dict_dir=args.kokoro_dict_dir, + lexicon=args.kokoro_lexicon, + ), + provider=args.provider, + debug=args.debug, + num_threads=args.num_threads, + ), + rule_fsts=args.tts_rule_fsts, + max_num_sentences=args.max_num_sentences, + ) + if not tts_config.validate(): + raise ValueError("Please check your config") + + tts = sherpa_mnn.OfflineTts(tts_config) + + start = time.time() + audio = tts.generate(args.text, sid=args.sid, speed=args.speed) + end = time.time() + + if len(audio.samples) == 0: + print("Error in generating audios. Please read previous error messages.") + return + + elapsed_seconds = end - start + audio_duration = len(audio.samples) / audio.sample_rate + real_time_factor = elapsed_seconds / audio_duration + + sf.write( + args.output_filename, + audio.samples, + samplerate=audio.sample_rate, + subtype="PCM_16", + ) + print(f"Saved to {args.output_filename}") + print(f"The text is '{args.text}'") + print(f"Elapsed seconds: {elapsed_seconds:.3f}") + print(f"Audio duration in seconds: {audio_duration:.3f}") + print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}") + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/offline-websocket-client-decode-files-paralell.py b/apps/frameworks/sherpa-mnn/python-api-examples/offline-websocket-client-decode-files-paralell.py new file mode 100755 index 00000000..f97f9c44 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/offline-websocket-client-decode-files-paralell.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 Xiaomi Corporation + +""" +A websocket client for sherpa-onnx-offline-websocket-server + +This file shows how to transcribe multiple +files in parallel. We create a separate connection for transcribing each file. + +Usage: + ./offline-websocket-client-decode-files-parallel.py \ + --server-addr localhost \ + --server-port 6006 \ + /path/to/foo.wav \ + /path/to/bar.wav \ + /path/to/16kHz.wav \ + /path/to/8kHz.wav + +(Note: You have to first start the server before starting the client) + +You can find the server at +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/offline-websocket-server.cc + +Note: The server is implemented in C++. +""" + +import argparse +import asyncio +import logging +import wave +from typing import Tuple + +try: + import websockets +except ImportError: + print("please run:") + print("") + print(" pip install websockets") + print("") + print("before you run this script") + print("") + +import numpy as np + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--server-addr", + type=str, + default="localhost", + help="Address of the server", + ) + + parser.add_argument( + "--server-port", + type=int, + default=6006, + help="Port of the server", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to decode. Each file must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + return parser.parse_args() + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +async def run( + server_addr: str, + server_port: int, + wave_filename: str, +): + async with websockets.connect( + f"ws://{server_addr}:{server_port}" + ) as websocket: # noqa + logging.info(f"Sending {wave_filename}") + samples, sample_rate = read_wave(wave_filename) + assert isinstance(sample_rate, int) + assert samples.dtype == np.float32, samples.dtype + assert samples.ndim == 1, samples.dim + buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes + buf += (samples.size * 4).to_bytes(4, byteorder="little") + buf += samples.tobytes() + + payload_len = 10240 + while len(buf) > payload_len: + await websocket.send(buf[:payload_len]) + buf = buf[payload_len:] + + if buf: + await websocket.send(buf) + + decoding_results = await websocket.recv() + logging.info(f"{wave_filename}\n{decoding_results}") + + # to signal that the client has sent all the data + await websocket.send("Done") + + +async def main(): + args = get_args() + logging.info(vars(args)) + + server_addr = args.server_addr + server_port = args.server_port + sound_files = args.sound_files + + all_tasks = [] + for wave_filename in sound_files: + task = asyncio.create_task( + run( + server_addr=server_addr, + server_port=server_port, + wave_filename=wave_filename, + ) + ) + all_tasks.append(task) + + await asyncio.gather(*all_tasks) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa + ) + logging.basicConfig(format=formatter, level=logging.INFO) + asyncio.run(main()) diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/offline-websocket-client-decode-files-sequential.py b/apps/frameworks/sherpa-mnn/python-api-examples/offline-websocket-client-decode-files-sequential.py new file mode 100755 index 00000000..7dac1fc0 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/offline-websocket-client-decode-files-sequential.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 Xiaomi Corporation + +""" +A websocket client for sherpa-onnx-offline-websocket-server + +This file shows how to use a single connection to transcribe multiple +files sequentially. + +Usage: + ./offline-websocket-client-decode-files-sequential.py \ + --server-addr localhost \ + --server-port 6006 \ + /path/to/foo.wav \ + /path/to/bar.wav \ + /path/to/16kHz.wav \ + /path/to/8kHz.wav + +(Note: You have to first start the server before starting the client) + +You can find the server at +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/offline-websocket-server.cc + +Note: The server is implemented in C++. +""" + +import argparse +import asyncio +import logging +import wave +from typing import List, Tuple + +try: + import websockets +except ImportError: + print("please run:") + print("") + print(" pip install websockets") + print("") + print("before you run this script") + print("") + +import numpy as np + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--server-addr", + type=str, + default="localhost", + help="Address of the server", + ) + + parser.add_argument( + "--server-port", + type=int, + default=6006, + help="Port of the server", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to decode. Each file must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + return parser.parse_args() + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +async def run( + server_addr: str, + server_port: int, + sound_files: List[str], +): + async with websockets.connect( + f"ws://{server_addr}:{server_port}" + ) as websocket: # noqa + for wave_filename in sound_files: + logging.info(f"Sending {wave_filename}") + samples, sample_rate = read_wave(wave_filename) + assert isinstance(sample_rate, int) + assert samples.dtype == np.float32, samples.dtype + assert samples.ndim == 1, samples.dim + + buf = sample_rate.to_bytes(4, byteorder="little") # 4 bytes + buf += (samples.size * 4).to_bytes(4, byteorder="little") + buf += samples.tobytes() + + payload_len = 10240 + while len(buf) > payload_len: + await websocket.send(buf[:payload_len]) + buf = buf[payload_len:] + + if buf: + await websocket.send(buf) + + decoding_results = await websocket.recv() + print(decoding_results) + + # to signal that the client has sent all the data + await websocket.send("Done") + + +async def main(): + args = get_args() + logging.info(vars(args)) + + server_addr = args.server_addr + server_port = args.server_port + sound_files = args.sound_files + + await run( + server_addr=server_addr, + server_port=server_port, + sound_files=sound_files, + ) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa + ) + logging.basicConfig(format=formatter, level=logging.INFO) + asyncio.run(main()) diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/offline-whisper-decode-files.py b/apps/frameworks/sherpa-mnn/python-api-examples/offline-whisper-decode-files.py new file mode 100644 index 00000000..bbc59b62 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/offline-whisper-decode-files.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a non-streaming whisper model from +https://github.com/openai/whisper +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + +For instance, + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2 +tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2 +rm sherpa-onnx-whisper-tiny.en.tar.bz2 +""" + +import datetime as dt +from pathlib import Path + +import sherpa_mnn +import soundfile as sf + + +def create_recognizer(): + encoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx" + decoder = "./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx" + tokens = "./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt" + test_wav = "./sherpa-onnx-whisper-tiny.en/test_wavs/0.wav" + + if not Path(encoder).is_file() or not Path(test_wav).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return ( + sherpa_mnn.OfflineRecognizer.from_whisper( + encoder=encoder, + decoder=decoder, + tokens=tokens, + debug=True, + ), + test_wav, + ) + + +def main(): + recognizer, wave_filename = create_recognizer() + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] + # sample_rate does not need to be 16000 Hz + + start_t = dt.datetime.now() + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + recognizer.decode_stream(stream) + + end_t = dt.datetime.now() + elapsed_seconds = (end_t - start_t).total_seconds() + duration = audio.shape[-1] / sample_rate + rtf = elapsed_seconds / duration + + print(stream.result) + print(wave_filename) + print("Text:", stream.result.text) + print(f"Audio duration:\t{duration:.3f} s") + print(f"Elapsed:\t{elapsed_seconds:.3f} s") + print(f"RTF = {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/online-decode-files.py b/apps/frameworks/sherpa-mnn/python-api-examples/online-decode-files.py new file mode 100755 index 00000000..329f214a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/online-decode-files.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python3 + +""" +This file demonstrates how to use sherpa-onnx Python API to transcribe +file(s) with a streaming model. + +Usage: + +(1) Streaming transducer + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-2023-06-26.tar.bz2 +tar xvf sherpa-onnx-streaming-zipformer-en-2023-06-26.tar.bz2 +rm sherpa-onnx-streaming-zipformer-en-2023-06-26.tar.bz2 + +./python-api-examples/online-decode-files.py \ + --tokens=./sherpa-onnx-streaming-zipformer-en-2023-06-26/tokens.txt \ + --encoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \ + --decoder=./sherpa-onnx-streaming-zipformer-en-2023-06-26/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \ + --joiner=./sherpa-onnx-streaming-zipformer-en-2023-06-26/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \ + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/0.wav \ + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/1.wav \ + ./sherpa-onnx-streaming-zipformer-en-2023-06-26/test_wavs/8k.wav + +(2) Streaming paraformer + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 +tar xvf sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 +rm sherpa-onnx-streaming-paraformer-bilingual-zh-en.tar.bz2 + +./python-api-examples/online-decode-files.py \ + --tokens=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt \ + --paraformer-encoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx \ + --paraformer-decoder=./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx \ + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav \ + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/1.wav \ + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/2.wav \ + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \ + ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav + +(3) Streaming Zipformer2 CTC + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 +ls -lh sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 + +./python-api-examples/online-decode-files.py \ + --zipformer2-ctc=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.onnx \ + --tokens=./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \ + ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav + +(4) Streaming Conformer CTC from WeNet + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zh-wenet-wenetspeech.tar.bz2 +tar xvf sherpa-onnx-zh-wenet-wenetspeech.tar.bz2 +rm sherpa-onnx-zh-wenet-wenetspeech.tar.bz2 + +./python-api-examples/online-decode-files.py \ + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \ + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model-streaming.onnx \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav + + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +to download streaming pre-trained models. +""" +import argparse +import time +import wave +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import sherpa_mnn + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder", + type=str, + help="Path to the transducer encoder model", + ) + + parser.add_argument( + "--decoder", + type=str, + help="Path to the transducer decoder model", + ) + + parser.add_argument( + "--joiner", + type=str, + help="Path to the transducer joiner model", + ) + + parser.add_argument( + "--zipformer2-ctc", + type=str, + help="Path to the zipformer2 ctc model", + ) + + parser.add_argument( + "--paraformer-encoder", + type=str, + help="Path to the paraformer encoder model", + ) + + parser.add_argument( + "--paraformer-decoder", + type=str, + help="Path to the paraformer decoder model", + ) + + parser.add_argument( + "--wenet-ctc", + type=str, + help="Path to the wenet ctc model", + ) + + parser.add_argument( + "--wenet-ctc-chunk-size", + type=int, + default=16, + help="The --chunk-size parameter for streaming WeNet models", + ) + + parser.add_argument( + "--wenet-ctc-num-left-chunks", + type=int, + default=4, + help="The --num-left-chunks parameter for streaming WeNet models", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + parser.add_argument( + "--max-active-paths", + type=int, + default=4, + help="""Used only when --decoding-method is modified_beam_search. + It specifies number of active paths to keep during decoding. + """, + ) + + parser.add_argument( + "--lm", + type=str, + default="", + help="""Used only when --decoding-method is modified_beam_search. + path of language model. + """, + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.1, + help="""Used only when --decoding-method is modified_beam_search. + scale of language model. + """, + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, like + HELLO WORLD + 你好世界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + parser.add_argument( + "--modeling-unit", + type=str, + default="", + help=""" + The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe. + Used only when hotwords-file is given. + """, + ) + + parser.add_argument( + "--bpe-vocab", + type=str, + default="", + help=""" + The path to the bpe vocabulary, the bpe vocabulary is generated by + sentencepiece, you can also export the bpe vocabulary through a bpe model + by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given + and modeling-unit is bpe or cjkchar+bpe. + """, + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to decode. Each file must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +def main(): + args = get_args() + assert_file_exists(args.tokens) + + if args.encoder: + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + + assert not args.paraformer_encoder, args.paraformer_encoder + assert not args.paraformer_decoder, args.paraformer_decoder + + recognizer = sherpa_mnn.OnlineRecognizer.from_transducer( + tokens=args.tokens, + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + num_threads=args.num_threads, + provider=args.provider, + sample_rate=16000, + feature_dim=80, + decoding_method=args.decoding_method, + max_active_paths=args.max_active_paths, + lm=args.lm, + lm_scale=args.lm_scale, + hotwords_file=args.hotwords_file, + hotwords_score=args.hotwords_score, + modeling_unit=args.modeling_unit, + bpe_vocab=args.bpe_vocab, + blank_penalty=args.blank_penalty, + ) + elif args.zipformer2_ctc: + recognizer = sherpa_mnn.OnlineRecognizer.from_zipformer2_ctc( + tokens=args.tokens, + model=args.zipformer2_ctc, + num_threads=args.num_threads, + provider=args.provider, + sample_rate=16000, + feature_dim=80, + decoding_method="greedy_search", + ) + elif args.paraformer_encoder: + recognizer = sherpa_mnn.OnlineRecognizer.from_paraformer( + tokens=args.tokens, + encoder=args.paraformer_encoder, + decoder=args.paraformer_decoder, + num_threads=args.num_threads, + provider=args.provider, + sample_rate=16000, + feature_dim=80, + decoding_method="greedy_search", + ) + elif args.wenet_ctc: + recognizer = sherpa_mnn.OnlineRecognizer.from_wenet_ctc( + tokens=args.tokens, + model=args.wenet_ctc, + chunk_size=args.wenet_ctc_chunk_size, + num_left_chunks=args.wenet_ctc_num_left_chunks, + num_threads=args.num_threads, + provider=args.provider, + sample_rate=16000, + feature_dim=80, + decoding_method="greedy_search", + ) + else: + raise ValueError("Please provide a model") + + print("Started!") + start_time = time.time() + + streams = [] + total_duration = 0 + for wave_filename in args.sound_files: + assert_file_exists(wave_filename) + samples, sample_rate = read_wave(wave_filename) + duration = len(samples) / sample_rate + total_duration += duration + + s = recognizer.create_stream() + + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + + s.input_finished() + + streams.append(s) + + while True: + ready_list = [] + for s in streams: + if recognizer.is_ready(s): + ready_list.append(s) + if len(ready_list) == 0: + break + recognizer.decode_streams(ready_list) + results = [recognizer.get_result(s) for s in streams] + end_time = time.time() + print("Done!") + + for wave_filename, result in zip(args.sound_files, results): + print(f"{wave_filename}\n{result}") + print("-" * 10) + + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + print(f"num_threads: {args.num_threads}") + print(f"decoding_method: {args.decoding_method}") + print(f"Wave duration: {total_duration:.3f} s") + print(f"Elapsed time: {elapsed_seconds:.3f} s") + print( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/online-nemo-ctc-decode-files.py b/apps/frameworks/sherpa-mnn/python-api-examples/online-nemo-ctc-decode-files.py new file mode 100755 index 00000000..2638ee10 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/online-nemo-ctc-decode-files.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a streaming CTC model from NeMo +to decode files. + +Please download model files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + + +The example model is converted from +https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_80ms +""" + +from pathlib import Path + +import numpy as np +import sherpa_mnn +import soundfile as sf + + +def create_recognizer(): + model = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/model.onnx" + tokens = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/tokens.txt" + + test_wav = "./sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms/test_wavs/0.wav" + + if not Path(model).is_file() or not Path(test_wav).is_file(): + raise ValueError( + """Please download model files from + https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models + """ + ) + return ( + sherpa_mnn.OnlineRecognizer.from_nemo_ctc( + model=model, + tokens=tokens, + debug=True, + ), + test_wav, + ) + + +def main(): + recognizer, wave_filename = create_recognizer() + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # audio is a 1-D float32 numpy array normalized to the range [-1, 1] + # sample_rate does not need to be 16000 Hz + + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, audio) + + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) + stream.accept_waveform(sample_rate, tail_paddings) + stream.input_finished() + + while recognizer.is_ready(stream): + recognizer.decode_stream(stream) + print(wave_filename) + print(recognizer.get_result_all(stream)) + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/online-websocket-client-decode-file.py b/apps/frameworks/sherpa-mnn/python-api-examples/online-websocket-client-decode-file.py new file mode 100755 index 00000000..cbe55c86 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/online-websocket-client-decode-file.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 Xiaomi Corporation + +""" +A websocket client for sherpa-onnx-online-websocket-server + +Usage: + ./online-websocket-client-decode-file.py \ + --server-addr localhost \ + --server-port 6006 \ + --seconds-per-message 0.1 \ + --samples-per-message 8000 \ + /path/to/foo.wav + +(Note: You have to first start the server before starting the client) + +You can find the c++ server at +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc +or use the python server ./python-api-examples/streaming_server.py + +There is also a C++ version of the client. Please see +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc +""" + +import argparse +import asyncio +import json +import logging +import wave + +try: + import websockets +except ImportError: + print("please run:") + print("") + print(" pip install websockets") + print("") + print("before you run this script") + print("") + +import numpy as np + + +def read_wave(wave_filename: str) -> np.ndarray: + """ + Args: + wave_filename: + Path to a wave file. Its sampling rate has to be 16000. + It should be single channel and each sample should be 16-bit. + Returns: + Return a 1-D float32 tensor. + """ + + with wave.open(wave_filename) as f: + assert f.getframerate() == 16000, f.getframerate() + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32 + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--server-addr", + type=str, + default="localhost", + help="Address of the server", + ) + + parser.add_argument( + "--server-port", + type=int, + default=6006, + help="Port of the server", + ) + + parser.add_argument( + "--samples-per-message", + type=int, + default=8000, + help="Number of samples per message", + ) + + parser.add_argument( + "--seconds-per-message", + type=float, + default=0.1, + help="We will simulate that the duration of two messages is of this value", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file. Must be wave with a single channel, 16kHz " + "sampling rate, 16-bit of each sample.", + ) + + return parser.parse_args() + + +async def receive_results(socket: websockets.WebSocketServerProtocol): + last_message = "" + async for message in socket: + if message != "Done!": + last_message = message + logging.info(json.loads(message)) + else: + break + return last_message + + +async def run( + server_addr: str, + server_port: int, + wave_filename: str, + samples_per_message: int, + seconds_per_message: float, +): + data = read_wave(wave_filename) + + async with websockets.connect( + f"ws://{server_addr}:{server_port}" + ) as websocket: # noqa + logging.info(f"Sending {wave_filename}") + + receive_task = asyncio.create_task(receive_results(websocket)) + + start = 0 + while start < data.shape[0]: + end = start + samples_per_message + end = min(end, data.shape[0]) + d = data.data[start:end].tobytes() + + await websocket.send(d) + + # Simulate streaming. You can remove the sleep if you want + await asyncio.sleep(seconds_per_message) # in seconds + + start += samples_per_message + + # to signal that the client has sent all the data + await websocket.send("Done") + + decoding_results = await receive_task + logging.info(f"\nFinal result is:\n{json.loads(decoding_results)}") + + +async def main(): + args = get_args() + logging.info(vars(args)) + + server_addr = args.server_addr + server_port = args.server_port + samples_per_message = args.samples_per_message + seconds_per_message = args.seconds_per_message + + await run( + server_addr=server_addr, + server_port=server_port, + wave_filename=args.sound_file, + samples_per_message=samples_per_message, + seconds_per_message=seconds_per_message, + ) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" # noqa + ) + logging.basicConfig(format=formatter, level=logging.INFO) + asyncio.run(main()) diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/online-websocket-client-microphone.py b/apps/frameworks/sherpa-mnn/python-api-examples/online-websocket-client-microphone.py new file mode 100755 index 00000000..f42dd008 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/online-websocket-client-microphone.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 Xiaomi Corporation + +""" +A websocket client for sherpa-onnx-online-websocket-server + +Usage: + ./online-websocket-client-microphone.py \ + --server-addr localhost \ + --server-port 6006 + +(Note: You have to first start the server before starting the client) + +You can find the C++ server at +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-server.cc +or use the python server ./python-api-examples/streaming_server.py + +There is also a C++ version of the client. Please see +https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/csrc/online-websocket-client.cc +""" + +import argparse +import asyncio +import sys + +import numpy as np + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +try: + import websockets +except ImportError: + print("please run:") + print("") + print(" pip install websockets") + print("") + print("before you run this script") + print("") + sys.exit(-1) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--server-addr", + type=str, + default="localhost", + help="Address of the server", + ) + + parser.add_argument( + "--server-port", + type=int, + default=6006, + help="Port of the server", + ) + + return parser.parse_args() + + +async def inputstream_generator(channels=1): + """Generator that yields blocks of input data as NumPy arrays. + + See https://python-sounddevice.readthedocs.io/en/0.4.6/examples.html#creating-an-asyncio-generator-for-audio-blocks + """ + q_in = asyncio.Queue() + loop = asyncio.get_event_loop() + + def callback(indata, frame_count, time_info, status): + loop.call_soon_threadsafe(q_in.put_nowait, (indata.copy(), status)) + + devices = sd.query_devices() + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + print() + print("Started! Please speak") + + stream = sd.InputStream( + callback=callback, + channels=channels, + dtype="float32", + samplerate=16000, + blocksize=int(0.05 * 16000), # 0.05 seconds + ) + with stream: + while True: + indata, status = await q_in.get() + yield indata, status + + +async def receive_results(socket: websockets.WebSocketServerProtocol): + last_message = "" + async for message in socket: + if message != "Done!": + if last_message != message: + last_message = message + + if last_message: + print(last_message) + else: + return last_message + + +async def run( + server_addr: str, + server_port: int, +): + async with websockets.connect( + f"ws://{server_addr}:{server_port}" + ) as websocket: # noqa + receive_task = asyncio.create_task(receive_results(websocket)) + print("Started! Please Speak") + + async for indata, status in inputstream_generator(): + if status: + print(status) + indata = indata.reshape(-1) + indata = np.ascontiguousarray(indata) + await websocket.send(indata.tobytes()) + + decoding_results = await receive_task + print(f"\nFinal result is:\n{decoding_results}") + + +async def main(): + args = get_args() + print(vars(args)) + + server_addr = args.server_addr + server_port = args.server_port + + await run( + server_addr=server_addr, + server_port=server_port, + ) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/online-zipformer-ctc-hlg-decode-file.py b/apps/frameworks/sherpa-mnn/python-api-examples/online-zipformer-ctc-hlg-decode-file.py new file mode 100755 index 00000000..7f670748 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/online-zipformer-ctc-hlg-decode-file.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 + +# This file shows how to use a streaming zipformer CTC model and an HLG +# graph for decoding. +# +# We use the following model as an example +# +""" +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 +tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 +rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2 + +python3 ./python-api-examples/online-zipformer-ctc-hlg-decode-file.py \ + --tokens ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt \ + --graph ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst \ + --model ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \ + ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/0.wav + +""" +# (The above model is from https://github.com/k2-fsa/icefall/pull/1557) + +import argparse +import time +import wave +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import sherpa_mnn + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--model", + type=str, + required=True, + help="Path to the ONNX model", + ) + + parser.add_argument( + "--graph", + type=str, + required=True, + help="Path to H.fst, HL.fst, or HLG.fst", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + parser.add_argument( + "--debug", + type=int, + default=0, + help="Valid values: 1, 0", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to decode. It must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +def main(): + args = get_args() + print(vars(args)) + + assert_file_exists(args.tokens) + assert_file_exists(args.graph) + assert_file_exists(args.model) + + recognizer = sherpa_mnn.OnlineRecognizer.from_zipformer2_ctc( + tokens=args.tokens, + model=args.model, + num_threads=args.num_threads, + provider=args.provider, + sample_rate=16000, + feature_dim=80, + ctc_graph=args.graph, + ) + + wave_filename = args.sound_file + assert_file_exists(wave_filename) + samples, sample_rate = read_wave(wave_filename) + duration = len(samples) / sample_rate + + print("Started") + + start_time = time.time() + s = recognizer.create_stream() + s.accept_waveform(sample_rate, samples) + tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + while recognizer.is_ready(s): + recognizer.decode_stream(s) + + result = recognizer.get_result(s).lower() + end_time = time.time() + + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / duration + print(f"num_threads: {args.num_threads}") + print(f"Wave duration: {duration:.3f} s") + print(f"Elapsed time: {elapsed_seconds:.3f} s") + print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}") + print(result) + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/speaker-identification-with-vad-dynamic.py b/apps/frameworks/sherpa-mnn/python-api-examples/speaker-identification-with-vad-dynamic.py new file mode 100755 index 00000000..ceedeadc --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/speaker-identification-with-vad-dynamic.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 + +""" +This script shows how to use Python APIs for speaker identification with +a microphone and a VAD model + +Usage: + +(1) Download a model for computing speaker embeddings + +Please visit +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +to download a model. An example is given below: + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx + +Note that `zh` means Chinese, while `en` means English. + +(2) Download the VAD model +Please visit +https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx +to download silero_vad.onnx + +For instance, + +wget https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx + +(3) Run this script + +python3 ./python-api-examples/speaker-identification-with-vad-dynamic.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --model ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx +""" +import argparse +import sys + +import numpy as np +import sherpa_mnn + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +g_sample_rate = 16000 + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model", + type=str, + required=True, + help="Path to the speaker embedding model file.", + ) + + parser.add_argument( + "--silero-vad-model", + type=str, + required=True, + help="Path to silero_vad.onnx", + ) + + parser.add_argument("--threshold", type=float, default=0.4) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + return parser.parse_args() + + +def load_speaker_embedding_model(args): + config = sherpa_mnn.SpeakerEmbeddingExtractorConfig( + model=args.model, + num_threads=args.num_threads, + debug=args.debug, + provider=args.provider, + ) + if not config.validate(): + raise ValueError(f"Invalid config. {config}") + extractor = sherpa_mnn.SpeakerEmbeddingExtractor(config) + return extractor + + +def compute_speaker_embedding( + samples: np.ndarray, + extractor: sherpa_mnn.SpeakerEmbeddingExtractor, +) -> np.ndarray: + """ + Args: + samples: + A 1-D float32 array. + extractor: + The return value of function load_speaker_embedding_model(). + Returns: + Return a 1-D float32 array. + """ + if len(samples) < g_sample_rate: + print(f"Your input contains only {len(samples)} samples!") + + stream = extractor.create_stream() + stream.accept_waveform(sample_rate=g_sample_rate, waveform=samples) + stream.input_finished() + + assert extractor.is_ready(stream) + embedding = extractor.compute(stream) + embedding = np.array(embedding) + return embedding + + +def main(): + args = get_args() + print(args) + + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + sys.exit(0) + + print(devices) + # If you want to select a different device, please change + # sd.default.device[0]. For instance, if you want to select device 10, + # please use + # + # sd.default.device[0] = 4 + # print(devices) + # + + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + extractor = load_speaker_embedding_model(args) + + manager = sherpa_mnn.SpeakerEmbeddingManager(extractor.dim) + + vad_config = sherpa_mnn.VadModelConfig() + vad_config.silero_vad.model = args.silero_vad_model + vad_config.silero_vad.min_silence_duration = 0.25 + vad_config.silero_vad.min_speech_duration = 1.0 + vad_config.sample_rate = g_sample_rate + + window_size = vad_config.silero_vad.window_size + vad = sherpa_mnn.VoiceActivityDetector(vad_config, buffer_size_in_seconds=100) + + samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms + + print("Started! Please speak") + + line_num = 0 + speaker_id = 0 + buffer = [] + with sd.InputStream(channels=1, dtype="float32", samplerate=g_sample_rate) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + buffer = np.concatenate([buffer, samples]) + while len(buffer) > window_size: + vad.accept_waveform(buffer[:window_size]) + buffer = buffer[window_size:] + + while not vad.empty(): + if len(vad.front.samples) < 0.5 * g_sample_rate: + # this segment is too short, skip it + vad.pop() + continue + stream = extractor.create_stream() + stream.accept_waveform( + sample_rate=g_sample_rate, waveform=vad.front.samples + ) + vad.pop() + stream.input_finished() + + embedding = extractor.compute(stream) + embedding = np.array(embedding) + name = manager.search(embedding, threshold=args.threshold) + if not name: + # register it + new_name = f"speaker_{speaker_id}" + status = manager.add(new_name, embedding) + if not status: + raise RuntimeError(f"Failed to register speaker {new_name}") + print( + f"{line_num}: Detected new speaker. Register it as {new_name}" + ) + speaker_id += 1 + else: + print(f"{line_num}: Detected existing speaker: {name}") + line_num += 1 + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/speaker-identification-with-vad-non-streaming-asr-alsa.py b/apps/frameworks/sherpa-mnn/python-api-examples/speaker-identification-with-vad-non-streaming-asr-alsa.py new file mode 100644 index 00000000..2aabb7b2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/speaker-identification-with-vad-non-streaming-asr-alsa.py @@ -0,0 +1,494 @@ +#!/usr/bin/env python3 + +""" +This script works only on Linux. It uses ALSA for recording. + +This script shows how to use Python APIs for speaker identification with +a microphone, a VAD model, and a non-streaming ASR model. + +Please see also ./generate-subtitles.py + +Usage: + +(1) Prepare a text file containing speaker related files. + +Each line in the text file contains two columns. The first column is the +speaker name, while the second column contains the wave file of the speaker. + +If the text file contains multiple wave files for the same speaker, then the +embeddings of these files are averaged. + +An example text file is given below: + + foo /path/to/a.wav + bar /path/to/b.wav + foo /path/to/c.wav + foobar /path/to/d.wav + +Each wave file should contain only a single channel; the sample format +should be int16_t; the sample rate can be arbitrary. + +(2) Download a model for computing speaker embeddings + +Please visit +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +to download a model. An example is given below: + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/wespeaker_zh_cnceleb_resnet34.onnx + +Note that `zh` means Chinese, while `en` means English. + +(3) Download the VAD model +Please visit +https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx +to download silero_vad.onnx + +For instance, + +wget https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx + +(4) Please refer to ./generate-subtitles.py +to download a non-streaming ASR model. + +(5) Run this script + +Assume the filename of the text file is speaker.txt. + +python3 ./python-api-examples/speaker-identification-with-vad-non-streaming-asr.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --speaker-file ./speaker.txt \ + --model ./wespeaker_zh_cnceleb_resnet34.onnx +""" +import argparse +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import sherpa_mnn +import soundfile as sf + +g_sample_rate = 16000 + + +def register_non_streaming_asr_model_args(parser): + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder", + default="", + type=str, + help="Path to the transducer encoder model", + ) + + parser.add_argument( + "--decoder", + default="", + type=str, + help="Path to the transducer decoder model", + ) + + parser.add_argument( + "--joiner", + default="", + type=str, + help="Path to the transducer joiner model", + ) + + parser.add_argument( + "--paraformer", + default="", + type=str, + help="Path to the model.onnx from Paraformer", + ) + + parser.add_argument( + "--wenet-ctc", + default="", + type=str, + help="Path to the CTC model.onnx from WeNet", + ) + + parser.add_argument( + "--whisper-encoder", + default="", + type=str, + help="Path to whisper encoder model", + ) + + parser.add_argument( + "--whisper-decoder", + default="", + type=str, + help="Path to whisper decoder model", + ) + + parser.add_argument( + "--whisper-language", + default="", + type=str, + help="""It specifies the spoken language in the input file. + Example values: en, fr, de, zh, jp. + Available languages for multilingual models can be found at + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + If not specified, we infer the language from the input audio file. + """, + ) + + parser.add_argument( + "--whisper-task", + default="transcribe", + choices=["transcribe", "translate"], + type=str, + help="""For multilingual models, if you specify translate, the output + will be in English. + """, + ) + + parser.add_argument( + "--whisper-tail-paddings", + default=-1, + type=int, + help="""Number of tail padding frames. + We have removed the 30-second constraint from whisper, so you need to + choose the amount of tail padding frames by yourself. + Use -1 to use a default value for tail padding. + """, + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Valid values are greedy_search and modified_beam_search. + modified_beam_search is valid only for transducer models. + """, + ) + + parser.add_argument( + "--feature-dim", + type=int, + default=80, + help="Feature dimension. Must match the one expected by the model", + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + register_non_streaming_asr_model_args(parser) + + parser.add_argument( + "--speaker-file", + type=str, + required=True, + help="""Path to the speaker file. Read the help doc at the beginning of this + file for the format.""", + ) + + parser.add_argument( + "--model", + type=str, + required=True, + help="Path to the speaker embedding model file.", + ) + + parser.add_argument( + "--silero-vad-model", + type=str, + required=True, + help="Path to silero_vad.onnx", + ) + + parser.add_argument("--threshold", type=float, default=0.6) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + parser.add_argument( + "--device-name", + type=str, + required=True, + help=""" +The device name specifies which microphone to use in case there are several +on your system. You can use + + arecord -l + +to find all available microphones on your computer. For instance, if it outputs + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and device 0 on that card, please use: + + plughw:3,0 + +as the device_name. + """, + ) + + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def create_recognizer(args) -> sherpa_mnn.OfflineRecognizer: + if args.encoder: + assert len(args.paraformer) == 0, args.paraformer + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + + recognizer = sherpa_mnn.OfflineRecognizer.from_transducer( + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + elif args.paraformer: + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + + assert_file_exists(args.paraformer) + + recognizer = sherpa_mnn.OfflineRecognizer.from_paraformer( + paraformer=args.paraformer, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=g_sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + elif args.wenet_ctc: + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + + assert_file_exists(args.wenet_ctc) + + recognizer = sherpa_mnn.OfflineRecognizer.from_wenet_ctc( + model=args.wenet_ctc, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + elif args.whisper_encoder: + assert_file_exists(args.whisper_encoder) + assert_file_exists(args.whisper_decoder) + + recognizer = sherpa_mnn.OfflineRecognizer.from_whisper( + encoder=args.whisper_encoder, + decoder=args.whisper_decoder, + tokens=args.tokens, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + debug=args.debug, + language=args.whisper_language, + task=args.whisper_task, + tail_paddings=args.whisper_tail_paddings, + ) + else: + raise ValueError("Please specify at least one model") + + return recognizer + + +def load_speaker_embedding_model(args): + config = sherpa_mnn.SpeakerEmbeddingExtractorConfig( + model=args.model, + num_threads=args.num_threads, + debug=args.debug, + provider=args.provider, + ) + if not config.validate(): + raise ValueError(f"Invalid config. {config}") + extractor = sherpa_mnn.SpeakerEmbeddingExtractor(config) + return extractor + + +def load_speaker_file(args) -> Dict[str, List[str]]: + if not Path(args.speaker_file).is_file(): + raise ValueError(f"--speaker-file {args.speaker_file} does not exist") + + ans = defaultdict(list) + with open(args.speaker_file) as f: + for line in f: + line = line.strip() + if not line: + continue + + fields = line.split() + if len(fields) != 2: + raise ValueError(f"Invalid line: {line}. Fields: {fields}") + + speaker_name, filename = fields + ans[speaker_name].append(filename) + return ans + + +def load_audio(filename: str) -> Tuple[np.ndarray, int]: + data, sample_rate = sf.read( + filename, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + samples = np.ascontiguousarray(data) + return samples, sample_rate + + +def compute_speaker_embedding( + filenames: List[str], + extractor: sherpa_mnn.SpeakerEmbeddingExtractor, +) -> np.ndarray: + assert len(filenames) > 0, "filenames is empty" + + ans = None + for filename in filenames: + print(f"processing {filename}") + samples, sample_rate = load_audio(filename) + stream = extractor.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=samples) + stream.input_finished() + + assert extractor.is_ready(stream) + embedding = extractor.compute(stream) + embedding = np.array(embedding) + if ans is None: + ans = embedding + else: + ans += embedding + + return ans / len(filenames) + + +def main(): + args = get_args() + print(args) + + device_name = args.device_name + print(f"device_name: {device_name}") + alsa = sherpa_mnn.Alsa(device_name) + + recognizer = create_recognizer(args) + extractor = load_speaker_embedding_model(args) + speaker_file = load_speaker_file(args) + + manager = sherpa_mnn.SpeakerEmbeddingManager(extractor.dim) + for name, filename_list in speaker_file.items(): + embedding = compute_speaker_embedding( + filenames=filename_list, + extractor=extractor, + ) + status = manager.add(name, embedding) + if not status: + raise RuntimeError(f"Failed to register speaker {name}") + + vad_config = sherpa_mnn.VadModelConfig() + vad_config.silero_vad.model = args.silero_vad_model + vad_config.silero_vad.min_silence_duration = 0.25 + vad_config.silero_vad.min_speech_duration = 0.25 + vad_config.sample_rate = g_sample_rate + if not vad_config.validate(): + raise ValueError("Errors in vad config") + + window_size = vad_config.silero_vad.window_size + + vad = sherpa_mnn.VoiceActivityDetector(vad_config, buffer_size_in_seconds=100) + + samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms + + print("Started! Please speak") + + idx = 0 + buffer = [] + while True: + samples = alsa.read(samples_per_read) # a blocking read + samples = np.array(samples) + buffer = np.concatenate([buffer, samples]) + while len(buffer) > window_size: + vad.accept_waveform(buffer[:window_size]) + buffer = buffer[window_size:] + + while not vad.empty(): + if len(vad.front.samples) < 0.5 * g_sample_rate: + # this segment is too short, skip it + vad.pop() + continue + stream = extractor.create_stream() + stream.accept_waveform( + sample_rate=g_sample_rate, waveform=vad.front.samples + ) + stream.input_finished() + + embedding = extractor.compute(stream) + embedding = np.array(embedding) + name = manager.search(embedding, threshold=args.threshold) + if not name: + name = "unknown" + + # Now for non-streaming ASR + asr_stream = recognizer.create_stream() + asr_stream.accept_waveform( + sample_rate=g_sample_rate, waveform=vad.front.samples + ) + recognizer.decode_stream(asr_stream) + text = asr_stream.result.text + + vad.pop() + + print(f"\r{idx}-{name}: {text}") + idx += 1 + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/speaker-identification-with-vad-non-streaming-asr.py b/apps/frameworks/sherpa-mnn/python-api-examples/speaker-identification-with-vad-non-streaming-asr.py new file mode 100755 index 00000000..16d34271 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/speaker-identification-with-vad-non-streaming-asr.py @@ -0,0 +1,499 @@ +#!/usr/bin/env python3 + +""" +This script shows how to use Python APIs for speaker identification with +a microphone, a VAD model, and a non-streaming ASR model. + +Please see also ./generate-subtitles.py + +Usage: + +(1) Prepare a text file containing speaker related files. + +Each line in the text file contains two columns. The first column is the +speaker name, while the second column contains the wave file of the speaker. + +If the text file contains multiple wave files for the same speaker, then the +embeddings of these files are averaged. + +An example text file is given below: + + foo /path/to/a.wav + bar /path/to/b.wav + foo /path/to/c.wav + foobar /path/to/d.wav + +Each wave file should contain only a single channel; the sample format +should be int16_t; the sample rate can be arbitrary. + +(2) Download a model for computing speaker embeddings + +Please visit +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +to download a model. An example is given below: + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/wespeaker_zh_cnceleb_resnet34.onnx + +Note that `zh` means Chinese, while `en` means English. + +(3) Download the VAD model +Please visit +https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx +to download silero_vad.onnx + +For instance, + +wget https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx + +(4) Please refer to ./generate-subtitles.py +to download a non-streaming ASR model. + +(5) Run this script + +Assume the filename of the text file is speaker.txt. + +python3 ./python-api-examples/speaker-identification-with-vad-non-streaming-asr.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --speaker-file ./speaker.txt \ + --model ./wespeaker_zh_cnceleb_resnet34.onnx +""" +import argparse +import sys +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import sherpa_mnn +import soundfile as sf + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +g_sample_rate = 16000 + + +def register_non_streaming_asr_model_args(parser): + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder", + default="", + type=str, + help="Path to the transducer encoder model", + ) + + parser.add_argument( + "--decoder", + default="", + type=str, + help="Path to the transducer decoder model", + ) + + parser.add_argument( + "--joiner", + default="", + type=str, + help="Path to the transducer joiner model", + ) + + parser.add_argument( + "--paraformer", + default="", + type=str, + help="Path to the model.onnx from Paraformer", + ) + + parser.add_argument( + "--wenet-ctc", + default="", + type=str, + help="Path to the CTC model.onnx from WeNet", + ) + + parser.add_argument( + "--whisper-encoder", + default="", + type=str, + help="Path to whisper encoder model", + ) + + parser.add_argument( + "--whisper-decoder", + default="", + type=str, + help="Path to whisper decoder model", + ) + + parser.add_argument( + "--whisper-language", + default="", + type=str, + help="""It specifies the spoken language in the input file. + Example values: en, fr, de, zh, jp. + Available languages for multilingual models can be found at + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + If not specified, we infer the language from the input audio file. + """, + ) + + parser.add_argument( + "--whisper-task", + default="transcribe", + choices=["transcribe", "translate"], + type=str, + help="""For multilingual models, if you specify translate, the output + will be in English. + """, + ) + + parser.add_argument( + "--whisper-tail-paddings", + default=-1, + type=int, + help="""Number of tail padding frames. + We have removed the 30-second constraint from whisper, so you need to + choose the amount of tail padding frames by yourself. + Use -1 to use a default value for tail padding. + """, + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Valid values are greedy_search and modified_beam_search. + modified_beam_search is valid only for transducer models. + """, + ) + + parser.add_argument( + "--feature-dim", + type=int, + default=80, + help="Feature dimension. Must match the one expected by the model", + ) + + parser.add_argument( + "--sense-voice", + default="", + type=str, + help="Path to sense voice model", + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + register_non_streaming_asr_model_args(parser) + + parser.add_argument( + "--speaker-file", + type=str, + required=True, + help="""Path to the speaker file. Read the help doc at the beginning of this + file for the format.""", + ) + + parser.add_argument( + "--model", + type=str, + required=True, + help="Path to the speaker embedding model file.", + ) + + parser.add_argument( + "--silero-vad-model", + type=str, + required=True, + help="Path to silero_vad.onnx", + ) + + parser.add_argument("--threshold", type=float, default=0.6) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def create_recognizer(args) -> sherpa_mnn.OfflineRecognizer: + if args.encoder: + assert len(args.paraformer) == 0, args.paraformer + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + + recognizer = sherpa_mnn.OfflineRecognizer.from_transducer( + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + elif args.paraformer: + assert len(args.wenet_ctc) == 0, args.wenet_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + + assert_file_exists(args.paraformer) + + recognizer = sherpa_mnn.OfflineRecognizer.from_paraformer( + paraformer=args.paraformer, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=g_sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + elif args.wenet_ctc: + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + + assert_file_exists(args.wenet_ctc) + + recognizer = sherpa_mnn.OfflineRecognizer.from_wenet_ctc( + model=args.wenet_ctc, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + elif args.whisper_encoder: + assert_file_exists(args.whisper_encoder) + assert_file_exists(args.whisper_decoder) + + recognizer = sherpa_mnn.OfflineRecognizer.from_whisper( + encoder=args.whisper_encoder, + decoder=args.whisper_decoder, + tokens=args.tokens, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + debug=args.debug, + language=args.whisper_language, + task=args.whisper_task, + tail_paddings=args.whisper_tail_paddings, + ) + elif args.sense_voice: + assert_file_exists(args.sense_voice) + recognizer = sherpa_mnn.OfflineRecognizer.from_sense_voice( + model=args.sense_voice, + tokens=args.tokens, + num_threads=args.num_threads, + use_itn=True, + debug=args.debug, + ) + else: + raise ValueError("Please specify at least one model") + + return recognizer + + +def load_speaker_embedding_model(args): + config = sherpa_mnn.SpeakerEmbeddingExtractorConfig( + model=args.model, + num_threads=args.num_threads, + debug=args.debug, + provider=args.provider, + ) + if not config.validate(): + raise ValueError(f"Invalid config. {config}") + extractor = sherpa_mnn.SpeakerEmbeddingExtractor(config) + return extractor + + +def load_speaker_file(args) -> Dict[str, List[str]]: + if not Path(args.speaker_file).is_file(): + raise ValueError(f"--speaker-file {args.speaker_file} does not exist") + + ans = defaultdict(list) + with open(args.speaker_file) as f: + for line in f: + line = line.strip() + if not line: + continue + + fields = line.split() + if len(fields) != 2: + raise ValueError(f"Invalid line: {line}. Fields: {fields}") + + speaker_name, filename = fields + ans[speaker_name].append(filename) + return ans + + +def load_audio(filename: str) -> Tuple[np.ndarray, int]: + data, sample_rate = sf.read( + filename, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + samples = np.ascontiguousarray(data) + return samples, sample_rate + + +def compute_speaker_embedding( + filenames: List[str], + extractor: sherpa_mnn.SpeakerEmbeddingExtractor, +) -> np.ndarray: + assert len(filenames) > 0, "filenames is empty" + + ans = None + for filename in filenames: + print(f"processing {filename}") + samples, sample_rate = load_audio(filename) + stream = extractor.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=samples) + stream.input_finished() + + assert extractor.is_ready(stream) + embedding = extractor.compute(stream) + embedding = np.array(embedding) + if ans is None: + ans = embedding + else: + ans += embedding + + return ans / len(filenames) + + +def main(): + args = get_args() + print(args) + recognizer = create_recognizer(args) + extractor = load_speaker_embedding_model(args) + speaker_file = load_speaker_file(args) + + manager = sherpa_mnn.SpeakerEmbeddingManager(extractor.dim) + for name, filename_list in speaker_file.items(): + embedding = compute_speaker_embedding( + filenames=filename_list, + extractor=extractor, + ) + status = manager.add(name, embedding) + if not status: + raise RuntimeError(f"Failed to register speaker {name}") + + vad_config = sherpa_mnn.VadModelConfig() + vad_config.silero_vad.model = args.silero_vad_model + vad_config.silero_vad.min_silence_duration = 0.25 + vad_config.silero_vad.min_speech_duration = 0.25 + vad_config.sample_rate = g_sample_rate + if not vad_config.validate(): + raise ValueError("Errors in vad config") + + window_size = vad_config.silero_vad.window_size + + vad = sherpa_mnn.VoiceActivityDetector(vad_config, buffer_size_in_seconds=100) + + samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms + + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + sys.exit(0) + + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + print("Started! Please speak") + + idx = 0 + buffer = [] + with sd.InputStream(channels=1, dtype="float32", samplerate=g_sample_rate) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + buffer = np.concatenate([buffer, samples]) + while len(buffer) > window_size: + vad.accept_waveform(buffer[:window_size]) + buffer = buffer[window_size:] + + while not vad.empty(): + if len(vad.front.samples) < 0.5 * g_sample_rate: + # this segment is too short, skip it + vad.pop() + continue + stream = extractor.create_stream() + stream.accept_waveform( + sample_rate=g_sample_rate, waveform=vad.front.samples + ) + stream.input_finished() + + embedding = extractor.compute(stream) + embedding = np.array(embedding) + name = manager.search(embedding, threshold=args.threshold) + if not name: + name = "unknown" + + # Now for non-streaming ASR + asr_stream = recognizer.create_stream() + asr_stream.accept_waveform( + sample_rate=g_sample_rate, waveform=vad.front.samples + ) + recognizer.decode_stream(asr_stream) + text = asr_stream.result.text + + vad.pop() + + print(f"\r{idx}-{name}: {text}") + idx += 1 + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/speaker-identification-with-vad.py b/apps/frameworks/sherpa-mnn/python-api-examples/speaker-identification-with-vad.py new file mode 100755 index 00000000..1363514a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/speaker-identification-with-vad.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 + +""" +This script shows how to use Python APIs for speaker identification with +a microphone and a VAD model + +Usage: + +(1) Prepare a text file containing speaker related files. + +Each line in the text file contains two columns. The first column is the +speaker name, while the second column contains the wave file of the speaker. + +If the text file contains multiple wave files for the same speaker, then the +embeddings of these files are averaged. + +An example text file is given below: + + foo /path/to/a.wav + bar /path/to/b.wav + foo /path/to/c.wav + foobar /path/to/d.wav + +Each wave file should contain only a single channel; the sample format +should be int16_t; the sample rate can be arbitrary. + +(2) Download a model for computing speaker embeddings + +Please visit +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +to download a model. An example is given below: + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/wespeaker_zh_cnceleb_resnet34.onnx + +Note that `zh` means Chinese, while `en` means English. + +(3) Download the VAD model +Please visit +https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx +to download silero_vad.onnx + +For instance, + +wget https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx + +(4) Run this script + +Assume the filename of the text file is speaker.txt. + +python3 ./python-api-examples/speaker-identification-with-vad.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --speaker-file ./speaker.txt \ + --model ./wespeaker_zh_cnceleb_resnet34.onnx +""" +import argparse +import sys +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import sherpa_mnn +import soundfile as sf + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--speaker-file", + type=str, + required=True, + help="""Path to the speaker file. Read the help doc at the beginning of this + file for the format.""", + ) + + parser.add_argument( + "--model", + type=str, + required=True, + help="Path to the speaker embedding model file.", + ) + + parser.add_argument( + "--silero-vad-model", + type=str, + required=True, + help="Path to silero_vad.onnx", + ) + + parser.add_argument("--threshold", type=float, default=0.6) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + return parser.parse_args() + + +def load_speaker_embedding_model(args): + config = sherpa_mnn.SpeakerEmbeddingExtractorConfig( + model=args.model, + num_threads=args.num_threads, + debug=args.debug, + provider=args.provider, + ) + if not config.validate(): + raise ValueError(f"Invalid config. {config}") + extractor = sherpa_mnn.SpeakerEmbeddingExtractor(config) + return extractor + + +def load_speaker_file(args) -> Dict[str, List[str]]: + if not Path(args.speaker_file).is_file(): + raise ValueError(f"--speaker-file {args.speaker_file} does not exist") + + ans = defaultdict(list) + with open(args.speaker_file) as f: + for line in f: + line = line.strip() + if not line: + continue + + fields = line.split() + if len(fields) != 2: + raise ValueError(f"Invalid line: {line}. Fields: {fields}") + + speaker_name, filename = fields + ans[speaker_name].append(filename) + return ans + + +def load_audio(filename: str) -> Tuple[np.ndarray, int]: + data, sample_rate = sf.read( + filename, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + samples = np.ascontiguousarray(data) + return samples, sample_rate + + +def compute_speaker_embedding( + filenames: List[str], + extractor: sherpa_mnn.SpeakerEmbeddingExtractor, +) -> np.ndarray: + assert len(filenames) > 0, "filenames is empty" + + ans = None + for filename in filenames: + print(f"processing {filename}") + samples, sample_rate = load_audio(filename) + stream = extractor.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=samples) + stream.input_finished() + + assert extractor.is_ready(stream) + embedding = extractor.compute(stream) + embedding = np.array(embedding) + if ans is None: + ans = embedding + else: + ans += embedding + + return ans / len(filenames) + + +g_sample_rate = 16000 + + +def main(): + args = get_args() + print(args) + extractor = load_speaker_embedding_model(args) + speaker_file = load_speaker_file(args) + + manager = sherpa_mnn.SpeakerEmbeddingManager(extractor.dim) + for name, filename_list in speaker_file.items(): + embedding = compute_speaker_embedding( + filenames=filename_list, + extractor=extractor, + ) + status = manager.add(name, embedding) + if not status: + raise RuntimeError(f"Failed to register speaker {name}") + + vad_config = sherpa_mnn.VadModelConfig() + vad_config.silero_vad.model = args.silero_vad_model + vad_config.silero_vad.min_silence_duration = 0.25 + vad_config.silero_vad.min_speech_duration = 0.25 + vad_config.sample_rate = g_sample_rate + + window_size = vad_config.silero_vad.window_size + vad = sherpa_mnn.VoiceActivityDetector(vad_config, buffer_size_in_seconds=100) + + samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms + + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + sys.exit(0) + + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + print("Started! Please speak") + + idx = 0 + buffer = [] + with sd.InputStream(channels=1, dtype="float32", samplerate=g_sample_rate) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + buffer = np.concatenate([buffer, samples]) + while len(buffer) > window_size: + vad.accept_waveform(buffer[:window_size]) + buffer = buffer[window_size:] + + while not vad.empty(): + if len(vad.front.samples) < 0.5 * g_sample_rate: + # this segment is too short, skip it + vad.pop() + continue + stream = extractor.create_stream() + stream.accept_waveform( + sample_rate=g_sample_rate, waveform=vad.front.samples + ) + vad.pop() + stream.input_finished() + + print("Computing", end="") + embedding = extractor.compute(stream) + embedding = np.array(embedding) + name = manager.search(embedding, threshold=args.threshold) + if not name: + name = "unknown" + print(f"\r{idx}: Predicted name: {name}") + idx += 1 + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/speaker-identification.py b/apps/frameworks/sherpa-mnn/python-api-examples/speaker-identification.py new file mode 100755 index 00000000..48386e7f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/speaker-identification.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 + +""" +This script shows how to use Python APIs for speaker identification with +a microphone. + +Usage: + +(1) Prepare a text file containing speaker related files. + +Each line in the text file contains two columns. The first column is the +speaker name, while the second column contains the wave file of the speaker. + +If the text file contains multiple wave files for the same speaker, then the +embeddings of these files are averaged. + +An example text file is given below: + + foo /path/to/a.wav + bar /path/to/b.wav + foo /path/to/c.wav + foobar /path/to/d.wav + +Each wave file should contain only a single channel; the sample format +should be int16_t; the sample rate can be arbitrary. + +(2) Download a model for computing speaker embeddings + +Please visit +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +to download a model. An example is given below: + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/wespeaker_zh_cnceleb_resnet34.onnx + +Note that `zh` means Chinese, while `en` means English. + +(3) Run this script + +Assume the filename of the text file is speaker.txt. + +python3 ./python-api-examples/speaker-identification.py \ + --speaker-file ./speaker.txt \ + --model ./wespeaker_zh_cnceleb_resnet34.onnx +""" +import argparse +import queue +import sys +import threading +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import sherpa_mnn +import soundfile as sf + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--speaker-file", + type=str, + required=True, + help="""Path to the speaker file. Read the help doc at the beginning of this + file for the format.""", + ) + + parser.add_argument( + "--model", + type=str, + required=True, + help="Path to the model file.", + ) + + parser.add_argument("--threshold", type=float, default=0.6) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + return parser.parse_args() + + +def load_speaker_embedding_model(args): + config = sherpa_mnn.SpeakerEmbeddingExtractorConfig( + model=args.model, + num_threads=args.num_threads, + debug=args.debug, + provider=args.provider, + ) + if not config.validate(): + raise ValueError(f"Invalid config. {config}") + extractor = sherpa_mnn.SpeakerEmbeddingExtractor(config) + return extractor + + +def load_speaker_file(args) -> Dict[str, List[str]]: + if not Path(args.speaker_file).is_file(): + raise ValueError(f"--speaker-file {args.speaker_file} does not exist") + + ans = defaultdict(list) + with open(args.speaker_file) as f: + for line in f: + line = line.strip() + if not line: + continue + + fields = line.split() + if len(fields) != 2: + raise ValueError(f"Invalid line: {line}. Fields: {fields}") + + speaker_name, filename = fields + ans[speaker_name].append(filename) + return ans + + +def load_audio(filename: str) -> Tuple[np.ndarray, int]: + data, sample_rate = sf.read( + filename, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + samples = np.ascontiguousarray(data) + return samples, sample_rate + + +def compute_speaker_embedding( + filenames: List[str], + extractor: sherpa_mnn.SpeakerEmbeddingExtractor, +) -> np.ndarray: + assert len(filenames) > 0, "filenames is empty" + + ans = None + for filename in filenames: + print(f"processing {filename}") + samples, sample_rate = load_audio(filename) + stream = extractor.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=samples) + stream.input_finished() + + assert extractor.is_ready(stream) + embedding = extractor.compute(stream) + embedding = np.array(embedding) + if ans is None: + ans = embedding + else: + ans += embedding + + return ans / len(filenames) + + +g_buffer = queue.Queue() +g_stop = False +g_sample_rate = 16000 +g_read_mic_thread = None + + +def read_mic(): + print("Please speak!") + samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms + with sd.InputStream(channels=1, dtype="float32", samplerate=g_sample_rate) as s: + while not g_stop: + samples, _ = s.read(samples_per_read) # a blocking read + g_buffer.put(samples) + + +def main(): + args = get_args() + print(args) + extractor = load_speaker_embedding_model(args) + speaker_file = load_speaker_file(args) + + manager = sherpa_mnn.SpeakerEmbeddingManager(extractor.dim) + for name, filename_list in speaker_file.items(): + embedding = compute_speaker_embedding( + filenames=filename_list, + extractor=extractor, + ) + status = manager.add(name, embedding) + if not status: + raise RuntimeError(f"Failed to register speaker {name}") + + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + sys.exit(0) + + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + global g_stop + global g_read_mic_thread + while True: + key = input("Press Enter to start recording") + if key.lower() in ("q", "quit"): + g_stop = True + break + + g_stop = False + g_buffer.queue.clear() + g_read_mic_thread = threading.Thread(target=read_mic) + g_read_mic_thread.start() + input("Press Enter to stop recording") + g_stop = True + g_read_mic_thread.join() + print("Compute embedding") + stream = extractor.create_stream() + while not g_buffer.empty(): + samples = g_buffer.get() + stream.accept_waveform(sample_rate=g_sample_rate, waveform=samples) + stream.input_finished() + + embedding = extractor.compute(stream) + embedding = np.array(embedding) + name = manager.search(embedding, threshold=args.threshold) + if not name: + name = "unknown" + print(f"Predicted name: {name}") + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") + g_stop = True + if g_read_mic_thread.is_alive(): + g_read_mic_thread.join() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection-alsa.py b/apps/frameworks/sherpa-mnn/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection-alsa.py new file mode 100755 index 00000000..be311e4c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection-alsa.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 + +# Real-time speech recognition from a microphone with sherpa-onnx Python API +# with endpoint detection. +# +# Note: This script uses ALSA and works only on Linux systems, especially +# for embedding Linux systems and for running Linux on Windows using WSL. +# +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +# to download pre-trained models + +import argparse +import sys +from pathlib import Path +import sherpa_mnn + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder", + type=str, + required=True, + help="Path to the encoder model", + ) + + parser.add_argument( + "--decoder", + type=str, + required=True, + help="Path to the decoder model", + ) + + parser.add_argument( + "--joiner", + type=str, + required=True, + help="Path to the joiner model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. For example: + + ▁HE LL O ▁WORLD + 你 好 世 界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "--device-name", + type=str, + required=True, + help=""" +The device name specifies which microphone to use in case there are several +on your system. You can use + + arecord -l + +to find all available microphones on your computer. For instance, if it outputs + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and device 0 on that card, please use: + + plughw:3,0 + +as the device_name. + """, + ) + + return parser.parse_args() + + +def create_recognizer(args): + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + assert_file_exists(args.tokens) + # Please replace the model files if needed. + # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html + # for download links. + recognizer = sherpa_mnn.OnlineRecognizer.from_transducer( + tokens=args.tokens, + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + num_threads=1, + sample_rate=16000, + feature_dim=80, + enable_endpoint_detection=True, + rule1_min_trailing_silence=2.4, + rule2_min_trailing_silence=1.2, + rule3_min_utterance_length=300, # it essentially disables this rule + decoding_method=args.decoding_method, + provider=args.provider, + hotwords_file=args.hotwords_file, + hotwords_score=args.hotwords_score, + blank_penalty=args.blank_penalty, + ) + return recognizer + + +def main(): + args = get_args() + device_name = args.device_name + print(f"device_name: {device_name}") + alsa = sherpa_mnn.Alsa(device_name) + + print("Creating recognizer") + recognizer = create_recognizer(args) + print("Started! Please speak") + + sample_rate = 16000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + + stream = recognizer.create_stream() + + last_result = "" + segment_id = 0 + while True: + samples = alsa.read(samples_per_read) # a blocking read + stream.accept_waveform(sample_rate, samples) + while recognizer.is_ready(stream): + recognizer.decode_stream(stream) + + is_endpoint = recognizer.is_endpoint(stream) + + result = recognizer.get_result(stream) + + if result and (last_result != result): + last_result = result + print("\r{}:{}".format(segment_id, result), end="", flush=True) + if is_endpoint: + if result: + print("\r{}:{}".format(segment_id, result), flush=True) + segment_id += 1 + recognizer.reset(stream) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py b/apps/frameworks/sherpa-mnn/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py new file mode 100755 index 00000000..af9bfec0 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 + +# Real-time speech recognition from a microphone with sherpa-onnx Python API +# with endpoint detection. +# +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +# to download pre-trained models + +import argparse +import sys +from pathlib import Path + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +import sherpa_mnn + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder", + type=str, + required=True, + help="Path to the encoder model", + ) + + parser.add_argument( + "--decoder", + type=str, + required=True, + help="Path to the decoder model", + ) + + parser.add_argument( + "--joiner", + type=str, + required=True, + help="Path to the joiner model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. For example: + + ▁HE LL O ▁WORLD + 你 好 世 界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + return parser.parse_args() + + +def create_recognizer(args): + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + assert_file_exists(args.tokens) + # Please replace the model files if needed. + # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html + # for download links. + recognizer = sherpa_mnn.OnlineRecognizer.from_transducer( + tokens=args.tokens, + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + num_threads=1, + sample_rate=16000, + feature_dim=80, + enable_endpoint_detection=True, + rule1_min_trailing_silence=2.4, + rule2_min_trailing_silence=1.2, + rule3_min_utterance_length=300, # it essentially disables this rule + decoding_method=args.decoding_method, + provider=args.provider, + hotwords_file=args.hotwords_file, + hotwords_score=args.hotwords_score, + blank_penalty=args.blank_penalty, + ) + return recognizer + + +def main(): + args = get_args() + + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + sys.exit(0) + + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + recognizer = create_recognizer(args) + print("Started! Please speak") + + # The model is using 16 kHz, we use 48 kHz here to demonstrate that + # sherpa-onnx will do resampling inside. + sample_rate = 48000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + + stream = recognizer.create_stream() + + last_result = "" + segment_id = 0 + with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + stream.accept_waveform(sample_rate, samples) + while recognizer.is_ready(stream): + recognizer.decode_stream(stream) + + is_endpoint = recognizer.is_endpoint(stream) + + result = recognizer.get_result(stream) + + if result and (last_result != result): + last_result = result + print("\r{}:{}".format(segment_id, result), end="", flush=True) + if is_endpoint: + if result: + print("\r{}:{}".format(segment_id, result), flush=True) + segment_id += 1 + recognizer.reset(stream) + + +if __name__ == "__main__": + + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/speech-recognition-from-microphone.py b/apps/frameworks/sherpa-mnn/python-api-examples/speech-recognition-from-microphone.py new file mode 100755 index 00000000..01be4c0e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/speech-recognition-from-microphone.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 + +# Real-time speech recognition from a microphone with sherpa-onnx Python API +# +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +# to download pre-trained models + +import argparse +import sys +from pathlib import Path + +from typing import List + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +import sherpa_mnn + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder", + type=str, + required=True, + help="Path to the encoder model", + ) + + parser.add_argument( + "--decoder", + type=str, + required=True, + help="Path to the decoder model", + ) + + parser.add_argument( + "--joiner", + type=str, + help="Path to the joiner model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + parser.add_argument( + "--max-active-paths", + type=int, + default=4, + help="""Used only when --decoding-method is modified_beam_search. + It specifies number of active paths to keep during decoding. + """, + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. For example: + + ▁HE LL O ▁WORLD + 你 好 世 界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + return parser.parse_args() + + +def create_recognizer(args): + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + assert_file_exists(args.tokens) + # Please replace the model files if needed. + # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html + # for download links. + recognizer = sherpa_mnn.OnlineRecognizer.from_transducer( + tokens=args.tokens, + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + num_threads=1, + sample_rate=16000, + feature_dim=80, + decoding_method=args.decoding_method, + max_active_paths=args.max_active_paths, + provider=args.provider, + hotwords_file=args.hotwords_file, + hotwords_score=args.hotwords_score, + blank_penalty=args.blank_penalty, + ) + return recognizer + + +def main(): + args = get_args() + + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + sys.exit(0) + + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + recognizer = create_recognizer(args) + print("Started! Please speak") + + # The model is using 16 kHz, we use 48 kHz here to demonstrate that + # sherpa-onnx will do resampling inside. + sample_rate = 48000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + last_result = "" + stream = recognizer.create_stream() + with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + stream.accept_waveform(sample_rate, samples) + while recognizer.is_ready(stream): + recognizer.decode_stream(stream) + result = recognizer.get_result(stream) + if last_result != result: + last_result = result + print("\r{}".format(result), end="", flush=True) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/speech-recognition-from-url.py b/apps/frameworks/sherpa-mnn/python-api-examples/speech-recognition-from-url.py new file mode 100755 index 00000000..d0728b03 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/speech-recognition-from-url.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +# +# Real-time speech recognition from a URL with sherpa-onnx Python API +# +# Supported URLs are those supported by ffmpeg. +# +# For instance: +# (1) RTMP +# rtmp://localhost/live/livestream +# +# (2) A file +# https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition/resolve/main/test_wavs/wenetspeech/DEV_T0000000000.opus +# https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition/resolve/main/test_wavs/aishell2/ID0012W0030.wav +# file:///Users/fangjun/open-source/sherpa-onnx/a.wav +# +# Note that it supports all file formats supported by ffmpeg +# +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +# to download pre-trained models + +import argparse +import shutil +import subprocess +import sys +from pathlib import Path + +import numpy as np +import sherpa_mnn + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder", + type=str, + required=True, + help="Path to the encoder model", + ) + + parser.add_argument( + "--decoder", + type=str, + required=True, + help="Path to the decoder model", + ) + + parser.add_argument( + "--joiner", + type=str, + help="Path to the joiner model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + parser.add_argument( + "--url", + type=str, + required=True, + help="""Example values: + rtmp://localhost/live/livestream + https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition/resolve/main/test_wavs/wenetspeech/DEV_T0000000000.opus + https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition/resolve/main/test_wavs/aishell2/ID0012W0030.wav + """, + ) + + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. For example: + + ▁HE LL O ▁WORLD + 你 好 世 界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + + + return parser.parse_args() + + +def create_recognizer(args): + # Please replace the model files if needed. + # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html + # for download links. + recognizer = sherpa_mnn.OnlineRecognizer.from_transducer( + tokens=args.tokens, + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + num_threads=1, + sample_rate=16000, + feature_dim=80, + decoding_method=args.decoding_method, + enable_endpoint_detection=True, + rule1_min_trailing_silence=2.4, + rule2_min_trailing_silence=1.2, + rule3_min_utterance_length=300, # it essentially disables this rule + hotwords_file=args.hotwords_file, + hotwords_score=args.hotwords_score, + ) + return recognizer + + +def main(): + args = get_args() + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + assert_file_exists(args.tokens) + + recognizer = create_recognizer(args) + + ffmpeg_cmd = [ + "ffmpeg", + "-i", + args.url, + "-f", + "s16le", + "-acodec", + "pcm_s16le", + "-ac", + "1", + "-ar", + "16000", + "-", + ] + + process = subprocess.Popen( + ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL + ) + + frames_per_read = 1600 # 0.1 second + + stream = recognizer.create_stream() + + last_result = "" + segment_id = 0 + + print("Started!") + while True: + # *2 because int16_t has two bytes + data = process.stdout.read(frames_per_read * 2) + if not data: + break + + samples = np.frombuffer(data, dtype=np.int16) + samples = samples.astype(np.float32) / 32768 + stream.accept_waveform(16000, samples) + + while recognizer.is_ready(stream): + recognizer.decode_stream(stream) + + is_endpoint = recognizer.is_endpoint(stream) + + result = recognizer.get_result(stream) + + if result and (last_result != result): + last_result = result + print("\r{}:{}".format(segment_id, result), end="", flush=True) + if is_endpoint: + if result: + print("\r{}:{}".format(segment_id, result), flush=True) + segment_id += 1 + recognizer.reset(stream) + + +if __name__ == "__main__": + if shutil.which("ffmpeg") is None: + sys.exit("Please install ffmpeg first!") + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/spoken-language-identification.py b/apps/frameworks/sherpa-mnn/python-api-examples/spoken-language-identification.py new file mode 100755 index 00000000..75a8482c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/spoken-language-identification.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 + +""" +This script shows how to use Python APIs for spoken languge identification. +It detects the language spoken in the given wave file. + +Usage: + +1. Download a whisper multilingual model. We use a tiny model below. +Please refer to https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models +to download more models. + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2 +tar xvf sherpa-onnx-whisper-tiny.tar.bz2 +rm sherpa-onnx-whisper-tiny.tar.bz2 + +We only use the int8.onnx models below. + +2. Download a test wave. + +You can find many wave files for different languages at +https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs + +wget https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/resolve/main/test_wavs/de-german.wav + +python3 ./python-api-examples/spoken-language-identification.py + --whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx \ + --whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx \ + --num-threads=1 \ + ./de-german.wav +""" + +import argparse +import logging +import time +import wave +from pathlib import Path +from typing import Tuple + +import numpy as np +import sherpa_mnn + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--whisper-encoder", + required=True, + type=str, + help="Path to a multilingual whisper encoder model", + ) + + parser.add_argument( + "--whisper-decoder", + required=True, + type=str, + help="Path to a multilingual whisper decoder model", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to identify. It must be of WAVE" + "format with a single channel, and each sample has 16-bit, " + "i.e., int16_t. " + "The sample rate of the file can be arbitrary and does not need to " + "be 16 kHz", + ) + + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html to download it" + ) + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +def main(): + args = get_args() + assert_file_exists(args.whisper_encoder) + assert_file_exists(args.whisper_decoder) + assert args.num_threads > 0, args.num_threads + config = sherpa_mnn.SpokenLanguageIdentificationConfig( + whisper=sherpa_mnn.SpokenLanguageIdentificationWhisperConfig( + encoder=args.whisper_encoder, + decoder=args.whisper_decoder, + ), + num_threads=args.num_threads, + debug=args.debug, + provider=args.provider, + ) + slid = sherpa_mnn.SpokenLanguageIdentification(config) + + samples, sample_rate = read_wave(args.sound_file) + + start_time = time.time() + stream = slid.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=samples) + lang = slid.compute(stream) + end_time = time.time() + + elapsed_seconds = end_time - start_time + audio_duration = len(samples) / sample_rate + real_time_factor = elapsed_seconds / audio_duration + + logging.info(f"File: {args.sound_file}") + logging.info(f"Detected language: {lang}") + logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}") + logging.info(f"Audio duration in seconds: {audio_duration:.3f}") + logging.info( + f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/streaming-paraformer-asr-microphone.py b/apps/frameworks/sherpa-mnn/python-api-examples/streaming-paraformer-asr-microphone.py new file mode 100755 index 00000000..d6196076 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/streaming-paraformer-asr-microphone.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 + +# Real-time speech recognition from a microphone with sherpa-onnx Python API +# with endpoint detection. +# This script uses a streaming paraformer +# +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html# +# to download pre-trained models + +import sys +from pathlib import Path + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +import sherpa_mnn + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html to download it" + ) + + +def create_recognizer(): + encoder = "./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx" + decoder = "./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx" + tokens = "./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt" + assert_file_exists(encoder) + assert_file_exists(decoder) + assert_file_exists(tokens) + recognizer = sherpa_mnn.OnlineRecognizer.from_paraformer( + tokens=tokens, + encoder=encoder, + decoder=decoder, + num_threads=1, + sample_rate=16000, + feature_dim=80, + enable_endpoint_detection=True, + rule1_min_trailing_silence=2.4, + rule2_min_trailing_silence=1.2, + rule3_min_utterance_length=300, # it essentially disables this rule + ) + return recognizer + + +def main(): + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + sys.exit(0) + + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + recognizer = create_recognizer() + print("Started! Please speak") + + # The model is using 16 kHz, we use 48 kHz here to demonstrate that + # sherpa-onnx will do resampling inside. + sample_rate = 48000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + + stream = recognizer.create_stream() + + last_result = "" + segment_id = 0 + with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + stream.accept_waveform(sample_rate, samples) + while recognizer.is_ready(stream): + recognizer.decode_stream(stream) + + is_endpoint = recognizer.is_endpoint(stream) + + result = recognizer.get_result(stream) + + if result and (last_result != result): + last_result = result + print("\r{}:{}".format(segment_id, result), end="", flush=True) + if is_endpoint: + if result: + print("\r{}:{}".format(segment_id, result), flush=True) + segment_id += 1 + recognizer.reset(stream) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/streaming_server.py b/apps/frameworks/sherpa-mnn/python-api-examples/streaming_server.py new file mode 100755 index 00000000..8f17c91c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/streaming_server.py @@ -0,0 +1,885 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. +# +""" +A server for streaming ASR recognition. By streaming it means the audio samples +are coming in real-time. You don't need to wait until all audio samples are +captured before sending them for recognition. + +It supports multiple clients sending at the same time. + +Usage: + ./streaming_server.py --help + +Example: + +(1) Without a certificate + +python3 ./python-api-examples/streaming_server.py \ + --encoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ + --decoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ + --joiner ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ + --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt + +(2) With a certificate + +(a) Generate a certificate first: + + cd python-api-examples/web + ./generate-certificate.py + cd ../.. + +(b) Start the server + +python3 ./python-api-examples/streaming_server.py \ + --encoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx \ + --decoder ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx \ + --joiner ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx \ + --tokens ./sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt \ + --certificate ./python-api-examples/web/cert.pem + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html +to download pre-trained models. + +The model in the above help messages is from +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english + +To use a WeNet streaming Conformer CTC model, please use + +python3 ./python-api-examples/streaming_server.py \ + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \ + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model-streaming.onnx +""" + +import argparse +import asyncio +import http +import json +import logging +import socket +import ssl +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from pathlib import Path +from typing import List, Optional, Tuple + +import numpy as np +import sherpa_mnn +import websockets + +from http_server import HttpServer + + +def setup_logger( + log_filename: str, + log_level: str = "info", + use_console: bool = True, +) -> None: + """Setup log level. + + Args: + log_filename: + The filename to save the log. + log_level: + The log level to use, e.g., "debug", "info", "warning", "error", + "critical" + use_console: + True to also print logs to console. + """ + now = datetime.now() + date_time = now.strftime("%Y-%m-%d-%H-%M-%S") + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + log_filename = f"{log_filename}-{date_time}.txt" + + Path(log_filename).parent.mkdir(parents=True, exist_ok=True) + + level = logging.ERROR + if log_level == "debug": + level = logging.DEBUG + elif log_level == "info": + level = logging.INFO + elif log_level == "warning": + level = logging.WARNING + elif log_level == "critical": + level = logging.CRITICAL + + logging.basicConfig( + filename=log_filename, + format=formatter, + level=level, + filemode="w", + ) + if use_console: + console = logging.StreamHandler() + console.setLevel(level) + console.setFormatter(logging.Formatter(formatter)) + logging.getLogger("").addHandler(console) + + +def add_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--encoder", + type=str, + help="Path to the transducer encoder model", + ) + + parser.add_argument( + "--decoder", + type=str, + help="Path to the transducer decoder model.", + ) + + parser.add_argument( + "--joiner", + type=str, + help="Path to the transducer joiner model.", + ) + + parser.add_argument( + "--zipformer2-ctc", + type=str, + help="Path to the model file from zipformer2 ctc", + ) + + parser.add_argument( + "--wenet-ctc", + type=str, + help="Path to the model.onnx from WeNet", + ) + + parser.add_argument( + "--paraformer-encoder", + type=str, + help="Path to the paraformer encoder model", + ) + + parser.add_argument( + "--paraformer-decoder", + type=str, + help="Path to the paraformer decoder model.", + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="Sample rate of the data used to train the model. " + "Caution: If your input sound files have a different sampling rate, " + "we will do resampling inside", + ) + + parser.add_argument( + "--feat-dim", + type=int, + default=80, + help="Feature dimension of the model", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + +def add_decoding_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Decoding method to use. Current supported methods are: + - greedy_search + - modified_beam_search + """, + ) + + add_modified_beam_search_args(parser) + + +def add_hotwords_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--hotwords-file", + type=str, + default="", + help=""" + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. For example: + + ▁HE LL O ▁WORLD + 你 好 世 界 + """, + ) + + parser.add_argument( + "--hotwords-score", + type=float, + default=1.5, + help=""" + The hotword score of each token for biasing word/phrase. Used only if + --hotwords-file is given. + """, + ) + parser.add_argument( + "--modeling-unit", + type=str, + default='cjkchar', + help=""" + The modeling unit of the used model. Current supported units are: + - cjkchar(for Chinese) + - bpe(for English like languages) + - cjkchar+bpe(for multilingual models) + """, + ) + parser.add_argument( + "--bpe-vocab", + type=str, + default='', + help=""" + The bpe vocabulary generated by sentencepiece toolkit. + It is only used when modeling-unit is bpe or cjkchar+bpe. + if you can’t find bpe.vocab in the model directory, please run: + python script/export_bpe_vocab.py --bpe-model exp/bpe.model + """, + ) + + +def add_modified_beam_search_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-active-paths", + type=int, + default=4, + help="""Used only when --decoding-method is modified_beam_search. + It specifies number of active paths to keep during decoding. + """, + ) + +def add_blank_penalty_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + +def add_endpointing_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--use-endpoint", + type=int, + default=1, + help="1 to enable endpoiting. 0 to disable it", + ) + + parser.add_argument( + "--rule1-min-trailing-silence", + type=float, + default=2.4, + help="""This endpointing rule1 requires duration of trailing silence + in seconds) to be >= this value""", + ) + + parser.add_argument( + "--rule2-min-trailing-silence", + type=float, + default=1.2, + help="""This endpointing rule2 requires duration of trailing silence in + seconds) to be >= this value.""", + ) + + parser.add_argument( + "--rule3-min-utterance-length", + type=float, + default=20, + help="""This endpointing rule3 requires utterance-length (in seconds) + to be >= this value.""", + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + add_model_args(parser) + add_decoding_args(parser) + add_endpointing_args(parser) + add_hotwords_args(parser) + add_blank_penalty_args(parser) + + parser.add_argument( + "--port", + type=int, + default=6006, + help="The server will listen on this port", + ) + + parser.add_argument( + "--nn-pool-size", + type=int, + default=1, + help="Number of threads for NN computation and decoding.", + ) + + parser.add_argument( + "--max-batch-size", + type=int, + default=3, + help="""Max batch size for computation. Note if there are not enough + requests in the queue, it will wait for max_wait_ms time. After that, + even if there are not enough requests, it still sends the + available requests in the queue for computation. + """, + ) + + parser.add_argument( + "--max-wait-ms", + type=float, + default=10, + help="""Max time in millisecond to wait to build batches for inference. + If there are not enough requests in the stream queue to build a batch + of max_batch_size, it waits up to this time before fetching available + requests for computation. + """, + ) + + parser.add_argument( + "--max-message-size", + type=int, + default=(1 << 20), + help="""Max message size in bytes. + The max size per message cannot exceed this limit. + """, + ) + + parser.add_argument( + "--max-queue-size", + type=int, + default=32, + help="Max number of messages in the queue for each connection.", + ) + + parser.add_argument( + "--max-active-connections", + type=int, + default=200, + help="""Maximum number of active connections. The server will refuse + to accept new connections once the current number of active connections + equals to this limit. + """, + ) + + parser.add_argument( + "--num-threads", + type=int, + default=2, + help="Number of threads to run the neural network model", + ) + + parser.add_argument( + "--certificate", + type=str, + help="""Path to the X.509 certificate. You need it only if you want to + use a secure websocket connection, i.e., use wss:// instead of ws://. + You can use ./web/generate-certificate.py + to generate the certificate `cert.pem`. + Note ./web/generate-certificate.py will generate three files but you + only need to pass the generated cert.pem to this option. + """, + ) + + parser.add_argument( + "--doc-root", + type=str, + default="./python-api-examples/web", + help="Path to the web root", + ) + + return parser.parse_args() + + +def create_recognizer(args) -> sherpa_mnn.OnlineRecognizer: + if args.encoder: + recognizer = sherpa_mnn.OnlineRecognizer.from_transducer( + tokens=args.tokens, + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + max_active_paths=args.num_active_paths, + hotwords_score=args.hotwords_score, + hotwords_file=args.hotwords_file, + blank_penalty=args.blank_penalty, + enable_endpoint_detection=args.use_endpoint != 0, + rule1_min_trailing_silence=args.rule1_min_trailing_silence, + rule2_min_trailing_silence=args.rule2_min_trailing_silence, + rule3_min_utterance_length=args.rule3_min_utterance_length, + provider=args.provider, + modeling_unit=args.modeling_unit, + bpe_vocab=args.bpe_vocab + ) + elif args.paraformer_encoder: + recognizer = sherpa_mnn.OnlineRecognizer.from_paraformer( + tokens=args.tokens, + encoder=args.paraformer_encoder, + decoder=args.paraformer_decoder, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + enable_endpoint_detection=args.use_endpoint != 0, + rule1_min_trailing_silence=args.rule1_min_trailing_silence, + rule2_min_trailing_silence=args.rule2_min_trailing_silence, + rule3_min_utterance_length=args.rule3_min_utterance_length, + provider=args.provider, + ) + elif args.zipformer2_ctc: + recognizer = sherpa_mnn.OnlineRecognizer.from_zipformer2_ctc( + tokens=args.tokens, + model=args.zipformer2_ctc, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + enable_endpoint_detection=args.use_endpoint != 0, + rule1_min_trailing_silence=args.rule1_min_trailing_silence, + rule2_min_trailing_silence=args.rule2_min_trailing_silence, + rule3_min_utterance_length=args.rule3_min_utterance_length, + provider=args.provider, + ) + elif args.wenet_ctc: + recognizer = sherpa_mnn.OnlineRecognizer.from_wenet_ctc( + tokens=args.tokens, + model=args.wenet_ctc, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + enable_endpoint_detection=args.use_endpoint != 0, + rule1_min_trailing_silence=args.rule1_min_trailing_silence, + rule2_min_trailing_silence=args.rule2_min_trailing_silence, + rule3_min_utterance_length=args.rule3_min_utterance_length, + provider=args.provider, + ) + else: + raise ValueError("Please provide a model") + + return recognizer + + +def format_timestamps(timestamps: List[float]) -> List[str]: + return ["{:.3f}".format(t) for t in timestamps] + + +class StreamingServer(object): + def __init__( + self, + recognizer: sherpa_mnn.OnlineRecognizer, + nn_pool_size: int, + max_wait_ms: float, + max_batch_size: int, + max_message_size: int, + max_queue_size: int, + max_active_connections: int, + doc_root: str, + certificate: Optional[str] = None, + ): + """ + Args: + recognizer: + An instance of online recognizer. + nn_pool_size: + Number of threads for the thread pool that is responsible for + neural network computation and decoding. + max_wait_ms: + Max wait time in milliseconds in order to build a batch of + `batch_size`. + max_batch_size: + Max batch size for inference. + max_message_size: + Max size in bytes per message. + max_queue_size: + Max number of messages in the queue for each connection. + max_active_connections: + Max number of active connections. Once number of active client + equals to this limit, the server refuses to accept new connections. + beam_search_params: + Dictionary containing all the parameters for beam search. + online_endpoint_config: + Config for endpointing. + doc_root: + Path to the directory where files like index.html for the HTTP + server locate. + certificate: + Optional. If not None, it will use secure websocket. + You can use ./web/generate-certificate.py to generate + it (the default generated filename is `cert.pem`). + """ + self.recognizer = recognizer + + self.certificate = certificate + self.http_server = HttpServer(doc_root) + + self.nn_pool_size = nn_pool_size + self.nn_pool = ThreadPoolExecutor( + max_workers=nn_pool_size, + thread_name_prefix="nn", + ) + + self.stream_queue = asyncio.Queue() + + self.max_wait_ms = max_wait_ms + self.max_batch_size = max_batch_size + self.max_message_size = max_message_size + self.max_queue_size = max_queue_size + self.max_active_connections = max_active_connections + + self.current_active_connections = 0 + + self.sample_rate = int(recognizer.config.feat_config.sampling_rate) + + async def stream_consumer_task(self): + """This function extracts streams from the queue, batches them up, sends + them to the neural network model for computation and decoding. + """ + while True: + if self.stream_queue.empty(): + await asyncio.sleep(self.max_wait_ms / 1000) + continue + + batch = [] + try: + while len(batch) < self.max_batch_size: + item = self.stream_queue.get_nowait() + + assert self.recognizer.is_ready(item[0]) + + batch.append(item) + except asyncio.QueueEmpty: + pass + stream_list = [b[0] for b in batch] + future_list = [b[1] for b in batch] + + loop = asyncio.get_running_loop() + await loop.run_in_executor( + self.nn_pool, + self.recognizer.decode_streams, + stream_list, + ) + + for f in future_list: + self.stream_queue.task_done() + f.set_result(None) + + async def compute_and_decode( + self, + stream: sherpa_mnn.OnlineStream, + ) -> None: + """Put the stream into the queue and wait it to be processed by the + consumer task. + + Args: + stream: + The stream to be processed. Note: It is changed in-place. + """ + loop = asyncio.get_running_loop() + future = loop.create_future() + await self.stream_queue.put((stream, future)) + await future + + async def process_request( + self, + path: str, + request_headers: websockets.Headers, + ) -> Optional[Tuple[http.HTTPStatus, websockets.Headers, bytes]]: + if "sec-websocket-key" not in ( + request_headers.headers # For new request_headers + if hasattr(request_headers, "headers") + else request_headers # For old request_headers + ): + # This is a normal HTTP request + if path == "/": + path = "/index.html" + + if path in ("/upload.html", "/offline_record.html"): + response = r""" + +Speech recognition with next-gen Kaldi +

Only /streaming_record.html is available for the streaming server.

+
+
+Go back to /streaming_record.html + +""" + found = True + mime_type = "text/html" + else: + found, response, mime_type = self.http_server.process_request(path) + + if isinstance(response, str): + response = response.encode("utf-8") + + if not found: + status = http.HTTPStatus.NOT_FOUND + else: + status = http.HTTPStatus.OK + header = {"Content-Type": mime_type} + return status, header, response + + if self.current_active_connections < self.max_active_connections: + self.current_active_connections += 1 + return None + + # Refuse new connections + status = http.HTTPStatus.SERVICE_UNAVAILABLE # 503 + header = {"Hint": "The server is overloaded. Please retry later."} + response = b"The server is busy. Please retry later." + + return status, header, response + + async def run(self, port: int): + tasks = [] + for i in range(self.nn_pool_size): + tasks.append(asyncio.create_task(self.stream_consumer_task())) + + if self.certificate: + logging.info(f"Using certificate: {self.certificate}") + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ssl_context.load_cert_chain(self.certificate) + else: + ssl_context = None + logging.info("No certificate provided") + + async with websockets.serve( + self.handle_connection, + host="", + port=port, + max_size=self.max_message_size, + max_queue=self.max_queue_size, + process_request=self.process_request, + ssl=ssl_context, + ): + ip_list = ["localhost"] + if ssl_context: + ip_list += ["0.0.0.0", "127.0.0.1"] + ip_list.append(socket.gethostbyname(socket.gethostname())) + proto = "http://" if ssl_context is None else "https://" + s = "Please visit one of the following addresses:\n\n" + for p in ip_list: + s += " " + proto + p + f":{port}" "\n" + + if not ssl_context: + s += "\nSince you are not providing a certificate, you cannot " + s += "use your microphone from within the browser using " + s += "public IP addresses. Only localhost can be used." + s += "You also cannot use 0.0.0.0 or 127.0.0.1" + + logging.info(s) + + await asyncio.Future() # run forever + + await asyncio.gather(*tasks) # not reachable + + async def handle_connection( + self, + socket: websockets.WebSocketServerProtocol, + ): + """Receive audio samples from the client, process it, and send + decoding result back to the client. + + Args: + socket: + The socket for communicating with the client. + """ + try: + await self.handle_connection_impl(socket) + except websockets.exceptions.ConnectionClosedError: + logging.info(f"{socket.remote_address} disconnected") + finally: + # Decrement so that it can accept new connections + self.current_active_connections -= 1 + + logging.info( + f"Disconnected: {socket.remote_address}. " + f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa + ) + + async def handle_connection_impl( + self, + socket: websockets.WebSocketServerProtocol, + ): + """Receive audio samples from the client, process it, and send + decoding result back to the client. + + Args: + socket: + The socket for communicating with the client. + """ + logging.info( + f"Connected: {socket.remote_address}. " + f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa + ) + + stream = self.recognizer.create_stream() + segment = 0 + + while True: + samples = await self.recv_audio_samples(socket) + if samples is None: + break + + # TODO(fangjun): At present, we assume the sampling rate + # of the received audio samples equal to --sample-rate + stream.accept_waveform(sample_rate=self.sample_rate, waveform=samples) + + while self.recognizer.is_ready(stream): + await self.compute_and_decode(stream) + result = self.recognizer.get_result(stream) + + message = { + "text": result, + "segment": segment, + } + if self.recognizer.is_endpoint(stream): + self.recognizer.reset(stream) + segment += 1 + + await socket.send(json.dumps(message)) + + tail_padding = np.zeros(int(self.sample_rate * 0.3)).astype(np.float32) + stream.accept_waveform(sample_rate=self.sample_rate, waveform=tail_padding) + stream.input_finished() + while self.recognizer.is_ready(stream): + await self.compute_and_decode(stream) + + result = self.recognizer.get_result(stream) + + message = { + "text": result, + "segment": segment, + } + + await socket.send(json.dumps(message)) + + async def recv_audio_samples( + self, + socket: websockets.WebSocketServerProtocol, + ) -> Optional[np.ndarray]: + """Receive a tensor from the client. + + Each message contains either a bytes buffer containing audio samples + in 16 kHz or contains "Done" meaning the end of utterance. + + Args: + socket: + The socket for communicating with the client. + Returns: + Return a 1-D np.float32 tensor containing the audio samples or + return None. + """ + message = await socket.recv() + if message == "Done": + return None + + return np.frombuffer(message, dtype=np.float32) + + +def check_args(args): + if args.encoder: + assert Path(args.encoder).is_file(), f"{args.encoder} does not exist" + + assert Path(args.decoder).is_file(), f"{args.decoder} does not exist" + + assert Path(args.joiner).is_file(), f"{args.joiner} does not exist" + + assert args.paraformer_encoder is None, args.paraformer_encoder + assert args.paraformer_decoder is None, args.paraformer_decoder + assert args.zipformer2_ctc is None, args.zipformer2_ctc + assert args.wenet_ctc is None, args.wenet_ctc + elif args.paraformer_encoder: + assert Path( + args.paraformer_encoder + ).is_file(), f"{args.paraformer_encoder} does not exist" + + assert Path( + args.paraformer_decoder + ).is_file(), f"{args.paraformer_decoder} does not exist" + elif args.zipformer2_ctc: + assert Path( + args.zipformer2_ctc + ).is_file(), f"{args.zipformer2_ctc} does not exist" + elif args.wenet_ctc: + assert Path(args.wenet_ctc).is_file(), f"{args.wenet_ctc} does not exist" + else: + raise ValueError("Please provide a model") + + if not Path(args.tokens).is_file(): + raise ValueError(f"{args.tokens} does not exist") + + if args.decoding_method not in ( + "greedy_search", + "modified_beam_search", + ): + raise ValueError(f"Unsupported decoding method {args.decoding_method}") + + if args.decoding_method == "modified_beam_search": + assert args.num_active_paths > 0, args.num_active_paths + + +def main(): + args = get_args() + logging.info(vars(args)) + check_args(args) + + recognizer = create_recognizer(args) + + port = args.port + nn_pool_size = args.nn_pool_size + max_batch_size = args.max_batch_size + max_wait_ms = args.max_wait_ms + max_message_size = args.max_message_size + max_queue_size = args.max_queue_size + max_active_connections = args.max_active_connections + certificate = args.certificate + doc_root = args.doc_root + + if certificate and not Path(certificate).is_file(): + raise ValueError(f"{certificate} does not exist") + + if not Path(doc_root).is_dir(): + raise ValueError(f"Directory {doc_root} does not exist") + + server = StreamingServer( + recognizer=recognizer, + nn_pool_size=nn_pool_size, + max_batch_size=max_batch_size, + max_wait_ms=max_wait_ms, + max_message_size=max_message_size, + max_queue_size=max_queue_size, + max_active_connections=max_active_connections, + certificate=certificate, + doc_root=doc_root, + ) + asyncio.run(server.run(port)) + + +if __name__ == "__main__": + log_filename = "log/log-streaming-server" + setup_logger(log_filename) + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/two-pass-speech-recognition-from-microphone.py b/apps/frameworks/sherpa-mnn/python-api-examples/two-pass-speech-recognition-from-microphone.py new file mode 100755 index 00000000..0b4fdf73 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/two-pass-speech-recognition-from-microphone.py @@ -0,0 +1,440 @@ +#!/usr/bin/env python3 + +# Two-pass real-time speech recognition from a microphone with sherpa-onnx +# Python API. +# +# The first pass uses a streaming model, which has two purposes: +# +# (1) Display a temporary result to users +# +# (2) Endpointing +# +# The second pass uses a non-streaming model. It has a higher recognition +# accuracy than the first pass model and its result is used as the final result. +# +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +# to download pre-trained models + +""" +Usage examples: + +(1) Chinese: Streaming zipformer (1st pass) + Non-streaming paraformer (2nd pass) + +python3 ./python-api-examples/two-pass-speech-recognition-from-microphone.py \ + --first-encoder ./sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/encoder-epoch-99-avg-1.onnx \ + --first-decoder ./sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/decoder-epoch-99-avg-1.onnx \ + --first-joiner ./sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/joiner-epoch-99-avg-1.onnx \ + --first-tokens ./sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23/tokens.txt \ + \ + --second-paraformer ./sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx \ + --second-tokens ./sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt + +(2) English: Streaming zipformer (1st pass) + Non-streaming whisper (2nd pass) + +python3 ./python-api-examples/two-pass-speech-recognition-from-microphone.py \ + --first-encoder ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/encoder-epoch-99-avg-1.onnx \ + --first-decoder ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/decoder-epoch-99-avg-1.onnx \ + --first-joiner ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/joiner-epoch-99-avg-1.onnx \ + --first-tokens ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/tokens.txt \ + \ + --second-whisper-encoder ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx \ + --second-whisper-decoder ./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx \ + --second-tokens ./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt +""" + +import argparse +import sys +from pathlib import Path +from typing import List + +import numpy as np + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +import sherpa_mnn + + +def assert_file_exists(filename: str, message: str): + if not filename: + raise ValueError(f"Please specify {message}") + + if not Path(filename).is_file(): + raise ValueError(f"{message} {filename} does not exist") + + +def add_first_pass_streaming_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--first-tokens", + type=str, + required=True, + help="Path to tokens.txt for the first pass", + ) + + parser.add_argument( + "--first-encoder", + type=str, + required=True, + help="Path to the encoder model for the first pass", + ) + + parser.add_argument( + "--first-decoder", + type=str, + required=True, + help="Path to the decoder model for the first pass", + ) + + parser.add_argument( + "--first-joiner", + type=str, + help="Path to the joiner model for the first pass", + ) + + parser.add_argument( + "--first-decoding-method", + type=str, + default="greedy_search", + help="""Decoding method for the first pass. Valid values are + greedy_search and modified_beam_search""", + ) + + parser.add_argument( + "--first-max-active-paths", + type=int, + default=4, + help="""Used only when --first-decoding-method is modified_beam_search. + It specifies number of active paths to keep during decoding. + """, + ) + + +def add_second_pass_transducer_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--second-encoder", + default="", + type=str, + help="Path to the transducer encoder model for the second pass", + ) + + parser.add_argument( + "--second-decoder", + default="", + type=str, + help="Path to the transducer decoder model for the second pass", + ) + + parser.add_argument( + "--second-joiner", + default="", + type=str, + help="Path to the transducer joiner model for the second pass", + ) + + +def add_second_pass_paraformer_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--second-paraformer", + default="", + type=str, + help="Path to the model.onnx for Paraformer for the second pass", + ) + + +def add_second_pass_nemo_ctc_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--second-nemo-ctc", + default="", + type=str, + help="Path to the model.onnx for NeMo CTC for the second pass", + ) + + +def add_second_pass_whisper_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--second-whisper-encoder", + default="", + type=str, + help="Path to whisper encoder model for the second pass", + ) + + parser.add_argument( + "--second-whisper-decoder", + default="", + type=str, + help="Path to whisper decoder model for the second pass", + ) + + parser.add_argument( + "--second-whisper-language", + default="", + type=str, + help="""It specifies the spoken language in the input audio file. + Example values: en, fr, de, zh, jp. + Available languages for multilingual models can be found at + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + If not specified, we infer the language from the input audio file. + """, + ) + + parser.add_argument( + "--second-whisper-task", + default="transcribe", + choices=["transcribe", "translate"], + type=str, + help="""For multilingual models, if you specify translate, the output + will be in English. + """, + ) + + parser.add_argument( + "--second-whisper-tail-paddings", + default=-1, + type=int, + help="""Number of tail padding frames. + We have removed the 30-second constraint from whisper, so you need to + choose the amount of tail padding frames by yourself. + Use -1 to use a default value for tail padding. + """, + ) + + +def add_second_pass_non_streaming_model_args(parser: argparse.ArgumentParser): + add_second_pass_transducer_model_args(parser) + add_second_pass_nemo_ctc_model_args(parser) + add_second_pass_paraformer_model_args(parser) + add_second_pass_whisper_model_args(parser) + + parser.add_argument( + "--second-tokens", + type=str, + help="Path to tokens.txt for the second pass", + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + add_first_pass_streaming_model_args(parser) + add_second_pass_non_streaming_model_args(parser) + + return parser.parse_args() + + +def check_first_pass_args(args): + assert_file_exists(args.first_tokens, "--first-tokens") + assert_file_exists(args.first_encoder, "--first-encoder") + assert_file_exists(args.first_decoder, "--first-decoder") + assert_file_exists(args.first_joiner, "--first-joiner") + + +def check_second_pass_args(args): + assert_file_exists(args.second_tokens, "--second-tokens") + + if args.second_encoder: + assert_file_exists(args.second_encoder, "--second-encoder") + assert_file_exists(args.second_decoder, "--second-decoder") + assert_file_exists(args.second_joiner, "--second-joiner") + elif args.second_paraformer: + assert_file_exists(args.second_paraformer, "--second-paraformer") + elif args.second_nemo_ctc: + assert_file_exists(args.second_nemo_ctc, "--second-nemo-ctc") + elif args.second_whisper_encoder: + assert_file_exists(args.second_whisper_encoder, "--second-whisper-encoder") + assert_file_exists(args.second_whisper_decoder, "--second-whisper-decoder") + else: + raise ValueError("Please specify the model for the second pass") + + +def create_first_pass_recognizer(args): + # Please replace the model files if needed. + # See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html + # for download links. + recognizer = sherpa_mnn.OnlineRecognizer.from_transducer( + tokens=args.first_tokens, + encoder=args.first_encoder, + decoder=args.first_decoder, + joiner=args.first_joiner, + num_threads=1, + sample_rate=16000, + feature_dim=80, + decoding_method=args.first_decoding_method, + max_active_paths=args.first_max_active_paths, + provider=args.provider, + enable_endpoint_detection=True, + rule1_min_trailing_silence=2.4, + rule2_min_trailing_silence=1.2, + rule3_min_utterance_length=20, + ) + return recognizer + + +def create_second_pass_recognizer(args) -> sherpa_mnn.OfflineRecognizer: + if args.second_encoder: + recognizer = sherpa_mnn.OfflineRecognizer.from_transducer( + encoder=args.second_encoder, + decoder=args.second_decoder, + joiner=args.second_joiner, + tokens=args.second_tokens, + sample_rate=16000, + feature_dim=80, + decoding_method="greedy_search", + max_active_paths=4, + ) + elif args.second_paraformer: + recognizer = sherpa_mnn.OfflineRecognizer.from_paraformer( + paraformer=args.second_paraformer, + tokens=args.second_tokens, + num_threads=1, + sample_rate=16000, + feature_dim=80, + decoding_method="greedy_search", + ) + elif args.second_nemo_ctc: + recognizer = sherpa_mnn.OfflineRecognizer.from_nemo_ctc( + model=args.second_nemo_ctc, + tokens=args.second_tokens, + num_threads=1, + sample_rate=16000, + feature_dim=80, + decoding_method="greedy_search", + ) + elif args.second_whisper_encoder: + recognizer = sherpa_mnn.OfflineRecognizer.from_whisper( + encoder=args.second_whisper_encoder, + decoder=args.second_whisper_decoder, + tokens=args.second_tokens, + num_threads=1, + decoding_method="greedy_search", + language=args.second_whisper_language, + task=args.second_whisper_task, + tail_paddings=args.second_whisper_tail_paddings, + ) + else: + raise ValueError("Please specify at least one model for the second pass") + + return recognizer + + +def run_second_pass( + recognizer: sherpa_mnn.OfflineRecognizer, + samples: np.ndarray, + sample_rate: int, +): + stream = recognizer.create_stream() + stream.accept_waveform(sample_rate, samples) + + recognizer.decode_stream(stream) + + return stream.result.text + + +def main(): + args = get_args() + check_first_pass_args(args) + check_second_pass_args(args) + + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + sys.exit(0) + + print(devices) + + # If you want to select a different input device, please use + # sd.default.device[0] = xxx + # where xxx is the device number + + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + print("Creating recognizers. Please wait...") + first_recognizer = create_first_pass_recognizer(args) + second_recognizer = create_second_pass_recognizer(args) + + print("Started! Please speak") + + sample_rate = 16000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + stream = first_recognizer.create_stream() + + last_result = "" + segment_id = 0 + + sample_buffers = [] + with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + stream.accept_waveform(sample_rate, samples) + + sample_buffers.append(samples) + + while first_recognizer.is_ready(stream): + first_recognizer.decode_stream(stream) + + is_endpoint = first_recognizer.is_endpoint(stream) + + result = first_recognizer.get_result(stream) + result = result.lower().strip() + + if last_result != result: + print( + "\r{}:{}".format(segment_id, " " * len(last_result)), + end="", + flush=True, + ) + last_result = result + print("\r{}:{}".format(segment_id, result), end="", flush=True) + + if is_endpoint: + if result: + samples = np.concatenate(sample_buffers) + # There are internal sample buffers inside the streaming + # feature extractor, so we cannot send all samples to + # the 2nd pass. Here 8000 is just an empirical value + # that should work for most streaming models in sherpa-onnx + sample_buffers = [samples[-8000:]] + samples = samples[:-8000] + result = run_second_pass( + recognizer=second_recognizer, + samples=samples, + sample_rate=sample_rate, + ) + result = result.lower().strip() + + print( + "\r{}:{}".format(segment_id, " " * len(last_result)), + end="", + flush=True, + ) + print("\r{}:{}".format(segment_id, result), flush=True) + segment_id += 1 + else: + sample_buffers = [] + + first_recognizer.reset(stream) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/vad-alsa.py b/apps/frameworks/sherpa-mnn/python-api-examples/vad-alsa.py new file mode 100755 index 00000000..61b65966 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/vad-alsa.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 + +""" +This script works only on Linux. It uses ALSA for recording. +""" + +import argparse +from pathlib import Path + +import sherpa_mnn + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--silero-vad-model", + type=str, + required=True, + help="Path to silero_vad.onnx", + ) + + parser.add_argument( + "--device-name", + type=str, + required=True, + help=""" +The device name specifies which microphone to use in case there are several +on your system. You can use + + arecord -l + +to find all available microphones on your computer. For instance, if it outputs + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and device 0 on that card, please use: + + plughw:3,0 + +as the device_name. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + if not Path(args.silero_vad_model).is_file(): + raise RuntimeError( + f"{args.silero_vad_model} does not exist. Please download it from " + "https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx" + ) + + device_name = args.device_name + print(f"device_name: {device_name}") + alsa = sherpa_mnn.Alsa(device_name) + + sample_rate = 16000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + + config = sherpa_mnn.VadModelConfig() + config.silero_vad.model = args.silero_vad_model + config.sample_rate = sample_rate + + vad = sherpa_mnn.VoiceActivityDetector(config, buffer_size_in_seconds=30) + + print("Started! Please speak. Press Ctrl C to exit") + + printed = False + k = 0 + try: + while True: + samples = alsa.read(samples_per_read) # a blocking read + + vad.accept_waveform(samples) + + if vad.is_speech_detected() and not printed: + print("Detected speech") + printed = True + + if not vad.is_speech_detected(): + printed = False + + while not vad.empty(): + samples = vad.front.samples + duration = len(samples) / sample_rate + filename = f"seg-{k}-{duration:.3f}-seconds.wav" + k += 1 + sherpa_mnn.write_wave(filename, samples, sample_rate) + print(f"Duration: {duration:.3f} seconds") + print(f"Saved to {filename}") + print("----------") + + vad.pop() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exit") + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/vad-microphone.py b/apps/frameworks/sherpa-mnn/python-api-examples/vad-microphone.py new file mode 100755 index 00000000..5e2beb36 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/vad-microphone.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 + +import argparse +import os +import sys +from pathlib import Path + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +import sherpa_mnn + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--silero-vad-model", + type=str, + required=True, + help="Path to silero_vad.onnx", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + if not Path(args.silero_vad_model).is_file(): + raise RuntimeError( + f"{args.silero_vad_model} does not exist. Please download it from " + "https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx" + ) + + mic_sample_rate = 16000 + if "SHERPA_ONNX_MIC_SAMPLE_RATE" in os.environ: + mic_sample_rate = int(os.environ.get("SHERPA_ONNX_MIC_SAMPLE_RATE")) + print(f"Change microphone sample rate to {mic_sample_rate}") + + sample_rate = 16000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + + config = sherpa_mnn.VadModelConfig() + config.silero_vad.model = args.silero_vad_model + config.sample_rate = sample_rate + + vad = sherpa_mnn.VoiceActivityDetector(config, buffer_size_in_seconds=30) + + # python3 -m sounddevice + # can also be used to list all devices + + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + print( + "If you are using Linux and you are sure there is a microphone " + "on your system, please use " + "./vad-alsa.py" + ) + sys.exit(0) + + print(devices) + + if "SHERPA_ONNX_MIC_DEVICE" in os.environ: + input_device_idx = int(os.environ.get("SHERPA_ONNX_MIC_DEVICE")) + sd.default.device[0] = input_device_idx + print(f'Use selected device: {devices[input_device_idx]["name"]}') + else: + input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[input_device_idx]["name"]}') + + print("Started! Please speak. Press Ctrl C to exit") + + printed = False + k = 0 + try: + with sd.InputStream( + channels=1, dtype="float32", samplerate=mic_sample_rate + ) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + + if mic_sample_rate != sample_rate: + import librosa + + samples = librosa.resample( + samples, orig_sr=mic_sample_rate, target_sr=sample_rate + ) + + vad.accept_waveform(samples) + + if vad.is_speech_detected() and not printed: + print("Detected speech") + printed = True + + if not vad.is_speech_detected(): + printed = False + + while not vad.empty(): + samples = vad.front.samples + duration = len(samples) / sample_rate + filename = f"seg-{k}-{duration:.3f}-seconds.wav" + k += 1 + sherpa_mnn.write_wave(filename, samples, sample_rate) + print(f"Duration: {duration:.3f} seconds") + print(f"Saved to {filename}") + print("----------") + + vad.pop() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exit") + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/vad-remove-non-speech-segments-alsa.py b/apps/frameworks/sherpa-mnn/python-api-examples/vad-remove-non-speech-segments-alsa.py new file mode 100755 index 00000000..eb590e3b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/vad-remove-non-speech-segments-alsa.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 + +""" +This file shows how to remove non-speech segments +and merge all speech segments into a large segment +and save it to a file. + +Different from ./vad-remove-non-speech-segments.py, this file supports only +Linux. + +Usage + +python3 ./vad-remove-non-speech-segments-alsa.py \ + --silero-vad-model silero_vad.onnx + +Please visit +https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx +to download silero_vad.onnx + +For instance, + +wget https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx +""" + +import argparse +import time +from pathlib import Path + +import numpy as np +import sherpa_mnn +import soundfile as sf + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--silero-vad-model", + type=str, + required=True, + help="Path to silero_vad.onnx", + ) + + parser.add_argument( + "--device-name", + type=str, + required=True, + help=""" +The device name specifies which microphone to use in case there are several +on your system. You can use + + arecord -l + +to find all available microphones on your computer. For instance, if it outputs + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and device 0 on that card, please use: + + plughw:3,0 + +as the device_name. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + assert_file_exists(args.silero_vad_model) + + device_name = args.device_name + print(f"device_name: {device_name}") + alsa = sherpa_mnn.Alsa(device_name) + + sample_rate = 16000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + + config = sherpa_mnn.VadModelConfig() + config.silero_vad.model = args.silero_vad_model + config.sample_rate = sample_rate + + window_size = config.silero_vad.window_size + + buffer = [] + vad = sherpa_mnn.VoiceActivityDetector(config, buffer_size_in_seconds=30) + + all_samples = [] + + print("Started! Please speak. Press Ctrl C to exit") + + try: + while True: + samples = alsa.read(samples_per_read) # a blocking read + samples = np.array(samples) + + buffer = np.concatenate([buffer, samples]) + + all_samples = np.concatenate([all_samples, samples]) + + while len(buffer) > window_size: + vad.accept_waveform(buffer[:window_size]) + buffer = buffer[window_size:] + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Saving & Exiting") + + speech_samples = [] + while not vad.empty(): + speech_samples.extend(vad.front.samples) + vad.pop() + + speech_samples = np.array(speech_samples, dtype=np.float32) + + filename_for_speech = time.strftime("%Y%m%d-%H%M%S-speech.wav") + sf.write(filename_for_speech, speech_samples, samplerate=sample_rate) + + filename_for_all = time.strftime("%Y%m%d-%H%M%S-all.wav") + sf.write(filename_for_all, all_samples, samplerate=sample_rate) + + print(f"Saved to {filename_for_speech} and {filename_for_all}") + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/vad-remove-non-speech-segments-from-file.py b/apps/frameworks/sherpa-mnn/python-api-examples/vad-remove-non-speech-segments-from-file.py new file mode 100755 index 00000000..52f8fd95 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/vad-remove-non-speech-segments-from-file.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 + +""" +This file shows how to remove non-speech segments +and merge all speech segments into a large segment +and save it to a file. + +Usage + +python3 ./vad-remove-non-speech-segments-from-file.py \ + --silero-vad-model silero_vad.onnx \ + input.wav \ + output.wav + +Please visit +https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx +to download silero_vad.onnx + +For instance, + +wget https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx +""" + +import argparse +from pathlib import Path +from typing import Tuple + +import numpy as np +import sherpa_mnn +import soundfile as sf + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--silero-vad-model", + type=str, + required=True, + help="Path to silero_vad.onnx", + ) + + parser.add_argument( + "input", + type=str, + help="Path to input.wav", + ) + + parser.add_argument( + "output", + type=str, + help="Path to output.wav", + ) + + return parser.parse_args() + + +def load_audio(filename: str) -> Tuple[np.ndarray, int]: + data, sample_rate = sf.read( + filename, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + samples = np.ascontiguousarray(data) + return samples, sample_rate + + +def main(): + args = get_args() + assert_file_exists(args.silero_vad_model) + assert_file_exists(args.input) + + samples, sample_rate = load_audio(args.input) + if sample_rate != 16000: + import librosa + + samples = librosa.resample(samples, orig_sr=sample_rate, target_sr=16000) + sample_rate = 16000 + + config = sherpa_mnn.VadModelConfig() + config.silero_vad.model = args.silero_vad_model + config.silero_vad.threshold = 0.5 + config.silero_vad.min_silence_duration = 0.25 # seconds + config.silero_vad.min_speech_duration = 0.25 # seconds + + # If the current segment is larger than this value, then it increases + # the threshold to 0.9 internally. After detecting this segment, + # it resets the threshold to its original value. + config.silero_vad.max_speech_duration = 5 # seconds + + config.sample_rate = sample_rate + + window_size = config.silero_vad.window_size + + vad = sherpa_mnn.VoiceActivityDetector(config, buffer_size_in_seconds=30) + + speech_samples = [] + while len(samples) > window_size: + vad.accept_waveform(samples[:window_size]) + samples = samples[window_size:] + + while not vad.empty(): + speech_samples.extend(vad.front.samples) + vad.pop() + + vad.flush() + + while not vad.empty(): + speech_samples.extend(vad.front.samples) + vad.pop() + + speech_samples = np.array(speech_samples, dtype=np.float32) + + sf.write(args.output, speech_samples, samplerate=sample_rate) + + print(f"Saved to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/vad-remove-non-speech-segments.py b/apps/frameworks/sherpa-mnn/python-api-examples/vad-remove-non-speech-segments.py new file mode 100755 index 00000000..11339269 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/vad-remove-non-speech-segments.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 + +""" +This file shows how to remove non-speech segments +and merge all speech segments into a large segment +and save it to a file. + +Usage + +python3 ./vad-remove-non-speech-segments.py \ + --silero-vad-model silero_vad.onnx + +Please visit +https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx +to download silero_vad.onnx + +For instance, + +wget https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx +""" + +import argparse +import sys +import time +from pathlib import Path + +import numpy as np +import sherpa_mnn +import soundfile as sf + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--silero-vad-model", + type=str, + required=True, + help="Path to silero_vad.onnx", + ) + + return parser.parse_args() + + +def main(): + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + print( + "If you are using Linux and you are sure there is a microphone " + "on your system, please use " + "./vad-remove-non-speech-segments-alsa.py" + ) + sys.exit(0) + + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + args = get_args() + assert_file_exists(args.silero_vad_model) + + sample_rate = 16000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + + config = sherpa_mnn.VadModelConfig() + config.silero_vad.model = args.silero_vad_model + config.sample_rate = sample_rate + + window_size = config.silero_vad.window_size + + buffer = [] + vad = sherpa_mnn.VoiceActivityDetector(config, buffer_size_in_seconds=30) + + all_samples = [] + + print("Started! Please speak. Press Ctrl C to exit") + + try: + with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + buffer = np.concatenate([buffer, samples]) + + all_samples = np.concatenate([all_samples, samples]) + + while len(buffer) > window_size: + vad.accept_waveform(buffer[:window_size]) + buffer = buffer[window_size:] + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Saving & Exiting") + + speech_samples = [] + while not vad.empty(): + speech_samples.extend(vad.front.samples) + vad.pop() + + speech_samples = np.array(speech_samples, dtype=np.float32) + + filename_for_speech = time.strftime("%Y%m%d-%H%M%S-speech.wav") + sf.write(filename_for_speech, speech_samples, samplerate=sample_rate) + + filename_for_all = time.strftime("%Y%m%d-%H%M%S-all.wav") + sf.write(filename_for_all, all_samples, samplerate=sample_rate) + + print(f"Saved to {filename_for_speech} and {filename_for_all}") + + +if __name__ == "__main__": + main() diff --git a/apps/frameworks/sherpa-mnn/python-api-examples/vad-with-non-streaming-asr.py b/apps/frameworks/sherpa-mnn/python-api-examples/vad-with-non-streaming-asr.py new file mode 100755 index 00000000..2fc1b287 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/python-api-examples/vad-with-non-streaming-asr.py @@ -0,0 +1,470 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2023 Xiaomi Corporation + +""" +This file demonstrates how to use sherpa-onnx Python APIs +with VAD and non-streaming ASR models for speech recognition +from a microphone. + +Note that you need a non-streaming model for this script. + +(1) For paraformer + + ./python-api-examples/vad-with-non-streaming-asr.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/paraformer.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 + +(2) For transducer models from icefall + + ./python-api-examples/vad-with-non-streaming-asr.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + --sample-rate=16000 \ + --feature-dim=80 + +(3) For Moonshine models + +./python-api-examples/vad-with-non-streaming-asr.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \ + --moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \ + --moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \ + --moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \ + --tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt \ + --num-threads=2 + +(4) For Whisper models + +./python-api-examples/vad-with-non-streaming-asr.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ + --whisper-task=transcribe \ + --num-threads=2 + +(5) For SenseVoice CTC models + +./python-api-examples/vad-with-non-streaming-asr.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --sense-voice=./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.onnx \ + --tokens=./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt \ + --num-threads=2 + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/index.html +to install sherpa-onnx and to download non-streaming pre-trained models +used in this file. + +Please visit +https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx +to download silero_vad.onnx + +For instance, + +wget https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx +""" +import argparse +import sys +from pathlib import Path + +import numpy as np + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +import sherpa_mnn + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--silero-vad-model", + type=str, + required=True, + help="Path to silero_vad.onnx", + ) + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder", + default="", + type=str, + help="Path to the transducer encoder model", + ) + + parser.add_argument( + "--decoder", + default="", + type=str, + help="Path to the transducer decoder model", + ) + + parser.add_argument( + "--joiner", + default="", + type=str, + help="Path to the transducer joiner model", + ) + + parser.add_argument( + "--paraformer", + default="", + type=str, + help="Path to the model.onnx from Paraformer", + ) + + parser.add_argument( + "--sense-voice", + default="", + type=str, + help="Path to the model.onnx from SenseVoice", + ) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--whisper-encoder", + default="", + type=str, + help="Path to whisper encoder model", + ) + + parser.add_argument( + "--whisper-decoder", + default="", + type=str, + help="Path to whisper decoder model", + ) + + parser.add_argument( + "--whisper-language", + default="", + type=str, + help="""It specifies the spoken language in the input file. + Example values: en, fr, de, zh, jp. + Available languages for multilingual models can be found at + https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + If not specified, we infer the language from the input audio file. + """, + ) + + parser.add_argument( + "--whisper-task", + default="transcribe", + choices=["transcribe", "translate"], + type=str, + help="""For multilingual models, if you specify translate, the output + will be in English. + """, + ) + + parser.add_argument( + "--whisper-tail-paddings", + default=-1, + type=int, + help="""Number of tail padding frames. + We have removed the 30-second constraint from whisper, so you need to + choose the amount of tail padding frames by yourself. + Use -1 to use a default value for tail padding. + """, + ) + + parser.add_argument( + "--moonshine-preprocessor", + default="", + type=str, + help="Path to moonshine preprocessor model", + ) + + parser.add_argument( + "--moonshine-encoder", + default="", + type=str, + help="Path to moonshine encoder model", + ) + + parser.add_argument( + "--moonshine-uncached-decoder", + default="", + type=str, + help="Path to moonshine uncached decoder model", + ) + + parser.add_argument( + "--moonshine-cached-decoder", + default="", + type=str, + help="Path to moonshine cached decoder model", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Valid values are greedy_search and modified_beam_search. + modified_beam_search is valid only for transducer models. + """, + ) + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages when loading modes.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="""Sample rate of the feature extractor. Must match the one + expected by the model.""", + ) + + parser.add_argument( + "--feature-dim", + type=int, + default=80, + help="Feature dimension. Must match the one expected by the model", + ) + + return parser.parse_args() + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def create_recognizer(args) -> sherpa_mnn.OfflineRecognizer: + if args.encoder: + assert len(args.paraformer) == 0, args.paraformer + assert len(args.sense_voice) == 0, args.sense_voice + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + assert_file_exists(args.encoder) + assert_file_exists(args.decoder) + assert_file_exists(args.joiner) + + recognizer = sherpa_mnn.OfflineRecognizer.from_transducer( + encoder=args.encoder, + decoder=args.decoder, + joiner=args.joiner, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + blank_penalty=args.blank_penalty, + debug=args.debug, + ) + elif args.paraformer: + assert len(args.sense_voice) == 0, args.sense_voice + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + assert_file_exists(args.paraformer) + + recognizer = sherpa_mnn.OfflineRecognizer.from_paraformer( + paraformer=args.paraformer, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) + elif args.sense_voice: + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + assert_file_exists(args.sense_voice) + recognizer = sherpa_mnn.OfflineRecognizer.from_sense_voice( + model=args.sense_voice, + tokens=args.tokens, + num_threads=args.num_threads, + use_itn=True, + debug=args.debug, + ) + elif args.whisper_encoder: + assert_file_exists(args.whisper_encoder) + assert_file_exists(args.whisper_decoder) + assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor + assert len(args.moonshine_encoder) == 0, args.moonshine_encoder + assert ( + len(args.moonshine_uncached_decoder) == 0 + ), args.moonshine_uncached_decoder + assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder + + recognizer = sherpa_mnn.OfflineRecognizer.from_whisper( + encoder=args.whisper_encoder, + decoder=args.whisper_decoder, + tokens=args.tokens, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + debug=args.debug, + language=args.whisper_language, + task=args.whisper_task, + tail_paddings=args.whisper_tail_paddings, + ) + elif args.moonshine_preprocessor: + assert_file_exists(args.moonshine_preprocessor) + assert_file_exists(args.moonshine_encoder) + assert_file_exists(args.moonshine_uncached_decoder) + assert_file_exists(args.moonshine_cached_decoder) + + recognizer = sherpa_mnn.OfflineRecognizer.from_moonshine( + preprocessor=args.moonshine_preprocessor, + encoder=args.moonshine_encoder, + uncached_decoder=args.moonshine_uncached_decoder, + cached_decoder=args.moonshine_cached_decoder, + tokens=args.tokens, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + debug=args.debug, + ) + else: + raise ValueError("Please specify at least one model") + + return recognizer + + +def main(): + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + sys.exit(0) + + print(devices) + + # If you want to select a different input device, please use + # sd.default.device[0] = xxx + # where xxx is the device number + + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + args = get_args() + assert_file_exists(args.tokens) + assert_file_exists(args.silero_vad_model) + + assert args.num_threads > 0, args.num_threads + + assert ( + args.sample_rate == 16000 + ), f"Only sample rate 16000 is supported.Given: {args.sample_rate}" + + print("Creating recognizer. Please wait...") + recognizer = create_recognizer(args) + + config = sherpa_mnn.VadModelConfig() + config.silero_vad.model = args.silero_vad_model + config.silero_vad.min_silence_duration = 0.25 + config.sample_rate = args.sample_rate + + window_size = config.silero_vad.window_size + + vad = sherpa_mnn.VoiceActivityDetector(config, buffer_size_in_seconds=100) + + samples_per_read = int(0.1 * args.sample_rate) # 0.1 second = 100 ms + + print("Started! Please speak") + + buffer = [] + texts = [] + with sd.InputStream(channels=1, dtype="float32", samplerate=args.sample_rate) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + + buffer = np.concatenate([buffer, samples]) + while len(buffer) > window_size: + vad.accept_waveform(buffer[:window_size]) + buffer = buffer[window_size:] + + while not vad.empty(): + stream = recognizer.create_stream() + stream.accept_waveform(args.sample_rate, vad.front.samples) + + vad.pop() + recognizer.decode_stream(stream) + + text = stream.result.text.strip().lower() + if len(text): + idx = len(texts) + texts.append(text) + print(f"{idx}: {text}") + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/apps/frameworks/sherpa-mnn/setup.py b/apps/frameworks/sherpa-mnn/setup.py new file mode 100644 index 00000000..b85e8b67 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/setup.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 + +import os +import re +from pathlib import Path + +import setuptools + +from cmake.cmake_extension import ( + BuildExtension, + bdist_wheel, + cmake_extension, + get_binaries, + is_windows, +) + + +def read_long_description(): + with open("README.md", encoding="utf8") as f: + readme = f.read() + return readme + + +def get_package_version(): + with open("CMakeLists.txt") as f: + content = f.read() + + match = re.search(r"set\(sherpa_mnn_VERSION (.*)\)", content) + latest_version = match.group(1).strip('"') + + cmake_args = os.environ.get("sherpa_mnn_CMAKE_ARGS", "") + extra_version = "" + if "-Dsherpa_mnn_ENABLE_GPU=ON" in cmake_args: + extra_version = "+cuda" + + latest_version += extra_version + + return latest_version + + +package_name = "sherpa-mnn" + +with open("sherpa-onnx/python/sherpa_mnn/__init__.py", "a") as f: + f.write(f"__version__ = '{get_package_version()}'\n") + + +def get_binaries_to_install(): + bin_dir = Path("build") / "sherpa_mnn" / "bin" + bin_dir.mkdir(parents=True, exist_ok=True) + suffix = ".exe" if is_windows() else "" + + binaries = get_binaries() + + exe = [] + for f in binaries: + suffix = "" if (".dll" in f or ".lib" in f) else suffix + t = bin_dir / (f + suffix) + exe.append(str(t)) + return exe + + +setuptools.setup( + name=package_name, + python_requires=">=3.6", + version=get_package_version(), + author="The sherpa-onnx development team", + author_email="dpovey@gmail.com", + package_dir={ + "sherpa_mnn": "sherpa-onnx/python/sherpa_mnn", + }, + packages=["sherpa_mnn"], + data_files=[("bin", get_binaries_to_install())], + url="https://github.com/k2-fsa/sherpa-onnx", + long_description=read_long_description(), + long_description_content_type="text/markdown", + ext_modules=[cmake_extension("_sherpa_mnn")], + cmdclass={"build_ext": BuildExtension, "bdist_wheel": bdist_wheel}, + zip_safe=False, + classifiers=[ + "Programming Language :: C++", + "Programming Language :: Python", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + entry_points={ + "console_scripts": [ + "sherpa-onnx-cli=sherpa_mnn.cli:cli", + ], + }, + license="Apache licensed, as found in the LICENSE file", +) + +with open("sherpa-onnx/python/sherpa_mnn/__init__.py", "r") as f: + lines = f.readlines() + +with open("sherpa-onnx/python/sherpa_mnn/__init__.py", "w") as f: + for line in lines: + if "__version__" in line: + # skip __version__ = "x.x.x" + continue + f.write(line) diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/CMakeLists.txt b/apps/frameworks/sherpa-mnn/sherpa-mnn/CMakeLists.txt new file mode 100644 index 00000000..2dbae97f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/CMakeLists.txt @@ -0,0 +1,12 @@ +add_subdirectory(csrc) +if(SHERPA_MNN_ENABLE_PYTHON) + add_subdirectory(python) +endif() + +if(SHERPA_MNN_ENABLE_JNI) + add_subdirectory(jni) +endif() + +if(SHERPA_MNN_ENABLE_C_API) + add_subdirectory(c-api) +endif() diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/c-api/CMakeLists.txt b/apps/frameworks/sherpa-mnn/sherpa-mnn/c-api/CMakeLists.txt new file mode 100644 index 00000000..6ec9e199 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/c-api/CMakeLists.txt @@ -0,0 +1,27 @@ +include_directories(${CMAKE_SOURCE_DIR}) +add_library(sherpa-mnn-c-api c-api.cc) +target_link_libraries(sherpa-mnn-c-api sherpa-mnn-core) + +if(BUILD_SHARED_LIBS) + target_compile_definitions(sherpa-mnn-c-api PUBLIC SHERPA_MNN_BUILD_SHARED_LIBS=1) + target_compile_definitions(sherpa-mnn-c-api PUBLIC SHERPA_MNN_BUILD_MAIN_LIB=1) +endif() + +add_library(sherpa-mnn-cxx-api cxx-api.cc) +target_link_libraries(sherpa-mnn-cxx-api sherpa-mnn-c-api) + +install( + TARGETS + sherpa-mnn-c-api + sherpa-mnn-cxx-api + DESTINATION + lib +) + +install( + FILES + c-api.h + cxx-api.h + DESTINATION + include/sherpa-mnn/c-api +) diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/c-api/c-api.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/c-api/c-api.cc new file mode 100644 index 00000000..55a06a7d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/c-api/c-api.cc @@ -0,0 +1,2476 @@ +// sherpa-mnn/c-api/c-api.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/c-api/c-api.h" + +#include +#include +#include +#include +#include +#include +#include + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/audio-tagging.h" +#include "sherpa-mnn/csrc/circular-buffer.h" +#include "sherpa-mnn/csrc/display.h" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/keyword-spotter.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-punctuation.h" +#include "sherpa-mnn/csrc/offline-recognizer.h" +#include "sherpa-mnn/csrc/offline-speech-denoiser.h" +#include "sherpa-mnn/csrc/online-punctuation.h" +#include "sherpa-mnn/csrc/online-recognizer.h" +#include "sherpa-mnn/csrc/resample.h" +#include "sherpa-mnn/csrc/speaker-embedding-extractor.h" +#include "sherpa-mnn/csrc/speaker-embedding-manager.h" +#include "sherpa-mnn/csrc/spoken-language-identification.h" +#include "sherpa-mnn/csrc/voice-activity-detector.h" +#include "sherpa-mnn/csrc/wave-reader.h" +#include "sherpa-mnn/csrc/wave-writer.h" + +#if SHERPA_MNN_ENABLE_TTS == 1 +#include "sherpa-mnn/csrc/offline-tts.h" +#endif + +#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1 +#include "sherpa-mnn/csrc/offline-speaker-diarization.h" +#endif + +struct SherpaMnnOnlineRecognizer { + std::unique_ptr impl; +}; + +struct SherpaMnnOnlineStream { + std::unique_ptr impl; + explicit SherpaMnnOnlineStream(std::unique_ptr p) + : impl(std::move(p)) {} +}; + +struct SherpaMnnDisplay { + std::unique_ptr impl; +}; + +#define SHERPA_ONNX_OR(x, y) (x ? x : y) + +static sherpa_mnn::OnlineRecognizerConfig GetOnlineRecognizerConfig( + const SherpaMnnOnlineRecognizerConfig *config) { + sherpa_mnn::OnlineRecognizerConfig recognizer_config; + + recognizer_config.feat_config.sampling_rate = + SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); + recognizer_config.feat_config.feature_dim = + SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); + + recognizer_config.model_config.transducer.encoder = + SHERPA_ONNX_OR(config->model_config.transducer.encoder, ""); + recognizer_config.model_config.transducer.decoder = + SHERPA_ONNX_OR(config->model_config.transducer.decoder, ""); + recognizer_config.model_config.transducer.joiner = + SHERPA_ONNX_OR(config->model_config.transducer.joiner, ""); + + recognizer_config.model_config.paraformer.encoder = + SHERPA_ONNX_OR(config->model_config.paraformer.encoder, ""); + recognizer_config.model_config.paraformer.decoder = + SHERPA_ONNX_OR(config->model_config.paraformer.decoder, ""); + + recognizer_config.model_config.zipformer2_ctc.model = + SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, ""); + + recognizer_config.model_config.tokens = + SHERPA_ONNX_OR(config->model_config.tokens, ""); + if (config->model_config.tokens_buf && + config->model_config.tokens_buf_size > 0) { + recognizer_config.model_config.tokens_buf = std::string( + config->model_config.tokens_buf, config->model_config.tokens_buf_size); + } + + recognizer_config.model_config.num_threads = + SHERPA_ONNX_OR(config->model_config.num_threads, 1); + recognizer_config.model_config.provider_config.provider = + SHERPA_ONNX_OR(config->model_config.provider, "cpu"); + + if (recognizer_config.model_config.provider_config.provider.empty()) { + recognizer_config.model_config.provider_config.provider = "cpu"; + } + + recognizer_config.model_config.model_type = + SHERPA_ONNX_OR(config->model_config.model_type, ""); + recognizer_config.model_config.debug = + SHERPA_ONNX_OR(config->model_config.debug, 0); + recognizer_config.model_config.modeling_unit = + SHERPA_ONNX_OR(config->model_config.modeling_unit, "cjkchar"); + + if (recognizer_config.model_config.modeling_unit.empty()) { + recognizer_config.model_config.modeling_unit = "cjkchar"; + } + + recognizer_config.model_config.bpe_vocab = + SHERPA_ONNX_OR(config->model_config.bpe_vocab, ""); + + recognizer_config.decoding_method = + SHERPA_ONNX_OR(config->decoding_method, "greedy_search"); + if (recognizer_config.decoding_method.empty()) { + recognizer_config.decoding_method = "greedy_search"; + } + + recognizer_config.max_active_paths = + SHERPA_ONNX_OR(config->max_active_paths, 4); + + recognizer_config.enable_endpoint = + SHERPA_ONNX_OR(config->enable_endpoint, 0); + + recognizer_config.endpoint_config.rule1.min_trailing_silence = + SHERPA_ONNX_OR(config->rule1_min_trailing_silence, 2.4); + + recognizer_config.endpoint_config.rule2.min_trailing_silence = + SHERPA_ONNX_OR(config->rule2_min_trailing_silence, 1.2); + + recognizer_config.endpoint_config.rule3.min_utterance_length = + SHERPA_ONNX_OR(config->rule3_min_utterance_length, 20); + + recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, ""); + recognizer_config.hotwords_score = + SHERPA_ONNX_OR(config->hotwords_score, 1.5); + if (config->hotwords_buf && config->hotwords_buf_size > 0) { + recognizer_config.hotwords_buf = + std::string(config->hotwords_buf, config->hotwords_buf_size); + } + + recognizer_config.blank_penalty = config->blank_penalty; + + recognizer_config.ctc_fst_decoder_config.graph = + SHERPA_ONNX_OR(config->ctc_fst_decoder_config.graph, ""); + recognizer_config.ctc_fst_decoder_config.max_active = + SHERPA_ONNX_OR(config->ctc_fst_decoder_config.max_active, 3000); + + recognizer_config.rule_fsts = SHERPA_ONNX_OR(config->rule_fsts, ""); + recognizer_config.rule_fars = SHERPA_ONNX_OR(config->rule_fars, ""); + + if (config->model_config.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", recognizer_config.ToString().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", recognizer_config.ToString().c_str()); +#endif + } + + return recognizer_config; +} + +const SherpaMnnOnlineRecognizer *SherpaMnnCreateOnlineRecognizer( + const SherpaMnnOnlineRecognizerConfig *config) { + sherpa_mnn::OnlineRecognizerConfig recognizer_config = + GetOnlineRecognizerConfig(config); + + if (!recognizer_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in config!"); + return nullptr; + } + + SherpaMnnOnlineRecognizer *recognizer = new SherpaMnnOnlineRecognizer; + + recognizer->impl = + std::make_unique(recognizer_config); + + return recognizer; +} + +void SherpaMnnDestroyOnlineRecognizer( + const SherpaMnnOnlineRecognizer *recognizer) { + delete recognizer; +} + +const SherpaMnnOnlineStream *SherpaMnnCreateOnlineStream( + const SherpaMnnOnlineRecognizer *recognizer) { + SherpaMnnOnlineStream *stream = + new SherpaMnnOnlineStream(recognizer->impl->CreateStream()); + return stream; +} + +const SherpaMnnOnlineStream *SherpaMnnCreateOnlineStreamWithHotwords( + const SherpaMnnOnlineRecognizer *recognizer, const char *hotwords) { + SherpaMnnOnlineStream *stream = + new SherpaMnnOnlineStream(recognizer->impl->CreateStream(hotwords)); + return stream; +} + +void SherpaMnnDestroyOnlineStream(const SherpaMnnOnlineStream *stream) { + delete stream; +} + +void SherpaMnnOnlineStreamAcceptWaveform(const SherpaMnnOnlineStream *stream, + int32_t sample_rate, + const float *samples, int32_t n) { + stream->impl->AcceptWaveform(sample_rate, samples, n); +} + +int32_t SherpaMnnIsOnlineStreamReady( + const SherpaMnnOnlineRecognizer *recognizer, + const SherpaMnnOnlineStream *stream) { + return recognizer->impl->IsReady(stream->impl.get()); +} + +void SherpaMnnDecodeOnlineStream(const SherpaMnnOnlineRecognizer *recognizer, + const SherpaMnnOnlineStream *stream) { + recognizer->impl->DecodeStream(stream->impl.get()); +} + +void SherpaMnnDecodeMultipleOnlineStreams( + const SherpaMnnOnlineRecognizer *recognizer, + const SherpaMnnOnlineStream **streams, int32_t n) { + std::vector ss(n); + for (int32_t i = 0; i != n; ++i) { + ss[i] = streams[i]->impl.get(); + } + recognizer->impl->DecodeStreams(ss.data(), n); +} + +const SherpaMnnOnlineRecognizerResult *SherpaMnnGetOnlineStreamResult( + const SherpaMnnOnlineRecognizer *recognizer, + const SherpaMnnOnlineStream *stream) { + sherpa_mnn::OnlineRecognizerResult result = + recognizer->impl->GetResult(stream->impl.get()); + const auto &text = result.text; + + auto r = new SherpaMnnOnlineRecognizerResult; + memset(r, 0, sizeof(SherpaMnnOnlineRecognizerResult)); + + // copy text + char *pText = new char[text.size() + 1]; + std::copy(text.begin(), text.end(), pText); + pText[text.size()] = 0; + r->text = pText; + + // copy json + std::string json = result.AsJsonString(); + char *pJson = new char[json.size() + 1]; + std::copy(json.begin(), json.end(), pJson); + pJson[json.size()] = 0; + r->json = pJson; + + // copy tokens + auto count = result.tokens.size(); + if (count > 0) { + size_t total_length = 0; + for (const auto &token : result.tokens) { + // +1 for the null character at the end of each token + total_length += token.size() + 1; + } + + r->count = count; + // Each word ends with nullptr + char *tokens = new char[total_length]{}; + char **tokens_temp = new char *[r->count]; + int32_t pos = 0; + for (int32_t i = 0; i < r->count; ++i) { + tokens_temp[i] = tokens + pos; + memcpy(tokens + pos, result.tokens[i].c_str(), result.tokens[i].size()); + // +1 to move past the null character + pos += result.tokens[i].size() + 1; + } + r->tokens_arr = tokens_temp; + + if (!result.timestamps.empty() && result.timestamps.size() == r->count) { + r->timestamps = new float[r->count]; + std::copy(result.timestamps.begin(), result.timestamps.end(), + r->timestamps); + } else { + r->timestamps = nullptr; + } + + r->tokens = tokens; + } else { + r->count = 0; + r->timestamps = nullptr; + r->tokens = nullptr; + r->tokens_arr = nullptr; + } + + return r; +} + +void SherpaMnnDestroyOnlineRecognizerResult( + const SherpaMnnOnlineRecognizerResult *r) { + if (r) { + delete[] r->text; + delete[] r->json; + delete[] r->tokens; + delete[] r->tokens_arr; + delete[] r->timestamps; + delete r; + } +} + +const char *SherpaMnnGetOnlineStreamResultAsJson( + const SherpaMnnOnlineRecognizer *recognizer, + const SherpaMnnOnlineStream *stream) { + sherpa_mnn::OnlineRecognizerResult result = + recognizer->impl->GetResult(stream->impl.get()); + std::string json = result.AsJsonString(); + char *pJson = new char[json.size() + 1]; + std::copy(json.begin(), json.end(), pJson); + pJson[json.size()] = 0; + return pJson; +} + +void SherpaMnnDestroyOnlineStreamResultJson(const char *s) { delete[] s; } + +void SherpaMnnOnlineStreamReset(const SherpaMnnOnlineRecognizer *recognizer, + const SherpaMnnOnlineStream *stream) { + recognizer->impl->Reset(stream->impl.get()); +} + +void SherpaMnnOnlineStreamInputFinished(const SherpaMnnOnlineStream *stream) { + stream->impl->InputFinished(); +} + +int32_t SherpaMnnOnlineStreamIsEndpoint( + const SherpaMnnOnlineRecognizer *recognizer, + const SherpaMnnOnlineStream *stream) { + return recognizer->impl->IsEndpoint(stream->impl.get()); +} + +const SherpaMnnDisplay *SherpaMnnCreateDisplay(int32_t max_word_per_line) { + SherpaMnnDisplay *ans = new SherpaMnnDisplay; + ans->impl = std::make_unique(max_word_per_line); + return ans; +} + +void SherpaMnnDestroyDisplay(const SherpaMnnDisplay *display) { + delete display; +} + +void SherpaMnnPrint(const SherpaMnnDisplay *display, int32_t idx, + const char *s) { + display->impl->Print(idx, s); +} + +// ============================================================ +// For offline ASR (i.e., non-streaming ASR) +// ============================================================ +// +struct SherpaMnnOfflineRecognizer { + std::unique_ptr impl; +}; + +struct SherpaMnnOfflineStream { + std::unique_ptr impl; + explicit SherpaMnnOfflineStream( + std::unique_ptr p) + : impl(std::move(p)) {} +}; + +static sherpa_mnn::OfflineRecognizerConfig GetOfflineRecognizerConfig( + const SherpaMnnOfflineRecognizerConfig *config) { + sherpa_mnn::OfflineRecognizerConfig recognizer_config; + + recognizer_config.feat_config.sampling_rate = + SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); + + recognizer_config.feat_config.feature_dim = + SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); + + recognizer_config.model_config.transducer.encoder_filename = + SHERPA_ONNX_OR(config->model_config.transducer.encoder, ""); + + recognizer_config.model_config.transducer.decoder_filename = + SHERPA_ONNX_OR(config->model_config.transducer.decoder, ""); + + recognizer_config.model_config.transducer.joiner_filename = + SHERPA_ONNX_OR(config->model_config.transducer.joiner, ""); + + recognizer_config.model_config.paraformer.model = + SHERPA_ONNX_OR(config->model_config.paraformer.model, ""); + + recognizer_config.model_config.nemo_ctc.model = + SHERPA_ONNX_OR(config->model_config.nemo_ctc.model, ""); + + recognizer_config.model_config.whisper.encoder = + SHERPA_ONNX_OR(config->model_config.whisper.encoder, ""); + + recognizer_config.model_config.whisper.decoder = + SHERPA_ONNX_OR(config->model_config.whisper.decoder, ""); + + recognizer_config.model_config.whisper.language = + SHERPA_ONNX_OR(config->model_config.whisper.language, ""); + + recognizer_config.model_config.whisper.task = + SHERPA_ONNX_OR(config->model_config.whisper.task, "transcribe"); + if (recognizer_config.model_config.whisper.task.empty()) { + recognizer_config.model_config.whisper.task = "transcribe"; + } + + recognizer_config.model_config.whisper.tail_paddings = + SHERPA_ONNX_OR(config->model_config.whisper.tail_paddings, -1); + + recognizer_config.model_config.tdnn.model = + SHERPA_ONNX_OR(config->model_config.tdnn.model, ""); + + recognizer_config.model_config.tokens = + SHERPA_ONNX_OR(config->model_config.tokens, ""); + recognizer_config.model_config.num_threads = + SHERPA_ONNX_OR(config->model_config.num_threads, 1); + recognizer_config.model_config.debug = + SHERPA_ONNX_OR(config->model_config.debug, 0); + recognizer_config.model_config.provider = + SHERPA_ONNX_OR(config->model_config.provider, "cpu"); + if (recognizer_config.model_config.provider.empty()) { + recognizer_config.model_config.provider = "cpu"; + } + + recognizer_config.model_config.model_type = + SHERPA_ONNX_OR(config->model_config.model_type, ""); + recognizer_config.model_config.modeling_unit = + SHERPA_ONNX_OR(config->model_config.modeling_unit, "cjkchar"); + + if (recognizer_config.model_config.modeling_unit.empty()) { + recognizer_config.model_config.modeling_unit = "cjkchar"; + } + + recognizer_config.model_config.bpe_vocab = + SHERPA_ONNX_OR(config->model_config.bpe_vocab, ""); + + recognizer_config.model_config.telespeech_ctc = + SHERPA_ONNX_OR(config->model_config.telespeech_ctc, ""); + + recognizer_config.model_config.sense_voice.model = + SHERPA_ONNX_OR(config->model_config.sense_voice.model, ""); + + recognizer_config.model_config.sense_voice.language = + SHERPA_ONNX_OR(config->model_config.sense_voice.language, ""); + + recognizer_config.model_config.sense_voice.use_itn = + config->model_config.sense_voice.use_itn; + + recognizer_config.model_config.moonshine.preprocessor = + SHERPA_ONNX_OR(config->model_config.moonshine.preprocessor, ""); + + recognizer_config.model_config.moonshine.encoder = + SHERPA_ONNX_OR(config->model_config.moonshine.encoder, ""); + + recognizer_config.model_config.moonshine.uncached_decoder = + SHERPA_ONNX_OR(config->model_config.moonshine.uncached_decoder, ""); + + recognizer_config.model_config.moonshine.cached_decoder = + SHERPA_ONNX_OR(config->model_config.moonshine.cached_decoder, ""); + + recognizer_config.model_config.fire_red_asr.encoder = + SHERPA_ONNX_OR(config->model_config.fire_red_asr.encoder, ""); + + recognizer_config.model_config.fire_red_asr.decoder = + SHERPA_ONNX_OR(config->model_config.fire_red_asr.decoder, ""); + + recognizer_config.lm_config.model = + SHERPA_ONNX_OR(config->lm_config.model, ""); + recognizer_config.lm_config.scale = + SHERPA_ONNX_OR(config->lm_config.scale, 1.0); + + recognizer_config.decoding_method = + SHERPA_ONNX_OR(config->decoding_method, "greedy_search"); + + if (recognizer_config.decoding_method.empty()) { + recognizer_config.decoding_method = "greedy_search"; + } + + recognizer_config.max_active_paths = + SHERPA_ONNX_OR(config->max_active_paths, 4); + + recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, ""); + recognizer_config.hotwords_score = + SHERPA_ONNX_OR(config->hotwords_score, 1.5); + + recognizer_config.blank_penalty = config->blank_penalty; + + recognizer_config.rule_fsts = SHERPA_ONNX_OR(config->rule_fsts, ""); + recognizer_config.rule_fars = SHERPA_ONNX_OR(config->rule_fars, ""); + + if (config->model_config.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", recognizer_config.ToString().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", recognizer_config.ToString().c_str()); +#endif + } + + return recognizer_config; +} + +const SherpaMnnOfflineRecognizer *SherpaMnnCreateOfflineRecognizer( + const SherpaMnnOfflineRecognizerConfig *config) { + sherpa_mnn::OfflineRecognizerConfig recognizer_config = + GetOfflineRecognizerConfig(config); + + if (!recognizer_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in config"); + return nullptr; + } + + SherpaMnnOfflineRecognizer *recognizer = new SherpaMnnOfflineRecognizer; + + recognizer->impl = + std::make_unique(recognizer_config); + + return recognizer; +} + +void SherpaMnnOfflineRecognizerSetConfig( + const SherpaMnnOfflineRecognizer *recognizer, + const SherpaMnnOfflineRecognizerConfig *config) { + sherpa_mnn::OfflineRecognizerConfig recognizer_config = + GetOfflineRecognizerConfig(config); + recognizer->impl->SetConfig(recognizer_config); +} + +void SherpaMnnDestroyOfflineRecognizer( + const SherpaMnnOfflineRecognizer *recognizer) { + delete recognizer; +} + +const SherpaMnnOfflineStream *SherpaMnnCreateOfflineStream( + const SherpaMnnOfflineRecognizer *recognizer) { + SherpaMnnOfflineStream *stream = + new SherpaMnnOfflineStream(recognizer->impl->CreateStream()); + return stream; +} + +const SherpaMnnOfflineStream *SherpaMnnCreateOfflineStreamWithHotwords( + const SherpaMnnOfflineRecognizer *recognizer, const char *hotwords) { + SherpaMnnOfflineStream *stream = + new SherpaMnnOfflineStream(recognizer->impl->CreateStream(hotwords)); + return stream; +} + +void SherpaMnnDestroyOfflineStream(const SherpaMnnOfflineStream *stream) { + delete stream; +} + +void SherpaMnnAcceptWaveformOffline(const SherpaMnnOfflineStream *stream, + int32_t sample_rate, const float *samples, + int32_t n) { + stream->impl->AcceptWaveform(sample_rate, samples, n); +} + +void SherpaMnnDecodeOfflineStream( + const SherpaMnnOfflineRecognizer *recognizer, + const SherpaMnnOfflineStream *stream) { + recognizer->impl->DecodeStream(stream->impl.get()); +} + +void SherpaMnnDecodeMultipleOfflineStreams( + const SherpaMnnOfflineRecognizer *recognizer, + const SherpaMnnOfflineStream **streams, int32_t n) { + std::vector ss(n); + for (int32_t i = 0; i != n; ++i) { + ss[i] = streams[i]->impl.get(); + } + recognizer->impl->DecodeStreams(ss.data(), n); +} + +const SherpaMnnOfflineRecognizerResult *SherpaMnnGetOfflineStreamResult( + const SherpaMnnOfflineStream *stream) { + const sherpa_mnn::OfflineRecognitionResult &result = + stream->impl->GetResult(); + const auto &text = result.text; + + auto r = new SherpaMnnOfflineRecognizerResult; + memset(r, 0, sizeof(SherpaMnnOfflineRecognizerResult)); + + char *pText = new char[text.size() + 1]; + std::copy(text.begin(), text.end(), pText); + pText[text.size()] = 0; + r->text = pText; + + // lang + const auto &lang = result.lang; + char *c_lang = new char[lang.size() + 1]; + std::copy(lang.begin(), lang.end(), c_lang); + c_lang[lang.size()] = '\0'; + r->lang = c_lang; + + // emotion + const auto &emotion = result.emotion; + char *c_emotion = new char[emotion.size() + 1]; + std::copy(emotion.begin(), emotion.end(), c_emotion); + c_emotion[emotion.size()] = '\0'; + r->emotion = c_emotion; + + // event + const auto &event = result.event; + char *c_event = new char[event.size() + 1]; + std::copy(event.begin(), event.end(), c_event); + c_event[event.size()] = '\0'; + r->event = c_event; + + // copy json + std::string json = result.AsJsonString(); + char *pJson = new char[json.size() + 1]; + std::copy(json.begin(), json.end(), pJson); + pJson[json.size()] = 0; + r->json = pJson; + + // copy tokens + auto count = result.tokens.size(); + if (count > 0) { + size_t total_length = 0; + for (const auto &token : result.tokens) { + // +1 for the null character at the end of each token + total_length += token.size() + 1; + } + + r->count = count; + // Each word ends with nullptr + char *tokens = new char[total_length]{}; + char **tokens_temp = new char *[r->count]; + int32_t pos = 0; + for (int32_t i = 0; i < r->count; ++i) { + tokens_temp[i] = tokens + pos; + memcpy(tokens + pos, result.tokens[i].c_str(), result.tokens[i].size()); + // +1 to move past the null character + pos += result.tokens[i].size() + 1; + } + r->tokens_arr = tokens_temp; + + if (!result.timestamps.empty() && result.timestamps.size() == r->count) { + r->timestamps = new float[r->count]; + std::copy(result.timestamps.begin(), result.timestamps.end(), + r->timestamps); + } else { + r->timestamps = nullptr; + } + + r->tokens = tokens; + } else { + r->count = 0; + r->timestamps = nullptr; + r->tokens = nullptr; + r->tokens_arr = nullptr; + } + + return r; +} + +void SherpaMnnDestroyOfflineRecognizerResult( + const SherpaMnnOfflineRecognizerResult *r) { + if (r) { + delete[] r->text; + delete[] r->timestamps; + delete[] r->tokens; + delete[] r->tokens_arr; + delete[] r->json; + delete[] r->lang; + delete[] r->emotion; + delete[] r->event; + delete r; + } +} + +const char *SherpaMnnGetOfflineStreamResultAsJson( + const SherpaMnnOfflineStream *stream) { + const sherpa_mnn::OfflineRecognitionResult &result = + stream->impl->GetResult(); + std::string json = result.AsJsonString(); + char *pJson = new char[json.size() + 1]; + std::copy(json.begin(), json.end(), pJson); + pJson[json.size()] = 0; + return pJson; +} + +void SherpaMnnDestroyOfflineStreamResultJson(const char *s) { delete[] s; } + +// ============================================================ +// For Keyword Spot +// ============================================================ + +struct SherpaMnnKeywordSpotter { + std::unique_ptr impl; +}; + +static sherpa_mnn::KeywordSpotterConfig GetKeywordSpotterConfig( + const SherpaMnnKeywordSpotterConfig *config) { + sherpa_mnn::KeywordSpotterConfig spotter_config; + + spotter_config.feat_config.sampling_rate = + SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); + spotter_config.feat_config.feature_dim = + SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); + + spotter_config.model_config.transducer.encoder = + SHERPA_ONNX_OR(config->model_config.transducer.encoder, ""); + spotter_config.model_config.transducer.decoder = + SHERPA_ONNX_OR(config->model_config.transducer.decoder, ""); + spotter_config.model_config.transducer.joiner = + SHERPA_ONNX_OR(config->model_config.transducer.joiner, ""); + + spotter_config.model_config.paraformer.encoder = + SHERPA_ONNX_OR(config->model_config.paraformer.encoder, ""); + spotter_config.model_config.paraformer.decoder = + SHERPA_ONNX_OR(config->model_config.paraformer.decoder, ""); + + spotter_config.model_config.zipformer2_ctc.model = + SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, ""); + + spotter_config.model_config.tokens = + SHERPA_ONNX_OR(config->model_config.tokens, ""); + if (config->model_config.tokens_buf && + config->model_config.tokens_buf_size > 0) { + spotter_config.model_config.tokens_buf = std::string( + config->model_config.tokens_buf, config->model_config.tokens_buf_size); + } + + spotter_config.model_config.num_threads = + SHERPA_ONNX_OR(config->model_config.num_threads, 1); + spotter_config.model_config.provider_config.provider = + SHERPA_ONNX_OR(config->model_config.provider, "cpu"); + if (spotter_config.model_config.provider_config.provider.empty()) { + spotter_config.model_config.provider_config.provider = "cpu"; + } + + spotter_config.model_config.model_type = + SHERPA_ONNX_OR(config->model_config.model_type, ""); + spotter_config.model_config.debug = + SHERPA_ONNX_OR(config->model_config.debug, 0); + + spotter_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4); + + spotter_config.num_trailing_blanks = + SHERPA_ONNX_OR(config->num_trailing_blanks, 1); + + spotter_config.keywords_score = SHERPA_ONNX_OR(config->keywords_score, 1.0); + + spotter_config.keywords_threshold = + SHERPA_ONNX_OR(config->keywords_threshold, 0.25); + + spotter_config.keywords_file = SHERPA_ONNX_OR(config->keywords_file, ""); + if (config->keywords_buf && config->keywords_buf_size > 0) { + spotter_config.keywords_buf = + std::string(config->keywords_buf, config->keywords_buf_size); + } + + if (spotter_config.model_config.debug) { +#if OHOS + SHERPA_ONNX_LOGE("%{public}s\n", spotter_config.ToString().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str()); +#endif + } + + return spotter_config; +} + +const SherpaMnnKeywordSpotter *SherpaMnnCreateKeywordSpotter( + const SherpaMnnKeywordSpotterConfig *config) { + auto spotter_config = GetKeywordSpotterConfig(config); + if (!spotter_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in config!"); + return nullptr; + } + + SherpaMnnKeywordSpotter *spotter = new SherpaMnnKeywordSpotter; + + spotter->impl = std::make_unique(spotter_config); + + return spotter; +} + +void SherpaMnnDestroyKeywordSpotter(const SherpaMnnKeywordSpotter *spotter) { + delete spotter; +} + +const SherpaMnnOnlineStream *SherpaMnnCreateKeywordStream( + const SherpaMnnKeywordSpotter *spotter) { + SherpaMnnOnlineStream *stream = + new SherpaMnnOnlineStream(spotter->impl->CreateStream()); + return stream; +} + +const SherpaMnnOnlineStream *SherpaMnnCreateKeywordStreamWithKeywords( + const SherpaMnnKeywordSpotter *spotter, const char *keywords) { + SherpaMnnOnlineStream *stream = + new SherpaMnnOnlineStream(spotter->impl->CreateStream(keywords)); + return stream; +} + +int32_t SherpaMnnIsKeywordStreamReady(const SherpaMnnKeywordSpotter *spotter, + const SherpaMnnOnlineStream *stream) { + return spotter->impl->IsReady(stream->impl.get()); +} + +void SherpaMnnDecodeKeywordStream(const SherpaMnnKeywordSpotter *spotter, + const SherpaMnnOnlineStream *stream) { + spotter->impl->DecodeStream(stream->impl.get()); +} + +void SherpaMnnResetKeywordStream(const SherpaMnnKeywordSpotter *spotter, + const SherpaMnnOnlineStream *stream) { + spotter->impl->Reset(stream->impl.get()); +} + +void SherpaMnnDecodeMultipleKeywordStreams( + const SherpaMnnKeywordSpotter *spotter, + const SherpaMnnOnlineStream **streams, int32_t n) { + std::vector ss(n); + for (int32_t i = 0; i != n; ++i) { + ss[i] = streams[i]->impl.get(); + } + spotter->impl->DecodeStreams(ss.data(), n); +} + +const SherpaMnnKeywordResult *SherpaMnnGetKeywordResult( + const SherpaMnnKeywordSpotter *spotter, + const SherpaMnnOnlineStream *stream) { + const sherpa_mnn::KeywordResult &result = + spotter->impl->GetResult(stream->impl.get()); + const auto &keyword = result.keyword; + + auto r = new SherpaMnnKeywordResult; + memset(r, 0, sizeof(SherpaMnnKeywordResult)); + + r->start_time = result.start_time; + + // copy keyword + char *pKeyword = new char[keyword.size() + 1]; + std::copy(keyword.begin(), keyword.end(), pKeyword); + pKeyword[keyword.size()] = 0; + r->keyword = pKeyword; + + // copy json + std::string json = result.AsJsonString(); + char *pJson = new char[json.size() + 1]; + std::copy(json.begin(), json.end(), pJson); + pJson[json.size()] = 0; + r->json = pJson; + + // copy tokens + auto count = result.tokens.size(); + if (count > 0) { + size_t total_length = 0; + for (const auto &token : result.tokens) { + // +1 for the null character at the end of each token + total_length += token.size() + 1; + } + + r->count = count; + // Each word ends with nullptr + char *pTokens = new char[total_length]{}; + char **tokens_temp = new char *[r->count]; + int32_t pos = 0; + for (int32_t i = 0; i < r->count; ++i) { + tokens_temp[i] = pTokens + pos; + memcpy(pTokens + pos, result.tokens[i].c_str(), result.tokens[i].size()); + // +1 to move past the null character + pos += result.tokens[i].size() + 1; + } + r->tokens = pTokens; + r->tokens_arr = tokens_temp; + + if (!result.timestamps.empty()) { + r->timestamps = new float[result.timestamps.size()]; + std::copy(result.timestamps.begin(), result.timestamps.end(), + r->timestamps); + } else { + r->timestamps = nullptr; + } + + } else { + r->count = 0; + r->timestamps = nullptr; + r->tokens = nullptr; + r->tokens_arr = nullptr; + } + + return r; +} + +void SherpaMnnDestroyKeywordResult(const SherpaMnnKeywordResult *r) { + if (r) { + delete[] r->keyword; + delete[] r->json; + delete[] r->tokens; + delete[] r->tokens_arr; + delete[] r->timestamps; + delete r; + } +} + +const char *SherpaMnnGetKeywordResultAsJson( + const SherpaMnnKeywordSpotter *spotter, + const SherpaMnnOnlineStream *stream) { + const sherpa_mnn::KeywordResult &result = + spotter->impl->GetResult(stream->impl.get()); + + std::string json = result.AsJsonString(); + char *pJson = new char[json.size() + 1]; + std::copy(json.begin(), json.end(), pJson); + pJson[json.size()] = 0; + return pJson; +} + +void SherpaMnnFreeKeywordResultJson(const char *s) { delete[] s; } + +// ============================================================ +// For VAD +// ============================================================ +// +struct SherpaMnnCircularBuffer { + std::unique_ptr impl; +}; + +const SherpaMnnCircularBuffer *SherpaMnnCreateCircularBuffer( + int32_t capacity) { + SherpaMnnCircularBuffer *buffer = new SherpaMnnCircularBuffer; + buffer->impl = std::make_unique(capacity); + return buffer; +} + +void SherpaMnnDestroyCircularBuffer(const SherpaMnnCircularBuffer *buffer) { + delete buffer; +} + +void SherpaMnnCircularBufferPush(const SherpaMnnCircularBuffer *buffer, + const float *p, int32_t n) { + buffer->impl->Push(p, n); +} + +const float *SherpaMnnCircularBufferGet(const SherpaMnnCircularBuffer *buffer, + int32_t start_index, int32_t n) { + std::vector v = buffer->impl->Get(start_index, n); + + float *p = new float[n]; + std::copy(v.begin(), v.end(), p); + return p; +} + +void SherpaMnnCircularBufferFree(const float *p) { delete[] p; } + +void SherpaMnnCircularBufferPop(const SherpaMnnCircularBuffer *buffer, + int32_t n) { + buffer->impl->Pop(n); +} + +int32_t SherpaMnnCircularBufferSize(const SherpaMnnCircularBuffer *buffer) { + return buffer->impl->Size(); +} + +int32_t SherpaMnnCircularBufferHead(const SherpaMnnCircularBuffer *buffer) { + return buffer->impl->Head(); +} + +void SherpaMnnCircularBufferReset(const SherpaMnnCircularBuffer *buffer) { + buffer->impl->Reset(); +} + +struct SherpaMnnVoiceActivityDetector { + std::unique_ptr impl; +}; + +sherpa_mnn::VadModelConfig GetVadModelConfig( + const SherpaMnnVadModelConfig *config) { + sherpa_mnn::VadModelConfig vad_config; + + vad_config.silero_vad.model = SHERPA_ONNX_OR(config->silero_vad.model, ""); + vad_config.silero_vad.threshold = + SHERPA_ONNX_OR(config->silero_vad.threshold, 0.5); + + vad_config.silero_vad.min_silence_duration = + SHERPA_ONNX_OR(config->silero_vad.min_silence_duration, 0.5); + + vad_config.silero_vad.min_speech_duration = + SHERPA_ONNX_OR(config->silero_vad.min_speech_duration, 0.25); + + vad_config.silero_vad.window_size = + SHERPA_ONNX_OR(config->silero_vad.window_size, 512); + + vad_config.silero_vad.max_speech_duration = + SHERPA_ONNX_OR(config->silero_vad.max_speech_duration, 20); + + vad_config.sample_rate = SHERPA_ONNX_OR(config->sample_rate, 16000); + vad_config.num_threads = SHERPA_ONNX_OR(config->num_threads, 1); + vad_config.provider = SHERPA_ONNX_OR(config->provider, "cpu"); + if (vad_config.provider.empty()) { + vad_config.provider = "cpu"; + } + + vad_config.debug = SHERPA_ONNX_OR(config->debug, false); + + if (vad_config.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", vad_config.ToString().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", vad_config.ToString().c_str()); +#endif + } + + return vad_config; +} + +const SherpaMnnVoiceActivityDetector *SherpaMnnCreateVoiceActivityDetector( + const SherpaMnnVadModelConfig *config, float buffer_size_in_seconds) { + auto vad_config = GetVadModelConfig(config); + + if (!vad_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in config"); + return nullptr; + } + + SherpaMnnVoiceActivityDetector *p = new SherpaMnnVoiceActivityDetector; + p->impl = std::make_unique( + vad_config, buffer_size_in_seconds); + + return p; +} + +void SherpaMnnDestroyVoiceActivityDetector( + const SherpaMnnVoiceActivityDetector *p) { + delete p; +} + +void SherpaMnnVoiceActivityDetectorAcceptWaveform( + const SherpaMnnVoiceActivityDetector *p, const float *samples, int32_t n) { + p->impl->AcceptWaveform(samples, n); +} + +int32_t SherpaMnnVoiceActivityDetectorEmpty( + const SherpaMnnVoiceActivityDetector *p) { + return p->impl->Empty(); +} + +int32_t SherpaMnnVoiceActivityDetectorDetected( + const SherpaMnnVoiceActivityDetector *p) { + return p->impl->IsSpeechDetected(); +} + +void SherpaMnnVoiceActivityDetectorPop( + const SherpaMnnVoiceActivityDetector *p) { + p->impl->Pop(); +} + +void SherpaMnnVoiceActivityDetectorClear( + const SherpaMnnVoiceActivityDetector *p) { + p->impl->Clear(); +} + +const SherpaMnnSpeechSegment *SherpaMnnVoiceActivityDetectorFront( + const SherpaMnnVoiceActivityDetector *p) { + const sherpa_mnn::SpeechSegment &segment = p->impl->Front(); + + SherpaMnnSpeechSegment *ans = new SherpaMnnSpeechSegment; + ans->start = segment.start; + ans->samples = new float[segment.samples.size()]; + std::copy(segment.samples.begin(), segment.samples.end(), ans->samples); + ans->n = segment.samples.size(); + + return ans; +} + +void SherpaMnnDestroySpeechSegment(const SherpaMnnSpeechSegment *p) { + if (p) { + delete[] p->samples; + delete p; + } +} + +void SherpaMnnVoiceActivityDetectorReset( + const SherpaMnnVoiceActivityDetector *p) { + p->impl->Reset(); +} + +void SherpaMnnVoiceActivityDetectorFlush( + const SherpaMnnVoiceActivityDetector *p) { + p->impl->Flush(); +} + +#if SHERPA_MNN_ENABLE_TTS == 1 +struct SherpaMnnOfflineTts { + std::unique_ptr impl; +}; + +static sherpa_mnn::OfflineTtsConfig GetOfflineTtsConfig( + const SherpaMnnOfflineTtsConfig *config) { + sherpa_mnn::OfflineTtsConfig tts_config; + + // vits + tts_config.model.vits.model = SHERPA_ONNX_OR(config->model.vits.model, ""); + tts_config.model.vits.lexicon = + SHERPA_ONNX_OR(config->model.vits.lexicon, ""); + tts_config.model.vits.tokens = SHERPA_ONNX_OR(config->model.vits.tokens, ""); + tts_config.model.vits.data_dir = + SHERPA_ONNX_OR(config->model.vits.data_dir, ""); + tts_config.model.vits.noise_scale = + SHERPA_ONNX_OR(config->model.vits.noise_scale, 0.667); + tts_config.model.vits.noise_scale_w = + SHERPA_ONNX_OR(config->model.vits.noise_scale_w, 0.8); + tts_config.model.vits.length_scale = + SHERPA_ONNX_OR(config->model.vits.length_scale, 1.0); + tts_config.model.vits.dict_dir = + SHERPA_ONNX_OR(config->model.vits.dict_dir, ""); + + // matcha + tts_config.model.matcha.acoustic_model = + SHERPA_ONNX_OR(config->model.matcha.acoustic_model, ""); + tts_config.model.matcha.vocoder = + SHERPA_ONNX_OR(config->model.matcha.vocoder, ""); + tts_config.model.matcha.lexicon = + SHERPA_ONNX_OR(config->model.matcha.lexicon, ""); + tts_config.model.matcha.tokens = + SHERPA_ONNX_OR(config->model.matcha.tokens, ""); + tts_config.model.matcha.data_dir = + SHERPA_ONNX_OR(config->model.matcha.data_dir, ""); + tts_config.model.matcha.noise_scale = + SHERPA_ONNX_OR(config->model.matcha.noise_scale, 0.667); + tts_config.model.matcha.length_scale = + SHERPA_ONNX_OR(config->model.matcha.length_scale, 1.0); + tts_config.model.matcha.dict_dir = + SHERPA_ONNX_OR(config->model.matcha.dict_dir, ""); + + // kokoro + tts_config.model.kokoro.model = + SHERPA_ONNX_OR(config->model.kokoro.model, ""); + tts_config.model.kokoro.voices = + SHERPA_ONNX_OR(config->model.kokoro.voices, ""); + tts_config.model.kokoro.tokens = + SHERPA_ONNX_OR(config->model.kokoro.tokens, ""); + tts_config.model.kokoro.data_dir = + SHERPA_ONNX_OR(config->model.kokoro.data_dir, ""); + tts_config.model.kokoro.length_scale = + SHERPA_ONNX_OR(config->model.kokoro.length_scale, 1.0); + tts_config.model.kokoro.dict_dir = + SHERPA_ONNX_OR(config->model.kokoro.dict_dir, ""); + tts_config.model.kokoro.lexicon = + SHERPA_ONNX_OR(config->model.kokoro.lexicon, ""); + + tts_config.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1); + tts_config.model.debug = config->model.debug; + tts_config.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu"); + if (tts_config.model.provider.empty()) { + tts_config.model.provider = "cpu"; + } + + tts_config.rule_fsts = SHERPA_ONNX_OR(config->rule_fsts, ""); + tts_config.rule_fars = SHERPA_ONNX_OR(config->rule_fars, ""); + tts_config.max_num_sentences = SHERPA_ONNX_OR(config->max_num_sentences, 1); + tts_config.silence_scale = SHERPA_ONNX_OR(config->silence_scale, 0.2); + + if (tts_config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", tts_config.ToString().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", tts_config.ToString().c_str()); +#endif + } + + return tts_config; +} + +const SherpaMnnOfflineTts *SherpaMnnCreateOfflineTts( + const SherpaMnnOfflineTtsConfig *config) { + auto tts_config = GetOfflineTtsConfig(config); + + if (!tts_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in config"); + return nullptr; + } + + SherpaMnnOfflineTts *tts = new SherpaMnnOfflineTts; + + tts->impl = std::make_unique(tts_config); + + return tts; +} + +void SherpaMnnDestroyOfflineTts(const SherpaMnnOfflineTts *tts) { + delete tts; +} + +int32_t SherpaMnnOfflineTtsSampleRate(const SherpaMnnOfflineTts *tts) { + return tts->impl->SampleRate(); +} + +int32_t SherpaMnnOfflineTtsNumSpeakers(const SherpaMnnOfflineTts *tts) { + return tts->impl->NumSpeakers(); +} + +static const SherpaMnnGeneratedAudio *SherpaMnnOfflineTtsGenerateInternal( + const SherpaMnnOfflineTts *tts, const char *text, int32_t sid, float speed, + std::function callback) { + sherpa_mnn::GeneratedAudio audio = + tts->impl->Generate(text, sid, speed, callback); + + if (audio.samples.empty()) { + return nullptr; + } + + SherpaMnnGeneratedAudio *ans = new SherpaMnnGeneratedAudio; + + float *samples = new float[audio.samples.size()]; + std::copy(audio.samples.begin(), audio.samples.end(), samples); + + ans->samples = samples; + ans->n = audio.samples.size(); + ans->sample_rate = audio.sample_rate; + + return ans; +} + +const SherpaMnnGeneratedAudio *SherpaMnnOfflineTtsGenerate( + const SherpaMnnOfflineTts *tts, const char *text, int32_t sid, + float speed) { + return SherpaMnnOfflineTtsGenerateInternal(tts, text, sid, speed, nullptr); +} + +const SherpaMnnGeneratedAudio *SherpaMnnOfflineTtsGenerateWithCallback( + const SherpaMnnOfflineTts *tts, const char *text, int32_t sid, float speed, + SherpaMnnGeneratedAudioCallback callback) { + auto wrapper = [callback](const float *samples, int32_t n, + float /*progress*/) { + return callback(samples, n); + }; + + return SherpaMnnOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper); +} + +const SherpaMnnGeneratedAudio * +SherpaMnnOfflineTtsGenerateWithProgressCallback( + const SherpaMnnOfflineTts *tts, const char *text, int32_t sid, float speed, + SherpaMnnGeneratedAudioProgressCallback callback) { + auto wrapper = [callback](const float *samples, int32_t n, float progress) { + return callback(samples, n, progress); + }; + return SherpaMnnOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper); +} + +const SherpaMnnGeneratedAudio * +SherpaMnnOfflineTtsGenerateWithProgressCallbackWithArg( + const SherpaMnnOfflineTts *tts, const char *text, int32_t sid, float speed, + SherpaMnnGeneratedAudioProgressCallbackWithArg callback, void *arg) { + auto wrapper = [callback, arg](const float *samples, int32_t n, + float progress) { + return callback(samples, n, progress, arg); + }; + return SherpaMnnOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper); +} + +const SherpaMnnGeneratedAudio *SherpaMnnOfflineTtsGenerateWithCallbackWithArg( + const SherpaMnnOfflineTts *tts, const char *text, int32_t sid, float speed, + SherpaMnnGeneratedAudioCallbackWithArg callback, void *arg) { + auto wrapper = [callback, arg](const float *samples, int32_t n, + float /*progress*/) { + return callback(samples, n, arg); + }; + + return SherpaMnnOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper); +} + +void SherpaMnnDestroyOfflineTtsGeneratedAudio( + const SherpaMnnGeneratedAudio *p) { + if (p) { + delete[] p->samples; + delete p; + } +} +#else +const SherpaMnnOfflineTts *SherpaMnnCreateOfflineTts( + const SherpaMnnOfflineTtsConfig *config) { + SHERPA_ONNX_LOGE("TTS is not enabled. Please rebuild sherpa-mnn"); + return nullptr; +} + +void SherpaMnnDestroyOfflineTts(const SherpaMnnOfflineTts *tts) { + SHERPA_ONNX_LOGE("TTS is not enabled. Please rebuild sherpa-mnn"); +} + +int32_t SherpaMnnOfflineTtsSampleRate(const SherpaMnnOfflineTts *tts) { + SHERPA_ONNX_LOGE("TTS is not enabled. Please rebuild sherpa-mnn"); + return 0; +} + +int32_t SherpaMnnOfflineTtsNumSpeakers(const SherpaMnnOfflineTts *tts) { + SHERPA_ONNX_LOGE("TTS is not enabled. Please rebuild sherpa-mnn"); + return 0; +} + +const SherpaMnnGeneratedAudio *SherpaMnnOfflineTtsGenerate( + const SherpaMnnOfflineTts *tts, const char *text, int32_t sid, + float speed) { + SHERPA_ONNX_LOGE("TTS is not enabled. Please rebuild sherpa-mnn"); + return nullptr; +} + +const SherpaMnnGeneratedAudio *SherpaMnnOfflineTtsGenerateWithCallback( + const SherpaMnnOfflineTts *tts, const char *text, int32_t sid, float speed, + SherpaMnnGeneratedAudioCallback callback) { + SHERPA_ONNX_LOGE("TTS is not enabled. Please rebuild sherpa-mnn"); + return nullptr; +} + +const SherpaMnnGeneratedAudio * +SherpaMnnOfflineTtsGenerateWithProgressCallback( + const SherpaMnnOfflineTts *tts, const char *text, int32_t sid, float speed, + SherpaMnnGeneratedAudioProgressCallback callback) { + SHERPA_ONNX_LOGE("TTS is not enabled. Please rebuild sherpa-mnn"); + return nullptr; +} + +const SherpaMnnGeneratedAudio * +SherpaMnnOfflineTtsGenerateWithProgressCallbackWithArg( + const SherpaMnnOfflineTts *tts, const char *text, int32_t sid, float speed, + SherpaMnnGeneratedAudioProgressCallbackWithArg callback, void *arg) { + SHERPA_ONNX_LOGE("TTS is not enabled. Please rebuild sherpa-mnn"); + return nullptr; +} + +const SherpaMnnGeneratedAudio *SherpaMnnOfflineTtsGenerateWithCallbackWithArg( + const SherpaMnnOfflineTts *tts, const char *text, int32_t sid, float speed, + SherpaMnnGeneratedAudioCallbackWithArg callback, void *arg) { + SHERPA_ONNX_LOGE("TTS is not enabled. Please rebuild sherpa-mnn"); + return nullptr; +} + +void SherpaMnnDestroyOfflineTtsGeneratedAudio( + const SherpaMnnGeneratedAudio *p) { + SHERPA_ONNX_LOGE("TTS is not enabled. Please rebuild sherpa-mnn"); +} + +#endif // SHERPA_MNN_ENABLE_TTS == 1 + +int32_t SherpaMnnWriteWave(const float *samples, int32_t n, + int32_t sample_rate, const char *filename) { + return sherpa_mnn::WriteWave(filename, sample_rate, samples, n); +} + +int64_t SherpaMnnWaveFileSize(int32_t n_samples) { + return sherpa_mnn::WaveFileSize(n_samples); +} + +SHERPA_ONNX_API void SherpaMnnWriteWaveToBuffer(const float *samples, + int32_t n, int32_t sample_rate, + char *buffer) { + sherpa_mnn::WriteWave(buffer, sample_rate, samples, n); +} + +const SherpaMnnWave *SherpaMnnReadWave(const char *filename) { + int32_t sample_rate = -1; + bool is_ok = false; + std::vector samples = + sherpa_mnn::ReadWave(filename, &sample_rate, &is_ok); + if (!is_ok) { + return nullptr; + } + + float *c_samples = new float[samples.size()]; + std::copy(samples.begin(), samples.end(), c_samples); + + SherpaMnnWave *wave = new SherpaMnnWave; + wave->samples = c_samples; + wave->sample_rate = sample_rate; + wave->num_samples = samples.size(); + return wave; +} + +const SherpaMnnWave *SherpaMnnReadWaveFromBinaryData(const char *data, + int32_t n) { + int32_t sample_rate = -1; + bool is_ok = false; + + std::istrstream is(data, n); + + std::vector samples = sherpa_mnn::ReadWave(is, &sample_rate, &is_ok); + if (!is_ok) { + return nullptr; + } + + float *c_samples = new float[samples.size()]; + std::copy(samples.begin(), samples.end(), c_samples); + + SherpaMnnWave *wave = new SherpaMnnWave; + wave->samples = c_samples; + wave->sample_rate = sample_rate; + wave->num_samples = samples.size(); + return wave; +} + +void SherpaMnnFreeWave(const SherpaMnnWave *wave) { + if (wave) { + delete[] wave->samples; + delete wave; + } +} + +struct SherpaMnnSpokenLanguageIdentification { + std::unique_ptr impl; +}; + +const SherpaMnnSpokenLanguageIdentification * +SherpaMnnCreateSpokenLanguageIdentification( + const SherpaMnnSpokenLanguageIdentificationConfig *config) { + sherpa_mnn::SpokenLanguageIdentificationConfig slid_config; + slid_config.whisper.encoder = SHERPA_ONNX_OR(config->whisper.encoder, ""); + slid_config.whisper.decoder = SHERPA_ONNX_OR(config->whisper.decoder, ""); + slid_config.whisper.tail_paddings = + SHERPA_ONNX_OR(config->whisper.tail_paddings, -1); + slid_config.num_threads = SHERPA_ONNX_OR(config->num_threads, 1); + slid_config.debug = config->debug; + slid_config.provider = SHERPA_ONNX_OR(config->provider, "cpu"); + if (slid_config.provider.empty()) { + slid_config.provider = "cpu"; + } + + if (slid_config.debug) { + SHERPA_ONNX_LOGE("%s\n", slid_config.ToString().c_str()); + } + + if (!slid_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in config"); + return nullptr; + } + + SherpaMnnSpokenLanguageIdentification *slid = + new SherpaMnnSpokenLanguageIdentification; + slid->impl = + std::make_unique(slid_config); + + return slid; +} + +void SherpaMnnDestroySpokenLanguageIdentification( + const SherpaMnnSpokenLanguageIdentification *slid) { + delete slid; +} + +SherpaMnnOfflineStream * +SherpaMnnSpokenLanguageIdentificationCreateOfflineStream( + const SherpaMnnSpokenLanguageIdentification *slid) { + SherpaMnnOfflineStream *stream = + new SherpaMnnOfflineStream(slid->impl->CreateStream()); + return stream; +} + +const SherpaMnnSpokenLanguageIdentificationResult * +SherpaMnnSpokenLanguageIdentificationCompute( + const SherpaMnnSpokenLanguageIdentification *slid, + const SherpaMnnOfflineStream *s) { + std::string lang = slid->impl->Compute(s->impl.get()); + char *c_lang = new char[lang.size() + 1]; + std::copy(lang.begin(), lang.end(), c_lang); + c_lang[lang.size()] = '\0'; + SherpaMnnSpokenLanguageIdentificationResult *r = + new SherpaMnnSpokenLanguageIdentificationResult; + r->lang = c_lang; + return r; +} + +void SherpaMnnDestroySpokenLanguageIdentificationResult( + const SherpaMnnSpokenLanguageIdentificationResult *r) { + if (r) { + delete[] r->lang; + delete r; + } +} + +struct SherpaMnnSpeakerEmbeddingExtractor { + std::unique_ptr impl; +}; + +static sherpa_mnn::SpeakerEmbeddingExtractorConfig +GetSpeakerEmbeddingExtractorConfig( + const SherpaMnnSpeakerEmbeddingExtractorConfig *config) { + sherpa_mnn::SpeakerEmbeddingExtractorConfig c; + c.model = SHERPA_ONNX_OR(config->model, ""); + + c.num_threads = SHERPA_ONNX_OR(config->num_threads, 1); + c.debug = SHERPA_ONNX_OR(config->debug, 0); + c.provider = SHERPA_ONNX_OR(config->provider, "cpu"); + if (c.provider.empty()) { + c.provider = "cpu"; + } + + if (config->debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", c.ToString().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", c.ToString().c_str()); +#endif + } + + return c; +} + +const SherpaMnnSpeakerEmbeddingExtractor * +SherpaMnnCreateSpeakerEmbeddingExtractor( + const SherpaMnnSpeakerEmbeddingExtractorConfig *config) { + auto c = GetSpeakerEmbeddingExtractorConfig(config); + + if (!c.Validate()) { + SHERPA_ONNX_LOGE("Errors in config!"); + return nullptr; + } + + auto p = new SherpaMnnSpeakerEmbeddingExtractor; + + p->impl = std::make_unique(c); + + return p; +} + +void SherpaMnnDestroySpeakerEmbeddingExtractor( + const SherpaMnnSpeakerEmbeddingExtractor *p) { + delete p; +} + +int32_t SherpaMnnSpeakerEmbeddingExtractorDim( + const SherpaMnnSpeakerEmbeddingExtractor *p) { + return p->impl->Dim(); +} + +const SherpaMnnOnlineStream *SherpaMnnSpeakerEmbeddingExtractorCreateStream( + const SherpaMnnSpeakerEmbeddingExtractor *p) { + SherpaMnnOnlineStream *stream = + new SherpaMnnOnlineStream(p->impl->CreateStream()); + return stream; +} + +int32_t SherpaMnnSpeakerEmbeddingExtractorIsReady( + const SherpaMnnSpeakerEmbeddingExtractor *p, + const SherpaMnnOnlineStream *s) { + return p->impl->IsReady(s->impl.get()); +} + +const float *SherpaMnnSpeakerEmbeddingExtractorComputeEmbedding( + const SherpaMnnSpeakerEmbeddingExtractor *p, + const SherpaMnnOnlineStream *s) { + std::vector v = p->impl->Compute(s->impl.get()); + float *ans = new float[v.size()]; + std::copy(v.begin(), v.end(), ans); + return ans; +} + +void SherpaMnnSpeakerEmbeddingExtractorDestroyEmbedding(const float *v) { + delete[] v; +} + +struct SherpaMnnSpeakerEmbeddingManager { + std::unique_ptr impl; +}; + +const SherpaMnnSpeakerEmbeddingManager * +SherpaMnnCreateSpeakerEmbeddingManager(int32_t dim) { + auto p = new SherpaMnnSpeakerEmbeddingManager; + p->impl = std::make_unique(dim); + return p; +} + +void SherpaMnnDestroySpeakerEmbeddingManager( + const SherpaMnnSpeakerEmbeddingManager *p) { + delete p; +} + +int32_t SherpaMnnSpeakerEmbeddingManagerAdd( + const SherpaMnnSpeakerEmbeddingManager *p, const char *name, + const float *v) { + return p->impl->Add(name, v); +} + +int32_t SherpaMnnSpeakerEmbeddingManagerAddList( + const SherpaMnnSpeakerEmbeddingManager *p, const char *name, + const float **v) { + int32_t n = 0; + auto q = v; + while (q && q[0]) { + ++n; + ++q; + } + + if (n == 0) { + SHERPA_ONNX_LOGE("Empty embedding!"); + return 0; + } + + std::vector> vec(n); + int32_t dim = p->impl->Dim(); + + for (int32_t i = 0; i != n; ++i) { + vec[i] = std::vector(v[i], v[i] + dim); + } + + return p->impl->Add(name, vec); +} + +int32_t SherpaMnnSpeakerEmbeddingManagerAddListFlattened( + const SherpaMnnSpeakerEmbeddingManager *p, const char *name, + const float *v, int32_t n) { + std::vector> vec(n); + + int32_t dim = p->impl->Dim(); + + for (int32_t i = 0; i != n; ++i, v += dim) { + vec[i] = std::vector(v, v + dim); + } + + return p->impl->Add(name, vec); +} + +int32_t SherpaMnnSpeakerEmbeddingManagerRemove( + const SherpaMnnSpeakerEmbeddingManager *p, const char *name) { + return p->impl->Remove(name); +} + +const char *SherpaMnnSpeakerEmbeddingManagerSearch( + const SherpaMnnSpeakerEmbeddingManager *p, const float *v, + float threshold) { + auto r = p->impl->Search(v, threshold); + if (r.empty()) { + return nullptr; + } + + char *name = new char[r.size() + 1]; + std::copy(r.begin(), r.end(), name); + name[r.size()] = '\0'; + + return name; +} + +void SherpaMnnSpeakerEmbeddingManagerFreeSearch(const char *name) { + delete[] name; +} + +const SherpaMnnSpeakerEmbeddingManagerBestMatchesResult * +SherpaMnnSpeakerEmbeddingManagerGetBestMatches( + const SherpaMnnSpeakerEmbeddingManager *p, const float *v, float threshold, + int32_t n) { + auto matches = p->impl->GetBestMatches(v, threshold, n); + + if (matches.empty()) { + return nullptr; + } + + auto resultMatches = + new SherpaMnnSpeakerEmbeddingManagerSpeakerMatch[matches.size()]; + for (int i = 0; i < matches.size(); ++i) { + resultMatches[i].score = matches[i].score; + + char *name = new char[matches[i].name.size() + 1]; + std::copy(matches[i].name.begin(), matches[i].name.end(), name); + name[matches[i].name.size()] = '\0'; + + resultMatches[i].name = name; + } + + auto *result = new SherpaMnnSpeakerEmbeddingManagerBestMatchesResult(); + result->count = matches.size(); + result->matches = resultMatches; + + return result; +} + +void SherpaMnnSpeakerEmbeddingManagerFreeBestMatches( + const SherpaMnnSpeakerEmbeddingManagerBestMatchesResult *r) { + if (r == nullptr) { + return; + } + + for (int32_t i = 0; i < r->count; ++i) { + delete[] r->matches[i].name; + } + delete[] r->matches; + delete r; +} + +int32_t SherpaMnnSpeakerEmbeddingManagerVerify( + const SherpaMnnSpeakerEmbeddingManager *p, const char *name, + const float *v, float threshold) { + return p->impl->Verify(name, v, threshold); +} + +int32_t SherpaMnnSpeakerEmbeddingManagerContains( + const SherpaMnnSpeakerEmbeddingManager *p, const char *name) { + return p->impl->Contains(name); +} + +int32_t SherpaMnnSpeakerEmbeddingManagerNumSpeakers( + const SherpaMnnSpeakerEmbeddingManager *p) { + return p->impl->NumSpeakers(); +} + +const char *const *SherpaMnnSpeakerEmbeddingManagerGetAllSpeakers( + const SherpaMnnSpeakerEmbeddingManager *manager) { + std::vector all_speakers = manager->impl->GetAllSpeakers(); + int32_t num_speakers = all_speakers.size(); + char **p = new char *[num_speakers + 1]; + p[num_speakers] = nullptr; + + int32_t i = 0; + for (const auto &name : all_speakers) { + p[i] = new char[name.size() + 1]; + std::copy(name.begin(), name.end(), p[i]); + p[i][name.size()] = '\0'; + + i += 1; + } + return p; +} + +void SherpaMnnSpeakerEmbeddingManagerFreeAllSpeakers( + const char *const *names) { + auto p = names; + + while (p && p[0]) { + delete[] p[0]; + ++p; + } + + delete[] names; +} + +struct SherpaMnnAudioTagging { + std::unique_ptr impl; +}; + +const SherpaMnnAudioTagging *SherpaMnnCreateAudioTagging( + const SherpaMnnAudioTaggingConfig *config) { + sherpa_mnn::AudioTaggingConfig ac; + ac.model.zipformer.model = SHERPA_ONNX_OR(config->model.zipformer.model, ""); + ac.model.ced = SHERPA_ONNX_OR(config->model.ced, ""); + ac.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1); + ac.model.debug = config->model.debug; + ac.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu"); + if (ac.model.provider.empty()) { + ac.model.provider = "cpu"; + } + + ac.labels = SHERPA_ONNX_OR(config->labels, ""); + ac.top_k = SHERPA_ONNX_OR(config->top_k, 5); + + if (ac.model.debug) { + SHERPA_ONNX_LOGE("%s\n", ac.ToString().c_str()); + } + + if (!ac.Validate()) { + SHERPA_ONNX_LOGE("Errors in config"); + return nullptr; + } + + SherpaMnnAudioTagging *tagger = new SherpaMnnAudioTagging; + tagger->impl = std::make_unique(ac); + + return tagger; +} + +void SherpaMnnDestroyAudioTagging(const SherpaMnnAudioTagging *tagger) { + delete tagger; +} + +const SherpaMnnOfflineStream *SherpaMnnAudioTaggingCreateOfflineStream( + const SherpaMnnAudioTagging *tagger) { + const SherpaMnnOfflineStream *stream = + new SherpaMnnOfflineStream(tagger->impl->CreateStream()); + return stream; +} + +const SherpaMnnAudioEvent *const *SherpaMnnAudioTaggingCompute( + const SherpaMnnAudioTagging *tagger, const SherpaMnnOfflineStream *s, + int32_t top_k) { + std::vector events = + tagger->impl->Compute(s->impl.get(), top_k); + + int32_t n = static_cast(events.size()); + SherpaMnnAudioEvent **ans = new SherpaMnnAudioEvent *[n + 1]; + ans[n] = nullptr; + + int32_t i = 0; + for (const auto &e : events) { + SherpaMnnAudioEvent *p = new SherpaMnnAudioEvent; + + char *name = new char[e.name.size() + 1]; + std::copy(e.name.begin(), e.name.end(), name); + name[e.name.size()] = 0; + + p->name = name; + + p->index = e.index; + p->prob = e.prob; + + ans[i] = p; + i += 1; + } + + return ans; +} + +void SherpaMnnAudioTaggingFreeResults( + const SherpaMnnAudioEvent *const *events) { + auto p = events; + + while (p && *p) { + auto e = *p; + + delete[] e->name; + delete e; + + ++p; + } + + delete[] events; +} + +struct SherpaMnnOfflinePunctuation { + std::unique_ptr impl; +}; + +const SherpaMnnOfflinePunctuation *SherpaMnnCreateOfflinePunctuation( + const SherpaMnnOfflinePunctuationConfig *config) { + sherpa_mnn::OfflinePunctuationConfig c; + c.model.ct_transformer = SHERPA_ONNX_OR(config->model.ct_transformer, ""); + c.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1); + c.model.debug = config->model.debug; + c.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu"); + if (c.model.provider.empty()) { + c.model.provider = "cpu"; + } + + if (c.model.debug) { + SHERPA_ONNX_LOGE("%s\n", c.ToString().c_str()); + } + + if (!c.Validate()) { + SHERPA_ONNX_LOGE("Errors in config"); + return nullptr; + } + + SherpaMnnOfflinePunctuation *punct = new SherpaMnnOfflinePunctuation; + punct->impl = std::make_unique(c); + + return punct; +} + +void SherpaMnnDestroyOfflinePunctuation( + const SherpaMnnOfflinePunctuation *punct) { + delete punct; +} + +const char *SherpaOfflinePunctuationAddPunct( + const SherpaMnnOfflinePunctuation *punct, const char *text) { + std::string text_with_punct = punct->impl->AddPunctuation(text); + + char *ans = new char[text_with_punct.size() + 1]; + std::copy(text_with_punct.begin(), text_with_punct.end(), ans); + ans[text_with_punct.size()] = 0; + + return ans; +} + +void SherpaOfflinePunctuationFreeText(const char *text) { delete[] text; } + +struct SherpaMnnOnlinePunctuation { + std::unique_ptr impl; +}; + +const SherpaMnnOnlinePunctuation *SherpaMnnCreateOnlinePunctuation( + const SherpaMnnOnlinePunctuationConfig *config) { + auto p = new SherpaMnnOnlinePunctuation; + try { + sherpa_mnn::OnlinePunctuationConfig punctuation_config; + punctuation_config.model.cnn_bilstm = + SHERPA_ONNX_OR(config->model.cnn_bilstm, ""); + punctuation_config.model.bpe_vocab = + SHERPA_ONNX_OR(config->model.bpe_vocab, ""); + punctuation_config.model.num_threads = + SHERPA_ONNX_OR(config->model.num_threads, 1); + punctuation_config.model.debug = config->model.debug; + punctuation_config.model.provider = + SHERPA_ONNX_OR(config->model.provider, "cpu"); + + p->impl = + std::make_unique(punctuation_config); + } catch (const std::exception &e) { + SHERPA_ONNX_LOGE("Failed to create online punctuation: %s", e.what()); + delete p; + return nullptr; + } + return p; +} + +void SherpaMnnDestroyOnlinePunctuation(const SherpaMnnOnlinePunctuation *p) { + delete p; +} + +const char *SherpaMnnOnlinePunctuationAddPunct( + const SherpaMnnOnlinePunctuation *punctuation, const char *text) { + if (!punctuation || !text) return nullptr; + + try { + std::string s = punctuation->impl->AddPunctuationWithCase(text); + char *p = new char[s.size() + 1]; + std::copy(s.begin(), s.end(), p); + p[s.size()] = '\0'; + return p; + } catch (const std::exception &e) { + SHERPA_ONNX_LOGE("Failed to add punctuation: %s", e.what()); + return nullptr; + } +} + +void SherpaMnnOnlinePunctuationFreeText(const char *text) { delete[] text; } + +struct SherpaMnnLinearResampler { + std::unique_ptr impl; +}; + +const SherpaMnnLinearResampler *SherpaMnnCreateLinearResampler( + int32_t samp_rate_in_hz, int32_t samp_rate_out_hz, float filter_cutoff_hz, + int32_t num_zeros) { + SherpaMnnLinearResampler *p = new SherpaMnnLinearResampler; + p->impl = std::make_unique( + samp_rate_in_hz, samp_rate_out_hz, filter_cutoff_hz, num_zeros); + + return p; +} + +void SherpaMnnDestroyLinearResampler(const SherpaMnnLinearResampler *p) { + delete p; +} + +const SherpaMnnResampleOut *SherpaMnnLinearResamplerResample( + const SherpaMnnLinearResampler *p, const float *input, int32_t input_dim, + int32_t flush) { + std::vector o; + p->impl->Resample(input, input_dim, flush, &o); + + float *s = new float[o.size()]; + std::copy(o.begin(), o.end(), s); + + SherpaMnnResampleOut *ans = new SherpaMnnResampleOut; + ans->samples = s; + ans->n = static_cast(o.size()); + + return ans; +} + +void SherpaMnnLinearResamplerResampleFree(const SherpaMnnResampleOut *p) { + delete[] p->samples; + delete p; +} + +int32_t SherpaMnnLinearResamplerResampleGetInputSampleRate( + const SherpaMnnLinearResampler *p) { + return p->impl->GetInputSamplingRate(); +} + +int32_t SherpaMnnLinearResamplerResampleGetOutputSampleRate( + const SherpaMnnLinearResampler *p) { + return p->impl->GetOutputSamplingRate(); +} + +void SherpaMnnLinearResamplerReset(SherpaMnnLinearResampler *p) { + p->impl->Reset(); +} + +int32_t SherpaMnnFileExists(const char *filename) { + return sherpa_mnn::FileExists(filename); +} + +struct SherpaMnnOfflineSpeechDenoiser { + std::unique_ptr impl; +}; + +static sherpa_mnn::OfflineSpeechDenoiserConfig GetOfflineSpeechDenoiserConfig( + const SherpaMnnOfflineSpeechDenoiserConfig *config) { + sherpa_mnn::OfflineSpeechDenoiserConfig c; + c.model.gtcrn.model = SHERPA_ONNX_OR(config->model.gtcrn.model, ""); + c.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1); + c.model.debug = config->model.debug; + c.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu"); + + if (c.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", c.ToString().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", c.ToString().c_str()); +#endif + } + + return c; +} + +const SherpaMnnOfflineSpeechDenoiser *SherpaMnnCreateOfflineSpeechDenoiser( + const SherpaMnnOfflineSpeechDenoiserConfig *config) { + auto sd_config = GetOfflineSpeechDenoiserConfig(config); + + if (!sd_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in config"); + return nullptr; + } + + SherpaMnnOfflineSpeechDenoiser *sd = new SherpaMnnOfflineSpeechDenoiser; + + sd->impl = std::make_unique(sd_config); + + return sd; +} + +void SherpaMnnDestroyOfflineSpeechDenoiser( + const SherpaMnnOfflineSpeechDenoiser *sd) { + delete sd; +} + +int32_t SherpaMnnOfflineSpeechDenoiserGetSampleRate( + const SherpaMnnOfflineSpeechDenoiser *sd) { + return sd->impl->GetSampleRate(); +} + +const SherpaMnnDenoisedAudio *SherpaMnnOfflineSpeechDenoiserRun( + const SherpaMnnOfflineSpeechDenoiser *sd, const float *samples, int32_t n, + int32_t sample_rate) { + auto audio = sd->impl->Run(samples, n, sample_rate); + + auto ans = new SherpaMnnDenoisedAudio; + + float *denoised_samples = new float[audio.samples.size()]; + std::copy(audio.samples.begin(), audio.samples.end(), denoised_samples); + + ans->samples = denoised_samples; + ans->n = audio.samples.size(); + ans->sample_rate = audio.sample_rate; + + return ans; +} + +void SherpaMnnDestroyDenoisedAudio(const SherpaMnnDenoisedAudio *p) { + delete[] p->samples; + delete p; +} + +#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1 + +struct SherpaMnnOfflineSpeakerDiarization { + std::unique_ptr impl; +}; + +struct SherpaMnnOfflineSpeakerDiarizationResult { + sherpa_mnn::OfflineSpeakerDiarizationResult impl; +}; + +static sherpa_mnn::OfflineSpeakerDiarizationConfig +GetOfflineSpeakerDiarizationConfig( + const SherpaMnnOfflineSpeakerDiarizationConfig *config) { + sherpa_mnn::OfflineSpeakerDiarizationConfig sd_config; + + sd_config.segmentation.pyannote.model = + SHERPA_ONNX_OR(config->segmentation.pyannote.model, ""); + sd_config.segmentation.num_threads = + SHERPA_ONNX_OR(config->segmentation.num_threads, 1); + sd_config.segmentation.debug = config->segmentation.debug; + sd_config.segmentation.provider = + SHERPA_ONNX_OR(config->segmentation.provider, "cpu"); + if (sd_config.segmentation.provider.empty()) { + sd_config.segmentation.provider = "cpu"; + } + + sd_config.embedding.model = SHERPA_ONNX_OR(config->embedding.model, ""); + sd_config.embedding.num_threads = + SHERPA_ONNX_OR(config->embedding.num_threads, 1); + sd_config.embedding.debug = config->embedding.debug; + sd_config.embedding.provider = + SHERPA_ONNX_OR(config->embedding.provider, "cpu"); + if (sd_config.embedding.provider.empty()) { + sd_config.embedding.provider = "cpu"; + } + + sd_config.clustering.num_clusters = + SHERPA_ONNX_OR(config->clustering.num_clusters, -1); + + sd_config.clustering.threshold = + SHERPA_ONNX_OR(config->clustering.threshold, 0.5); + + sd_config.min_duration_on = SHERPA_ONNX_OR(config->min_duration_on, 0.3); + + sd_config.min_duration_off = SHERPA_ONNX_OR(config->min_duration_off, 0.5); + + if (sd_config.segmentation.debug || sd_config.embedding.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", sd_config.ToString().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", sd_config.ToString().c_str()); +#endif + } + + return sd_config; +} + +const SherpaMnnOfflineSpeakerDiarization * +SherpaMnnCreateOfflineSpeakerDiarization( + const SherpaMnnOfflineSpeakerDiarizationConfig *config) { + auto sd_config = GetOfflineSpeakerDiarizationConfig(config); + + if (!sd_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in config"); + return nullptr; + } + + SherpaMnnOfflineSpeakerDiarization *sd = + new SherpaMnnOfflineSpeakerDiarization; + + sd->impl = + std::make_unique(sd_config); + + return sd; +} + +void SherpaMnnDestroyOfflineSpeakerDiarization( + const SherpaMnnOfflineSpeakerDiarization *sd) { + delete sd; +} + +int32_t SherpaMnnOfflineSpeakerDiarizationGetSampleRate( + const SherpaMnnOfflineSpeakerDiarization *sd) { + return sd->impl->SampleRate(); +} + +void SherpaMnnOfflineSpeakerDiarizationSetConfig( + const SherpaMnnOfflineSpeakerDiarization *sd, + const SherpaMnnOfflineSpeakerDiarizationConfig *config) { + sherpa_mnn::OfflineSpeakerDiarizationConfig sd_config; + + sd_config.clustering.num_clusters = + SHERPA_ONNX_OR(config->clustering.num_clusters, -1); + + sd_config.clustering.threshold = + SHERPA_ONNX_OR(config->clustering.threshold, 0.5); + + sd->impl->SetConfig(sd_config); +} + +int32_t SherpaMnnOfflineSpeakerDiarizationResultGetNumSpeakers( + const SherpaMnnOfflineSpeakerDiarizationResult *r) { + return r->impl.NumSpeakers(); +} + +int32_t SherpaMnnOfflineSpeakerDiarizationResultGetNumSegments( + const SherpaMnnOfflineSpeakerDiarizationResult *r) { + return r->impl.NumSegments(); +} + +const SherpaMnnOfflineSpeakerDiarizationSegment * +SherpaMnnOfflineSpeakerDiarizationResultSortByStartTime( + const SherpaMnnOfflineSpeakerDiarizationResult *r) { + if (r->impl.NumSegments() == 0) { + return nullptr; + } + + auto segments = r->impl.SortByStartTime(); + + int32_t n = segments.size(); + SherpaMnnOfflineSpeakerDiarizationSegment *ans = + new SherpaMnnOfflineSpeakerDiarizationSegment[n]; + + for (int32_t i = 0; i != n; ++i) { + const auto &s = segments[i]; + + ans[i].start = s.Start(); + ans[i].end = s.End(); + ans[i].speaker = s.Speaker(); + } + + return ans; +} + +void SherpaMnnOfflineSpeakerDiarizationDestroySegment( + const SherpaMnnOfflineSpeakerDiarizationSegment *s) { + delete[] s; +} + +const SherpaMnnOfflineSpeakerDiarizationResult * +SherpaMnnOfflineSpeakerDiarizationProcess( + const SherpaMnnOfflineSpeakerDiarization *sd, const float *samples, + int32_t n) { + auto ans = new SherpaMnnOfflineSpeakerDiarizationResult; + ans->impl = sd->impl->Process(samples, n); + + return ans; +} + +void SherpaMnnOfflineSpeakerDiarizationDestroyResult( + const SherpaMnnOfflineSpeakerDiarizationResult *r) { + delete r; +} + +const SherpaMnnOfflineSpeakerDiarizationResult * +SherpaMnnOfflineSpeakerDiarizationProcessWithCallback( + const SherpaMnnOfflineSpeakerDiarization *sd, const float *samples, + int32_t n, SherpaMnnOfflineSpeakerDiarizationProgressCallback callback, + void *arg) { + auto ans = new SherpaMnnOfflineSpeakerDiarizationResult; + ans->impl = sd->impl->Process(samples, n, callback, arg); + + return ans; +} + +const SherpaMnnOfflineSpeakerDiarizationResult * +SherpaMnnOfflineSpeakerDiarizationProcessWithCallbackNoArg( + const SherpaMnnOfflineSpeakerDiarization *sd, const float *samples, + int32_t n, + SherpaMnnOfflineSpeakerDiarizationProgressCallbackNoArg callback) { + auto wrapper = [callback](int32_t num_processed_chunks, + int32_t num_total_chunks, void *) { + return callback(num_processed_chunks, num_total_chunks); + }; + + auto ans = new SherpaMnnOfflineSpeakerDiarizationResult; + ans->impl = sd->impl->Process(samples, n, wrapper); + + return ans; +} +#else + +const SherpaMnnOfflineSpeakerDiarization * +SherpaMnnCreateOfflineSpeakerDiarization( + const SherpaMnnOfflineSpeakerDiarizationConfig *config) { + SHERPA_ONNX_LOGE( + "Speaker diarization is not enabled. Please rebuild sherpa-mnn"); + return nullptr; +} + +void SherpaMnnDestroyOfflineSpeakerDiarization( + const SherpaMnnOfflineSpeakerDiarization *sd) { + SHERPA_ONNX_LOGE( + "Speaker diarization is not enabled. Please rebuild sherpa-mnn"); +} + +int32_t SherpaMnnOfflineSpeakerDiarizationGetSampleRate( + const SherpaMnnOfflineSpeakerDiarization *sd) { + SHERPA_ONNX_LOGE( + "Speaker diarization is not enabled. Please rebuild sherpa-mnn"); + return 0; +} + +void SherpaMnnOfflineSpeakerDiarizationSetConfig( + const SherpaMnnOfflineSpeakerDiarization *sd, + const SherpaMnnOfflineSpeakerDiarizationConfig *config) { + SHERPA_ONNX_LOGE( + "Speaker diarization is not enabled. Please rebuild sherpa-mnn"); +} + +int32_t SherpaMnnOfflineSpeakerDiarizationResultGetNumSpeakers( + const SherpaMnnOfflineSpeakerDiarizationResult *r) { + SHERPA_ONNX_LOGE( + "Speaker diarization is not enabled. Please rebuild sherpa-mnn"); + return 0; +} + +int32_t SherpaMnnOfflineSpeakerDiarizationResultGetNumSegments( + const SherpaMnnOfflineSpeakerDiarizationResult *r) { + SHERPA_ONNX_LOGE( + "Speaker diarization is not enabled. Please rebuild sherpa-mnn"); + return 0; +} + +const SherpaMnnOfflineSpeakerDiarizationSegment * +SherpaMnnOfflineSpeakerDiarizationResultSortByStartTime( + const SherpaMnnOfflineSpeakerDiarizationResult *r) { + SHERPA_ONNX_LOGE( + "Speaker diarization is not enabled. Please rebuild sherpa-mnn"); + return nullptr; +} + +void SherpaMnnOfflineSpeakerDiarizationDestroySegment( + const SherpaMnnOfflineSpeakerDiarizationSegment *s) { + SHERPA_ONNX_LOGE( + "Speaker diarization is not enabled. Please rebuild sherpa-mnn"); +} + +const SherpaMnnOfflineSpeakerDiarizationResult * +SherpaMnnOfflineSpeakerDiarizationProcess( + const SherpaMnnOfflineSpeakerDiarization *sd, const float *samples, + int32_t n) { + SHERPA_ONNX_LOGE( + "Speaker diarization is not enabled. Please rebuild sherpa-mnn"); + return nullptr; +} + +const SherpaMnnOfflineSpeakerDiarizationResult * +SherpaMnnOfflineSpeakerDiarizationProcessWithCallback( + const SherpaMnnOfflineSpeakerDiarization *sd, const float *samples, + int32_t n, SherpaMnnOfflineSpeakerDiarizationProgressCallback callback, + void *arg) { + SHERPA_ONNX_LOGE( + "Speaker diarization is not enabled. Please rebuild sherpa-mnn"); + return nullptr; +} + +const SherpaMnnOfflineSpeakerDiarizationResult * +SherpaMnnOfflineSpeakerDiarizationProcessWithCallbackNoArg( + const SherpaMnnOfflineSpeakerDiarization *sd, const float *samples, + int32_t n, + SherpaMnnOfflineSpeakerDiarizationProgressCallbackNoArg callback) { + SHERPA_ONNX_LOGE( + "Speaker diarization is not enabled. Please rebuild sherpa-mnn"); + return nullptr; +} + +void SherpaMnnOfflineSpeakerDiarizationDestroyResult( + const SherpaMnnOfflineSpeakerDiarizationResult *r) { + SHERPA_ONNX_LOGE( + "Speaker diarization is not enabled. Please rebuild sherpa-mnn"); +} + +#endif + +#ifdef __OHOS__ + +const SherpaMnnOfflineSpeechDenoiser * +SherpaMnnCreateOfflineSpeechDenoiserOHOS( + const SherpaMnnOfflineSpeechDenoiserConfig *config, + NativeResourceManager *mgr) { + auto sd_config = GetOfflineSpeechDenoiserConfig(config); + + SherpaMnnOfflineSpeechDenoiser *sd = new SherpaMnnOfflineSpeechDenoiser; + + sd->impl = std::make_unique(sd_config); + + return sd; +} + +const SherpaMnnOnlineRecognizer *SherpaMnnCreateOnlineRecognizerOHOS( + const SherpaMnnOnlineRecognizerConfig *config, + NativeResourceManager *mgr) { + sherpa_mnn::OnlineRecognizerConfig recognizer_config = + GetOnlineRecognizerConfig(config); + + SherpaMnnOnlineRecognizer *recognizer = new SherpaMnnOnlineRecognizer; + + recognizer->impl = + std::make_unique(mgr, recognizer_config); + + return recognizer; +} + +const SherpaMnnOfflineRecognizer *SherpaMnnCreateOfflineRecognizerOHOS( + const SherpaMnnOfflineRecognizerConfig *config, + NativeResourceManager *mgr) { + if (mgr == nullptr) { + return SherpaMnnCreateOfflineRecognizer(config); + } + + sherpa_mnn::OfflineRecognizerConfig recognizer_config = + GetOfflineRecognizerConfig(config); + + SherpaMnnOfflineRecognizer *recognizer = new SherpaMnnOfflineRecognizer; + + recognizer->impl = + std::make_unique(mgr, recognizer_config); + + return recognizer; +} + +const SherpaMnnVoiceActivityDetector * +SherpaMnnCreateVoiceActivityDetectorOHOS( + const SherpaMnnVadModelConfig *config, float buffer_size_in_seconds, + NativeResourceManager *mgr) { + if (mgr == nullptr) { + return SherpaMnnCreateVoiceActivityDetector(config, + buffer_size_in_seconds); + } + + auto vad_config = GetVadModelConfig(config); + + SherpaMnnVoiceActivityDetector *p = new SherpaMnnVoiceActivityDetector; + p->impl = std::make_unique( + mgr, vad_config, buffer_size_in_seconds); + + return p; +} + +const SherpaMnnSpeakerEmbeddingExtractor * +SherpaMnnCreateSpeakerEmbeddingExtractorOHOS( + const SherpaMnnSpeakerEmbeddingExtractorConfig *config, + NativeResourceManager *mgr) { + if (!mgr) { + return SherpaMnnCreateSpeakerEmbeddingExtractor(config); + } + + auto c = GetSpeakerEmbeddingExtractorConfig(config); + + auto p = new SherpaMnnSpeakerEmbeddingExtractor; + + p->impl = std::make_unique(mgr, c); + + return p; +} + +const SherpaMnnKeywordSpotter *SherpaMnnCreateKeywordSpotterOHOS( + const SherpaMnnKeywordSpotterConfig *config, NativeResourceManager *mgr) { + if (!mgr) { + return SherpaMnnCreateKeywordSpotter(config); + } + + auto spotter_config = GetKeywordSpotterConfig(config); + + SherpaMnnKeywordSpotter *spotter = new SherpaMnnKeywordSpotter; + + spotter->impl = + std::make_unique(mgr, spotter_config); + + return spotter; +} + +#if SHERPA_MNN_ENABLE_TTS == 1 +const SherpaMnnOfflineTts *SherpaMnnCreateOfflineTtsOHOS( + const SherpaMnnOfflineTtsConfig *config, NativeResourceManager *mgr) { + if (!mgr) { + return SherpaMnnCreateOfflineTts(config); + } + + auto tts_config = GetOfflineTtsConfig(config); + + SherpaMnnOfflineTts *tts = new SherpaMnnOfflineTts; + + tts->impl = std::make_unique(mgr, tts_config); + + return tts; +} +#else +const SherpaMnnOfflineTts *SherpaMnnCreateOfflineTtsOHOS( + const SherpaMnnOfflineTtsConfig *config, NativeResourceManager *mgr) { + SHERPA_ONNX_LOGE("TTS is not enabled. Please rebuild sherpa-mnn"); + return nullptr; +} +#endif // #if SHERPA_MNN_ENABLE_TTS == 1 + // +#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1 +const SherpaMnnOfflineSpeakerDiarization * +SherpaMnnCreateOfflineSpeakerDiarizationOHOS( + const SherpaMnnOfflineSpeakerDiarizationConfig *config, + NativeResourceManager *mgr) { + if (!mgr) { + return SherpaMnnCreateOfflineSpeakerDiarization(config); + } + + auto sd_config = GetOfflineSpeakerDiarizationConfig(config); + + SherpaMnnOfflineSpeakerDiarization *sd = + new SherpaMnnOfflineSpeakerDiarization; + + sd->impl = + std::make_unique(mgr, sd_config); + + return sd; +} +#else + +const SherpaMnnOfflineSpeakerDiarization * +SherpaMnnCreateOfflineSpeakerDiarizationOHOS( + const SherpaMnnOfflineSpeakerDiarizationConfig *config, + NativeResourceManager *mgr) { + SHERPA_ONNX_LOGE( + "Speaker diarization is not enabled. Please rebuild sherpa-mnn"); + return nullptr; +} + +#endif // #if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1 + +#endif // #ifdef __OHOS__ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/c-api/c-api.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/c-api/c-api.h new file mode 100644 index 00000000..611db03c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/c-api/c-api.h @@ -0,0 +1,1759 @@ +// sherpa-mnn/c-api/c-api.h +// +// Copyright (c) 2023 Xiaomi Corporation + +// C API for sherpa-mnn +// +// Please refer to +// https://github.com/k2-fsa/sherpa-mnn/blob/master/c-api-examples/decode-file-c-api.c +// for usages. +// + +#ifndef SHERPA_ONNX_C_API_C_API_H_ +#define SHERPA_ONNX_C_API_C_API_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// See https://github.com/pytorch/pytorch/blob/main/c10/macros/Export.h +// We will set SHERPA_ONNX_BUILD_SHARED_LIBS and SHERPA_ONNX_BUILD_MAIN_LIB in +// CMakeLists.txt + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wattributes" +#endif + +#if defined(_WIN32) +#if defined(SHERPA_ONNX_BUILD_SHARED_LIBS) +#define SHERPA_ONNX_EXPORT __declspec(dllexport) +#define SHERPA_ONNX_IMPORT __declspec(dllimport) +#else +#define SHERPA_ONNX_EXPORT +#define SHERPA_ONNX_IMPORT +#endif +#else // WIN32 +#define SHERPA_ONNX_EXPORT __attribute__((visibility("default"))) + +#define SHERPA_ONNX_IMPORT SHERPA_ONNX_EXPORT +#endif // WIN32 + +#if defined(SHERPA_ONNX_BUILD_MAIN_LIB) +#define SHERPA_ONNX_API SHERPA_ONNX_EXPORT +#else +#define SHERPA_ONNX_API SHERPA_ONNX_IMPORT +#endif + +/// Please refer to +/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +/// to download pre-trained models. That is, you can find encoder-xxx.onnx +/// decoder-xxx.onnx, joiner-xxx.onnx, and tokens.txt for this struct +/// from there. +SHERPA_ONNX_API typedef struct SherpaMnnOnlineTransducerModelConfig { + const char *encoder; + const char *decoder; + const char *joiner; +} SherpaMnnOnlineTransducerModelConfig; + +// please visit +// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/index.html +// to download pre-trained streaming paraformer models +SHERPA_ONNX_API typedef struct SherpaMnnOnlineParaformerModelConfig { + const char *encoder; + const char *decoder; +} SherpaMnnOnlineParaformerModelConfig; + +// Please visit +// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-ctc/zipformer-ctc-models.html# +// to download pre-trained streaming zipformer2 ctc models +SHERPA_ONNX_API typedef struct SherpaMnnOnlineZipformer2CtcModelConfig { + const char *model; +} SherpaMnnOnlineZipformer2CtcModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOnlineModelConfig { + SherpaMnnOnlineTransducerModelConfig transducer; + SherpaMnnOnlineParaformerModelConfig paraformer; + SherpaMnnOnlineZipformer2CtcModelConfig zipformer2_ctc; + const char *tokens; + int32_t num_threads; + const char *provider; + int32_t debug; // true to print debug information of the model + const char *model_type; + // Valid values: + // - cjkchar + // - bpe + // - cjkchar+bpe + const char *modeling_unit; + const char *bpe_vocab; + /// if non-null, loading the tokens from the buffer instead of from the + /// "tokens" file + const char *tokens_buf; + /// byte size excluding the trailing '\0' + int32_t tokens_buf_size; +} SherpaMnnOnlineModelConfig; + +/// It expects 16 kHz 16-bit single channel wave format. +SHERPA_ONNX_API typedef struct SherpaMnnFeatureConfig { + /// Sample rate of the input data. MUST match the one expected + /// by the model. For instance, it should be 16000 for models provided + /// by us. + int32_t sample_rate; + + /// Feature dimension of the model. + /// For instance, it should be 80 for models provided by us. + int32_t feature_dim; +} SherpaMnnFeatureConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOnlineCtcFstDecoderConfig { + const char *graph; + int32_t max_active; +} SherpaMnnOnlineCtcFstDecoderConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOnlineRecognizerConfig { + SherpaMnnFeatureConfig feat_config; + SherpaMnnOnlineModelConfig model_config; + + /// Possible values are: greedy_search, modified_beam_search + const char *decoding_method; + + /// Used only when decoding_method is modified_beam_search + /// Example value: 4 + int32_t max_active_paths; + + /// 0 to disable endpoint detection. + /// A non-zero value to enable endpoint detection. + int32_t enable_endpoint; + + /// An endpoint is detected if trailing silence in seconds is larger than + /// this value even if nothing has been decoded. + /// Used only when enable_endpoint is not 0. + float rule1_min_trailing_silence; + + /// An endpoint is detected if trailing silence in seconds is larger than + /// this value after something that is not blank has been decoded. + /// Used only when enable_endpoint is not 0. + float rule2_min_trailing_silence; + + /// An endpoint is detected if the utterance in seconds is larger than + /// this value. + /// Used only when enable_endpoint is not 0. + float rule3_min_utterance_length; + + /// Path to the hotwords. + const char *hotwords_file; + + /// Bonus score for each token in hotwords. + float hotwords_score; + + SherpaMnnOnlineCtcFstDecoderConfig ctc_fst_decoder_config; + const char *rule_fsts; + const char *rule_fars; + float blank_penalty; + + /// if non-nullptr, loading the hotwords from the buffered string directly in + const char *hotwords_buf; + /// byte size excluding the tailing '\0' + int32_t hotwords_buf_size; +} SherpaMnnOnlineRecognizerConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOnlineRecognizerResult { + // Recognized text + const char *text; + + // Pointer to continuous memory which holds string based tokens + // which are separated by \0 + const char *tokens; + + // a pointer array containing the address of the first item in tokens + const char *const *tokens_arr; + + // Pointer to continuous memory which holds timestamps + // + // Caution: If timestamp information is not available, this pointer is NULL. + // Please check whether it is NULL before you access it; otherwise, you would + // get segmentation fault. + float *timestamps; + + // The number of tokens/timestamps in above pointer + int32_t count; + + /** Return a json string. + * + * The returned string contains: + * { + * "text": "The recognition result", + * "tokens": [x, x, x], + * "timestamps": [x, x, x], + * "segment": x, + * "start_time": x, + * "is_final": true|false + * } + */ + const char *json; +} SherpaMnnOnlineRecognizerResult; + +/// Note: OnlineRecognizer here means StreamingRecognizer. +/// It does not need to access the Internet during recognition. +/// Everything is run locally. +SHERPA_ONNX_API typedef struct SherpaMnnOnlineRecognizer + SherpaMnnOnlineRecognizer; +SHERPA_ONNX_API typedef struct SherpaMnnOnlineStream SherpaMnnOnlineStream; + +/// @param config Config for the recognizer. +/// @return Return a pointer to the recognizer. The user has to invoke +// SherpaMnnDestroyOnlineRecognizer() to free it to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnOnlineRecognizer * +SherpaMnnCreateOnlineRecognizer( + const SherpaMnnOnlineRecognizerConfig *config); + +/// Free a pointer returned by SherpaMnnCreateOnlineRecognizer() +/// +/// @param p A pointer returned by SherpaMnnCreateOnlineRecognizer() +SHERPA_ONNX_API void SherpaMnnDestroyOnlineRecognizer( + const SherpaMnnOnlineRecognizer *recognizer); + +/// Create an online stream for accepting wave samples. +/// +/// @param recognizer A pointer returned by SherpaMnnCreateOnlineRecognizer() +/// @return Return a pointer to an OnlineStream. The user has to invoke +/// SherpaMnnDestroyOnlineStream() to free it to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnOnlineStream *SherpaMnnCreateOnlineStream( + const SherpaMnnOnlineRecognizer *recognizer); + +/// Create an online stream for accepting wave samples with the specified hot +/// words. +/// +/// @param recognizer A pointer returned by SherpaMnnCreateOnlineRecognizer() +/// @return Return a pointer to an OnlineStream. The user has to invoke +/// SherpaMnnDestroyOnlineStream() to free it to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnOnlineStream * +SherpaMnnCreateOnlineStreamWithHotwords( + const SherpaMnnOnlineRecognizer *recognizer, const char *hotwords); + +/// Destroy an online stream. +/// +/// @param stream A pointer returned by SherpaMnnCreateOnlineStream() +SHERPA_ONNX_API void SherpaMnnDestroyOnlineStream( + const SherpaMnnOnlineStream *stream); + +/// Accept input audio samples and compute the features. +/// The user has to invoke SherpaMnnDecodeOnlineStream() to run the neural +/// network and decoding. +/// +/// @param stream A pointer returned by SherpaMnnCreateOnlineStream(). +/// @param sample_rate Sample rate of the input samples. If it is different +/// from config.feat_config.sample_rate, we will do +/// resampling inside sherpa-mnn. +/// @param samples A pointer to a 1-D array containing audio samples. +/// The range of samples has to be normalized to [-1, 1]. +/// @param n Number of elements in the samples array. +SHERPA_ONNX_API void SherpaMnnOnlineStreamAcceptWaveform( + const SherpaMnnOnlineStream *stream, int32_t sample_rate, + const float *samples, int32_t n); + +/// Return 1 if there are enough number of feature frames for decoding. +/// Return 0 otherwise. +/// +/// @param recognizer A pointer returned by SherpaMnnCreateOnlineRecognizer +/// @param stream A pointer returned by SherpaMnnCreateOnlineStream +SHERPA_ONNX_API int32_t +SherpaMnnIsOnlineStreamReady(const SherpaMnnOnlineRecognizer *recognizer, + const SherpaMnnOnlineStream *stream); + +/// Call this function to run the neural network model and decoding. +// +/// Precondition for this function: SherpaMnnIsOnlineStreamReady() MUST +/// return 1. +/// +/// Usage example: +/// +/// while (SherpaMnnIsOnlineStreamReady(recognizer, stream)) { +/// SherpaMnnDecodeOnlineStream(recognizer, stream); +/// } +/// +SHERPA_ONNX_API void SherpaMnnDecodeOnlineStream( + const SherpaMnnOnlineRecognizer *recognizer, + const SherpaMnnOnlineStream *stream); + +/// This function is similar to SherpaMnnDecodeOnlineStream(). It decodes +/// multiple OnlineStream in parallel. +/// +/// Caution: The caller has to ensure each OnlineStream is ready, i.e., +/// SherpaMnnIsOnlineStreamReady() for that stream should return 1. +/// +/// @param recognizer A pointer returned by SherpaMnnCreateOnlineRecognizer() +/// @param streams A pointer array containing pointers returned by +/// SherpaMnnCreateOnlineRecognizer() +/// @param n Number of elements in the given streams array. +SHERPA_ONNX_API void SherpaMnnDecodeMultipleOnlineStreams( + const SherpaMnnOnlineRecognizer *recognizer, + const SherpaMnnOnlineStream **streams, int32_t n); + +/// Get the decoding results so far for an OnlineStream. +/// +/// @param recognizer A pointer returned by SherpaMnnCreateOnlineRecognizer(). +/// @param stream A pointer returned by SherpaMnnCreateOnlineStream(). +/// @return A pointer containing the result. The user has to invoke +/// SherpaMnnDestroyOnlineRecognizerResult() to free the returned +/// pointer to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnOnlineRecognizerResult * +SherpaMnnGetOnlineStreamResult(const SherpaMnnOnlineRecognizer *recognizer, + const SherpaMnnOnlineStream *stream); + +/// Destroy the pointer returned by SherpaMnnGetOnlineStreamResult(). +/// +/// @param r A pointer returned by SherpaMnnGetOnlineStreamResult() +SHERPA_ONNX_API void SherpaMnnDestroyOnlineRecognizerResult( + const SherpaMnnOnlineRecognizerResult *r); + +/// Return the result as a json string. +/// The user has to invoke +/// SherpaMnnDestroyOnlineStreamResultJson() +/// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API const char *SherpaMnnGetOnlineStreamResultAsJson( + const SherpaMnnOnlineRecognizer *recognizer, + const SherpaMnnOnlineStream *stream); + +SHERPA_ONNX_API void SherpaMnnDestroyOnlineStreamResultJson(const char *s); + +/// SherpaMnnOnlineStreamReset an OnlineStream , which clears the neural +/// network model state and the state for decoding. +/// +/// @param recognizer A pointer returned by SherpaMnnCreateOnlineRecognizer(). +/// @param stream A pointer returned by SherpaMnnCreateOnlineStream +SHERPA_ONNX_API void SherpaMnnOnlineStreamReset( + const SherpaMnnOnlineRecognizer *recognizer, + const SherpaMnnOnlineStream *stream); + +/// Signal that no more audio samples would be available. +/// After this call, you cannot call SherpaMnnOnlineStreamAcceptWaveform() any +/// more. +/// +/// @param stream A pointer returned by SherpaMnnCreateOnlineStream() +SHERPA_ONNX_API void SherpaMnnOnlineStreamInputFinished( + const SherpaMnnOnlineStream *stream); + +/// Return 1 if an endpoint has been detected. +/// +/// @param recognizer A pointer returned by SherpaMnnCreateOnlineRecognizer() +/// @param stream A pointer returned by SherpaMnnCreateOnlineStream() +/// @return Return 1 if an endpoint is detected. Return 0 otherwise. +SHERPA_ONNX_API int32_t +SherpaMnnOnlineStreamIsEndpoint(const SherpaMnnOnlineRecognizer *recognizer, + const SherpaMnnOnlineStream *stream); + +// for displaying results on Linux/macOS. +SHERPA_ONNX_API typedef struct SherpaMnnDisplay SherpaMnnDisplay; + +/// Create a display object. Must be freed using SherpaMnnDestroyDisplay to +/// avoid memory leak. +SHERPA_ONNX_API const SherpaMnnDisplay *SherpaMnnCreateDisplay( + int32_t max_word_per_line); + +SHERPA_ONNX_API void SherpaMnnDestroyDisplay(const SherpaMnnDisplay *display); + +/// Print the result. +SHERPA_ONNX_API void SherpaMnnPrint(const SherpaMnnDisplay *display, + int32_t idx, const char *s); +// ============================================================ +// For offline ASR (i.e., non-streaming ASR) +// ============================================================ + +/// Please refer to +/// https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +/// to download pre-trained models. That is, you can find encoder-xxx.onnx +/// decoder-xxx.onnx, and joiner-xxx.onnx for this struct +/// from there. +SHERPA_ONNX_API typedef struct SherpaMnnOfflineTransducerModelConfig { + const char *encoder; + const char *decoder; + const char *joiner; +} SherpaMnnOfflineTransducerModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineParaformerModelConfig { + const char *model; +} SherpaMnnOfflineParaformerModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineNemoEncDecCtcModelConfig { + const char *model; +} SherpaMnnOfflineNemoEncDecCtcModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineWhisperModelConfig { + const char *encoder; + const char *decoder; + const char *language; + const char *task; + int32_t tail_paddings; +} SherpaMnnOfflineWhisperModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineFireRedAsrModelConfig { + const char *encoder; + const char *decoder; +} SherpaMnnOfflineFireRedAsrModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineMoonshineModelConfig { + const char *preprocessor; + const char *encoder; + const char *uncached_decoder; + const char *cached_decoder; +} SherpaMnnOfflineMoonshineModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineTdnnModelConfig { + const char *model; +} SherpaMnnOfflineTdnnModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineLMConfig { + const char *model; + float scale; +} SherpaMnnOfflineLMConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineSenseVoiceModelConfig { + const char *model; + const char *language; + int32_t use_itn; +} SherpaMnnOfflineSenseVoiceModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineModelConfig { + SherpaMnnOfflineTransducerModelConfig transducer; + SherpaMnnOfflineParaformerModelConfig paraformer; + SherpaMnnOfflineNemoEncDecCtcModelConfig nemo_ctc; + SherpaMnnOfflineWhisperModelConfig whisper; + SherpaMnnOfflineTdnnModelConfig tdnn; + + const char *tokens; + int32_t num_threads; + int32_t debug; + const char *provider; + const char *model_type; + // Valid values: + // - cjkchar + // - bpe + // - cjkchar+bpe + const char *modeling_unit; + const char *bpe_vocab; + const char *telespeech_ctc; + SherpaMnnOfflineSenseVoiceModelConfig sense_voice; + SherpaMnnOfflineMoonshineModelConfig moonshine; + SherpaMnnOfflineFireRedAsrModelConfig fire_red_asr; +} SherpaMnnOfflineModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineRecognizerConfig { + SherpaMnnFeatureConfig feat_config; + SherpaMnnOfflineModelConfig model_config; + SherpaMnnOfflineLMConfig lm_config; + + const char *decoding_method; + int32_t max_active_paths; + + /// Path to the hotwords. + const char *hotwords_file; + + /// Bonus score for each token in hotwords. + float hotwords_score; + const char *rule_fsts; + const char *rule_fars; + float blank_penalty; +} SherpaMnnOfflineRecognizerConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineRecognizer + SherpaMnnOfflineRecognizer; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineStream SherpaMnnOfflineStream; + +/// @param config Config for the recognizer. +/// @return Return a pointer to the recognizer. The user has to invoke +// SherpaMnnDestroyOfflineRecognizer() to free it to avoid memory +// leak. +SHERPA_ONNX_API const SherpaMnnOfflineRecognizer * +SherpaMnnCreateOfflineRecognizer( + const SherpaMnnOfflineRecognizerConfig *config); + +/// @param config Config for the recognizer. +SHERPA_ONNX_API void SherpaMnnOfflineRecognizerSetConfig( + const SherpaMnnOfflineRecognizer *recognizer, + const SherpaMnnOfflineRecognizerConfig *config); + +/// Free a pointer returned by SherpaMnnCreateOfflineRecognizer() +/// +/// @param p A pointer returned by SherpaMnnCreateOfflineRecognizer() +SHERPA_ONNX_API void SherpaMnnDestroyOfflineRecognizer( + const SherpaMnnOfflineRecognizer *recognizer); + +/// Create an offline stream for accepting wave samples. +/// +/// @param recognizer A pointer returned by SherpaMnnCreateOfflineRecognizer() +/// @return Return a pointer to an OfflineStream. The user has to invoke +/// SherpaMnnDestroyOfflineStream() to free it to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnOfflineStream *SherpaMnnCreateOfflineStream( + const SherpaMnnOfflineRecognizer *recognizer); + +/// Create an offline stream for accepting wave samples with the specified hot +/// words. +/// +/// @param recognizer A pointer returned by SherpaMnnCreateOfflineRecognizer() +/// @return Return a pointer to an OfflineStream. The user has to invoke +/// SherpaMnnDestroyOfflineStream() to free it to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnOfflineStream * +SherpaMnnCreateOfflineStreamWithHotwords( + const SherpaMnnOfflineRecognizer *recognizer, const char *hotwords); + +/// Destroy an offline stream. +/// +/// @param stream A pointer returned by SherpaMnnCreateOfflineStream() +SHERPA_ONNX_API void SherpaMnnDestroyOfflineStream( + const SherpaMnnOfflineStream *stream); + +/// Accept input audio samples and compute the features. +/// The user has to invoke SherpaMnnDecodeOfflineStream() to run the neural +/// network and decoding. +/// +/// @param stream A pointer returned by SherpaMnnCreateOfflineStream(). +/// @param sample_rate Sample rate of the input samples. If it is different +/// from config.feat_config.sample_rate, we will do +/// resampling inside sherpa-mnn. +/// @param samples A pointer to a 1-D array containing audio samples. +/// The range of samples has to be normalized to [-1, 1]. +/// @param n Number of elements in the samples array. +/// +/// @caution: For each offline stream, please invoke this function only once! +SHERPA_ONNX_API void SherpaMnnAcceptWaveformOffline( + const SherpaMnnOfflineStream *stream, int32_t sample_rate, + const float *samples, int32_t n); +/// Decode an offline stream. +/// +/// We assume you have invoked SherpaMnnAcceptWaveformOffline() for the given +/// stream before calling this function. +/// +/// @param recognizer A pointer returned by SherpaMnnCreateOfflineRecognizer(). +/// @param stream A pointer returned by SherpaMnnCreateOfflineStream() +SHERPA_ONNX_API void SherpaMnnDecodeOfflineStream( + const SherpaMnnOfflineRecognizer *recognizer, + const SherpaMnnOfflineStream *stream); + +/// Decode a list offline streams in parallel. +/// +/// We assume you have invoked SherpaMnnAcceptWaveformOffline() for each stream +/// before calling this function. +/// +/// @param recognizer A pointer returned by SherpaMnnCreateOfflineRecognizer(). +/// @param streams A pointer pointer array containing pointers returned +/// by SherpaMnnCreateOfflineStream(). +/// @param n Number of entries in the given streams. +SHERPA_ONNX_API void SherpaMnnDecodeMultipleOfflineStreams( + const SherpaMnnOfflineRecognizer *recognizer, + const SherpaMnnOfflineStream **streams, int32_t n); + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineRecognizerResult { + const char *text; + + // Pointer to continuous memory which holds timestamps + // + // It is NULL if the model does not support timestamps + float *timestamps; + + // number of entries in timestamps + int32_t count; + + // Pointer to continuous memory which holds string based tokens + // which are separated by \0 + const char *tokens; + + // a pointer array containing the address of the first item in tokens + const char *const *tokens_arr; + + /** Return a json string. + * + * The returned string contains: + * { + * "text": "The recognition result", + * "tokens": [x, x, x], + * "timestamps": [x, x, x], + * "segment": x, + * "start_time": x, + * "is_final": true|false + * } + */ + const char *json; + + // return recognized language + const char *lang; + + // return emotion. + const char *emotion; + + // return event. + const char *event; +} SherpaMnnOfflineRecognizerResult; + +/// Get the result of the offline stream. +/// +/// We assume you have called SherpaMnnDecodeOfflineStream() or +/// SherpaMnnDecodeMultipleOfflineStreams() with the given stream before +/// calling this function. +/// +/// @param stream A pointer returned by SherpaMnnCreateOfflineStream(). +/// @return Return a pointer to the result. The user has to invoke +/// SherpaMnnDestroyOnlineRecognizerResult() to free the returned +/// pointer to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnOfflineRecognizerResult * +SherpaMnnGetOfflineStreamResult(const SherpaMnnOfflineStream *stream); + +/// Destroy the pointer returned by SherpaMnnGetOfflineStreamResult(). +/// +/// @param r A pointer returned by SherpaMnnGetOfflineStreamResult() +SHERPA_ONNX_API void SherpaMnnDestroyOfflineRecognizerResult( + const SherpaMnnOfflineRecognizerResult *r); + +/// Return the result as a json string. +/// The user has to use SherpaMnnDestroyOfflineStreamResultJson() +/// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API const char *SherpaMnnGetOfflineStreamResultAsJson( + const SherpaMnnOfflineStream *stream); + +SHERPA_ONNX_API void SherpaMnnDestroyOfflineStreamResultJson(const char *s); + +// ============================================================ +// For Keyword Spotter +// ============================================================ +SHERPA_ONNX_API typedef struct SherpaMnnKeywordResult { + /// The triggered keyword. + /// For English, it consists of space separated words. + /// For Chinese, it consists of Chinese words without spaces. + /// Example 1: "hello world" + /// Example 2: "你好世界" + const char *keyword; + + /// Decoded results at the token level. + /// For instance, for BPE-based models it consists of a list of BPE tokens. + const char *tokens; + + const char *const *tokens_arr; + + int32_t count; + + /// timestamps.size() == tokens.size() + /// timestamps[i] records the time in seconds when tokens[i] is decoded. + float *timestamps; + + /// Starting time of this segment. + /// When an endpoint is detected, it will change + float start_time; + + /** Return a json string. + * + * The returned string contains: + * { + * "keyword": "The triggered keyword", + * "tokens": [x, x, x], + * "timestamps": [x, x, x], + * "start_time": x, + * } + */ + const char *json; +} SherpaMnnKeywordResult; + +SHERPA_ONNX_API typedef struct SherpaMnnKeywordSpotterConfig { + SherpaMnnFeatureConfig feat_config; + SherpaMnnOnlineModelConfig model_config; + int32_t max_active_paths; + int32_t num_trailing_blanks; + float keywords_score; + float keywords_threshold; + const char *keywords_file; + /// if non-null, loading the keywords from the buffer instead of from the + /// keywords_file + const char *keywords_buf; + /// byte size excluding the trailing '\0' + int32_t keywords_buf_size; +} SherpaMnnKeywordSpotterConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnKeywordSpotter + SherpaMnnKeywordSpotter; + +/// @param config Config for the keyword spotter. +/// @return Return a pointer to the spotter. The user has to invoke +/// SherpaMnnDestroyKeywordSpotter() to free it to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnKeywordSpotter *SherpaMnnCreateKeywordSpotter( + const SherpaMnnKeywordSpotterConfig *config); + +/// Free a pointer returned by SherpaMnnCreateKeywordSpotter() +/// +/// @param p A pointer returned by SherpaMnnCreateKeywordSpotter() +SHERPA_ONNX_API void SherpaMnnDestroyKeywordSpotter( + const SherpaMnnKeywordSpotter *spotter); + +/// Create an online stream for accepting wave samples. +/// +/// @param spotter A pointer returned by SherpaMnnCreateKeywordSpotter() +/// @return Return a pointer to an OnlineStream. The user has to invoke +/// SherpaMnnDestroyOnlineStream() to free it to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnOnlineStream *SherpaMnnCreateKeywordStream( + const SherpaMnnKeywordSpotter *spotter); + +/// Create an online stream for accepting wave samples with the specified hot +/// words. +/// +/// @param spotter A pointer returned by SherpaMnnCreateKeywordSpotter() +/// @param keywords A pointer points to the keywords that you set +/// @return Return a pointer to an OnlineStream. The user has to invoke +/// SherpaMnnDestroyOnlineStream() to free it to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnOnlineStream * +SherpaMnnCreateKeywordStreamWithKeywords( + const SherpaMnnKeywordSpotter *spotter, const char *keywords); + +/// Return 1 if there are enough number of feature frames for decoding. +/// Return 0 otherwise. +/// +/// @param spotter A pointer returned by SherpaMnnCreateKeywordSpotter +/// @param stream A pointer returned by SherpaMnnCreateKeywordStream +SHERPA_ONNX_API int32_t +SherpaMnnIsKeywordStreamReady(const SherpaMnnKeywordSpotter *spotter, + const SherpaMnnOnlineStream *stream); + +/// Call this function to run the neural network model and decoding. +// +/// Precondition for this function: SherpaMnnIsKeywordStreamReady() MUST +/// return 1. +SHERPA_ONNX_API void SherpaMnnDecodeKeywordStream( + const SherpaMnnKeywordSpotter *spotter, + const SherpaMnnOnlineStream *stream); + +/// Please call it right after a keyword is detected +SHERPA_ONNX_API void SherpaMnnResetKeywordStream( + const SherpaMnnKeywordSpotter *spotter, + const SherpaMnnOnlineStream *stream); + +/// This function is similar to SherpaMnnDecodeKeywordStream(). It decodes +/// multiple OnlineStream in parallel. +/// +/// Caution: The caller has to ensure each OnlineStream is ready, i.e., +/// SherpaMnnIsKeywordStreamReady() for that stream should return 1. +/// +/// @param spotter A pointer returned by SherpaMnnCreateKeywordSpotter() +/// @param streams A pointer array containing pointers returned by +/// SherpaMnnCreateKeywordStream() +/// @param n Number of elements in the given streams array. +SHERPA_ONNX_API void SherpaMnnDecodeMultipleKeywordStreams( + const SherpaMnnKeywordSpotter *spotter, + const SherpaMnnOnlineStream **streams, int32_t n); + +/// Get the decoding results so far for an OnlineStream. +/// +/// @param spotter A pointer returned by SherpaMnnCreateKeywordSpotter(). +/// @param stream A pointer returned by SherpaMnnCreateKeywordStream(). +/// @return A pointer containing the result. The user has to invoke +/// SherpaMnnDestroyKeywordResult() to free the returned pointer to +/// avoid memory leak. +SHERPA_ONNX_API const SherpaMnnKeywordResult *SherpaMnnGetKeywordResult( + const SherpaMnnKeywordSpotter *spotter, + const SherpaMnnOnlineStream *stream); + +/// Destroy the pointer returned by SherpaMnnGetKeywordResult(). +/// +/// @param r A pointer returned by SherpaMnnGetKeywordResult() +SHERPA_ONNX_API void SherpaMnnDestroyKeywordResult( + const SherpaMnnKeywordResult *r); + +// the user has to call SherpaMnnFreeKeywordResultJson() to free the returned +// pointer to avoid memory leak +SHERPA_ONNX_API const char *SherpaMnnGetKeywordResultAsJson( + const SherpaMnnKeywordSpotter *spotter, + const SherpaMnnOnlineStream *stream); + +SHERPA_ONNX_API void SherpaMnnFreeKeywordResultJson(const char *s); + +// ============================================================ +// For VAD +// ============================================================ + +SHERPA_ONNX_API typedef struct SherpaMnnSileroVadModelConfig { + // Path to the silero VAD model + const char *model; + + // threshold to classify a segment as speech + // + // If the predicted probability of a segment is larger than this + // value, then it is classified as speech. + float threshold; + + // in seconds + float min_silence_duration; + + // in seconds + float min_speech_duration; + + int window_size; + + // If a speech segment is longer than this value, then we increase + // the threshold to 0.9. After finishing detecting the segment, + // the threshold value is reset to its original value. + float max_speech_duration; +} SherpaMnnSileroVadModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnVadModelConfig { + SherpaMnnSileroVadModelConfig silero_vad; + + int32_t sample_rate; + int32_t num_threads; + const char *provider; + int32_t debug; +} SherpaMnnVadModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnCircularBuffer + SherpaMnnCircularBuffer; + +// Return an instance of circular buffer. The user has to use +// SherpaMnnDestroyCircularBuffer() to free the returned pointer to avoid +// memory leak. +SHERPA_ONNX_API const SherpaMnnCircularBuffer *SherpaMnnCreateCircularBuffer( + int32_t capacity); + +// Free the pointer returned by SherpaMnnCreateCircularBuffer() +SHERPA_ONNX_API void SherpaMnnDestroyCircularBuffer( + const SherpaMnnCircularBuffer *buffer); + +SHERPA_ONNX_API void SherpaMnnCircularBufferPush( + const SherpaMnnCircularBuffer *buffer, const float *p, int32_t n); + +// Return n samples starting at the given index. +// +// Return a pointer to an array containing n samples starting at start_index. +// The user has to use SherpaMnnCircularBufferFree() to free the returned +// pointer to avoid memory leak. +SHERPA_ONNX_API const float *SherpaMnnCircularBufferGet( + const SherpaMnnCircularBuffer *buffer, int32_t start_index, int32_t n); + +// Free the pointer returned by SherpaMnnCircularBufferGet(). +SHERPA_ONNX_API void SherpaMnnCircularBufferFree(const float *p); + +// Remove n elements from the buffer +SHERPA_ONNX_API void SherpaMnnCircularBufferPop( + const SherpaMnnCircularBuffer *buffer, int32_t n); + +// Return number of elements in the buffer. +SHERPA_ONNX_API int32_t +SherpaMnnCircularBufferSize(const SherpaMnnCircularBuffer *buffer); + +// Return the head of the buffer. It's always non-decreasing until you +// invoke SherpaMnnCircularBufferReset() which resets head to 0. +SHERPA_ONNX_API int32_t +SherpaMnnCircularBufferHead(const SherpaMnnCircularBuffer *buffer); + +// Clear all elements in the buffer +SHERPA_ONNX_API void SherpaMnnCircularBufferReset( + const SherpaMnnCircularBuffer *buffer); + +SHERPA_ONNX_API typedef struct SherpaMnnSpeechSegment { + // The start index in samples of this segment + int32_t start; + + // pointer to the array containing the samples + float *samples; + + // number of samples in this segment + int32_t n; +} SherpaMnnSpeechSegment; + +typedef struct SherpaMnnVoiceActivityDetector SherpaMnnVoiceActivityDetector; + +// Return an instance of VoiceActivityDetector. +// The user has to use SherpaMnnDestroyVoiceActivityDetector() to free +// the returned pointer to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnVoiceActivityDetector * +SherpaMnnCreateVoiceActivityDetector(const SherpaMnnVadModelConfig *config, + float buffer_size_in_seconds); + +SHERPA_ONNX_API void SherpaMnnDestroyVoiceActivityDetector( + const SherpaMnnVoiceActivityDetector *p); + +SHERPA_ONNX_API void SherpaMnnVoiceActivityDetectorAcceptWaveform( + const SherpaMnnVoiceActivityDetector *p, const float *samples, int32_t n); + +// Return 1 if there are no speech segments available. +// Return 0 if there are speech segments. +SHERPA_ONNX_API int32_t +SherpaMnnVoiceActivityDetectorEmpty(const SherpaMnnVoiceActivityDetector *p); + +// Return 1 if there is voice detected. +// Return 0 if voice is silent. +SHERPA_ONNX_API int32_t SherpaMnnVoiceActivityDetectorDetected( + const SherpaMnnVoiceActivityDetector *p); + +// Return the first speech segment. +// It throws if SherpaMnnVoiceActivityDetectorEmpty() returns 1. +SHERPA_ONNX_API void SherpaMnnVoiceActivityDetectorPop( + const SherpaMnnVoiceActivityDetector *p); + +// Clear current speech segments. +SHERPA_ONNX_API void SherpaMnnVoiceActivityDetectorClear( + const SherpaMnnVoiceActivityDetector *p); + +// Return the first speech segment. +// The user has to use SherpaMnnDestroySpeechSegment() to free the returned +// pointer to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnSpeechSegment * +SherpaMnnVoiceActivityDetectorFront(const SherpaMnnVoiceActivityDetector *p); + +// Free the pointer returned SherpaMnnVoiceActivityDetectorFront(). +SHERPA_ONNX_API void SherpaMnnDestroySpeechSegment( + const SherpaMnnSpeechSegment *p); + +// Re-initialize the voice activity detector. +SHERPA_ONNX_API void SherpaMnnVoiceActivityDetectorReset( + const SherpaMnnVoiceActivityDetector *p); + +SHERPA_ONNX_API void SherpaMnnVoiceActivityDetectorFlush( + const SherpaMnnVoiceActivityDetector *p); + +// ============================================================ +// For offline Text-to-Speech (i.e., non-streaming TTS) +// ============================================================ +SHERPA_ONNX_API typedef struct SherpaMnnOfflineTtsVitsModelConfig { + const char *model; + const char *lexicon; + const char *tokens; + const char *data_dir; + + float noise_scale; + float noise_scale_w; + float length_scale; // < 1, faster in speech speed; > 1, slower in speed + const char *dict_dir; +} SherpaMnnOfflineTtsVitsModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineTtsMatchaModelConfig { + const char *acoustic_model; + const char *vocoder; + const char *lexicon; + const char *tokens; + const char *data_dir; + + float noise_scale; + float length_scale; // < 1, faster in speech speed; > 1, slower in speed + const char *dict_dir; +} SherpaMnnOfflineTtsMatchaModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineTtsKokoroModelConfig { + const char *model; + const char *voices; + const char *tokens; + const char *data_dir; + + float length_scale; // < 1, faster in speech speed; > 1, slower in speed + const char *dict_dir; + const char *lexicon; +} SherpaMnnOfflineTtsKokoroModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineTtsModelConfig { + SherpaMnnOfflineTtsVitsModelConfig vits; + int32_t num_threads; + int32_t debug; + const char *provider; + SherpaMnnOfflineTtsMatchaModelConfig matcha; + SherpaMnnOfflineTtsKokoroModelConfig kokoro; +} SherpaMnnOfflineTtsModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineTtsConfig { + SherpaMnnOfflineTtsModelConfig model; + const char *rule_fsts; + int32_t max_num_sentences; + const char *rule_fars; + float silence_scale; +} SherpaMnnOfflineTtsConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnGeneratedAudio { + const float *samples; // in the range [-1, 1] + int32_t n; // number of samples + int32_t sample_rate; +} SherpaMnnGeneratedAudio; + +// If the callback returns 0, then it stops generating +// If the callback returns 1, then it keeps generating +typedef int32_t (*SherpaMnnGeneratedAudioCallback)(const float *samples, + int32_t n); + +typedef int32_t (*SherpaMnnGeneratedAudioCallbackWithArg)(const float *samples, + int32_t n, + void *arg); + +typedef int32_t (*SherpaMnnGeneratedAudioProgressCallback)( + const float *samples, int32_t n, float p); + +typedef int32_t (*SherpaMnnGeneratedAudioProgressCallbackWithArg)( + const float *samples, int32_t n, float p, void *arg); + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineTts SherpaMnnOfflineTts; + +// Create an instance of offline TTS. The user has to use DestroyOfflineTts() +// to free the returned pointer to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnOfflineTts *SherpaMnnCreateOfflineTts( + const SherpaMnnOfflineTtsConfig *config); + +// Free the pointer returned by SherpaMnnCreateOfflineTts() +SHERPA_ONNX_API void SherpaMnnDestroyOfflineTts( + const SherpaMnnOfflineTts *tts); + +// Return the sample rate of the current TTS object +SHERPA_ONNX_API int32_t +SherpaMnnOfflineTtsSampleRate(const SherpaMnnOfflineTts *tts); + +// Return the number of speakers of the current TTS object +SHERPA_ONNX_API int32_t +SherpaMnnOfflineTtsNumSpeakers(const SherpaMnnOfflineTts *tts); + +// Generate audio from the given text and speaker id (sid). +// The user has to use SherpaMnnDestroyOfflineTtsGeneratedAudio() to free the +// returned pointer to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnGeneratedAudio *SherpaMnnOfflineTtsGenerate( + const SherpaMnnOfflineTts *tts, const char *text, int32_t sid, + float speed); + +// callback is called whenever SherpaMnnOfflineTtsConfig.max_num_sentences +// sentences have been processed. The pointer passed to the callback +// is freed once the callback is returned. So the caller should not keep +// a reference to it. +SHERPA_ONNX_API const SherpaMnnGeneratedAudio * +SherpaMnnOfflineTtsGenerateWithCallback( + const SherpaMnnOfflineTts *tts, const char *text, int32_t sid, float speed, + SherpaMnnGeneratedAudioCallback callback); + +SHERPA_ONNX_API +const SherpaMnnGeneratedAudio * +SherpaMnnOfflineTtsGenerateWithProgressCallback( + const SherpaMnnOfflineTts *tts, const char *text, int32_t sid, float speed, + SherpaMnnGeneratedAudioProgressCallback callback); + +SHERPA_ONNX_API +const SherpaMnnGeneratedAudio * +SherpaMnnOfflineTtsGenerateWithProgressCallbackWithArg( + const SherpaMnnOfflineTts *tts, const char *text, int32_t sid, float speed, + SherpaMnnGeneratedAudioProgressCallbackWithArg callback, void *arg); + +// Same as SherpaMnnGeneratedAudioCallback but you can pass an additional +// `void* arg` to the callback. +SHERPA_ONNX_API const SherpaMnnGeneratedAudio * +SherpaMnnOfflineTtsGenerateWithCallbackWithArg( + const SherpaMnnOfflineTts *tts, const char *text, int32_t sid, float speed, + SherpaMnnGeneratedAudioCallbackWithArg callback, void *arg); + +SHERPA_ONNX_API void SherpaMnnDestroyOfflineTtsGeneratedAudio( + const SherpaMnnGeneratedAudio *p); + +// Write the generated audio to a wave file. +// The saved wave file contains a single channel and has 16-bit samples. +// +// Return 1 if the write succeeded; return 0 on failure. +SHERPA_ONNX_API int32_t SherpaMnnWriteWave(const float *samples, int32_t n, + int32_t sample_rate, + const char *filename); + +// the amount of bytes needed to store a wave file which contains a +// single channel and has 16-bit samples. +SHERPA_ONNX_API int64_t SherpaMnnWaveFileSize(int32_t n_samples); + +// Similar to SherpaMnnWriteWave , it writes wave to allocated buffer; +// +// in some case (http tts api return wave binary file, server do not need to +// write wave to fs) +SHERPA_ONNX_API void SherpaMnnWriteWaveToBuffer(const float *samples, + int32_t n, int32_t sample_rate, + char *buffer); + +SHERPA_ONNX_API typedef struct SherpaMnnWave { + // samples normalized to the range [-1, 1] + const float *samples; + int32_t sample_rate; + int32_t num_samples; +} SherpaMnnWave; + +// Return a NULL pointer on error. It supports only standard WAVE file. +// Each sample should be 16-bit. It supports only single channel.. +// +// If the returned pointer is not NULL, the user has to invoke +// SherpaMnnFreeWave() to free the returned pointer to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnWave *SherpaMnnReadWave(const char *filename); + +// Similar to SherpaMnnReadWave(), it has read the content of `filename` +// into the array `data`. +// +// If the returned pointer is not NULL, the user has to invoke +// SherpaMnnFreeWave() to free the returned pointer to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnWave *SherpaMnnReadWaveFromBinaryData( + const char *data, int32_t n); + +SHERPA_ONNX_API void SherpaMnnFreeWave(const SherpaMnnWave *wave); + +// ============================================================ +// For spoken language identification +// ============================================================ + +SHERPA_ONNX_API typedef struct + SherpaMnnSpokenLanguageIdentificationWhisperConfig { + const char *encoder; + const char *decoder; + int32_t tail_paddings; +} SherpaMnnSpokenLanguageIdentificationWhisperConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnSpokenLanguageIdentificationConfig { + SherpaMnnSpokenLanguageIdentificationWhisperConfig whisper; + int32_t num_threads; + int32_t debug; + const char *provider; +} SherpaMnnSpokenLanguageIdentificationConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnSpokenLanguageIdentification + SherpaMnnSpokenLanguageIdentification; + +// Create an instance of SpokenLanguageIdentification. +// The user has to invoke SherpaMnnDestroySpokenLanguageIdentification() +// to free the returned pointer to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnSpokenLanguageIdentification * +SherpaMnnCreateSpokenLanguageIdentification( + const SherpaMnnSpokenLanguageIdentificationConfig *config); + +SHERPA_ONNX_API void SherpaMnnDestroySpokenLanguageIdentification( + const SherpaMnnSpokenLanguageIdentification *slid); + +// The user has to invoke SherpaMnnDestroyOfflineStream() +// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API SherpaMnnOfflineStream * +SherpaMnnSpokenLanguageIdentificationCreateOfflineStream( + const SherpaMnnSpokenLanguageIdentification *slid); + +SHERPA_ONNX_API typedef struct SherpaMnnSpokenLanguageIdentificationResult { + // en for English + // de for German + // zh for Chinese + // es for Spanish + // ... + const char *lang; +} SherpaMnnSpokenLanguageIdentificationResult; + +// The user has to invoke SherpaMnnDestroySpokenLanguageIdentificationResult() +// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API const SherpaMnnSpokenLanguageIdentificationResult * +SherpaMnnSpokenLanguageIdentificationCompute( + const SherpaMnnSpokenLanguageIdentification *slid, + const SherpaMnnOfflineStream *s); + +SHERPA_ONNX_API void SherpaMnnDestroySpokenLanguageIdentificationResult( + const SherpaMnnSpokenLanguageIdentificationResult *r); + +// ============================================================ +// For speaker embedding extraction +// ============================================================ +SHERPA_ONNX_API typedef struct SherpaMnnSpeakerEmbeddingExtractorConfig { + const char *model; + int32_t num_threads; + int32_t debug; + const char *provider; +} SherpaMnnSpeakerEmbeddingExtractorConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnSpeakerEmbeddingExtractor + SherpaMnnSpeakerEmbeddingExtractor; + +// The user has to invoke SherpaMnnDestroySpeakerEmbeddingExtractor() +// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API const SherpaMnnSpeakerEmbeddingExtractor * +SherpaMnnCreateSpeakerEmbeddingExtractor( + const SherpaMnnSpeakerEmbeddingExtractorConfig *config); + +SHERPA_ONNX_API void SherpaMnnDestroySpeakerEmbeddingExtractor( + const SherpaMnnSpeakerEmbeddingExtractor *p); + +SHERPA_ONNX_API int32_t SherpaMnnSpeakerEmbeddingExtractorDim( + const SherpaMnnSpeakerEmbeddingExtractor *p); + +// The user has to invoke SherpaMnnDestroyOnlineStream() to free the returned +// pointer to avoid memory leak +SHERPA_ONNX_API const SherpaMnnOnlineStream * +SherpaMnnSpeakerEmbeddingExtractorCreateStream( + const SherpaMnnSpeakerEmbeddingExtractor *p); + +// Return 1 if the stream has enough feature frames for computing embeddings. +// Return 0 otherwise. +SHERPA_ONNX_API int32_t SherpaMnnSpeakerEmbeddingExtractorIsReady( + const SherpaMnnSpeakerEmbeddingExtractor *p, + const SherpaMnnOnlineStream *s); + +// Compute the embedding of the stream. +// +// @return Return a pointer pointing to an array containing the embedding. +// The length of the array is `dim` as returned by +// SherpaMnnSpeakerEmbeddingExtractorDim(p) +// +// The user has to invoke SherpaMnnSpeakerEmbeddingExtractorDestroyEmbedding() +// to free the returned pointer to avoid memory leak. +SHERPA_ONNX_API const float * +SherpaMnnSpeakerEmbeddingExtractorComputeEmbedding( + const SherpaMnnSpeakerEmbeddingExtractor *p, + const SherpaMnnOnlineStream *s); + +SHERPA_ONNX_API void SherpaMnnSpeakerEmbeddingExtractorDestroyEmbedding( + const float *v); + +SHERPA_ONNX_API typedef struct SherpaMnnSpeakerEmbeddingManager + SherpaMnnSpeakerEmbeddingManager; + +// The user has to invoke SherpaMnnDestroySpeakerEmbeddingManager() +// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API const SherpaMnnSpeakerEmbeddingManager * +SherpaMnnCreateSpeakerEmbeddingManager(int32_t dim); + +SHERPA_ONNX_API void SherpaMnnDestroySpeakerEmbeddingManager( + const SherpaMnnSpeakerEmbeddingManager *p); + +// Register the embedding of a user +// +// @param name The name of the user +// @param p Pointer to an array containing the embeddings. The length of the +// array must be equal to `dim` used to construct the manager `p`. +// +// @return Return 1 if added successfully. Return 0 on error +SHERPA_ONNX_API int32_t +SherpaMnnSpeakerEmbeddingManagerAdd(const SherpaMnnSpeakerEmbeddingManager *p, + const char *name, const float *v); + +// @param v Pointer to an array of embeddings. If there are n embeddings, then +// v[0] is the pointer to the 0-th array containing the embeddings +// v[1] is the pointer to the 1-st array containing the embeddings +// v[n-1] is the pointer to the last array containing the embeddings +// v[n] is a NULL pointer +// @return Return 1 if added successfully. Return 0 on error +SHERPA_ONNX_API int32_t SherpaMnnSpeakerEmbeddingManagerAddList( + const SherpaMnnSpeakerEmbeddingManager *p, const char *name, + const float **v); + +// Similar to SherpaMnnSpeakerEmbeddingManagerAddList() but the memory +// is flattened. +// +// The length of the input array should be `n * dim`. +// +// @return Return 1 if added successfully. Return 0 on error +SHERPA_ONNX_API int32_t SherpaMnnSpeakerEmbeddingManagerAddListFlattened( + const SherpaMnnSpeakerEmbeddingManager *p, const char *name, + const float *v, int32_t n); + +// Remove a user. +// @param naem The name of the user to remove. +// @return Return 1 if removed successfully; return 0 on error. +// +// Note if the user does not exist, it also returns 0. +SHERPA_ONNX_API int32_t SherpaMnnSpeakerEmbeddingManagerRemove( + const SherpaMnnSpeakerEmbeddingManager *p, const char *name); + +// Search if an existing users' embedding matches the given one. +// +// @param p Pointer to an array containing the embedding. The dim +// of the array must equal to `dim` used to construct the manager `p`. +// @param threshold A value between 0 and 1. If the similarity score exceeds +// this threshold, we say a match is found. +// @return Returns the name of the user if found. Return NULL if not found. +// If not NULL, the caller has to invoke +// SherpaMnnSpeakerEmbeddingManagerFreeSearch() to free the returned +// pointer to avoid memory leak. +SHERPA_ONNX_API const char *SherpaMnnSpeakerEmbeddingManagerSearch( + const SherpaMnnSpeakerEmbeddingManager *p, const float *v, + float threshold); + +SHERPA_ONNX_API void SherpaMnnSpeakerEmbeddingManagerFreeSearch( + const char *name); + +SHERPA_ONNX_API typedef struct SherpaMnnSpeakerEmbeddingManagerSpeakerMatch { + float score; + const char *name; +} SherpaMnnSpeakerEmbeddingManagerSpeakerMatch; + +SHERPA_ONNX_API typedef struct + SherpaMnnSpeakerEmbeddingManagerBestMatchesResult { + const SherpaMnnSpeakerEmbeddingManagerSpeakerMatch *matches; + int32_t count; +} SherpaMnnSpeakerEmbeddingManagerBestMatchesResult; + +// Get the best matching speakers whose embeddings match the given +// embedding. +// +// @param p Pointer to the SherpaMnnSpeakerEmbeddingManager instance. +// @param v Pointer to an array containing the embedding vector. +// @param threshold Minimum similarity score required for a match (between 0 and +// 1). +// @param n Number of best matches to retrieve. +// @return Returns a pointer to +// SherpaMnnSpeakerEmbeddingManagerBestMatchesResult +// containing the best matches found. Returns NULL if no matches are +// found. The caller is responsible for freeing the returned pointer +// using SherpaMnnSpeakerEmbeddingManagerFreeBestMatches() to +// avoid memory leaks. +SHERPA_ONNX_API const SherpaMnnSpeakerEmbeddingManagerBestMatchesResult * +SherpaMnnSpeakerEmbeddingManagerGetBestMatches( + const SherpaMnnSpeakerEmbeddingManager *p, const float *v, float threshold, + int32_t n); + +SHERPA_ONNX_API void SherpaMnnSpeakerEmbeddingManagerFreeBestMatches( + const SherpaMnnSpeakerEmbeddingManagerBestMatchesResult *r); + +// Check whether the input embedding matches the embedding of the input +// speaker. +// +// It is for speaker verification. +// +// @param name The target speaker name. +// @param p The input embedding to check. +// @param threshold A value between 0 and 1. +// @return Return 1 if it matches. Otherwise, it returns 0. +SHERPA_ONNX_API int32_t SherpaMnnSpeakerEmbeddingManagerVerify( + const SherpaMnnSpeakerEmbeddingManager *p, const char *name, + const float *v, float threshold); + +// Return 1 if the user with the name is in the manager. +// Return 0 if the user does not exist. +SHERPA_ONNX_API int32_t SherpaMnnSpeakerEmbeddingManagerContains( + const SherpaMnnSpeakerEmbeddingManager *p, const char *name); + +// Return number of speakers in the manager. +SHERPA_ONNX_API int32_t SherpaMnnSpeakerEmbeddingManagerNumSpeakers( + const SherpaMnnSpeakerEmbeddingManager *p); + +// Return the name of all speakers in the manager. +// +// @return Return an array of pointers `ans`. If there are n speakers, then +// - ans[0] contains the name of the 0-th speaker +// - ans[1] contains the name of the 1-st speaker +// - ans[n-1] contains the name of the last speaker +// - ans[n] is NULL +// If there are no users at all, then ans[0] is NULL. In any case, +// `ans` is not NULL. +// +// Each name is NULL-terminated +// +// The caller has to invoke SherpaMnnSpeakerEmbeddingManagerFreeAllSpeakers() +// to free the returned pointer to avoid memory leak. +SHERPA_ONNX_API const char *const * +SherpaMnnSpeakerEmbeddingManagerGetAllSpeakers( + const SherpaMnnSpeakerEmbeddingManager *p); + +SHERPA_ONNX_API void SherpaMnnSpeakerEmbeddingManagerFreeAllSpeakers( + const char *const *names); + +// ============================================================ +// For audio tagging +// ============================================================ +SHERPA_ONNX_API typedef struct + SherpaMnnOfflineZipformerAudioTaggingModelConfig { + const char *model; +} SherpaMnnOfflineZipformerAudioTaggingModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnAudioTaggingModelConfig { + SherpaMnnOfflineZipformerAudioTaggingModelConfig zipformer; + const char *ced; + int32_t num_threads; + int32_t debug; // true to print debug information of the model + const char *provider; +} SherpaMnnAudioTaggingModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnAudioTaggingConfig { + SherpaMnnAudioTaggingModelConfig model; + const char *labels; + int32_t top_k; +} SherpaMnnAudioTaggingConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnAudioEvent { + const char *name; + int32_t index; + float prob; +} SherpaMnnAudioEvent; + +SHERPA_ONNX_API typedef struct SherpaMnnAudioTagging SherpaMnnAudioTagging; + +// The user has to invoke +// SherpaMnnDestroyAudioTagging() +// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API const SherpaMnnAudioTagging *SherpaMnnCreateAudioTagging( + const SherpaMnnAudioTaggingConfig *config); + +SHERPA_ONNX_API void SherpaMnnDestroyAudioTagging( + const SherpaMnnAudioTagging *tagger); + +// The user has to invoke SherpaMnnDestroyOfflineStream() +// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API const SherpaMnnOfflineStream * +SherpaMnnAudioTaggingCreateOfflineStream(const SherpaMnnAudioTagging *tagger); + +// Return an array of pointers. The length of the array is top_k + 1. +// If top_k is -1, then config.top_k is used, where config is the config +// used to create the input tagger. +// +// The ans[0]->prob has the largest probability among the array elements +// The last element of the array is a null pointer +// +// The user has to use SherpaMnnAudioTaggingFreeResults() +// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API const SherpaMnnAudioEvent *const * +SherpaMnnAudioTaggingCompute(const SherpaMnnAudioTagging *tagger, + const SherpaMnnOfflineStream *s, int32_t top_k); + +SHERPA_ONNX_API void SherpaMnnAudioTaggingFreeResults( + const SherpaMnnAudioEvent *const *p); + +// ============================================================ +// For punctuation +// ============================================================ + +SHERPA_ONNX_API typedef struct SherpaMnnOfflinePunctuationModelConfig { + const char *ct_transformer; + int32_t num_threads; + int32_t debug; // true to print debug information of the model + const char *provider; +} SherpaMnnOfflinePunctuationModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflinePunctuationConfig { + SherpaMnnOfflinePunctuationModelConfig model; +} SherpaMnnOfflinePunctuationConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflinePunctuation + SherpaMnnOfflinePunctuation; + +// The user has to invoke SherpaMnnDestroyOfflinePunctuation() +// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API const SherpaMnnOfflinePunctuation * +SherpaMnnCreateOfflinePunctuation( + const SherpaMnnOfflinePunctuationConfig *config); + +SHERPA_ONNX_API void SherpaMnnDestroyOfflinePunctuation( + const SherpaMnnOfflinePunctuation *punct); + +// Add punctuations to the input text. +// The user has to invoke SherpaOfflinePunctuationFreeText() +// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API const char *SherpaOfflinePunctuationAddPunct( + const SherpaMnnOfflinePunctuation *punct, const char *text); + +SHERPA_ONNX_API void SherpaOfflinePunctuationFreeText(const char *text); + +SHERPA_ONNX_API typedef struct SherpaMnnOnlinePunctuationModelConfig { + const char *cnn_bilstm; + const char *bpe_vocab; + int32_t num_threads; + int32_t debug; + const char *provider; +} SherpaMnnOnlinePunctuationModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOnlinePunctuationConfig { + SherpaMnnOnlinePunctuationModelConfig model; +} SherpaMnnOnlinePunctuationConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOnlinePunctuation + SherpaMnnOnlinePunctuation; + +// Create an online punctuation processor. The user has to invoke +// SherpaMnnDestroyOnlinePunctuation() to free the returned pointer +// to avoid memory leak +SHERPA_ONNX_API const SherpaMnnOnlinePunctuation * +SherpaMnnCreateOnlinePunctuation( + const SherpaMnnOnlinePunctuationConfig *config); + +// Free a pointer returned by SherpaMnnCreateOnlinePunctuation() +SHERPA_ONNX_API void SherpaMnnDestroyOnlinePunctuation( + const SherpaMnnOnlinePunctuation *punctuation); + +// Add punctuations to the input text. The user has to invoke +// SherpaMnnOnlinePunctuationFreeText() to free the returned pointer +// to avoid memory leak +SHERPA_ONNX_API const char *SherpaMnnOnlinePunctuationAddPunct( + const SherpaMnnOnlinePunctuation *punctuation, const char *text); + +// Free a pointer returned by SherpaMnnOnlinePunctuationAddPunct() +SHERPA_ONNX_API void SherpaMnnOnlinePunctuationFreeText(const char *text); + +// for resampling +SHERPA_ONNX_API typedef struct SherpaMnnLinearResampler + SherpaMnnLinearResampler; + +/* + float min_freq = min(sampling_rate_in_hz, samp_rate_out_hz); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + int32_t lowpass_filter_width = 6; + + You can set filter_cutoff_hz to lowpass_cutoff + sand set num_zeros to lowpass_filter_width +*/ +// The user has to invoke SherpaMnnDestroyLinearResampler() +// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API const SherpaMnnLinearResampler * +SherpaMnnCreateLinearResampler(int32_t samp_rate_in_hz, + int32_t samp_rate_out_hz, + float filter_cutoff_hz, int32_t num_zeros); + +SHERPA_ONNX_API void SherpaMnnDestroyLinearResampler( + const SherpaMnnLinearResampler *p); + +SHERPA_ONNX_API void SherpaMnnLinearResamplerReset( + const SherpaMnnLinearResampler *p); + +typedef struct SherpaMnnResampleOut { + const float *samples; + int32_t n; +} SherpaMnnResampleOut; +// The user has to invoke SherpaMnnLinearResamplerResampleFree() +// to free the returned pointer to avoid memory leak. +// +// If this is the last segment, you can set flush to 1; otherwise, please +// set flush to 0 +SHERPA_ONNX_API const SherpaMnnResampleOut *SherpaMnnLinearResamplerResample( + const SherpaMnnLinearResampler *p, const float *input, int32_t input_dim, + int32_t flush); + +SHERPA_ONNX_API void SherpaMnnLinearResamplerResampleFree( + const SherpaMnnResampleOut *p); + +SHERPA_ONNX_API int32_t SherpaMnnLinearResamplerResampleGetInputSampleRate( + const SherpaMnnLinearResampler *p); + +SHERPA_ONNX_API int32_t SherpaMnnLinearResamplerResampleGetOutputSampleRate( + const SherpaMnnLinearResampler *p); + +// Return 1 if the file exists; return 0 if the file does not exist. +SHERPA_ONNX_API int32_t SherpaMnnFileExists(const char *filename); + +// ========================================================================= +// For offline speaker diarization (i.e., non-streaming speaker diarization) +// ========================================================================= +SHERPA_ONNX_API typedef struct + SherpaMnnOfflineSpeakerSegmentationPyannoteModelConfig { + const char *model; +} SherpaMnnOfflineSpeakerSegmentationPyannoteModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineSpeakerSegmentationModelConfig { + SherpaMnnOfflineSpeakerSegmentationPyannoteModelConfig pyannote; + int32_t num_threads; // 1 + int32_t debug; // false + const char *provider; // "cpu" +} SherpaMnnOfflineSpeakerSegmentationModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnFastClusteringConfig { + // If greater than 0, then threshold is ignored. + // + // We strongly recommend that you set it if you know the number of clusters + // in advance + int32_t num_clusters; + + // distance threshold. + // + // The smaller, the more clusters it will generate. + // The larger, the fewer clusters it will generate. + float threshold; +} SherpaMnnFastClusteringConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineSpeakerDiarizationConfig { + SherpaMnnOfflineSpeakerSegmentationModelConfig segmentation; + SherpaMnnSpeakerEmbeddingExtractorConfig embedding; + SherpaMnnFastClusteringConfig clustering; + + // if a segment is less than this value, then it is discarded + float min_duration_on; // in seconds + + // if the gap between to segments of the same speaker is less than this value, + // then these two segments are merged into a single segment. + // We do this recursively. + float min_duration_off; // in seconds +} SherpaMnnOfflineSpeakerDiarizationConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineSpeakerDiarization + SherpaMnnOfflineSpeakerDiarization; + +// The users has to invoke SherpaMnnDestroyOfflineSpeakerDiarization() +// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API const SherpaMnnOfflineSpeakerDiarization * +SherpaMnnCreateOfflineSpeakerDiarization( + const SherpaMnnOfflineSpeakerDiarizationConfig *config); + +// Free the pointer returned by SherpaMnnCreateOfflineSpeakerDiarization() +SHERPA_ONNX_API void SherpaMnnDestroyOfflineSpeakerDiarization( + const SherpaMnnOfflineSpeakerDiarization *sd); + +// Expected sample rate of the input audio samples +SHERPA_ONNX_API int32_t SherpaMnnOfflineSpeakerDiarizationGetSampleRate( + const SherpaMnnOfflineSpeakerDiarization *sd); + +// Only config->clustering is used. All other fields are ignored +SHERPA_ONNX_API void SherpaMnnOfflineSpeakerDiarizationSetConfig( + const SherpaMnnOfflineSpeakerDiarization *sd, + const SherpaMnnOfflineSpeakerDiarizationConfig *config); + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineSpeakerDiarizationResult + SherpaMnnOfflineSpeakerDiarizationResult; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineSpeakerDiarizationSegment { + float start; + float end; + int32_t speaker; +} SherpaMnnOfflineSpeakerDiarizationSegment; + +SHERPA_ONNX_API int32_t SherpaMnnOfflineSpeakerDiarizationResultGetNumSpeakers( + const SherpaMnnOfflineSpeakerDiarizationResult *r); + +SHERPA_ONNX_API int32_t SherpaMnnOfflineSpeakerDiarizationResultGetNumSegments( + const SherpaMnnOfflineSpeakerDiarizationResult *r); + +// The user has to invoke SherpaMnnOfflineSpeakerDiarizationDestroySegment() +// to free the returned pointer to avoid memory leak. +// +// The returned pointer is the start address of an array. +// Number of entries in the array equals to the value +// returned by SherpaMnnOfflineSpeakerDiarizationResultGetNumSegments() +SHERPA_ONNX_API const SherpaMnnOfflineSpeakerDiarizationSegment * +SherpaMnnOfflineSpeakerDiarizationResultSortByStartTime( + const SherpaMnnOfflineSpeakerDiarizationResult *r); + +SHERPA_ONNX_API void SherpaMnnOfflineSpeakerDiarizationDestroySegment( + const SherpaMnnOfflineSpeakerDiarizationSegment *s); + +typedef int32_t (*SherpaMnnOfflineSpeakerDiarizationProgressCallback)( + int32_t num_processed_chunks, int32_t num_total_chunks, void *arg); + +typedef int32_t (*SherpaMnnOfflineSpeakerDiarizationProgressCallbackNoArg)( + int32_t num_processed_chunks, int32_t num_total_chunks); + +// The user has to invoke SherpaMnnOfflineSpeakerDiarizationDestroyResult() +// to free the returned pointer to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnOfflineSpeakerDiarizationResult * +SherpaMnnOfflineSpeakerDiarizationProcess( + const SherpaMnnOfflineSpeakerDiarization *sd, const float *samples, + int32_t n); + +// The user has to invoke SherpaMnnOfflineSpeakerDiarizationDestroyResult() +// to free the returned pointer to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnOfflineSpeakerDiarizationResult * +SherpaMnnOfflineSpeakerDiarizationProcessWithCallback( + const SherpaMnnOfflineSpeakerDiarization *sd, const float *samples, + int32_t n, SherpaMnnOfflineSpeakerDiarizationProgressCallback callback, + void *arg); + +SHERPA_ONNX_API const SherpaMnnOfflineSpeakerDiarizationResult * +SherpaMnnOfflineSpeakerDiarizationProcessWithCallbackNoArg( + const SherpaMnnOfflineSpeakerDiarization *sd, const float *samples, + int32_t n, + SherpaMnnOfflineSpeakerDiarizationProgressCallbackNoArg callback); + +SHERPA_ONNX_API void SherpaMnnOfflineSpeakerDiarizationDestroyResult( + const SherpaMnnOfflineSpeakerDiarizationResult *r); + +// ========================================================================= +// For offline speech enhancement +// ========================================================================= +SHERPA_ONNX_API typedef struct SherpaMnnOfflineSpeechDenoiserGtcrnModelConfig { + const char *model; +} SherpaMnnOfflineSpeechDenoiserGtcrnModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineSpeechDenoiserModelConfig { + SherpaMnnOfflineSpeechDenoiserGtcrnModelConfig gtcrn; + int32_t num_threads; + int32_t debug; // true to print debug information of the model + const char *provider; +} SherpaMnnOfflineSpeechDenoiserModelConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineSpeechDenoiserConfig { + SherpaMnnOfflineSpeechDenoiserModelConfig model; +} SherpaMnnOfflineSpeechDenoiserConfig; + +SHERPA_ONNX_API typedef struct SherpaMnnOfflineSpeechDenoiser + SherpaMnnOfflineSpeechDenoiser; + +// The users has to invoke SherpaMnnDestroyOfflineSpeechDenoiser() +// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API const SherpaMnnOfflineSpeechDenoiser * +SherpaMnnCreateOfflineSpeechDenoiser( + const SherpaMnnOfflineSpeechDenoiserConfig *config); + +// Free the pointer returned by SherpaMnnCreateOfflineSpeechDenoiser() +SHERPA_ONNX_API void SherpaMnnDestroyOfflineSpeechDenoiser( + const SherpaMnnOfflineSpeechDenoiser *sd); + +SHERPA_ONNX_API int32_t SherpaMnnOfflineSpeechDenoiserGetSampleRate( + const SherpaMnnOfflineSpeechDenoiser *sd); + +SHERPA_ONNX_API typedef struct SherpaMnnDenoisedAudio { + const float *samples; // in the range [-1, 1] + int32_t n; // number of samples + int32_t sample_rate; +} SherpaMnnDenoisedAudio; + +// Run speech denosing on input samples +// @param samples A 1-D array containing the input audio samples. Each sample +// should be in the range [-1, 1]. +// @param n Number of samples +// @param sample_rate Sample rate of the input samples +// +// The user MUST use SherpaMnnDestroyDenoisedAudio() to free the returned +// pointer to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnDenoisedAudio * +SherpaMnnOfflineSpeechDenoiserRun(const SherpaMnnOfflineSpeechDenoiser *sd, + const float *samples, int32_t n, + int32_t sample_rate); + +SHERPA_ONNX_API void SherpaMnnDestroyDenoisedAudio( + const SherpaMnnDenoisedAudio *p); + +#ifdef __OHOS__ + +// It is for HarmonyOS +typedef struct NativeResourceManager NativeResourceManager; + +SHERPA_ONNX_API const SherpaMnnOfflineSpeechDenoiser * +SherpaMnnCreateOfflineSpeechDenoiserOHOS( + const SherpaMnnOfflineSpeechDenoiserConfig *config, + NativeResourceManager *mgr); + +/// @param config Config for the recognizer. +/// @return Return a pointer to the recognizer. The user has to invoke +// SherpaMnnDestroyOnlineRecognizer() to free it to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnOnlineRecognizer * +SherpaMnnCreateOnlineRecognizerOHOS( + const SherpaMnnOnlineRecognizerConfig *config, NativeResourceManager *mgr); + +/// @param config Config for the recognizer. +/// @return Return a pointer to the recognizer. The user has to invoke +// SherpaMnnDestroyOfflineRecognizer() to free it to avoid memory +// leak. +SHERPA_ONNX_API const SherpaMnnOfflineRecognizer * +SherpaMnnCreateOfflineRecognizerOHOS( + const SherpaMnnOfflineRecognizerConfig *config, + NativeResourceManager *mgr); + +// Return an instance of VoiceActivityDetector. +// The user has to use SherpaMnnDestroyVoiceActivityDetector() to free +// the returned pointer to avoid memory leak. +SHERPA_ONNX_API const SherpaMnnVoiceActivityDetector * +SherpaMnnCreateVoiceActivityDetectorOHOS( + const SherpaMnnVadModelConfig *config, float buffer_size_in_seconds, + NativeResourceManager *mgr); + +SHERPA_ONNX_API const SherpaMnnOfflineTts *SherpaMnnCreateOfflineTtsOHOS( + const SherpaMnnOfflineTtsConfig *config, NativeResourceManager *mgr); + +SHERPA_ONNX_API const SherpaMnnSpeakerEmbeddingExtractor * +SherpaMnnCreateSpeakerEmbeddingExtractorOHOS( + const SherpaMnnSpeakerEmbeddingExtractorConfig *config, + NativeResourceManager *mgr); + +SHERPA_ONNX_API const SherpaMnnKeywordSpotter * +SherpaMnnCreateKeywordSpotterOHOS(const SherpaMnnKeywordSpotterConfig *config, + NativeResourceManager *mgr); + +SHERPA_ONNX_API const SherpaMnnOfflineSpeakerDiarization * +SherpaMnnCreateOfflineSpeakerDiarizationOHOS( + const SherpaMnnOfflineSpeakerDiarizationConfig *config, + NativeResourceManager *mgr); +#endif + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#endif // SHERPA_ONNX_C_API_C_API_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/c-api/cxx-api.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/c-api/cxx-api.cc new file mode 100644 index 00000000..50ce4142 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/c-api/cxx-api.cc @@ -0,0 +1,561 @@ +// sherpa-mnn/c-api/cxx-api.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-mnn/c-api/cxx-api.h" + +#include +#include + +namespace sherpa_mnn::cxx { + +Wave ReadWave(const std::string &filename) { + auto p = SherpaMnnReadWave(filename.c_str()); + + Wave ans; + if (p) { + ans.samples.resize(p->num_samples); + + std::copy(p->samples, p->samples + p->num_samples, ans.samples.data()); + + ans.sample_rate = p->sample_rate; + SherpaMnnFreeWave(p); + } + + return ans; +} + +bool WriteWave(const std::string &filename, const Wave &wave) { + return SherpaMnnWriteWave(wave.samples.data(), wave.samples.size(), + wave.sample_rate, filename.c_str()); +} + +OnlineStream::OnlineStream(const SherpaMnnOnlineStream *p) + : MoveOnly(p) {} + +void OnlineStream::Destroy(const SherpaMnnOnlineStream *p) const { + SherpaMnnDestroyOnlineStream(p); +} + +void OnlineStream::AcceptWaveform(int32_t sample_rate, const float *samples, + int32_t n) const { + SherpaMnnOnlineStreamAcceptWaveform(p_, sample_rate, samples, n); +} + +void OnlineStream::InputFinished() const { + SherpaMnnOnlineStreamInputFinished(p_); +} + +OnlineRecognizer OnlineRecognizer::Create( + const OnlineRecognizerConfig &config) { + struct SherpaMnnOnlineRecognizerConfig c; + memset(&c, 0, sizeof(c)); + + c.feat_config.sample_rate = config.feat_config.sample_rate; + c.feat_config.feature_dim = config.feat_config.feature_dim; + + c.model_config.transducer.encoder = + config.model_config.transducer.encoder.c_str(); + c.model_config.transducer.decoder = + config.model_config.transducer.decoder.c_str(); + c.model_config.transducer.joiner = + config.model_config.transducer.joiner.c_str(); + + c.model_config.paraformer.encoder = + config.model_config.paraformer.encoder.c_str(); + c.model_config.paraformer.decoder = + config.model_config.paraformer.decoder.c_str(); + + c.model_config.zipformer2_ctc.model = + config.model_config.zipformer2_ctc.model.c_str(); + + c.model_config.tokens = config.model_config.tokens.c_str(); + c.model_config.num_threads = config.model_config.num_threads; + c.model_config.provider = config.model_config.provider.c_str(); + c.model_config.debug = config.model_config.debug; + c.model_config.model_type = config.model_config.model_type.c_str(); + c.model_config.modeling_unit = config.model_config.modeling_unit.c_str(); + c.model_config.bpe_vocab = config.model_config.bpe_vocab.c_str(); + c.model_config.tokens_buf = config.model_config.tokens_buf.c_str(); + c.model_config.tokens_buf_size = config.model_config.tokens_buf.size(); + + c.decoding_method = config.decoding_method.c_str(); + c.max_active_paths = config.max_active_paths; + c.enable_endpoint = config.enable_endpoint; + c.rule1_min_trailing_silence = config.rule1_min_trailing_silence; + c.rule2_min_trailing_silence = config.rule2_min_trailing_silence; + c.rule3_min_utterance_length = config.rule3_min_utterance_length; + c.hotwords_file = config.hotwords_file.c_str(); + c.hotwords_score = config.hotwords_score; + + c.ctc_fst_decoder_config.graph = config.ctc_fst_decoder_config.graph.c_str(); + c.ctc_fst_decoder_config.max_active = + config.ctc_fst_decoder_config.max_active; + + c.rule_fsts = config.rule_fsts.c_str(); + c.rule_fars = config.rule_fars.c_str(); + + c.blank_penalty = config.blank_penalty; + + c.hotwords_buf = config.hotwords_buf.c_str(); + c.hotwords_buf_size = config.hotwords_buf.size(); + + auto p = SherpaMnnCreateOnlineRecognizer(&c); + return OnlineRecognizer(p); +} + +OnlineRecognizer::OnlineRecognizer(const SherpaMnnOnlineRecognizer *p) + : MoveOnly(p) {} + +void OnlineRecognizer::Destroy(const SherpaMnnOnlineRecognizer *p) const { + SherpaMnnDestroyOnlineRecognizer(p); +} + +OnlineStream OnlineRecognizer::CreateStream() const { + auto s = SherpaMnnCreateOnlineStream(p_); + return OnlineStream{s}; +} + +OnlineStream OnlineRecognizer::CreateStream(const std::string &hotwords) const { + auto s = SherpaMnnCreateOnlineStreamWithHotwords(p_, hotwords.c_str()); + return OnlineStream{s}; +} + +bool OnlineRecognizer::IsReady(const OnlineStream *s) const { + return SherpaMnnIsOnlineStreamReady(p_, s->Get()); +} + +void OnlineRecognizer::Decode(const OnlineStream *s) const { + SherpaMnnDecodeOnlineStream(p_, s->Get()); +} + +void OnlineRecognizer::Reset(const OnlineStream *s) const { + SherpaMnnOnlineStreamReset(p_, s->Get()); +} + +bool OnlineRecognizer::IsEndpoint(const OnlineStream *s) const { + return SherpaMnnOnlineStreamIsEndpoint(p_, s->Get()); +} + +void OnlineRecognizer::Decode(const OnlineStream *ss, int32_t n) const { + if (n <= 0) { + return; + } + + std::vector streams(n); + for (int32_t i = 0; i != n; ++n) { + streams[i] = ss[i].Get(); + } + + SherpaMnnDecodeMultipleOnlineStreams(p_, streams.data(), n); +} + +OnlineRecognizerResult OnlineRecognizer::GetResult( + const OnlineStream *s) const { + auto r = SherpaMnnGetOnlineStreamResult(p_, s->Get()); + + OnlineRecognizerResult ans; + ans.text = r->text; + + ans.tokens.resize(r->count); + for (int32_t i = 0; i != r->count; ++i) { + ans.tokens[i] = r->tokens_arr[i]; + } + + if (r->timestamps) { + ans.timestamps.resize(r->count); + std::copy(r->timestamps, r->timestamps + r->count, ans.timestamps.data()); + } + + ans.json = r->json; + + SherpaMnnDestroyOnlineRecognizerResult(r); + + return ans; +} + +// ============================================================================ +// Non-streaming ASR +// ============================================================================ +OfflineStream::OfflineStream(const SherpaMnnOfflineStream *p) + : MoveOnly(p) {} + +void OfflineStream::Destroy(const SherpaMnnOfflineStream *p) const { + SherpaMnnDestroyOfflineStream(p); +} + +void OfflineStream::AcceptWaveform(int32_t sample_rate, const float *samples, + int32_t n) const { + SherpaMnnAcceptWaveformOffline(p_, sample_rate, samples, n); +} + +OfflineRecognizer OfflineRecognizer::Create( + const OfflineRecognizerConfig &config) { + struct SherpaMnnOfflineRecognizerConfig c; + memset(&c, 0, sizeof(c)); + + c.feat_config.sample_rate = config.feat_config.sample_rate; + c.feat_config.feature_dim = config.feat_config.feature_dim; + c.model_config.transducer.encoder = + config.model_config.transducer.encoder.c_str(); + c.model_config.transducer.decoder = + config.model_config.transducer.decoder.c_str(); + c.model_config.transducer.joiner = + config.model_config.transducer.joiner.c_str(); + + c.model_config.paraformer.model = + config.model_config.paraformer.model.c_str(); + + c.model_config.nemo_ctc.model = config.model_config.nemo_ctc.model.c_str(); + + c.model_config.whisper.encoder = config.model_config.whisper.encoder.c_str(); + c.model_config.whisper.decoder = config.model_config.whisper.decoder.c_str(); + c.model_config.whisper.language = + config.model_config.whisper.language.c_str(); + c.model_config.whisper.task = config.model_config.whisper.task.c_str(); + c.model_config.whisper.tail_paddings = + config.model_config.whisper.tail_paddings; + + c.model_config.tdnn.model = config.model_config.tdnn.model.c_str(); + + c.model_config.tokens = config.model_config.tokens.c_str(); + c.model_config.num_threads = config.model_config.num_threads; + c.model_config.debug = config.model_config.debug; + c.model_config.provider = config.model_config.provider.c_str(); + c.model_config.model_type = config.model_config.model_type.c_str(); + c.model_config.modeling_unit = config.model_config.modeling_unit.c_str(); + c.model_config.bpe_vocab = config.model_config.bpe_vocab.c_str(); + c.model_config.telespeech_ctc = config.model_config.telespeech_ctc.c_str(); + + c.model_config.sense_voice.model = + config.model_config.sense_voice.model.c_str(); + c.model_config.sense_voice.language = + config.model_config.sense_voice.language.c_str(); + c.model_config.sense_voice.use_itn = config.model_config.sense_voice.use_itn; + + c.model_config.moonshine.preprocessor = + config.model_config.moonshine.preprocessor.c_str(); + c.model_config.moonshine.encoder = + config.model_config.moonshine.encoder.c_str(); + c.model_config.moonshine.uncached_decoder = + config.model_config.moonshine.uncached_decoder.c_str(); + c.model_config.moonshine.cached_decoder = + config.model_config.moonshine.cached_decoder.c_str(); + + c.model_config.fire_red_asr.encoder = + config.model_config.fire_red_asr.encoder.c_str(); + c.model_config.fire_red_asr.decoder = + config.model_config.fire_red_asr.decoder.c_str(); + + c.lm_config.model = config.lm_config.model.c_str(); + c.lm_config.scale = config.lm_config.scale; + + c.decoding_method = config.decoding_method.c_str(); + c.max_active_paths = config.max_active_paths; + c.hotwords_file = config.hotwords_file.c_str(); + c.hotwords_score = config.hotwords_score; + + c.rule_fsts = config.rule_fsts.c_str(); + c.rule_fars = config.rule_fars.c_str(); + + c.blank_penalty = config.blank_penalty; + + auto p = SherpaMnnCreateOfflineRecognizer(&c); + return OfflineRecognizer(p); +} + +OfflineRecognizer::OfflineRecognizer(const SherpaMnnOfflineRecognizer *p) + : MoveOnly(p) {} + +void OfflineRecognizer::Destroy(const SherpaMnnOfflineRecognizer *p) const { + SherpaMnnDestroyOfflineRecognizer(p_); +} + +OfflineStream OfflineRecognizer::CreateStream() const { + auto s = SherpaMnnCreateOfflineStream(p_); + return OfflineStream{s}; +} + +OfflineStream OfflineRecognizer::CreateStream( + const std::string &hotwords) const { + auto s = SherpaMnnCreateOfflineStreamWithHotwords(p_, hotwords.c_str()); + return OfflineStream{s}; +} + +void OfflineRecognizer::Decode(const OfflineStream *s) const { + SherpaMnnDecodeOfflineStream(p_, s->Get()); +} + +void OfflineRecognizer::Decode(const OfflineStream *ss, int32_t n) const { + if (n <= 0) { + return; + } + + std::vector streams(n); + for (int32_t i = 0; i != n; ++i) { + streams[i] = ss[i].Get(); + } + + SherpaMnnDecodeMultipleOfflineStreams(p_, streams.data(), n); +} + +OfflineRecognizerResult OfflineRecognizer::GetResult( + const OfflineStream *s) const { + auto r = SherpaMnnGetOfflineStreamResult(s->Get()); + + OfflineRecognizerResult ans; + if (r) { + ans.text = r->text; + + if (r->timestamps) { + ans.timestamps.resize(r->count); + std::copy(r->timestamps, r->timestamps + r->count, ans.timestamps.data()); + } + + ans.tokens.resize(r->count); + for (int32_t i = 0; i != r->count; ++i) { + ans.tokens[i] = r->tokens_arr[i]; + } + + ans.json = r->json; + ans.lang = r->lang ? r->lang : ""; + ans.emotion = r->emotion ? r->emotion : ""; + ans.event = r->event ? r->event : ""; + } + + SherpaMnnDestroyOfflineRecognizerResult(r); + + return ans; +} + +OfflineTts OfflineTts::Create(const OfflineTtsConfig &config) { + struct SherpaMnnOfflineTtsConfig c; + memset(&c, 0, sizeof(c)); + + c.model.vits.model = config.model.vits.model.c_str(); + c.model.vits.lexicon = config.model.vits.lexicon.c_str(); + c.model.vits.tokens = config.model.vits.tokens.c_str(); + c.model.vits.data_dir = config.model.vits.data_dir.c_str(); + c.model.vits.noise_scale = config.model.vits.noise_scale; + c.model.vits.noise_scale_w = config.model.vits.noise_scale_w; + c.model.vits.length_scale = config.model.vits.length_scale; + c.model.vits.dict_dir = config.model.vits.dict_dir.c_str(); + + c.model.matcha.acoustic_model = config.model.matcha.acoustic_model.c_str(); + c.model.matcha.vocoder = config.model.matcha.vocoder.c_str(); + c.model.matcha.lexicon = config.model.matcha.lexicon.c_str(); + c.model.matcha.tokens = config.model.matcha.tokens.c_str(); + c.model.matcha.data_dir = config.model.matcha.data_dir.c_str(); + c.model.matcha.noise_scale = config.model.matcha.noise_scale; + c.model.matcha.length_scale = config.model.matcha.length_scale; + c.model.matcha.dict_dir = config.model.matcha.dict_dir.c_str(); + + c.model.kokoro.model = config.model.kokoro.model.c_str(); + c.model.kokoro.voices = config.model.kokoro.voices.c_str(); + c.model.kokoro.tokens = config.model.kokoro.tokens.c_str(); + c.model.kokoro.data_dir = config.model.kokoro.data_dir.c_str(); + c.model.kokoro.length_scale = config.model.kokoro.length_scale; + c.model.kokoro.dict_dir = config.model.kokoro.dict_dir.c_str(); + c.model.kokoro.lexicon = config.model.kokoro.lexicon.c_str(); + + c.model.num_threads = config.model.num_threads; + c.model.debug = config.model.debug; + c.model.provider = config.model.provider.c_str(); + + c.rule_fsts = config.rule_fsts.c_str(); + c.max_num_sentences = config.max_num_sentences; + c.silence_scale = config.silence_scale; + c.rule_fars = config.rule_fars.c_str(); + + auto p = SherpaMnnCreateOfflineTts(&c); + return OfflineTts(p); +} + +OfflineTts::OfflineTts(const SherpaMnnOfflineTts *p) + : MoveOnly(p) {} + +void OfflineTts::Destroy(const SherpaMnnOfflineTts *p) const { + SherpaMnnDestroyOfflineTts(p); +} + +int32_t OfflineTts::SampleRate() const { + return SherpaMnnOfflineTtsSampleRate(p_); +} + +int32_t OfflineTts::NumSpeakers() const { + return SherpaMnnOfflineTtsNumSpeakers(p_); +} + +GeneratedAudio OfflineTts::Generate(const std::string &text, + int32_t sid /*= 0*/, float speed /*= 1.0*/, + OfflineTtsCallback callback /*= nullptr*/, + void *arg /*= nullptr*/) const { + const SherpaMnnGeneratedAudio *audio; + if (!callback) { + audio = SherpaMnnOfflineTtsGenerate(p_, text.c_str(), sid, speed); + } else { + audio = SherpaMnnOfflineTtsGenerateWithProgressCallbackWithArg( + p_, text.c_str(), sid, speed, callback, arg); + } + + GeneratedAudio ans; + ans.samples = std::vector{audio->samples, audio->samples + audio->n}; + ans.sample_rate = audio->sample_rate; + + SherpaMnnDestroyOfflineTtsGeneratedAudio(audio); + return ans; +} + +KeywordSpotter KeywordSpotter::Create(const KeywordSpotterConfig &config) { + struct SherpaMnnKeywordSpotterConfig c; + memset(&c, 0, sizeof(c)); + + c.feat_config.sample_rate = config.feat_config.sample_rate; + + c.model_config.transducer.encoder = + config.model_config.transducer.encoder.c_str(); + c.model_config.transducer.decoder = + config.model_config.transducer.decoder.c_str(); + c.model_config.transducer.joiner = + config.model_config.transducer.joiner.c_str(); + c.feat_config.feature_dim = config.feat_config.feature_dim; + + c.model_config.paraformer.encoder = + config.model_config.paraformer.encoder.c_str(); + c.model_config.paraformer.decoder = + config.model_config.paraformer.decoder.c_str(); + + c.model_config.zipformer2_ctc.model = + config.model_config.zipformer2_ctc.model.c_str(); + + c.model_config.tokens = config.model_config.tokens.c_str(); + c.model_config.num_threads = config.model_config.num_threads; + c.model_config.provider = config.model_config.provider.c_str(); + c.model_config.debug = config.model_config.debug; + c.model_config.model_type = config.model_config.model_type.c_str(); + c.model_config.modeling_unit = config.model_config.modeling_unit.c_str(); + c.model_config.bpe_vocab = config.model_config.bpe_vocab.c_str(); + c.model_config.tokens_buf = config.model_config.tokens_buf.c_str(); + c.model_config.tokens_buf_size = config.model_config.tokens_buf.size(); + + c.max_active_paths = config.max_active_paths; + c.num_trailing_blanks = config.num_trailing_blanks; + c.keywords_score = config.keywords_score; + c.keywords_threshold = config.keywords_threshold; + c.keywords_file = config.keywords_file.c_str(); + + auto p = SherpaMnnCreateKeywordSpotter(&c); + return KeywordSpotter(p); +} + +KeywordSpotter::KeywordSpotter(const SherpaMnnKeywordSpotter *p) + : MoveOnly(p) {} + +void KeywordSpotter::Destroy(const SherpaMnnKeywordSpotter *p) const { + SherpaMnnDestroyKeywordSpotter(p); +} + +OnlineStream KeywordSpotter::CreateStream() const { + auto s = SherpaMnnCreateKeywordStream(p_); + return OnlineStream{s}; +} + +OnlineStream KeywordSpotter::CreateStream(const std::string &keywords) const { + auto s = SherpaMnnCreateKeywordStreamWithKeywords(p_, keywords.c_str()); + return OnlineStream{s}; +} + +bool KeywordSpotter::IsReady(const OnlineStream *s) const { + return SherpaMnnIsKeywordStreamReady(p_, s->Get()); +} + +void KeywordSpotter::Decode(const OnlineStream *s) const { + return SherpaMnnDecodeKeywordStream(p_, s->Get()); +} + +void KeywordSpotter::Decode(const OnlineStream *ss, int32_t n) const { + if (n <= 0) { + return; + } + + std::vector streams(n); + for (int32_t i = 0; i != n; ++n) { + streams[i] = ss[i].Get(); + } + + SherpaMnnDecodeMultipleKeywordStreams(p_, streams.data(), n); +} + +KeywordResult KeywordSpotter::GetResult(const OnlineStream *s) const { + auto r = SherpaMnnGetKeywordResult(p_, s->Get()); + + KeywordResult ans; + ans.keyword = r->keyword; + + ans.tokens.resize(r->count); + for (int32_t i = 0; i < r->count; ++i) { + ans.tokens[i] = r->tokens_arr[i]; + } + + if (r->timestamps) { + ans.timestamps.resize(r->count); + std::copy(r->timestamps, r->timestamps + r->count, ans.timestamps.data()); + } + + ans.start_time = r->start_time; + ans.json = r->json; + + SherpaMnnDestroyKeywordResult(r); + + return ans; +} + +void KeywordSpotter::Reset(const OnlineStream *s) const { + SherpaMnnResetKeywordStream(p_, s->Get()); +} + +// ============================================================ +// For Offline Speech Enhancement +// ============================================================ + +OfflineSpeechDenoiser OfflineSpeechDenoiser::Create( + const OfflineSpeechDenoiserConfig &config) { + struct SherpaMnnOfflineSpeechDenoiserConfig c; + memset(&c, 0, sizeof(c)); + + c.model.gtcrn.model = config.model.gtcrn.model.c_str(); + + c.model.num_threads = config.model.num_threads; + c.model.provider = config.model.provider.c_str(); + c.model.debug = config.model.debug; + + auto p = SherpaMnnCreateOfflineSpeechDenoiser(&c); + + return OfflineSpeechDenoiser(p); +} + +void OfflineSpeechDenoiser::Destroy( + const SherpaMnnOfflineSpeechDenoiser *p) const { + SherpaMnnDestroyOfflineSpeechDenoiser(p); +} + +OfflineSpeechDenoiser::OfflineSpeechDenoiser( + const SherpaMnnOfflineSpeechDenoiser *p) + : MoveOnly(p) {} + +DenoisedAudio OfflineSpeechDenoiser::Run(const float *samples, int32_t n, + int32_t sample_rate) const { + auto audio = SherpaMnnOfflineSpeechDenoiserRun(p_, samples, n, sample_rate); + + DenoisedAudio ans; + ans.samples = {audio->samples, audio->samples + audio->n}; + ans.sample_rate = audio->sample_rate; + SherpaMnnDestroyDenoisedAudio(audio); + + return ans; +} + +int32_t OfflineSpeechDenoiser::GetSampleRate() const { + return SherpaMnnOfflineSpeechDenoiserGetSampleRate(p_); +} + +} // namespace sherpa_mnn::cxx diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/c-api/cxx-api.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/c-api/cxx-api.h new file mode 100644 index 00000000..f423a6f9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/c-api/cxx-api.h @@ -0,0 +1,505 @@ +// sherpa-mnn/c-api/cxx-api.h +// +// Copyright (c) 2024 Xiaomi Corporation + +// C++ Wrapper of the C API for sherpa-mnn +#ifndef SHERPA_ONNX_C_API_CXX_API_H_ +#define SHERPA_ONNX_C_API_CXX_API_H_ + +#include +#include + +#include "sherpa-mnn/c-api/c-api.h" + +namespace sherpa_mnn::cxx { + +// ============================================================================ +// Streaming ASR +// ============================================================================ +struct OnlineTransducerModelConfig { + std::string encoder; + std::string decoder; + std::string joiner; +}; + +struct OnlineParaformerModelConfig { + std::string encoder; + std::string decoder; +}; + +struct OnlineZipformer2CtcModelConfig { + std::string model; +}; + +struct OnlineModelConfig { + OnlineTransducerModelConfig transducer; + OnlineParaformerModelConfig paraformer; + OnlineZipformer2CtcModelConfig zipformer2_ctc; + std::string tokens; + int32_t num_threads = 1; + std::string provider = "cpu"; + bool debug = false; + std::string model_type; + std::string modeling_unit = "cjkchar"; + std::string bpe_vocab; + std::string tokens_buf; +}; + +struct FeatureConfig { + int32_t sample_rate = 16000; + int32_t feature_dim = 80; +}; + +struct OnlineCtcFstDecoderConfig { + std::string graph; + int32_t max_active = 3000; +}; + +struct OnlineRecognizerConfig { + FeatureConfig feat_config; + OnlineModelConfig model_config; + + std::string decoding_method = "greedy_search"; + + int32_t max_active_paths = 4; + + bool enable_endpoint = false; + + float rule1_min_trailing_silence = 2.4; + + float rule2_min_trailing_silence = 1.2; + + float rule3_min_utterance_length = 20; + + std::string hotwords_file; + + float hotwords_score = 1.5; + + OnlineCtcFstDecoderConfig ctc_fst_decoder_config; + std::string rule_fsts; + std::string rule_fars; + float blank_penalty = 0; + + std::string hotwords_buf; +}; + +struct OnlineRecognizerResult { + std::string text; + std::vector tokens; + std::vector timestamps; + std::string json; +}; + +struct Wave { + std::vector samples; + int32_t sample_rate; +}; + +SHERPA_ONNX_API Wave ReadWave(const std::string &filename); + +// Return true on success; +// Return false on failure +SHERPA_ONNX_API bool WriteWave(const std::string &filename, const Wave &wave); + +template +class SHERPA_ONNX_API MoveOnly { + public: + explicit MoveOnly(const T *p) : p_(p) {} + + ~MoveOnly() { Destroy(); } + + MoveOnly(const MoveOnly &) = delete; + + MoveOnly &operator=(const MoveOnly &) = delete; + + MoveOnly(MoveOnly &&other) : p_(other.Release()) {} + + MoveOnly &operator=(MoveOnly &&other) { + if (&other == this) { + return *this; + } + + Destroy(); + + p_ = other.Release(); + + return *this; + } + + const T *Get() const { return p_; } + + const T *Release() { + const T *p = p_; + p_ = nullptr; + return p; + } + + private: + void Destroy() { + if (p_ == nullptr) { + return; + } + + static_cast(this)->Destroy(p_); + + p_ = nullptr; + } + + protected: + const T *p_ = nullptr; +}; + +class SHERPA_ONNX_API OnlineStream + : public MoveOnly { + public: + explicit OnlineStream(const SherpaMnnOnlineStream *p); + + void AcceptWaveform(int32_t sample_rate, const float *samples, + int32_t n) const; + + void InputFinished() const; + + void Destroy(const SherpaMnnOnlineStream *p) const; +}; + +class SHERPA_ONNX_API OnlineRecognizer + : public MoveOnly { + public: + static OnlineRecognizer Create(const OnlineRecognizerConfig &config); + + void Destroy(const SherpaMnnOnlineRecognizer *p) const; + + OnlineStream CreateStream() const; + + OnlineStream CreateStream(const std::string &hotwords) const; + + bool IsReady(const OnlineStream *s) const; + + void Decode(const OnlineStream *s) const; + + void Decode(const OnlineStream *ss, int32_t n) const; + + OnlineRecognizerResult GetResult(const OnlineStream *s) const; + + void Reset(const OnlineStream *s) const; + + bool IsEndpoint(const OnlineStream *s) const; + + private: + explicit OnlineRecognizer(const SherpaMnnOnlineRecognizer *p); +}; + +// ============================================================================ +// Non-streaming ASR +// ============================================================================ +struct SHERPA_ONNX_API OfflineTransducerModelConfig { + std::string encoder; + std::string decoder; + std::string joiner; +}; + +struct SHERPA_ONNX_API OfflineParaformerModelConfig { + std::string model; +}; + +struct SHERPA_ONNX_API OfflineNemoEncDecCtcModelConfig { + std::string model; +}; + +struct SHERPA_ONNX_API OfflineWhisperModelConfig { + std::string encoder; + std::string decoder; + std::string language; + std::string task = "transcribe"; + int32_t tail_paddings = -1; +}; + +struct SHERPA_ONNX_API OfflineFireRedAsrModelConfig { + std::string encoder; + std::string decoder; +}; + +struct SHERPA_ONNX_API OfflineTdnnModelConfig { + std::string model; +}; + +struct SHERPA_ONNX_API OfflineSenseVoiceModelConfig { + std::string model; + std::string language; + bool use_itn = false; +}; + +struct SHERPA_ONNX_API OfflineMoonshineModelConfig { + std::string preprocessor; + std::string encoder; + std::string uncached_decoder; + std::string cached_decoder; +}; + +struct SHERPA_ONNX_API OfflineModelConfig { + OfflineTransducerModelConfig transducer; + OfflineParaformerModelConfig paraformer; + OfflineNemoEncDecCtcModelConfig nemo_ctc; + OfflineWhisperModelConfig whisper; + OfflineTdnnModelConfig tdnn; + + std::string tokens; + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + std::string model_type; + std::string modeling_unit = "cjkchar"; + std::string bpe_vocab; + std::string telespeech_ctc; + OfflineSenseVoiceModelConfig sense_voice; + OfflineMoonshineModelConfig moonshine; + OfflineFireRedAsrModelConfig fire_red_asr; +}; + +struct SHERPA_ONNX_API OfflineLMConfig { + std::string model; + float scale = 1.0; +}; + +struct SHERPA_ONNX_API OfflineRecognizerConfig { + FeatureConfig feat_config; + OfflineModelConfig model_config; + OfflineLMConfig lm_config; + + std::string decoding_method = "greedy_search"; + int32_t max_active_paths = 4; + + std::string hotwords_file; + + float hotwords_score = 1.5; + std::string rule_fsts; + std::string rule_fars; + float blank_penalty = 0; +}; + +struct SHERPA_ONNX_API OfflineRecognizerResult { + std::string text; + std::vector timestamps; + std::vector tokens; + std::string json; + std::string lang; + std::string emotion; + std::string event; +}; + +class SHERPA_ONNX_API OfflineStream + : public MoveOnly { + public: + explicit OfflineStream(const SherpaMnnOfflineStream *p); + + void AcceptWaveform(int32_t sample_rate, const float *samples, + int32_t n) const; + + void Destroy(const SherpaMnnOfflineStream *p) const; +}; + +class SHERPA_ONNX_API OfflineRecognizer + : public MoveOnly { + public: + static OfflineRecognizer Create(const OfflineRecognizerConfig &config); + + void Destroy(const SherpaMnnOfflineRecognizer *p) const; + + OfflineStream CreateStream() const; + + OfflineStream CreateStream(const std::string &hotwords) const; + + void Decode(const OfflineStream *s) const; + + void Decode(const OfflineStream *ss, int32_t n) const; + + OfflineRecognizerResult GetResult(const OfflineStream *s) const; + + private: + explicit OfflineRecognizer(const SherpaMnnOfflineRecognizer *p); +}; + +// ============================================================================ +// Non-streaming TTS +// ============================================================================ +struct OfflineTtsVitsModelConfig { + std::string model; + std::string lexicon; + std::string tokens; + std::string data_dir; + std::string dict_dir; + + float noise_scale = 0.667; + float noise_scale_w = 0.8; + float length_scale = 1.0; // < 1, faster in speed; > 1, slower in speed +}; + +struct OfflineTtsMatchaModelConfig { + std::string acoustic_model; + std::string vocoder; + std::string lexicon; + std::string tokens; + std::string data_dir; + std::string dict_dir; + + float noise_scale = 0.667; + float length_scale = 1.0; // < 1, faster in speed; > 1, slower in speed +}; + +struct OfflineTtsKokoroModelConfig { + std::string model; + std::string voices; + std::string tokens; + std::string data_dir; + std::string dict_dir; + std::string lexicon; + + float length_scale = 1.0; // < 1, faster in speed; > 1, slower in speed +}; + +struct OfflineTtsModelConfig { + OfflineTtsVitsModelConfig vits; + OfflineTtsMatchaModelConfig matcha; + OfflineTtsKokoroModelConfig kokoro; + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; +}; + +struct OfflineTtsConfig { + OfflineTtsModelConfig model; + std::string rule_fsts; + std::string rule_fars; + int32_t max_num_sentences = 1; + float silence_scale = 0.2; +}; + +struct GeneratedAudio { + std::vector samples; // in the range [-1, 1] + int32_t sample_rate; +}; + +// Return 1 to continue generating +// Return 0 to stop generating +using OfflineTtsCallback = int32_t (*)(const float *samples, + int32_t num_samples, float progress, + void *arg); + +class SHERPA_ONNX_API OfflineTts + : public MoveOnly { + public: + static OfflineTts Create(const OfflineTtsConfig &config); + + void Destroy(const SherpaMnnOfflineTts *p) const; + + // Return the sample rate of the generated audio + int32_t SampleRate() const; + + // Number of supported speakers. + // If it supports only a single speaker, then it return 0 or 1. + int32_t NumSpeakers() const; + + // @param text A string containing words separated by spaces + // @param sid Speaker ID. Used only for multi-speaker models, e.g., models + // trained using the VCTK dataset. It is not used for + // single-speaker models, e.g., models trained using the ljspeech + // dataset. + // @param speed The speed for the generated speech. E.g., 2 means 2x faster. + // @param callback If not NULL, it is called whenever config.max_num_sentences + // sentences have been processed. The callback is called in + // the current thread. + GeneratedAudio Generate(const std::string &text, int32_t sid = 0, + float speed = 1.0, + OfflineTtsCallback callback = nullptr, + void *arg = nullptr) const; + + private: + explicit OfflineTts(const SherpaMnnOfflineTts *p); +}; + +// ============================================================ +// For Keyword Spotter +// ============================================================ + +struct KeywordResult { + std::string keyword; + std::vector tokens; + std::vector timestamps; + float start_time; + std::string json; +}; + +struct KeywordSpotterConfig { + FeatureConfig feat_config; + OnlineModelConfig model_config; + int32_t max_active_paths = 4; + int32_t num_trailing_blanks = 1; + float keywords_score = 1.0f; + float keywords_threshold = 0.25f; + std::string keywords_file; +}; + +class SHERPA_ONNX_API KeywordSpotter + : public MoveOnly { + public: + static KeywordSpotter Create(const KeywordSpotterConfig &config); + + void Destroy(const SherpaMnnKeywordSpotter *p) const; + + OnlineStream CreateStream() const; + + OnlineStream CreateStream(const std::string &keywords) const; + + bool IsReady(const OnlineStream *s) const; + + void Decode(const OnlineStream *s) const; + + void Decode(const OnlineStream *ss, int32_t n) const; + + void Reset(const OnlineStream *s) const; + + KeywordResult GetResult(const OnlineStream *s) const; + + private: + explicit KeywordSpotter(const SherpaMnnKeywordSpotter *p); +}; + +struct OfflineSpeechDenoiserGtcrnModelConfig { + std::string model; +}; + +struct OfflineSpeechDenoiserModelConfig { + OfflineSpeechDenoiserGtcrnModelConfig gtcrn; + int32_t num_threads = 1; + int32_t debug = false; + std::string provider = "cpu"; +}; + +struct OfflineSpeechDenoiserConfig { + OfflineSpeechDenoiserModelConfig model; +}; + +struct DenoisedAudio { + std::vector samples; // in the range [-1, 1] + int32_t sample_rate; +}; + +class SHERPA_ONNX_API OfflineSpeechDenoiser + : public MoveOnly { + public: + static OfflineSpeechDenoiser Create( + const OfflineSpeechDenoiserConfig &config); + + void Destroy(const SherpaMnnOfflineSpeechDenoiser *p) const; + + DenoisedAudio Run(const float *samples, int32_t n, int32_t sample_rate) const; + + int32_t GetSampleRate() const; + + private: + explicit OfflineSpeechDenoiser(const SherpaMnnOfflineSpeechDenoiser *p); +}; + +} // namespace sherpa_mnn::cxx + +#endif // SHERPA_ONNX_C_API_CXX_API_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/.gitignore b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/.gitignore new file mode 100644 index 00000000..09849b0f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/.gitignore @@ -0,0 +1,2 @@ +*.cc-bak +*.h-bak diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/CMakeLists.txt b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/CMakeLists.txt new file mode 100644 index 00000000..1127fe52 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/CMakeLists.txt @@ -0,0 +1,623 @@ +include_directories(${CMAKE_SOURCE_DIR}) +message(STATUS "Currnet: ${CMAKE_SOURCE_DIR}") + +if(SHERPA_MNN_ENABLE_PYTHON) + message(STATUS "PYTHON_EXECUTABLE: ${PYTHON_EXECUTABLE}") + execute_process( + COMMAND "${PYTHON_EXECUTABLE}" -c "import sys; print('.'.join(sys.version.split('.')[:2]))" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE PYTHON_VERSION + ) + message(STATUS "PYTHON_VERSION: ${PYTHON_VERSION}") +endif() +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") + +set(sources + base64-decode.cc + bbpe.cc + cat.cc + circular-buffer.cc + context-graph.cc + endpoint.cc + features.cc + file-utils.cc + fst-utils.cc + hypothesis.cc + keyword-spotter-impl.cc + keyword-spotter.cc + offline-ctc-fst-decoder-config.cc + offline-ctc-fst-decoder.cc + offline-ctc-greedy-search-decoder.cc + offline-ctc-model.cc + offline-fire-red-asr-greedy-search-decoder.cc + offline-fire-red-asr-model-config.cc + offline-fire-red-asr-model.cc + offline-lm-config.cc + offline-lm.cc + offline-model-config.cc + offline-moonshine-greedy-search-decoder.cc + offline-moonshine-model-config.cc + offline-moonshine-model.cc + offline-nemo-enc-dec-ctc-model-config.cc + offline-nemo-enc-dec-ctc-model.cc + offline-paraformer-greedy-search-decoder.cc + offline-paraformer-model-config.cc + offline-paraformer-model.cc + offline-recognizer-impl.cc + offline-recognizer.cc + offline-rnn-lm.cc + offline-sense-voice-model-config.cc + offline-sense-voice-model.cc + offline-stream.cc + offline-tdnn-ctc-model.cc + offline-tdnn-model-config.cc + offline-telespeech-ctc-model.cc + offline-transducer-greedy-search-decoder.cc + offline-transducer-greedy-search-nemo-decoder.cc + offline-transducer-model-config.cc + offline-transducer-model.cc + offline-transducer-modified-beam-search-decoder.cc + offline-transducer-nemo-model.cc + offline-wenet-ctc-model-config.cc + offline-wenet-ctc-model.cc + offline-whisper-greedy-search-decoder.cc + offline-whisper-model-config.cc + offline-whisper-model.cc + offline-zipformer-ctc-model-config.cc + offline-zipformer-ctc-model.cc + online-conformer-transducer-model.cc + online-ctc-fst-decoder-config.cc + online-ctc-fst-decoder.cc + online-ctc-greedy-search-decoder.cc + online-ctc-model.cc + online-ebranchformer-transducer-model.cc + online-lm-config.cc + online-lm.cc + online-lstm-transducer-model.cc + online-model-config.cc + online-nemo-ctc-model-config.cc + online-nemo-ctc-model.cc + online-paraformer-model-config.cc + online-paraformer-model.cc + online-recognizer-impl.cc + online-recognizer.cc + online-rnn-lm.cc + online-stream.cc + online-transducer-decoder.cc + online-transducer-greedy-search-decoder.cc + online-transducer-greedy-search-nemo-decoder.cc + online-transducer-model-config.cc + online-transducer-model.cc + online-transducer-modified-beam-search-decoder.cc + online-transducer-nemo-model.cc + online-wenet-ctc-model-config.cc + online-wenet-ctc-model.cc + online-zipformer-transducer-model.cc + online-zipformer2-ctc-model-config.cc + online-zipformer2-ctc-model.cc + online-zipformer2-transducer-model.cc + MNNUtils.cc + packed-sequence.cc + pad-sequence.cc + parse-options.cc + provider-config.cc + provider.cc + resample.cc + session.cc + silero-vad-model-config.cc + silero-vad-model.cc + slice.cc + spoken-language-identification-impl.cc + spoken-language-identification.cc + stack.cc + symbol-table.cc + text-utils.cc + transducer-keyword-decoder.cc + transpose.cc + unbind.cc + utils.cc + vad-model-config.cc + vad-model.cc + voice-activity-detector.cc + wave-reader.cc + wave-writer.cc +) + +# speaker embedding extractor +list(APPEND sources + speaker-embedding-extractor-impl.cc + speaker-embedding-extractor-model.cc + speaker-embedding-extractor-nemo-model.cc + speaker-embedding-extractor.cc + speaker-embedding-manager.cc +) + +# audio tagging +list(APPEND sources + audio-tagging-impl.cc + audio-tagging-label-file.cc + audio-tagging-model-config.cc + audio-tagging.cc + offline-ced-model.cc + offline-zipformer-audio-tagging-model-config.cc + offline-zipformer-audio-tagging-model.cc +) + +# punctuation +list(APPEND sources + offline-ct-transformer-model.cc + offline-punctuation-impl.cc + offline-punctuation-model-config.cc + offline-punctuation.cc + online-cnn-bilstm-model.cc + online-punctuation-impl.cc + online-punctuation-model-config.cc + online-punctuation.cc +) + +if(SHERPA_MNN_ENABLE_TTS) + list(APPEND sources + hifigan-vocoder.cc + jieba-lexicon.cc + kokoro-multi-lang-lexicon.cc + lexicon.cc + melo-tts-lexicon.cc + offline-tts-character-frontend.cc + offline-tts-frontend.cc + offline-tts-impl.cc + offline-tts-kokoro-model-config.cc + offline-tts-kokoro-model.cc + offline-tts-matcha-model-config.cc + offline-tts-matcha-model.cc + offline-tts-model-config.cc + offline-tts-vits-model-config.cc + offline-tts-vits-model.cc + offline-tts.cc + piper-phonemize-lexicon.cc + ) +endif() + +list(APPEND sources + offline-speech-denoiser-gtcrn-model-config.cc + offline-speech-denoiser-gtcrn-model.cc + offline-speech-denoiser-impl.cc + offline-speech-denoiser-model-config.cc + offline-speech-denoiser.cc +) + +if(SHERPA_MNN_ENABLE_SPEAKER_DIARIZATION) + list(APPEND sources + fast-clustering-config.cc + fast-clustering.cc + offline-speaker-diarization-impl.cc + offline-speaker-diarization-result.cc + offline-speaker-diarization.cc + offline-speaker-segmentation-model-config.cc + offline-speaker-segmentation-pyannote-model-config.cc + offline-speaker-segmentation-pyannote-model.cc + ) +endif() + +if(SHERPA_MNN_ENABLE_CHECK) + list(APPEND sources log.cc) +endif() + +# Always static build +add_library(sherpa-mnn-core STATIC ${sources}) + +set_target_properties( + sherpa-mnn-core + PROPERTIES + POSITION_INDEPENDENT_CODE ON + C_VISIBILITY_PRESET hidden + CXX_VISIBILITY_PRESET hidden +) + +if(APPLE) + target_compile_options(sherpa-mnn-core PRIVATE + -Wno-deprecated-declarations + ) +endif() + +if(ANDROID_NDK) + target_link_libraries(sherpa-mnn-core android log) +endif() +target_link_libraries(sherpa-mnn-core MNN) + +target_link_libraries(sherpa-mnn-core + kaldi-native-fbank-core + kaldi-decoder-core + ssentencepiece_core +) +if(DEFINED OHOS AND x${OHOS} STREQUAL xOHOS) + target_link_libraries(sherpa-mnn-core + hilog_ndk.z + rawfile.z + ) +endif() + + +if(NOT WIN32) + target_link_libraries(sherpa-mnn-core -lm) +endif() + +if(NOT BUILD_SHARED_LIBS AND APPLE) + target_link_libraries(sherpa-mnn-core "-framework Foundation") +endif() + +target_link_libraries(sherpa-mnn-core fstfar fst) + +if(SHERPA_MNN_ENABLE_TTS) + target_link_libraries(sherpa-mnn-core piper_phonemize) + target_link_libraries(sherpa-mnn-core cppjieba) +endif() + +if(SHERPA_MNN_ENABLE_CHECK) + target_compile_definitions(sherpa-mnn-core PUBLIC SHERPA_MNN_ENABLE_CHECK=1) + + if(SHERPA_MNN_HAVE_EXECINFO_H) + target_compile_definitions(sherpa-mnn-core PRIVATE SHERPA_MNN_HAVE_EXECINFO_H=1) + endif() + + if(SHERPA_MNN_HAVE_CXXABI_H) + target_compile_definitions(sherpa-mnn-core PRIVATE SHERPA_MNN_HAVE_CXXABI_H=1) + endif() +endif() + +if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux) + # This is for linux arm32 and arm64 + target_link_libraries(sherpa-mnn-core -ldl) +endif() + +if(NOT WIN32 AND NOT SHERPA_MNN_ENABLE_WASM AND CMAKE_SYSTEM_NAME STREQUAL Linux) + target_link_libraries(sherpa-mnn-core -pthread) +endif() + +if(SHERPA_MNN_ENABLE_BINARY) + add_executable(sherpa-mnn sherpa-onnx.cc) + add_executable(sherpa-mnn-keyword-spotter sherpa-onnx-keyword-spotter.cc) + add_executable(sherpa-mnn-offline sherpa-onnx-offline.cc) + add_executable(sherpa-mnn-offline-audio-tagging sherpa-onnx-offline-audio-tagging.cc) + add_executable(sherpa-mnn-offline-language-identification sherpa-onnx-offline-language-identification.cc) + add_executable(sherpa-mnn-offline-parallel sherpa-onnx-offline-parallel.cc) + add_executable(sherpa-mnn-offline-punctuation sherpa-onnx-offline-punctuation.cc) + add_executable(sherpa-mnn-online-punctuation sherpa-onnx-online-punctuation.cc) + add_executable(sherpa-mnn-offline-denoiser sherpa-onnx-offline-denoiser.cc) + + if(SHERPA_MNN_ENABLE_TTS) + add_executable(sherpa-mnn-offline-tts sherpa-onnx-offline-tts.cc) + endif() + + if(SHERPA_MNN_ENABLE_SPEAKER_DIARIZATION) + add_executable(sherpa-mnn-offline-speaker-diarization sherpa-onnx-offline-speaker-diarization.cc) + endif() + + set(main_exes + sherpa-mnn + sherpa-mnn-keyword-spotter + sherpa-mnn-offline + sherpa-mnn-offline-audio-tagging + sherpa-mnn-offline-language-identification + sherpa-mnn-offline-parallel + sherpa-mnn-offline-punctuation + sherpa-mnn-offline-denoiser + sherpa-mnn-online-punctuation + ) + if(SHERPA_MNN_ENABLE_TTS) + list(APPEND main_exes + sherpa-mnn-offline-tts + ) + endif() + + if(SHERPA_MNN_ENABLE_SPEAKER_DIARIZATION) + list(APPEND main_exes + sherpa-mnn-offline-speaker-diarization + ) + endif() + + foreach(exe IN LISTS main_exes) + target_link_libraries(${exe} sherpa-mnn-core) + endforeach() + + if(NOT WIN32) + foreach(exe IN LISTS main_exes) + target_link_libraries(${exe} "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../lib") + target_link_libraries(${exe} "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../../../sherpa_onnx/lib") + + if(SHERPA_MNN_ENABLE_PYTHON) + target_link_libraries(${exe} "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") + endif() + endforeach() + endif() +endif() + +if(NOT BUILD_SHARED_LIBS) + install(TARGETS sherpa-mnn-core DESTINATION lib) +endif() + +if(SHERPA_MNN_ENABLE_BINARY) + install( + TARGETS + ${main_exes} + DESTINATION + bin + ) +endif() + +if(SHERPA_MNN_HAS_ALSA AND SHERPA_MNN_ENABLE_BINARY) + add_executable(sherpa-mnn-alsa sherpa-onnx-alsa.cc alsa.cc) + add_executable(sherpa-mnn-alsa-offline sherpa-onnx-alsa-offline.cc alsa.cc) + add_executable(sherpa-mnn-alsa-offline-audio-tagging sherpa-onnx-alsa-offline-audio-tagging.cc alsa.cc) + add_executable(sherpa-mnn-alsa-offline-speaker-identification sherpa-onnx-alsa-offline-speaker-identification.cc alsa.cc) + add_executable(sherpa-mnn-keyword-spotter-alsa sherpa-onnx-keyword-spotter-alsa.cc alsa.cc) + add_executable(sherpa-mnn-vad-alsa sherpa-onnx-vad-alsa.cc alsa.cc) + + + if(SHERPA_MNN_ENABLE_TTS) + add_executable(sherpa-mnn-offline-tts-play-alsa sherpa-onnx-offline-tts-play-alsa.cc alsa-play.cc) + endif() + + set(exes + sherpa-mnn-alsa + sherpa-mnn-alsa-offline + sherpa-mnn-alsa-offline-speaker-identification + sherpa-mnn-keyword-spotter-alsa + sherpa-mnn-vad-alsa + sherpa-mnn-alsa-offline-audio-tagging + ) + + if(SHERPA_MNN_ENABLE_TTS) + list(APPEND exes + sherpa-mnn-offline-tts-play-alsa + ) + endif() + + # # To fix the following error for Windows when building exe + # # mismatch detected for 'RuntimeLibrary': value 'MT_StaticRelease' doesn't match value 'MD_Dynamic Release' + + foreach(exe IN LISTS exes) + target_link_libraries(${exe} sherpa-mnn-core) + endforeach() + + foreach(exe IN LISTS exes) + if(DEFINED ENV{SHERPA_MNN_ALSA_LIB_DIR}) + target_link_libraries(${exe} -L$ENV{SHERPA_MNN_ALSA_LIB_DIR} -lasound) + else() + target_link_libraries(${exe} asound) + endif() + endforeach() + + if(NOT WIN32) + foreach(exe IN LISTS exes) + target_link_libraries(${exe} "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../lib") + target_link_libraries(${exe} "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../../../sherpa_onnx/lib") + endforeach() + + if(SHERPA_MNN_ENABLE_PYTHON) + foreach(exe IN LISTS exes) + target_link_libraries(${exe} "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") + endforeach() + endif() + endif() + + install( + TARGETS ${exes} + DESTINATION + bin + ) +endif() + +if(SHERPA_MNN_ENABLE_PORTAUDIO AND SHERPA_MNN_ENABLE_BINARY) + if(SHERPA_MNN_ENABLE_TTS) + add_executable(sherpa-mnn-offline-tts-play + sherpa-onnx-offline-tts-play.cc + microphone.cc + ) + endif() + + add_executable(sherpa-mnn-keyword-spotter-microphone + sherpa-onnx-keyword-spotter-microphone.cc + microphone.cc + ) + + add_executable(sherpa-mnn-microphone + sherpa-onnx-microphone.cc + microphone.cc + ) + + add_executable(sherpa-mnn-microphone-offline + sherpa-onnx-microphone-offline.cc + microphone.cc + ) + + add_executable(sherpa-mnn-vad-microphone + sherpa-onnx-vad-microphone.cc + microphone.cc + ) + + add_executable(sherpa-mnn-vad-with-offline-asr + sherpa-onnx-vad-with-offline-asr.cc + ) + + add_executable(sherpa-mnn-vad-microphone-offline-asr + sherpa-onnx-vad-microphone-offline-asr.cc + microphone.cc + ) + + add_executable(sherpa-mnn-microphone-offline-speaker-identification + sherpa-onnx-microphone-offline-speaker-identification.cc + microphone.cc + ) + + add_executable(sherpa-mnn-microphone-offline-audio-tagging + sherpa-onnx-microphone-offline-audio-tagging.cc + microphone.cc + ) + + set(exes + sherpa-mnn-microphone + sherpa-mnn-keyword-spotter-microphone + sherpa-mnn-microphone-offline + sherpa-mnn-microphone-offline-speaker-identification + sherpa-mnn-microphone-offline-audio-tagging + sherpa-mnn-vad-microphone + sherpa-mnn-vad-microphone-offline-asr + sherpa-mnn-vad-with-offline-asr + ) + if(SHERPA_MNN_ENABLE_TTS) + list(APPEND exes + sherpa-mnn-offline-tts-play + ) + endif() + + foreach(exe IN LISTS exes) + target_link_libraries(${exe} portaudio_static sherpa-mnn-core) + endforeach() + + if(NOT WIN32) + foreach(exe IN LISTS exes) + target_link_libraries(${exe} "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../lib") + target_link_libraries(${exe} "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../../../sherpa_onnx/lib") + endforeach() + + if(SHERPA_MNN_ENABLE_PYTHON) + foreach(exe IN LISTS exes) + target_link_libraries(${exe} "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") + endforeach() + endif() + endif() + + install( + TARGETS ${exes} + DESTINATION + bin + ) +endif() + +if(SHERPA_MNN_ENABLE_WEBSOCKET AND SHERPA_MNN_ENABLE_BINARY) + add_definitions(-DASIO_STANDALONE) + add_definitions(-D_WEBSOCKETPP_CPP11_STL_) + + add_executable(sherpa-mnn-online-websocket-server + online-websocket-server-impl.cc + online-websocket-server.cc + ) + target_link_libraries(sherpa-mnn-online-websocket-server sherpa-mnn-core) + + add_executable(sherpa-mnn-online-websocket-client + online-websocket-client.cc + ) + target_link_libraries(sherpa-mnn-online-websocket-client sherpa-mnn-core) + + if(NOT WIN32) + target_compile_options(sherpa-mnn-online-websocket-server PRIVATE -Wno-deprecated-declarations) + + target_compile_options(sherpa-mnn-online-websocket-client PRIVATE -Wno-deprecated-declarations) + endif() + + # For offline websocket + add_executable(sherpa-mnn-offline-websocket-server + offline-websocket-server-impl.cc + offline-websocket-server.cc + ) + target_link_libraries(sherpa-mnn-offline-websocket-server sherpa-mnn-core) + + if(NOT WIN32) + target_compile_options(sherpa-mnn-offline-websocket-server PRIVATE -Wno-deprecated-declarations) + endif() + + if(NOT WIN32) + target_link_libraries(sherpa-mnn-online-websocket-server "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../lib") + target_link_libraries(sherpa-mnn-online-websocket-server "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../../../sherpa_onnx/lib") + + target_link_libraries(sherpa-mnn-online-websocket-client "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../lib") + target_link_libraries(sherpa-mnn-online-websocket-client "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../../../sherpa_onnx/lib") + + target_link_libraries(sherpa-mnn-offline-websocket-server "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../lib") + target_link_libraries(sherpa-mnn-offline-websocket-server "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../../../sherpa_onnx/lib") + + if(SHERPA_MNN_ENABLE_PYTHON AND NOT WIN32) + target_link_libraries(sherpa-mnn-online-websocket-server "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") + target_link_libraries(sherpa-mnn-online-websocket-client "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") + target_link_libraries(sherpa-mnn-offline-websocket-server "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") + endif() + endif() + + install( + TARGETS + sherpa-mnn-online-websocket-server + sherpa-mnn-online-websocket-client + sherpa-mnn-offline-websocket-server + DESTINATION + bin + ) +endif() + +if(SHERPA_MNN_ENABLE_TESTS) + set(sherpa_onnx_test_srcs + cat-test.cc + circular-buffer-test.cc + context-graph-test.cc + packed-sequence-test.cc + pad-sequence-test.cc + regex-lang-test.cc + slice-test.cc + stack-test.cc + text-utils-test.cc + text2token-test.cc + transpose-test.cc + unbind-test.cc + utfcpp-test.cc + ) + if(SHERPA_MNN_ENABLE_TTS) + list(APPEND sherpa_onnx_test_srcs + cppjieba-test.cc + piper-phonemize-test.cc + ) + endif() + + if(SHERPA_MNN_ENABLE_SPEAKER_DIARIZATION) + list(APPEND sherpa_onnx_test_srcs + fast-clustering-test.cc + ) + endif() + + list(APPEND sherpa_onnx_test_srcs + speaker-embedding-manager-test.cc + ) + + function(sherpa_onnx_add_test source) + get_filename_component(name ${source} NAME_WE) + set(target_name ${name}) + add_executable(${target_name} "${source}") + + target_link_libraries(${target_name} + PRIVATE + gtest + gtest_main + sherpa-mnn-core + ) + + add_test(NAME "${target_name}" + COMMAND + $ + ) + endfunction() + + foreach(source IN LISTS sherpa_onnx_test_srcs) + sherpa_onnx_add_test(${source}) + endforeach() +endif() + +set(srcs_to_check) +foreach(s IN LISTS sources) + list(APPEND srcs_to_check ${CMAKE_CURRENT_LIST_DIR}/${s}) +endforeach() + +# For clang-tidy +add_custom_target( + clang-tidy-check + clang-tidy -p ${CMAKE_BINARY_DIR}/compile_commands.json --config-file ${CMAKE_SOURCE_DIR}/.clang-tidy ${srcs_to_check} + DEPENDS ${sources}) + +add_custom_target(check DEPENDS clang-tidy-check) diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/CPPLINT.cfg b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/CPPLINT.cfg new file mode 100644 index 00000000..d0129441 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/CPPLINT.cfg @@ -0,0 +1 @@ +exclude_files=tee-stream.h diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/MNNUtils.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/MNNUtils.cc new file mode 100644 index 00000000..4d6a80cd --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/MNNUtils.cc @@ -0,0 +1,301 @@ +// sherpa-mnn/csrc/onnx-utils.cc +// +// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2023 Pingfeng Luo +#include "sherpa-mnn/csrc/MNNUtils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +static std::string GetInputName(MNN::Express::Module *sess, size_t index, + MNNAllocator *allocator) { + return sess->getInfo()->inputNames[index]; +} + +static std::string GetOutputName(MNN::Express::Module *sess, size_t index, + MNNAllocator *allocator) { + return sess->getInfo()->outputNames[index]; +} + +void GetInputNames(MNN::Express::Module *sess, std::vector *input_names, + std::vector *input_names_ptr) { + MNNAllocator* allocator; + size_t node_count = sess->getInfo()->inputNames.size(); + input_names->resize(node_count); + input_names_ptr->resize(node_count); + for (size_t i = 0; i != node_count; ++i) { + (*input_names)[i] = GetInputName(sess, i, allocator); + (*input_names_ptr)[i] = (*input_names)[i].c_str(); + } +} + +void GetOutputNames(MNN::Express::Module *sess, std::vector *output_names, + std::vector *output_names_ptr) { + MNNAllocator* allocator; + size_t node_count = sess->getInfo()->outputNames.size(); + output_names->resize(node_count); + output_names_ptr->resize(node_count); + for (size_t i = 0; i != node_count; ++i) { + (*output_names)[i] = GetOutputName(sess, i, allocator); + (*output_names_ptr)[i] = (*output_names)[i].c_str(); + } +} + +MNN::Express::VARP GetEncoderOutFrame(MNNAllocator *allocator, MNN::Express::VARP encoder_out, + int32_t t) { + std::vector encoder_out_shape = + encoder_out->getInfo()->dim; + + auto batch_size = encoder_out_shape[0]; + auto num_frames = encoder_out_shape[1]; + assert(t < num_frames); + + auto encoder_out_dim = encoder_out_shape[2]; + + auto offset = num_frames * encoder_out_dim; + + std::array shape{batch_size, encoder_out_dim}; + + MNN::Express::VARP ans = + MNNUtilsCreateTensor(allocator, shape.data(), shape.size()); + + float *dst = ans->writeMap(); + const float *src = encoder_out->readMap(); + + for (int32_t i = 0; i != batch_size; ++i) { + std::copy(src + t * encoder_out_dim, src + (t + 1) * encoder_out_dim, dst); + src += offset; + dst += encoder_out_dim; + } + return ans; +} + +void PrintModelMetadata(std::ostream &os, const MNNMeta &meta_data) { + MNNAllocator* allocator; + for (auto& iter : meta_data) { + os << iter.first << "=" << iter.second <<"\n"; + } +} + +MNN::Express::VARP Clone(MNNAllocator *allocator, MNN::Express::VARP v) { + return MNN::Express::_Clone(v, true); +} + +MNN::Express::VARP View(MNN::Express::VARP v) { + return v; +} + +float ComputeSum(MNN::Express::VARP v, int32_t n /*= -1*/) { + std::vector shape = v->getInfo()->dim; + auto size = static_cast( + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>())); + if (n != -1 && n < size && n > 0) { + size = n; + } + + const float *p = v->readMap(); + + return std::accumulate(p, p + size, 1.0f); +} + +float ComputeMean(MNN::Express::VARP v, int32_t n /*= -1*/) { + std::vector shape = v->getInfo()->dim; + auto size = static_cast( + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>())); + + if (n != -1 && n < size && n > 0) { + size = n; + } + + auto sum = ComputeSum(v, n); + return sum / size; +} + +void PrintShape(MNN::Express::VARP v) { + std::vector shape = v->getInfo()->dim; + std::ostringstream os; + for (auto i : shape) { + os << i << ", "; + } + os << "\n"; + fprintf(stderr, "%s", os.str().c_str()); +} + +template +void Print1D(MNN::Express::VARP v) { + std::vector shape = v->getInfo()->dim; + const T *d = v->readMap(); + std::ostringstream os; + for (int32_t i = 0; i != static_cast(shape[0]); ++i) { + os << d[i] << " "; + } + os << "\n"; + fprintf(stderr, "%s\n", os.str().c_str()); +} + +template void Print1D(MNN::Express::VARP v); +template void Print1D(MNN::Express::VARP v); + +template +void Print2D(MNN::Express::VARP v) { + std::vector shape = v->getInfo()->dim; + const T *d = v->readMap(); + + std::ostringstream os; + for (int32_t r = 0; r != static_cast(shape[0]); ++r) { + for (int32_t c = 0; c != static_cast(shape[1]); ++c, ++d) { + os << *d << " "; + } + os << "\n"; + } + fprintf(stderr, "%s\n", os.str().c_str()); +} + +template void Print2D(MNN::Express::VARP v); +template void Print2D(MNN::Express::VARP v); + +void Print3D(MNN::Express::VARP v) { + std::vector shape = v->getInfo()->dim; + const float *d = v->readMap(); + + for (int32_t p = 0; p != static_cast(shape[0]); ++p) { + fprintf(stderr, "---plane %d---\n", p); + for (int32_t r = 0; r != static_cast(shape[1]); ++r) { + for (int32_t c = 0; c != static_cast(shape[2]); ++c, ++d) { + fprintf(stderr, "%.3f ", *d); + } + fprintf(stderr, "\n"); + } + } + fprintf(stderr, "\n"); +} + +void Print4D(MNN::Express::VARP v) { + std::vector shape = v->getInfo()->dim; + const float *d = v->readMap(); + + for (int32_t p = 0; p != static_cast(shape[0]); ++p) { + fprintf(stderr, "---plane %d---\n", p); + for (int32_t q = 0; q != static_cast(shape[1]); ++q) { + fprintf(stderr, "---subplane %d---\n", q); + for (int32_t r = 0; r != static_cast(shape[2]); ++r) { + for (int32_t c = 0; c != static_cast(shape[3]); ++c, ++d) { + fprintf(stderr, "%.3f ", *d); + } + fprintf(stderr, "\n"); + } + fprintf(stderr, "\n"); + } + } + fprintf(stderr, "\n"); +} + +MNN::Express::VARP Repeat(MNNAllocator *allocator, MNN::Express::VARP cur_encoder_out, + const std::vector &hyps_num_split) { + std::vector cur_encoder_out_shape = + cur_encoder_out->getInfo()->dim; + + std::array ans_shape{hyps_num_split.back(), + cur_encoder_out_shape[1]}; + + MNN::Express::VARP ans = MNNUtilsCreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + + const float *src = cur_encoder_out->readMap(); + float *dst = ans->writeMap(); + int32_t batch_size = static_cast(hyps_num_split.size()) - 1; + for (int32_t b = 0; b != batch_size; ++b) { + int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b]; + for (int32_t i = 0; i != cur_stream_hyps_num; ++i) { + std::copy(src, src + cur_encoder_out_shape[1], dst); + dst += cur_encoder_out_shape[1]; + } + src += cur_encoder_out_shape[1]; + } + return ans; +} + +CopyableOrtValue::CopyableOrtValue(const CopyableOrtValue &other) { + *this = other; +} + +CopyableOrtValue &CopyableOrtValue::operator=(const CopyableOrtValue &other) { + if (this == &other) { + return *this; + } + if (nullptr != other.value.get()) { + MNNAllocator* allocator; + value = Clone(allocator, other.value); + } + return *this; +} + +CopyableOrtValue::CopyableOrtValue(CopyableOrtValue &&other) noexcept { + *this = std::move(other); +} + +CopyableOrtValue &CopyableOrtValue::operator=( + CopyableOrtValue &&other) noexcept { + if (this == &other) { + return *this; + } + value = std::move(other.value); + return *this; +} + +std::vector Convert(std::vector values) { + std::vector ans; + ans.reserve(values.size()); + + for (auto &v : values) { + ans.emplace_back(std::move(v)); + } + + return ans; +} + +std::vector Convert(std::vector values) { + std::vector ans; + ans.reserve(values.size()); + + for (auto &v : values) { + ans.emplace_back(std::move(v.value)); + } + + return ans; +} + +std::string LookupCustomModelMetaData(const MNNMeta &meta_data, + const char *key, + MNNAllocator *allocator) { + auto iter = meta_data.find(key); + if (iter == meta_data.end()) { + return ""; + } + return iter->second; +} + +MNN::Express::VARP MNNUtilsCreateTensor(MNNAllocator* allocator, const void* data, size_t data_size, const int* shapedata, + int shapeSize, halide_type_t type ) { + std::vector s(shapedata, shapedata+shapeSize); + return MNN::Express::_Const(data, s, MNN::Express::NCHW, type); +} + +MNN::Express::VARP MNNUtilsCreateTensor(MNNAllocator* allocator, const int* shapedata, + int shapeSize, halide_type_t type) { + std::vector s(shapedata, shapedata+shapeSize); + return MNN::Express::_Input(s, MNN::Express::NCHW, type); + +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/MNNUtils.hpp b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/MNNUtils.hpp new file mode 100644 index 00000000..3dd4113e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/MNNUtils.hpp @@ -0,0 +1,142 @@ +#ifndef MNNUTILS_HPP +#define MNNUTILS_HPP +#include +#include +#include +#include +typedef std::map MNNMeta; +class MNNAllocator { + // Empty +}; + +class MNNEnv { + // Empty +}; + +class MNNConfig { +public: + std::shared_ptr pManager; + MNN::Express::Module::Config pConfig; +}; + +namespace sherpa_mnn { + +MNN::Express::VARP MNNUtilsCreateTensor(MNNAllocator* allocator, const void* data, size_t data_size, const int* shapedata, + int shapeSize, halide_type_t type = halide_type_of()); + +MNN::Express::VARP MNNUtilsCreateTensor(MNNAllocator* allocator, const int* shapedata, + int shapeSize, halide_type_t type = halide_type_of()); + +template +MNN::Express::VARP MNNUtilsCreateTensor(MNNAllocator* allocator, const T* data, size_t data_size, const int* shapedata, + int shapeSize) { + return MNNUtilsCreateTensor(allocator, data, data_size, shapedata, shapeSize, halide_type_of()); +} + + +template +MNN::Express::VARP MNNUtilsCreateTensor(MNNAllocator* allocator, const int* shapedata, + int shapeSize) { + return MNNUtilsCreateTensor(allocator, shapedata, shapeSize, halide_type_of()); +} + + +/** + * Get the input names of a model. + * + * @param sess An onnxruntime session. + * @param input_names. On return, it contains the input names of the model. + * @param input_names_ptr. On return, input_names_ptr[i] contains + * input_names[i].c_str() + */ +void GetInputNames(MNN::Express::Module *sess, std::vector *input_names, + std::vector *input_names_ptr); + +/** + * Get the output names of a model. + * + * @param sess An onnxruntime session. + * @param output_names. On return, it contains the output names of the model. + * @param output_names_ptr. On return, output_names_ptr[i] contains + * output_names[i].c_str() + */ +void GetOutputNames(MNN::Express::Module *sess, std::vector *output_names, + std::vector *output_names_ptr); + +/** + * Get the output frame of Encoder + * + * @param allocator allocator of onnxruntime + * @param encoder_out encoder out tensor + * @param t frame_index + * + */ +MNN::Express::VARP GetEncoderOutFrame(MNNAllocator *allocator, MNN::Express::VARP encoder_out, + int32_t t); + +std::string LookupCustomModelMetaData(const MNNMeta &meta_data, + const char *key, MNNAllocator *allocator); + +void PrintModelMetadata(std::ostream &os, + const MNNMeta &meta_data); // NOLINT + +// Return a deep copy of v +MNN::Express::VARP Clone(MNNAllocator *allocator, MNN::Express::VARP v); + +// Return a shallow copy +MNN::Express::VARP View(MNN::Express::VARP v); + +float ComputeSum(MNN::Express::VARP v, int32_t n = -1); +float ComputeMean(MNN::Express::VARP v, int32_t n = -1); + +// Print a 1-D tensor to stderr +template +void Print1D(MNN::Express::VARP v); + +// Print a 2-D tensor to stderr +template +void Print2D(MNN::Express::VARP v); + +// Print a 3-D tensor to stderr +void Print3D(MNN::Express::VARP v); + +// Print a 4-D tensor to stderr +void Print4D(MNN::Express::VARP v); + +void PrintShape(MNN::Express::VARP v); + +template +void Fill(MNN::Express::VARP tensor, T value) { + auto n = tensor->getInfo()->size; + auto p = tensor->writeMap(); + std::fill(p, p + n, value); +} + +// TODO(fangjun): Document it +MNN::Express::VARP Repeat(MNNAllocator *allocator, MNN::Express::VARP cur_encoder_out, + const std::vector &hyps_num_split); + +struct CopyableOrtValue { + MNN::Express::VARP value{nullptr}; + + CopyableOrtValue() = default; + + /*explicit*/ CopyableOrtValue(MNN::Express::VARP v) // NOLINT + : value(std::move(v)) {} + + CopyableOrtValue(const CopyableOrtValue &other); + + CopyableOrtValue &operator=(const CopyableOrtValue &other); + + CopyableOrtValue(CopyableOrtValue &&other) noexcept; + + CopyableOrtValue &operator=(CopyableOrtValue &&other) noexcept; +}; + +std::vector Convert(std::vector values); + +std::vector Convert(std::vector values); + + +}; +#endif diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/README.md b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/README.md new file mode 100644 index 00000000..f073bb06 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/README.md @@ -0,0 +1,29 @@ +# File descriptions + +- [./sherpa-onnx-alsa.cc](./sherpa-onnx-alsa.cc) For Linux only, especially for + embedded Linux, e.g., Raspberry Pi; it uses a streaming model for real-time + speech recognition with a microphone. + +- [./sherpa-onnx-microphone.cc](./sherpa-onnx-microphone.cc) + For Linux/Windows/macOS; it uses a streaming model for real-time speech + recognition with a microphone. + +- [./sherpa-onnx-microphone-offline.cc](./sherpa-onnx-microphone-offline.cc) + For Linux/Windows/macOS; it uses a non-streaming model for speech + recognition with a microphone. + +- [./sherpa-onnx.cc](./sherpa-onnx.cc) + It uses a streaming model to decode wave files + +- [./sherpa-onnx-offline.cc](./sherpa-onnx-offline.cc) + It uses a non-streaming model to decode wave files + +- [./online-websocket-server.cc](./online-websocket-server.cc) + WebSocket server for streaming models. + +- [./offline-websocket-server.cc](./offline-websocket-server.cc) + WebSocket server for non-streaming models. + +- [./sherpa-onnx-vad-microphone.cc](./sherpa-onnx-vad-microphone.cc) + Use silero VAD to detect speeches with a microphone. + diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/alsa-play.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/alsa-play.cc new file mode 100644 index 00000000..fb52edc8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/alsa-play.cc @@ -0,0 +1,150 @@ +// sherpa-mnn/csrc/alsa-play.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifdef SHERPA_ONNX_ENABLE_ALSA + +#include "sherpa-mnn/csrc/alsa-play.h" + +#include + +namespace sherpa_mnn { + +AlsaPlay::AlsaPlay(const char *device_name, int32_t sample_rate) { + int32_t err = snd_pcm_open(&handle_, device_name, SND_PCM_STREAM_PLAYBACK, 0); + + if (err) { + fprintf(stderr, "Unable to open: %s. %s\n", device_name, snd_strerror(err)); + exit(-1); + } + + SetParameters(sample_rate); +} + +AlsaPlay::~AlsaPlay() { + if (handle_) { + int32_t err = snd_pcm_close(handle_); + if (err < 0) { + printf("Failed to close pcm: %s\n", snd_strerror(err)); + } + } +} + +void AlsaPlay::SetParameters(int32_t sample_rate) { + // set the following parameters + // 1. sample_rate + // 2. sample format: int16_t + // 3. num_channels: 1 + snd_pcm_hw_params_t *params; + snd_pcm_hw_params_alloca(¶ms); + snd_pcm_hw_params_any(handle_, params); + + int32_t err = snd_pcm_hw_params_set_access(handle_, params, + SND_PCM_ACCESS_RW_INTERLEAVED); + if (err < 0) { + printf("SND_PCM_ACCESS_RW_INTERLEAVED is not supported: %s\n", + snd_strerror(err)); + exit(-1); + } + + err = snd_pcm_hw_params_set_format(handle_, params, SND_PCM_FORMAT_S16_LE); + + if (err < 0) { + printf("Can't set format to 16-bit: %s\n", snd_strerror(err)); + exit(-1); + } + + err = snd_pcm_hw_params_set_channels(handle_, params, 1); + + if (err < 0) { + printf("Can't set channel number to 1: %s\n", snd_strerror(err)); + } + + uint32_t rate = sample_rate; + err = snd_pcm_hw_params_set_rate_near(handle_, params, &rate, 0); + if (err < 0) { + printf("Can't set rate to %d. %s\n", rate, snd_strerror(err)); + } + + err = snd_pcm_hw_params(handle_, params); + if (err < 0) { + printf("Can't set hardware parameters. %s\n", snd_strerror(err)); + exit(-1); + } + + uint32_t tmp; + snd_pcm_hw_params_get_rate(params, &tmp, 0); + int32_t actual_sample_rate = tmp; + if (actual_sample_rate != sample_rate) { + fprintf(stderr, + "Creating a resampler:\n" + " in_sample_rate: %d\n" + " output_sample_rate: %d\n", + sample_rate, actual_sample_rate); + + float min_freq = std::min(actual_sample_rate, sample_rate); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + + int32_t lowpass_filter_width = 6; + resampler_ = std::make_unique( + sample_rate, actual_sample_rate, lowpass_cutoff, lowpass_filter_width); + } + + snd_pcm_uframes_t frames; + snd_pcm_hw_params_get_period_size(params, &frames, 0); + buf_.resize(frames); +} + +void AlsaPlay::Play(const std::vector &samples) { + std::vector tmp; + const float *p = samples.data(); + int32_t num_samples = samples.size(); + if (resampler_) { + resampler_->Resample(samples.data(), samples.size(), false, &tmp); + p = tmp.data(); + num_samples = tmp.size(); + } + + int32_t frames = buf_.size(); + int32_t i = 0; + for (; i + frames < num_samples; i += frames) { + for (int32_t k = 0; k != frames; ++k) { + buf_[k] = p[i + k] * 32767; + } + + int32_t err = snd_pcm_writei(handle_, buf_.data(), frames); + if (err == -EPIPE) { + printf("XRUN.\n"); + snd_pcm_prepare(handle_); + } else if (err < 0) { + printf("Can't write to PCM device: %s\n", snd_strerror(err)); + exit(-1); + } + } + + if (i < num_samples) { + for (int32_t k = 0; k + i < num_samples; ++k) { + buf_[k] = p[i + k] * 32767; + } + + int32_t err = snd_pcm_writei(handle_, buf_.data(), num_samples - i); + if (err == -EPIPE) { + printf("XRUN.\n"); + snd_pcm_prepare(handle_); + } else if (err < 0) { + printf("Can't write to PCM device: %s\n", snd_strerror(err)); + exit(-1); + } + } +} + +void AlsaPlay::Drain() { + int32_t err = snd_pcm_drain(handle_); + if (err < 0) { + printf("Failed to drain pcm. %s\n", snd_strerror(err)); + } +} + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_ENABLE_ALSA diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/alsa-play.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/alsa-play.h new file mode 100644 index 00000000..2e396558 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/alsa-play.h @@ -0,0 +1,37 @@ +// sherpa-mnn/csrc/alsa-play.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ALSA_PLAY_H_ +#define SHERPA_ONNX_CSRC_ALSA_PLAY_H_ + +#include +#include +#include + +#include "alsa/asoundlib.h" +#include "sherpa-mnn/csrc/resample.h" + +namespace sherpa_mnn { + +class AlsaPlay { + public: + AlsaPlay(const char *device_name, int32_t sample_rate); + ~AlsaPlay(); + void Play(const std::vector &samples); + + // wait for all the samples to be played + void Drain(); + + private: + void SetParameters(int32_t sample_rate); + + private: + snd_pcm_t *handle_ = nullptr; + std::unique_ptr resampler_; + std::vector buf_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ALSA_PLAY_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/alsa.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/alsa.cc new file mode 100644 index 00000000..91f65f09 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/alsa.cc @@ -0,0 +1,180 @@ +// sherpa-mnn/csrc/sherpa-alsa.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifdef SHERPA_ONNX_ENABLE_ALSA + +#include "sherpa-mnn/csrc/alsa.h" + +#include + +#include "alsa/asoundlib.h" + +namespace sherpa_mnn { + +void ToFloat(const std::vector &in, int32_t num_channels, + std::vector *out) { + out->resize(in.size() / num_channels); + + int32_t n = in.size(); + for (int32_t i = 0, k = 0; i < n; i += num_channels, ++k) { + (*out)[k] = in[i] / 32768.; + } +} + +Alsa::Alsa(const char *device_name) { + const char *kDeviceHelp = R"( +Please use the command: + + arecord -l + +to list all available devices. For instance, if the output is: + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and device 0 on that card, please use: + + plughw:3,0 + + )"; + + int32_t err = + snd_pcm_open(&capture_handle_, device_name, SND_PCM_STREAM_CAPTURE, 0); + if (err) { + fprintf(stderr, "Unable to open: %s. %s\n", device_name, snd_strerror(err)); + fprintf(stderr, "%s\n", kDeviceHelp); + exit(-1); + } + + snd_pcm_hw_params_t *hw_params; + snd_pcm_hw_params_alloca(&hw_params); + + err = snd_pcm_hw_params_any(capture_handle_, hw_params); + if (err) { + fprintf(stderr, "Failed to initialize hw_params: %s\n", snd_strerror(err)); + exit(-1); + } + + err = snd_pcm_hw_params_set_access(capture_handle_, hw_params, + SND_PCM_ACCESS_RW_INTERLEAVED); + if (err) { + fprintf(stderr, "Failed to set access type: %s\n", snd_strerror(err)); + exit(-1); + } + + err = snd_pcm_hw_params_set_format(capture_handle_, hw_params, + SND_PCM_FORMAT_S16_LE); + if (err) { + fprintf(stderr, "Failed to set format: %s\n", snd_strerror(err)); + exit(-1); + } + + // mono + err = snd_pcm_hw_params_set_channels(capture_handle_, hw_params, 1); + if (err) { + fprintf(stderr, "Failed to set number of channels to 1. %s\n", + snd_strerror(err)); + + err = snd_pcm_hw_params_set_channels(capture_handle_, hw_params, 2); + if (err) { + fprintf(stderr, "Failed to set number of channels to 2. %s\n", + snd_strerror(err)); + + exit(-1); + } + actual_channel_count_ = 2; + fprintf(stderr, + "Channel count is set to 2. Will use only 1 channel of it.\n"); + } + + uint32_t actual_sample_rate = expected_sample_rate_; + + int32_t dir = 0; + err = snd_pcm_hw_params_set_rate_near(capture_handle_, hw_params, + &actual_sample_rate, &dir); + if (err) { + fprintf(stderr, "Failed to set sample rate to, %d: %s\n", + expected_sample_rate_, snd_strerror(err)); + exit(-1); + } + actual_sample_rate_ = actual_sample_rate; + + if (actual_sample_rate_ != expected_sample_rate_) { + fprintf(stderr, "Failed to set sample rate to %d\n", expected_sample_rate_); + fprintf(stderr, "Current sample rate is %d\n", actual_sample_rate_); + fprintf(stderr, + "Creating a resampler:\n" + " in_sample_rate: %d\n" + " output_sample_rate: %d\n", + actual_sample_rate_, expected_sample_rate_); + + float min_freq = std::min(actual_sample_rate_, expected_sample_rate_); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + + int32_t lowpass_filter_width = 6; + resampler_ = std::make_unique( + actual_sample_rate_, expected_sample_rate_, lowpass_cutoff, + lowpass_filter_width); + } else { + fprintf(stderr, "Current sample rate: %d\n", actual_sample_rate_); + } + + err = snd_pcm_hw_params(capture_handle_, hw_params); + if (err) { + fprintf(stderr, "Failed to set hw params: %s\n", snd_strerror(err)); + exit(-1); + } + + err = snd_pcm_prepare(capture_handle_); + if (err) { + fprintf(stderr, "Failed to prepare for recording: %s\n", snd_strerror(err)); + exit(-1); + } + + fprintf(stderr, "Recording started!\n"); +} + +Alsa::~Alsa() { snd_pcm_close(capture_handle_); } + +const std::vector &Alsa::Read(int32_t num_samples) { + samples_.resize(num_samples * actual_channel_count_); + + // count is in frames. Each frame contains actual_channel_count_ samples + int32_t count = snd_pcm_readi(capture_handle_, samples_.data(), num_samples); + if (count == -EPIPE) { + static int32_t n = 0; + if (++n > 5) { + fprintf( + stderr, + "Too many overruns. It is very likely that the RTF on your board is " + "larger than 1. Please use ./bin/sherpa-mnn to compute the RTF.\n"); + exit(-1); + } + fprintf(stderr, "XRUN.\n"); + snd_pcm_prepare(capture_handle_); + + static std::vector tmp; + return tmp; + } else if (count < 0) { + fprintf(stderr, "Can't read PCM device: %s\n", snd_strerror(count)); + exit(-1); + } + + samples_.resize(count * actual_channel_count_); + + ToFloat(samples_, actual_channel_count_, &samples1_); + + if (!resampler_) { + return samples1_; + } + + resampler_->Resample(samples1_.data(), samples_.size(), false, &samples2_); + return samples2_; +} + +} // namespace sherpa_mnn + +#endif diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/alsa.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/alsa.h new file mode 100644 index 00000000..c7c6139e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/alsa.h @@ -0,0 +1,46 @@ +// sherpa-mnn/csrc/sherpa-alsa.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ALSA_H_ +#define SHERPA_ONNX_CSRC_ALSA_H_ + +#include +#include + +#include "alsa/asoundlib.h" +#include "sherpa-mnn/csrc/resample.h" + +namespace sherpa_mnn { + +class Alsa { + public: + explicit Alsa(const char *device_name); + ~Alsa(); + + // This is a blocking read. + // + // @param num_samples Number of samples to read. + // + // The returned value is valid until the next call to Read(). + const std::vector &Read(int32_t num_samples); + + int32_t GetExpectedSampleRate() const { return expected_sample_rate_; } + int32_t GetActualSampleRate() const { return actual_sample_rate_; } + + private: + snd_pcm_t *capture_handle_; + int32_t expected_sample_rate_ = 16000; + int32_t actual_sample_rate_; + + int32_t actual_channel_count_ = 1; + + std::unique_ptr resampler_; + std::vector samples_; // directly from the microphone + std::vector samples1_; // normalized version of samples_ + std::vector samples2_; // possibly resampled from samples1_ +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ALSA_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-ced-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-ced-impl.h new file mode 100644 index 00000000..0bfd31e5 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-ced-impl.h @@ -0,0 +1,111 @@ +// sherpa-mnn/csrc/audio-tagging-ced-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_ +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_ + +#include + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/audio-tagging-impl.h" +#include "sherpa-mnn/csrc/audio-tagging-label-file.h" +#include "sherpa-mnn/csrc/audio-tagging.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/math.h" +#include "sherpa-mnn/csrc/offline-ced-model.h" + +namespace sherpa_mnn { + +class AudioTaggingCEDImpl : public AudioTaggingImpl { + public: + explicit AudioTaggingCEDImpl(const AudioTaggingConfig &config) + : config_(config), model_(config.model), labels_(config.labels) { + if (model_.NumEventClasses() != labels_.NumEventClasses()) { + SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)", + model_.NumEventClasses(), labels_.NumEventClasses()); + exit(-1); + } + } + +#if __ANDROID_API__ >= 9 + explicit AudioTaggingCEDImpl(AAssetManager *mgr, + const AudioTaggingConfig &config) + : config_(config), + model_(mgr, config.model), + labels_(mgr, config.labels) { + if (model_.NumEventClasses() != labels_.NumEventClasses()) { + SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)", + model_.NumEventClasses(), labels_.NumEventClasses()); + exit(-1); + } + } +#endif + + std::unique_ptr CreateStream() const override { + return std::make_unique(CEDTag{}); + } + + std::vector Compute(OfflineStream *s, + int32_t top_k = -1) const override { + if (top_k < 0) { + top_k = config_.top_k; + } + + int32_t num_event_classes = model_.NumEventClasses(); + + if (top_k > num_event_classes) { + top_k = num_event_classes; + } + + auto memory_info = + (MNNAllocator*)(nullptr); + + // WARNING(fangjun): It is fixed to 64 for CED models + int32_t feat_dim = 64; + std::vector f = s->GetFrames(); + + int32_t num_frames = f.size() / feat_dim; + assert(feat_dim * num_frames == static_cast(f.size())); + + std::array shape = {1, num_frames, feat_dim}; + + MNN::Express::VARP x = MNNUtilsCreateTensor(memory_info, f.data(), f.size(), + shape.data(), shape.size()); + + MNN::Express::VARP probs = model_.Forward(std::move(x)); + + const float *p = probs->readMap(); + + std::vector top_k_indexes = TopkIndex(p, num_event_classes, top_k); + + std::vector ans(top_k); + + int32_t i = 0; + + for (int32_t index : top_k_indexes) { + ans[i].name = labels_.GetEventName(index); + ans[i].index = index; + ans[i].prob = p[index]; + i += 1; + } + + return ans; + } + + private: + AudioTaggingConfig config_; + OfflineCEDModel model_; + AudioTaggingLabels labels_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-impl.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-impl.cc new file mode 100644 index 00000000..94fcd820 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-impl.cc @@ -0,0 +1,48 @@ +// sherpa-mnn/csrc/audio-tagging-impl.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/audio-tagging-impl.h" + +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/audio-tagging-ced-impl.h" +#include "sherpa-mnn/csrc/audio-tagging-zipformer-impl.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +std::unique_ptr AudioTaggingImpl::Create( + const AudioTaggingConfig &config) { + if (!config.model.zipformer.model.empty()) { + return std::make_unique(config); + } else if (!config.model.ced.empty()) { + return std::make_unique(config); + } + + SHERPA_ONNX_LOGE( + "Please specify an audio tagging model! Return a null pointer"); + return nullptr; +} + +#if __ANDROID_API__ >= 9 +std::unique_ptr AudioTaggingImpl::Create( + AAssetManager *mgr, const AudioTaggingConfig &config) { + if (!config.model.zipformer.model.empty()) { + return std::make_unique(mgr, config); + } else if (!config.model.ced.empty()) { + return std::make_unique(mgr, config); + } + + SHERPA_ONNX_LOGE( + "Please specify an audio tagging model! Return a null pointer"); + return nullptr; +} +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-impl.h new file mode 100644 index 00000000..855a8fd2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-impl.h @@ -0,0 +1,39 @@ +// sherpa-mnn/csrc/audio-tagging-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_ +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_ + +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/audio-tagging.h" + +namespace sherpa_mnn { + +class AudioTaggingImpl { + public: + virtual ~AudioTaggingImpl() = default; + + static std::unique_ptr Create( + const AudioTaggingConfig &config); + +#if __ANDROID_API__ >= 9 + static std::unique_ptr Create( + AAssetManager *mgr, const AudioTaggingConfig &config); +#endif + + virtual std::unique_ptr CreateStream() const = 0; + + virtual std::vector Compute(OfflineStream *s, + int32_t top_k = -1) const = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-label-file.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-label-file.cc new file mode 100644 index 00000000..fe570a76 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-label-file.cc @@ -0,0 +1,87 @@ +// sherpa-mnn/csrc/audio-tagging-label-file.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/audio-tagging-label-file.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include + +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +AudioTaggingLabels::AudioTaggingLabels(const std::string &filename) { + std::ifstream is(filename); + Init(is); +} + +#if __ANDROID_API__ >= 9 +AudioTaggingLabels::AudioTaggingLabels(AAssetManager *mgr, + const std::string &filename) { + auto buf = ReadFile(mgr, filename); + std::istrstream is(buf.data(), buf.size()); + Init(is); +} +#endif + +// Format of a label file +/* +index,mid,display_name +0,/m/09x0r,"Speech" +1,/m/05zppz,"Male speech, man speaking" +*/ +void AudioTaggingLabels::Init(std::istream &is) { + std::string line; + std::getline(is, line); // skip the header + + std::string index; + std::string tmp; + std::string name; + + while (std::getline(is, line)) { + index.clear(); + name.clear(); + std::istringstream input2(line); + + std::getline(input2, index, ','); + std::getline(input2, tmp, ','); + std::getline(input2, name); + + std::size_t pos{}; + int32_t i = std::stoi(index, &pos); + if (index.empty() || pos != index.size()) { + SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str()); + exit(-1); + } + + if (i != static_cast(names_.size())) { + SHERPA_ONNX_LOGE( + "Index should be sorted and contiguous. Expected index: %d, given: " + "%d.", + static_cast(names_.size()), i); + } + if (name.empty() || name.front() != '"' || name.back() != '"') { + SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str()); + exit(-1); + } + + names_.emplace_back(name.begin() + 1, name.end() - 1); + } +} + +const std::string &AudioTaggingLabels::GetEventName(int32_t index) const { + return names_.at(index); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-label-file.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-label-file.h new file mode 100644 index 00000000..7c0e90eb --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-label-file.h @@ -0,0 +1,39 @@ +// sherpa-mnn/csrc/audio-tagging-label-file.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_ +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_ + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +namespace sherpa_mnn { + +class AudioTaggingLabels { + public: + explicit AudioTaggingLabels(const std::string &filename); +#if __ANDROID_API__ >= 9 + AudioTaggingLabels(AAssetManager *mgr, const std::string &filename); +#endif + + // Return the event name for the given index. + // The returned reference is valid as long as this object is alive + const std::string &GetEventName(int32_t index) const; + int32_t NumEventClasses() const { return names_.size(); } + + private: + void Init(std::istream &is); + + private: + std::vector names_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-model-config.cc new file mode 100644 index 00000000..ba5c7b1c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-model-config.cc @@ -0,0 +1,60 @@ +// sherpa-mnn/csrc/audio-tagging-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/audio-tagging-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void AudioTaggingModelConfig::Register(ParseOptions *po) { + zipformer.Register(po); + + po->Register("ced-model", &ced, + "Path to CED model. Only need to pass one of --zipformer-model " + "or --ced-model"); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool AudioTaggingModelConfig::Validate() const { + if (!zipformer.model.empty() && !zipformer.Validate()) { + return false; + } + + if (!ced.empty() && !FileExists(ced)) { + SHERPA_ONNX_LOGE("CED model file '%s' does not exist", ced.c_str()); + return false; + } + + if (zipformer.model.empty() && ced.empty()) { + SHERPA_ONNX_LOGE("Please provide either --zipformer-model or --ced-model"); + return false; + } + + return true; +} + +std::string AudioTaggingModelConfig::ToString() const { + std::ostringstream os; + + os << "AudioTaggingModelConfig("; + os << "zipformer=" << zipformer.ToString() << ", "; + os << "ced=\"" << ced << "\", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-model-config.h new file mode 100644 index 00000000..7577742f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-model-config.h @@ -0,0 +1,42 @@ +// sherpa-mnn/csrc/audio-tagging-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/offline-zipformer-audio-tagging-model-config.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct AudioTaggingModelConfig { + struct OfflineZipformerAudioTaggingModelConfig zipformer; + std::string ced; + + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + AudioTaggingModelConfig() = default; + + AudioTaggingModelConfig( + const OfflineZipformerAudioTaggingModelConfig &zipformer, + const std::string &ced, int32_t num_threads, bool debug, + const std::string &provider) + : zipformer(zipformer), + ced(ced), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-zipformer-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-zipformer-impl.h new file mode 100644 index 00000000..3e0cdfd0 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging-zipformer-impl.h @@ -0,0 +1,118 @@ +// sherpa-mnn/csrc/audio-tagging-zipformer-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ + +#include + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/audio-tagging-impl.h" +#include "sherpa-mnn/csrc/audio-tagging-label-file.h" +#include "sherpa-mnn/csrc/audio-tagging.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/math.h" +#include "sherpa-mnn/csrc/offline-zipformer-audio-tagging-model.h" + +namespace sherpa_mnn { + +class AudioTaggingZipformerImpl : public AudioTaggingImpl { + public: + explicit AudioTaggingZipformerImpl(const AudioTaggingConfig &config) + : config_(config), model_(config.model), labels_(config.labels) { + if (model_.NumEventClasses() != labels_.NumEventClasses()) { + SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)", + model_.NumEventClasses(), labels_.NumEventClasses()); + exit(-1); + } + } + +#if __ANDROID_API__ >= 9 + explicit AudioTaggingZipformerImpl(AAssetManager *mgr, + const AudioTaggingConfig &config) + : config_(config), + model_(mgr, config.model), + labels_(mgr, config.labels) { + if (model_.NumEventClasses() != labels_.NumEventClasses()) { + SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)", + model_.NumEventClasses(), labels_.NumEventClasses()); + exit(-1); + } + } +#endif + + std::unique_ptr CreateStream() const override { + return std::make_unique(); + } + + std::vector Compute(OfflineStream *s, + int32_t top_k = -1) const override { + if (top_k < 0) { + top_k = config_.top_k; + } + + int32_t num_event_classes = model_.NumEventClasses(); + + if (top_k > num_event_classes) { + top_k = num_event_classes; + } + + auto memory_info = + (MNNAllocator*)(nullptr); + + // WARNING(fangjun): It is fixed to 80 for all models from icefall + int32_t feat_dim = 80; + std::vector f = s->GetFrames(); + + int32_t num_frames = f.size() / feat_dim; + + assert(feat_dim * num_frames == static_cast(f.size())); + + std::array shape = {1, num_frames, feat_dim}; + + MNN::Express::VARP x = MNNUtilsCreateTensor(memory_info, f.data(), f.size(), + shape.data(), shape.size()); + + int x_length_scalar = num_frames; + std::array x_length_shape = {1}; + MNN::Express::VARP x_length = + MNNUtilsCreateTensor(memory_info, &x_length_scalar, 1, + x_length_shape.data(), x_length_shape.size()); + + MNN::Express::VARP probs = model_.Forward(std::move(x), std::move(x_length)); + + const float *p = probs->readMap(); + + std::vector top_k_indexes = TopkIndex(p, num_event_classes, top_k); + + std::vector ans(top_k); + + int32_t i = 0; + + for (int32_t index : top_k_indexes) { + ans[i].name = labels_.GetEventName(index); + ans[i].index = index; + ans[i].prob = p[index]; + i += 1; + } + + return ans; + } + + private: + AudioTaggingConfig config_; + OfflineZipformerAudioTaggingModel model_; + AudioTaggingLabels labels_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging.cc new file mode 100644 index 00000000..119e57f8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging.cc @@ -0,0 +1,87 @@ +// sherpa-mnn/csrc/audio-tagging.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/audio-tagging.h" + +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/audio-tagging-impl.h" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +std::string AudioEvent::ToString() const { + std::ostringstream os; + os << "AudioEvent("; + os << "name=\"" << name << "\", "; + os << "index=" << index << ", "; + os << "prob=" << prob << ")"; + return os.str(); +} + +void AudioTaggingConfig::Register(ParseOptions *po) { + model.Register(po); + po->Register("labels", &labels, "Event label file"); + po->Register("top-k", &top_k, "Top k events to return in the result"); +} + +bool AudioTaggingConfig::Validate() const { + if (!model.Validate()) { + return false; + } + + if (top_k < 1) { + SHERPA_ONNX_LOGE("--top-k should be >= 1. Given: %d", top_k); + return false; + } + + if (labels.empty()) { + SHERPA_ONNX_LOGE("Please provide --labels"); + return false; + } + + if (!FileExists(labels)) { + SHERPA_ONNX_LOGE("--labels '%s' does not exist", labels.c_str()); + return false; + } + + return true; +} +std::string AudioTaggingConfig::ToString() const { + std::ostringstream os; + + os << "AudioTaggingConfig("; + os << "model=" << model.ToString() << ", "; + os << "labels=\"" << labels << "\", "; + os << "top_k=" << top_k << ")"; + + return os.str(); +} + +AudioTagging::AudioTagging(const AudioTaggingConfig &config) + : impl_(AudioTaggingImpl::Create(config)) {} + +#if __ANDROID_API__ >= 9 +AudioTagging::AudioTagging(AAssetManager *mgr, const AudioTaggingConfig &config) + : impl_(AudioTaggingImpl::Create(mgr, config)) {} +#endif + +AudioTagging::~AudioTagging() = default; + +std::unique_ptr AudioTagging::CreateStream() const { + return impl_->CreateStream(); +} + +std::vector AudioTagging::Compute(OfflineStream *s, + int32_t top_k /*= -1*/) const { + return impl_->Compute(s, top_k); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging.h new file mode 100644 index 00000000..d0152ccb --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/audio-tagging.h @@ -0,0 +1,74 @@ +// sherpa-mnn/csrc/audio-tagging.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_ +#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_ + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/audio-tagging-model-config.h" +#include "sherpa-mnn/csrc/offline-stream.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct AudioTaggingConfig { + AudioTaggingModelConfig model; + std::string labels; + + int32_t top_k = 5; + + AudioTaggingConfig() = default; + + AudioTaggingConfig(const AudioTaggingModelConfig &model, + const std::string &labels, int32_t top_k) + : model(model), labels(labels), top_k(top_k) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +struct AudioEvent { + std::string name; // name of the event + int32_t index; // index of the event in the label file + float prob; // probability of the event + + std::string ToString() const; +}; + +class AudioTaggingImpl; + +class AudioTagging { + public: + explicit AudioTagging(const AudioTaggingConfig &config); + +#if __ANDROID_API__ >= 9 + AudioTagging(AAssetManager *mgr, const AudioTaggingConfig &config); +#endif + + ~AudioTagging(); + + std::unique_ptr CreateStream() const; + + // If top_k is -1, then config.top_k is used. + // Otherwise, config.top_k is ignored + // + // Return top_k AudioEvent. ans[0].prob is the largest of all returned events. + std::vector Compute(OfflineStream *s, int32_t top_k = -1) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/base64-decode.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/base64-decode.cc new file mode 100644 index 00000000..0f207a37 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/base64-decode.cc @@ -0,0 +1,67 @@ +// sherpa-mnn/csrc/base64-decode.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/base64-decode.h" + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +static int32_t Ord(char c) { + if (c >= 'A' && c <= 'Z') { + return c - 'A'; + } else if (c >= 'a' && c <= 'z') { + return c - 'a' + ('Z' - 'A') + 1; + } else if (c >= '0' && c <= '9') { + return c - '0' + ('Z' - 'A') + ('z' - 'a') + 2; + } else if (c == '+') { + return 62; + } else if (c == '/') { + return 63; + } + + SHERPA_ONNX_LOGE("Unknown character %d, %c\n", c, c); + + exit(-1); +} + +// see +// https://github.com/ReneNyffenegger/cpp-base64/blob/master/base64.cpp#L243 +std::string Base64Decode(const std::string &s) { + if (s.empty()) { + SHERPA_ONNX_LOGE("Empty string!"); + exit(-1); + } + + int32_t n = static_cast(s.size()) / 4 * 3; + + std::string ans; + ans.reserve(n); + + int32_t i = 0; + while (i < static_cast(s.size())) { + if (s[i] == '=') { + return " "; + } + + int32_t first = (Ord(s[i]) << 2) + ((Ord(s[i + 1]) & 0x30) >> 4); + ans.push_back(static_cast(first)); + + if (i + 2 < static_cast(s.size()) && s[i + 2] != '=') { + int32_t second = + ((Ord(s[i + 1]) & 0x0f) << 4) + ((Ord(s[i + 2]) & 0x3c) >> 2); + ans.push_back(static_cast(second)); + + if (i + 3 < static_cast(s.size()) && s[i + 3] != '=') { + int32_t third = ((Ord(s[i + 2]) & 0x03) << 6) + Ord(s[i + 3]); + ans.push_back(static_cast(third)); + } + } + i += 4; + } + + return ans; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/base64-decode.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/base64-decode.h new file mode 100644 index 00000000..6a7cdbbf --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/base64-decode.h @@ -0,0 +1,19 @@ +// sherpa-mnn/csrc/base64-decode.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_BASE64_DECODE_H_ +#define SHERPA_ONNX_CSRC_BASE64_DECODE_H_ + +#include + +namespace sherpa_mnn { + +/** @param s A base64 encoded string. + * @return Return the decoded string. + */ +std::string Base64Decode(const std::string &s); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_BASE64_DECODE_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/bbpe.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/bbpe.cc new file mode 100644 index 00000000..bebbd32b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/bbpe.cc @@ -0,0 +1,61 @@ +// sherpa-mnn/csrc/bbpe.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +// Auto-generated! DO NOT EDIT + +#include "sherpa-mnn/csrc/bbpe.h" + +#include +#include +#include + +const std::unordered_map &GetByteBpeTable() { + static const std::unordered_map table = { + {"Ā", 0}, {"ā", 1}, {"Ă", 2}, {"ă", 3}, {"Ą", 4}, {"ą", 5}, + {"Ć", 6}, {"ć", 7}, {"Ĉ", 8}, {"ĉ", 9}, {"Ċ", 10}, {"ċ", 11}, + {"Č", 12}, {"č", 13}, {"Ď", 14}, {"ď", 15}, {"Đ", 16}, {"đ", 17}, + {"Ē", 18}, {"ē", 19}, {"Ĕ", 20}, {"ĕ", 21}, {"Ė", 22}, {"ė", 23}, + {"Ę", 24}, {"ę", 25}, {"Ě", 26}, {"ě", 27}, {"Ĝ", 28}, {"ĝ", 29}, + {"Ğ", 30}, {"ğ", 31}, {" ", 32}, {"!", 33}, {"\"", 34}, {"#", 35}, + {"$", 36}, {"%", 37}, {"&", 38}, {"'", 39}, {"(", 40}, {")", 41}, + {"*", 42}, {"+", 43}, {",", 44}, {"-", 45}, {".", 46}, {"/", 47}, + {"0", 48}, {"1", 49}, {"2", 50}, {"3", 51}, {"4", 52}, {"5", 53}, + {"6", 54}, {"7", 55}, {"8", 56}, {"9", 57}, {":", 58}, {";", 59}, + {"<", 60}, {"=", 61}, {">", 62}, {"?", 63}, {"@", 64}, {"A", 65}, + {"B", 66}, {"C", 67}, {"D", 68}, {"E", 69}, {"F", 70}, {"G", 71}, + {"H", 72}, {"I", 73}, {"J", 74}, {"K", 75}, {"L", 76}, {"M", 77}, + {"N", 78}, {"O", 79}, {"P", 80}, {"Q", 81}, {"R", 82}, {"S", 83}, + {"T", 84}, {"U", 85}, {"V", 86}, {"W", 87}, {"X", 88}, {"Y", 89}, + {"Z", 90}, {"[", 91}, {"\\", 92}, {"]", 93}, {"^", 94}, {"_", 95}, + {"`", 96}, {"a", 97}, {"b", 98}, {"c", 99}, {"d", 100}, {"e", 101}, + {"f", 102}, {"g", 103}, {"h", 104}, {"i", 105}, {"j", 106}, {"k", 107}, + {"l", 108}, {"m", 109}, {"n", 110}, {"o", 111}, {"p", 112}, {"q", 113}, + {"r", 114}, {"s", 115}, {"t", 116}, {"u", 117}, {"v", 118}, {"w", 119}, + {"x", 120}, {"y", 121}, {"z", 122}, {"{", 123}, {"|", 124}, {"}", 125}, + {"~", 126}, {"Ġ", 127}, {"ġ", 128}, {"Ģ", 129}, {"ģ", 130}, {"Ĥ", 131}, + {"ĥ", 132}, {"Ħ", 133}, {"ħ", 134}, {"Ĩ", 135}, {"ĩ", 136}, {"Ī", 137}, + {"ī", 138}, {"Ĭ", 139}, {"ĭ", 140}, {"Į", 141}, {"į", 142}, {"İ", 143}, + {"ı", 144}, {"Ĵ", 145}, {"ĵ", 146}, {"Ķ", 147}, {"ķ", 148}, {"ĸ", 149}, + {"Ĺ", 150}, {"ĺ", 151}, {"Ļ", 152}, {"ļ", 153}, {"Ľ", 154}, {"ľ", 155}, + {"Ł", 156}, {"ł", 157}, {"Ń", 158}, {"ń", 159}, {"Ņ", 160}, {"ņ", 161}, + {"Ň", 162}, {"ň", 163}, {"Ŋ", 164}, {"ŋ", 165}, {"Ō", 166}, {"ō", 167}, + {"Ŏ", 168}, {"ŏ", 169}, {"Ő", 170}, {"ő", 171}, {"Œ", 172}, {"œ", 173}, + {"Ŕ", 174}, {"ŕ", 175}, {"Ŗ", 176}, {"ŗ", 177}, {"Ř", 178}, {"ř", 179}, + {"Ś", 180}, {"ś", 181}, {"Ŝ", 182}, {"ŝ", 183}, {"Ş", 184}, {"ş", 185}, + {"Š", 186}, {"š", 187}, {"Ţ", 188}, {"ţ", 189}, {"Ť", 190}, {"ť", 191}, + {"Ŧ", 192}, {"ŧ", 193}, {"Ũ", 194}, {"ũ", 195}, {"Ū", 196}, {"ū", 197}, + {"Ŭ", 198}, {"ŭ", 199}, {"Ů", 200}, {"ů", 201}, {"Ű", 202}, {"ű", 203}, + {"Ų", 204}, {"ų", 205}, {"Ŵ", 206}, {"ŵ", 207}, {"Ŷ", 208}, {"ŷ", 209}, + {"Ÿ", 210}, {"Ź", 211}, {"ź", 212}, {"Ż", 213}, {"ż", 214}, {"Ž", 215}, + {"ž", 216}, {"ƀ", 217}, {"Ɓ", 218}, {"Ƃ", 219}, {"ƃ", 220}, {"Ƅ", 221}, + {"ƅ", 222}, {"Ɔ", 223}, {"Ƈ", 224}, {"ƈ", 225}, {"Ɖ", 226}, {"Ɗ", 227}, + {"Ƌ", 228}, {"ƌ", 229}, {"ƍ", 230}, {"Ǝ", 231}, {"Ə", 232}, {"Ɛ", 233}, + {"Ƒ", 234}, {"ƒ", 235}, {"Ɠ", 236}, {"Ɣ", 237}, {"ƕ", 238}, {"Ɩ", 239}, + {"Ɨ", 240}, {"Ƙ", 241}, {"ƙ", 242}, {"ƚ", 243}, {"ƛ", 244}, {"Ɯ", 245}, + {"Ɲ", 246}, {"ƞ", 247}, {"Ɵ", 248}, {"Ơ", 249}, {"ơ", 250}, {"Ƣ", 251}, + {"ƣ", 252}, {"Ƥ", 253}, {"ƥ", 254}, {"Ʀ", 255}, {"⁇", 32}, + }; + + return table; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/bbpe.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/bbpe.h new file mode 100644 index 00000000..e76cc468 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/bbpe.h @@ -0,0 +1,16 @@ +// sherpa-mnn/csrc/bbpe.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_BBPE_H_ +#define SHERPA_ONNX_CSRC_BBPE_H_ +#include +#include +#include + +// It is equivalent to the map BCHAR_TO_BYTE +// from +// https://github.com/k2-fsa/icefall/blob/master/icefall/byte_utils.py#L280 +const std::unordered_map &GetByteBpeTable(); + +#endif // SHERPA_ONNX_CSRC_BBPE_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/cat-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/cat-test.cc new file mode 100644 index 00000000..22f31d85 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/cat-test.cc @@ -0,0 +1,254 @@ +// sherpa-mnn/csrc/cat-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/cat.h" + +#include "gtest/gtest.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +TEST(Cat, Test1DTensors) { + MNNAllocator* allocator; + + std::array a_shape{3}; + std::array b_shape{6}; + + MNN::Express::VARP a = MNNUtilsCreateTensor(allocator, a_shape.data(), + a_shape.size()); + + MNN::Express::VARP b = MNNUtilsCreateTensor(allocator, b_shape.data(), + b_shape.size()); + float *pa = a->writeMap(); + float *pb = b->writeMap(); + for (int32_t i = 0; i != static_cast(a_shape[0]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; i != static_cast(b_shape[0]); ++i) { + pb[i] = i + 10; + } + + MNN::Express::VARP ans = Cat(allocator, {&a, &b}, 0); + + const float *pans = ans->readMap(); + for (int32_t i = 0; i != static_cast(a_shape[0]); ++i) { + EXPECT_EQ(pa[i], pans[i]); + } + + for (int32_t i = 0; i != static_cast(b_shape[0]); ++i) { + EXPECT_EQ(pb[i], pans[i + a_shape[0]]); + } + + Print1D(&a); + Print1D(&b); + Print1D(&ans); +} + +TEST(Cat, Test2DTensorsDim0) { + MNNAllocator* allocator; + + std::array a_shape{2, 3}; + std::array b_shape{4, 3}; + + MNN::Express::VARP a = MNNUtilsCreateTensor(allocator, a_shape.data(), + a_shape.size()); + + MNN::Express::VARP b = MNNUtilsCreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a->writeMap(); + float *pb = b->writeMap(); + for (int32_t i = 0; i != static_cast(a_shape[0] * a_shape[1]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; i != static_cast(b_shape[0] * b_shape[1]); ++i) { + pb[i] = i + 10; + } + + MNN::Express::VARP ans = Cat(allocator, {&a, &b}, 0); + + const float *pans = ans->readMap(); + for (int32_t i = 0; i != static_cast(a_shape[0] * a_shape[1]); ++i) { + EXPECT_EQ(pa[i], pans[i]); + } + for (int32_t i = 0; i != static_cast(b_shape[0] * b_shape[1]); ++i) { + EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1]]); + } + + Print2D(&a); + Print2D(&b); + Print2D(&ans); +} + +TEST(Cat, Test2DTensorsDim1) { + MNNAllocator* allocator; + + std::array a_shape{4, 3}; + std::array b_shape{4, 2}; + + MNN::Express::VARP a = MNNUtilsCreateTensor(allocator, a_shape.data(), + a_shape.size()); + + MNN::Express::VARP b = MNNUtilsCreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a->writeMap(); + float *pb = b->writeMap(); + for (int32_t i = 0; i != static_cast(a_shape[0] * a_shape[1]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; i != static_cast(b_shape[0] * b_shape[1]); ++i) { + pb[i] = i + 10; + } + + MNN::Express::VARP ans = Cat(allocator, {&a, &b}, 1); + + const float *pans = ans->readMap(); + + for (int32_t r = 0; r != static_cast(a_shape[0]); ++r) { + for (int32_t i = 0; i != static_cast(a_shape[1]); + ++i, ++pa, ++pans) { + EXPECT_EQ(*pa, *pans); + } + + for (int32_t i = 0; i != static_cast(b_shape[1]); + ++i, ++pb, ++pans) { + EXPECT_EQ(*pb, *pans); + } + } + + Print2D(&a); + Print2D(&b); + Print2D(&ans); +} + +TEST(Cat, Test3DTensorsDim0) { + MNNAllocator* allocator; + + std::array a_shape{2, 3, 2}; + std::array b_shape{4, 3, 2}; + + MNN::Express::VARP a = MNNUtilsCreateTensor(allocator, a_shape.data(), + a_shape.size()); + + MNN::Express::VARP b = MNNUtilsCreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a->writeMap(); + float *pb = b->writeMap(); + for (int32_t i = 0; + i != static_cast(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; + i != static_cast(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { + pb[i] = i + 10; + } + + MNN::Express::VARP ans = Cat(allocator, {&a, &b}, 0); + + const float *pans = ans->readMap(); + for (int32_t i = 0; + i != static_cast(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { + EXPECT_EQ(pa[i], pans[i]); + } + for (int32_t i = 0; + i != static_cast(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { + EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1] * a_shape[2]]); + } + + Print3D(&a); + Print3D(&b); + Print3D(&ans); +} + +TEST(Cat, Test3DTensorsDim1) { + MNNAllocator* allocator; + + std::array a_shape{2, 2, 3}; + std::array b_shape{2, 4, 3}; + + MNN::Express::VARP a = MNNUtilsCreateTensor(allocator, a_shape.data(), + a_shape.size()); + + MNN::Express::VARP b = MNNUtilsCreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a->writeMap(); + float *pb = b->writeMap(); + for (int32_t i = 0; + i != static_cast(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; + i != static_cast(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { + pb[i] = i + 10; + } + + MNN::Express::VARP ans = Cat(allocator, {&a, &b}, 1); + + const float *pans = ans->readMap(); + + for (int32_t i = 0; i != static_cast(a_shape[0]); ++i) { + for (int32_t k = 0; k != static_cast(a_shape[1] * a_shape[2]); + ++k, ++pa, ++pans) { + EXPECT_EQ(*pa, *pans); + } + + for (int32_t k = 0; k != static_cast(b_shape[1] * b_shape[2]); + ++k, ++pb, ++pans) { + EXPECT_EQ(*pb, *pans); + } + } + + Print3D(&a); + Print3D(&b); + Print3D(&ans); +} + +TEST(Cat, Test3DTensorsDim2) { + MNNAllocator* allocator; + + std::array a_shape{2, 3, 4}; + std::array b_shape{2, 3, 5}; + + MNN::Express::VARP a = MNNUtilsCreateTensor(allocator, a_shape.data(), + a_shape.size()); + + MNN::Express::VARP b = MNNUtilsCreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a->writeMap(); + float *pb = b->writeMap(); + for (int32_t i = 0; + i != static_cast(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; + i != static_cast(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { + pb[i] = i + 10; + } + + MNN::Express::VARP ans = Cat(allocator, {&a, &b}, 2); + + const float *pans = ans->readMap(); + + for (int32_t i = 0; i != static_cast(a_shape[0] * a_shape[1]); ++i) { + for (int32_t k = 0; k != static_cast(a_shape[2]); + ++k, ++pa, ++pans) { + EXPECT_EQ(*pa, *pans); + } + + for (int32_t k = 0; k != static_cast(b_shape[2]); + ++k, ++pb, ++pans) { + EXPECT_EQ(*pb, *pans); + } + } + + Print3D(&a); + Print3D(&b); + Print3D(&ans); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/cat.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/cat.cc new file mode 100644 index 00000000..71b4383c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/cat.cc @@ -0,0 +1,106 @@ +// sherpa-mnn/csrc/cat.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/cat.h" + +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +static bool Compare(const std::vector &a, + const std::vector &b, int32_t skip_dim) { + if (a.size() != b.size()) return false; + + for (int32_t i = 0; i != static_cast(a.size()); ++i) { + if (i == skip_dim) continue; + + if (a[i] != b[i]) return false; + } + + return true; +} + +static void PrintShape(const std::vector &a) { + for (auto i : a) { + fprintf(stderr, "%d ", static_cast(i)); + } + fprintf(stderr, "\n"); +} + +template +MNN::Express::VARP Cat(MNNAllocator *allocator, + const std::vector &values, int32_t dim) { + if (values.size() == 1u) { + return Clone(allocator, values[0]); + } + + std::vector v0_shape = + values[0]->getInfo()->dim; + + int total_dim = v0_shape[dim]; + + for (int32_t i = 1; i != static_cast(values.size()); ++i) { + auto s = values[i]->getInfo()->dim; + total_dim += s[dim]; + + bool ret = Compare(v0_shape, s, dim); + if (!ret) { + fprintf(stderr, "Incorrect shape in Cat !\n"); + + fprintf(stderr, "Shape for tensor 0: "); + PrintShape(v0_shape); + + fprintf(stderr, "Shape for tensor %d: ", i); + PrintShape(s); + + exit(-1); + } + } + + std::vector ans_shape; + ans_shape.reserve(v0_shape.size()); + ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim); + ans_shape.push_back(total_dim); + ans_shape.insert(ans_shape.end(), v0_shape.data() + dim + 1, + v0_shape.data() + v0_shape.size()); + + auto leading_size = static_cast(std::accumulate( + v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies())); + + auto trailing_size = static_cast( + std::accumulate(v0_shape.begin() + dim + 1, v0_shape.end(), 1, + std::multiplies())); + + MNN::Express::VARP ans = MNNUtilsCreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + T *dst = ans->writeMap(); + + for (int32_t i = 0; i != leading_size; ++i) { + for (auto value : values) { + auto this_dim = value->getInfo()->dim[dim]; + const T *src = value->readMap(); + src += i * this_dim * trailing_size; + + std::copy(src, src + this_dim * trailing_size, dst); + dst += this_dim * trailing_size; + } + } + + return ans; +} + +template MNN::Express::VARP Cat(MNNAllocator *allocator, + const std::vector &values, + int32_t dim); + +template MNN::Express::VARP Cat(MNNAllocator *allocator, + const std::vector &values, + int32_t dim); + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/cat.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/cat.h new file mode 100644 index 00000000..9e9de4dc --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/cat.h @@ -0,0 +1,28 @@ +// sherpa-mnn/csrc/cat.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_CAT_H_ +#define SHERPA_ONNX_CSRC_CAT_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT + +namespace sherpa_mnn { + +/** Cat a list of tensors along the given dim. + * + * @param allocator Allocator to allocate space for the returned tensor + * @param values Pointer to a list of tensors. The shape of the tensor must + * be the same except on the dim to be concatenated. + * @param dim The dim along which to concatenate the input tensors + * + * @return Return the concatenated tensor + */ +template +MNN::Express::VARP Cat(MNNAllocator *allocator, + const std::vector &values, int32_t dim); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_CAT_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/circular-buffer-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/circular-buffer-test.cc new file mode 100644 index 00000000..5be0a6aa --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/circular-buffer-test.cc @@ -0,0 +1,150 @@ +// sherpa-mnn/csrc/circular-buffer-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/circular-buffer.h" + +#include + +#include "gtest/gtest.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +TEST(CircularBuffer, Push) { + CircularBuffer buffer(10); + EXPECT_EQ(buffer.Size(), 0); + EXPECT_EQ(buffer.Head(), 0); + EXPECT_EQ(buffer.Tail(), 0); + + std::vector a = {0, 1, 2, 3, 4, 5}; + buffer.Push(a.data(), a.size()); + + EXPECT_EQ(buffer.Size(), 6); + EXPECT_EQ(buffer.Head(), 0); + EXPECT_EQ(buffer.Tail(), 6); + + auto c = buffer.Get(0, a.size()); + EXPECT_EQ(a.size(), c.size()); + for (int32_t i = 0; i != a.size(); ++i) { + EXPECT_EQ(a[i], c[i]); + } + + std::vector d = {-6, -7, -8, -9}; + buffer.Push(d.data(), d.size()); + + c = buffer.Get(a.size(), d.size()); + EXPECT_EQ(d.size(), c.size()); + for (int32_t i = 0; i != d.size(); ++i) { + EXPECT_EQ(d[i], c[i]); + } +} + +TEST(CircularBuffer, PushAndPop) { + CircularBuffer buffer(5); + std::vector a = {0, 1, 2, 3}; + buffer.Push(a.data(), a.size()); + + EXPECT_EQ(buffer.Size(), 4); + EXPECT_EQ(buffer.Head(), 0); + EXPECT_EQ(buffer.Tail(), 4); + + buffer.Pop(2); + + EXPECT_EQ(buffer.Size(), 2); + EXPECT_EQ(buffer.Head(), 2); + EXPECT_EQ(buffer.Tail(), 4); + + auto c = buffer.Get(2, 2); + EXPECT_EQ(c.size(), 2); + EXPECT_EQ(c[0], 2); + EXPECT_EQ(c[1], 3); + + a = {10, 20, 30}; + buffer.Push(a.data(), a.size()); + EXPECT_EQ(buffer.Size(), 5); + EXPECT_EQ(buffer.Head(), 2); + EXPECT_EQ(buffer.Tail(), 7); + + c = buffer.Get(2, 5); + EXPECT_EQ(c.size(), 5); + EXPECT_EQ(c[0], 2); + EXPECT_EQ(c[1], 3); + EXPECT_EQ(c[2], 10); + EXPECT_EQ(c[3], 20); + EXPECT_EQ(c[4], 30); + + c = buffer.Get(3, 4); + EXPECT_EQ(c.size(), 4); + EXPECT_EQ(c[0], 3); + EXPECT_EQ(c[1], 10); + EXPECT_EQ(c[2], 20); + EXPECT_EQ(c[3], 30); + + c = buffer.Get(4, 3); + EXPECT_EQ(c.size(), 3); + EXPECT_EQ(c[0], 10); + EXPECT_EQ(c[1], 20); + EXPECT_EQ(c[2], 30); + + buffer.Pop(4); + EXPECT_EQ(buffer.Size(), 1); + EXPECT_EQ(buffer.Head(), 6); + EXPECT_EQ(buffer.Tail(), 7); + + c = buffer.Get(6, 1); + EXPECT_EQ(c.size(), 1); + EXPECT_EQ(c[0], 30); + + a = {100, 200, 300, 400}; + buffer.Push(a.data(), a.size()); + EXPECT_EQ(buffer.Size(), 5); + + EXPECT_EQ(buffer.Size(), 5); + EXPECT_EQ(buffer.Head(), 6); + EXPECT_EQ(buffer.Tail(), 11); + + c = buffer.Get(6, 5); + EXPECT_EQ(c.size(), 5); + EXPECT_EQ(c[0], 30); + EXPECT_EQ(c[1], 100); + EXPECT_EQ(c[2], 200); + EXPECT_EQ(c[3], 300); + EXPECT_EQ(c[4], 400); + + buffer.Pop(3); + EXPECT_EQ(buffer.Size(), 2); + EXPECT_EQ(buffer.Head(), 9); + EXPECT_EQ(buffer.Tail(), 11); + + c = buffer.Get(10, 1); + EXPECT_EQ(c.size(), 1); + EXPECT_EQ(c[0], 400); + + a = {1000, 2000, 3000}; + buffer.Push(a.data(), a.size()); + + EXPECT_EQ(buffer.Size(), 5); + EXPECT_EQ(buffer.Head(), 9); + EXPECT_EQ(buffer.Tail(), 14); + + buffer.Pop(1); + + EXPECT_EQ(buffer.Size(), 4); + EXPECT_EQ(buffer.Head(), 10); + EXPECT_EQ(buffer.Tail(), 14); + + a = {4000}; + + buffer.Push(a.data(), a.size()); + EXPECT_EQ(buffer.Size(), 5); + EXPECT_EQ(buffer.Head(), 10); + EXPECT_EQ(buffer.Tail(), 15); + + c = buffer.Get(13, 2); + EXPECT_EQ(c.size(), 2); + EXPECT_EQ(c[0], 3000); + EXPECT_EQ(c[1], 4000); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/circular-buffer.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/circular-buffer.cc new file mode 100644 index 00000000..d59b9bbf --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/circular-buffer.cc @@ -0,0 +1,181 @@ +// sherpa-mnn/csrc/circular-buffer.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/circular-buffer.h" + +#include + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +CircularBuffer::CircularBuffer(int32_t capacity) { + if (capacity <= 0) { + SHERPA_ONNX_LOGE("Please specify a positive capacity. Given: %d\n", + capacity); + exit(-1); + } + buffer_.resize(capacity); +} + +void CircularBuffer::Resize(int32_t new_capacity) { + int32_t capacity = static_cast(buffer_.size()); + if (new_capacity <= capacity) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "new_capacity (%{public}d) <= original capacity (%{public}d). Skip it.", + new_capacity, capacity); +#else + SHERPA_ONNX_LOGE("new_capacity (%d) <= original capacity (%d). Skip it.", + new_capacity, capacity); +#endif + return; + } + + int32_t size = Size(); + if (size == 0) { + buffer_.resize(new_capacity); + return; + } + + std::vector new_buffer(new_capacity); + int32_t start = head_ % capacity; + int32_t dest = head_ % new_capacity; + + if (start + size <= capacity) { + if (dest + size <= new_capacity) { + std::copy(buffer_.begin() + start, buffer_.begin() + start + size, + new_buffer.begin() + dest); + } else { + int32_t part1_size = new_capacity - dest; + + // copy [start, start+part1_size] to new_buffer + std::copy(buffer_.begin() + start, buffer_.begin() + start + part1_size, + new_buffer.begin() + dest); + + // copy [start+part1_size, start+size] to new_buffer + std::copy(buffer_.begin() + start + part1_size, + buffer_.begin() + start + size, new_buffer.begin()); + } + } else { + int32_t part1_size = capacity - start; + int32_t part2_size = size - part1_size; + + // copy [start, start+part1_size] to new_buffer + if (dest + part1_size <= new_capacity) { + std::copy(buffer_.begin() + start, buffer_.begin() + start + part1_size, + new_buffer.begin() + dest); + } else { + int32_t first_part = new_capacity - dest; + std::copy(buffer_.begin() + start, buffer_.begin() + start + first_part, + new_buffer.begin() + dest); + + std::copy(buffer_.begin() + start + first_part, + buffer_.begin() + start + part1_size, new_buffer.begin()); + } + + int32_t new_dest = (dest + part1_size) % new_capacity; + + if (new_dest + part2_size <= new_capacity) { + std::copy(buffer_.begin(), buffer_.begin() + part2_size, + new_buffer.begin() + new_dest); + } else { + int32_t first_part = new_capacity - new_dest; + std::copy(buffer_.begin(), buffer_.begin() + first_part, + new_buffer.begin() + new_dest); + std::copy(buffer_.begin() + first_part, buffer_.begin() + part2_size, + new_buffer.begin()); + } + } + buffer_.swap(new_buffer); +} + +void CircularBuffer::Push(const float *p, int32_t n) { + int32_t capacity = static_cast(buffer_.size()); + int32_t size = Size(); + if (n + size > capacity) { + int32_t new_capacity = std::max(capacity * 2, n + size); +#if __OHOS__ + SHERPA_ONNX_LOGE( + "Overflow! n: %{public}d, size: %{public}d, n+size: %{public}d, " + "capacity: %{public}d. Increase " + "capacity to: %{public}d. (Original data is copied. No data loss!)", + n, size, n + size, capacity, new_capacity); +#else + SHERPA_ONNX_LOGE( + "Overflow! n: %d, size: %d, n+size: %d, capacity: %d. Increase " + "capacity to: %d. (Original data is copied. No data loss!)", + n, size, n + size, capacity, new_capacity); +#endif + Resize(new_capacity); + + capacity = new_capacity; + } + + int32_t start = tail_ % capacity; + + tail_ += n; + + if (start + n < capacity) { + std::copy(p, p + n, buffer_.begin() + start); + return; + } + + int32_t part1_size = capacity - start; + + std::copy(p, p + part1_size, buffer_.begin() + start); + + std::copy(p + part1_size, p + n, buffer_.begin()); +} + +std::vector CircularBuffer::Get(int32_t start_index, int32_t n) const { + if (start_index < head_ || start_index >= tail_) { + SHERPA_ONNX_LOGE("Invalid start_index: %d. head_: %d, tail_: %d", + start_index, head_, tail_); + return {}; + } + + int32_t size = Size(); + if (n < 0 || n > size) { + SHERPA_ONNX_LOGE("Invalid n: %d. size: %d", n, size); + return {}; + } + + int32_t capacity = static_cast(buffer_.size()); + + if (start_index - head_ + n > size) { + SHERPA_ONNX_LOGE("Invalid start_index: %d and n: %d. head_: %d, size: %d", + start_index, n, head_, size); + return {}; + } + + int32_t start = start_index % capacity; + + if (start + n < capacity) { + return {buffer_.begin() + start, buffer_.begin() + start + n}; + } + + std::vector ans(n); + + std::copy(buffer_.begin() + start, buffer_.end(), ans.begin()); + + int32_t part1_size = capacity - start; + int32_t part2_size = n - part1_size; + std::copy(buffer_.begin(), buffer_.begin() + part2_size, + ans.begin() + part1_size); + + return ans; +} + +void CircularBuffer::Pop(int32_t n) { + int32_t size = Size(); + if (n < 0 || n > size) { + SHERPA_ONNX_LOGE("Invalid n: %d. size: %d", n, size); + return; + } + + head_ += n; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/circular-buffer.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/circular-buffer.h new file mode 100644 index 00000000..24e02af8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/circular-buffer.h @@ -0,0 +1,61 @@ +// sherpa-mnn/csrc/circular-buffer.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_CIRCULAR_BUFFER_H_ +#define SHERPA_ONNX_CSRC_CIRCULAR_BUFFER_H_ + +#include +#include + +namespace sherpa_mnn { + +class CircularBuffer { + public: + // Capacity of this buffer. Should be large enough. + // If it is full, we just print a message and exit the program. + explicit CircularBuffer(int32_t capacity); + + // Push an array + // + // @param p Pointer to the start address of the array + // @param n Number of elements in the array + // + // Note: If n + Size() > capacity, we print an error message and exit. + void Push(const float *p, int32_t n); + + // @param start_index Should in the range [head_, tail_) + // @param n Number of elements to get + // @return Return a vector of size n containing the requested elements + std::vector Get(int32_t start_index, int32_t n) const; + + // Remove n elements from the buffer + // + // @param n Should be in the range [0, size_] + void Pop(int32_t n); + + // Number of elements in the buffer. + int32_t Size() const { return tail_ - head_; } + + // Current position of the head + int32_t Head() const { return head_; } + + // Current position of the tail + int32_t Tail() const { return tail_; } + + void Reset() { + head_ = 0; + tail_ = 0; + } + + void Resize(int32_t new_capacity); + + private: + std::vector buffer_; + + int32_t head_ = 0; // linear index; always increasing; never wraps around + int32_t tail_ = 0; // linear index, always increasing; never wraps around. +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_CIRCULAR_BUFFER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/context-graph-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/context-graph-test.cc new file mode 100644 index 00000000..effdd105 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/context-graph-test.cc @@ -0,0 +1,102 @@ +// sherpa-mnn/csrc/context-graph-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/context-graph.h" + +#include // NOLINT +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +static void TestHelper(const std::map &queries, float score, + bool strict_mode) { + std::vector contexts_str( + {"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"}); + std::vector> contexts; + std::vector scores; + for (int32_t i = 0; i < contexts_str.size(); ++i) { + contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end()); + scores.push_back(std::round(score / contexts_str[i].size() * 100) / 100); + } + auto context_graph = ContextGraph(contexts, 1, scores); + + for (const auto &iter : queries) { + float total_scores = 0; + auto state = context_graph.Root(); + for (auto q : iter.first) { + auto res = context_graph.ForwardOneStep(state, q, strict_mode); + total_scores += std::get<0>(res); + state = std::get<1>(res); + } + auto res = context_graph.Finalize(state); + EXPECT_EQ(res.second->token, -1); + total_scores += res.first; + EXPECT_EQ(total_scores, iter.second); + } +} + +TEST(ContextGraph, TestBasic) { + auto queries = std::map{ + {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9}, + {"SHED", 6}, {"SHELF", 6}, {"HELL", 2}, + {"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}}; + TestHelper(queries, 0, true); +} + +TEST(ContextGraph, TestBasicNonStrict) { + auto queries = std::map{ + {"HEHERSHE", 7}, {"HERSHE", 5}, {"HISHE", 5}, {"SHED", 3}, {"SHELF", 3}, + {"HELL", 2}, {"HELLO", 2}, {"DHRHISQ", 3}, {"THEN", 2}}; + TestHelper(queries, 0, false); +} + +TEST(ContextGraph, TestCustomize) { + auto queries = std::map{ + {"HEHERSHE", 35.84}, {"HERSHE", 30.84}, {"HISHE", 24.18}, + {"SHED", 18.34}, {"SHELF", 18.34}, {"HELL", 5}, + {"HELLO", 13}, {"DHRHISQ", 10.84}, {"THEN", 5}}; + TestHelper(queries, 5, true); +} + +TEST(ContextGraph, TestCustomizeNonStrict) { + auto queries = std::map{ + {"HEHERSHE", 20}, {"HERSHE", 15}, {"HISHE", 10.84}, + {"SHED", 10}, {"SHELF", 10}, {"HELL", 5}, + {"HELLO", 5}, {"DHRHISQ", 5.84}, {"THEN", 5}}; + TestHelper(queries, 5, false); +} + +TEST(ContextGraph, Benchmark) { + std::random_device rd; + std::mt19937 mt(rd()); + std::uniform_int_distribution char_dist(0, 25); + std::uniform_int_distribution len_dist(3, 8); + for (int32_t num = 10; num <= 10000; num *= 10) { + std::vector> contexts; + for (int32_t i = 0; i < num; ++i) { + std::vector tmp; + int32_t word_len = len_dist(mt); + for (int32_t j = 0; j < word_len; ++j) { + tmp.push_back(char_dist(mt)); + } + contexts.push_back(std::move(tmp)); + } + auto start = std::chrono::high_resolution_clock::now(); + auto context_graph = ContextGraph(contexts, 1); + auto stop = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(stop - start); + SHERPA_ONNX_LOGE("Construct context graph for %d item takes %d us.", num, + static_cast(duration.count())); + } +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/context-graph.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/context-graph.cc new file mode 100644 index 00000000..3aab0fc8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/context-graph.cc @@ -0,0 +1,167 @@ +// sherpa-mnn/csrc/context-graph.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/context-graph.h" + +#include +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { +void ContextGraph::Build(const std::vector> &token_ids, + const std::vector &scores, + const std::vector &phrases, + const std::vector &ac_thresholds) const { + if (!scores.empty()) { + SHERPA_ONNX_CHECK_EQ(token_ids.size(), scores.size()); + } + if (!phrases.empty()) { + SHERPA_ONNX_CHECK_EQ(token_ids.size(), phrases.size()); + } + if (!ac_thresholds.empty()) { + SHERPA_ONNX_CHECK_EQ(token_ids.size(), ac_thresholds.size()); + } + for (int32_t i = 0; i < static_cast(token_ids.size()); ++i) { + auto node = root_.get(); + float score = scores.empty() ? 0.0f : scores[i]; + score = score == 0.0f ? context_score_ : score; + float ac_threshold = ac_thresholds.empty() ? 0.0f : ac_thresholds[i]; + ac_threshold = ac_threshold == 0.0f ? ac_threshold_ : ac_threshold; + std::string phrase = phrases.empty() ? std::string() : phrases[i]; + + for (int32_t j = 0; j < static_cast(token_ids[i].size()); ++j) { + int32_t token = token_ids[i][j]; + if (0 == node->next.count(token)) { + bool is_end = j == (static_cast(token_ids[i].size()) - 1); + node->next[token] = std::make_unique( + token, score, node->node_score + score, + is_end ? node->node_score + score : 0, j + 1, + is_end ? ac_threshold : 0.0f, is_end, + is_end ? phrase : std::string()); + } else { + float token_score = std::max(score, node->next[token]->token_score); + node->next[token]->token_score = token_score; + float node_score = node->node_score + token_score; + node->next[token]->node_score = node_score; + bool is_end = (j == static_cast(token_ids[i].size()) - 1) || + node->next[token]->is_end; + node->next[token]->output_score = is_end ? node_score : 0.0f; + node->next[token]->is_end = is_end; + if (j == static_cast(token_ids[i].size()) - 1) { + node->next[token]->phrase = phrase; + node->next[token]->ac_threshold = ac_threshold; + } + } + node = node->next[token].get(); + } + } + FillFailOutput(); +} + +std::tuple +ContextGraph::ForwardOneStep(const ContextState *state, int32_t token, + bool strict_mode /*= true*/) const { + const ContextState *node = nullptr; + float score = 0; + if (1 == state->next.count(token)) { + node = state->next.at(token).get(); + score = node->token_score; + } else { + node = state->fail; + while (0 == node->next.count(token)) { + node = node->fail; + if (-1 == node->token) break; // root + } + if (1 == node->next.count(token)) { + node = node->next.at(token).get(); + } + score = node->node_score - state->node_score; + } + + if (!node) { + SHERPA_ONNX_LOGE("Some bad things happened."); + exit(-1); + } + + const ContextState *matched_node = + node->is_end ? node : (node->output != nullptr ? node->output : nullptr); + + if (!strict_mode && node->output_score != 0) { + SHERPA_ONNX_CHECK(nullptr != matched_node); + float output_score = + node->is_end ? node->node_score + : (node->output != nullptr ? node->output->node_score + : node->node_score); + return std::make_tuple(score + output_score - node->node_score, root_.get(), + matched_node); + } + return std::make_tuple(score + node->output_score, node, matched_node); +} + +std::pair ContextGraph::Finalize( + const ContextState *state) const { + float score = -state->node_score; + return std::make_pair(score, root_.get()); +} + +std::pair ContextGraph::IsMatched( + const ContextState *state) const { + bool status = false; + const ContextState *node = nullptr; + if (state->is_end) { + status = true; + node = state; + } else { + if (state->output != nullptr) { + status = true; + node = state->output; + } + } + return std::make_pair(status, node); +} + +void ContextGraph::FillFailOutput() const { + std::queue node_queue; + for (auto &kv : root_->next) { + kv.second->fail = root_.get(); + node_queue.push(kv.second.get()); + } + while (!node_queue.empty()) { + auto current_node = node_queue.front(); + node_queue.pop(); + for (auto &kv : current_node->next) { + auto fail = current_node->fail; + if (1 == fail->next.count(kv.first)) { + fail = fail->next.at(kv.first).get(); + } else { + fail = fail->fail; + while (0 == fail->next.count(kv.first)) { + fail = fail->fail; + if (-1 == fail->token) break; + } + if (1 == fail->next.count(kv.first)) + fail = fail->next.at(kv.first).get(); + } + kv.second->fail = fail; + // fill the output arc + auto output = fail; + while (!output->is_end) { + output = output->fail; + if (-1 == output->token) { + output = nullptr; + break; + } + } + kv.second->output = output; + kv.second->output_score += output == nullptr ? 0 : output->output_score; + node_queue.push(kv.second.get()); + } + } +} +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/context-graph.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/context-graph.h new file mode 100644 index 00000000..ec4a7500 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/context-graph.h @@ -0,0 +1,92 @@ +// sherpa-mnn/csrc/context-graph.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_ +#define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_ + +#include +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/log.h" + +namespace sherpa_mnn { + +class ContextGraph; +using ContextGraphPtr = std::shared_ptr; + +struct ContextState { + int32_t token; + float token_score; + float node_score; + float output_score; + int32_t level; + float ac_threshold; + bool is_end; + std::string phrase; + std::unordered_map> next; + const ContextState *fail = nullptr; + const ContextState *output = nullptr; + + ContextState() = default; + ContextState(int32_t token, float token_score, float node_score, + float output_score, int32_t level = 0, float ac_threshold = 0.0f, + bool is_end = false, const std::string &phrase = {}) + : token(token), + token_score(token_score), + node_score(node_score), + output_score(output_score), + level(level), + ac_threshold(ac_threshold), + is_end(is_end), + phrase(phrase) {} +}; + +class ContextGraph { + public: + ContextGraph() = default; + ContextGraph(const std::vector> &token_ids, + float context_score, float ac_threshold, + const std::vector &scores = {}, + const std::vector &phrases = {}, + const std::vector &ac_thresholds = {}) + : context_score_(context_score), ac_threshold_(ac_threshold) { + root_ = std::make_unique(-1, 0, 0, 0); + root_->fail = root_.get(); + Build(token_ids, scores, phrases, ac_thresholds); + } + + ContextGraph(const std::vector> &token_ids, + float context_score, const std::vector &scores = {}) + : ContextGraph(token_ids, context_score, 0.0f, scores, + std::vector(), std::vector()) {} + + std::tuple ForwardOneStep( + const ContextState *state, int32_t token_id, + bool strict_mode = true) const; + + std::pair IsMatched( + const ContextState *state) const; + + std::pair Finalize( + const ContextState *state) const; + + const ContextState *Root() const { return root_.get(); } + + private: + float context_score_; + float ac_threshold_; + std::unique_ptr root_; + void Build(const std::vector> &token_ids, + const std::vector &scores, + const std::vector &phrases, + const std::vector &ac_thresholds) const; + void FillFailOutput() const; +}; + +} // namespace sherpa_mnn +#endif // SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/cppjieba-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/cppjieba-test.cc new file mode 100644 index 00000000..f8bd4242 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/cppjieba-test.cc @@ -0,0 +1,144 @@ +// sherpa-mnn/csrc/cppjieba-test.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include +#include // NOLINT +#include +#include + +#include "cppjieba/Jieba.hpp" +#include "gtest/gtest.h" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +// Please download dict files form +// https://github.com/csukuangfj/cppjieba/releases/download/sherpa-mnn-2024-04-19/dict.tar.bz2 +const char *const kDictPath = "./dict/jieba.dict.utf8"; +const char *const kHmmPath = "./dict/hmm_model.utf8"; +const char *const kUserDictPath = "./dict/user.dict.utf8"; +const char *const kIdfPath = "./dict/idf.utf8"; +const char *const kStopWordPath = "./dict/stop_words.utf8"; + +TEST(CppJieBa, Case1) { + if (!FileExists(kDictPath)) { + SHERPA_ONNX_LOGE("%s does not exist. Skipping test", kDictPath); + return; + } + + cppjieba::Jieba jieba(kDictPath, kHmmPath, kUserDictPath, kIdfPath, + kStopWordPath); + + std::vector words; + std::vector jiebawords; + + std::string s = "他来到了网易杭研大厦。How are you?"; + std::cout << s << std::endl; + std::cout << "[demo] Cut With HMM" << std::endl; + jieba.Cut(s, words, true); + std::cout << limonp::Join(words.begin(), words.end(), "/") << std::endl; + /* + 他来到了网易杭研大厦 + [demo] Cut With HMM + 他/来到/了/网易/杭研/大厦 + */ + s = "小明硕士毕业于中国科学院计算所,后在日本京都大学深造"; + std::cout << s << std::endl; + std::cout << "[demo] CutForSearch" << std::endl; + jieba.CutForSearch(s, words); + std::cout << limonp::Join(words.begin(), words.end(), "/") << std::endl; + /* + 小明硕士毕业于中国科学院计算所,后在日本京都大学深造 + [demo] CutForSearch + 小明/硕士/毕业/于/中国/科学/学院/科学院/中国科学院/计算/计算所/,/后/在/日本/京都/大学/日本京都大学/深造 + */ + std::cout << "[demo] Insert User Word" << std::endl; + jieba.Cut("男默女泪", words); + std::cout << limonp::Join(words.begin(), words.end(), "/") << std::endl; + jieba.InsertUserWord("男默女泪"); + jieba.Cut("男默女泪", words); + std::cout << limonp::Join(words.begin(), words.end(), "/") << std::endl; + /* + [demo] Insert User Word + 男默/女泪 + 男默女泪 + */ + std::cout << "[demo] CutForSearch Word With Offset" << std::endl; + jieba.CutForSearch(s, jiebawords, true); + std::cout << jiebawords << std::endl; + /* +[demo] CutForSearch Word With Offset +[{"word": "小明", "offset": 0}, {"word": "硕士", "offset": 6}, {"word": "毕业", +"offset": 12}, {"word": "于", "offset": 18}, {"word": "中国", "offset": 21}, +{"word": "科学", "offset": 27}, {"word": "学院", "offset": 30}, {"word": +"科学院", "offset": 27}, {"word": "中国科学院", "offset": 21}, {"word": "计算", +"offset": 36}, {"word": "计算所", "offset": 36}, {"word": ",", "offset": 45}, +{"word": "后", "offset": 48}, {"word": "在", "offset": 51}, {"word": "日本", +"offset": 54}, {"word": "京都", "offset": 60}, {"word": "大学", "offset": 66}, +{"word": "日本京都大学", "offset": 54}, {"word": " 深造", "offset": 72}] + */ + // see more test at + // https://github.com/yanyiwu/cppjieba/blob/master/test/demo.cpp +} + +TEST(CppJieBa, Case2) { + if (!FileExists(kDictPath)) { + SHERPA_ONNX_LOGE("%s does not exist. Skipping test", kDictPath); + return; + } + + cppjieba::Jieba jieba(kDictPath, kHmmPath, kUserDictPath, kIdfPath, + kStopWordPath); + std::string s = + "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如" + "涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感" + "受着生命的奇迹与温柔"; + std::vector words; + bool is_hmm = true; + jieba.Cut(s, words, is_hmm); + { + std::ostringstream os; + std::string sep = ""; + for (const auto &w : words) { + os << sep << w; + sep = "_"; + } + + std::cout << os.str() << "\n"; + } + /* +当_夜幕降临_,_星光点点_,_伴随_着_微风_拂面_, +_我_在_静谧_中_感受_着_时光_的_流转_, +_思念_如_涟漪_荡漾_,_梦境_如_画卷_展开_,_我_与_自然_融为一体_, +_沉静_在_这_片_宁静_的_美丽_之中_,_感受_着_生命_的_奇迹_与_温柔 + */ + s = "这里有:红的、绿的、蓝的;各种各样的颜色都有!你想要什么呢?测试."; + std::regex punct_re(":|、|;"); + std::string s2 = std::regex_replace(s, punct_re, ","); + + std::regex punct_re2("[.]"); + s2 = std::regex_replace(s2, punct_re2, "。"); + + std::regex punct_re3("[?]"); + s2 = std::regex_replace(s2, punct_re3, "?"); + + std::regex punct_re4("[!]"); + s2 = std::regex_replace(s2, punct_re4, "!"); + std::cout << s << "\n" << s2 << "\n"; + + words.clear(); + jieba.Cut(s2, words, is_hmm); + { + std::ostringstream os; + std::string sep = ""; + for (const auto &w : words) { + os << sep << w; + sep = "_"; + } + + std::cout << os.str() << "\n"; + } +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/display.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/display.h new file mode 100644 index 00000000..ec82f079 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/display.h @@ -0,0 +1,88 @@ +// sherpa-mnn/csrc/display.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_DISPLAY_H_ +#define SHERPA_ONNX_CSRC_DISPLAY_H_ +#include + +#include + +namespace sherpa_mnn { + +class Display { + public: + explicit Display(int32_t max_word_per_line = 60) + : max_word_per_line_(max_word_per_line) {} + + void Print(int32_t segment_id, const std::string &s) { +#ifdef _MSC_VER + if (segment_id != -1) { + fprintf(stderr, "%d:%s\n", segment_id, s.c_str()); + } else { + fprintf(stderr, "%s\n", s.c_str()); + } + return; +#endif + if (last_segment_ == segment_id) { + Clear(); + } else { + if (last_segment_ != -1) { + fprintf(stderr, "\n\r"); + } + last_segment_ = segment_id; + num_previous_lines_ = 0; + } + + if (segment_id != -1) { + fprintf(stderr, "\r%d:", segment_id); + } + + int32_t i = 0; + for (size_t n = 0; n < s.size();) { + if (s[n] > 0 && s[n] < 0x7f) { + fprintf(stderr, "%c", s[n]); + ++n; + } else { + // Each Chinese character occupies 3 bytes for UTF-8 encoding. + std::string tmp(s.begin() + n, s.begin() + n + 3); + fprintf(stderr, "%s", tmp.data()); + n += 3; + } + + ++i; + if (i >= max_word_per_line_ && n + 1 < s.size() && + (s[n] == ' ' || s[n] < 0)) { + fprintf(stderr, "\n\r "); + ++num_previous_lines_; + i = 0; + } + } + } + + private: + // Clear the output for the current segment + void Clear() { + ClearCurrentLine(); + while (num_previous_lines_ > 0) { + GoUpOneLine(); + ClearCurrentLine(); + --num_previous_lines_; + } + } + + // Clear the current line + void ClearCurrentLine() const { fprintf(stderr, "\33[2K\r"); } + + // Move the cursor to the previous line + void GoUpOneLine() const { fprintf(stderr, "\033[1A\r"); } + + private: + int32_t max_word_per_line_; + int32_t num_previous_lines_ = 0; + int32_t last_segment_ = -1; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_DISPLAY_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/endpoint.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/endpoint.cc new file mode 100644 index 00000000..caa61885 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/endpoint.cc @@ -0,0 +1,96 @@ +// sherpa-mnn/csrc/endpoint.cc +// +// Copyright (c) 2022 (authors: Pingfeng Luo) +// 2022-2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/endpoint.h" + +#include + +#include "sherpa-mnn/csrc/log.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +static bool RuleActivated(const EndpointRule &rule, + const std::string &rule_name, float trailing_silence, + float utterance_length) { + bool contain_nonsilence = utterance_length > trailing_silence; + bool ans = (contain_nonsilence || !rule.must_contain_nonsilence) && + trailing_silence >= rule.min_trailing_silence && + utterance_length >= rule.min_utterance_length; + if (ans) { + SHERPA_ONNX_LOG(DEBUG) << "Endpointing rule " << rule_name << " activated: " + << (contain_nonsilence ? "true" : "false") << ',' + << trailing_silence << ',' << utterance_length; + } + return ans; +} + +static void RegisterEndpointRule(ParseOptions *po, EndpointRule *rule, + const std::string &rule_name) { + po->Register( + rule_name + "-must-contain-nonsilence", &rule->must_contain_nonsilence, + "If True, for this endpointing " + rule_name + + " to apply there must be nonsilence in the best-path traceback. " + "For decoding, a non-blank token is considered as non-silence"); + po->Register(rule_name + "-min-trailing-silence", &rule->min_trailing_silence, + "This endpointing " + rule_name + + " requires duration of trailing silence in seconds) to " + "be >= this value."); + po->Register(rule_name + "-min-utterance-length", &rule->min_utterance_length, + "This endpointing " + rule_name + + " requires utterance-length (in seconds) to be >= this " + "value."); +} + +std::string EndpointRule::ToString() const { + std::ostringstream os; + + os << "EndpointRule("; + os << "must_contain_nonsilence=" + << (must_contain_nonsilence ? "True" : "False") << ", "; + os << "min_trailing_silence=" << min_trailing_silence << ", "; + os << "min_utterance_length=" << min_utterance_length << ")"; + + return os.str(); +} + +void EndpointConfig::Register(ParseOptions *po) { + RegisterEndpointRule(po, &rule1, "rule1"); + RegisterEndpointRule(po, &rule2, "rule2"); + RegisterEndpointRule(po, &rule3, "rule3"); +} + +std::string EndpointConfig::ToString() const { + std::ostringstream os; + + os << "EndpointConfig("; + os << "rule1=" << rule1.ToString() << ", "; + os << "rule2=" << rule2.ToString() << ", "; + os << "rule3=" << rule3.ToString() << ")"; + + return os.str(); +} + +bool Endpoint::IsEndpoint(int32_t num_frames_decoded, + int32_t trailing_silence_frames, + float frame_shift_in_seconds) const { + float utterance_length = + static_cast(num_frames_decoded) * frame_shift_in_seconds; + + float trailing_silence = + static_cast(trailing_silence_frames) * frame_shift_in_seconds; + + if (RuleActivated(config_.rule1, "rule1", trailing_silence, + utterance_length) || + RuleActivated(config_.rule2, "rule2", trailing_silence, + utterance_length) || + RuleActivated(config_.rule3, "rule3", trailing_silence, + utterance_length)) { + return true; + } + return false; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/endpoint.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/endpoint.h new file mode 100644 index 00000000..e7812d28 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/endpoint.h @@ -0,0 +1,76 @@ +// sherpa-mnn/csrc/endpoint.h +// +// Copyright (c) 2022 (authors: Pingfeng Luo) +// 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ENDPOINT_H_ +#define SHERPA_ONNX_CSRC_ENDPOINT_H_ + +#include +#include + +namespace sherpa_mnn { + +struct EndpointRule { + // If True, for this endpointing rule to apply there must + // be nonsilence in the best-path traceback. + // For decoding, a non-blank token is considered as non-silence + bool must_contain_nonsilence = true; + // This endpointing rule requires duration of trailing silence + // (in seconds) to be >= this value. + float min_trailing_silence = 2.0; + // This endpointing rule requires utterance-length (in seconds) + // to be >= this value. + float min_utterance_length = 0.0f; + + EndpointRule() = default; + + EndpointRule(bool must_contain_nonsilence, float min_trailing_silence, + float min_utterance_length) + : must_contain_nonsilence(must_contain_nonsilence), + min_trailing_silence(min_trailing_silence), + min_utterance_length(min_utterance_length) {} + + std::string ToString() const; +}; + +class ParseOptions; + +struct EndpointConfig { + // For default setting, + // rule1 times out after 2.4 seconds of silence, even if we decoded nothing. + // rule2 times out after 1.2 seconds of silence after decoding something. + // rule3 times out after the utterance is 20 seconds long, regardless of + // anything else. + EndpointRule rule1; + EndpointRule rule2; + EndpointRule rule3; + + void Register(ParseOptions *po); + + EndpointConfig() + : rule1{false, 2.4, 0}, rule2{true, 1.2, 0}, rule3{false, 0, 20} {} + + EndpointConfig(const EndpointRule &rule1, const EndpointRule &rule2, + const EndpointRule &rule3) + : rule1(rule1), rule2(rule2), rule3(rule3) {} + + std::string ToString() const; +}; + +class Endpoint { + public: + explicit Endpoint(const EndpointConfig &config) : config_(config) {} + + /// This function returns true if this set of endpointing rules thinks we + /// should terminate decoding. + bool IsEndpoint(int32_t num_frames_decoded, int32_t trailing_silence_frames, + float frame_shift_in_seconds) const; + + private: + EndpointConfig config_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ENDPOINT_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fast-clustering-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fast-clustering-config.cc new file mode 100644 index 00000000..7c0ed8f5 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fast-clustering-config.cc @@ -0,0 +1,45 @@ +// sherpa-mnn/csrc/fast-clustering-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/fast-clustering-config.h" + +#include +#include + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { +std::string FastClusteringConfig::ToString() const { + std::ostringstream os; + + os << "FastClusteringConfig("; + os << "num_clusters=" << num_clusters << ", "; + os << "threshold=" << threshold << ")"; + + return os.str(); +} + +void FastClusteringConfig::Register(ParseOptions *po) { + po->Register( + "num-clusters", &num_clusters, + "Number of cluster. If greater than 0, then cluster threshold is " + "ignored. Please provide it if you know the actual number of " + "clusters in advance."); + + po->Register("cluster-threshold", &threshold, + "If num_clusters is not specified, then it specifies the " + "distance threshold for clustering. smaller value -> more " + "clusters. larger value -> fewer clusters"); +} + +bool FastClusteringConfig::Validate() const { + if (num_clusters < 1 && threshold < 0) { + SHERPA_ONNX_LOGE("Please provide either num_clusters or threshold"); + return false; + } + + return true; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fast-clustering-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fast-clustering-config.h new file mode 100644 index 00000000..3c1c0b3b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fast-clustering-config.h @@ -0,0 +1,39 @@ +// sherpa-mnn/csrc/fast-clustering-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_FAST_CLUSTERING_CONFIG_H_ +#define SHERPA_ONNX_CSRC_FAST_CLUSTERING_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct FastClusteringConfig { + // If greater than 0, then threshold is ignored. + // + // We strongly recommend that you set it if you know the number of clusters + // in advance + int32_t num_clusters = -1; + + // distance threshold. + // + // The smaller, the more clusters it will generate. + // The larger, the fewer clusters it will generate. + float threshold = 0.5; + + FastClusteringConfig() = default; + + FastClusteringConfig(int32_t num_clusters, float threshold) + : num_clusters(num_clusters), threshold(threshold) {} + + std::string ToString() const; + + void Register(ParseOptions *po); + bool Validate() const; +}; + +} // namespace sherpa_mnn +#endif // SHERPA_ONNX_CSRC_FAST_CLUSTERING_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fast-clustering-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fast-clustering-test.cc new file mode 100644 index 00000000..5c0e4e9b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fast-clustering-test.cc @@ -0,0 +1,69 @@ +// sherpa-mnn/csrc/fast-clustering-test.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/fast-clustering.h" + +#include + +#include "gtest/gtest.h" + +namespace sherpa_mnn { + +TEST(FastClustering, TestTwoClusters) { + std::vector features = { + // point 0 + 0.1, + 0.1, + // point 2 + 0.4, + -0.5, + // point 3 + 0.6, + -0.7, + // point 1 + 0.2, + 0.3, + }; + + FastClusteringConfig config; + config.num_clusters = 2; + + FastClustering clustering(config); + auto labels = clustering.Cluster(features.data(), 4, 2); + int32_t k = 0; + for (auto i : labels) { + std::cout << "point " << k << ": label " << i << "\n"; + ++k; + } +} + +TEST(FastClustering, TestClusteringWithThreshold) { + std::vector features = { + // point 0 + 0.1, + 0.1, + // point 2 + 0.4, + -0.5, + // point 3 + 0.6, + -0.7, + // point 1 + 0.2, + 0.3, + }; + + FastClusteringConfig config; + config.threshold = 0.5; + + FastClustering clustering(config); + auto labels = clustering.Cluster(features.data(), 4, 2); + int32_t k = 0; + for (auto i : labels) { + std::cout << "point " << k << ": label " << i << "\n"; + ++k; + } +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fast-clustering.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fast-clustering.cc new file mode 100644 index 00000000..69c3357a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fast-clustering.cc @@ -0,0 +1,83 @@ +// sherpa-mnn/csrc/fast-clustering.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/fast-clustering.h" + +#include + +#include "Eigen/Dense" +#include "fastcluster-all-in-one.h" // NOLINT + +namespace sherpa_mnn { + +class FastClustering::Impl { + public: + explicit Impl(const FastClusteringConfig &config) : config_(config) {} + + std::vector Cluster(float *features, int32_t num_rows, + int32_t num_cols) const { + if (num_rows <= 0) { + return {}; + } + + if (num_rows == 1) { + return {0}; + } + + Eigen::Map< + Eigen::Matrix> + m(features, num_rows, num_cols); + m.rowwise().normalize(); + + std::vector distance((num_rows * (num_rows - 1)) / 2); + + int32_t k = 0; + for (int32_t i = 0; i != num_rows; ++i) { + auto v = m.row(i); + for (int32_t j = i + 1; j != num_rows; ++j) { + double cosine_similarity = v.dot(m.row(j)); + double consine_dissimilarity = 1 - cosine_similarity; + + if (consine_dissimilarity < 0) { + consine_dissimilarity = 0; + } + + distance[k] = consine_dissimilarity; + ++k; + } + } + + std::vector merge(2 * (num_rows - 1)); + std::vector height(num_rows - 1); + + fastclustercpp::hclust_fast(num_rows, distance.data(), + fastclustercpp::HCLUST_METHOD_COMPLETE, + merge.data(), height.data()); + + std::vector labels(num_rows); + if (config_.num_clusters > 0) { + fastclustercpp::cutree_k(num_rows, merge.data(), config_.num_clusters, + labels.data()); + } else { + fastclustercpp::cutree_cdist(num_rows, merge.data(), height.data(), + config_.threshold, labels.data()); + } + + return labels; + } + + private: + FastClusteringConfig config_; +}; + +FastClustering::FastClustering(const FastClusteringConfig &config) + : impl_(std::make_unique(config)) {} + +FastClustering::~FastClustering() = default; + +std::vector FastClustering::Cluster(float *features, int32_t num_rows, + int32_t num_cols) const { + return impl_->Cluster(features, num_rows, num_cols); +} +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fast-clustering.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fast-clustering.h new file mode 100644 index 00000000..0558be15 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fast-clustering.h @@ -0,0 +1,43 @@ +// sherpa-mnn/csrc/fast-clustering.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_FAST_CLUSTERING_H_ +#define SHERPA_ONNX_CSRC_FAST_CLUSTERING_H_ + +#include +#include + +#include "sherpa-mnn/csrc/fast-clustering-config.h" + +namespace sherpa_mnn { + +class FastClustering { + public: + explicit FastClustering(const FastClusteringConfig &config); + ~FastClustering(); + + /** + * @param features Pointer to a 2-D feature matrix in row major. Each row + * is a feature frame. It is changed in-place. We will + * convert each feature frame to a normalized vector. + * That is, the L2-norm of each vector will be equal to 1. + * It uses cosine dissimilarity, + * which is 1 - (cosine similarity) + * @param num_rows Number of feature frames + * @param num-cols The feature dimension. + * + * @return Return a vector of size num_rows. ans[i] contains the label + * for the i-th feature frame, i.e., the i-th row of the feature + * matrix. + */ + std::vector Cluster(float *features, int32_t num_rows, + int32_t num_cols) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn +#endif // SHERPA_ONNX_CSRC_FAST_CLUSTERING_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/features.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/features.cc new file mode 100644 index 00000000..fb6012d9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/features.cc @@ -0,0 +1,271 @@ +// sherpa-mnn/csrc/features.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/features.h" + +#include +#include +#include // NOLINT +#include +#include + +#include "kaldi-native-fbank/csrc/online-feature.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/resample.h" + +namespace sherpa_mnn { + +void FeatureExtractorConfig::Register(ParseOptions *po) { + po->Register("sample-rate", &sampling_rate, + "Sampling rate of the input waveform. " + "Note: You can have a different " + "sample rate for the input waveform. We will do resampling " + "inside the feature extractor"); + + po->Register("feat-dim", &feature_dim, + "Feature dimension. Must match the one expected by the model. " + "Not used by whisper and CED models"); + + po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins"); + + po->Register("high-freq", &high_freq, + "High cutoff frequency for mel bins " + "(if <= 0, offset from Nyquist)"); + + po->Register("dither", &dither, + "Dithering constant (0.0 means no dither). " + "By default the audio samples are in range [-1,+1], " + "so 0.00003 is a good value, " + "equivalent to the default 1.0 from kaldi"); +} + +std::string FeatureExtractorConfig::ToString() const { + std::ostringstream os; + + os << "FeatureExtractorConfig("; + os << "sampling_rate=" << sampling_rate << ", "; + os << "feature_dim=" << feature_dim << ", "; + os << "low_freq=" << low_freq << ", "; + os << "high_freq=" << high_freq << ", "; + os << "dither=" << dither << ", "; + os << "normalize_samples=" << (normalize_samples ? "True" : "False") << ", "; + os << "snip_edges=" << (snip_edges ? "True" : "False") << ")"; + + return os.str(); +} + +class FeatureExtractor::Impl { + public: + explicit Impl(const FeatureExtractorConfig &config) : config_(config) { + if (config_.is_mfcc) { + InitMfcc(); + } else { + InitFbank(); + } + } + + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { + if (config_.normalize_samples) { + AcceptWaveformImpl(sampling_rate, waveform, n); + } else { + std::vector buf(n); + for (int32_t i = 0; i != n; ++i) { + buf[i] = waveform[i] * 32768; + } + AcceptWaveformImpl(sampling_rate, buf.data(), n); + } + } + + void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform, + int32_t n) { + std::lock_guard lock(mutex_); + + if (resampler_) { + if (sampling_rate != resampler_->GetInputSamplingRate()) { + SHERPA_ONNX_LOGE( + "You changed the input sampling rate!! Expected: %d, given: " + "%d", + resampler_->GetInputSamplingRate(), sampling_rate); + exit(-1); + } + + std::vector samples; + resampler_->Resample(waveform, n, false, &samples); + if (fbank_) { + fbank_->AcceptWaveform(config_.sampling_rate, samples.data(), + samples.size()); + } else { + mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(), + samples.size()); + } + return; + } + + if (sampling_rate != config_.sampling_rate) { + SHERPA_ONNX_LOGE( + "Creating a resampler:\n" + " in_sample_rate: %d\n" + " output_sample_rate: %d\n", + sampling_rate, static_cast(config_.sampling_rate)); + + float min_freq = std::min(sampling_rate, config_.sampling_rate); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + + int32_t lowpass_filter_width = 6; + resampler_ = std::make_unique( + sampling_rate, config_.sampling_rate, lowpass_cutoff, + lowpass_filter_width); + + std::vector samples; + resampler_->Resample(waveform, n, false, &samples); + if (fbank_) { + fbank_->AcceptWaveform(config_.sampling_rate, samples.data(), + samples.size()); + } else { + mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(), + samples.size()); + } + return; + } + + if (fbank_) { + fbank_->AcceptWaveform(sampling_rate, waveform, n); + } else { + mfcc_->AcceptWaveform(sampling_rate, waveform, n); + } + } + + void InputFinished() const { + std::lock_guard lock(mutex_); + fbank_->InputFinished(); + } + + int32_t NumFramesReady() const { + std::lock_guard lock(mutex_); + return fbank_->NumFramesReady(); + } + + bool IsLastFrame(int32_t frame) const { + std::lock_guard lock(mutex_); + return fbank_->IsLastFrame(frame); + } + + std::vector GetFrames(int32_t frame_index, int32_t n) { + std::lock_guard lock(mutex_); + if (frame_index + n > fbank_->NumFramesReady()) { + SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n, + fbank_->NumFramesReady()); + exit(-1); + } + + int32_t discard_num = frame_index - last_frame_index_; + if (discard_num < 0) { + SHERPA_ONNX_LOGE("last_frame_index_: %d, frame_index_: %d", + last_frame_index_, frame_index); + exit(-1); + } + fbank_->Pop(discard_num); + + int32_t feature_dim = fbank_->Dim(); + std::vector features(feature_dim * n); + + float *p = features.data(); + + for (int32_t i = 0; i != n; ++i) { + const float *f = fbank_->GetFrame(i + frame_index); + std::copy(f, f + feature_dim, p); + p += feature_dim; + } + + last_frame_index_ = frame_index; + + return features; + } + + int32_t FeatureDim() const { + return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins; + } + + private: + void InitFbank() { + opts_.frame_opts.dither = config_.dither; + opts_.frame_opts.snip_edges = config_.snip_edges; + opts_.frame_opts.samp_freq = config_.sampling_rate; + opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms; + opts_.frame_opts.frame_length_ms = config_.frame_length_ms; + opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset; + opts_.frame_opts.preemph_coeff = config_.preemph_coeff; + opts_.frame_opts.window_type = config_.window_type; + + opts_.mel_opts.num_bins = config_.feature_dim; + + opts_.mel_opts.high_freq = config_.high_freq; + opts_.mel_opts.low_freq = config_.low_freq; + + opts_.mel_opts.is_librosa = config_.is_librosa; + + fbank_ = std::make_unique(opts_); + } + void InitMfcc() { + mfcc_opts_.frame_opts.dither = config_.dither; + mfcc_opts_.frame_opts.snip_edges = config_.snip_edges; + mfcc_opts_.frame_opts.samp_freq = config_.sampling_rate; + mfcc_opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms; + mfcc_opts_.frame_opts.frame_length_ms = config_.frame_length_ms; + mfcc_opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset; + mfcc_opts_.frame_opts.preemph_coeff = config_.preemph_coeff; + mfcc_opts_.frame_opts.window_type = config_.window_type; + + mfcc_opts_.mel_opts.num_bins = config_.feature_dim; + + mfcc_opts_.mel_opts.high_freq = config_.high_freq; + mfcc_opts_.mel_opts.low_freq = config_.low_freq; + + mfcc_opts_.mel_opts.is_librosa = config_.is_librosa; + + mfcc_opts_.num_ceps = config_.num_ceps; + mfcc_opts_.use_energy = config_.use_energy; + + mfcc_ = std::make_unique(mfcc_opts_); + } + + private: + std::unique_ptr fbank_; + std::unique_ptr mfcc_; + knf::FbankOptions opts_; + knf::MfccOptions mfcc_opts_; + FeatureExtractorConfig config_; + mutable std::mutex mutex_; + std::unique_ptr resampler_; + int32_t last_frame_index_ = 0; +}; + +FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) + : impl_(std::make_unique(config)) {} + +FeatureExtractor::~FeatureExtractor() = default; + +void FeatureExtractor::AcceptWaveform(int32_t sampling_rate, + const float *waveform, int32_t n) const { + impl_->AcceptWaveform(sampling_rate, waveform, n); +} + +void FeatureExtractor::InputFinished() const { impl_->InputFinished(); } + +int32_t FeatureExtractor::NumFramesReady() const { + return impl_->NumFramesReady(); +} + +bool FeatureExtractor::IsLastFrame(int32_t frame) const { + return impl_->IsLastFrame(frame); +} + +std::vector FeatureExtractor::GetFrames(int32_t frame_index, + int32_t n) const { + return impl_->GetFrames(frame_index, n); +} + +int32_t FeatureExtractor::FeatureDim() const { return impl_->FeatureDim(); } + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/features.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/features.h new file mode 100644 index 00000000..2a0f36c2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/features.h @@ -0,0 +1,137 @@ +// sherpa-mnn/csrc/features.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_FEATURES_H_ +#define SHERPA_ONNX_CSRC_FEATURES_H_ + +#include +#include +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct FeatureExtractorConfig { + // Sampling rate used by the feature extractor. If it is different from + // the sampling rate of the input waveform, we will do resampling inside. + int32_t sampling_rate = 16000; + + // num_mel_bins + // + // Note: for mfcc, this value is also for num_mel_bins. + // The actual feature dimension is actuall num_ceps + int32_t feature_dim = 80; + + // minimal frequency for Mel-filterbank, in Hz + float low_freq = 20.0f; + + // maximal frequency of Mel-filterbank + // in Hz; negative value is subtracted from Nyquist freq.: + // i.e. for sampling_rate 16000 / 2 - 400 = 7600Hz + // + // Please see + // https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27 + // and + // https://github.com/k2-fsa/sherpa-mnn/issues/514 + float high_freq = -400.0f; + + // dithering constant, useful for signals with hard-zeroes in non-speech parts + // this prevents large negative values in log-mel filterbanks + // + // In k2, audio samples are in range [-1..+1], in kaldi the range was + // [-32k..+32k], so the value 0.00003 is equivalent to kaldi default 1.0 + // + float dither = 0.0f; // dithering disabled by default + + // Set internally by some models, e.g., paraformer sets it to false. + // This parameter is not exposed to users from the commandline + // If true, the feature extractor expects inputs to be normalized to + // the range [-1, 1]. + // If false, we will multiply the inputs by 32768 + bool normalize_samples = true; + + bool snip_edges = false; + float frame_shift_ms = 10.0f; // in milliseconds. + float frame_length_ms = 25.0f; // in milliseconds. + bool is_librosa = false; + bool remove_dc_offset = true; // Subtract mean of wave before FFT. + float preemph_coeff = 0.97f; // Preemphasis coefficient. + std::string window_type = "povey"; // e.g. Hamming window + + // For models from NeMo + // This option is not exposed and is set internally when loading models. + // Possible values: + // - per_feature + // - all_features (not implemented yet) + // - fixed_mean (not implemented) + // - fixed_std (not implemented) + // - or just leave it to empty + // See + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 + // for details + std::string nemo_normalize_type; + + // for MFCC + int32_t num_ceps = 13; + bool use_energy = true; + + bool is_mfcc = false; + + std::string ToString() const; + + void Register(ParseOptions *po); +}; + +class FeatureExtractor { + public: + explicit FeatureExtractor(const FeatureExtractorConfig &config = {}); + ~FeatureExtractor(); + + /** + @param sampling_rate The sampling_rate of the input waveform. If it does + not equal to config.sampling_rate, we will do + resampling inside. + @param waveform Pointer to a 1-D array of size n. It must be normalized to + the range [-1, 1]. + @param n Number of entries in waveform + */ + void AcceptWaveform(int32_t sampling_rate, const float *waveform, + int32_t n) const; + + /** + * InputFinished() tells the class you won't be providing any + * more waveform. This will help flush out the last frame or two + * of features, in the case where snip-edges == false; it also + * affects the return value of IsLastFrame(). + */ + void InputFinished() const; + + int32_t NumFramesReady() const; + + /** Note: IsLastFrame() will only ever return true if you have called + * InputFinished() (and this frame is the last frame). + */ + bool IsLastFrame(int32_t frame) const; + + /** Get n frames starting from the given frame index. + * + * @param frame_index The starting frame index + * @param n Number of frames to get. + * @return Return a 2-D tensor of shape (n, feature_dim). + * which is flattened into a 1-D vector (flattened in row major) + */ + std::vector GetFrames(int32_t frame_index, int32_t n) const; + + /// Return feature dim of this extractor + int32_t FeatureDim() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_FEATURES_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/file-utils.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/file-utils.cc new file mode 100644 index 00000000..2a2308db --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/file-utils.cc @@ -0,0 +1,84 @@ +// sherpa-mnn/csrc/file-utils.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/file-utils.h" + +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +bool FileExists(const std::string &filename) { + return std::ifstream(filename).good(); +} + +void AssertFileExists(const std::string &filename) { + if (!FileExists(filename)) { + SHERPA_ONNX_LOGE("filename '%s' does not exist", filename.c_str()); + exit(-1); + } +} + +std::vector ReadFile(const std::string &filename) { + std::ifstream input(filename, std::ios::binary); + std::vector buffer(std::istreambuf_iterator(input), {}); + return buffer; +} + +#if __ANDROID_API__ >= 9 +std::vector ReadFile(AAssetManager *mgr, const std::string &filename) { + AAsset *asset = AAssetManager_open(mgr, filename.c_str(), AASSET_MODE_BUFFER); + if (!asset) { + __android_log_print(ANDROID_LOG_FATAL, "sherpa-mnn", + "Read binary file: Load %s failed", filename.c_str()); + exit(-1); + } + + auto p = reinterpret_cast(AAsset_getBuffer(asset)); + size_t asset_length = AAsset_getLength(asset); + + std::vector buffer(p, p + asset_length); + AAsset_close(asset); + + return buffer; +} +#endif + +#if __OHOS__ +std::vector ReadFile(NativeResourceManager *mgr, + const std::string &filename) { + std::unique_ptr fp( + OH_ResourceManager_OpenRawFile(mgr, filename.c_str()), + OH_ResourceManager_CloseRawFile); + + if (!fp) { + std::ostringstream os; + os << "Read file '" << filename << "' failed."; + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + return {}; + } + + auto len = static_cast(OH_ResourceManager_GetRawFileSize(fp.get())); + + std::vector buffer(len); + + int32_t n = OH_ResourceManager_ReadRawFile(fp.get(), buffer.data(), len); + + if (n != len) { + std::ostringstream os; + os << "Read file '" << filename << "' failed. Number of bytes read: " << n + << ". Expected bytes to read: " << len; + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + return {}; + } + + return buffer; +} +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/file-utils.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/file-utils.h new file mode 100644 index 00000000..cad8193d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/file-utils.h @@ -0,0 +1,49 @@ +// sherpa-mnn/csrc/file-utils.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_FILE_UTILS_H_ +#define SHERPA_ONNX_CSRC_FILE_UTILS_H_ + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +namespace sherpa_mnn { + +/** Check whether a given path is a file or not + * + * @param filename Path to check. + * @return Return true if the given path is a file; return false otherwise. + */ +bool FileExists(const std::string &filename); + +/** Abort if the file does not exist. + * + * @param filename The file to check. + */ +void AssertFileExists(const std::string &filename); + +std::vector ReadFile(const std::string &filename); + +#if __ANDROID_API__ >= 9 +std::vector ReadFile(AAssetManager *mgr, const std::string &filename); +#endif + +#if __OHOS__ +std::vector ReadFile(NativeResourceManager *mgr, + const std::string &filename); +#endif + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_FILE_UTILS_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fst-utils.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fst-utils.cc new file mode 100644 index 00000000..274823a1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fst-utils.cc @@ -0,0 +1,53 @@ +// sherpa-mnn/csrc/fst-utils.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/fst-utils.h" + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +// This function is copied from kaldi. +// +// @param filename Path to a StdVectorFst or StdConstFst graph +// @return The caller should free the returned pointer using `delete` to +// avoid memory leak. +fst::Fst *ReadGraph(const std::string &filename) { + // read decoding network FST + std::ifstream is(filename, std::ios::binary); + if (!is.good()) { + SHERPA_ONNX_LOGE("Could not open decoding-graph FST %s", filename.c_str()); + } + + fst::FstHeader hdr; + if (!hdr.Read(is, "")) { + SHERPA_ONNX_LOGE("Reading FST: error reading FST header."); + } + + if (hdr.ArcType() != fst::StdArc::Type()) { + SHERPA_ONNX_LOGE("FST with arc type %s not supported", + hdr.ArcType().c_str()); + } + fst::FstReadOptions ropts("", &hdr); + + fst::Fst *decode_fst = nullptr; + + if (hdr.FstType() == "vector") { + decode_fst = fst::VectorFst::Read(is, ropts); + } else if (hdr.FstType() == "const") { + decode_fst = fst::ConstFst::Read(is, ropts); + } else { + SHERPA_ONNX_LOGE("Reading FST: unsupported FST type: %s", + hdr.FstType().c_str()); + } + + if (decode_fst == nullptr) { // fst code will warn. + SHERPA_ONNX_LOGE("Error reading FST (after reading header)."); + return nullptr; + } else { + return decode_fst; + } +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fst-utils.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fst-utils.h new file mode 100644 index 00000000..aaac84d3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/fst-utils.h @@ -0,0 +1,18 @@ +// sherpa-mnn/csrc/fst-utils.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_FST_UTILS_H_ +#define SHERPA_ONNX_CSRC_FST_UTILS_H_ + +#include + +#include "fst/fstlib.h" + +namespace sherpa_mnn { + +fst::Fst *ReadGraph(const std::string &filename); + +} + +#endif // SHERPA_ONNX_CSRC_FST_UTILS_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/hifigan-vocoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/hifigan-vocoder.cc new file mode 100644 index 00000000..86cdb024 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/hifigan-vocoder.cc @@ -0,0 +1,106 @@ +// sherpa-mnn/csrc/hifigan-vocoder.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/hifigan-vocoder.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" + +namespace sherpa_mnn { + +class HifiganVocoder::Impl { + public: + explicit Impl(int32_t num_threads, const std::string &provider, + const std::string &model) + : + sess_opts_(GetSessionOptions(num_threads, provider)), + allocator_{} { + auto buf = ReadFile(model); + Init(buf.data(), buf.size()); + } + + template + explicit Impl(Manager *mgr, int32_t num_threads, const std::string &provider, + const std::string &model) + : + sess_opts_(GetSessionOptions(num_threads, provider)), + allocator_{} { + auto buf = ReadFile(mgr, model); + Init(buf.data(), buf.size()); + } + + MNN::Express::VARP Run(MNN::Express::VARP mel) const { + auto out = sess_->onForward({mel}); + return std::move(out[0]); + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + } + + private: + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; +}; + +HifiganVocoder::HifiganVocoder(int32_t num_threads, const std::string &provider, + const std::string &model) + : impl_(std::make_unique(num_threads, provider, model)) {} + +template +HifiganVocoder::HifiganVocoder(Manager *mgr, int32_t num_threads, + const std::string &provider, + const std::string &model) + : impl_(std::make_unique(mgr, num_threads, provider, model)) {} + +HifiganVocoder::~HifiganVocoder() = default; + +MNN::Express::VARP HifiganVocoder::Run(MNN::Express::VARP mel) const { + return impl_->Run(std::move(mel)); +} + +#if __ANDROID_API__ >= 9 +template HifiganVocoder::HifiganVocoder(AAssetManager *mgr, int32_t num_threads, + const std::string &provider, + const std::string &model); +#endif + +#if __OHOS__ +template HifiganVocoder::HifiganVocoder(NativeResourceManager *mgr, + int32_t num_threads, + const std::string &provider, + const std::string &model); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/hifigan-vocoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/hifigan-vocoder.h new file mode 100644 index 00000000..19020e64 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/hifigan-vocoder.h @@ -0,0 +1,38 @@ +// sherpa-mnn/csrc/hifigan-vocoder.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_HIFIGAN_VOCODER_H_ +#define SHERPA_ONNX_CSRC_HIFIGAN_VOCODER_H_ + +#include +#include + +#include "MNNUtils.hpp" // NOLINT + +namespace sherpa_mnn { + +class HifiganVocoder { + public: + ~HifiganVocoder(); + + HifiganVocoder(int32_t num_threads, const std::string &provider, + const std::string &model); + + template + HifiganVocoder(Manager *mgr, int32_t num_threads, const std::string &provider, + const std::string &model); + + /** @param mel A float32 tensor of shape (batch_size, feat_dim, num_frames). + * @return Return a float32 tensor of shape (batch_size, num_samples). + */ + MNN::Express::VARP Run(MNN::Express::VARP mel) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_HIFIGAN_VOCODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/hypothesis.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/hypothesis.cc new file mode 100644 index 00000000..361f406d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/hypothesis.cc @@ -0,0 +1,81 @@ +/** + * Copyright (c) 2023 Xiaomi Corporation + * Copyright (c) 2023 Pingfeng Luo + */ + +#include "sherpa-mnn/csrc/hypothesis.h" + +#include +#include + +namespace sherpa_mnn { + +void Hypotheses::Add(Hypothesis hyp) { + auto key = hyp.Key(); + auto it = hyps_dict_.find(key); + if (it == hyps_dict_.end()) { + hyps_dict_[key] = std::move(hyp); + } else { + it->second.log_prob = LogAdd()(it->second.log_prob, hyp.log_prob); + } +} + +Hypothesis Hypotheses::GetMostProbable(bool length_norm) const { + if (length_norm == false) { + return std::max_element(hyps_dict_.begin(), hyps_dict_.end(), + [](const auto &left, auto &right) -> bool { + return left.second.TotalLogProb() < + right.second.TotalLogProb(); + }) + ->second; + } else { + // for length_norm is true + return std::max_element( + hyps_dict_.begin(), hyps_dict_.end(), + [](const auto &left, const auto &right) -> bool { + return left.second.TotalLogProb() / left.second.ys.size() < + right.second.TotalLogProb() / right.second.ys.size(); + }) + ->second; + } +} + +std::vector Hypotheses::GetTopK(int32_t k, bool length_norm) const { + k = std::max(k, 1); + k = std::min(k, Size()); + + std::vector all_hyps = Vec(); + + if (length_norm == false) { + std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), + [](const auto &a, const auto &b) { + return a.TotalLogProb() > b.TotalLogProb(); + }); + } else { + // for length_norm is true + std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), + [](const auto &a, const auto &b) { + return a.TotalLogProb() / a.ys.size() > + b.TotalLogProb() / b.ys.size(); + }); + } + + return {all_hyps.begin(), all_hyps.begin() + k}; +} + +const std::vector GetHypsRowSplits( + const std::vector &hyps) { + std::vector row_splits; + row_splits.reserve(hyps.size() + 1); + + row_splits.push_back(0); + int32_t s = 0; + for (const auto &h : hyps) { + s += h.Size(); + row_splits.push_back(s); + } + + return row_splits; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/hypothesis.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/hypothesis.h new file mode 100644 index 00000000..d8196d15 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/hypothesis.h @@ -0,0 +1,165 @@ +/** + * Copyright (c) 2023 Xiaomi Corporation + * Copyright (c) 2023 Pingfeng Luo + * + */ + +#ifndef SHERPA_ONNX_CSRC_HYPOTHESIS_H_ +#define SHERPA_ONNX_CSRC_HYPOTHESIS_H_ + +#include +#include +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/context-graph.h" +#include "sherpa-mnn/csrc/math.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +struct Hypothesis { + // The predicted tokens so far. Newly predicated tokens are appended. + std::vector ys; + + // timestamps[i] contains the frame number after subsampling + // on which ys[i] is decoded. + std::vector timestamps; + + // The acoustic probability for each token in ys. + // Used for keyword spotting task. + // For transducer mofified beam-search and greedy-search, + // this is filled with log_posterior scores. + std::vector ys_probs; + + // lm_probs[i] contains the lm score for each token in ys. + // Used only in transducer mofified beam-search. + // Elements filled only if LM is used. + std::vector lm_probs; + + // context_scores[i] contains the context-graph score for each token in ys. + // Used only in transducer mofified beam-search. + // Elements filled only if `ContextGraph` is used. + std::vector context_scores; + + // The total score of ys in log space. + // It contains only acoustic scores + double log_prob = 0; + + // LM log prob if any. + double lm_log_prob = 0; + + // the nn lm score for next token given the current ys, + // when using shallow fusion + CopyableOrtValue nn_lm_scores; + + // cur scored tokens by RNN LM, when rescoring + int32_t cur_scored_pos = 0; + + // the nn lm states + std::vector nn_lm_states; + + const ContextState *context_state; + + // TODO(fangjun): Make it configurable + // the minimum of tokens in a chunk for streaming RNN LM + int32_t lm_rescore_min_chunk = 2; // a const + + int32_t num_trailing_blanks = 0; + + Hypothesis() = default; + Hypothesis(const std::vector &ys, double log_prob, + const ContextState *context_state = nullptr) + : ys(ys), log_prob(log_prob), context_state(context_state) {} + + double TotalLogProb() const { return log_prob + lm_log_prob; } + + // If two Hypotheses have the same `Key`, then they contain + // the same token sequence. + std::string Key() const { + // TODO(fangjun): Use a hash function? + std::ostringstream os; + std::string sep; + for (auto i : ys) { + os << sep << i; + sep = "-"; + } + return os.str(); + } + + // For debugging + std::string ToString() const { + std::ostringstream os; + os << "(" << Key() << ", " << log_prob << ")"; + return os.str(); + } +}; + +class Hypotheses { + public: + Hypotheses() = default; + + explicit Hypotheses(std::vector hyps) { + for (auto &h : hyps) { + hyps_dict_[h.Key()] = std::move(h); + } + } + + explicit Hypotheses(std::unordered_map hyps_dict) + : hyps_dict_(std::move(hyps_dict)) {} + + // Add hyp to this object. If it already exists, its log_prob + // is updated with the given hyp using log-sum-exp. + void Add(Hypothesis hyp); + + // Get the hyp that has the largest log_prob. + // If length_norm is true, hyp's log_prob is divided by + // len(hyp.ys) before comparison. + Hypothesis GetMostProbable(bool length_norm) const; + + // Get the k hyps that have the largest log_prob. + // If length_norm is true, hyp's log_prob is divided by + // len(hyp.ys) before comparison. + std::vector GetTopK(int32_t k, bool length_norm) const; + + int32_t Size() const { return hyps_dict_.size(); } + + std::string ToString() const { + std::ostringstream os; + for (const auto &p : hyps_dict_) { + os << p.second.ToString() << "\n"; + } + return os.str(); + } + + auto begin() const { return hyps_dict_.begin(); } + auto end() const { return hyps_dict_.end(); } + + auto begin() { return hyps_dict_.begin(); } + auto end() { return hyps_dict_.end(); } + + void Clear() { hyps_dict_.clear(); } + + // Return a list of hyps contained in this object. + std::vector Vec() const { + std::vector ans; + ans.reserve(hyps_dict_.size()); + for (const auto &p : hyps_dict_) { + ans.push_back(p.second); + } + return ans; + } + + private: + using Map = std ::unordered_map; + Map hyps_dict_; +}; + +const std::vector GetHypsRowSplits( + const std::vector &hyps); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_HYPOTHESIS_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/jieba-lexicon.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/jieba-lexicon.cc new file mode 100644 index 00000000..6898c83f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/jieba-lexicon.cc @@ -0,0 +1,351 @@ +// sherpa-mnn/csrc/jieba-lexicon.cc +// +// Copyright (c) 2022-2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/jieba-lexicon.h" + +#include +#include // NOLINT +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "cppjieba/Jieba.hpp" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/symbol-table.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +static bool IsPunct(const std::string &s) { + static const std::unordered_set puncts = { + ",", ".", "!", "?", ":", "\"", "'", ",", + "。", "!", "?", "“", "”", "‘", "’", + }; + return puncts.count(s); +} + +class JiebaLexicon::Impl { + public: + Impl(const std::string &lexicon, const std::string &tokens, + const std::string &dict_dir, bool debug) + : debug_(debug) { + std::string dict = dict_dir + "/jieba.dict.utf8"; + std::string hmm = dict_dir + "/hmm_model.utf8"; + std::string user_dict = dict_dir + "/user.dict.utf8"; + std::string idf = dict_dir + "/idf.utf8"; + std::string stop_word = dict_dir + "/stop_words.utf8"; + + AssertFileExists(dict); + AssertFileExists(hmm); + AssertFileExists(user_dict); + AssertFileExists(idf); + AssertFileExists(stop_word); + + jieba_ = + std::make_unique(dict, hmm, user_dict, idf, stop_word); + + { + std::ifstream is(tokens); + InitTokens(is); + } + + { + std::ifstream is(lexicon); + InitLexicon(is); + } + } + + template + Impl(Manager *mgr, const std::string &lexicon, const std::string &tokens, + const std::string &dict_dir, bool debug) + : debug_(debug) { + std::string dict = dict_dir + "/jieba.dict.utf8"; + std::string hmm = dict_dir + "/hmm_model.utf8"; + std::string user_dict = dict_dir + "/user.dict.utf8"; + std::string idf = dict_dir + "/idf.utf8"; + std::string stop_word = dict_dir + "/stop_words.utf8"; + + AssertFileExists(dict); + AssertFileExists(hmm); + AssertFileExists(user_dict); + AssertFileExists(idf); + AssertFileExists(stop_word); + + jieba_ = + std::make_unique(dict, hmm, user_dict, idf, stop_word); + + { + auto buf = ReadFile(mgr, tokens); + std::istrstream is(buf.data(), buf.size()); + + InitTokens(is); + } + + { + auto buf = ReadFile(mgr, lexicon); + std::istrstream is(buf.data(), buf.size()); + InitLexicon(is); + } + } + + std::vector ConvertTextToTokenIds(const std::string &text) const { + // see + // https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244 + std::regex punct_re{":|、|;"}; + std::string s = std::regex_replace(text, punct_re, ","); + + std::regex punct_re2("[.]"); + s = std::regex_replace(s, punct_re2, "。"); + + std::regex punct_re3("[?]"); + s = std::regex_replace(s, punct_re3, "?"); + + std::regex punct_re4("[!]"); + s = std::regex_replace(s, punct_re4, "!"); + + std::vector words; + bool is_hmm = true; + jieba_->Cut(text, words, is_hmm); + + if (debug_) { +#if __OHOS__ + SHERPA_ONNX_LOGE("input text:\n%{public}s", text.c_str()); + SHERPA_ONNX_LOGE("after replacing punctuations:\n%{public}s", s.c_str()); +#else + SHERPA_ONNX_LOGE("input text:\n%s", text.c_str()); + SHERPA_ONNX_LOGE("after replacing punctuations:\n%s", s.c_str()); +#endif + + std::ostringstream os; + std::string sep = ""; + for (const auto &w : words) { + os << sep << w; + sep = "_"; + } + +#if __OHOS__ + SHERPA_ONNX_LOGE("after jieba processing:\n%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("after jieba processing:\n%s", os.str().c_str()); +#endif + } + + // remove spaces after punctuations + std::vector words2 = std::move(words); + words.reserve(words2.size()); + + for (int32_t i = 0; i < words2.size(); ++i) { + if (i == 0) { + words.push_back(std::move(words2[i])); + } else if (words2[i] == " ") { + if (words.back() == " " || IsPunct(words.back())) { + continue; + } else { + words.push_back(std::move(words2[i])); + } + } else if (IsPunct(words2[i])) { + if (words.back() == " " || IsPunct(words.back())) { + continue; + } else { + words.push_back(std::move(words2[i])); + } + } else { + words.push_back(std::move(words2[i])); + } + } + + if (debug_) { + std::ostringstream os; + std::string sep = ""; + for (const auto &w : words) { + os << sep << w; + sep = "_"; + } + +#if __OHOS__ + SHERPA_ONNX_LOGE("after removing spaces after punctuations:\n%{public}s", + os.str().c_str()); +#else + SHERPA_ONNX_LOGE("after removing spaces after punctuations:\n%s", + os.str().c_str()); +#endif + } + + std::vector ans; + std::vector this_sentence; + + for (const auto &w : words) { + auto ids = ConvertWordToIds(w); + if (ids.empty()) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Ignore OOV '%{public}s'", w.c_str()); +#else + SHERPA_ONNX_LOGE("Ignore OOV '%s'", w.c_str()); +#endif + continue; + } + + this_sentence.insert(this_sentence.end(), ids.begin(), ids.end()); + + if (IsPunct(w)) { + ans.emplace_back(std::move(this_sentence)); + this_sentence = {}; + } + } // for (const auto &w : words) + + if (!this_sentence.empty()) { + ans.emplace_back(std::move(this_sentence)); + } + + return ans; + } + + private: + std::vector ConvertWordToIds(const std::string &w) const { + if (word2ids_.count(w)) { + return word2ids_.at(w); + } + + if (token2id_.count(w)) { + return {token2id_.at(w)}; + } + + std::vector ans; + + std::vector words = SplitUtf8(w); + for (const auto &word : words) { + if (word2ids_.count(word)) { + auto ids = ConvertWordToIds(word); + ans.insert(ans.end(), ids.begin(), ids.end()); + } + } + + return ans; + } + + void InitTokens(std::istream &is) { + token2id_ = ReadTokens(is); + + std::vector> puncts = { + {",", ","}, {".", "。"}, {"!", "!"}, {"?", "?"}, {":", ":"}, + {"\"", "“"}, {"\"", "”"}, {"'", "‘"}, {"'", "’"}, {";", ";"}, + }; + + for (const auto &p : puncts) { + if (token2id_.count(p.first) && !token2id_.count(p.second)) { + token2id_[p.second] = token2id_[p.first]; + } + + if (!token2id_.count(p.first) && token2id_.count(p.second)) { + token2id_[p.first] = token2id_[p.second]; + } + } + + if (!token2id_.count("、") && token2id_.count(",")) { + token2id_["、"] = token2id_[","]; + } + + if (!token2id_.count(";") && token2id_.count(",")) { + token2id_[";"] = token2id_[","]; + } + } + + void InitLexicon(std::istream &is) { + std::string word; + std::vector token_list; + std::string line; + std::string phone; + int32_t line_num = 0; + + while (std::getline(is, line)) { + ++line_num; + + std::istringstream iss(line); + + token_list.clear(); + + iss >> word; + ToLowerCase(&word); + + if (word2ids_.count(word)) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "Duplicated word: %{public}s at line %{public}d:%{public}s. Ignore " + "it.", + word.c_str(), line_num, line.c_str()); +#else + SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.", + word.c_str(), line_num, line.c_str()); +#endif + continue; + } + + while (iss >> phone) { + token_list.push_back(std::move(phone)); + } + + std::vector ids = ConvertTokensToIds(token2id_, token_list); + if (ids.empty()) { + continue; + } + + word2ids_.insert({std::move(word), std::move(ids)}); + } + } + + private: + // lexicon.txt is saved in word2ids_ + std::unordered_map> word2ids_; + + // tokens.txt is saved in token2id_ + std::unordered_map token2id_; + + std::unique_ptr jieba_; + bool debug_ = false; +}; + +JiebaLexicon::~JiebaLexicon() = default; + +JiebaLexicon::JiebaLexicon(const std::string &lexicon, + const std::string &tokens, + const std::string &dict_dir, bool debug) + : impl_(std::make_unique(lexicon, tokens, dict_dir, debug)) {} + +template +JiebaLexicon::JiebaLexicon(Manager *mgr, const std::string &lexicon, + const std::string &tokens, + const std::string &dict_dir, bool debug) + : impl_(std::make_unique(mgr, lexicon, tokens, dict_dir, debug)) {} + +std::vector JiebaLexicon::ConvertTextToTokenIds( + const std::string &text, const std::string & /*unused_voice = ""*/) const { + return impl_->ConvertTextToTokenIds(text); +} + +#if __ANDROID_API__ >= 9 +template JiebaLexicon::JiebaLexicon(AAssetManager *mgr, + const std::string &lexicon, + const std::string &tokens, + const std::string &dict_dir, bool debug); +#endif + +#if __OHOS__ +template JiebaLexicon::JiebaLexicon(NativeResourceManager *mgr, + const std::string &lexicon, + const std::string &tokens, + const std::string &dict_dir, bool debug); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/jieba-lexicon.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/jieba-lexicon.h new file mode 100644 index 00000000..b6abf472 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/jieba-lexicon.h @@ -0,0 +1,40 @@ +// sherpa-mnn/csrc/jieba-lexicon.h +// +// Copyright (c) 2022-2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_JIEBA_LEXICON_H_ +#define SHERPA_ONNX_CSRC_JIEBA_LEXICON_H_ + +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/offline-tts-frontend.h" + +namespace sherpa_mnn { + +class JiebaLexicon : public OfflineTtsFrontend { + public: + ~JiebaLexicon() override; + + JiebaLexicon(const std::string &lexicon, const std::string &tokens, + const std::string &dict_dir, bool debug); + + template + JiebaLexicon(Manager *mgr, const std::string &lexicon, + const std::string &tokens, const std::string &dict_dir, + bool debug); + + std::vector ConvertTextToTokenIds( + const std::string &text, + const std::string &unused_voice = "") const override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_JIEBA_LEXICON_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/keyword-spotter-impl.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/keyword-spotter-impl.cc new file mode 100644 index 00000000..c27c4636 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/keyword-spotter-impl.cc @@ -0,0 +1,51 @@ +// sherpa-mnn/csrc/keyword-spotter-impl.cc +// +// Copyright (c) 2023-2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/keyword-spotter-impl.h" + +#include "sherpa-mnn/csrc/keyword-spotter-transducer-impl.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +namespace sherpa_mnn { + +std::unique_ptr KeywordSpotterImpl::Create( + const KeywordSpotterConfig &config) { + if (!config.model_config.transducer.encoder.empty()) { + return std::make_unique(config); + } + + SHERPA_ONNX_LOGE("Please specify a model"); + exit(-1); +} + +template +std::unique_ptr KeywordSpotterImpl::Create( + Manager *mgr, const KeywordSpotterConfig &config) { + if (!config.model_config.transducer.encoder.empty()) { + return std::make_unique(mgr, config); + } + + SHERPA_ONNX_LOGE("Please specify a model"); + exit(-1); +} + +#if __ANDROID_API__ >= 9 +template std::unique_ptr KeywordSpotterImpl::Create( + AAssetManager *mgr, const KeywordSpotterConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr KeywordSpotterImpl::Create( + NativeResourceManager *mgr, const KeywordSpotterConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/keyword-spotter-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/keyword-spotter-impl.h new file mode 100644 index 00000000..23402ea3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/keyword-spotter-impl.h @@ -0,0 +1,44 @@ +// sherpa-mnn/csrc/keyword-spotter-impl.h +// +// Copyright (c) 2023-2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_ +#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_ + +#include +#include +#include + +#include "sherpa-mnn/csrc/keyword-spotter.h" +#include "sherpa-mnn/csrc/online-stream.h" + +namespace sherpa_mnn { + +class KeywordSpotterImpl { + public: + static std::unique_ptr Create( + const KeywordSpotterConfig &config); + + template + static std::unique_ptr Create( + Manager *mgr, const KeywordSpotterConfig &config); + + virtual ~KeywordSpotterImpl() = default; + + virtual std::unique_ptr CreateStream() const = 0; + + virtual std::unique_ptr CreateStream( + const std::string &keywords) const = 0; + + virtual bool IsReady(OnlineStream *s) const = 0; + + virtual void Reset(OnlineStream *s) const = 0; + + virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0; + + virtual KeywordResult GetResult(OnlineStream *s) const = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/keyword-spotter-transducer-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/keyword-spotter-transducer-impl.h new file mode 100644 index 00000000..f49f6a72 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/keyword-spotter-transducer-impl.h @@ -0,0 +1,370 @@ +// sherpa-mnn/csrc/keyword-spotter-transducer-impl.h +// +// Copyright (c) 2023-2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_ +#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_ + +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/keyword-spotter-impl.h" +#include "sherpa-mnn/csrc/keyword-spotter.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-transducer-model.h" +#include "sherpa-mnn/csrc/symbol-table.h" +#include "sherpa-mnn/csrc/transducer-keyword-decoder.h" +#include "sherpa-mnn/csrc/utils.h" + +namespace sherpa_mnn { + +static KeywordResult Convert(const TransducerKeywordResult &src, + const SymbolTable &sym_table, float frame_shift_ms, + int32_t subsampling_factor, + int32_t frames_since_start) { + KeywordResult r; + r.tokens.reserve(src.tokens.size()); + r.timestamps.reserve(src.tokens.size()); + r.keyword = src.keyword; + bool from_tokens = src.keyword.empty(); + + for (auto i : src.tokens) { + auto sym = sym_table[i]; + if (from_tokens) { + r.keyword.append(sym); + } + r.tokens.push_back(std::move(sym)); + } + if (from_tokens && r.keyword.size()) { + r.keyword = r.keyword.substr(1); + } + + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; + for (auto t : src.timestamps) { + float time = frame_shift_s * t; + r.timestamps.push_back(time); + } + + r.start_time = frames_since_start * frame_shift_ms / 1000.; + + return r; +} + +class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { + public: + explicit KeywordSpotterTransducerImpl(const KeywordSpotterConfig &config) + : config_(config), + model_(OnlineTransducerModel::Create(config.model_config)) { + if (!config.model_config.tokens_buf.empty()) { + sym_ = SymbolTable(config.model_config.tokens_buf, false); + } else { + /// assuming tokens_buf and tokens are guaranteed not being both empty + sym_ = SymbolTable(config.model_config.tokens, true); + } + + if (sym_.Contains("")) { + unk_id_ = sym_[""]; + } + + model_->SetFeatureDim(config.feat_config.feature_dim); + + if (config.keywords_buf.empty()) { + InitKeywords(); + } else { + InitKeywordsFromBufStr(); + } + + decoder_ = std::make_unique( + model_.get(), config_.max_active_paths, config_.num_trailing_blanks, + unk_id_); + } + + template + KeywordSpotterTransducerImpl(Manager *mgr, const KeywordSpotterConfig &config) + : config_(config), + model_(OnlineTransducerModel::Create(mgr, config.model_config)), + sym_(mgr, config.model_config.tokens) { + if (sym_.Contains("")) { + unk_id_ = sym_[""]; + } + + model_->SetFeatureDim(config.feat_config.feature_dim); + + InitKeywords(mgr); + + decoder_ = std::make_unique( + model_.get(), config_.max_active_paths, config_.num_trailing_blanks, + unk_id_); + } + + std::unique_ptr CreateStream() const override { + auto stream = + std::make_unique(config_.feat_config, keywords_graph_); + InitOnlineStream(stream.get()); + return stream; + } + + std::unique_ptr CreateStream( + const std::string &keywords) const override { + auto kws = std::regex_replace(keywords, std::regex("/"), "\n"); + std::istringstream is(kws); + + std::vector> current_ids; + std::vector current_kws; + std::vector current_scores; + std::vector current_thresholds; + + if (!EncodeKeywords(is, sym_, ¤t_ids, ¤t_kws, ¤t_scores, + ¤t_thresholds)) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Encode keywords %{public}s failed.", keywords.c_str()); +#else + SHERPA_ONNX_LOGE("Encode keywords %s failed.", keywords.c_str()); +#endif + return nullptr; + } + + int32_t num_kws = current_ids.size(); + int32_t num_default_kws = keywords_id_.size(); + + current_ids.insert(current_ids.end(), keywords_id_.begin(), + keywords_id_.end()); + + if (!current_kws.empty() && !keywords_.empty()) { + current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end()); + } else if (!current_kws.empty() && keywords_.empty()) { + current_kws.insert(current_kws.end(), num_default_kws, std::string()); + } else if (current_kws.empty() && !keywords_.empty()) { + current_kws.insert(current_kws.end(), num_kws, std::string()); + current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end()); + } else { + // Do nothing. + } + + if (!current_scores.empty() && !boost_scores_.empty()) { + current_scores.insert(current_scores.end(), boost_scores_.begin(), + boost_scores_.end()); + } else if (!current_scores.empty() && boost_scores_.empty()) { + current_scores.insert(current_scores.end(), num_default_kws, + config_.keywords_score); + } else if (current_scores.empty() && !boost_scores_.empty()) { + current_scores.insert(current_scores.end(), num_kws, + config_.keywords_score); + current_scores.insert(current_scores.end(), boost_scores_.begin(), + boost_scores_.end()); + } else { + // Do nothing. + } + + if (!current_thresholds.empty() && !thresholds_.empty()) { + current_thresholds.insert(current_thresholds.end(), thresholds_.begin(), + thresholds_.end()); + } else if (!current_thresholds.empty() && thresholds_.empty()) { + current_thresholds.insert(current_thresholds.end(), num_default_kws, + config_.keywords_threshold); + } else if (current_thresholds.empty() && !thresholds_.empty()) { + current_thresholds.insert(current_thresholds.end(), num_kws, + config_.keywords_threshold); + current_thresholds.insert(current_thresholds.end(), thresholds_.begin(), + thresholds_.end()); + } else { + // Do nothing. + } + + auto keywords_graph = std::make_shared( + current_ids, config_.keywords_score, config_.keywords_threshold, + current_scores, current_kws, current_thresholds); + + auto stream = + std::make_unique(config_.feat_config, keywords_graph); + InitOnlineStream(stream.get()); + return stream; + } + + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() + model_->ChunkSize() < + s->NumFramesReady(); + } + void Reset(OnlineStream *s) const override { InitOnlineStream(s); } + + void DecodeStreams(OnlineStream **ss, int32_t n) const override { + for (int32_t i = 0; i < n; ++i) { + auto s = ss[i]; + auto r = s->GetKeywordResult(true); + int32_t num_trailing_blanks = r.num_trailing_blanks; + // assume subsampling_factor is 4 + // assume frameshift is 0.01 second + float trailing_slience = num_trailing_blanks * 4 * 0.01; + + // it resets automatically after detecting 1.5 seconds of silence + float threshold = 1.5; + if (trailing_slience > threshold) { + Reset(s); + } + } + + int32_t chunk_size = model_->ChunkSize(); + int32_t chunk_shift = model_->ChunkShift(); + + int32_t feature_dim = ss[0]->FeatureDim(); + + std::vector results(n); + std::vector features_vec(n * chunk_size * feature_dim); + std::vector> states_vec(n); + std::vector all_processed_frames(n); + + for (int32_t i = 0; i != n; ++i) { + SHERPA_ONNX_CHECK(ss[i]->GetContextGraph() != nullptr); + + const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); + std::vector features = + ss[i]->GetFrames(num_processed_frames, chunk_size); + + // Question: should num_processed_frames include chunk_shift? + ss[i]->GetNumProcessedFrames() += chunk_shift; + + std::copy(features.begin(), features.end(), + features_vec.data() + i * chunk_size * feature_dim); + + results[i] = std::move(ss[i]->GetKeywordResult()); + states_vec[i] = std::move(ss[i]->GetStates()); + all_processed_frames[i] = num_processed_frames; + } + + MNNAllocator* memory_info = nullptr; + + std::array x_shape{n, chunk_size, feature_dim}; + + MNN::Express::VARP x = MNNUtilsCreateTensor(memory_info, features_vec.data(), + features_vec.size(), x_shape.data(), + x_shape.size()); + + std::array processed_frames_shape{ + static_cast(all_processed_frames.size())}; + + MNN::Express::VARP processed_frames = MNNUtilsCreateTensor( + memory_info, all_processed_frames.data(), all_processed_frames.size(), + processed_frames_shape.data(), processed_frames_shape.size()); + + auto states = model_->StackStates(states_vec); + + auto pair = model_->RunEncoder(std::move(x), std::move(states), + std::move(processed_frames)); + + decoder_->Decode(std::move(pair.first), ss, &results); + + std::vector> next_states = + model_->UnStackStates(pair.second); + + for (int32_t i = 0; i != n; ++i) { + ss[i]->SetKeywordResult(results[i]); + ss[i]->SetStates(std::move(next_states[i])); + } + } + + KeywordResult GetResult(OnlineStream *s) const override { + TransducerKeywordResult decoder_result = s->GetKeywordResult(true); + + // TODO(fangjun): Remember to change these constants if needed + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = 4; + return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, + s->GetNumFramesSinceStart()); + } + + private: + void InitKeywords(std::istream &is) { + if (!EncodeKeywords(is, sym_, &keywords_id_, &keywords_, &boost_scores_, + &thresholds_)) { + SHERPA_ONNX_LOGE("Encode keywords failed."); + exit(-1); + } + keywords_graph_ = std::make_shared( + keywords_id_, config_.keywords_score, config_.keywords_threshold, + boost_scores_, keywords_, thresholds_); + } + + void InitKeywords() { +#ifdef SHERPA_ONNX_ENABLE_WASM_KWS + // Due to the limitations of the wasm file system, + // the keyword_file variable is directly parsed as a string of keywords + // if WASM KWS on + std::istringstream is(config_.keywords_file); + InitKeywords(is); +#else + // each line in keywords_file contains space-separated words + std::ifstream is(config_.keywords_file); + if (!is) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Open keywords file failed: %{public}s", + config_.keywords_file.c_str()); +#else + SHERPA_ONNX_LOGE("Open keywords file failed: %s", + config_.keywords_file.c_str()); +#endif + exit(-1); + } + InitKeywords(is); +#endif + } + + template + void InitKeywords(Manager *mgr) { + // each line in keywords_file contains space-separated words + + auto buf = ReadFile(mgr, config_.keywords_file); + + std::istrstream is(buf.data(), buf.size()); + + if (!is) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Open keywords file failed: %{public}s", + config_.keywords_file.c_str()); +#else + SHERPA_ONNX_LOGE("Open keywords file failed: %s", + config_.keywords_file.c_str()); +#endif + exit(-1); + } + InitKeywords(is); + } + + void InitKeywordsFromBufStr() { + // keywords_buf's content is supposed to be same as the keywords_file's + std::istringstream is(config_.keywords_buf); + InitKeywords(is); + } + + void InitOnlineStream(OnlineStream *stream) const { + auto r = decoder_->GetEmptyResult(); + SHERPA_ONNX_CHECK_EQ(r.hyps.Size(), 1); + + SHERPA_ONNX_CHECK(stream->GetContextGraph() != nullptr); + r.hyps.begin()->second.context_state = stream->GetContextGraph()->Root(); + + stream->SetKeywordResult(r); + stream->SetStates(model_->GetEncoderInitStates()); + } + + private: + KeywordSpotterConfig config_; + std::vector> keywords_id_; + std::vector boost_scores_; + std::vector thresholds_; + std::vector keywords_; + ContextGraphPtr keywords_graph_; + std::unique_ptr model_; + std::unique_ptr decoder_; + SymbolTable sym_; + int32_t unk_id_ = -1; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/keyword-spotter.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/keyword-spotter.cc new file mode 100644 index 00000000..fe292eb8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/keyword-spotter.cc @@ -0,0 +1,187 @@ +// sherpa-mnn/csrc/keyword-spotter.cc +// +// Copyright (c) 2023-2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/keyword-spotter.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/keyword-spotter-impl.h" + +namespace sherpa_mnn { + +std::string KeywordResult::AsJsonString() const { + std::ostringstream os; + os << "{"; + os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time + << ", "; + + os << "\"keyword\"" + << ": "; + os << "\"" << keyword << "\"" + << ", "; + + os << "\"" + << "timestamps" + << "\"" + << ": "; + os << "["; + + std::string sep = ""; + for (auto t : timestamps) { + os << sep << std::fixed << std::setprecision(2) << t; + sep = ", "; + } + os << "], "; + + os << "\"" + << "tokens" + << "\"" + << ":"; + os << "["; + + sep = ""; + auto oldFlags = os.flags(); + for (const auto &t : tokens) { + if (t.size() == 1 && static_cast(t[0]) > 0x7f) { + const uint8_t *p = reinterpret_cast(t.c_str()); + os << sep << "\"" + << "<0x" << std::hex << std::uppercase << static_cast(p[0]) + << ">" + << "\""; + os.flags(oldFlags); + } else { + os << sep << "\"" << t << "\""; + } + sep = ", "; + } + os << "]"; + os << "}"; + + return os.str(); +} + +void KeywordSpotterConfig::Register(ParseOptions *po) { + feat_config.Register(po); + model_config.Register(po); + + po->Register("max-active-paths", &max_active_paths, + "beam size used in modified beam search."); + po->Register("num-trailing-blanks", &num_trailing_blanks, + "The number of trailing blanks should have after the keyword."); + po->Register("keywords-score", &keywords_score, + "The bonus score for each token in context word/phrase."); + po->Register("keywords-threshold", &keywords_threshold, + "The acoustic threshold (probability) to trigger the keywords."); + po->Register( + "keywords-file", &keywords_file, + "The file containing keywords, one word/phrase per line, and for each" + "phrase the bpe/cjkchar are separated by a space. For example: " + "▁HE LL O ▁WORLD" + "你 好 世 界"); +} + +bool KeywordSpotterConfig::Validate() const { + if (!keywords_file.empty() && !keywords_buf.empty()) { + SHERPA_ONNX_LOGE( + "you can not provide a keywords_buf and a keywords file: '%s', " + "at the same time, which is confusing", + keywords_file.c_str()); + return false; + } + + if (keywords_file.empty() && keywords_buf.empty()) { + SHERPA_ONNX_LOGE( + "Please provide either a keywords-file or the keywords-buf"); + return false; + } + +#ifndef SHERPA_ONNX_ENABLE_WASM_KWS + // due to the limitations of the wasm file system, + // keywords file will be packaged into the sherpa-mnn-wasm-kws-main.data file + // Solution: take keyword_file variable is directly + // parsed as a string of keywords + if (keywords_buf.empty() && !std::ifstream(keywords_file.c_str()).good()) { + SHERPA_ONNX_LOGE("Keywords file '%s' does not exist.", + keywords_file.c_str()); + return false; + } +#endif + + return model_config.Validate(); +} + +std::string KeywordSpotterConfig::ToString() const { + std::ostringstream os; + + os << "KeywordSpotterConfig("; + os << "feat_config=" << feat_config.ToString() << ", "; + os << "model_config=" << model_config.ToString() << ", "; + os << "max_active_paths=" << max_active_paths << ", "; + os << "num_trailing_blanks=" << num_trailing_blanks << ", "; + os << "keywords_score=" << keywords_score << ", "; + os << "keywords_threshold=" << keywords_threshold << ", "; + os << "keywords_file=\"" << keywords_file << "\")"; + + return os.str(); +} + +KeywordSpotter::KeywordSpotter(const KeywordSpotterConfig &config) + : impl_(KeywordSpotterImpl::Create(config)) {} + +template +KeywordSpotter::KeywordSpotter(Manager *mgr, const KeywordSpotterConfig &config) + : impl_(KeywordSpotterImpl::Create(mgr, config)) {} + +KeywordSpotter::~KeywordSpotter() = default; + +std::unique_ptr KeywordSpotter::CreateStream() const { + return impl_->CreateStream(); +} + +std::unique_ptr KeywordSpotter::CreateStream( + const std::string &keywords) const { + return impl_->CreateStream(keywords); +} + +bool KeywordSpotter::IsReady(OnlineStream *s) const { + return impl_->IsReady(s); +} + +void KeywordSpotter::Reset(OnlineStream *s) const { impl_->Reset(s); } + +void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const { + impl_->DecodeStreams(ss, n); +} + +KeywordResult KeywordSpotter::GetResult(OnlineStream *s) const { + return impl_->GetResult(s); +} + +#if __ANDROID_API__ >= 9 +template KeywordSpotter::KeywordSpotter(AAssetManager *mgr, + const KeywordSpotterConfig &config); +#endif + +#if __OHOS__ +template KeywordSpotter::KeywordSpotter(NativeResourceManager *mgr, + const KeywordSpotterConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/keyword-spotter.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/keyword-spotter.h new file mode 100644 index 00000000..23bc4099 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/keyword-spotter.h @@ -0,0 +1,150 @@ +// sherpa-mnn/csrc/keyword-spotter.h +// +// Copyright (c) 2023-2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_ +#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_ + +#include +#include +#include + +#include "sherpa-mnn/csrc/features.h" +#include "sherpa-mnn/csrc/online-model-config.h" +#include "sherpa-mnn/csrc/online-stream.h" +#include "sherpa-mnn/csrc/online-transducer-model-config.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct KeywordResult { + /// The triggered keyword. + /// For English, it consists of space separated words. + /// For Chinese, it consists of Chinese words without spaces. + /// Example 1: "hello world" + /// Example 2: "你好世界" + std::string keyword; + + /// Decoded results at the token level. + /// For instance, for BPE-based models it consists of a list of BPE tokens. + std::vector tokens; + + /// timestamps.size() == tokens.size() + /// timestamps[i] records the time in seconds when tokens[i] is decoded. + std::vector timestamps; + + /// Starting time of this segment. + /// When an endpoint is detected, it will change + float start_time = 0; + + /** Return a json string. + * + * The returned string contains: + * { + * "keyword": "The triggered keyword", + * "tokens": [x, x, x], + * "timestamps": [x, x, x], + * "start_time": x, + * } + */ + std::string AsJsonString() const; +}; + +struct KeywordSpotterConfig { + FeatureExtractorConfig feat_config; + OnlineModelConfig model_config; + + int32_t max_active_paths = 4; + + int32_t num_trailing_blanks = 1; + + float keywords_score = 1.0; + + float keywords_threshold = 0.25; + + std::string keywords_file; + + /// if keywords_buf is non-empty, + /// the keywords will be loaded from the buffer instead of from the + /// "keywrods_file" + std::string keywords_buf; + + KeywordSpotterConfig() = default; + + KeywordSpotterConfig(const FeatureExtractorConfig &feat_config, + const OnlineModelConfig &model_config, + int32_t max_active_paths, int32_t num_trailing_blanks, + float keywords_score, float keywords_threshold, + const std::string &keywords_file) + : feat_config(feat_config), + model_config(model_config), + max_active_paths(max_active_paths), + num_trailing_blanks(num_trailing_blanks), + keywords_score(keywords_score), + keywords_threshold(keywords_threshold), + keywords_file(keywords_file) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +class KeywordSpotterImpl; + +class KeywordSpotter { + public: + explicit KeywordSpotter(const KeywordSpotterConfig &config); + + template + KeywordSpotter(Manager *mgr, const KeywordSpotterConfig &config); + + ~KeywordSpotter(); + + /** Create a stream for decoding. + * + */ + std::unique_ptr CreateStream() const; + + /** Create a stream for decoding. + * + * @param The keywords for this string, it might contain several keywords, + * the keywords are separated by "/". In each of the keywords, there + * are cjkchars or bpes, the bpe/cjkchar are separated by space (" "). + * For example, keywords I LOVE YOU and HELLO WORLD, looks like: + * + * "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD" + */ + std::unique_ptr CreateStream(const std::string &keywords) const; + + /** + * Return true if the given stream has enough frames for decoding. + * Return false otherwise + */ + bool IsReady(OnlineStream *s) const; + + // Remember to call it after detecting a keyword + void Reset(OnlineStream *s) const; + + /** Decode a single stream. */ + void DecodeStream(OnlineStream *s) const { + OnlineStream *ss[1] = {s}; + DecodeStreams(ss, 1); + } + + /** Decode multiple streams in parallel + * + * @param ss Pointer array containing streams to be decoded. + * @param n Number of streams in `ss`. + */ + void DecodeStreams(OnlineStream **ss, int32_t n) const; + + KeywordResult GetResult(OnlineStream *s) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/kokoro-multi-lang-lexicon.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/kokoro-multi-lang-lexicon.cc new file mode 100644 index 00000000..87164302 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/kokoro-multi-lang-lexicon.cc @@ -0,0 +1,525 @@ +// sherpa-mnn/csrc/kokoro-multi-lang-lexicon.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-mnn/csrc/kokoro-multi-lang-lexicon.h" + +#include +#include // NOLINT +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include + +#include "cppjieba/Jieba.hpp" +#include "espeak-ng/speak_lib.h" +#include "phoneme_ids.hpp" +#include "phonemize.hpp" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/symbol-table.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +void CallPhonemizeEspeak(const std::string &text, + piper::eSpeakPhonemeConfig &config, // NOLINT + std::vector> *phonemes); + +class KokoroMultiLangLexicon::Impl { + public: + Impl(const std::string &tokens, const std::string &lexicon, + const std::string &dict_dir, const std::string &data_dir, + const OfflineTtsKokoroModelMetaData &meta_data, bool debug) + : meta_data_(meta_data), debug_(debug) { + InitTokens(tokens); + + InitLexicon(lexicon); + + InitJieba(dict_dir); + + InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc + } + + template + Impl(Manager *mgr, const std::string &tokens, const std::string &lexicon, + const std::string &dict_dir, const std::string &data_dir, + const OfflineTtsKokoroModelMetaData &meta_data, bool debug) + : meta_data_(meta_data), debug_(debug) { + InitTokens(mgr, tokens); + + InitLexicon(mgr, lexicon); + + // we assume you have copied dict_dir and data_dir from assets to some path + InitJieba(dict_dir); + + InitEspeak(data_dir); // See ./piper-phonemize-lexicon.cc + } + + std::vector ConvertTextToTokenIds(const std::string &_text) const { + std::string text = ToLowerCase(_text); + if (debug_) { + SHERPA_ONNX_LOGE("After converting to lowercase:\n%s", text.c_str()); + } + + std::vector> replace_str_pairs = { + {",", ","}, {":", ","}, {"、", ","}, {";", ";"}, {":", ":"}, + {"。", "."}, {"?", "?"}, {"!", "!"}, {"\\s+", " "}, + }; + for (const auto &p : replace_str_pairs) { + std::regex re(p.first); + text = std::regex_replace(text, re, p.second); + } + + if (debug_) { + SHERPA_ONNX_LOGE("After replacing punctuations and merging spaces:\n%s", + text.c_str()); + } + + // https://en.cppreference.com/w/cpp/regex + // https://stackoverflow.com/questions/37989081/how-to-use-unicode-range-in-c-regex + std::string expr_chinese = "([\\u4e00-\\u9fff]+)"; + std::string expr_not_chinese = "([^\\u4e00-\\u9fff]+)"; + + std::string expr_both = expr_chinese + "|" + expr_not_chinese; + + auto ws = ToWideString(text); + std::wstring wexpr_both = ToWideString(expr_both); + std::wregex we_both(wexpr_both); + + std::wstring wexpr_zh = ToWideString(expr_chinese); + std::wregex we_zh(wexpr_zh); + + auto begin = std::wsregex_iterator(ws.begin(), ws.end(), we_both); + auto end = std::wsregex_iterator(); + + std::vector ans; + + for (std::wsregex_iterator i = begin; i != end; ++i) { + std::wsmatch match = *i; + std::wstring match_str = match.str(); + + auto ms = ToString(match_str); + uint8_t c = reinterpret_cast(ms.data())[0]; + + std::vector> ids_vec; + if (std::regex_match(match_str, we_zh)) { + if (debug_) { + SHERPA_ONNX_LOGE("Chinese: %s", ms.c_str()); + } + ids_vec = ConvertChineseToTokenIDs(ms); + } else { + if (debug_) { + SHERPA_ONNX_LOGE("Non-Chinese: %s", ms.c_str()); + } + + ids_vec = ConvertEnglishToTokenIDs(ms, meta_data_.voice); + } + + for (const auto &ids : ids_vec) { + if (ids.size() > 10 + 2) { + ans.emplace_back(ids); + } else { + if (ans.empty()) { + ans.emplace_back(ids); + } else { + if (ans.back().tokens.size() + ids.size() < 50) { + ans.back().tokens.back() = ids[1]; + ans.back().tokens.insert(ans.back().tokens.end(), ids.begin() + 2, + ids.end()); + } else { + ans.emplace_back(ids); + } + } + } + } + } + + if (debug_) { + for (const auto &v : ans) { + std::ostringstream os; + os << "\n"; + std::string sep; + for (auto i : v.tokens) { + os << sep << i; + sep = " "; + } + os << "\n"; + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + } + + return ans; + } + + private: + bool IsPunctuation(const std::string &text) const { + if (text == ";" || text == ":" || text == "," || text == "." || + text == "!" || text == "?" || text == "—" || text == "…" || + text == "\"" || text == "(" || text == ")" || text == "“" || + text == "”") { + return true; + } + + return false; + } + + std::vector ConvertWordToIds(const std::string &w) const { + std::vector ans; + if (word2ids_.count(w)) { + ans = word2ids_.at(w); + return ans; + } + + std::vector words = SplitUtf8(w); + for (const auto &word : words) { + if (word2ids_.count(word)) { + auto ids = ConvertWordToIds(word); + ans.insert(ans.end(), ids.begin(), ids.end()); + } else { + if (debug_) { + SHERPA_ONNX_LOGE("Skip OOV: '%s'", word.c_str()); + } + } + } + + return ans; + } + + std::vector> ConvertChineseToTokenIDs( + const std::string &text) const { + bool is_hmm = true; + + std::vector words; + jieba_->Cut(text, words, is_hmm); + if (debug_) { + std::ostringstream os; + os << "After jieba processing:\n"; + + std::string sep; + for (const auto &w : words) { + os << sep << w; + sep = "_"; + } + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + std::vector> ans; + std::vector this_sentence; + int32_t max_len = meta_data_.max_token_len; + + this_sentence.push_back(0); + for (const auto &w : words) { + auto ids = ConvertWordToIds(w); + if (this_sentence.size() + ids.size() > max_len - 2) { + this_sentence.push_back(0); + ans.push_back(std::move(this_sentence)); + + this_sentence.push_back(0); + } + + this_sentence.insert(this_sentence.end(), ids.begin(), ids.end()); + } + + if (this_sentence.size() > 1) { + this_sentence.push_back(0); + ans.push_back(std::move(this_sentence)); + } + + if (debug_) { + for (const auto &v : ans) { + std::ostringstream os; + os << "\n"; + std::string sep; + for (auto i : v) { + os << sep << i; + sep = " "; + } + os << "\n"; + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + } + + return ans; + } + + std::vector> ConvertEnglishToTokenIDs( + const std::string &text, const std::string &voice) const { + std::vector words = SplitUtf8(text); + if (debug_) { + std::ostringstream os; + os << "After splitting to words: "; + std::string sep; + for (const auto &w : words) { + os << sep << w; + sep = "_"; + } + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + std::vector> ans; + int32_t max_len = meta_data_.max_token_len; + std::vector this_sentence; + + int32_t space_id = token2id_.at(" "); + + this_sentence.push_back(0); + + for (const auto &word : words) { + if (IsPunctuation(word)) { + this_sentence.push_back(token2id_.at(word)); + + if (this_sentence.size() > max_len - 2) { + // this sentence is too long, split it + this_sentence.push_back(0); + ans.push_back(std::move(this_sentence)); + + this_sentence.push_back(0); + continue; + } + + if (word == "." || word == "!" || word == "?" || word == ";") { + // Note: You can add more punctuations here to split the text + // into sentences. We just use four here: .!?; + this_sentence.push_back(0); + ans.push_back(std::move(this_sentence)); + + this_sentence.push_back(0); + } + } else if (word2ids_.count(word)) { + const auto &ids = word2ids_.at(word); + if (this_sentence.size() + ids.size() + 3 > max_len - 2) { + this_sentence.push_back(0); + ans.push_back(std::move(this_sentence)); + + this_sentence.push_back(0); + } + + this_sentence.insert(this_sentence.end(), ids.begin(), ids.end()); + this_sentence.push_back(space_id); + } else { + if (debug_) { + SHERPA_ONNX_LOGE("Use espeak-ng to handle the OOV: '%s'", + word.c_str()); + } + + piper::eSpeakPhonemeConfig config; + + config.voice = voice; + + std::vector> phonemes; + + CallPhonemizeEspeak(word, config, &phonemes); + // Note phonemes[i] contains a vector of unicode codepoints; + // we need to convert them to utf8 + + std::wstring_convert, char32_t> conv; + + std::vector ids; + for (const auto &v : phonemes) { + for (const auto p : v) { + auto token = conv.to_bytes(p); + if (token2id_.count(token)) { + ids.push_back(token2id_.at(token)); + } else { + if (debug_) { + SHERPA_ONNX_LOGE("Skip OOV token '%s' from '%s'", token.c_str(), + word.c_str()); + } + } + } + } + + if (this_sentence.size() + ids.size() + 3 > max_len - 2) { + this_sentence.push_back(0); + ans.push_back(std::move(this_sentence)); + + this_sentence.push_back(0); + } + + this_sentence.insert(this_sentence.end(), ids.begin(), ids.end()); + this_sentence.push_back(space_id); + } + } + + if (this_sentence.size() > 1) { + this_sentence.push_back(0); + ans.push_back(std::move(this_sentence)); + } + + if (debug_) { + for (const auto &v : ans) { + std::ostringstream os; + os << "\n"; + std::string sep; + for (auto i : v) { + os << sep << i; + sep = " "; + } + os << "\n"; + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + } + + return ans; + } + + void InitTokens(const std::string &tokens) { + std::ifstream is(tokens); + InitTokens(is); + } + + template + void InitTokens(Manager *mgr, const std::string &tokens) { + auto buf = ReadFile(mgr, tokens); + + std::istrstream is(buf.data(), buf.size()); + InitTokens(is); + } + + void InitTokens(std::istream &is) { + token2id_ = ReadTokens(is); // defined in ./symbol-table.cc + } + + void InitLexicon(const std::string &lexicon) { + std::vector files; + SplitStringToVector(lexicon, ",", false, &files); + for (const auto &f : files) { + std::ifstream is(f); + InitLexicon(is); + } + } + + template + void InitLexicon(Manager *mgr, const std::string &lexicon) { + std::vector files; + SplitStringToVector(lexicon, ",", false, &files); + for (const auto &f : files) { + auto buf = ReadFile(mgr, f); + + std::istrstream is(buf.data(), buf.size()); + InitLexicon(is); + } + } + + void InitLexicon(std::istream &is) { + std::string word; + std::vector token_list; + std::string token; + + std::string line; + int32_t line_num = 0; + int32_t num_warn = 0; + while (std::getline(is, line)) { + ++line_num; + std::istringstream iss(line); + + token_list.clear(); + iss >> word; + ToLowerCase(&word); + + if (word2ids_.count(word)) { + num_warn += 1; + if (num_warn < 10) { + SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.", + word.c_str(), line_num, line.c_str()); + } + continue; + } + + while (iss >> token) { + token_list.push_back(std::move(token)); + } + + std::vector ids = ConvertTokensToIds(token2id_, token_list); + + if (ids.empty()) { + SHERPA_ONNX_LOGE( + "Invalid pronunciation for word '%s' at line %d:%s. Ignore it", + word.c_str(), line_num, line.c_str()); + continue; + } + + word2ids_.insert({std::move(word), std::move(ids)}); + } + } + + void InitJieba(const std::string &dict_dir) { + std::string dict = dict_dir + "/jieba.dict.utf8"; + std::string hmm = dict_dir + "/hmm_model.utf8"; + std::string user_dict = dict_dir + "/user.dict.utf8"; + std::string idf = dict_dir + "/idf.utf8"; + std::string stop_word = dict_dir + "/stop_words.utf8"; + + AssertFileExists(dict); + AssertFileExists(hmm); + AssertFileExists(user_dict); + AssertFileExists(idf); + AssertFileExists(stop_word); + + jieba_ = + std::make_unique(dict, hmm, user_dict, idf, stop_word); + } + + private: + OfflineTtsKokoroModelMetaData meta_data_; + + // word to token IDs + std::unordered_map> word2ids_; + + // tokens.txt is saved in token2id_ + std::unordered_map token2id_; + + std::unique_ptr jieba_; + bool debug_ = false; +}; + +KokoroMultiLangLexicon::~KokoroMultiLangLexicon() = default; + +KokoroMultiLangLexicon::KokoroMultiLangLexicon( + const std::string &tokens, const std::string &lexicon, + const std::string &dict_dir, const std::string &data_dir, + const OfflineTtsKokoroModelMetaData &meta_data, bool debug) + : impl_(std::make_unique(tokens, lexicon, dict_dir, data_dir, + meta_data, debug)) {} + +template +KokoroMultiLangLexicon::KokoroMultiLangLexicon( + Manager *mgr, const std::string &tokens, const std::string &lexicon, + const std::string &dict_dir, const std::string &data_dir, + const OfflineTtsKokoroModelMetaData &meta_data, bool debug) + : impl_(std::make_unique(mgr, tokens, lexicon, dict_dir, data_dir, + meta_data, debug)) {} + +std::vector KokoroMultiLangLexicon::ConvertTextToTokenIds( + const std::string &text, const std::string & /*unused_voice = ""*/) const { + return impl_->ConvertTextToTokenIds(text); +} + +#if __ANDROID_API__ >= 9 +template KokoroMultiLangLexicon::KokoroMultiLangLexicon( + AAssetManager *mgr, const std::string &tokens, const std::string &lexicon, + const std::string &dict_dir, const std::string &data_dir, + const OfflineTtsKokoroModelMetaData &meta_data, bool debug); +#endif + +#if __OHOS__ +template KokoroMultiLangLexicon::KokoroMultiLangLexicon( + NativeResourceManager *mgr, const std::string &tokens, + const std::string &lexicon, const std::string &dict_dir, + const std::string &data_dir, const OfflineTtsKokoroModelMetaData &meta_data, + bool debug); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/kokoro-multi-lang-lexicon.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/kokoro-multi-lang-lexicon.h new file mode 100644 index 00000000..6b5dc8eb --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/kokoro-multi-lang-lexicon.h @@ -0,0 +1,45 @@ +// sherpa-mnn/csrc/kokoro-multi-lang-lexicon.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_KOKORO_MULTI_LANG_LEXICON_H_ +#define SHERPA_ONNX_CSRC_KOKORO_MULTI_LANG_LEXICON_H_ + +#include +#include +#include + +#include "sherpa-mnn/csrc/offline-tts-frontend.h" +#include "sherpa-mnn/csrc/offline-tts-kokoro-model-meta-data.h" + +namespace sherpa_mnn { + +class KokoroMultiLangLexicon : public OfflineTtsFrontend { + public: + ~KokoroMultiLangLexicon() override; + + KokoroMultiLangLexicon(const std::string &tokens, const std::string &lexicon, + const std::string &dict_dir, + const std::string &data_dir, + const OfflineTtsKokoroModelMetaData &meta_data, + bool debug); + + template + KokoroMultiLangLexicon(Manager *mgr, const std::string &tokens, + const std::string &lexicon, + const std::string &dict_dir, + const std::string &data_dir, + const OfflineTtsKokoroModelMetaData &meta_data, + bool debug); + + std::vector ConvertTextToTokenIds( + const std::string &text, const std::string &voice = "") const override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_KOKORO_MULTI_LANG_LEXICON_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/lexicon.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/lexicon.cc new file mode 100644 index 00000000..8c31199b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/lexicon.cc @@ -0,0 +1,408 @@ +// sherpa-mnn/csrc/lexicon.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/lexicon.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/symbol-table.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +static std::vector ProcessHeteronyms( + const std::vector &words) { + std::vector ans; + ans.reserve(words.size()); + + int32_t num_words = static_cast(words.size()); + int32_t i = 0; + int32_t prev = -1; + while (i < num_words) { + // start of a phrase #$| + if ((i + 2 < num_words) && words[i] == "#" && words[i + 1] == "$" && + words[i + 2] == "|") { + if (prev == -1) { + prev = i + 3; + } + i = i + 3; + continue; + } + + // end of a phrase |$# + if ((i + 2 < num_words) && words[i] == "|" && words[i + 1] == "$" && + words[i + 2] == "#") { + if (prev != -1) { + std::ostringstream os; + for (int32_t k = prev; k < i; ++k) { + if (words[k] != "|" && words[k] != "$" && words[k] != "#") { + os << words[k]; + } + } + ans.push_back(os.str()); + + prev = -1; + } + + i += 3; + continue; + } + + if (prev == -1) { + // not inside a phrase + ans.push_back(words[i]); + } + + ++i; + } + + return ans; +} + +std::vector ConvertTokensToIds( + const std::unordered_map &token2id, + const std::vector &tokens) { + std::vector ids; + ids.reserve(tokens.size()); + for (const auto &s : tokens) { + if (!token2id.count(s)) { + return {}; + } + int32_t id = token2id.at(s); + ids.push_back(id); + } + + return ids; +} + +Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, + const std::string &punctuations, const std::string &language, + bool debug /*= false*/) + : debug_(debug) { + InitLanguage(language); + + { + std::ifstream is(tokens); + InitTokens(is); + } + + { + std::ifstream is(lexicon); + InitLexicon(is); + } + + InitPunctuations(punctuations); +} + +template +Lexicon::Lexicon(Manager *mgr, const std::string &lexicon, + const std::string &tokens, const std::string &punctuations, + const std::string &language, bool debug /*= false*/ + ) + : debug_(debug) { + InitLanguage(language); + + { + auto buf = ReadFile(mgr, tokens); + std::istrstream is(buf.data(), buf.size()); + InitTokens(is); + } + + { + auto buf = ReadFile(mgr, lexicon); + std::istrstream is(buf.data(), buf.size()); + InitLexicon(is); + } + + InitPunctuations(punctuations); +} + +std::vector Lexicon::ConvertTextToTokenIds( + const std::string &text, const std::string & /*voice*/ /*= ""*/) const { + switch (language_) { + case Language::kChinese: + return ConvertTextToTokenIdsChinese(text); + case Language::kNotChinese: + return ConvertTextToTokenIdsNotChinese(text); + default: + SHERPA_ONNX_LOGE("Unknown language: %d", static_cast(language_)); + exit(-1); + } + + return {}; +} + +std::vector Lexicon::ConvertTextToTokenIdsChinese( + const std::string &_text) const { + std::string text(_text); + ToLowerCase(&text); + + std::vector words = SplitUtf8(text); + words = ProcessHeteronyms(words); + + if (debug_) { + std::ostringstream os; + + os << "Input text in string: " << text << "\n"; + os << "Input text in bytes:"; + for (uint8_t c : text) { + os << " 0x" << std::setfill('0') << std::setw(2) << std::right << std::hex + << c; + } + os << "\n"; + os << "After splitting to words:"; + for (const auto &w : words) { + os << " " << w; + } + os << "\n"; + +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + std::vector ans; + std::vector this_sentence; + + int32_t blank = -1; + if (token2id_.count(" ")) { + blank = token2id_.at(" "); + } + + int32_t sil = -1; + int32_t eos = -1; + if (token2id_.count("sil")) { + sil = token2id_.at("sil"); + eos = token2id_.at("eos"); + } + + int32_t pad = -1; + if (token2id_.count("#0")) { + pad = token2id_.at("#0"); + } + + if (sil != -1) { + this_sentence.push_back(sil); + } + + for (const auto &w : words) { + if (w == "." || w == ";" || w == "!" || w == "?" || w == "-" || w == ":" || + w == "。" || w == ";" || w == "!" || w == "?" || w == ":" || + w == "”" || + // not sentence break + w == "," || w == "“" || w == "," || w == "、") { + if (punctuations_.count(w)) { + if (token2id_.count(w)) { + this_sentence.push_back(token2id_.at(w)); + } else if (pad != -1) { + this_sentence.push_back(pad); + } else if (sil != -1) { + this_sentence.push_back(sil); + } + } + + if (w != "," && w != "“" && w != "," && w != "、") { + if (eos != -1) { + this_sentence.push_back(eos); + } + ans.emplace_back(std::move(this_sentence)); + this_sentence = {}; + + if (sil != -1) { + this_sentence.push_back(sil); + } + } + continue; + } + + if (!word2ids_.count(w)) { + SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str()); + continue; + } + + const auto &token_ids = word2ids_.at(w); + this_sentence.insert(this_sentence.end(), token_ids.begin(), + token_ids.end()); + if (blank != -1) { + this_sentence.push_back(blank); + } + } + + if (sil != -1) { + this_sentence.push_back(sil); + } + + if (eos != -1) { + this_sentence.push_back(eos); + } + ans.emplace_back(std::move(this_sentence)); + + return ans; +} + +std::vector Lexicon::ConvertTextToTokenIdsNotChinese( + const std::string &_text) const { + std::string text(_text); + ToLowerCase(&text); + + std::vector words = SplitUtf8(text); + + if (debug_) { + std::ostringstream os; + + os << "Input text (lowercase) in string: " << text << "\n"; + os << "Input text in bytes:"; + for (uint8_t c : text) { + os << " 0x" << std::setfill('0') << std::setw(2) << std::right << std::hex + << c; + } + os << "\n"; + os << "After splitting to words:"; + for (const auto &w : words) { + os << " " << w; + } + os << "\n"; + +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + int32_t blank = token2id_.at(" "); + + std::vector ans; + std::vector this_sentence; + + for (const auto &w : words) { + if (w == "." || w == ";" || w == "!" || w == "?" || w == "-" || w == ":" || + // not sentence break + w == ",") { + if (punctuations_.count(w)) { + this_sentence.push_back(token2id_.at(w)); + } + + if (w != ",") { + this_sentence.push_back(blank); + ans.emplace_back(std::move(this_sentence)); + this_sentence = {}; + } + + continue; + } + + if (!word2ids_.count(w)) { + SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str()); + continue; + } + + const auto &token_ids = word2ids_.at(w); + this_sentence.insert(this_sentence.end(), token_ids.begin(), + token_ids.end()); + this_sentence.push_back(blank); + } + + if (!this_sentence.empty()) { + // remove the last blank + this_sentence.resize(this_sentence.size() - 1); + } + + if (!this_sentence.empty()) { + ans.emplace_back(std::move(this_sentence)); + } + + return ans; +} + +void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); } + +void Lexicon::InitLanguage(const std::string &_lang) { + std::string lang(_lang); + ToLowerCase(&lang); + if (lang == "chinese") { + language_ = Language::kChinese; + } else if (!lang.empty()) { + language_ = Language::kNotChinese; + } else { + SHERPA_ONNX_LOGE("Unknown language: %s", _lang.c_str()); + exit(-1); + } +} + +void Lexicon::InitLexicon(std::istream &is) { + std::string word; + std::vector token_list; + std::string line; + std::string phone; + + while (std::getline(is, line)) { + std::istringstream iss(line); + + token_list.clear(); + + iss >> word; + ToLowerCase(&word); + + if (word2ids_.count(word)) { + SHERPA_ONNX_LOGE("Duplicated word: %s. Ignore it.", word.c_str()); + continue; + } + + while (iss >> phone) { + token_list.push_back(std::move(phone)); + } + + std::vector ids = ConvertTokensToIds(token2id_, token_list); + if (ids.empty()) { + continue; + } + + word2ids_.insert({std::move(word), std::move(ids)}); + } +} + +void Lexicon::InitPunctuations(const std::string &punctuations) { + std::vector punctuation_list; + SplitStringToVector(punctuations, " ", false, &punctuation_list); + for (auto &s : punctuation_list) { + punctuations_.insert(std::move(s)); + } +} + +#if __ANDROID_API__ >= 9 +template Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon, + const std::string &tokens, + const std::string &punctuations, + const std::string &language, bool debug = false); +#endif + +#if __OHOS__ +template Lexicon::Lexicon(NativeResourceManager *mgr, + const std::string &lexicon, const std::string &tokens, + const std::string &punctuations, + const std::string &language, bool debug = false); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/lexicon.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/lexicon.h new file mode 100644 index 00000000..112615be --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/lexicon.h @@ -0,0 +1,66 @@ +// sherpa-mnn/csrc/lexicon.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_LEXICON_H_ +#define SHERPA_ONNX_CSRC_LEXICON_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/offline-tts-frontend.h" + +namespace sherpa_mnn { + +class Lexicon : public OfflineTtsFrontend { + public: + Lexicon() = default; // for subclasses + // + // Note: for models from piper, we won't use this class. + Lexicon(const std::string &lexicon, const std::string &tokens, + const std::string &punctuations, const std::string &language, + bool debug = false); + + template + Lexicon(Manager *mgr, const std::string &lexicon, const std::string &tokens, + const std::string &punctuations, const std::string &language, + bool debug = false); + + std::vector ConvertTextToTokenIds( + const std::string &text, const std::string &voice = "") const override; + + private: + std::vector ConvertTextToTokenIdsNotChinese( + const std::string &text) const; + + std::vector ConvertTextToTokenIdsChinese( + const std::string &text) const; + + void InitLanguage(const std::string &lang); + void InitTokens(std::istream &is); + void InitLexicon(std::istream &is); + void InitPunctuations(const std::string &punctuations); + + private: + enum class Language { + kNotChinese, + kChinese, + kUnknown, + }; + + private: + std::unordered_map> word2ids_; + std::unordered_set punctuations_; + std::unordered_map token2id_; + Language language_ = Language::kUnknown; + bool debug_ = false; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_LEXICON_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/log.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/log.cc new file mode 100644 index 00000000..1c4d0b25 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/log.cc @@ -0,0 +1,122 @@ +// sherpa-mnn/csrc/log.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/log.h" + +#ifdef SHERPA_ONNX_HAVE_EXECINFO_H +#include // To get stack trace in error messages. +#ifdef SHERPA_ONNX_HAVE_CXXABI_H +#include // For name demangling. +// Useful to decode the stack trace, but only used if we have execinfo.h +#endif // SHERPA_ONNX_HAVE_CXXABI_H +#endif // SHERPA_ONNX_HAVE_EXECINFO_H + +#include + +#include +#include +#include + +namespace sherpa_mnn { + +std::string GetDateTimeStr() { + std::ostringstream os; + std::time_t t = std::time(nullptr); + std::tm tm = *std::localtime(&t); + os << std::put_time(&tm, "%F %T"); // yyyy-mm-dd hh:mm:ss + return os.str(); +} + +static bool LocateSymbolRange(const std::string &trace_name, std::size_t *begin, + std::size_t *end) { + // Find the first '_' with leading ' ' or '('. + *begin = std::string::npos; + for (std::size_t i = 1; i < trace_name.size(); ++i) { + if (trace_name[i] != '_') { + continue; + } + if (trace_name[i - 1] == ' ' || trace_name[i - 1] == '(') { + *begin = i; + break; + } + } + if (*begin == std::string::npos) { + return false; + } + *end = trace_name.find_first_of(" +", *begin); + return *end != std::string::npos; +} + +#ifdef SHERPA_ONNX_HAVE_EXECINFO_H +static std::string Demangle(const std::string &trace_name) { +#ifndef SHERPA_ONNX_HAVE_CXXABI_H + return trace_name; +#else // SHERPA_ONNX_HAVE_CXXABI_H + // Try demangle the symbol. We are trying to support the following formats + // produced by different platforms: + // + // Linux: + // ./kaldi-error-test(_ZN5kaldi13UnitTestErrorEv+0xb) [0x804965d] + // + // Mac: + // 0 server 0x000000010f67614d _ZNK5kaldi13MessageLogger10LogMessageEv + 813 + // + // We want to extract the name e.g., '_ZN5kaldi13UnitTestErrorEv' and + // demangle it info a readable name like kaldi::UnitTextError. + std::size_t begin, end; + if (!LocateSymbolRange(trace_name, &begin, &end)) { + return trace_name; + } + std::string symbol = trace_name.substr(begin, end - begin); + int status; + char *demangled_name = abi::__cxa_demangle(symbol.c_str(), 0, 0, &status); + if (status == 0 && demangled_name != nullptr) { + symbol = demangled_name; + free(demangled_name); + } + return trace_name.substr(0, begin) + symbol + + trace_name.substr(end, std::string::npos); +#endif // SHERPA_ONNX_HAVE_CXXABI_H +} +#endif // SHERPA_ONNX_HAVE_EXECINFO_H + +std::string GetStackTrace() { + std::string ans; +#ifdef SHERPA_ONNX_HAVE_EXECINFO_H + constexpr const std::size_t kMaxTraceSize = 50; + constexpr const std::size_t kMaxTracePrint = 50; // Must be even. + // Buffer for the trace. + void *trace[kMaxTraceSize]; + // Get the trace. + std::size_t size = backtrace(trace, kMaxTraceSize); + // Get the trace symbols. + char **trace_symbol = backtrace_symbols(trace, size); + if (trace_symbol == nullptr) return ans; + + // Compose a human-readable backtrace string. + ans += "[ Stack-Trace: ]\n"; + if (size <= kMaxTracePrint) { + for (std::size_t i = 0; i < size; ++i) { + ans += Demangle(trace_symbol[i]) + "\n"; + } + } else { // Print out first+last (e.g.) 5. + for (std::size_t i = 0; i < kMaxTracePrint / 2; ++i) { + ans += Demangle(trace_symbol[i]) + "\n"; + } + ans += ".\n.\n.\n"; + for (std::size_t i = size - kMaxTracePrint / 2; i < size; ++i) { + ans += Demangle(trace_symbol[i]) + "\n"; + } + if (size == kMaxTraceSize) + ans += ".\n.\n.\n"; // Stack was too long, probably a bug. + } + + // We must free the array of pointers allocated by backtrace_symbols(), + // but not the strings themselves. + free(trace_symbol); +#endif // SHERPA_ONNX_HAVE_EXECINFO_H + return ans; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/log.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/log.h new file mode 100644 index 00000000..55234644 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/log.h @@ -0,0 +1,378 @@ +// sherpa-mnn/csrc/log.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_LOG_H_ +#define SHERPA_ONNX_CSRC_LOG_H_ + +#include + +#include // NOLINT +#include +#include + +namespace sherpa_mnn { + +#if SHERPA_ONNX_ENABLE_CHECK + +#if defined(NDEBUG) +constexpr bool kDisableDebug = true; +#else +constexpr bool kDisableDebug = false; +#endif + +enum class LogLevel { + kTrace = 0, + kDebug = 1, + kInfo = 2, + kWarning = 3, + kError = 4, + kFatal = 5, // print message and abort the program +}; + +// They are used in SHERPA_ONNX_LOG(xxx), so their names +// do not follow the google c++ code style +// +// You can use them in the following way: +// +// SHERPA_ONNX_LOG(TRACE) << "some message"; +// SHERPA_ONNX_LOG(DEBUG) << "some message"; +#ifndef _MSC_VER +constexpr LogLevel TRACE = LogLevel::kTrace; +constexpr LogLevel DEBUG = LogLevel::kDebug; +constexpr LogLevel INFO = LogLevel::kInfo; +constexpr LogLevel WARNING = LogLevel::kWarning; +constexpr LogLevel ERROR = LogLevel::kError; +constexpr LogLevel FATAL = LogLevel::kFatal; +#else +#define TRACE LogLevel::kTrace +#define DEBUG LogLevel::kDebug +#define INFO LogLevel::kInfo +#define WARNING LogLevel::kWarning +#define ERROR LogLevel::kError +#define FATAL LogLevel::kFatal +#endif + +std::string GetStackTrace(); + +/* Return the current log level. + + + If the current log level is TRACE, then all logged messages are printed out. + + If the current log level is DEBUG, log messages with "TRACE" level are not + shown and all other levels are printed out. + + Similarly, if the current log level is INFO, log message with "TRACE" and + "DEBUG" are not shown and all other levels are printed out. + + If it is FATAL, then only FATAL messages are shown. + */ +inline LogLevel GetCurrentLogLevel() { + static LogLevel log_level = INFO; + static std::once_flag init_flag; + std::call_once(init_flag, []() { + const char *env_log_level = std::getenv("SHERPA_ONNX_LOG_LEVEL"); + if (env_log_level == nullptr) return; + + std::string s = env_log_level; + if (s == "TRACE") + log_level = TRACE; + else if (s == "DEBUG") + log_level = DEBUG; + else if (s == "INFO") + log_level = INFO; + else if (s == "WARNING") + log_level = WARNING; + else if (s == "ERROR") + log_level = ERROR; + else if (s == "FATAL") + log_level = FATAL; + else + fprintf(stderr, + "Unknown SHERPA_ONNX_LOG_LEVEL: %s" + "\nSupported values are: " + "TRACE, DEBUG, INFO, WARNING, ERROR, FATAL", + s.c_str()); + }); + return log_level; +} + +inline bool EnableAbort() { + static std::once_flag init_flag; + static bool enable_abort = false; + std::call_once(init_flag, []() { + enable_abort = (std::getenv("SHERPA_ONNX_ABORT") != nullptr); + }); + return enable_abort; +} + +class Logger { + public: + Logger(const char *filename, const char *func_name, uint32_t line_num, + LogLevel level) + : filename_(filename), + func_name_(func_name), + line_num_(line_num), + level_(level) { + cur_level_ = GetCurrentLogLevel(); + switch (level) { + case TRACE: + if (cur_level_ <= TRACE) fprintf(stderr, "[T] "); + break; + case DEBUG: + if (cur_level_ <= DEBUG) fprintf(stderr, "[D] "); + break; + case INFO: + if (cur_level_ <= INFO) fprintf(stderr, "[I] "); + break; + case WARNING: + if (cur_level_ <= WARNING) fprintf(stderr, "[W] "); + break; + case ERROR: + if (cur_level_ <= ERROR) fprintf(stderr, "[E] "); + break; + case FATAL: + if (cur_level_ <= FATAL) fprintf(stderr, "[F] "); + break; + } + + if (cur_level_ <= level_) { + fprintf(stderr, "%s:%u:%s ", filename, line_num, func_name); + } + } + + ~Logger() noexcept(false) { + static constexpr const char *kErrMsg = R"( + Some bad things happened. Please read the above error messages and stack + trace. If you are using Python, the following command may be helpful: + + gdb --args python /path/to/your/code.py + + (You can use `gdb` to debug the code. Please consider compiling + a debug version of sherpa_mnn.). + + If you are unable to fix it, please open an issue at: + + https://github.com/csukuangfj/kaldi-native-fbank/issues/new + )"; + if (level_ == FATAL) { + fprintf(stderr, "\n"); + std::string stack_trace = GetStackTrace(); + if (!stack_trace.empty()) { + fprintf(stderr, "\n\n%s\n", stack_trace.c_str()); + } + + fflush(nullptr); + +#ifndef __ANDROID_API__ + if (EnableAbort()) { + // NOTE: abort() will terminate the program immediately without + // printing the Python stack backtrace. + abort(); + } + + throw std::runtime_error(kErrMsg); +#else + abort(); +#endif + } + } + + const Logger &operator<<(bool b) const { + if (cur_level_ <= level_) { + fprintf(stderr, b ? "true" : "false"); + } + return *this; + } + + const Logger &operator<<(int8_t i) const { + if (cur_level_ <= level_) fprintf(stderr, "%d", i); + return *this; + } + + const Logger &operator<<(const char *s) const { + if (cur_level_ <= level_) fprintf(stderr, "%s", s); + return *this; + } + + const Logger &operator<<(int32_t i) const { + if (cur_level_ <= level_) fprintf(stderr, "%d", i); + return *this; + } + + const Logger &operator<<(uint32_t i) const { + if (cur_level_ <= level_) fprintf(stderr, "%u", i); + return *this; + } + + const Logger &operator<<(uint i) const { + if (cur_level_ <= level_) + fprintf(stderr, "%llu", (long long unsigned int)i); // NOLINT + return *this; + } + + const Logger &operator<<(int i) const { + if (cur_level_ <= level_) + fprintf(stderr, "%lli", (long long int)i); // NOLINT + return *this; + } + + const Logger &operator<<(float f) const { + if (cur_level_ <= level_) fprintf(stderr, "%f", f); + return *this; + } + + const Logger &operator<<(double d) const { + if (cur_level_ <= level_) fprintf(stderr, "%f", d); + return *this; + } + + template + const Logger &operator<<(const T &t) const { + // require T overloads operator<< + std::ostringstream os; + os << t; + return *this << os.str().c_str(); + } + + // specialization to fix compile error: `stringstream << nullptr` is ambiguous + const Logger &operator<<(const std::nullptr_t &null) const { + if (cur_level_ <= level_) *this << "(null)"; + return *this; + } + + private: + const char *filename_; + const char *func_name_; + uint32_t line_num_; + LogLevel level_; + LogLevel cur_level_; +}; +#endif // SHERPA_ONNX_ENABLE_CHECK + +class Voidifier { + public: +#if SHERPA_ONNX_ENABLE_CHECK + void operator&(const Logger &) const {} +#endif +}; +#if !defined(SHERPA_ONNX_ENABLE_CHECK) +template +const Voidifier &operator<<(const Voidifier &v, T &&) { + return v; +} +#endif + +} // namespace sherpa_mnn + +#define SHERPA_ONNX_STATIC_ASSERT(x) static_assert(x, "") + +#ifdef SHERPA_ONNX_ENABLE_CHECK + +#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) || \ + defined(__PRETTY_FUNCTION__) +// for clang and GCC +#define SHERPA_ONNX_FUNC __PRETTY_FUNCTION__ +#else +// for other compilers +#define SHERPA_ONNX_FUNC __func__ +#endif + +#define SHERPA_ONNX_CHECK(x) \ + (x) ? (void)0 \ + : ::sherpa_mnn::Voidifier() & \ + ::sherpa_mnn::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, \ + ::sherpa_mnn::FATAL) \ + << "Check failed: " << #x << " " + +// WARNING: x and y may be evaluated multiple times, but this happens only +// when the check fails. Since the program aborts if it fails, we don't think +// the extra evaluation of x and y matters. +// +// CAUTION: we recommend the following use case: +// +// auto x = Foo(); +// auto y = Bar(); +// SHERPA_ONNX_CHECK_EQ(x, y) << "Some message"; +// +// And please avoid +// +// SHERPA_ONNX_CHECK_EQ(Foo(), Bar()); +// +// if `Foo()` or `Bar()` causes some side effects, e.g., changing some +// local static variables or global variables. +#define _SHERPA_ONNX_CHECK_OP(x, y, op) \ + ((x)op(y)) ? (void)0 \ + : ::sherpa_mnn::Voidifier() & \ + ::sherpa_mnn::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, \ + ::sherpa_mnn::FATAL) \ + << "Check failed: " << #x << " " << #op << " " << #y \ + << " (" << (x) << " vs. " << (y) << ") " + +#define SHERPA_ONNX_CHECK_EQ(x, y) _SHERPA_ONNX_CHECK_OP(x, y, ==) +#define SHERPA_ONNX_CHECK_NE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, !=) +#define SHERPA_ONNX_CHECK_LT(x, y) _SHERPA_ONNX_CHECK_OP(x, y, <) +#define SHERPA_ONNX_CHECK_LE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, <=) +#define SHERPA_ONNX_CHECK_GT(x, y) _SHERPA_ONNX_CHECK_OP(x, y, >) +#define SHERPA_ONNX_CHECK_GE(x, y) _SHERPA_ONNX_CHECK_OP(x, y, >=) + +#define SHERPA_ONNX_LOG(x) \ + ::sherpa_mnn::Logger(__FILE__, SHERPA_ONNX_FUNC, __LINE__, ::sherpa_mnn::x) + +// ------------------------------------------------------------ +// For debug check +// ------------------------------------------------------------ +// If you define the macro "-D NDEBUG" while compiling kaldi-native-fbank, +// the following macros are in fact empty and does nothing. + +#define SHERPA_ONNX_DCHECK(x) \ + ::sherpa_mnn::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK(x) + +#define SHERPA_ONNX_DCHECK_EQ(x, y) \ + ::sherpa_mnn::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_EQ(x, y) + +#define SHERPA_ONNX_DCHECK_NE(x, y) \ + ::sherpa_mnn::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_NE(x, y) + +#define SHERPA_ONNX_DCHECK_LT(x, y) \ + ::sherpa_mnn::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_LT(x, y) + +#define SHERPA_ONNX_DCHECK_LE(x, y) \ + ::sherpa_mnn::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_LE(x, y) + +#define SHERPA_ONNX_DCHECK_GT(x, y) \ + ::sherpa_mnn::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_GT(x, y) + +#define SHERPA_ONNX_DCHECK_GE(x, y) \ + ::sherpa_mnn::kDisableDebug ? (void)0 : SHERPA_ONNX_CHECK_GE(x, y) + +#define SHERPA_ONNX_DLOG(x) \ + ::sherpa_mnn::kDisableDebug \ + ? (void)0 \ + : ::sherpa_mnn::Voidifier() & SHERPA_ONNX_LOG(x) + +#else + +#define SHERPA_ONNX_CHECK(x) ::sherpa_mnn::Voidifier() +#define SHERPA_ONNX_LOG(x) ::sherpa_mnn::Voidifier() + +#define SHERPA_ONNX_CHECK_EQ(x, y) ::sherpa_mnn::Voidifier() +#define SHERPA_ONNX_CHECK_NE(x, y) ::sherpa_mnn::Voidifier() +#define SHERPA_ONNX_CHECK_LT(x, y) ::sherpa_mnn::Voidifier() +#define SHERPA_ONNX_CHECK_LE(x, y) ::sherpa_mnn::Voidifier() +#define SHERPA_ONNX_CHECK_GT(x, y) ::sherpa_mnn::Voidifier() +#define SHERPA_ONNX_CHECK_GE(x, y) ::sherpa_mnn::Voidifier() + +#define SHERPA_ONNX_DCHECK(x) ::sherpa_mnn::Voidifier() +#define SHERPA_ONNX_DLOG(x) ::sherpa_mnn::Voidifier() +#define SHERPA_ONNX_DCHECK_EQ(x, y) ::sherpa_mnn::Voidifier() +#define SHERPA_ONNX_DCHECK_NE(x, y) ::sherpa_mnn::Voidifier() +#define SHERPA_ONNX_DCHECK_LT(x, y) ::sherpa_mnn::Voidifier() +#define SHERPA_ONNX_DCHECK_LE(x, y) ::sherpa_mnn::Voidifier() +#define SHERPA_ONNX_DCHECK_GT(x, y) ::sherpa_mnn::Voidifier() +#define SHERPA_ONNX_DCHECK_GE(x, y) ::sherpa_mnn::Voidifier() + +#endif // SHERPA_ONNX_CHECK_NE + +#endif // SHERPA_ONNX_CSRC_LOG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/macros.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/macros.h new file mode 100644 index 00000000..f11aa528 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/macros.h @@ -0,0 +1,188 @@ +// sherpa-mnn/csrc/macros.h +// +// Copyright 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_MACROS_H_ +#define SHERPA_ONNX_CSRC_MACROS_H_ +#include +#include + +#include +#if __OHOS__ +#include "hilog/log.h" + +#undef LOG_DOMAIN +#undef LOG_TAG + +// https://gitee.com/openharmony/docs/blob/145a084f0b742e4325915e32f8184817927d1251/en/contribute/OpenHarmony-Log-guide.md#hilog-api-usage-specifications +#define LOG_DOMAIN 0x6666 +#define LOG_TAG "sherpa_mnn" +#endif + +#if __ANDROID_API__ >= 8 +#include "android/log.h" +#define SHERPA_ONNX_LOGE(...) \ + do { \ + fprintf(stderr, "%s:%s:%d ", __FILE__, __func__, \ + static_cast(__LINE__)); \ + fprintf(stderr, ##__VA_ARGS__); \ + fprintf(stderr, "\n"); \ + __android_log_print(ANDROID_LOG_WARN, "sherpa-mnn", ##__VA_ARGS__); \ + } while (0) +#elif defined(__OHOS__) +#define SHERPA_ONNX_LOGE(...) OH_LOG_INFO(LOG_APP, ##__VA_ARGS__) +#elif SHERPA_ONNX_ENABLE_WASM +#define SHERPA_ONNX_LOGE(...) \ + do { \ + fprintf(stdout, "%s:%s:%d ", __FILE__, __func__, \ + static_cast(__LINE__)); \ + fprintf(stdout, ##__VA_ARGS__); \ + fprintf(stdout, "\n"); \ + } while (0) +#else +#define SHERPA_ONNX_LOGE(...) \ + do { \ + fprintf(stderr, "%s:%s:%d ", __FILE__, __func__, \ + static_cast(__LINE__)); \ + fprintf(stderr, ##__VA_ARGS__); \ + fprintf(stderr, "\n"); \ + } while (0) +#endif + +#define SHERPA_ONNX_EXIT(code) exit(code) + +// Read an integer +#define SHERPA_ONNX_READ_META_DATA(dst, src_key) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ + SHERPA_ONNX_EXIT(-1); \ + } \ + \ + dst = atoi(value.c_str()); \ + if (dst < 0) { \ + SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \ + SHERPA_ONNX_EXIT(-1); \ + } \ + } while (0) + +#define SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(dst, src_key, default_value) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ + dst = default_value; \ + } else { \ + dst = atoi(value.c_str()); \ + if (dst < 0) { \ + SHERPA_ONNX_LOGE("Invalid value %d for '%s'", dst, src_key); \ + SHERPA_ONNX_EXIT(-1); \ + } \ + } \ + } while (0) + +// read a vector of integers +#define SHERPA_ONNX_READ_META_DATA_VEC(dst, src_key) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ + SHERPA_ONNX_EXIT(-1); \ + } \ + \ + bool ret = SplitStringToIntegers(value.c_str(), ",", true, &dst); \ + if (!ret) { \ + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \ + SHERPA_ONNX_EXIT(-1); \ + } \ + } while (0) + +// read a vector of floats +#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ + SHERPA_ONNX_EXIT(-1); \ + } \ + \ + bool ret = SplitStringToFloats(value.c_str(), ",", true, &dst); \ + if (!ret) { \ + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'", value.c_str(), src_key); \ + SHERPA_ONNX_EXIT(-1); \ + } \ + } while (0) + +// read a vector of strings +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING(dst, src_key) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ + SHERPA_ONNX_EXIT(-1); \ + } \ + SplitStringToVector(value.c_str(), ",", false, &dst); \ + \ + if (dst.empty()) { \ + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \ + value.c_str(), src_key); \ + SHERPA_ONNX_EXIT(-1); \ + } \ + } while (0) + +// read a vector of strings separated by sep +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ + SHERPA_ONNX_EXIT(-1); \ + } \ + SplitStringToVector(value.c_str(), sep, false, &dst); \ + \ + if (dst.empty()) { \ + SHERPA_ONNX_LOGE("Invalid value '%s' for '%s'. Empty vector!", \ + value.c_str(), src_key); \ + SHERPA_ONNX_EXIT(-1); \ + } \ + } while (0) + +// Read a string +#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ + SHERPA_ONNX_LOGE("'%s' does not exist in the metadata", src_key); \ + SHERPA_ONNX_EXIT(-1); \ + } \ + \ + dst = std::move(value); \ + if (dst.empty()) { \ + SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \ + SHERPA_ONNX_EXIT(-1); \ + } \ + } while (0) + +#define SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(dst, src_key) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + \ + dst = std::move(value); \ + } while (0) + +#define SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(dst, src_key, \ + default_value) \ + do { \ + auto value = LookupCustomModelMetaData(meta_data, src_key, allocator); \ + if (value.empty()) { \ + dst = default_value; \ + } else { \ + dst = std::move(value); \ + if (dst.empty()) { \ + SHERPA_ONNX_LOGE("Invalid value for '%s'\n", src_key); \ + SHERPA_ONNX_EXIT(-1); \ + } \ + } \ + } while (0) + +#endif // SHERPA_ONNX_CSRC_MACROS_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/math.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/math.h new file mode 100644 index 00000000..cc0b1e59 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/math.h @@ -0,0 +1,135 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Daniel Povey) + * Copyright (c) 2023 (Pingfeng Luo) + * + */ +// This file is copied from k2/csrc/utils.h +#ifndef SHERPA_ONNX_CSRC_MATH_H_ +#define SHERPA_ONNX_CSRC_MATH_H_ + +#include +#include +#include +#include +#include + +namespace sherpa_mnn { + +// logf(FLT_EPSILON) +#define SHERPA_ONNX_MIN_LOG_DIFF_FLOAT -15.9423847198486328125f + +// log(DBL_EPSILON) +#define SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE \ + -36.0436533891171535515240975655615329742431640625 + +template +struct LogAdd; + +template <> +struct LogAdd { + double operator()(double x, double y) const { + double diff; + + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE) { + double res; + res = x + log1p(exp(diff)); + return res; + } + + return x; // return the larger one. + } +}; + +template <> +struct LogAdd { + float operator()(float x, float y) const { + float diff; + + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= SHERPA_ONNX_MIN_LOG_DIFF_DOUBLE) { + float res; + res = x + log1pf(expf(diff)); + return res; + } + + return x; // return the larger one. + } +}; + +template +void LogSoftmax(T *input, int32_t input_len) { + assert(input); + + T m = *std::max_element(input, input + input_len); + + T sum = 0.0; + for (int32_t i = 0; i < input_len; i++) { + sum += exp(input[i] - m); + } + + T offset = m + log(sum); + for (int32_t i = 0; i < input_len; i++) { + input[i] -= offset; + } +} + +template +void LogSoftmax(T *in, int32_t w, int32_t h) { + for (int32_t i = 0; i != h; ++i) { + LogSoftmax(in, w); + in += w; + } +} + +template +void SubtractBlank(T *in, int32_t w, int32_t h, int32_t blank_idx, + float blank_penalty) { + for (int32_t i = 0; i != h; ++i) { + in[blank_idx] -= blank_penalty; + in += w; + } +} + +template +std::vector TopkIndex(const T *vec, int32_t size, int32_t topk) { + std::vector vec_index(size); + std::iota(vec_index.begin(), vec_index.end(), 0); + + std::partial_sort(vec_index.begin(), vec_index.begin() + topk, + vec_index.end(), [vec](int32_t index_1, int32_t index_2) { + return vec[index_1] > vec[index_2]; + }); + + int32_t k_num = std::min(size, topk); + return {vec_index.begin(), vec_index.begin() + k_num}; +} + +template +std::vector TopkIndex(const std::vector> &vec, + int32_t topk) { + std::vector flatten; + flatten.reserve(vec.size() * vec[0].size()); + for (const auto &v : vec) { + flatten.insert(flatten.end(), v.begin(), v.end()); + } + + return TopkIndex(flatten.data(), flatten.size(), topk); +} + +} // namespace sherpa_mnn +#endif // SHERPA_ONNX_CSRC_MATH_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/melo-tts-lexicon.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/melo-tts-lexicon.cc new file mode 100644 index 00000000..f2ac527e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/melo-tts-lexicon.cc @@ -0,0 +1,427 @@ +// sherpa-mnn/csrc/melo-tts-lexicon.cc +// +// Copyright (c) 2022-2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/melo-tts-lexicon.h" + +#include +#include // NOLINT +#include +#include +#include +#include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "cppjieba/Jieba.hpp" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/symbol-table.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class MeloTtsLexicon::Impl { + public: + Impl(const std::string &lexicon, const std::string &tokens, + const std::string &dict_dir, + const OfflineTtsVitsModelMetaData &meta_data, bool debug) + : meta_data_(meta_data), debug_(debug) { + std::string dict = dict_dir + "/jieba.dict.utf8"; + std::string hmm = dict_dir + "/hmm_model.utf8"; + std::string user_dict = dict_dir + "/user.dict.utf8"; + std::string idf = dict_dir + "/idf.utf8"; + std::string stop_word = dict_dir + "/stop_words.utf8"; + + AssertFileExists(dict); + AssertFileExists(hmm); + AssertFileExists(user_dict); + AssertFileExists(idf); + AssertFileExists(stop_word); + + jieba_ = + std::make_unique(dict, hmm, user_dict, idf, stop_word); + + { + std::ifstream is(tokens); + InitTokens(is); + } + + { + std::ifstream is(lexicon); + InitLexicon(is); + } + } + + Impl(const std::string &lexicon, const std::string &tokens, + const OfflineTtsVitsModelMetaData &meta_data, bool debug) + : meta_data_(meta_data), debug_(debug) { + { + std::ifstream is(tokens); + InitTokens(is); + } + + { + std::ifstream is(lexicon); + InitLexicon(is); + } + } + + template + Impl(Manager *mgr, const std::string &lexicon, const std::string &tokens, + const std::string &dict_dir, + const OfflineTtsVitsModelMetaData &meta_data, bool debug) + : meta_data_(meta_data), debug_(debug) { + std::string dict = dict_dir + "/jieba.dict.utf8"; + std::string hmm = dict_dir + "/hmm_model.utf8"; + std::string user_dict = dict_dir + "/user.dict.utf8"; + std::string idf = dict_dir + "/idf.utf8"; + std::string stop_word = dict_dir + "/stop_words.utf8"; + + AssertFileExists(dict); + AssertFileExists(hmm); + AssertFileExists(user_dict); + AssertFileExists(idf); + AssertFileExists(stop_word); + + jieba_ = + std::make_unique(dict, hmm, user_dict, idf, stop_word); + + { + auto buf = ReadFile(mgr, tokens); + + std::istrstream is(buf.data(), buf.size()); + InitTokens(is); + } + + { + auto buf = ReadFile(mgr, lexicon); + + std::istrstream is(buf.data(), buf.size()); + InitLexicon(is); + } + } + + template + Impl(Manager *mgr, const std::string &lexicon, const std::string &tokens, + const OfflineTtsVitsModelMetaData &meta_data, bool debug) + : meta_data_(meta_data), debug_(debug) { + { + auto buf = ReadFile(mgr, tokens); + + std::istrstream is(buf.data(), buf.size()); + InitTokens(is); + } + + { + auto buf = ReadFile(mgr, lexicon); + + std::istrstream is(buf.data(), buf.size()); + InitLexicon(is); + } + } + + std::vector ConvertTextToTokenIds(const std::string &_text) const { + std::string text = ToLowerCase(_text); + // see + // https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244 + std::regex punct_re{":|、|;"}; + std::string s = std::regex_replace(text, punct_re, ","); + + std::regex punct_re2("。"); + s = std::regex_replace(s, punct_re2, "."); + + std::regex punct_re3("?"); + s = std::regex_replace(s, punct_re3, "?"); + + std::regex punct_re4("!"); + s = std::regex_replace(s, punct_re4, "!"); + + std::vector words; + if (jieba_) { + bool is_hmm = true; + jieba_->Cut(text, words, is_hmm); + + if (debug_) { + std::ostringstream os; + std::string sep = ""; + for (const auto &w : words) { + os << sep << w; + sep = "_"; + } +#if __OHOS__ + SHERPA_ONNX_LOGE("input text: %{public}s", text.c_str()); + SHERPA_ONNX_LOGE("after replacing punctuations: %{public}s", s.c_str()); + + SHERPA_ONNX_LOGE("after jieba processing: %{public}s", + os.str().c_str()); +#else + SHERPA_ONNX_LOGE("input text: %s", text.c_str()); + SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str()); + + SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str()); +#endif + } + } else { + words = SplitUtf8(text); + + if (debug_) { + fprintf(stderr, "Input text in string (lowercase): %s\n", text.c_str()); + fprintf(stderr, "Input text in bytes (lowercase):"); + for (int8_t c : text) { + fprintf(stderr, " %02x", c); + } + fprintf(stderr, "\n"); + fprintf(stderr, "After splitting to words:"); + for (const auto &w : words) { + fprintf(stderr, " %s", w.c_str()); + } + fprintf(stderr, "\n"); + } + } + + std::vector ans; + TokenIDs this_sentence; + + for (const auto &w : words) { + auto ids = ConvertWordToIds(w); + if (ids.tokens.empty()) { + SHERPA_ONNX_LOGE("Ignore OOV '%s'", w.c_str()); + continue; + } + + this_sentence.tokens.insert(this_sentence.tokens.end(), + ids.tokens.begin(), ids.tokens.end()); + this_sentence.tones.insert(this_sentence.tones.end(), ids.tones.begin(), + ids.tones.end()); + + if (w == "." || w == "!" || w == "?" || w == "," || w == "。" || + w == "!" || w == "?" || w == ",") { + ans.push_back(std::move(this_sentence)); + this_sentence = {}; + } + } // for (const auto &w : words) + + if (!this_sentence.tokens.empty()) { + ans.push_back(std::move(this_sentence)); + } + + return ans; + } + + private: + TokenIDs ConvertWordToIds(const std::string &w) const { + if (word2ids_.count(w)) { + return word2ids_.at(w); + } + + if (token2id_.count(w)) { + return {{token2id_.at(w)}, {0}}; + } + + TokenIDs ans; + + std::vector words = SplitUtf8(w); + for (const auto &word : words) { + if (word2ids_.count(word)) { + auto ids = ConvertWordToIds(word); + ans.tokens.insert(ans.tokens.end(), ids.tokens.begin(), + ids.tokens.end()); + ans.tones.insert(ans.tones.end(), ids.tones.begin(), ids.tones.end()); + } else { + // If the lexicon does not contain the word, we split the word into + // characters. + // + // For instance, if the word is TTS and it is does not exist + // in the lexicon, we split it into 3 characters: T T S + std::string s; + for (char c : word) { + s = c; + if (word2ids_.count(s)) { + const auto &t = word2ids_.at(s); + ans.tokens.insert(ans.tokens.end(), t.tokens.begin(), + t.tokens.end()); + ans.tones.insert(ans.tones.end(), t.tones.begin(), t.tones.end()); + } + } + } + } + + return ans; + } + + void InitTokens(std::istream &is) { + token2id_ = ReadTokens(is); + token2id_[" "] = token2id_["_"]; + + std::vector> puncts = { + {",", ","}, {".", "。"}, {"!", "!"}, {"?", "?"}}; + + for (const auto &p : puncts) { + if (token2id_.count(p.first) && !token2id_.count(p.second)) { + token2id_[p.second] = token2id_[p.first]; + } + + if (!token2id_.count(p.first) && token2id_.count(p.second)) { + token2id_[p.first] = token2id_[p.second]; + } + } + + if (!token2id_.count("、") && token2id_.count(",")) { + token2id_["、"] = token2id_[","]; + } + } + + void InitLexicon(std::istream &is) { + std::string word; + std::vector token_list; + + std::vector phone_list; + std::vector tone_list; + + std::string line; + std::string phone; + int32_t line_num = 0; + + while (std::getline(is, line)) { + ++line_num; + + std::istringstream iss(line); + + token_list.clear(); + phone_list.clear(); + tone_list.clear(); + + iss >> word; + ToLowerCase(&word); + + if (word2ids_.count(word)) { + SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.", + word.c_str(), line_num, line.c_str()); + continue; + } + + while (iss >> phone) { + token_list.push_back(std::move(phone)); + } + + if ((token_list.size() & 1) != 0) { + SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str()); + exit(-1); + } + + int32_t num_phones = token_list.size() / 2; + phone_list.reserve(num_phones); + tone_list.reserve(num_phones); + + for (int32_t i = 0; i != num_phones; ++i) { + phone_list.push_back(std::move(token_list[i])); + tone_list.push_back(std::stoi(token_list[i + num_phones], nullptr)); + if (tone_list.back() < 0 || tone_list.back() > 50) { + SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str()); + exit(-1); + } + } + + std::vector ids = ConvertTokensToIds(token2id_, phone_list); + if (ids.empty()) { + continue; + } + + if (ids.size() != num_phones) { + SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str()); + exit(-1); + } + + std::vector ids64{ids.begin(), ids.end()}; + + word2ids_.insert( + {std::move(word), TokenIDs{std::move(ids64), std::move(tone_list)}}); + } + + // For Chinese+English MeloTTS + word2ids_["呣"] = word2ids_["母"]; + word2ids_["嗯"] = word2ids_["恩"]; + } + + private: + // lexicon.txt is saved in word2ids_ + std::unordered_map word2ids_; + + // tokens.txt is saved in token2id_ + std::unordered_map token2id_; + + OfflineTtsVitsModelMetaData meta_data_; + + std::unique_ptr jieba_; + bool debug_ = false; +}; + +MeloTtsLexicon::~MeloTtsLexicon() = default; + +MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon, + const std::string &tokens, + const std::string &dict_dir, + const OfflineTtsVitsModelMetaData &meta_data, + bool debug) + : impl_(std::make_unique(lexicon, tokens, dict_dir, meta_data, + debug)) {} + +MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon, + const std::string &tokens, + const OfflineTtsVitsModelMetaData &meta_data, + bool debug) + : impl_(std::make_unique(lexicon, tokens, meta_data, debug)) {} + +template +MeloTtsLexicon::MeloTtsLexicon(Manager *mgr, const std::string &lexicon, + const std::string &tokens, + const std::string &dict_dir, + const OfflineTtsVitsModelMetaData &meta_data, + bool debug) + : impl_(std::make_unique(mgr, lexicon, tokens, dict_dir, meta_data, + debug)) {} + +template +MeloTtsLexicon::MeloTtsLexicon(Manager *mgr, const std::string &lexicon, + const std::string &tokens, + const OfflineTtsVitsModelMetaData &meta_data, + bool debug) + : impl_(std::make_unique(mgr, lexicon, tokens, meta_data, debug)) {} + +std::vector MeloTtsLexicon::ConvertTextToTokenIds( + const std::string &text, const std::string & /*unused_voice = ""*/) const { + return impl_->ConvertTextToTokenIds(text); +} + +#if __ANDROID_API__ >= 9 +template MeloTtsLexicon::MeloTtsLexicon( + AAssetManager *mgr, const std::string &lexicon, const std::string &tokens, + const std::string &dict_dir, const OfflineTtsVitsModelMetaData &meta_data, + bool debug); + +template MeloTtsLexicon::MeloTtsLexicon( + AAssetManager *mgr, const std::string &lexicon, const std::string &tokens, + const OfflineTtsVitsModelMetaData &meta_data, bool debug); +#endif + +#if __OHOS__ +template MeloTtsLexicon::MeloTtsLexicon( + NativeResourceManager *mgr, const std::string &lexicon, + const std::string &tokens, const std::string &dict_dir, + const OfflineTtsVitsModelMetaData &meta_data, bool debug); + +template MeloTtsLexicon::MeloTtsLexicon( + NativeResourceManager *mgr, const std::string &lexicon, + const std::string &tokens, const OfflineTtsVitsModelMetaData &meta_data, + bool debug); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/melo-tts-lexicon.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/melo-tts-lexicon.h new file mode 100644 index 00000000..3ede1687 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/melo-tts-lexicon.h @@ -0,0 +1,48 @@ +// sherpa-mnn/csrc/melo-tts-lexicon.h +// +// Copyright (c) 2022-2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_ +#define SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_ + +#include +#include +#include + +#include "sherpa-mnn/csrc/offline-tts-frontend.h" +#include "sherpa-mnn/csrc/offline-tts-vits-model-meta-data.h" + +namespace sherpa_mnn { + +class MeloTtsLexicon : public OfflineTtsFrontend { + public: + ~MeloTtsLexicon() override; + MeloTtsLexicon(const std::string &lexicon, const std::string &tokens, + const std::string &dict_dir, + const OfflineTtsVitsModelMetaData &meta_data, bool debug); + + MeloTtsLexicon(const std::string &lexicon, const std::string &tokens, + const OfflineTtsVitsModelMetaData &meta_data, bool debug); + + template + MeloTtsLexicon(Manager *mgr, const std::string &lexicon, + const std::string &tokens, const std::string &dict_dir, + const OfflineTtsVitsModelMetaData &meta_data, bool debug); + + template + MeloTtsLexicon(Manager *mgr, const std::string &lexicon, + const std::string &tokens, + const OfflineTtsVitsModelMetaData &meta_data, bool debug); + + std::vector ConvertTextToTokenIds( + const std::string &text, + const std::string &unused_voice = "") const override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/microphone.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/microphone.cc new file mode 100644 index 00000000..67db0022 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/microphone.cc @@ -0,0 +1,30 @@ +// sherpa-mnn/csrc/microphone.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/microphone.h" + +#include +#include + +#include "portaudio.h" // NOLINT + +namespace sherpa_mnn { + +Microphone::Microphone() { + PaError err = Pa_Initialize(); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(-1); + } +} + +Microphone::~Microphone() { + PaError err = Pa_Terminate(); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(-1); + } +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/microphone.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/microphone.h new file mode 100644 index 00000000..158524ed --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/microphone.h @@ -0,0 +1,18 @@ +// sherpa-mnn/csrc/microphone.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_MICROPHONE_H_ +#define SHERPA_ONNX_CSRC_MICROPHONE_H_ + +namespace sherpa_mnn { + +class Microphone { + public: + Microphone(); + ~Microphone(); +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_MICROPHONE_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ced-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ced-model.cc new file mode 100644 index 00000000..ebb99f64 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ced-model.cc @@ -0,0 +1,110 @@ +// sherpa-mnn/csrc/offline-ced-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-ced-model.h" + +#include +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" +#include "sherpa-mnn/csrc/transpose.h" + +namespace sherpa_mnn { + +class OfflineCEDModel::Impl { + public: + explicit Impl(const AudioTaggingModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.ced); + Init(buf.data(), buf.size()); + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const AudioTaggingModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.ced); + Init(buf.data(), buf.size()); + } +#endif + + MNN::Express::VARP Forward(MNN::Express::VARP features) { + features = Transpose12(allocator_, features); + + auto ans = sess_->onForward({features}); + return std::move(ans[0]); + } + + int32_t NumEventClasses() const { return num_event_classes_; } + + MNNAllocator *Allocator() { return allocator_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + // get num_event_classes from the output[0].shape, + // which is (N, num_event_classes) + //num_event_classes_ = + // sess_->GetOutputTypeInfo(0)->getInfo()->dim[1]; + } + + private: + AudioTaggingModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t num_event_classes_ = 0; +}; + +OfflineCEDModel::OfflineCEDModel(const AudioTaggingModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OfflineCEDModel::OfflineCEDModel(AAssetManager *mgr, + const AudioTaggingModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OfflineCEDModel::~OfflineCEDModel() = default; + +MNN::Express::VARP OfflineCEDModel::Forward(MNN::Express::VARP features) const { + return impl_->Forward(std::move(features)); +} + +int32_t OfflineCEDModel::NumEventClasses() const { + return impl_->NumEventClasses(); +} + +MNNAllocator *OfflineCEDModel::Allocator() const { return impl_->Allocator(); } + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ced-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ced-model.h new file mode 100644 index 00000000..61adfb06 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ced-model.h @@ -0,0 +1,56 @@ +// sherpa-mnn/csrc/offline-ced-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_ +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/audio-tagging-model-config.h" + +namespace sherpa_mnn { + +/** This class implements the CED model from + * https://github.com/RicherMans/CED/blob/main/export_onnx.py + */ +class OfflineCEDModel { + public: + explicit OfflineCEDModel(const AudioTaggingModelConfig &config); + +#if __ANDROID_API__ >= 9 + OfflineCEDModel(AAssetManager *mgr, const AudioTaggingModelConfig &config); +#endif + + ~OfflineCEDModel(); + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). + * + * @return Return a tensor + * - probs: A 2-D tensor of shape (N, num_event_classes). + */ + MNN::Express::VARP Forward(MNN::Express::VARP features) const; + + /** Return the number of event classes of the model + */ + int32_t NumEventClasses() const; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ct-transformer-model-meta-data.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ct-transformer-model-meta-data.h new file mode 100644 index 00000000..de63731a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ct-transformer-model-meta-data.h @@ -0,0 +1,29 @@ +// sherpa-mnn/csrc/offline-ct-transformer-model-meta-data.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_ + +#include +#include +#include + +namespace sherpa_mnn { + +struct OfflineCtTransformerModelMetaData { + std::unordered_map token2id; + std::unordered_map punct2id; + std::vector id2punct; + + int32_t unk_id; + int32_t dot_id; + int32_t comma_id; + int32_t quest_id; + int32_t pause_id; + int32_t underline_id; + int32_t num_punctuations; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ct-transformer-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ct-transformer-model.cc new file mode 100644 index 00000000..b645f618 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ct-transformer-model.cc @@ -0,0 +1,162 @@ +// sherpa-mnn/csrc/offline-ct-transformer-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-ct-transformer-model.h" + +#include +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OfflineCtTransformerModel::Impl { + public: + explicit Impl(const OfflinePunctuationModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.ct_transformer); + Init(buf.data(), buf.size()); + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OfflinePunctuationModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.ct_transformer); + Init(buf.data(), buf.size()); + } +#endif + + MNN::Express::VARP Forward(MNN::Express::VARP text, MNN::Express::VARP text_len) { + std::vector inputs = {std::move(text), std::move(text_len)}; + + auto ans = + sess_->onForward(inputs); + return std::move(ans[0]); + } + + MNNAllocator *Allocator() { return allocator_; } + + const OfflineCtTransformerModelMetaData & metaData() const { + return meta_data_; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + + MNNAllocator* allocator; // used in the macro below + + std::vector tokens; + SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(tokens, "tokens", "|"); + + int32_t vocab_size = 0; + SHERPA_ONNX_READ_META_DATA(vocab_size, "vocab_size"); + if (static_cast(tokens.size()) != vocab_size) { + SHERPA_ONNX_LOGE("tokens.size() %d != vocab_size %d", + static_cast(tokens.size()), vocab_size); + exit(-1); + } + + SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(meta_data_.id2punct, + "punctuations", "|"); + + std::string unk_symbol; + SHERPA_ONNX_READ_META_DATA_STR(unk_symbol, "unk_symbol"); + + // output shape is (N, T, num_punctuations) + //meta_data_.num_punctuations = + // sess_->GetOutputTypeInfo(0)->getInfo()->dim[2]; + + int32_t i = 0; + for (const auto &t : tokens) { + meta_data_.token2id[t] = i; + i += 1; + } + + i = 0; + for (const auto &p : meta_data_.id2punct) { + meta_data_.punct2id[p] = i; + i += 1; + } + + meta_data_.unk_id = meta_data_.token2id.at(unk_symbol); + + meta_data_.dot_id = meta_data_.punct2id.at("。"); + meta_data_.comma_id = meta_data_.punct2id.at(","); + meta_data_.quest_id = meta_data_.punct2id.at("?"); + meta_data_.pause_id = meta_data_.punct2id.at("、"); + meta_data_.underline_id = meta_data_.punct2id.at("_"); + + if (config_.debug) { + std::ostringstream os; + os << "vocab_size: " << meta_data_.token2id.size() << "\n"; + os << "num_punctuations: " << meta_data_.num_punctuations << "\n"; + os << "punctuations: "; + for (const auto &s : meta_data_.id2punct) { + os << s << " "; + } + os << "\n"; + SHERPA_ONNX_LOGE("\n%s\n", os.str().c_str()); + } + } + + private: + OfflinePunctuationModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + OfflineCtTransformerModelMetaData meta_data_; +}; + +OfflineCtTransformerModel::OfflineCtTransformerModel( + const OfflinePunctuationModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OfflineCtTransformerModel::OfflineCtTransformerModel( + AAssetManager *mgr, const OfflinePunctuationModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OfflineCtTransformerModel::~OfflineCtTransformerModel() = default; + +MNN::Express::VARP OfflineCtTransformerModel::Forward(MNN::Express::VARP text, + MNN::Express::VARP text_len) const { + return impl_->Forward(std::move(text), std::move(text_len)); +} + +MNNAllocator *OfflineCtTransformerModel::Allocator() const { + return impl_->Allocator(); +} + +const OfflineCtTransformerModelMetaData & +OfflineCtTransformerModel::metaData() const { + return impl_->metaData(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ct-transformer-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ct-transformer-model.h new file mode 100644 index 00000000..65f31f00 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ct-transformer-model.h @@ -0,0 +1,59 @@ +// sherpa-mnn/csrc/offline-ct-transformer-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_H_ +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-ct-transformer-model-meta-data.h" +#include "sherpa-mnn/csrc/offline-punctuation-model-config.h" + +namespace sherpa_mnn { + +/** This class implements + * https://github.com/alibaba-damo-academy/FunASR/blob/main/runtime/python/onnxruntime/funasr_onnx/punc_bin.py#L17 + * from FunASR + */ +class OfflineCtTransformerModel { + public: + explicit OfflineCtTransformerModel( + const OfflinePunctuationModelConfig &config); + +#if __ANDROID_API__ >= 9 + OfflineCtTransformerModel(AAssetManager *mgr, + const OfflinePunctuationModelConfig &config); +#endif + + ~OfflineCtTransformerModel(); + + /** Run the forward method of the model. + * + * @param text A tensor of shape (N, T) of dtype int32. + * @param text A tensor of shape (N) of dtype int32. + * + * @return Return a tensor + * - punctuation_ids: A 2-D tensor of shape (N, T). + */ + MNN::Express::VARP Forward(MNN::Express::VARP text, MNN::Express::VARP text_len) const; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const; + + const OfflineCtTransformerModelMetaData & metaData() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-decoder.h new file mode 100644 index 00000000..0f330f17 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-decoder.h @@ -0,0 +1,50 @@ +// sherpa-mnn/csrc/offline-ctc-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT + +namespace sherpa_mnn { + +struct OfflineCtcDecoderResult { + /// The decoded token IDs + std::vector tokens; + + /// The decoded word IDs + /// Note: tokens.size() is usually not equal to words.size() + /// words is empty for greedy search decoding. + /// it is not empty when an HLG graph or an HLG graph is used. + std::vector words; + + /// timestamps[i] contains the output frame index where tokens[i] is decoded. + /// Note: The index is after subsampling + /// + /// tokens.size() == timestamps.size() + std::vector timestamps; +}; + +class OfflineCtcDecoder { + public: + virtual ~OfflineCtcDecoder() = default; + + /** Run CTC decoding given the output from the encoder model. + * + * @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing + * lob_probs. + * @param log_probs_length A 1-D tensor of shape (N,) containing number + * of valid frames in log_probs before padding. + * + * @return Return a vector of size `N` containing the decoded results. + */ + virtual std::vector Decode( + MNN::Express::VARP log_probs, MNN::Express::VARP log_probs_length) = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-fst-decoder-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-fst-decoder-config.cc new file mode 100644 index 00000000..8e8fa0c0 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-fst-decoder-config.cc @@ -0,0 +1,43 @@ +// sherpa-mnn/csrc/offline-ctc-fst-decoder-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-ctc-fst-decoder-config.h" + +#include +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +std::string OfflineCtcFstDecoderConfig::ToString() const { + std::ostringstream os; + + os << "OfflineCtcFstDecoderConfig("; + os << "graph=\"" << graph << "\", "; + os << "max_active=" << max_active << ")"; + + return os.str(); +} + +void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) { + std::string prefix = "ctc"; + ParseOptions p(prefix, po); + + p.Register("graph", &graph, "Path to H.fst, HL.fst, or HLG.fst"); + + p.Register("max-active", &max_active, + "Decoder max active states. Larger->slower; more accurate"); +} + +bool OfflineCtcFstDecoderConfig::Validate() const { + if (!graph.empty() && !FileExists(graph)) { + SHERPA_ONNX_LOGE("graph: '%s' does not exist", graph.c_str()); + return false; + } + return true; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-fst-decoder-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-fst-decoder-config.h new file mode 100644 index 00000000..13c5c615 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-fst-decoder-config.h @@ -0,0 +1,32 @@ +// sherpa-mnn/csrc/offline-ctc-fst-decoder-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineCtcFstDecoderConfig { + // Path to H.fst, HL.fst or HLG.fst + std::string graph; + int32_t max_active = 3000; + + OfflineCtcFstDecoderConfig() = default; + + OfflineCtcFstDecoderConfig(const std::string &graph, int32_t max_active) + : graph(graph), max_active(max_active) {} + + std::string ToString() const; + + void Register(ParseOptions *po); + bool Validate() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-fst-decoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-fst-decoder.cc new file mode 100644 index 00000000..be79e8e0 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-fst-decoder.cc @@ -0,0 +1,119 @@ +// sherpa-mnn/csrc/offline-ctc-fst-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-ctc-fst-decoder.h" + +#include +#include + +#include "fst/fstlib.h" +#include "kaldi-decoder/csrc/decodable-ctc.h" +#include "kaldi-decoder/csrc/eigen.h" +#include "kaldi-decoder/csrc/faster-decoder.h" +#include "sherpa-mnn/csrc/fst-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +/** + * @param decoder + * @param p Pointer to a 2-d array of shape (num_frames, vocab_size) + * @param num_frames Number of rows in the 2-d array. + * @param vocab_size Number of columns in the 2-d array. + * @return Return the decoded result. + */ +static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder, + const float *p, int32_t num_frames, + int32_t vocab_size) { + OfflineCtcDecoderResult r; + kaldi_decoder::DecodableCtc decodable(p, num_frames, vocab_size); + + decoder->Decode(&decodable); + + if (!decoder->ReachedFinal()) { + SHERPA_ONNX_LOGE("Not reached final!"); + return r; + } + + fst::VectorFst decoded; // linear FST. + decoder->GetBestPath(&decoded); + + if (decoded.NumStates() == 0) { + SHERPA_ONNX_LOGE("Empty best path!"); + return r; + } + + auto cur_state = decoded.Start(); + + int32_t blank_id = 0; + + for (int32_t t = 0, prev = -1; decoded.NumArcs(cur_state) == 1; ++t) { + fst::ArcIterator> iter(decoded, cur_state); + const auto &arc = iter.Value(); + + cur_state = arc.nextstate; + + if (arc.ilabel == prev) { + continue; + } + + // 0 is epsilon here + if (arc.ilabel == 0 || arc.ilabel == blank_id + 1) { + prev = arc.ilabel; + continue; + } + + // -1 here since the input labels are incremented during graph + // construction + r.tokens.push_back(arc.ilabel - 1); + if (arc.olabel != 0) { + r.words.push_back(arc.olabel); + } + + r.timestamps.push_back(t); + prev = arc.ilabel; + } + + return r; +} + +OfflineCtcFstDecoder::OfflineCtcFstDecoder( + const OfflineCtcFstDecoderConfig &config) + : config_(config), fst_(ReadGraph(config_.graph)) {} + +std::vector OfflineCtcFstDecoder::Decode( + MNN::Express::VARP log_probs, MNN::Express::VARP log_probs_length) { + std::vector shape = log_probs->getInfo()->dim; + + assert(static_cast(shape.size()) == 3); + int32_t batch_size = shape[0]; + int32_t T = shape[1]; + int32_t vocab_size = shape[2]; + + std::vector length_shape = + log_probs_length->getInfo()->dim; + assert(static_cast(length_shape.size()) == 1); + + assert(shape[0] == length_shape[0]); + + kaldi_decoder::FasterDecoderOptions opts; + opts.max_active = config_.max_active; + kaldi_decoder::FasterDecoder faster_decoder(*fst_, opts); + + const float *start = log_probs->readMap(); + + std::vector ans; + ans.reserve(batch_size); + + for (int32_t i = 0; i != batch_size; ++i) { + const float *p = start + i * T * vocab_size; + int32_t num_frames = log_probs_length->readMap()[i]; + auto r = DecodeOne(&faster_decoder, p, num_frames, vocab_size); + ans.push_back(std::move(r)); + } + + return ans; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-fst-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-fst-decoder.h new file mode 100644 index 00000000..5f33e410 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-fst-decoder.h @@ -0,0 +1,33 @@ +// sherpa-mnn/csrc/offline-ctc-fst-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_ + +#include +#include + +#include "fst/fst.h" +#include "sherpa-mnn/csrc/offline-ctc-decoder.h" +#include "sherpa-mnn/csrc/offline-ctc-fst-decoder-config.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +class OfflineCtcFstDecoder : public OfflineCtcDecoder { + public: + explicit OfflineCtcFstDecoder(const OfflineCtcFstDecoderConfig &config); + + std::vector Decode( + MNN::Express::VARP log_probs, MNN::Express::VARP log_probs_length) override; + + private: + OfflineCtcFstDecoderConfig config_; + + std::unique_ptr> fst_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-greedy-search-decoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-greedy-search-decoder.cc new file mode 100644 index 00000000..847dd64c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-greedy-search-decoder.cc @@ -0,0 +1,54 @@ +// sherpa-mnn/csrc/offline-ctc-greedy-search-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-ctc-greedy-search-decoder.h" + +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +std::vector OfflineCtcGreedySearchDecoder::Decode( + MNN::Express::VARP log_probs, MNN::Express::VARP log_probs_length) { + std::vector shape = log_probs->getInfo()->dim; + int32_t batch_size = static_cast(shape[0]); + int32_t num_frames = static_cast(shape[1]); + int32_t vocab_size = static_cast(shape[2]); + + const int *p_log_probs_length = log_probs_length->readMap(); + + std::vector ans; + ans.reserve(batch_size); + + for (int32_t b = 0; b != batch_size; ++b) { + const float *p_log_probs = + log_probs->readMap() + b * num_frames * vocab_size; + + OfflineCtcDecoderResult r; + int prev_id = -1; + + for (int32_t t = 0; t != static_cast(p_log_probs_length[b]); ++t) { + auto y = static_cast(std::distance( + static_cast(p_log_probs), + std::max_element( + static_cast(p_log_probs), + static_cast(p_log_probs) + vocab_size))); + p_log_probs += vocab_size; + + if (y != blank_id_ && y != prev_id) { + r.tokens.push_back(y); + r.timestamps.push_back(t); + } + prev_id = y; + } // for (int32_t t = 0; ...) + + ans.push_back(std::move(r)); + } + return ans; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-greedy-search-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-greedy-search-decoder.h new file mode 100644 index 00000000..cb3bcebb --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-greedy-search-decoder.h @@ -0,0 +1,28 @@ +// sherpa-mnn/csrc/offline-ctc-greedy-search-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_ + +#include + +#include "sherpa-mnn/csrc/offline-ctc-decoder.h" + +namespace sherpa_mnn { + +class OfflineCtcGreedySearchDecoder : public OfflineCtcDecoder { + public: + explicit OfflineCtcGreedySearchDecoder(int32_t blank_id) + : blank_id_(blank_id) {} + + std::vector Decode( + MNN::Express::VARP log_probs, MNN::Express::VARP log_probs_length) override; + + private: + int32_t blank_id_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-model.cc new file mode 100644 index 00000000..068666c3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-model.cc @@ -0,0 +1,221 @@ +// sherpa-mnn/csrc/offline-ctc-model.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-ctc-model.h" + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model.h" +#include "sherpa-mnn/csrc/offline-tdnn-ctc-model.h" +#include "sherpa-mnn/csrc/offline-telespeech-ctc-model.h" +#include "sherpa-mnn/csrc/offline-wenet-ctc-model.h" +#include "sherpa-mnn/csrc/offline-zipformer-ctc-model.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace { + +enum class ModelType : std::uint8_t { + kEncDecCTCModelBPE, + kEncDecCTCModel, + kEncDecHybridRNNTCTCBPEModel, + kTdnn, + kZipformerCtc, + kWenetCtc, + kTeleSpeechCtc, + kUnknown, +}; + +} // namespace + +namespace sherpa_mnn { + +static ModelType GetModelType(char *model_data, size_t model_data_length, + bool debug) { + MNNEnv env; + std::shared_ptr sess_opts; + + + + auto sess = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts)); + + MNNMeta meta_data = sess->getInfo()->metaData; + if (debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; + auto model_type = + LookupCustomModelMetaData(meta_data, "model_type", allocator); + if (model_type.empty()) { + SHERPA_ONNX_LOGE( + "No model_type in the metadata!\n" + "If you are using models from NeMo, please refer to\n" + "https://huggingface.co/csukuangfj/" + "sherpa-mnn-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py\n" + "or " + "https://github.com/k2-fsa/sherpa-mnn/tree/master/scripts/nemo/" + "fast-conformer-hybrid-transducer-ctc\n" + "If you are using models from WeNet, please refer to\n" + "https://github.com/k2-fsa/sherpa-mnn/blob/master/scripts/wenet/" + "run.sh\n" + "If you are using models from TeleSpeech, please refer to\n" + "https://github.com/k2-fsa/sherpa-mnn/blob/master/scripts/tele-speech/" + "add-metadata.py" + "\n" + "for how to add metadta to model.onnx\n"); + return ModelType::kUnknown; + } + + if (model_type == "EncDecCTCModelBPE") { + return ModelType::kEncDecCTCModelBPE; + } else if (model_type == "EncDecCTCModel") { + return ModelType::kEncDecCTCModel; + } else if (model_type == "EncDecHybridRNNTCTCBPEModel") { + return ModelType::kEncDecHybridRNNTCTCBPEModel; + } else if (model_type == "tdnn") { + return ModelType::kTdnn; + } else if (model_type == "zipformer2_ctc") { + return ModelType::kZipformerCtc; + } else if (model_type == "wenet_ctc") { + return ModelType::kWenetCtc; + } else if (model_type == "telespeech_ctc") { + return ModelType::kTeleSpeechCtc; + } else { + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str()); + return ModelType::kUnknown; + } +} + +std::unique_ptr OfflineCtcModel::Create( + const OfflineModelConfig &config) { + // TODO(fangjun): Refactor it. We don't need to use model_type here + ModelType model_type = ModelType::kUnknown; + + std::string filename; + if (!config.nemo_ctc.model.empty()) { + filename = config.nemo_ctc.model; + } else if (!config.tdnn.model.empty()) { + filename = config.tdnn.model; + } else if (!config.zipformer_ctc.model.empty()) { + filename = config.zipformer_ctc.model; + } else if (!config.wenet_ctc.model.empty()) { + filename = config.wenet_ctc.model; + } else if (!config.telespeech_ctc.empty()) { + filename = config.telespeech_ctc; + } else { + SHERPA_ONNX_LOGE("Please specify a CTC model"); + exit(-1); + } + + { + auto buffer = ReadFile(filename); + + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); + } + + switch (model_type) { + case ModelType::kEncDecCTCModelBPE: + case ModelType::kEncDecCTCModel: + return std::make_unique(config); + case ModelType::kEncDecHybridRNNTCTCBPEModel: + return std::make_unique(config); + case ModelType::kTdnn: + return std::make_unique(config); + case ModelType::kZipformerCtc: + return std::make_unique(config); + case ModelType::kWenetCtc: + return std::make_unique(config); + case ModelType::kTeleSpeechCtc: + return std::make_unique(config); + case ModelType::kUnknown: + SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); + return nullptr; + } + + return nullptr; +} + +template +std::unique_ptr OfflineCtcModel::Create( + Manager *mgr, const OfflineModelConfig &config) { + // TODO(fangjun): Refactor it. We don't need to use model_type here + ModelType model_type = ModelType::kUnknown; + + std::string filename; + if (!config.nemo_ctc.model.empty()) { + filename = config.nemo_ctc.model; + } else if (!config.tdnn.model.empty()) { + filename = config.tdnn.model; + } else if (!config.zipformer_ctc.model.empty()) { + filename = config.zipformer_ctc.model; + } else if (!config.wenet_ctc.model.empty()) { + filename = config.wenet_ctc.model; + } else if (!config.telespeech_ctc.empty()) { + filename = config.telespeech_ctc; + } else { + SHERPA_ONNX_LOGE("Please specify a CTC model"); + exit(-1); + } + + { + auto buffer = ReadFile(mgr, filename); + + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); + } + + switch (model_type) { + case ModelType::kEncDecCTCModelBPE: + case ModelType::kEncDecCTCModel: + return std::make_unique(mgr, config); + case ModelType::kEncDecHybridRNNTCTCBPEModel: + return std::make_unique(mgr, + config); + case ModelType::kTdnn: + return std::make_unique(mgr, config); + case ModelType::kZipformerCtc: + return std::make_unique(mgr, config); + case ModelType::kWenetCtc: + return std::make_unique(mgr, config); + case ModelType::kTeleSpeechCtc: + return std::make_unique(mgr, config); + case ModelType::kUnknown: + SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); + return nullptr; + } + + return nullptr; +} + +#if __ANDROID_API__ >= 9 +template std::unique_ptr OfflineCtcModel::Create( + AAssetManager *mgr, const OfflineModelConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr OfflineCtcModel::Create( + NativeResourceManager *mgr, const OfflineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-model.h new file mode 100644 index 00000000..5ecbf47d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-ctc-model.h @@ -0,0 +1,71 @@ +// sherpa-mnn/csrc/offline-ctc-model.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_ + +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-model-config.h" + +namespace sherpa_mnn { + +class OfflineCtcModel { + public: + virtual ~OfflineCtcModel() = default; + + static std::unique_ptr Create( + const OfflineModelConfig &config); + + template + static std::unique_ptr Create( + Manager *mgr, const OfflineModelConfig &config); + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int. + * + * @return Return a vector containing: + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int + */ + virtual std::vector Forward(MNN::Express::VARP features, + MNN::Express::VARP features_length) = 0; + + /** Return the vocabulary size of the model + */ + virtual int32_t VocabSize() const = 0; + + /** SubsamplingFactor of the model + * + * For NeMo Citrinet, the subsampling factor is usually 4. + * For NeMo Conformer CTC, the subsampling factor is usually 8. + */ + virtual int32_t SubsamplingFactor() const { return 1; } + + /** Return an allocator for allocating memory + */ + virtual MNNAllocator *Allocator() const = 0; + + /** For some models, e.g., those from NeMo, they require some preprocessing + * for the features. + */ + virtual std::string FeatureNormalizationMethod() const { return {}; } + + // Return true if the model supports batch size > 1 + virtual bool SupportBatchProcessing() const { return true; } + + // return true for models from https://github.com/salute-developers/GigaAM + // return false otherwise + virtual bool IsGigaAM() const { return false; } +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-decoder.h new file mode 100644 index 00000000..64365ea4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-decoder.h @@ -0,0 +1,39 @@ +// sherpa-mnn/csrc/offline-fire-red-asr-decoder.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_ + +#include +#include + +#include "MNNUtils.hpp" // NOLINT + +namespace sherpa_mnn { + +struct OfflineFireRedAsrDecoderResult { + /// The decoded token IDs + std::vector tokens; +}; + +class OfflineFireRedAsrDecoder { + public: + virtual ~OfflineFireRedAsrDecoder() = default; + + /** Run beam search given the output from the FireRedAsr encoder model. + * + * @param n_layer_cross_k A 4-D tensor of shape + * (num_decoder_layers, N, T, d_model). + * @param n_layer_cross_v A 4-D tensor of shape + * (num_decoder_layers, N, T, d_model). + * + * @return Return a vector of size `N` containing the decoded results. + */ + virtual std::vector Decode( + MNN::Express::VARP n_layer_cross_k, MNN::Express::VARP n_layer_cross_v) = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-greedy-search-decoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-greedy-search-decoder.cc new file mode 100644 index 00000000..6f2f913c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-greedy-search-decoder.cc @@ -0,0 +1,87 @@ +// sherpa-mnn/csrc/offline-fire-red-asr-greedy-search-decoder.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-fire-red-asr-greedy-search-decoder.h" + +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +// Note: this functions works only for batch size == 1 at present +std::vector +OfflineFireRedAsrGreedySearchDecoder::Decode(MNN::Express::VARP cross_k, + MNN::Express::VARP cross_v) { + const auto &meta_data = model_->metaData(); + + auto memory_info = + (MNNAllocator*)(nullptr); + + // For multilingual models, initial_tokens contains [sot, language, task] + // - language is English by default + // - task is transcribe by default + // + // For non-multilingual models, initial_tokens contains [sot] + std::array token_shape = {1, 1}; + int token = meta_data.sos_id; + + int32_t batch_size = 1; + + MNN::Express::VARP tokens = MNNUtilsCreateTensor( + memory_info, &token, 1, token_shape.data(), token_shape.size()); + + std::array offset_shape{1}; + MNN::Express::VARP offset = MNNUtilsCreateTensor( + model_->Allocator(), offset_shape.data(), offset_shape.size()); + *(offset->writeMap()) = 0; + + std::vector ans(1); + + auto self_kv_cache = model_->GetInitialSelfKVCache(); + + std::tuple + decoder_out = {MNN::Express::VARP{nullptr}, + std::move(self_kv_cache.first), + std::move(self_kv_cache.second), + std::move(cross_k), + std::move(cross_v), + std::move(offset)}; + + for (int32_t i = 0; i < meta_data.max_len; ++i) { + decoder_out = model_->ForwardDecoder(View(tokens), + std::move(std::get<1>(decoder_out)), + std::move(std::get<2>(decoder_out)), + std::move(std::get<3>(decoder_out)), + std::move(std::get<4>(decoder_out)), + std::move(std::get<5>(decoder_out))); + + const auto &logits = std::get<0>(decoder_out); + const float *p_logits = logits->readMap(); + + auto logits_shape = logits->getInfo()->dim; + int32_t vocab_size = logits_shape[2]; + + int32_t max_token_id = static_cast(std::distance( + p_logits, std::max_element(p_logits, p_logits + vocab_size))); + if (max_token_id == meta_data.eos_id) { + break; + } + + ans[0].tokens.push_back(max_token_id); + + token = max_token_id; + + // increment offset + *(std::get<5>(decoder_out)->writeMap()) += 1; + } + + return ans; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-greedy-search-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-greedy-search-decoder.h new file mode 100644 index 00000000..cc2f0c31 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-greedy-search-decoder.h @@ -0,0 +1,29 @@ +// sherpa-mnn/csrc/offline-fire-red-asr-greedy-search-decoder.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_ + +#include + +#include "sherpa-mnn/csrc/offline-fire-red-asr-decoder.h" +#include "sherpa-mnn/csrc/offline-fire-red-asr-model.h" + +namespace sherpa_mnn { + +class OfflineFireRedAsrGreedySearchDecoder : public OfflineFireRedAsrDecoder { + public: + explicit OfflineFireRedAsrGreedySearchDecoder(OfflineFireRedAsrModel *model) + : model_(model) {} + + std::vector Decode( + MNN::Express::VARP cross_k, MNN::Express::VARP cross_v) override; + + private: + OfflineFireRedAsrModel *model_; // not owned +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_GREEDY_SEARCH_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-model-config.cc new file mode 100644 index 00000000..11544b8e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-model-config.cc @@ -0,0 +1,56 @@ +// sherpa-mnn/csrc/offline-fire-red-asr-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-fire-red-asr-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineFireRedAsrModelConfig::Register(ParseOptions *po) { + po->Register("fire-red-asr-encoder", &encoder, + "Path to onnx encoder of FireRedAsr"); + + po->Register("fire-red-asr-decoder", &decoder, + "Path to onnx decoder of FireRedAsr"); +} + +bool OfflineFireRedAsrModelConfig::Validate() const { + if (encoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --fire-red-asr-encoder"); + return false; + } + + if (!FileExists(encoder)) { + SHERPA_ONNX_LOGE("FireRedAsr encoder file '%s' does not exist", + encoder.c_str()); + return false; + } + + if (decoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --fire-red-asr-decoder"); + return false; + } + + if (!FileExists(decoder)) { + SHERPA_ONNX_LOGE("FireRedAsr decoder file '%s' does not exist", + decoder.c_str()); + return false; + } + + return true; +} + +std::string OfflineFireRedAsrModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineFireRedAsrModelConfig("; + os << "encoder=\"" << encoder << "\", "; + os << "decoder=\"" << decoder << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-model-config.h new file mode 100644 index 00000000..d5283ffc --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-model-config.h @@ -0,0 +1,31 @@ +// sherpa-mnn/csrc/offline-fire-red-asr-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +// see https://github.com/FireRedTeam/FireRedASR +struct OfflineFireRedAsrModelConfig { + std::string encoder; + std::string decoder; + + OfflineFireRedAsrModelConfig() = default; + OfflineFireRedAsrModelConfig(const std::string &encoder, + const std::string &decoder) + : encoder(encoder), decoder(decoder) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-model-meta-data.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-model-meta-data.h new file mode 100644 index 00000000..de609f3e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-model-meta-data.h @@ -0,0 +1,28 @@ +// sherpa-mnn/csrc/offline-fire-red-asr-model-meta-data.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_ + +#include +#include +#include + +namespace sherpa_mnn { + +struct OfflineFireRedAsrModelMetaData { + int32_t sos_id; + int32_t eos_id; + int32_t max_len; + + int32_t num_decoder_layers; + int32_t num_head; + int32_t head_dim; + + std::vector mean; + std::vector inv_stddev; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_META_DATA_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-model.cc new file mode 100644 index 00000000..9b9d45db --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-model.cc @@ -0,0 +1,250 @@ +// sherpa-mnn/csrc/offline-fire-red-asr-model.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-fire-red-asr-model.h" + +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OfflineFireRedAsrModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.fire_red_asr.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.fire_red_asr.decoder); + InitDecoder(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.fire_red_asr.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.fire_red_asr.decoder); + InitDecoder(buf.data(), buf.size()); + } + } + + std::pair ForwardEncoder(MNN::Express::VARP features, + MNN::Express::VARP features_length) { + std::vector inputs{std::move(features), + std::move(features_length)}; + + auto encoder_out = encoder_sess_->onForward(inputs); + + return {std::move(encoder_out[0]), std::move(encoder_out[1])}; + } + + std::tuple + ForwardDecoder(MNN::Express::VARP tokens, MNN::Express::VARP n_layer_self_k_cache, + MNN::Express::VARP n_layer_self_v_cache, MNN::Express::VARP n_layer_cross_k, + MNN::Express::VARP n_layer_cross_v, MNN::Express::VARP offset) { + std::vector decoder_input = {std::move(tokens), + std::move(n_layer_self_k_cache), + std::move(n_layer_self_v_cache), + std::move(n_layer_cross_k), + std::move(n_layer_cross_v), + std::move(offset)}; + + auto decoder_out = decoder_sess_->onForward(decoder_input); + + return std::tuple{ + std::move(decoder_out[0]), std::move(decoder_out[1]), + std::move(decoder_out[2]), std::move(decoder_input[3]), + std::move(decoder_input[4]), std::move(decoder_input[5])}; + } + + std::pair GetInitialSelfKVCache() { + int32_t batch_size = 1; + std::array shape{meta_data_.num_decoder_layers, batch_size, + meta_data_.max_len, meta_data_.num_head, + meta_data_.head_dim}; + + MNN::Express::VARP n_layer_self_k_cache = MNNUtilsCreateTensor( + Allocator(), shape.data(), shape.size()); + + MNN::Express::VARP n_layer_self_v_cache = MNNUtilsCreateTensor( + Allocator(), shape.data(), shape.size()); + + auto n = shape[0] * shape[1] * shape[2] * shape[3] * shape[4]; + + float *p_k = n_layer_self_k_cache->writeMap(); + float *p_v = n_layer_self_v_cache->writeMap(); + + memset(p_k, 0, sizeof(float) * n); + memset(p_v, 0, sizeof(float) * n); + + return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)}; + } + + MNNAllocator *Allocator() { return allocator_; } + + const OfflineFireRedAsrModelMetaData& metaData() const { + return meta_data_; + } + + private: + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = encoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(meta_data_.num_decoder_layers, + "num_decoder_layers"); + SHERPA_ONNX_READ_META_DATA(meta_data_.num_head, "num_head"); + SHERPA_ONNX_READ_META_DATA(meta_data_.head_dim, "head_dim"); + SHERPA_ONNX_READ_META_DATA(meta_data_.sos_id, "sos"); + SHERPA_ONNX_READ_META_DATA(meta_data_.eos_id, "eos"); + SHERPA_ONNX_READ_META_DATA(meta_data_.max_len, "max_len"); + + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.mean, "cmvn_mean"); + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.inv_stddev, + "cmvn_inv_stddev"); + } + + void InitDecoder(void *model_data, size_t model_data_length) { + decoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + } + + private: + OfflineModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + OfflineFireRedAsrModelMetaData meta_data_; +}; + +OfflineFireRedAsrModel::OfflineFireRedAsrModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineFireRedAsrModel::OfflineFireRedAsrModel(Manager *mgr, + const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineFireRedAsrModel::~OfflineFireRedAsrModel() = default; + +std::pair OfflineFireRedAsrModel::ForwardEncoder( + MNN::Express::VARP features, MNN::Express::VARP features_length) const { + return impl_->ForwardEncoder(std::move(features), std::move(features_length)); +} + +std::tuple +OfflineFireRedAsrModel::ForwardDecoder(MNN::Express::VARP tokens, + MNN::Express::VARP n_layer_self_k_cache, + MNN::Express::VARP n_layer_self_v_cache, + MNN::Express::VARP n_layer_cross_k, + MNN::Express::VARP n_layer_cross_v, + MNN::Express::VARP offset) const { + return impl_->ForwardDecoder( + std::move(tokens), std::move(n_layer_self_k_cache), + std::move(n_layer_self_v_cache), std::move(n_layer_cross_k), + std::move(n_layer_cross_v), std::move(offset)); +} + +std::pair +OfflineFireRedAsrModel::GetInitialSelfKVCache() const { + return impl_->GetInitialSelfKVCache(); +} + +MNNAllocator *OfflineFireRedAsrModel::Allocator() const { + return impl_->Allocator(); +} + +const OfflineFireRedAsrModelMetaData& OfflineFireRedAsrModel::metaData() + const { + return impl_->metaData(); +} + +#if __ANDROID_API__ >= 9 +template OfflineFireRedAsrModel::OfflineFireRedAsrModel( + AAssetManager *mgr, const OfflineModelConfig &config); +#endif + +#if __OHOS__ +template OfflineFireRedAsrModel::OfflineFireRedAsrModel( + NativeResourceManager *mgr, const OfflineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-model.h new file mode 100644 index 00000000..31f9a8b1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-fire-red-asr-model.h @@ -0,0 +1,92 @@ +// sherpa-mnn/csrc/offline-fire-red-asr-model.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_ + +#include +#include +#include +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-fire-red-asr-model-meta-data.h" +#include "sherpa-mnn/csrc/offline-model-config.h" + +namespace sherpa_mnn { + +class OfflineFireRedAsrModel { + public: + explicit OfflineFireRedAsrModel(const OfflineModelConfig &config); + + template + OfflineFireRedAsrModel(Manager *mgr, const OfflineModelConfig &config); + + ~OfflineFireRedAsrModel(); + + /** Run the encoder model. + * + * @param features A tensor of shape (N, T, C). + * @param features_len A tensor of shape (N,) with dtype int64. + * + * @return Return a pair containing: + * - n_layer_cross_k: A 4-D tensor of shape + * (num_decoder_layers, N, T, d_model) + * - n_layer_cross_v: A 4-D tensor of shape + * (num_decoder_layers, N, T, d_model) + */ + std::pair ForwardEncoder( + MNN::Express::VARP features, MNN::Express::VARP features_length) const; + + /** Run the decoder model. + * + * @param tokens A int64 tensor of shape (N, num_words) + * @param n_layer_self_k_cache A 5-D tensor of shape + * (num_decoder_layers, N, max_len, num_head, head_dim). + * @param n_layer_self_v_cache A 5-D tensor of shape + * (num_decoder_layers, N, max_len, num_head, head_dim). + * @param n_layer_cross_k A 5-D tensor of shape + * (num_decoder_layers, N, T, d_model). + * @param n_layer_cross_v A 5-D tensor of shape + * (num_decoder_layers, N, T, d_model). + * @param offset A int64 tensor of shape (N,) + * + * @return Return a tuple containing 6 tensors: + * + * - logits A 3-D tensor of shape (N, num_words, vocab_size) + * - out_n_layer_self_k_cache Same shape as n_layer_self_k_cache + * - out_n_layer_self_v_cache Same shape as n_layer_self_v_cache + * - out_n_layer_cross_k Same as n_layer_cross_k + * - out_n_layer_cross_v Same as n_layer_cross_v + * - out_offset Same as offset + */ + std::tuple + ForwardDecoder(MNN::Express::VARP tokens, MNN::Express::VARP n_layer_self_k_cache, + MNN::Express::VARP n_layer_self_v_cache, MNN::Express::VARP n_layer_cross_k, + MNN::Express::VARP n_layer_cross_v, MNN::Express::VARP offset) const; + + /** Return the initial self kv cache in a pair + * - n_layer_self_k_cache A 5-D tensor of shape + * (num_decoder_layers, N, max_len, num_head, head_dim). + * - n_layer_self_v_cache A 5-D tensor of shape + * (num_decoder_layers, N, max_len, num_head, head_dim). + */ + std::pair GetInitialSelfKVCache() const; + + const OfflineFireRedAsrModelMetaData& metaData() const; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-lm-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-lm-config.cc new file mode 100644 index 00000000..b3916af8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-lm-config.cc @@ -0,0 +1,42 @@ +// sherpa-mnn/csrc/offline-lm-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-lm-config.h" + +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineLMConfig::Register(ParseOptions *po) { + po->Register("lm", &model, "Path to LM model."); + po->Register("lm-scale", &scale, "LM scale."); + po->Register("lm-num-threads", &lm_num_threads, + "Number of threads to run the neural network of LM model"); + po->Register("lm-provider", &lm_provider, + "Specify a provider to LM model use: cpu, cuda, coreml"); +} + +bool OfflineLMConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("'%s' does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OfflineLMConfig::ToString() const { + std::ostringstream os; + + os << "OfflineLMConfig("; + os << "model=\"" << model << "\", "; + os << "scale=" << scale << ")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-lm-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-lm-config.h new file mode 100644 index 00000000..e57023db --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-lm-config.h @@ -0,0 +1,39 @@ +// sherpa-mnn/csrc/offline-lm-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineLMConfig { + // path to the onnx model + std::string model; + + // LM scale + float scale = 0.5; + int32_t lm_num_threads = 1; + std::string lm_provider = "cpu"; + + OfflineLMConfig() = default; + + OfflineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, + const std::string &lm_provider) + : model(model), + scale(scale), + lm_num_threads(lm_num_threads), + lm_provider(lm_provider) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-lm.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-lm.cc new file mode 100644 index 00000000..27bec4de --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-lm.cc @@ -0,0 +1,96 @@ +// sherpa-mnn/csrc/offline-lm.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-lm.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/offline-rnn-lm.h" + +namespace sherpa_mnn { + +std::unique_ptr OfflineLM::Create(const OfflineLMConfig &config) { + return std::make_unique(config); +} + +template +std::unique_ptr OfflineLM::Create(Manager *mgr, + const OfflineLMConfig &config) { + return std::make_unique(mgr, config); +} + +void OfflineLM::ComputeLMScore(float scale, int32_t context_size, + std::vector *hyps) { + // compute the max token seq so that we know how much space to allocate + int32_t max_token_seq = 0; + int32_t num_hyps = 0; + + // we subtract context_size below since each token sequence is prepended + // with context_size blanks + for (const auto &h : *hyps) { + num_hyps += h.Size(); + for (const auto &t : h) { + max_token_seq = + std::max(max_token_seq, t.second.ys.size() - context_size); + } + } + + MNNAllocator* allocator; + std::array x_shape{num_hyps, max_token_seq}; + MNN::Express::VARP x = MNNUtilsCreateTensor(allocator, x_shape.data(), + x_shape.size()); + + std::array x_lens_shape{num_hyps}; + MNN::Express::VARP x_lens = MNNUtilsCreateTensor( + allocator, x_lens_shape.data(), x_lens_shape.size()); + + int *p = x->writeMap(); + std::fill(p, p + num_hyps * max_token_seq, 0); + + int *p_lens = x_lens->writeMap(); + + for (const auto &h : *hyps) { + for (const auto &t : h) { + const auto &ys = t.second.ys; + int32_t len = ys.size() - context_size; + std::copy(ys.begin() + context_size, ys.end(), p); + *p_lens = len; + + p += max_token_seq; + ++p_lens; + } + } + auto negative_loglike = Rescore(std::move(x), std::move(x_lens)); + const float *p_nll = negative_loglike->readMap(); + for (auto &h : *hyps) { + for (auto &t : h) { + // Use -scale here since we want to change negative loglike to loglike. + t.second.lm_log_prob = -scale * (*p_nll); + ++p_nll; + } + } +} + +#if __ANDROID_API__ >= 9 +template std::unique_ptr OfflineLM::Create( + AAssetManager *mgr, const OfflineLMConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr OfflineLM::Create( + NativeResourceManager *mgr, const OfflineLMConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-lm.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-lm.h new file mode 100644 index 00000000..3db38a1a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-lm.h @@ -0,0 +1,50 @@ +// sherpa-mnn/csrc/offline-lm.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_LM_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_LM_H_ + +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/hypothesis.h" +#include "sherpa-mnn/csrc/offline-lm-config.h" + +namespace sherpa_mnn { + +class OfflineLM { + public: + virtual ~OfflineLM() = default; + + static std::unique_ptr Create(const OfflineLMConfig &config); + + template + static std::unique_ptr Create(Manager *mgr, + const OfflineLMConfig &config); + + /** Rescore a batch of sentences. + * + * @param x A 2-D tensor of shape (N, L) with data type int64. + * @param x_lens A 1-D tensor of shape (N,) with data type int64. + * It contains number of valid tokens in x before padding. + * @return Return a 1-D tensor of shape (N,) containing the negative log + * likelihood of each utterance. Its data type is float32. + * + * Caution: It returns negative log likelihood (nll), not log likelihood + */ + virtual MNN::Express::VARP Rescore(MNN::Express::VARP x, MNN::Express::VARP x_lens) = 0; + + // This function updates hyp.lm_lob_prob of hyps. + // + // @param scale LM score + // @param context_size Context size of the transducer decoder model + // @param hyps It is changed in-place. + void ComputeLMScore(float scale, int32_t context_size, + std::vector *hyps); +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_LM_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-model-config.cc new file mode 100644 index 00000000..ed50c7f2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-model-config.cc @@ -0,0 +1,151 @@ +// sherpa-mnn/csrc/offline-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-mnn/csrc/offline-model-config.h" + +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineModelConfig::Register(ParseOptions *po) { + transducer.Register(po); + paraformer.Register(po); + nemo_ctc.Register(po); + whisper.Register(po); + fire_red_asr.Register(po); + tdnn.Register(po); + zipformer_ctc.Register(po); + wenet_ctc.Register(po); + sense_voice.Register(po); + moonshine.Register(po); + + po->Register("telespeech-ctc", &telespeech_ctc, + "Path to model.onnx for telespeech ctc"); + + po->Register("tokens", &tokens, "Path to tokens.txt"); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); + + po->Register("model-type", &model_type, + "Specify it to reduce model initialization time. " + "Valid values are: transducer, paraformer, nemo_ctc, whisper, " + "tdnn, zipformer2_ctc, telespeech_ctc, fire_red_asr." + "All other values lead to loading the model twice."); + po->Register("modeling-unit", &modeling_unit, + "The modeling unit of the model, commonly used units are bpe, " + "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when " + "hotwords are provided, we need it to encode the hotwords into " + "token sequence."); + po->Register("bpe-vocab", &bpe_vocab, + "The vocabulary generated by google's sentencepiece program. " + "It is a file has two columns, one is the token, the other is " + "the log probability, you can get it from the directory where " + "your bpe model is generated. Only used when hotwords provided " + "and the modeling unit is bpe or cjkchar+bpe"); +} + +bool OfflineModelConfig::Validate() const { + if (num_threads < 1) { + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); + return false; + } + + if (!FileExists(tokens)) { + SHERPA_ONNX_LOGE("tokens: '%s' does not exist", tokens.c_str()); + return false; + } + + if (!modeling_unit.empty() && + (modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) { + if (!FileExists(bpe_vocab)) { + SHERPA_ONNX_LOGE("bpe_vocab: '%s' does not exist", bpe_vocab.c_str()); + return false; + } + } + + if (!paraformer.model.empty()) { + return paraformer.Validate(); + } + + if (!nemo_ctc.model.empty()) { + return nemo_ctc.Validate(); + } + + if (!whisper.encoder.empty()) { + return whisper.Validate(); + } + + if (!fire_red_asr.encoder.empty()) { + return fire_red_asr.Validate(); + } + + if (!tdnn.model.empty()) { + return tdnn.Validate(); + } + + if (!zipformer_ctc.model.empty()) { + return zipformer_ctc.Validate(); + } + + if (!wenet_ctc.model.empty()) { + return wenet_ctc.Validate(); + } + + if (!sense_voice.model.empty()) { + return sense_voice.Validate(); + } + + if (!moonshine.preprocessor.empty()) { + return moonshine.Validate(); + } + + if (!telespeech_ctc.empty() && !FileExists(telespeech_ctc)) { + SHERPA_ONNX_LOGE("telespeech_ctc: '%s' does not exist", + telespeech_ctc.c_str()); + return false; + } + + if (!transducer.encoder_filename.empty()) { + return transducer.Validate(); + } + + return true; +} + +std::string OfflineModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineModelConfig("; + os << "transducer=" << transducer.ToString() << ", "; + os << "paraformer=" << paraformer.ToString() << ", "; + os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; + os << "whisper=" << whisper.ToString() << ", "; + os << "fire_red_asr=" << fire_red_asr.ToString() << ", "; + os << "tdnn=" << tdnn.ToString() << ", "; + os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; + os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; + os << "sense_voice=" << sense_voice.ToString() << ", "; + os << "moonshine=" << moonshine.ToString() << ", "; + os << "telespeech_ctc=\"" << telespeech_ctc << "\", "; + os << "tokens=\"" << tokens << "\", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\", "; + os << "model_type=\"" << model_type << "\", "; + os << "modeling_unit=\"" << modeling_unit << "\", "; + os << "bpe_vocab=\"" << bpe_vocab << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-model-config.h new file mode 100644 index 00000000..5106fa03 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-model-config.h @@ -0,0 +1,97 @@ +// sherpa-mnn/csrc/offline-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/offline-fire-red-asr-model-config.h" +#include "sherpa-mnn/csrc/offline-moonshine-model-config.h" +#include "sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model-config.h" +#include "sherpa-mnn/csrc/offline-paraformer-model-config.h" +#include "sherpa-mnn/csrc/offline-sense-voice-model-config.h" +#include "sherpa-mnn/csrc/offline-tdnn-model-config.h" +#include "sherpa-mnn/csrc/offline-transducer-model-config.h" +#include "sherpa-mnn/csrc/offline-wenet-ctc-model-config.h" +#include "sherpa-mnn/csrc/offline-whisper-model-config.h" +#include "sherpa-mnn/csrc/offline-zipformer-ctc-model-config.h" + +namespace sherpa_mnn { + +struct OfflineModelConfig { + OfflineTransducerModelConfig transducer; + OfflineParaformerModelConfig paraformer; + OfflineNemoEncDecCtcModelConfig nemo_ctc; + OfflineWhisperModelConfig whisper; + OfflineFireRedAsrModelConfig fire_red_asr; + OfflineTdnnModelConfig tdnn; + OfflineZipformerCtcModelConfig zipformer_ctc; + OfflineWenetCtcModelConfig wenet_ctc; + OfflineSenseVoiceModelConfig sense_voice; + OfflineMoonshineModelConfig moonshine; + std::string telespeech_ctc; + + std::string tokens; + int32_t num_threads = 2; + bool debug = false; + std::string provider = "cpu"; + + // With the help of this field, we only need to load the model once + // instead of twice; and therefore it reduces initialization time. + // + // Valid values: + // - transducer. The given model is from icefall + // - paraformer. It is a paraformer model + // - nemo_ctc. It is a NeMo CTC model. + // + // All other values are invalid and lead to loading the model twice. + std::string model_type; + + std::string modeling_unit = "cjkchar"; + std::string bpe_vocab; + + OfflineModelConfig() = default; + OfflineModelConfig(const OfflineTransducerModelConfig &transducer, + const OfflineParaformerModelConfig ¶former, + const OfflineNemoEncDecCtcModelConfig &nemo_ctc, + const OfflineWhisperModelConfig &whisper, + const OfflineFireRedAsrModelConfig &fire_red_asr, + const OfflineTdnnModelConfig &tdnn, + const OfflineZipformerCtcModelConfig &zipformer_ctc, + const OfflineWenetCtcModelConfig &wenet_ctc, + const OfflineSenseVoiceModelConfig &sense_voice, + const OfflineMoonshineModelConfig &moonshine, + const std::string &telespeech_ctc, + const std::string &tokens, int32_t num_threads, bool debug, + const std::string &provider, const std::string &model_type, + const std::string &modeling_unit, + const std::string &bpe_vocab) + : transducer(transducer), + paraformer(paraformer), + nemo_ctc(nemo_ctc), + whisper(whisper), + fire_red_asr(fire_red_asr), + tdnn(tdnn), + zipformer_ctc(zipformer_ctc), + wenet_ctc(wenet_ctc), + sense_voice(sense_voice), + moonshine(moonshine), + telespeech_ctc(telespeech_ctc), + tokens(tokens), + num_threads(num_threads), + debug(debug), + provider(provider), + model_type(model_type), + modeling_unit(modeling_unit), + bpe_vocab(bpe_vocab) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-decoder.h new file mode 100644 index 00000000..93936dcf --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-decoder.h @@ -0,0 +1,34 @@ +// sherpa-mnn/csrc/offline-moonshine-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT + +namespace sherpa_mnn { + +struct OfflineMoonshineDecoderResult { + /// The decoded token IDs + std::vector tokens; +}; + +class OfflineMoonshineDecoder { + public: + virtual ~OfflineMoonshineDecoder() = default; + + /** Run beam search given the output from the moonshine encoder model. + * + * @param encoder_out A 3-D tensor of shape (batch_size, T, dim) + * @return Return a vector of size `N` containing the decoded results. + */ + virtual std::vector Decode( + MNN::Express::VARP encoder_out) = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-greedy-search-decoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-greedy-search-decoder.cc new file mode 100644 index 00000000..24fb3ebb --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-greedy-search-decoder.cc @@ -0,0 +1,93 @@ +// sherpa-mnn/csrc/offline-moonshine-greedy-search-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-moonshine-greedy-search-decoder.h" + +#include +#include + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +std::vector +OfflineMoonshineGreedySearchDecoder::Decode(MNN::Express::VARP encoder_out) { + auto encoder_out_shape = encoder_out->getInfo()->dim; + if (encoder_out_shape[0] != 1) { + SHERPA_ONNX_LOGE("Support only batch size == 1. Given: %d\n", + static_cast(encoder_out_shape[0])); + return {}; + } + + auto memory_info = + (MNNAllocator*)(nullptr); + + // encoder_out_shape[1] * 384 is the number of audio samples + // 16000 is the sample rate + // + // + // 384 is from the moonshine paper + int32_t max_len = + static_cast(encoder_out_shape[1] * 384 / 16000.0 * 6); + + int32_t sos = 1; + int32_t eos = 2; + int32_t seq_len = 1; + + std::vector tokens; + + std::array token_shape = {1, 1}; + int seq_len_shape = 1; + + MNN::Express::VARP token_tensor = MNNUtilsCreateTensor( + memory_info, &sos, 1, token_shape.data(), token_shape.size()); + + MNN::Express::VARP seq_len_tensor = + MNNUtilsCreateTensor(memory_info, &seq_len, 1, &seq_len_shape, 1); + + MNN::Express::VARP logits{nullptr}; + std::vector states; + + std::tie(logits, states) = model_->ForwardUnCachedDecoder( + std::move(token_tensor), std::move(seq_len_tensor), View(encoder_out)); + + int32_t vocab_size = logits->getInfo()->dim[2]; + + for (int32_t i = 0; i != max_len; ++i) { + const float *p = logits->readMap(); + + int32_t max_token_id = static_cast( + std::distance(p, std::max_element(p, p + vocab_size))); + if (max_token_id == eos) { + break; + } + tokens.push_back(max_token_id); + + seq_len += 1; + + token_tensor = MNNUtilsCreateTensor( + memory_info, &tokens.back(), 1, token_shape.data(), token_shape.size()); + + seq_len_tensor = + MNNUtilsCreateTensor(memory_info, &seq_len, 1, &seq_len_shape, 1); + + // To fix the false alarm of clang-tidy + // error: 'states' used after it was moved + // [bugprone-use-after-move,-warnings-as-errors] + // we use a tmp_states here + std::vector tmp_states{std::move(states)}; + + std::tie(logits, states) = model_->ForwardCachedDecoder( + std::move(token_tensor), std::move(seq_len_tensor), View(encoder_out), + std::move(tmp_states)); + } + + OfflineMoonshineDecoderResult ans; + ans.tokens = std::move(tokens); + + return {ans}; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-greedy-search-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-greedy-search-decoder.h new file mode 100644 index 00000000..487eadd3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-greedy-search-decoder.h @@ -0,0 +1,29 @@ +// sherpa-mnn/csrc/offline-moonshine-greedy-search-decoder.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_ + +#include + +#include "sherpa-mnn/csrc/offline-moonshine-decoder.h" +#include "sherpa-mnn/csrc/offline-moonshine-model.h" + +namespace sherpa_mnn { + +class OfflineMoonshineGreedySearchDecoder : public OfflineMoonshineDecoder { + public: + explicit OfflineMoonshineGreedySearchDecoder(OfflineMoonshineModel *model) + : model_(model) {} + + std::vector Decode( + MNN::Express::VARP encoder_out) override; + + private: + OfflineMoonshineModel *model_; // not owned +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_GREEDY_SEARCH_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-model-config.cc new file mode 100644 index 00000000..d039744f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-model-config.cc @@ -0,0 +1,88 @@ +// sherpa-mnn/csrc/offline-moonshine-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-moonshine-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineMoonshineModelConfig::Register(ParseOptions *po) { + po->Register("moonshine-preprocessor", &preprocessor, + "Path to onnx preprocessor of moonshine, e.g., preprocess.onnx"); + + po->Register("moonshine-encoder", &encoder, + "Path to onnx encoder of moonshine, e.g., encode.onnx"); + + po->Register( + "moonshine-uncached-decoder", &uncached_decoder, + "Path to onnx uncached_decoder of moonshine, e.g., uncached_decode.onnx"); + + po->Register( + "moonshine-cached-decoder", &cached_decoder, + "Path to onnx cached_decoder of moonshine, e.g., cached_decode.onnx"); +} + +bool OfflineMoonshineModelConfig::Validate() const { + if (preprocessor.empty()) { + SHERPA_ONNX_LOGE("Please provide --moonshine-preprocessor"); + return false; + } + + if (!FileExists(preprocessor)) { + SHERPA_ONNX_LOGE("moonshine preprocessor file '%s' does not exist", + preprocessor.c_str()); + return false; + } + + if (encoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --moonshine-encoder"); + return false; + } + + if (!FileExists(encoder)) { + SHERPA_ONNX_LOGE("moonshine encoder file '%s' does not exist", + encoder.c_str()); + return false; + } + + if (uncached_decoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --moonshine-uncached-decoder"); + return false; + } + + if (!FileExists(uncached_decoder)) { + SHERPA_ONNX_LOGE("moonshine uncached decoder file '%s' does not exist", + uncached_decoder.c_str()); + return false; + } + + if (cached_decoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --moonshine-cached-decoder"); + return false; + } + + if (!FileExists(cached_decoder)) { + SHERPA_ONNX_LOGE("moonshine cached decoder file '%s' does not exist", + cached_decoder.c_str()); + return false; + } + + return true; +} + +std::string OfflineMoonshineModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineMoonshineModelConfig("; + os << "preprocessor=\"" << preprocessor << "\", "; + os << "encoder=\"" << encoder << "\", "; + os << "uncached_decoder=\"" << uncached_decoder << "\", "; + os << "cached_decoder=\"" << cached_decoder << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-model-config.h new file mode 100644 index 00000000..cf3ab504 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-model-config.h @@ -0,0 +1,37 @@ +// sherpa-mnn/csrc/offline-moonshine-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineMoonshineModelConfig { + std::string preprocessor; + std::string encoder; + std::string uncached_decoder; + std::string cached_decoder; + + OfflineMoonshineModelConfig() = default; + OfflineMoonshineModelConfig(const std::string &preprocessor, + const std::string &encoder, + const std::string &uncached_decoder, + const std::string &cached_decoder) + : preprocessor(preprocessor), + encoder(encoder), + uncached_decoder(uncached_decoder), + cached_decoder(cached_decoder) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-model.cc new file mode 100644 index 00000000..434d34de --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-model.cc @@ -0,0 +1,286 @@ +// sherpa-mnn/csrc/offline-moonshine-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-moonshine-model.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OfflineMoonshineModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.moonshine.preprocessor); + InitPreprocessor(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.moonshine.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.moonshine.uncached_decoder); + InitUnCachedDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.moonshine.cached_decoder); + InitCachedDecoder(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.moonshine.preprocessor); + InitPreprocessor(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.moonshine.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.moonshine.uncached_decoder); + InitUnCachedDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.moonshine.cached_decoder); + InitCachedDecoder(buf.data(), buf.size()); + } + } + + MNN::Express::VARP ForwardPreprocessor(MNN::Express::VARP audio) { + auto features = preprocessor_sess_->onForward({audio}); + + return std::move(features[0]); + } + + MNN::Express::VARP ForwardEncoder(MNN::Express::VARP features, MNN::Express::VARP features_len) { + std::vector encoder_inputs{std::move(features), + std::move(features_len)}; + auto encoder_out = encoder_sess_->onForward(encoder_inputs); + + return std::move(encoder_out[0]); + } + + std::pair> ForwardUnCachedDecoder( + MNN::Express::VARP tokens, MNN::Express::VARP seq_len, MNN::Express::VARP encoder_out) { + std::vector uncached_decoder_input = { + std::move(tokens), + std::move(encoder_out), + std::move(seq_len), + }; + + auto uncached_decoder_out = uncached_decoder_sess_->onForward( + uncached_decoder_input); + + std::vector states; + states.reserve(uncached_decoder_out.size() - 1); + + int32_t i = -1; + for (auto &s : uncached_decoder_out) { + ++i; + if (i == 0) { + continue; + } + + states.push_back(std::move(s)); + } + + return {std::move(uncached_decoder_out[0]), std::move(states)}; + } + + std::pair> ForwardCachedDecoder( + MNN::Express::VARP tokens, MNN::Express::VARP seq_len, MNN::Express::VARP encoder_out, + std::vector states) { + std::vector cached_decoder_input; + cached_decoder_input.reserve(3 + states.size()); + cached_decoder_input.push_back(std::move(tokens)); + cached_decoder_input.push_back(std::move(encoder_out)); + cached_decoder_input.push_back(std::move(seq_len)); + + for (auto &s : states) { + cached_decoder_input.push_back(std::move(s)); + } + + auto cached_decoder_out = cached_decoder_sess_->onForward(cached_decoder_input); + + std::vector next_states; + next_states.reserve(cached_decoder_out.size() - 1); + + int32_t i = -1; + for (auto &s : cached_decoder_out) { + ++i; + if (i == 0) { + continue; + } + + next_states.push_back(std::move(s)); + } + + return {std::move(cached_decoder_out[0]), std::move(next_states)}; + } + + MNNAllocator *Allocator() { return allocator_; } + + private: + void InitPreprocessor(void *model_data, size_t model_data_length) { + preprocessor_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(preprocessor_sess_.get(), &preprocessor_input_names_, + &preprocessor_input_names_ptr_); + + GetOutputNames(preprocessor_sess_.get(), &preprocessor_output_names_, + &preprocessor_output_names_ptr_); + } + + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + } + + void InitUnCachedDecoder(void *model_data, size_t model_data_length) { + uncached_decoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(uncached_decoder_sess_.get(), &uncached_decoder_input_names_, + &uncached_decoder_input_names_ptr_); + + GetOutputNames(uncached_decoder_sess_.get(), + &uncached_decoder_output_names_, + &uncached_decoder_output_names_ptr_); + } + + void InitCachedDecoder(void *model_data, size_t model_data_length) { + cached_decoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(cached_decoder_sess_.get(), &cached_decoder_input_names_, + &cached_decoder_input_names_ptr_); + + GetOutputNames(cached_decoder_sess_.get(), &cached_decoder_output_names_, + &cached_decoder_output_names_ptr_); + } + + private: + OfflineModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr preprocessor_sess_; + std::unique_ptr encoder_sess_; + std::unique_ptr uncached_decoder_sess_; + std::unique_ptr cached_decoder_sess_; + + std::vector preprocessor_input_names_; + std::vector preprocessor_input_names_ptr_; + + std::vector preprocessor_output_names_; + std::vector preprocessor_output_names_ptr_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector uncached_decoder_input_names_; + std::vector uncached_decoder_input_names_ptr_; + + std::vector uncached_decoder_output_names_; + std::vector uncached_decoder_output_names_ptr_; + + std::vector cached_decoder_input_names_; + std::vector cached_decoder_input_names_ptr_; + + std::vector cached_decoder_output_names_; + std::vector cached_decoder_output_names_ptr_; +}; + +OfflineMoonshineModel::OfflineMoonshineModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineMoonshineModel::OfflineMoonshineModel(Manager *mgr, + const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineMoonshineModel::~OfflineMoonshineModel() = default; + +MNN::Express::VARP OfflineMoonshineModel::ForwardPreprocessor(MNN::Express::VARP audio) const { + return impl_->ForwardPreprocessor(std::move(audio)); +} + +MNN::Express::VARP OfflineMoonshineModel::ForwardEncoder( + MNN::Express::VARP features, MNN::Express::VARP features_len) const { + return impl_->ForwardEncoder(std::move(features), std::move(features_len)); +} + +std::pair> +OfflineMoonshineModel::ForwardUnCachedDecoder(MNN::Express::VARP token, + MNN::Express::VARP seq_len, + MNN::Express::VARP encoder_out) const { + return impl_->ForwardUnCachedDecoder(std::move(token), std::move(seq_len), + std::move(encoder_out)); +} + +std::pair> +OfflineMoonshineModel::ForwardCachedDecoder( + MNN::Express::VARP token, MNN::Express::VARP seq_len, MNN::Express::VARP encoder_out, + std::vector states) const { + return impl_->ForwardCachedDecoder(std::move(token), std::move(seq_len), + std::move(encoder_out), std::move(states)); +} + +MNNAllocator *OfflineMoonshineModel::Allocator() const { + return impl_->Allocator(); +} + +#if __ANDROID_API__ >= 9 +template OfflineMoonshineModel::OfflineMoonshineModel( + AAssetManager *mgr, const OfflineModelConfig &config); +#endif + +#if __OHOS__ +template OfflineMoonshineModel::OfflineMoonshineModel( + NativeResourceManager *mgr, const OfflineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-model.h new file mode 100644 index 00000000..de71f7b2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-moonshine-model.h @@ -0,0 +1,87 @@ +// sherpa-mnn/csrc/offline-moonshine-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_ + +#include +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-model-config.h" + +namespace sherpa_mnn { + +// please see +// https://github.com/k2-fsa/sherpa-mnn/blob/master/scripts/moonshine/test.py +class OfflineMoonshineModel { + public: + explicit OfflineMoonshineModel(const OfflineModelConfig &config); + + template + OfflineMoonshineModel(Manager *mgr, const OfflineModelConfig &config); + + ~OfflineMoonshineModel(); + + /** Run the preprocessor model. + * + * @param audio A float32 tensor of shape (batch_size, num_samples) + * + * @return Return a float32 tensor of shape (batch_size, T, dim) that + * can be used as the input of ForwardEncoder() + */ + MNN::Express::VARP ForwardPreprocessor(MNN::Express::VARP audio) const; + + /** Run the encoder model. + * + * @param features A float32 tensor of shape (batch_size, T, dim) + * @param features_len A int32 tensor of shape (batch_size,) + * @returns A float32 tensor of shape (batch_size, T, dim). + */ + MNN::Express::VARP ForwardEncoder(MNN::Express::VARP features, MNN::Express::VARP features_len) const; + + /** Run the uncached decoder. + * + * @param token A int32 tensor of shape (batch_size, num_tokens) + * @param seq_len A int32 tensor of shape (batch_size,) containing number + * of predicted tokens so far + * @param encoder_out A float32 tensor of shape (batch_size, T, dim) + * + * @returns Return a pair: + * + * - logits, a float32 tensor of shape (batch_size, 1, dim) + * - states, a list of states + */ + std::pair> ForwardUnCachedDecoder( + MNN::Express::VARP token, MNN::Express::VARP seq_len, MNN::Express::VARP encoder_out) const; + + /** Run the cached decoder. + * + * @param token A int32 tensor of shape (batch_size, num_tokens) + * @param seq_len A int32 tensor of shape (batch_size,) containing number + * of predicted tokens so far + * @param encoder_out A float32 tensor of shape (batch_size, T, dim) + * @param states A list of previous states + * + * @returns Return a pair: + * - logits, a float32 tensor of shape (batch_size, 1, dim) + * - states, a list of new states + */ + std::pair> ForwardCachedDecoder( + MNN::Express::VARP token, MNN::Express::VARP seq_len, MNN::Express::VARP encoder_out, + std::vector states) const; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_MOONSHINE_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model-config.cc new file mode 100644 index 00000000..d48a4f4a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model-config.cc @@ -0,0 +1,35 @@ +// sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) { + po->Register("nemo-ctc-model", &model, + "Path to model.onnx of Nemo EncDecCtcModel."); +} + +bool OfflineNemoEncDecCtcModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("NeMo model: '%s' does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OfflineNemoEncDecCtcModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineNemoEncDecCtcModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model-config.h new file mode 100644 index 00000000..0665640a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model-config.h @@ -0,0 +1,28 @@ +// sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineNemoEncDecCtcModelConfig { + std::string model; + + OfflineNemoEncDecCtcModelConfig() = default; + explicit OfflineNemoEncDecCtcModelConfig(const std::string &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model.cc new file mode 100644 index 00000000..f939afa0 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model.cc @@ -0,0 +1,178 @@ +// sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model.cc +// +// Copyright (c) 2023-2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" +#include "sherpa-mnn/csrc/transpose.h" + +namespace sherpa_mnn { + +class OfflineNemoEncDecCtcModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.nemo_ctc.model); + Init(buf.data(), buf.size()); + } + + template + Impl(Manager *mgr, const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.nemo_ctc.model); + Init(buf.data(), buf.size()); + } + + std::vector Forward(MNN::Express::VARP features, + MNN::Express::VARP features_length) { + std::vector shape = + features_length->getInfo()->dim; + + MNN::Express::VARP out_features_length = MNNUtilsCreateTensor( + allocator_, shape.data(), shape.size()); + + const int *src = features_length->readMap(); + int *dst = out_features_length->writeMap(); + for (int i = 0; i != shape[0]; ++i) { + dst[i] = src[i] / subsampling_factor_; + } + + // (B, T, C) -> (B, C, T) + features = Transpose12(allocator_, features); + + std::vector inputs = {std::move(features), + std::move(features_length)}; + auto out = + sess_->onForward(inputs); + + std::vector ans; + ans.reserve(2); + ans.push_back(std::move(out[0])); + ans.push_back(std::move(out_features_length)); + return ans; + } + + int32_t VocabSize() const { return vocab_size_; } + + int32_t SubsamplingFactor() const { return subsampling_factor_; } + + MNNAllocator *Allocator() { return allocator_; } + + std::string FeatureNormalizationMethod() const { return normalize_type_; } + + bool IsGigaAM() const { return is_giga_am_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); + SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(normalize_type_, + "normalize_type"); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(is_giga_am_, "is_giga_am", 0); + } + + private: + OfflineModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t vocab_size_ = 0; + int32_t subsampling_factor_ = 0; + std::string normalize_type_; + + // it is 1 for models from + // https://github.com/salute-developers/GigaAM + int32_t is_giga_am_ = 0; +}; + +OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel( + const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel( + Manager *mgr, const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default; + +std::vector OfflineNemoEncDecCtcModel::Forward( + MNN::Express::VARP features, MNN::Express::VARP features_length) { + return impl_->Forward(std::move(features), std::move(features_length)); +} + +int32_t OfflineNemoEncDecCtcModel::VocabSize() const { + return impl_->VocabSize(); +} +int32_t OfflineNemoEncDecCtcModel::SubsamplingFactor() const { + return impl_->SubsamplingFactor(); +} + +MNNAllocator *OfflineNemoEncDecCtcModel::Allocator() const { + return impl_->Allocator(); +} + +std::string OfflineNemoEncDecCtcModel::FeatureNormalizationMethod() const { + return impl_->FeatureNormalizationMethod(); +} + +bool OfflineNemoEncDecCtcModel::IsGigaAM() const { return impl_->IsGigaAM(); } + +#if __ANDROID_API__ >= 9 +template OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel( + AAssetManager *mgr, const OfflineModelConfig &config); +#endif + +#if __OHOS__ +template OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel( + NativeResourceManager *mgr, const OfflineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model.h new file mode 100644 index 00000000..6c31753d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model.h @@ -0,0 +1,83 @@ +// sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_ +#include +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-ctc-model.h" +#include "sherpa-mnn/csrc/offline-model-config.h" + +namespace sherpa_mnn { + +/** This class implements the EncDecCTCModelBPE model from NeMo. + * + * See + * https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/ctc_bpe_models.py + * https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/ctc_models.py + */ +class OfflineNemoEncDecCtcModel : public OfflineCtcModel { + public: + explicit OfflineNemoEncDecCtcModel(const OfflineModelConfig &config); + + template + OfflineNemoEncDecCtcModel(Manager *mgr, const OfflineModelConfig &config); + + ~OfflineNemoEncDecCtcModel() override; + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int. + * + * @return Return a vector containing: + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int + */ + std::vector Forward(MNN::Express::VARP features, + MNN::Express::VARP features_length) override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** SubsamplingFactor of the model + * + * For Citrinet, the subsampling factor is usually 4. + * For Conformer CTC, the subsampling factor is usually 8. + */ + int32_t SubsamplingFactor() const override; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const override; + + // Possible values: + // - per_feature + // - all_features (not implemented yet) + // - fixed_mean (not implemented) + // - fixed_std (not implemented) + // - or just leave it to empty + // See + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 + // for details + std::string FeatureNormalizationMethod() const override; + + bool IsGigaAM() const override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +using OfflineNemoEncDecHybridRNNTCTCBPEModel = OfflineNemoEncDecCtcModel; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-decoder.h new file mode 100644 index 00000000..02832d57 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-decoder.h @@ -0,0 +1,42 @@ +// sherpa-mnn/csrc/offline-paraformer-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_DECODER_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT + +namespace sherpa_mnn { + +struct OfflineParaformerDecoderResult { + /// The decoded token IDs + std::vector tokens; + + // it contains the start time of each token in seconds + // + // len(timestamps) == len(tokens) + std::vector timestamps; +}; + +class OfflineParaformerDecoder { + public: + virtual ~OfflineParaformerDecoder() = default; + + /** Run beam search given the output from the paraformer model. + * + * @param log_probs A 3-D tensor of shape (N, T, vocab_size) + * @param token_num A 1-D tensor of shape (N). token_num equals to T. + * + * @return Return a vector of size `N` containing the decoded results. + */ + virtual std::vector Decode( + MNN::Express::VARP log_probs, MNN::Express::VARP token_num, + MNN::Express::VARP us_cif_peak = MNN::Express::VARP(nullptr)) = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-greedy-search-decoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-greedy-search-decoder.cc new file mode 100644 index 00000000..5ad0e944 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-greedy-search-decoder.cc @@ -0,0 +1,74 @@ +// sherpa-mnn/csrc/offline-paraformer-greedy-search-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-paraformer-greedy-search-decoder.h" + +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +std::vector +OfflineParaformerGreedySearchDecoder::Decode( + MNN::Express::VARP log_probs, MNN::Express::VARP /*token_num*/, + MNN::Express::VARP us_cif_peak /*=MNN::Express::VARP(nullptr)*/ +) { + std::vector shape = log_probs->getInfo()->dim; + int32_t batch_size = shape[0]; + int32_t num_tokens = shape[1]; + int32_t vocab_size = shape[2]; + + std::vector results(batch_size); + + for (int32_t i = 0; i != batch_size; ++i) { + const float *p = + log_probs->readMap() + i * num_tokens * vocab_size; + for (int32_t k = 0; k != num_tokens; ++k) { + auto max_idx = static_cast( + std::distance(p, std::max_element(p, p + vocab_size))); + if (max_idx == eos_id_) { + break; + } + + results[i].tokens.push_back(max_idx); + + p += vocab_size; + } + + if (us_cif_peak.get() != nullptr) { + int32_t dim = us_cif_peak->getInfo()->dim.back(); + + const auto *peak = us_cif_peak->readMap() + i * dim; + std::vector timestamps; + timestamps.reserve(results[i].tokens.size()); + + // 10.0: frameshift is 10 milliseconds + // 6: LfrWindowSize + // 3: us_cif_peak is upsampled by a factor of 3 + // 1000: milliseconds to seconds + float scale = 10.0 * 6 / 3 / 1000; + + for (int32_t k = 0; k != dim; ++k) { + if (peak[k] > 1 - 1e-4) { + timestamps.push_back(k * scale); + } + } + + if (!timestamps.empty()) { + timestamps.pop_back(); + } + + if (timestamps.size() == results[i].tokens.size()) { + results[i].timestamps = std::move(timestamps); + } + } + } + + return results; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-greedy-search-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-greedy-search-decoder.h new file mode 100644 index 00000000..8eb5eeb4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-greedy-search-decoder.h @@ -0,0 +1,29 @@ +// sherpa-mnn/csrc/offline-paraformer-greedy-search-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_GREEDY_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_GREEDY_SEARCH_DECODER_H_ + +#include + +#include "sherpa-mnn/csrc/offline-paraformer-decoder.h" + +namespace sherpa_mnn { + +class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder { + public: + explicit OfflineParaformerGreedySearchDecoder(int32_t eos_id) + : eos_id_(eos_id) {} + + std::vector Decode( + MNN::Express::VARP log_probs, MNN::Express::VARP token_num, + MNN::Express::VARP us_cif_peak = MNN::Express::VARP(nullptr)) override; + + private: + int32_t eos_id_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_GREEDY_SEARCH_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-model-config.cc new file mode 100644 index 00000000..d44f1ab1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-model-config.cc @@ -0,0 +1,34 @@ +// sherpa-mnn/csrc/offline-paraformer-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-paraformer-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineParaformerModelConfig::Register(ParseOptions *po) { + po->Register("paraformer", &model, "Path to model.onnx of paraformer."); +} + +bool OfflineParaformerModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("Paraformer model '%s' does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OfflineParaformerModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineParaformerModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-model-config.h new file mode 100644 index 00000000..31307713 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-model-config.h @@ -0,0 +1,28 @@ +// sherpa-mnn/csrc/offline-paraformer-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineParaformerModelConfig { + std::string model; + + OfflineParaformerModelConfig() = default; + explicit OfflineParaformerModelConfig(const std::string &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-model.cc new file mode 100644 index 00000000..e84ec134 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-model.cc @@ -0,0 +1,163 @@ +// sherpa-mnn/csrc/offline-paraformer-model.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-paraformer-model.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OfflineParaformerModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.paraformer.model); + Init(buf.data(), buf.size()); + } + + template + Impl(Manager *mgr, const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.paraformer.model); + Init(buf.data(), buf.size()); + } + + std::vector Forward(MNN::Express::VARP features, + MNN::Express::VARP features_length) { + std::vector inputs = {std::move(features), + std::move(features_length)}; + + return sess_->onForward(inputs); + } + + int32_t VocabSize() const { return vocab_size_; } + + int32_t LfrWindowSize() const { return lfr_window_size_; } + + int32_t LfrWindowShift() const { return lfr_window_shift_; } + + const std::vector &NegativeMean() const { return neg_mean_; } + + const std::vector &InverseStdDev() const { return inv_stddev_; } + + MNNAllocator *Allocator() { return allocator_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(lfr_window_size_, "lfr_window_size"); + SHERPA_ONNX_READ_META_DATA(lfr_window_shift_, "lfr_window_shift"); + + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(neg_mean_, "neg_mean"); + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(inv_stddev_, "inv_stddev"); + } + + private: + OfflineModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + std::vector neg_mean_; + std::vector inv_stddev_; + + int32_t vocab_size_ = 0; // initialized in Init + int32_t lfr_window_size_ = 0; + int32_t lfr_window_shift_ = 0; +}; + +OfflineParaformerModel::OfflineParaformerModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineParaformerModel::OfflineParaformerModel(Manager *mgr, + const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineParaformerModel::~OfflineParaformerModel() = default; + +std::vector OfflineParaformerModel::Forward( + MNN::Express::VARP features, MNN::Express::VARP features_length) { + return impl_->Forward(std::move(features), std::move(features_length)); +} + +int32_t OfflineParaformerModel::VocabSize() const { return impl_->VocabSize(); } + +int32_t OfflineParaformerModel::LfrWindowSize() const { + return impl_->LfrWindowSize(); +} +int32_t OfflineParaformerModel::LfrWindowShift() const { + return impl_->LfrWindowShift(); +} +const std::vector &OfflineParaformerModel::NegativeMean() const { + return impl_->NegativeMean(); +} +const std::vector &OfflineParaformerModel::InverseStdDev() const { + return impl_->InverseStdDev(); +} + +MNNAllocator *OfflineParaformerModel::Allocator() const { + return impl_->Allocator(); +} + +#if __ANDROID_API__ >= 9 +template OfflineParaformerModel::OfflineParaformerModel( + AAssetManager *mgr, const OfflineModelConfig &config); +#endif + +#if __OHOS__ +template OfflineParaformerModel::OfflineParaformerModel( + NativeResourceManager *mgr, const OfflineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-model.h new file mode 100644 index 00000000..7e44ebd7 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-paraformer-model.h @@ -0,0 +1,74 @@ +// sherpa-mnn/csrc/offline-paraformer-model.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_ + +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-model-config.h" + +namespace sherpa_mnn { + +class OfflineParaformerModel { + public: + explicit OfflineParaformerModel(const OfflineModelConfig &config); + + template + OfflineParaformerModel(Manager *mgr, const OfflineModelConfig &config); + + ~OfflineParaformerModel(); + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int32_t. + * + * @return Return a vector containing: + * - log_probs: A 3-D tensor of shape (N, T', vocab_size) + * - token_num: A 1-D tensor of shape (N, T') containing number + * of valid tokens in each utterance. Its dtype is int. + * If it is a model supporting timestamps, then there are additional two + * outputs: + * - us_alphas + * - us_cif_peak + */ + std::vector Forward(MNN::Express::VARP features, + MNN::Express::VARP features_length); + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const; + + /** It is lfr_m in config.yaml + */ + int32_t LfrWindowSize() const; + + /** It is lfr_n in config.yaml + */ + int32_t LfrWindowShift() const; + + /** Return negative mean for CMVN + */ + const std::vector &NegativeMean() const; + + /** Return inverse stddev for CMVN + */ + const std::vector &InverseStdDev() const; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation-ct-transformer-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation-ct-transformer-impl.h new file mode 100644 index 00000000..190f3d13 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation-ct-transformer-impl.h @@ -0,0 +1,197 @@ +// sherpa-mnn/csrc/offline-punctuation-ct-transformer-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_ + +#include + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/math.h" +#include "sherpa-mnn/csrc/offline-ct-transformer-model.h" +#include "sherpa-mnn/csrc/offline-punctuation-impl.h" +#include "sherpa-mnn/csrc/offline-punctuation.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { + public: + explicit OfflinePunctuationCtTransformerImpl( + const OfflinePunctuationConfig &config) + : config_(config), model_(config.model) {} + +#if __ANDROID_API__ >= 9 + OfflinePunctuationCtTransformerImpl(AAssetManager *mgr, + const OfflinePunctuationConfig &config) + : config_(config), model_(mgr, config.model) {} +#endif + + std::string AddPunctuation(const std::string &text) const override { + if (text.empty()) { + return {}; + } + + std::vector tokens = SplitUtf8(text); + std::vector token_ids; + token_ids.reserve(tokens.size()); + + const auto &meta_data = model_.metaData(); + + for (const auto &t : tokens) { + std::string token = ToLowerCase(t); + if (meta_data.token2id.count(token)) { + token_ids.push_back(meta_data.token2id.at(token)); + } else { + token_ids.push_back(meta_data.unk_id); + } + } + + auto memory_info = + (MNNAllocator*)(nullptr); + + int32_t segment_size = 20; + int32_t max_len = 200; + int32_t num_segments = + ceil((static_cast(token_ids.size()) + segment_size - 1) / + segment_size); + + std::vector punctuations; + int32_t last = -1; + for (int32_t i = 0; i != num_segments; ++i) { + int32_t this_start = i * segment_size; // included + int32_t this_end = this_start + segment_size; // not included + if (this_end > static_cast(token_ids.size())) { + this_end = token_ids.size(); + } + + if (last != -1) { + this_start = last; + } + // token_ids[this_start:this_end] is sent to the model + + std::array x_shape = {1, this_end - this_start}; + MNN::Express::VARP x = + MNNUtilsCreateTensor(memory_info, token_ids.data() + this_start, + x_shape[1], x_shape.data(), x_shape.size()); + + int len_shape = 1; + int32_t len = x_shape[1]; + MNN::Express::VARP x_len = + MNNUtilsCreateTensor(memory_info, &len, 1, &len_shape, 1); + + MNN::Express::VARP out = model_.Forward(std::move(x), std::move(x_len)); + + // [N, T, num_punctuations] + std::vector out_shape = + out->getInfo()->dim; + + assert(out_shape[0] == 1); + assert(out_shape[1] == len); + assert(out_shape[2] == meta_data.num_punctuations); + + std::vector this_punctuations; + this_punctuations.reserve(len); + + const float *p = out->readMap(); + for (int32_t k = 0; k != len; ++k, p += meta_data.num_punctuations) { + auto index = static_cast(std::distance( + p, std::max_element(p, p + meta_data.num_punctuations))); + this_punctuations.push_back(index); + } // for (int32_t k = 0; k != len; ++k, p += meta_data.num_punctuations) + + int32_t dot_index = -1; + int32_t comma_index = -1; + + for (int32_t m = static_cast(this_punctuations.size()) - 2; + m >= 1; --m) { + int32_t punct_id = this_punctuations[m]; + + if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) { + dot_index = m; + break; + } + + if (comma_index == -1 && punct_id == meta_data.comma_id) { + comma_index = m; + } + } // for (int32_t k = this_punctuations.size() - 1; k >= 1; --k) + + if (dot_index == -1 && len >= max_len && comma_index != -1) { + dot_index = comma_index; + this_punctuations[dot_index] = meta_data.dot_id; + } + + if (dot_index == -1) { + if (last == -1) { + last = this_start; + } + + if (i == num_segments - 1) { + dot_index = static_cast(this_punctuations.size()) - 1; + } + } else { + last = this_start + dot_index + 1; + } + + if (dot_index != -1) { + punctuations.insert(punctuations.end(), this_punctuations.begin(), + this_punctuations.begin() + (dot_index + 1)); + } + } // for (int32_t i = 0; i != num_segments; ++i) + + if (punctuations.empty()) { + return text + meta_data.id2punct[meta_data.dot_id]; + } + std::vector words_punct; + + for (int32_t i = 0; i != static_cast(punctuations.size()); ++i) { + if (i >= static_cast(tokens.size())) { + break; + } + std::string &w = tokens[i]; + if (i > 0 && !(words_punct.back()[0] & 0x80) && !(w[0] & 0x80)) { + words_punct.push_back(" "); + } + words_punct.push_back(std::move(w)); + + if (punctuations[i] != meta_data.underline_id) { + words_punct.push_back(meta_data.id2punct[punctuations[i]]); + } + } + + if (words_punct.back() == meta_data.id2punct[meta_data.comma_id] || + words_punct.back() == meta_data.id2punct[meta_data.pause_id]) { + words_punct.back() = meta_data.id2punct[meta_data.dot_id]; + } + + if (words_punct.back() != meta_data.id2punct[meta_data.dot_id] && + words_punct.back() != meta_data.id2punct[meta_data.quest_id]) { + words_punct.push_back(meta_data.id2punct[meta_data.dot_id]); + } + + std::string ans; + for (const auto &w : words_punct) { + ans.append(w); + } + return ans; + } + + private: + OfflinePunctuationConfig config_; + OfflineCtTransformerModel model_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation-impl.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation-impl.cc new file mode 100644 index 00000000..2b4a06e8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation-impl.cc @@ -0,0 +1,39 @@ +// sherpa-mnn/csrc/offline-punctuation-impl.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-punctuation-impl.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-punctuation-ct-transformer-impl.h" + +namespace sherpa_mnn { + +std::unique_ptr OfflinePunctuationImpl::Create( + const OfflinePunctuationConfig &config) { + if (!config.model.ct_transformer.empty()) { + return std::make_unique(config); + } + + SHERPA_ONNX_LOGE("Please specify a punctuation model! Return a null pointer"); + return nullptr; +} + +#if __ANDROID_API__ >= 9 +std::unique_ptr OfflinePunctuationImpl::Create( + AAssetManager *mgr, const OfflinePunctuationConfig &config) { + if (!config.model.ct_transformer.empty()) { + return std::make_unique(mgr, config); + } + + SHERPA_ONNX_LOGE("Please specify a punctuation model! Return a null pointer"); + return nullptr; +} +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation-impl.h new file mode 100644 index 00000000..48903c4d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation-impl.h @@ -0,0 +1,36 @@ +// sherpa-mnn/csrc/offline-punctuation-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_IMPL_H_ + +#include +#include +#include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/offline-punctuation.h" + +namespace sherpa_mnn { + +class OfflinePunctuationImpl { + public: + virtual ~OfflinePunctuationImpl() = default; + + static std::unique_ptr Create( + const OfflinePunctuationConfig &config); + +#if __ANDROID_API__ >= 9 + static std::unique_ptr Create( + AAssetManager *mgr, const OfflinePunctuationConfig &config); +#endif + + virtual std::string AddPunctuation(const std::string &text) const = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation-model-config.cc new file mode 100644 index 00000000..3d3620c8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation-model-config.cc @@ -0,0 +1,53 @@ +// sherpa-mnn/csrc/offline-punctuation-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-punctuation-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflinePunctuationModelConfig::Register(ParseOptions *po) { + po->Register("ct-transformer", &ct_transformer, + "Path to the controllable time-delay (CT) transformer model"); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool OfflinePunctuationModelConfig::Validate() const { + if (ct_transformer.empty()) { + SHERPA_ONNX_LOGE("Please provide --ct-transformer"); + return false; + } + + if (!FileExists(ct_transformer)) { + SHERPA_ONNX_LOGE("--ct-transformer %s does not exist", + ct_transformer.c_str()); + return false; + } + + return true; +} + +std::string OfflinePunctuationModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflinePunctuationModelConfig("; + os << "ct_transformer=\"" << ct_transformer << "\", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation-model-config.h new file mode 100644 index 00000000..4b62dbc3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation-model-config.h @@ -0,0 +1,38 @@ +// sherpa-mnn/csrc/offline-punctuation-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflinePunctuationModelConfig { + std::string ct_transformer; + + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + OfflinePunctuationModelConfig() = default; + + OfflinePunctuationModelConfig(const std::string &ct_transformer, + int32_t num_threads, bool debug, + const std::string &provider) + : ct_transformer(ct_transformer), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation.cc new file mode 100644 index 00000000..07c0537f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation.cc @@ -0,0 +1,53 @@ +// sherpa-mnn/csrc/offline-punctuation.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-punctuation.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-punctuation-impl.h" + +namespace sherpa_mnn { + +void OfflinePunctuationConfig::Register(ParseOptions *po) { + model.Register(po); +} + +bool OfflinePunctuationConfig::Validate() const { + if (!model.Validate()) { + return false; + } + + return true; +} + +std::string OfflinePunctuationConfig::ToString() const { + std::ostringstream os; + + os << "OfflinePunctuationConfig("; + os << "model=" << model.ToString() << ")"; + + return os.str(); +} + +OfflinePunctuation::OfflinePunctuation(const OfflinePunctuationConfig &config) + : impl_(OfflinePunctuationImpl::Create(config)) {} + +#if __ANDROID_API__ >= 9 +OfflinePunctuation::OfflinePunctuation(AAssetManager *mgr, + const OfflinePunctuationConfig &config) + : impl_(OfflinePunctuationImpl::Create(mgr, config)) {} +#endif + +OfflinePunctuation::~OfflinePunctuation() = default; + +std::string OfflinePunctuation::AddPunctuation(const std::string &text) const { + return impl_->AddPunctuation(text); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation.h new file mode 100644 index 00000000..723cc431 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-punctuation.h @@ -0,0 +1,57 @@ +// sherpa-mnn/csrc/offline-punctuation.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_H_ + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/offline-punctuation-model-config.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflinePunctuationConfig { + OfflinePunctuationModelConfig model; + + OfflinePunctuationConfig() = default; + + explicit OfflinePunctuationConfig(const OfflinePunctuationModelConfig &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +class OfflinePunctuationImpl; + +class OfflinePunctuation { + public: + explicit OfflinePunctuation(const OfflinePunctuationConfig &config); + +#if __ANDROID_API__ >= 9 + OfflinePunctuation(AAssetManager *mgr, + const OfflinePunctuationConfig &config); +#endif + + ~OfflinePunctuation(); + + // Add punctuation to the input text and return it. + std::string AddPunctuation(const std::string &text) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-ctc-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-ctc-impl.h new file mode 100644 index 00000000..0c49b410 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-ctc-impl.h @@ -0,0 +1,274 @@ +// sherpa-mnn/csrc/offline-recognizer-ctc-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_ + +#include +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/offline-ctc-decoder.h" +#include "sherpa-mnn/csrc/offline-ctc-fst-decoder.h" +#include "sherpa-mnn/csrc/offline-ctc-greedy-search-decoder.h" +#include "sherpa-mnn/csrc/offline-ctc-model.h" +#include "sherpa-mnn/csrc/offline-recognizer-impl.h" +#include "sherpa-mnn/csrc/pad-sequence.h" +#include "sherpa-mnn/csrc/symbol-table.h" + +namespace sherpa_mnn { + +static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, + const SymbolTable &sym_table, + int32_t frame_shift_ms, + int32_t subsampling_factor) { + OfflineRecognitionResult r; + r.tokens.reserve(src.tokens.size()); + r.timestamps.reserve(src.timestamps.size()); + + std::string text; + + for (int32_t i = 0; i != src.tokens.size(); ++i) { + if (sym_table.Contains("SIL") && src.tokens[i] == sym_table["SIL"]) { + // tdnn models from yesno have a SIL token, we should remove it. + continue; + } + auto sym = sym_table[src.tokens[i]]; + text.append(sym); + + if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { + // for bpe models with byte_fallback + // (but don't rewrite printable characters 0x20..0x7e, + // which collide with standard BPE units) + std::ostringstream os; + os << "<0x" << std::hex << std::uppercase + << (static_cast(sym[0]) & 0xff) << ">"; + sym = os.str(); + } + + r.tokens.push_back(std::move(sym)); + } + + if (sym_table.IsByteBpe()) { + text = sym_table.DecodeByteBpe(text); + } + + r.text = std::move(text); + + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; + for (auto t : src.timestamps) { + float time = frame_shift_s * t; + r.timestamps.push_back(time); + } + + r.words = std::move(src.words); + + return r; +} + +class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerCtcImpl(const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(config), + config_(config), + symbol_table_(config_.model_config.tokens), + model_(OfflineCtcModel::Create(config_.model_config)) { + Init(); + } + + template + OfflineRecognizerCtcImpl(Manager *mgr, const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(mgr, config), + config_(config), + symbol_table_(mgr, config_.model_config.tokens), + model_(OfflineCtcModel::Create(mgr, config_.model_config)) { + Init(); + } + + void Init() { + if (!config_.model_config.telespeech_ctc.empty()) { + config_.feat_config.snip_edges = true; + config_.feat_config.num_ceps = 40; + config_.feat_config.feature_dim = 40; + config_.feat_config.low_freq = 40; + config_.feat_config.high_freq = -200; + config_.feat_config.use_energy = false; + config_.feat_config.normalize_samples = false; + config_.feat_config.is_mfcc = true; + } + + if (!config_.model_config.nemo_ctc.model.empty()) { + if (model_->IsGigaAM()) { + config_.feat_config.low_freq = 0; + config_.feat_config.high_freq = 8000; + config_.feat_config.remove_dc_offset = false; + config_.feat_config.preemph_coeff = 0; + config_.feat_config.window_type = "hann"; + config_.feat_config.feature_dim = 64; + } else { + config_.feat_config.low_freq = 0; + config_.feat_config.high_freq = 0; + config_.feat_config.is_librosa = true; + config_.feat_config.remove_dc_offset = false; + config_.feat_config.window_type = "hann"; + } + } + + if (!config_.model_config.wenet_ctc.model.empty()) { + // WeNet CTC models assume input samples are in the range + // [-32768, 32767], so we set normalize_samples to false + config_.feat_config.normalize_samples = false; + } + + config_.feat_config.nemo_normalize_type = + model_->FeatureNormalizationMethod(); + + if (!config_.ctc_fst_decoder_config.graph.empty()) { + // TODO(fangjun): Support android to read the graph from + // asset_manager + decoder_ = std::make_unique( + config_.ctc_fst_decoder_config); + } else if (config_.decoding_method == "greedy_search") { + if (!symbol_table_.Contains("") && + !symbol_table_.Contains("") && + !symbol_table_.Contains("")) { + SHERPA_ONNX_LOGE( + "We expect that tokens.txt contains " + "the symbol or or and its ID."); + exit(-1); + } + + int32_t blank_id = 0; + if (symbol_table_.Contains("")) { + blank_id = symbol_table_[""]; + } else if (symbol_table_.Contains("")) { + // for tdnn models of the yesno recipe from icefall + blank_id = symbol_table_[""]; + } else if (symbol_table_.Contains("")) { + // for Wenet CTC models + blank_id = symbol_table_[""]; + } + + decoder_ = std::make_unique(blank_id); + } else { + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", + config_.decoding_method.c_str()); + exit(-1); + } + } + + std::unique_ptr CreateStream() const override { + return std::make_unique(config_.feat_config); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + if (!model_->SupportBatchProcessing()) { + // If the model does not support batch process, + // we process each stream independently. + for (int32_t i = 0; i != n; ++i) { + DecodeStream(ss[i]); + } + return; + } + + auto memory_info = + (MNNAllocator*)(nullptr); + + int32_t feat_dim = config_.feat_config.feature_dim; + + std::vector features; + features.reserve(n); + + std::vector> features_vec(n); + std::vector features_length_vec(n); + + for (int32_t i = 0; i != n; ++i) { + std::vector f = ss[i]->GetFrames(); + + int32_t num_frames = f.size() / feat_dim; + features_vec[i] = std::move(f); + + features_length_vec[i] = num_frames; + + std::array shape = {num_frames, feat_dim}; + + MNN::Express::VARP x = MNNUtilsCreateTensor( + memory_info, features_vec[i].data(), features_vec[i].size(), + shape.data(), shape.size()); + features.push_back(std::move(x)); + } // for (int32_t i = 0; i != n; ++i) + + std::vector features_pointer(n); + for (int32_t i = 0; i != n; ++i) { + features_pointer[i] = features[i]; + } + + std::array features_length_shape = {n}; + MNN::Express::VARP x_length = MNNUtilsCreateTensor( + memory_info, features_length_vec.data(), n, + features_length_shape.data(), features_length_shape.size()); + + MNN::Express::VARP x = PadSequence(model_->Allocator(), features_pointer, + -23.025850929940457f); + auto t = model_->Forward(std::move(x), std::move(x_length)); + + auto results = decoder_->Decode(std::move(t[0]), std::move(t[1])); + + int32_t frame_shift_ms = 10; + for (int32_t i = 0; i != n; ++i) { + auto r = Convert(results[i], symbol_table_, frame_shift_ms, + model_->SubsamplingFactor()); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + ss[i]->SetResult(r); + } + } + + OfflineRecognizerConfig GetConfig() const override { return config_; } + + private: + // Decode a single stream. + // Some models do not support batch size > 1, e.g., WeNet CTC models. + void DecodeStream(OfflineStream *s) const { + auto memory_info = + (MNNAllocator*)(nullptr); + + int32_t feat_dim = config_.feat_config.feature_dim; + std::vector f = s->GetFrames(); + + int32_t num_frames = f.size() / feat_dim; + + std::array shape = {1, num_frames, feat_dim}; + + MNN::Express::VARP x = MNNUtilsCreateTensor(memory_info, f.data(), f.size(), + shape.data(), shape.size()); + + int x_length_scalar = num_frames; + std::array x_length_shape = {1}; + MNN::Express::VARP x_length = + MNNUtilsCreateTensor(memory_info, &x_length_scalar, 1, + x_length_shape.data(), x_length_shape.size()); + + auto t = model_->Forward(std::move(x), std::move(x_length)); + auto results = decoder_->Decode(std::move(t[0]), std::move(t[1])); + int32_t frame_shift_ms = 10; + + auto r = Convert(results[0], symbol_table_, frame_shift_ms, + model_->SubsamplingFactor()); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + s->SetResult(r); + } + + private: + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-fire-red-asr-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-fire-red-asr-impl.h new file mode 100644 index 00000000..f952eaf4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-fire-red-asr-impl.h @@ -0,0 +1,158 @@ +// sherpa-mnn/csrc/offline-recognizer-fire-red-asr-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_ + +#include +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/offline-fire-red-asr-decoder.h" +#include "sherpa-mnn/csrc/offline-fire-red-asr-greedy-search-decoder.h" +#include "sherpa-mnn/csrc/offline-fire-red-asr-model.h" +#include "sherpa-mnn/csrc/offline-model-config.h" +#include "sherpa-mnn/csrc/offline-recognizer-impl.h" +#include "sherpa-mnn/csrc/offline-recognizer.h" +#include "sherpa-mnn/csrc/symbol-table.h" +#include "sherpa-mnn/csrc/transpose.h" + +namespace sherpa_mnn { + +static OfflineRecognitionResult Convert( + const OfflineFireRedAsrDecoderResult &src, const SymbolTable &sym_table) { + OfflineRecognitionResult r; + r.tokens.reserve(src.tokens.size()); + + std::string text; + for (auto i : src.tokens) { + if (!sym_table.Contains(i)) { + continue; + } + + const auto &s = sym_table[i]; + text += s; + r.tokens.push_back(s); + } + + r.text = text; + + return r; +} + +class OfflineRecognizerFireRedAsrImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerFireRedAsrImpl( + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(config), + config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique(config.model_config)) { + Init(); + } + + template + OfflineRecognizerFireRedAsrImpl(Manager *mgr, + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(mgr, config), + config_(config), + symbol_table_(mgr, config_.model_config.tokens), + model_(std::make_unique(mgr, + config.model_config)) { + Init(); + } + + void Init() { + if (config_.decoding_method == "greedy_search") { + decoder_ = + std::make_unique(model_.get()); + } else { + SHERPA_ONNX_LOGE( + "Only greedy_search is supported at present for FireRedAsr. Given %s", + config_.decoding_method.c_str()); + SHERPA_ONNX_EXIT(-1); + } + + const auto &meta_data = model_->metaData(); + + config_.feat_config.normalize_samples = false; + config_.feat_config.high_freq = 0; + config_.feat_config.snip_edges = true; + } + + std::unique_ptr CreateStream() const override { + return std::make_unique(config_.feat_config); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + // batch decoding is not implemented yet + for (int32_t i = 0; i != n; ++i) { + DecodeStream(ss[i]); + } + } + + OfflineRecognizerConfig GetConfig() const override { return config_; } + + private: + void DecodeStream(OfflineStream *s) const { + auto memory_info = + (MNNAllocator*)(nullptr); + + int32_t feat_dim = s->FeatureDim(); + std::vector f = s->GetFrames(); + ApplyCMVN(&f); + + int num_frames = f.size() / feat_dim; + + std::array shape{1, num_frames, feat_dim}; + + MNN::Express::VARP x = MNNUtilsCreateTensor(memory_info, f.data(), f.size(), + shape.data(), shape.size()); + + int len_shape = 1; + MNN::Express::VARP x_len = + MNNUtilsCreateTensor(memory_info, &num_frames, 1, &len_shape, 1); + + auto cross_kv = model_->ForwardEncoder(std::move(x), std::move(x_len)); + + auto results = + decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second)); + + auto r = Convert(results[0], symbol_table_); + + r.text = ApplyInverseTextNormalization(std::move(r.text)); + s->SetResult(r); + } + + void ApplyCMVN(std::vector *v) const { + const auto &meta_data = model_->metaData(); + const auto &mean = meta_data.mean; + const auto &inv_stddev = meta_data.inv_stddev; + int32_t feat_dim = static_cast(mean.size()); + int32_t num_frames = static_cast(v->size()) / feat_dim; + + float *p = v->data(); + + for (int32_t i = 0; i != num_frames; ++i) { + for (int32_t k = 0; k != feat_dim; ++k) { + p[k] = (p[k] - mean[k]) * inv_stddev[k]; + } + + p += feat_dim; + } + } + + private: + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_FIRE_RED_ASR_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-impl.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-impl.cc new file mode 100644 index 00000000..165ed218 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-impl.cc @@ -0,0 +1,529 @@ +// sherpa-mnn/csrc/offline-recognizer-impl.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-recognizer-impl.h" + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 + +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "fst/extensions/far/far.h" +#include "kaldifst/csrc/kaldi-fst-io.h" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-recognizer-ctc-impl.h" +#include "sherpa-mnn/csrc/offline-recognizer-fire-red-asr-impl.h" +#include "sherpa-mnn/csrc/offline-recognizer-moonshine-impl.h" +#include "sherpa-mnn/csrc/offline-recognizer-paraformer-impl.h" +#include "sherpa-mnn/csrc/offline-recognizer-sense-voice-impl.h" +#include "sherpa-mnn/csrc/offline-recognizer-transducer-impl.h" +#include "sherpa-mnn/csrc/offline-recognizer-transducer-nemo-impl.h" +#include "sherpa-mnn/csrc/offline-recognizer-whisper-impl.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +std::unique_ptr OfflineRecognizerImpl::Create( + const OfflineRecognizerConfig &config) { + if (!config.model_config.sense_voice.model.empty()) { + return std::make_unique(config); + } + + if (!config.model_config.paraformer.model.empty()) { + return std::make_unique(config); + } + + if (!config.model_config.nemo_ctc.model.empty() || + !config.model_config.zipformer_ctc.model.empty() || + !config.model_config.tdnn.model.empty() || + !config.model_config.wenet_ctc.model.empty()) { + return std::make_unique(config); + } + + if (!config.model_config.whisper.encoder.empty()) { + return std::make_unique(config); + } + + if (!config.model_config.fire_red_asr.encoder.empty()) { + return std::make_unique(config); + } + + if (!config.model_config.moonshine.preprocessor.empty()) { + return std::make_unique(config); + } + + // TODO(fangjun): Refactor it. We only need to use model type for the + // following models: + // 1. transducer and nemo_transducer + if (!config.model_config.model_type.empty()) { + const auto &model_type = config.model_config.model_type; + if (model_type == "transducer") { + return std::make_unique(config); + } else if (model_type == "nemo_transducer") { + return std::make_unique(config); + } else if (model_type == "paraformer") { + return std::make_unique(config); + } else if (model_type == "nemo_ctc" || model_type == "tdnn" || + model_type == "zipformer2_ctc" || model_type == "wenet_ctc" || + model_type == "telespeech_ctc") { + return std::make_unique(config); + } else if (model_type == "whisper") { + // unreachable + return std::make_unique(config); + } else if (model_type == "moonshine") { + // unreachable + return std::make_unique(config); + } else { + SHERPA_ONNX_LOGE( + "Invalid model_type: %s. Trying to load the model to get its type", + model_type.c_str()); + } + } + + MNNEnv env; + + std::shared_ptr sess_opts; + + + + std::string model_filename; + if (!config.model_config.transducer.encoder_filename.empty()) { + model_filename = config.model_config.transducer.encoder_filename; + } else if (!config.model_config.paraformer.model.empty()) { + model_filename = config.model_config.paraformer.model; + } else if (!config.model_config.nemo_ctc.model.empty()) { + model_filename = config.model_config.nemo_ctc.model; + } else if (!config.model_config.telespeech_ctc.empty()) { + model_filename = config.model_config.telespeech_ctc; + } else if (!config.model_config.tdnn.model.empty()) { + model_filename = config.model_config.tdnn.model; + } else if (!config.model_config.zipformer_ctc.model.empty()) { + model_filename = config.model_config.zipformer_ctc.model; + } else if (!config.model_config.wenet_ctc.model.empty()) { + model_filename = config.model_config.wenet_ctc.model; + } else if (!config.model_config.whisper.encoder.empty()) { + model_filename = config.model_config.whisper.encoder; + } else { + SHERPA_ONNX_LOGE("Please provide a model"); + exit(-1); + } + + auto buf = ReadFile(model_filename); + + auto encoder_sess = + std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)buf.data(), buf.size(), sess_opts)); + + MNNMeta meta_data = encoder_sess->getInfo()->metaData; + + MNNAllocator* allocator; // used in the macro below + + auto model_type = + LookupCustomModelMetaData(meta_data, "model_type", allocator); + if (model_type.empty()) { + SHERPA_ONNX_LOGE( + "No model_type in the metadata!\n\n" + "Please refer to the following URLs to add metadata" + "\n" + "(0) Transducer models from icefall" + "\n " + "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" + "pruned_transducer_stateless7/export-onnx.py#L303" + "\n" + "(1) Nemo CTC models\n " + "https://huggingface.co/csukuangfj/" + "sherpa-mnn-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py" + "\n" + "(2) Paraformer" + "\n " + "https://huggingface.co/csukuangfj/" + "paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py" + "\n " + "(3) Whisper" + "\n " + "(4) Tdnn models of the yesno recipe from icefall" + "\n " + "https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn" + "\n" + "(5) Zipformer CTC models from icefall" + "\n " + "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" + "zipformer/export-onnx-ctc.py" + "\n" + "(6) CTC models from WeNet" + "\n " + "https://github.com/k2-fsa/sherpa-mnn/blob/master/scripts/wenet/run.sh" + "\n" + "(7) CTC models from TeleSpeech" + "\n " + "https://github.com/Tele-AI/TeleSpeech-ASR" + "\n" + "\n"); + exit(-1); + } + + if (model_type == "conformer" || model_type == "zipformer" || + model_type == "zipformer2") { + return std::make_unique(config); + } + + if (model_type == "paraformer") { + return std::make_unique(config); + } + + if ((model_type == "EncDecHybridRNNTCTCBPEModel" || + model_type == "EncDecRNNTBPEModel") && + !config.model_config.transducer.decoder_filename.empty() && + !config.model_config.transducer.joiner_filename.empty()) { + return std::make_unique(config); + } + + if (model_type == "EncDecCTCModelBPE" || model_type == "EncDecCTCModel" || + model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || + model_type == "zipformer2_ctc" || model_type == "wenet_ctc" || + model_type == "telespeech_ctc") { + return std::make_unique(config); + } + + if (strncmp(model_type.c_str(), "whisper", 7) == 0) { + return std::make_unique(config); + } + + SHERPA_ONNX_LOGE( + "\nUnsupported model_type: %s\n" + "We support only the following model types at present: \n" + " - Non-streaming transducer models from icefall\n" + " - Non-streaming Paraformer models from FunASR\n" + " - EncDecCTCModelBPE models from NeMo\n" + " - EncDecCTCModel models from NeMo\n" + " - EncDecHybridRNNTCTCBPEModel models from NeMo\n" + " - EncDecRNNTBPEModel models from NeMO" + " - Whisper models\n" + " - Tdnn models\n" + " - Zipformer CTC models\n" + " - WeNet CTC models\n" + " - TeleSpeech CTC models\n", + model_type.c_str()); + + exit(-1); +} + +template +std::unique_ptr OfflineRecognizerImpl::Create( + Manager *mgr, const OfflineRecognizerConfig &config) { + if (!config.model_config.sense_voice.model.empty()) { + return std::make_unique(mgr, config); + } + + if (!config.model_config.paraformer.model.empty()) { + return std::make_unique(mgr, config); + } + + if (!config.model_config.nemo_ctc.model.empty() || + !config.model_config.zipformer_ctc.model.empty() || + !config.model_config.tdnn.model.empty() || + !config.model_config.wenet_ctc.model.empty()) { + return std::make_unique(mgr, config); + } + + if (!config.model_config.whisper.encoder.empty()) { + return std::make_unique(mgr, config); + } + + if (!config.model_config.fire_red_asr.encoder.empty()) { + return std::make_unique(mgr, config); + } + + if (!config.model_config.moonshine.preprocessor.empty()) { + return std::make_unique(mgr, config); + } + + // TODO(fangjun): Refactor it. We only need to use model type for the + // following models: + // 1. transducer and nemo_transducer + if (!config.model_config.model_type.empty()) { + const auto &model_type = config.model_config.model_type; + if (model_type == "transducer") { + return std::make_unique(mgr, config); + } else if (model_type == "nemo_transducer") { + return std::make_unique(mgr, config); + } else if (model_type == "paraformer") { + return std::make_unique(mgr, config); + } else if (model_type == "nemo_ctc" || model_type == "tdnn" || + model_type == "zipformer2_ctc" || model_type == "wenet_ctc" || + model_type == "telespeech_ctc") { + return std::make_unique(mgr, config); + } else if (model_type == "whisper") { + return std::make_unique(mgr, config); + } else if (model_type == "moonshine") { + return std::make_unique(mgr, config); + } else { + SHERPA_ONNX_LOGE( + "Invalid model_type: %s. Trying to load the model to get its type", + model_type.c_str()); + } + } + + MNNEnv env; + + std::shared_ptr sess_opts; + + + + std::string model_filename; + if (!config.model_config.transducer.encoder_filename.empty()) { + model_filename = config.model_config.transducer.encoder_filename; + } else if (!config.model_config.paraformer.model.empty()) { + model_filename = config.model_config.paraformer.model; + } else if (!config.model_config.nemo_ctc.model.empty()) { + model_filename = config.model_config.nemo_ctc.model; + } else if (!config.model_config.tdnn.model.empty()) { + model_filename = config.model_config.tdnn.model; + } else if (!config.model_config.zipformer_ctc.model.empty()) { + model_filename = config.model_config.zipformer_ctc.model; + } else if (!config.model_config.wenet_ctc.model.empty()) { + model_filename = config.model_config.wenet_ctc.model; + } else if (!config.model_config.telespeech_ctc.empty()) { + model_filename = config.model_config.telespeech_ctc; + } else if (!config.model_config.whisper.encoder.empty()) { + model_filename = config.model_config.whisper.encoder; + } else { + SHERPA_ONNX_LOGE("Please provide a model"); + exit(-1); + } + + auto buf = ReadFile(mgr, model_filename); + + auto encoder_sess = + std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)buf.data(), buf.size(), sess_opts)); + + MNNMeta meta_data = encoder_sess->getInfo()->metaData; + + MNNAllocator* allocator; // used in the macro below + + auto model_type = + LookupCustomModelMetaData(meta_data, "model_type", allocator); + if (model_type.empty()) { + SHERPA_ONNX_LOGE( + "No model_type in the metadata!\n\n" + "Please refer to the following URLs to add metadata" + "\n" + "(0) Transducer models from icefall" + "\n " + "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" + "pruned_transducer_stateless7/export-onnx.py#L303" + "\n" + "(1) Nemo CTC models\n " + "https://huggingface.co/csukuangfj/" + "sherpa-mnn-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py" + "\n" + "(2) Paraformer" + "\n " + "https://huggingface.co/csukuangfj/" + "paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py" + "\n " + "(3) Whisper" + "\n " + "(4) Tdnn models of the yesno recipe from icefall" + "\n " + "https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn" + "\n" + "(5) Zipformer CTC models from icefall" + "\n " + "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" + "zipformer/export-onnx-ctc.py" + "\n" + "(6) CTC models from WeNet" + "\n " + "https://github.com/k2-fsa/sherpa-mnn/blob/master/scripts/wenet/run.sh" + "\n" + "(7) CTC models from TeleSpeech" + "\n " + "https://github.com/Tele-AI/TeleSpeech-ASR" + "\n" + "\n"); + exit(-1); + } + + if (model_type == "conformer" || model_type == "zipformer" || + model_type == "zipformer2") { + return std::make_unique(mgr, config); + } + + if (model_type == "paraformer") { + return std::make_unique(mgr, config); + } + + if ((model_type == "EncDecHybridRNNTCTCBPEModel" || + model_type == "EncDecRNNTBPEModel") && + !config.model_config.transducer.decoder_filename.empty() && + !config.model_config.transducer.joiner_filename.empty()) { + return std::make_unique(mgr, config); + } + + if (model_type == "EncDecCTCModelBPE" || model_type == "EncDecCTCModel" || + model_type == "EncDecHybridRNNTCTCBPEModel" || model_type == "tdnn" || + model_type == "zipformer2_ctc" || model_type == "wenet_ctc" || + model_type == "telespeech_ctc") { + return std::make_unique(mgr, config); + } + + if (strncmp(model_type.c_str(), "whisper", 7) == 0) { + return std::make_unique(mgr, config); + } + + SHERPA_ONNX_LOGE( + "\nUnsupported model_type: %s\n" + "We support only the following model types at present: \n" + " - Non-streaming transducer models from icefall\n" + " - Non-streaming Paraformer models from FunASR\n" + " - EncDecCTCModelBPE models from NeMo\n" + " - EncDecCTCModel models from NeMo\n" + " - EncDecHybridRNNTCTCBPEModel models from NeMo\n" + " - EncDecRNNTBPEModel models from NeMo\n" + " - Whisper models\n" + " - Tdnn models\n" + " - Zipformer CTC models\n" + " - WeNet CTC models\n" + " - TeleSpeech CTC models\n", + model_type.c_str()); + + exit(-1); +} + +OfflineRecognizerImpl::OfflineRecognizerImpl( + const OfflineRecognizerConfig &config) + : config_(config) { + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + itn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); + } + itn_list_.push_back(std::make_unique(f)); + } + } + + if (!config.rule_fars.empty()) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("Loading FST archives"); + } + std::vector files; + SplitStringToVector(config.rule_fars, ",", false, &files); + + itn_list_.reserve(files.size() + itn_list_.size()); + + for (const auto &f : files) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); + } + std::unique_ptr> reader( + fst::FarReader::Open(f)); + for (; !reader->Done(); reader->Next()) { + std::unique_ptr r( + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); + + itn_list_.push_back( + std::make_unique(std::move(r))); + } + } + + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("FST archives loaded!"); + } + } +} + +template +OfflineRecognizerImpl::OfflineRecognizerImpl( + Manager *mgr, const OfflineRecognizerConfig &config) + : config_(config) { + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + itn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); + } + auto buf = ReadFile(mgr, f); + std::istrstream is(buf.data(), buf.size()); + itn_list_.push_back(std::make_unique(is)); + } + } + + if (!config.rule_fars.empty()) { + std::vector files; + SplitStringToVector(config.rule_fars, ",", false, &files); + itn_list_.reserve(files.size() + itn_list_.size()); + + for (const auto &f : files) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); + } + + auto buf = ReadFile(mgr, f); + + std::unique_ptr s( + new std::istrstream(buf.data(), buf.size())); + + std::unique_ptr> reader( + fst::FarReader::Open(std::move(s))); + + for (; !reader->Done(); reader->Next()) { + std::unique_ptr r( + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); + + itn_list_.push_back( + std::make_unique(std::move(r))); + } // for (; !reader->Done(); reader->Next()) + } // for (const auto &f : files) + } // if (!config.rule_fars.empty()) +} + +std::string OfflineRecognizerImpl::ApplyInverseTextNormalization( + std::string text) const { + text = RemoveInvalidUtf8Sequences(text); + + if (!itn_list_.empty()) { + for (const auto &tn : itn_list_) { + text = tn->Normalize(text); + } + } + + return text; +} + +void OfflineRecognizerImpl::SetConfig(const OfflineRecognizerConfig &config) { + config_ = config; +} + +#if __ANDROID_API__ >= 9 +template OfflineRecognizerImpl::OfflineRecognizerImpl( + AAssetManager *mgr, const OfflineRecognizerConfig &config); + +template std::unique_ptr OfflineRecognizerImpl::Create( + AAssetManager *mgr, const OfflineRecognizerConfig &config); +#endif + +#if __OHOS__ +template OfflineRecognizerImpl::OfflineRecognizerImpl( + NativeResourceManager *mgr, const OfflineRecognizerConfig &config); +template std::unique_ptr OfflineRecognizerImpl::Create( + NativeResourceManager *mgr, const OfflineRecognizerConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-impl.h new file mode 100644 index 00000000..e3fb915c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-impl.h @@ -0,0 +1,61 @@ +// sherpa-mnn/csrc/offline-recognizer-impl.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ + +#include +#include +#include + +#include "kaldifst/csrc/text-normalizer.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-recognizer.h" +#include "sherpa-mnn/csrc/offline-stream.h" + +namespace sherpa_mnn { + +class OfflineRecognizerImpl { + public: + explicit OfflineRecognizerImpl(const OfflineRecognizerConfig &config); + + static std::unique_ptr Create( + const OfflineRecognizerConfig &config); + + template + OfflineRecognizerImpl(Manager *mgr, const OfflineRecognizerConfig &config); + + template + static std::unique_ptr Create( + Manager *mgr, const OfflineRecognizerConfig &config); + + virtual ~OfflineRecognizerImpl() = default; + + virtual std::unique_ptr CreateStream( + const std::string &hotwords) const { + SHERPA_ONNX_LOGE("Only transducer models support contextual biasing."); + exit(-1); + } + + virtual std::unique_ptr CreateStream() const = 0; + + virtual void DecodeStreams(OfflineStream **ss, int32_t n) const = 0; + + virtual void SetConfig(const OfflineRecognizerConfig &config); + + virtual OfflineRecognizerConfig GetConfig() const = 0; + + std::string ApplyInverseTextNormalization(std::string text) const; + + private: + OfflineRecognizerConfig config_; + // for inverse text normalization. Used only if + // config.rule_fsts is not empty or + // config.rule_fars is not empty + std::vector> itn_list_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-moonshine-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-moonshine-impl.h new file mode 100644 index 00000000..00ba9fa1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-moonshine-impl.h @@ -0,0 +1,137 @@ +// sherpa-mnn/csrc/offline-recognizer-moonshine-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_ + +#include +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/offline-model-config.h" +#include "sherpa-mnn/csrc/offline-moonshine-decoder.h" +#include "sherpa-mnn/csrc/offline-moonshine-greedy-search-decoder.h" +#include "sherpa-mnn/csrc/offline-moonshine-model.h" +#include "sherpa-mnn/csrc/offline-recognizer-impl.h" +#include "sherpa-mnn/csrc/offline-recognizer.h" +#include "sherpa-mnn/csrc/symbol-table.h" +#include "sherpa-mnn/csrc/transpose.h" + +namespace sherpa_mnn { + +static OfflineRecognitionResult Convert( + const OfflineMoonshineDecoderResult &src, const SymbolTable &sym_table) { + OfflineRecognitionResult r; + r.tokens.reserve(src.tokens.size()); + + std::string text; + for (auto i : src.tokens) { + if (!sym_table.Contains(i)) { + continue; + } + + const auto &s = sym_table[i]; + text += s; + r.tokens.push_back(s); + } + + r.text = text; + + return r; +} + +class OfflineRecognizerMoonshineImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerMoonshineImpl(const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(config), + config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique(config.model_config)) { + Init(); + } + + template + OfflineRecognizerMoonshineImpl(Manager *mgr, + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(mgr, config), + config_(config), + symbol_table_(mgr, config_.model_config.tokens), + model_( + std::make_unique(mgr, config.model_config)) { + Init(); + } + + void Init() { + if (config_.decoding_method == "greedy_search") { + decoder_ = + std::make_unique(model_.get()); + } else { + SHERPA_ONNX_LOGE( + "Only greedy_search is supported at present for moonshine. Given %s", + config_.decoding_method.c_str()); + exit(-1); + } + } + + std::unique_ptr CreateStream() const override { + MoonshineTag tag; + return std::make_unique(tag); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + // batch decoding is not implemented yet + for (int32_t i = 0; i != n; ++i) { + DecodeStream(ss[i]); + } + } + + OfflineRecognizerConfig GetConfig() const override { return config_; } + + private: + void DecodeStream(OfflineStream *s) const { + auto memory_info = + (MNNAllocator*)(nullptr); + + std::vector audio = s->GetFrames(); + + + std::array shape{1, static_cast(audio.size())}; + + MNN::Express::VARP audio_tensor = MNNUtilsCreateTensor( + memory_info, audio.data(), audio.size(), shape.data(), shape.size()); + + MNN::Express::VARP features = + model_->ForwardPreprocessor(std::move(audio_tensor)); + + int32_t features_len = features->getInfo()->dim[1]; + + int features_shape = 1; + + MNN::Express::VARP features_len_tensor = MNNUtilsCreateTensor( + memory_info, &features_len, 1, &features_shape, 1); + + MNN::Express::VARP encoder_out = model_->ForwardEncoder( + std::move(features), std::move(features_len_tensor)); + + auto results = decoder_->Decode(std::move(encoder_out)); + + auto r = Convert(results[0], symbol_table_); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + s->SetResult(r); + + } + + private: + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_MOONSHINE_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-paraformer-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-paraformer-impl.h new file mode 100644 index 00000000..0ad44fc9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-paraformer-impl.h @@ -0,0 +1,261 @@ +// sherpa-mnn/csrc/offline-recognizer-paraformer-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_PARAFORMER_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_PARAFORMER_IMPL_H_ + +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/offline-model-config.h" +#include "sherpa-mnn/csrc/offline-paraformer-decoder.h" +#include "sherpa-mnn/csrc/offline-paraformer-greedy-search-decoder.h" +#include "sherpa-mnn/csrc/offline-paraformer-model.h" +#include "sherpa-mnn/csrc/offline-recognizer-impl.h" +#include "sherpa-mnn/csrc/offline-recognizer.h" +#include "sherpa-mnn/csrc/pad-sequence.h" +#include "sherpa-mnn/csrc/symbol-table.h" + +namespace sherpa_mnn { + +static OfflineRecognitionResult Convert( + const OfflineParaformerDecoderResult &src, const SymbolTable &sym_table) { + OfflineRecognitionResult r; + r.tokens.reserve(src.tokens.size()); + r.timestamps = src.timestamps; + + std::string text; + + // When the current token ends with "@@" we set mergeable to true + bool mergeable = false; + + for (int32_t i = 0; i != src.tokens.size(); ++i) { + auto sym = sym_table[src.tokens[i]]; + r.tokens.push_back(sym); + + if ((sym.back() != '@') || (sym.size() > 2 && sym[sym.size() - 2] != '@')) { + // sym does not end with "@@" + const uint8_t *p = reinterpret_cast(sym.c_str()); + if (p[0] < 0x80) { + // an ascii + if (mergeable) { + mergeable = false; + text.append(sym); + } else { + text.append(" "); + text.append(sym); + } + } else { + // not an ascii + mergeable = false; + + if (i > 0) { + const uint8_t p = reinterpret_cast( + sym_table[src.tokens[i - 1]].c_str())[0]; + if (p < 0x80) { + // put a space between ascii and non-ascii + text.append(" "); + } + } + text.append(sym); + } + } else { + // this sym ends with @@ + sym = std::string(sym.data(), sym.size() - 2); + if (mergeable) { + text.append(sym); + } else { + text.append(" "); + text.append(sym); + mergeable = true; + } + } + } + r.text = std::move(text); + + return r; +} + +class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerParaformerImpl( + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(config), + config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique(config.model_config)) { + if (config.decoding_method == "greedy_search") { + int32_t eos_id = symbol_table_[""]; + decoder_ = std::make_unique(eos_id); + } else { + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", + config.decoding_method.c_str()); + exit(-1); + } + + InitFeatConfig(); + } + + template + OfflineRecognizerParaformerImpl(Manager *mgr, + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(mgr, config), + config_(config), + symbol_table_(mgr, config_.model_config.tokens), + model_(std::make_unique(mgr, + config.model_config)) { + if (config.decoding_method == "greedy_search") { + int32_t eos_id = symbol_table_[""]; + decoder_ = std::make_unique(eos_id); + } else { + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", + config.decoding_method.c_str()); + exit(-1); + } + + InitFeatConfig(); + } + + std::unique_ptr CreateStream() const override { + return std::make_unique(config_.feat_config); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + // 1. Apply LFR + // 2. Apply CMVN + // + // Please refer to + // https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45555.pdf + // for what LFR means + // + // "Lower Frame Rate Neural Network Acoustic Models" + auto memory_info = + (MNNAllocator*)(nullptr); + + std::vector features; + features.reserve(n); + + int32_t feat_dim = + config_.feat_config.feature_dim * model_->LfrWindowSize(); + + std::vector> features_vec(n); + std::vector features_length_vec(n); + for (int32_t i = 0; i != n; ++i) { + std::vector f = ss[i]->GetFrames(); + + f = ApplyLFR(f); + ApplyCMVN(&f); + + int32_t num_frames = f.size() / feat_dim; + features_vec[i] = std::move(f); + + features_length_vec[i] = num_frames; + + std::array shape = {num_frames, feat_dim}; + + MNN::Express::VARP x = MNNUtilsCreateTensor( + memory_info, features_vec[i].data(), features_vec[i].size(), + shape.data(), shape.size()); + features.push_back(std::move(x)); + } + + std::vector features_pointer(n); + for (int32_t i = 0; i != n; ++i) { + features_pointer[i] = features[i]; + } + + std::array features_length_shape = {n}; + MNN::Express::VARP x_length = MNNUtilsCreateTensor( + memory_info, features_length_vec.data(), n, + features_length_shape.data(), features_length_shape.size()); + + // Caution(fangjun): We cannot pad it with log(eps), + // i.e., -23.025850929940457f + MNN::Express::VARP x = PadSequence(model_->Allocator(), features_pointer, 0); + + std::vector t; + t = model_->Forward(std::move(x), std::move(x_length)); + + std::vector results; + if (t.size() == 2) { + results = decoder_->Decode(std::move(t[0]), std::move(t[1])); + } else { + results = + decoder_->Decode(std::move(t[0]), std::move(t[1]), std::move(t[3])); + } + + for (int32_t i = 0; i != n; ++i) { + auto r = Convert(results[i], symbol_table_); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + ss[i]->SetResult(r); + } + } + + OfflineRecognizerConfig GetConfig() const override { return config_; } + + private: + void InitFeatConfig() { + // Paraformer models assume input samples are in the range + // [-32768, 32767], so we set normalize_samples to false + config_.feat_config.normalize_samples = false; + config_.feat_config.window_type = "hamming"; + config_.feat_config.high_freq = 0; + config_.feat_config.snip_edges = true; + } + + std::vector ApplyLFR(const std::vector &in) const { + int32_t lfr_window_size = model_->LfrWindowSize(); + int32_t lfr_window_shift = model_->LfrWindowShift(); + int32_t in_feat_dim = config_.feat_config.feature_dim; + + int32_t in_num_frames = in.size() / in_feat_dim; + int32_t out_num_frames = + (in_num_frames - lfr_window_size) / lfr_window_shift + 1; + int32_t out_feat_dim = in_feat_dim * lfr_window_size; + + std::vector out(out_num_frames * out_feat_dim); + + const float *p_in = in.data(); + float *p_out = out.data(); + + for (int32_t i = 0; i != out_num_frames; ++i) { + std::copy(p_in, p_in + out_feat_dim, p_out); + + p_out += out_feat_dim; + p_in += lfr_window_shift * in_feat_dim; + } + + return out; + } + + void ApplyCMVN(std::vector *v) const { + const std::vector &neg_mean = model_->NegativeMean(); + const std::vector &inv_stddev = model_->InverseStdDev(); + + int32_t dim = neg_mean.size(); + int32_t num_frames = v->size() / dim; + + float *p = v->data(); + + for (int32_t i = 0; i != num_frames; ++i) { + for (int32_t k = 0; k != dim; ++k) { + p[k] = (p[k] + neg_mean[k]) * inv_stddev[k]; + } + + p += dim; + } + } + + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_PARAFORMER_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-sense-voice-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-sense-voice-impl.h new file mode 100644 index 00000000..9fabf8ec --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-sense-voice-impl.h @@ -0,0 +1,351 @@ +// sherpa-mnn/csrc/offline-recognizer-sense-voice-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_ + +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/offline-ctc-greedy-search-decoder.h" +#include "sherpa-mnn/csrc/offline-model-config.h" +#include "sherpa-mnn/csrc/offline-recognizer-impl.h" +#include "sherpa-mnn/csrc/offline-recognizer.h" +#include "sherpa-mnn/csrc/offline-sense-voice-model.h" +#include "sherpa-mnn/csrc/pad-sequence.h" +#include "sherpa-mnn/csrc/symbol-table.h" + +namespace sherpa_mnn { + +static OfflineRecognitionResult ConvertSenseVoiceResult( + const OfflineCtcDecoderResult &src, const SymbolTable &sym_table, + int32_t frame_shift_ms, int32_t subsampling_factor) { + OfflineRecognitionResult r; + r.tokens.reserve(src.tokens.size()); + r.timestamps.reserve(src.timestamps.size()); + + std::string text; + + for (int32_t i = 4; i < src.tokens.size(); ++i) { + auto sym = sym_table[src.tokens[i]]; + text.append(sym); + + r.tokens.push_back(std::move(sym)); + } + r.text = std::move(text); + + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; + + for (int32_t i = 4; i < src.timestamps.size(); ++i) { + float time = frame_shift_s * (src.timestamps[i] - 4); + r.timestamps.push_back(time); + } + + r.words = std::move(src.words); + + // parse lang, emotion and event from tokens. + if (src.tokens.size() >= 3) { + r.lang = sym_table[src.tokens[0]]; + r.emotion = sym_table[src.tokens[1]]; + r.event = sym_table[src.tokens[2]]; + } + + return r; +} + +class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerSenseVoiceImpl( + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(config), + config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique(config.model_config)) { + const auto &meta_data = model_->metaData(); + if (config.decoding_method == "greedy_search") { + decoder_ = + std::make_unique(meta_data.blank_id); + } else { + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", + config.decoding_method.c_str()); + exit(-1); + } + + InitFeatConfig(); + } + + template + OfflineRecognizerSenseVoiceImpl(Manager *mgr, + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(mgr, config), + config_(config), + symbol_table_(mgr, config_.model_config.tokens), + model_(std::make_unique(mgr, + config.model_config)) { + const auto &meta_data = model_->metaData(); + if (config.decoding_method == "greedy_search") { + decoder_ = + std::make_unique(meta_data.blank_id); + } else { + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", + config.decoding_method.c_str()); + exit(-1); + } + + InitFeatConfig(); + } + + std::unique_ptr CreateStream() const override { + return std::make_unique(config_.feat_config); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + if (n == 1) { + DecodeOneStream(ss[0]); + return; + } + + const auto &meta_data = model_->metaData(); + // 1. Apply LFR + // 2. Apply CMVN + // + // Please refer to + // https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45555.pdf + // for what LFR means + // + // "Lower Frame Rate Neural Network Acoustic Models" + auto memory_info = + (MNNAllocator*)(nullptr); + + std::vector features; + features.reserve(n); + + int32_t feat_dim = config_.feat_config.feature_dim * meta_data.window_size; + + std::vector> features_vec(n); + std::vector features_length_vec(n); + for (int32_t i = 0; i != n; ++i) { + std::vector f = ss[i]->GetFrames(); + + f = ApplyLFR(f); + ApplyCMVN(&f); + + int32_t num_frames = f.size() / feat_dim; + features_vec[i] = std::move(f); + + features_length_vec[i] = num_frames; + + std::array shape = {num_frames, feat_dim}; + + MNN::Express::VARP x = MNNUtilsCreateTensor( + memory_info, features_vec[i].data(), features_vec[i].size(), + shape.data(), shape.size()); + features.push_back(std::move(x)); + } + + std::vector features_pointer(n); + for (int32_t i = 0; i != n; ++i) { + features_pointer[i] = features[i]; + } + + std::array features_length_shape = {n}; + MNN::Express::VARP x_length = MNNUtilsCreateTensor( + memory_info, features_length_vec.data(), n, + features_length_shape.data(), features_length_shape.size()); + + // Caution(fangjun): We cannot pad it with log(eps), + // i.e., -23.025850929940457f + MNN::Express::VARP x = PadSequence(model_->Allocator(), features_pointer, 0); + + int32_t language = 0; + if (config_.model_config.sense_voice.language.empty()) { + language = 0; + } else if (meta_data.lang2id.count( + config_.model_config.sense_voice.language)) { + language = + meta_data.lang2id.at(config_.model_config.sense_voice.language); + } else { + SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.", + config_.model_config.sense_voice.language.c_str()); + } + + std::vector language_array(n); + std::fill(language_array.begin(), language_array.end(), language); + + std::vector text_norm_array(n); + std::fill(text_norm_array.begin(), text_norm_array.end(), + config_.model_config.sense_voice.use_itn + ? meta_data.with_itn_id + : meta_data.without_itn_id); + + MNN::Express::VARP language_tensor = MNNUtilsCreateTensor( + memory_info, language_array.data(), n, features_length_shape.data(), + features_length_shape.size()); + + MNN::Express::VARP text_norm_tensor = MNNUtilsCreateTensor( + memory_info, text_norm_array.data(), n, features_length_shape.data(), + features_length_shape.size()); + + MNN::Express::VARP logits{nullptr}; + logits = model_->Forward(std::move(x), std::move(x_length), + std::move(language_tensor), + std::move(text_norm_tensor)); + // decoder_->Decode() requires that logits_length is of dtype int64 + std::vector features_length_vec_64; + features_length_vec_64.reserve(n); + for (auto i : features_length_vec) { + i += 4; + features_length_vec_64.push_back(i); + } + + MNN::Express::VARP logits_length = MNNUtilsCreateTensor( + memory_info, features_length_vec_64.data(), n, + features_length_shape.data(), features_length_shape.size()); + + auto results = + decoder_->Decode(std::move(logits), std::move(logits_length)); + + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = meta_data.window_shift; + for (int32_t i = 0; i != n; ++i) { + auto r = ConvertSenseVoiceResult(results[i], symbol_table_, + frame_shift_ms, subsampling_factor); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + ss[i]->SetResult(r); + } + } + + OfflineRecognizerConfig GetConfig() const override { return config_; } + + private: + void DecodeOneStream(OfflineStream *s) const { + const auto &meta_data = model_->metaData(); + + auto memory_info = + (MNNAllocator*)(nullptr); + + int32_t feat_dim = config_.feat_config.feature_dim * meta_data.window_size; + std::vector f = s->GetFrames(); + f = ApplyLFR(f); + ApplyCMVN(&f); + int32_t num_frames = f.size() / feat_dim; + std::array shape = {1, num_frames, feat_dim}; + MNN::Express::VARP x = MNNUtilsCreateTensor(memory_info, f.data(), f.size(), + shape.data(), shape.size()); + + int scale_shape = 1; + + MNN::Express::VARP x_length = + MNNUtilsCreateTensor(memory_info, &num_frames, 1, &scale_shape, 1); + + int32_t language = 0; + if (config_.model_config.sense_voice.language.empty()) { + language = 0; + } else if (meta_data.lang2id.count( + config_.model_config.sense_voice.language)) { + language = + meta_data.lang2id.at(config_.model_config.sense_voice.language); + } else { + SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.", + config_.model_config.sense_voice.language.c_str()); + } + + int32_t text_norm = config_.model_config.sense_voice.use_itn + ? meta_data.with_itn_id + : meta_data.without_itn_id; + + MNN::Express::VARP language_tensor = + MNNUtilsCreateTensor(memory_info, &language, 1, &scale_shape, 1); + + MNN::Express::VARP text_norm_tensor = + MNNUtilsCreateTensor(memory_info, &text_norm, 1, &scale_shape, 1); + + MNN::Express::VARP logits{nullptr}; + logits = model_->Forward(std::move(x), std::move(x_length), + std::move(language_tensor), + std::move(text_norm_tensor)); + + int new_num_frames = num_frames + 4; + MNN::Express::VARP logits_length = MNNUtilsCreateTensor( + memory_info, &new_num_frames, 1, &scale_shape, 1); + + auto results = + decoder_->Decode(std::move(logits), std::move(logits_length)); + + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = meta_data.window_shift; + auto r = ConvertSenseVoiceResult(results[0], symbol_table_, frame_shift_ms, + subsampling_factor); + + r.text = ApplyInverseTextNormalization(std::move(r.text)); + s->SetResult(r); + } + + void InitFeatConfig() { + const auto &meta_data = model_->metaData(); + + config_.feat_config.normalize_samples = meta_data.normalize_samples; + config_.feat_config.window_type = "hamming"; + config_.feat_config.high_freq = 0; + config_.feat_config.snip_edges = true; + } + std::vector ApplyLFR(const std::vector &in) const { + const auto &meta_data = model_->metaData(); + + int32_t lfr_window_size = meta_data.window_size; + int32_t lfr_window_shift = meta_data.window_shift; + int32_t in_feat_dim = config_.feat_config.feature_dim; + + int32_t in_num_frames = in.size() / in_feat_dim; + int32_t out_num_frames = + (in_num_frames - lfr_window_size) / lfr_window_shift + 1; + int32_t out_feat_dim = in_feat_dim * lfr_window_size; + + std::vector out(out_num_frames * out_feat_dim); + + const float *p_in = in.data(); + float *p_out = out.data(); + + for (int32_t i = 0; i != out_num_frames; ++i) { + std::copy(p_in, p_in + out_feat_dim, p_out); + + p_out += out_feat_dim; + p_in += lfr_window_shift * in_feat_dim; + } + + return out; + } + + void ApplyCMVN(std::vector *v) const { + const auto &meta_data = model_->metaData(); + + const std::vector &neg_mean = meta_data.neg_mean; + const std::vector &inv_stddev = meta_data.inv_stddev; + + int32_t dim = neg_mean.size(); + int32_t num_frames = v->size() / dim; + + float *p = v->data(); + + for (int32_t i = 0; i != num_frames; ++i) { + for (int32_t k = 0; k != dim; ++k) { + p[k] = (p[k] + neg_mean[k]) * inv_stddev[k]; + } + + p += dim; + } + } + + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_SENSE_VOICE_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-transducer-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-transducer-impl.h new file mode 100644 index 00000000..107d644e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-transducer-impl.h @@ -0,0 +1,308 @@ +// sherpa-mnn/csrc/offline-recognizer-transducer-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ + +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/context-graph.h" +#include "sherpa-mnn/csrc/log.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-recognizer-impl.h" +#include "sherpa-mnn/csrc/offline-recognizer.h" +#include "sherpa-mnn/csrc/offline-transducer-decoder.h" +#include "sherpa-mnn/csrc/offline-transducer-greedy-search-decoder.h" +#include "sherpa-mnn/csrc/offline-transducer-model.h" +#include "sherpa-mnn/csrc/offline-transducer-modified-beam-search-decoder.h" +#include "sherpa-mnn/csrc/pad-sequence.h" +#include "sherpa-mnn/csrc/symbol-table.h" +#include "sherpa-mnn/csrc/utils.h" +#include "ssentencepiece/csrc/ssentencepiece.h" + +namespace sherpa_mnn { + +static OfflineRecognitionResult Convert( + const OfflineTransducerDecoderResult &src, const SymbolTable &sym_table, + int32_t frame_shift_ms, int32_t subsampling_factor) { + OfflineRecognitionResult r; + r.tokens.reserve(src.tokens.size()); + r.timestamps.reserve(src.timestamps.size()); + + std::string text; + for (auto i : src.tokens) { + auto sym = sym_table[i]; + text.append(sym); + + if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { + // for bpe models with byte_fallback, + // (but don't rewrite printable characters 0x20..0x7e, + // which collide with standard BPE units) + std::ostringstream os; + os << "<0x" << std::hex << std::uppercase + << (static_cast(sym[0]) & 0xff) << ">"; + sym = os.str(); + } + + r.tokens.push_back(std::move(sym)); + } + if (sym_table.IsByteBpe()) { + text = sym_table.DecodeByteBpe(text); + } + + r.text = std::move(text); + + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; + for (auto t : src.timestamps) { + float time = frame_shift_s * t; + r.timestamps.push_back(time); + } + + return r; +} + +class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerTransducerImpl( + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(config), + config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique(config_.model_config)) { + if (symbol_table_.Contains("")) { + unk_id_ = symbol_table_[""]; + } + + if (config_.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + model_.get(), unk_id_, config_.blank_penalty); + } else if (config_.decoding_method == "modified_beam_search") { + if (!config_.lm_config.model.empty()) { + lm_ = OfflineLM::Create(config.lm_config); + } + + if (!config_.model_config.bpe_vocab.empty()) { + bpe_encoder_ = std::make_unique( + config_.model_config.bpe_vocab); + } + + if (!config_.hotwords_file.empty()) { + InitHotwords(); + } + + decoder_ = std::make_unique( + model_.get(), lm_.get(), config_.max_active_paths, + config_.lm_config.scale, unk_id_, config_.blank_penalty); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config_.decoding_method.c_str()); + exit(-1); + } + } + + template + explicit OfflineRecognizerTransducerImpl( + Manager *mgr, const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(mgr, config), + config_(config), + symbol_table_(mgr, config_.model_config.tokens), + model_(std::make_unique(mgr, + config_.model_config)) { + if (symbol_table_.Contains("")) { + unk_id_ = symbol_table_[""]; + } + + if (config_.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + model_.get(), unk_id_, config_.blank_penalty); + } else if (config_.decoding_method == "modified_beam_search") { + if (!config_.lm_config.model.empty()) { + lm_ = OfflineLM::Create(mgr, config.lm_config); + } + + if (!config_.model_config.bpe_vocab.empty()) { + auto buf = ReadFile(mgr, config_.model_config.bpe_vocab); + std::istringstream iss(std::string(buf.begin(), buf.end())); + bpe_encoder_ = std::make_unique(iss); + } + + if (!config_.hotwords_file.empty()) { + InitHotwords(mgr); + } + + decoder_ = std::make_unique( + model_.get(), lm_.get(), config_.max_active_paths, + config_.lm_config.scale, unk_id_, config_.blank_penalty); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config_.decoding_method.c_str()); + exit(-1); + } + } + + std::unique_ptr CreateStream( + const std::string &hotwords) const override { + auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); + std::istringstream is(hws); + std::vector> current; + std::vector current_scores; + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, + bpe_encoder_.get(), ¤t, ¤t_scores)) { + SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", + hotwords.c_str()); + } + + int32_t num_default_hws = hotwords_.size(); + int32_t num_hws = current.size(); + + current.insert(current.end(), hotwords_.begin(), hotwords_.end()); + + if (!current_scores.empty() && !boost_scores_.empty()) { + current_scores.insert(current_scores.end(), boost_scores_.begin(), + boost_scores_.end()); + } else if (!current_scores.empty() && boost_scores_.empty()) { + current_scores.insert(current_scores.end(), num_default_hws, + config_.hotwords_score); + } else if (current_scores.empty() && !boost_scores_.empty()) { + current_scores.insert(current_scores.end(), num_hws, + config_.hotwords_score); + current_scores.insert(current_scores.end(), boost_scores_.begin(), + boost_scores_.end()); + } else { + // Do nothing. + } + + auto context_graph = std::make_shared( + current, config_.hotwords_score, current_scores); + return std::make_unique(config_.feat_config, context_graph); + } + + std::unique_ptr CreateStream() const override { + return std::make_unique(config_.feat_config, + hotwords_graph_); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + auto memory_info = + (MNNAllocator*)(nullptr); + + int32_t feat_dim = ss[0]->FeatureDim(); + + std::vector features; + + features.reserve(n); + + std::vector> features_vec(n); + std::vector features_length_vec(n); + for (int32_t i = 0; i != n; ++i) { + auto f = ss[i]->GetFrames(); + int32_t num_frames = f.size() / feat_dim; + + features_length_vec[i] = num_frames; + features_vec[i] = std::move(f); + + std::array shape = {num_frames, feat_dim}; + + MNN::Express::VARP x = MNNUtilsCreateTensor( + memory_info, features_vec[i].data(), features_vec[i].size(), + shape.data(), shape.size()); + features.push_back(std::move(x)); + } + + std::vector features_pointer(n); + for (int32_t i = 0; i != n; ++i) { + features_pointer[i] = features[i]; + } + + std::array features_length_shape = {n}; + MNN::Express::VARP x_length = MNNUtilsCreateTensor( + memory_info, features_length_vec.data(), n, + features_length_shape.data(), features_length_shape.size()); + + MNN::Express::VARP x = PadSequence(model_->Allocator(), features_pointer, + -23.025850929940457f); + + auto t = model_->RunEncoder(std::move(x), std::move(x_length)); + auto results = + decoder_->Decode(std::move(t.first), std::move(t.second), ss, n); + + int32_t frame_shift_ms = 10; + for (int32_t i = 0; i != n; ++i) { + auto r = Convert(results[i], symbol_table_, frame_shift_ms, + model_->SubsamplingFactor()); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + + ss[i]->SetResult(r); + } + } + + OfflineRecognizerConfig GetConfig() const override { return config_; } + + void InitHotwords() { + // each line in hotwords_file contains space-separated words + + std::ifstream is(config_.hotwords_file); + if (!is) { + SHERPA_ONNX_LOGE("Open hotwords file failed: %s", + config_.hotwords_file.c_str()); + exit(-1); + } + + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + SHERPA_ONNX_LOGE( + "Failed to encode some hotwords, skip them already, see logs above " + "for details."); + } + hotwords_graph_ = std::make_shared( + hotwords_, config_.hotwords_score, boost_scores_); + } + + template + void InitHotwords(Manager *mgr) { + // each line in hotwords_file contains space-separated words + + auto buf = ReadFile(mgr, config_.hotwords_file); + + std::istringstream is(std::string(buf.begin(), buf.end())); + + if (!is) { + SHERPA_ONNX_LOGE("Open hotwords file failed: %s", + config_.hotwords_file.c_str()); + exit(-1); + } + + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + SHERPA_ONNX_LOGE( + "Failed to encode some hotwords, skip them already, see logs above " + "for details."); + } + hotwords_graph_ = std::make_shared( + hotwords_, config_.hotwords_score, boost_scores_); + } + + private: + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::vector> hotwords_; + std::vector boost_scores_; + ContextGraphPtr hotwords_graph_; + std::unique_ptr bpe_encoder_; + std::unique_ptr model_; + std::unique_ptr decoder_; + std::unique_ptr lm_; + int32_t unk_id_ = -1; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-transducer-nemo-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-transducer-nemo-impl.h new file mode 100644 index 00000000..2d27553e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-transducer-nemo-impl.h @@ -0,0 +1,189 @@ +// sherpa-mnn/csrc/offline-recognizer-transducer-nemo-impl.h +// +// Copyright (c) 2022-2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ + +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-recognizer-impl.h" +#include "sherpa-mnn/csrc/offline-recognizer.h" +#include "sherpa-mnn/csrc/offline-transducer-greedy-search-nemo-decoder.h" +#include "sherpa-mnn/csrc/offline-transducer-nemo-model.h" +#include "sherpa-mnn/csrc/pad-sequence.h" +#include "sherpa-mnn/csrc/symbol-table.h" +#include "sherpa-mnn/csrc/transpose.h" +#include "sherpa-mnn/csrc/utils.h" + +namespace sherpa_mnn { + +// defined in ./offline-recognizer-transducer-impl.h +OfflineRecognitionResult Convert(const OfflineTransducerDecoderResult &src, + const SymbolTable &sym_table, + int32_t frame_shift_ms, + int32_t subsampling_factor); + +class OfflineRecognizerTransducerNeMoImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerTransducerNeMoImpl( + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(config), + config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique( + config_.model_config)) { + if (config_.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + model_.get(), config_.blank_penalty); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config_.decoding_method.c_str()); + exit(-1); + } + PostInit(); + } + + template + explicit OfflineRecognizerTransducerNeMoImpl( + Manager *mgr, const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(mgr, config), + config_(config), + symbol_table_(mgr, config_.model_config.tokens), + model_(std::make_unique( + mgr, config_.model_config)) { + if (config_.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + model_.get(), config_.blank_penalty); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config_.decoding_method.c_str()); + exit(-1); + } + + PostInit(); + } + + std::unique_ptr CreateStream() const override { + return std::make_unique(config_.feat_config); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + auto memory_info = + (MNNAllocator*)(nullptr); + + int32_t feat_dim = ss[0]->FeatureDim(); + + std::vector features; + + features.reserve(n); + + std::vector> features_vec(n); + std::vector features_length_vec(n); + for (int32_t i = 0; i != n; ++i) { + auto f = ss[i]->GetFrames(); + int32_t num_frames = f.size() / feat_dim; + + features_length_vec[i] = num_frames; + features_vec[i] = std::move(f); + + std::array shape = {num_frames, feat_dim}; + + MNN::Express::VARP x = MNNUtilsCreateTensor( + memory_info, features_vec[i].data(), features_vec[i].size(), + shape.data(), shape.size()); + features.push_back(std::move(x)); + } + + std::vector features_pointer(n); + for (int32_t i = 0; i != n; ++i) { + features_pointer[i] = features[i]; + } + + std::array features_length_shape = {n}; + MNN::Express::VARP x_length = MNNUtilsCreateTensor( + memory_info, features_length_vec.data(), n, + features_length_shape.data(), features_length_shape.size()); + + MNN::Express::VARP x = PadSequence(model_->Allocator(), features_pointer, 0); + + auto t = model_->RunEncoder(std::move(x), std::move(x_length)); + // t[0] encoder_out, float tensor, (batch_size, dim, T) + // t[1] encoder_out_length, int64 tensor, (batch_size,) + + MNN::Express::VARP encoder_out = Transpose12(model_->Allocator(), t[0]); + + auto results = decoder_->Decode(std::move(encoder_out), std::move(t[1])); + + int32_t frame_shift_ms = 10; + for (int32_t i = 0; i != n; ++i) { + auto r = Convert(results[i], symbol_table_, frame_shift_ms, + model_->SubsamplingFactor()); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + + ss[i]->SetResult(r); + } + } + + OfflineRecognizerConfig GetConfig() const override { return config_; } + + private: + void PostInit() { + config_.feat_config.nemo_normalize_type = + model_->FeatureNormalizationMethod(); + + config_.feat_config.dither = 0; + + if (model_->IsGigaAM()) { + config_.feat_config.low_freq = 0; + config_.feat_config.high_freq = 8000; + config_.feat_config.remove_dc_offset = false; + config_.feat_config.preemph_coeff = 0; + config_.feat_config.window_type = "hann"; + config_.feat_config.feature_dim = 64; + } else { + config_.feat_config.low_freq = 0; + // config_.feat_config.high_freq = 8000; + config_.feat_config.is_librosa = true; + config_.feat_config.remove_dc_offset = false; + // config_.feat_config.window_type = "hann"; + } + + int32_t vocab_size = model_->VocabSize(); + + // check the blank ID + if (!symbol_table_.Contains("")) { + SHERPA_ONNX_LOGE("tokens.txt does not include the blank token "); + exit(-1); + } + + if (symbol_table_[""] != vocab_size - 1) { + SHERPA_ONNX_LOGE(" is not the last token!"); + exit(-1); + } + + if (symbol_table_.NumSymbols() != vocab_size) { + SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)", + symbol_table_.NumSymbols(), vocab_size); + exit(-1); + } + } + + private: + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-whisper-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-whisper-impl.h new file mode 100644 index 00000000..0772e0d6 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer-whisper-impl.h @@ -0,0 +1,173 @@ +// sherpa-mnn/csrc/offline-recognizer-whisper-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_ + +#include +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/offline-model-config.h" +#include "sherpa-mnn/csrc/offline-recognizer-impl.h" +#include "sherpa-mnn/csrc/offline-recognizer.h" +#include "sherpa-mnn/csrc/offline-whisper-decoder.h" +#include "sherpa-mnn/csrc/offline-whisper-greedy-search-decoder.h" +#include "sherpa-mnn/csrc/offline-whisper-model.h" +#include "sherpa-mnn/csrc/symbol-table.h" +#include "sherpa-mnn/csrc/transpose.h" + +namespace sherpa_mnn { + +class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(config), + config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique(config.model_config)) { + Init(); + } + + template + OfflineRecognizerWhisperImpl(Manager *mgr, + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(mgr, config), + config_(config), + symbol_table_(mgr, config_.model_config.tokens), + model_( + std::make_unique(mgr, config.model_config)) { + Init(); + } + + void Init() { + // tokens.txt from whisper is base64 encoded, so we need to decode it + symbol_table_.ApplyBase64Decode(); + + if (config_.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + config_.model_config.whisper, model_.get()); + } else { + SHERPA_ONNX_LOGE( + "Only greedy_search is supported at present for whisper. Given %s", + config_.decoding_method.c_str()); + exit(-1); + } + } + + std::unique_ptr CreateStream() const override { + WhisperTag tag; + tag.dim = model_->FeatureDim(); + return std::make_unique(tag); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + // batch decoding is not implemented yet + for (int32_t i = 0; i != n; ++i) { + DecodeStream(ss[i]); + } + } + + void SetConfig(const OfflineRecognizerConfig &config) override { + config_.model_config.whisper = config.model_config.whisper; + } + + OfflineRecognizerConfig GetConfig() const override { return config_; } + + private: + void DecodeStream(OfflineStream *s) const { + decoder_->SetConfig(config_.model_config.whisper); + + int32_t max_num_frames = 3000; + auto memory_info = + (MNNAllocator*)(nullptr); + + int32_t feat_dim = s->FeatureDim(); + std::vector f = s->GetFrames(); + int32_t num_frames = f.size() / feat_dim; + + // we use 50 here so that there will be some zero tail paddings + if (num_frames >= max_num_frames - 50) { + SHERPA_ONNX_LOGE( + "Only waves less than 30 seconds are supported. We process only the " + "first 30 seconds and discard the remaining data"); + num_frames = max_num_frames - 50; + } + + model_->NormalizeFeatures(f.data(), num_frames, feat_dim); + + // note that 1000 is an experience-value. + // You can replace 1000 by other values, say, 100. + // + // Since we have removed the 30 seconds constraint, we need + // tail_padding_frames so that whisper is able to detect the eot token. + int32_t tail_padding_frames = 1000; + + if (config_.model_config.whisper.tail_paddings > 0) { + tail_padding_frames = config_.model_config.whisper.tail_paddings; + } + + int32_t actual_frames = + std::min(num_frames + tail_padding_frames, max_num_frames); + + std::array shape{1, actual_frames, feat_dim}; + + MNN::Express::VARP mel = MNNUtilsCreateTensor( + model_->Allocator(), shape.data(), shape.size()); + + float *p_mel = mel->writeMap(); + std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel); + + std::fill_n(p_mel + num_frames * feat_dim, + (actual_frames - num_frames) * feat_dim, 0); + + mel = Transpose12(model_->Allocator(), mel); + + auto cross_kv = model_->ForwardEncoder(std::move(mel)); + + auto results = decoder_->Decode(std::move(cross_kv.first), + std::move(cross_kv.second), num_frames); + + auto r = Convert(results[0], symbol_table_); + s->SetResult(r); + } + + private: + OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, + const SymbolTable &sym_table) const { + OfflineRecognitionResult r; + r.tokens.reserve(src.tokens.size()); + + std::string text; + for (auto i : src.tokens) { + if (!sym_table.Contains(i)) { + continue; + } + + std::string s = sym_table[i]; + s = ApplyInverseTextNormalization(s); + + text += s; + r.tokens.push_back(s); + } + + r.text = text; + r.lang = src.lang; + + return r; + } + + private: + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer.cc new file mode 100644 index 00000000..14f9823c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer.cc @@ -0,0 +1,186 @@ +// sherpa-mnn/csrc/offline-recognizer.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-recognizer.h" + +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-lm-config.h" +#include "sherpa-mnn/csrc/offline-recognizer-impl.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +void OfflineRecognizerConfig::Register(ParseOptions *po) { + feat_config.Register(po); + model_config.Register(po); + lm_config.Register(po); + ctc_fst_decoder_config.Register(po); + + po->Register( + "decoding-method", &decoding_method, + "decoding method," + "Valid values: greedy_search, modified_beam_search. " + "modified_beam_search is applicable only for transducer models."); + + po->Register("max-active-paths", &max_active_paths, + "Used only when decoding_method is modified_beam_search"); + + po->Register("blank-penalty", &blank_penalty, + "The penalty applied on blank symbol during decoding. " + "Note: It is a positive value. " + "Increasing value will lead to lower deletion at the cost" + "of higher insertions. " + "Currently only applicable for transducer models."); + + po->Register( + "hotwords-file", &hotwords_file, + "The file containing hotwords, one words/phrases per line, For example: " + "HELLO WORLD" + "你好世界"); + + po->Register("hotwords-score", &hotwords_score, + "The bonus score for each token in context word/phrase. " + "Used only when decoding_method is modified_beam_search"); + + po->Register( + "rule-fsts", &rule_fsts, + "If not empty, it specifies fsts for inverse text normalization. " + "If there are multiple fsts, they are separated by a comma."); + + po->Register( + "rule-fars", &rule_fars, + "If not empty, it specifies fst archives for inverse text normalization. " + "If there are multiple archives, they are separated by a comma."); +} + +bool OfflineRecognizerConfig::Validate() const { + if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) { + if (max_active_paths <= 0) { + SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d", + max_active_paths); + return false; + } + if (!lm_config.Validate()) { + return false; + } + } + + if (!hotwords_file.empty() && decoding_method != "modified_beam_search") { + SHERPA_ONNX_LOGE( + "Please use --decoding-method=modified_beam_search if you" + " provide --hotwords-file. Given --decoding-method='%s'", + decoding_method.c_str()); + return false; + } + + if (!ctc_fst_decoder_config.graph.empty() && + !ctc_fst_decoder_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in fst_decoder"); + return false; + } + + if (!hotwords_file.empty() && !FileExists(hotwords_file)) { + SHERPA_ONNX_LOGE("--hotwords-file: '%s' does not exist", + hotwords_file.c_str()); + return false; + } + + if (!rule_fsts.empty()) { + std::vector files; + SplitStringToVector(rule_fsts, ",", false, &files); + for (const auto &f : files) { + if (!FileExists(f)) { + SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str()); + return false; + } + } + } + + if (!rule_fars.empty()) { + std::vector files; + SplitStringToVector(rule_fars, ",", false, &files); + for (const auto &f : files) { + if (!FileExists(f)) { + SHERPA_ONNX_LOGE("Rule far '%s' does not exist. ", f.c_str()); + return false; + } + } + } + + return model_config.Validate(); +} + +std::string OfflineRecognizerConfig::ToString() const { + std::ostringstream os; + + os << "OfflineRecognizerConfig("; + os << "feat_config=" << feat_config.ToString() << ", "; + os << "model_config=" << model_config.ToString() << ", "; + os << "lm_config=" << lm_config.ToString() << ", "; + os << "ctc_fst_decoder_config=" << ctc_fst_decoder_config.ToString() << ", "; + os << "decoding_method=\"" << decoding_method << "\", "; + os << "max_active_paths=" << max_active_paths << ", "; + os << "hotwords_file=\"" << hotwords_file << "\", "; + os << "hotwords_score=" << hotwords_score << ", "; + os << "blank_penalty=" << blank_penalty << ", "; + os << "rule_fsts=\"" << rule_fsts << "\", "; + os << "rule_fars=\"" << rule_fars << "\")"; + + return os.str(); +} + +template +OfflineRecognizer::OfflineRecognizer(Manager *mgr, + const OfflineRecognizerConfig &config) + : impl_(OfflineRecognizerImpl::Create(mgr, config)) {} + +OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config) + : impl_(OfflineRecognizerImpl::Create(config)) {} + +OfflineRecognizer::~OfflineRecognizer() = default; + +std::unique_ptr OfflineRecognizer::CreateStream( + const std::string &hotwords) const { + return impl_->CreateStream(hotwords); +} + +std::unique_ptr OfflineRecognizer::CreateStream() const { + return impl_->CreateStream(); +} + +void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const { + impl_->DecodeStreams(ss, n); +} + +void OfflineRecognizer::SetConfig(const OfflineRecognizerConfig &config) { + impl_->SetConfig(config); +} + +OfflineRecognizerConfig OfflineRecognizer::GetConfig() const { + return impl_->GetConfig(); +} + +#if __ANDROID_API__ >= 9 +template OfflineRecognizer::OfflineRecognizer( + AAssetManager *mgr, const OfflineRecognizerConfig &config); +#endif + +#if __OHOS__ +template OfflineRecognizer::OfflineRecognizer( + NativeResourceManager *mgr, const OfflineRecognizerConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer.h new file mode 100644 index 00000000..aeaed769 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-recognizer.h @@ -0,0 +1,131 @@ +// sherpa-mnn/csrc/offline-recognizer.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_ + +#include +#include +#include + +#include "sherpa-mnn/csrc/features.h" +#include "sherpa-mnn/csrc/offline-ctc-fst-decoder-config.h" +#include "sherpa-mnn/csrc/offline-lm-config.h" +#include "sherpa-mnn/csrc/offline-model-config.h" +#include "sherpa-mnn/csrc/offline-stream.h" +#include "sherpa-mnn/csrc/offline-transducer-model-config.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineRecognitionResult; + +struct OfflineRecognizerConfig { + FeatureExtractorConfig feat_config; + OfflineModelConfig model_config; + OfflineLMConfig lm_config; + OfflineCtcFstDecoderConfig ctc_fst_decoder_config; + + std::string decoding_method = "greedy_search"; + int32_t max_active_paths = 4; + + std::string hotwords_file; + float hotwords_score = 1.5; + + float blank_penalty = 0.0; + + // If there are multiple rules, they are applied from left to right. + std::string rule_fsts; + + // If there are multiple FST archives, they are applied from left to right. + std::string rule_fars; + + // only greedy_search is implemented + // TODO(fangjun): Implement modified_beam_search + + OfflineRecognizerConfig() = default; + OfflineRecognizerConfig( + const FeatureExtractorConfig &feat_config, + const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config, + const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, + const std::string &decoding_method, int32_t max_active_paths, + const std::string &hotwords_file, float hotwords_score, + float blank_penalty, const std::string &rule_fsts, + const std::string &rule_fars) + : feat_config(feat_config), + model_config(model_config), + lm_config(lm_config), + ctc_fst_decoder_config(ctc_fst_decoder_config), + decoding_method(decoding_method), + max_active_paths(max_active_paths), + hotwords_file(hotwords_file), + hotwords_score(hotwords_score), + blank_penalty(blank_penalty), + rule_fsts(rule_fsts), + rule_fars(rule_fars) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +class OfflineRecognizerImpl; + +class OfflineRecognizer { + public: + ~OfflineRecognizer(); + + template + OfflineRecognizer(Manager *mgr, const OfflineRecognizerConfig &config); + + explicit OfflineRecognizer(const OfflineRecognizerConfig &config); + + /// Create a stream for decoding. + std::unique_ptr CreateStream() const; + + /** Create a stream for decoding. + * + * @param The hotwords for this string, it might contain several hotwords, + * the hotwords are separated by "/". In each of the hotwords, there + * are cjkchars or bpes, the bpe/cjkchar are separated by space (" "). + * For example, hotwords I LOVE YOU and HELLO WORLD, looks like: + * + * "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD" + */ + std::unique_ptr CreateStream( + const std::string &hotwords) const; + + /** Decode a single stream + * + * @param s The stream to decode. + */ + void DecodeStream(OfflineStream *s) const { + OfflineStream *ss[1] = {s}; + DecodeStreams(ss, 1); + } + + /** Decode a list of streams. + * + * @param ss Pointer to an array of streams. + * @param n Size of the input array. + */ + void DecodeStreams(OfflineStream **ss, int32_t n) const; + + /** Onnxruntime Session objects are not affected by this method. + * The exact behavior can be defined by a specific recognizer impl. + * For instance, for the whisper recognizer, you can retrieve the language and + * task from the config and ignore any remaining fields in `config`. + */ + void SetConfig(const OfflineRecognizerConfig &config); + + OfflineRecognizerConfig GetConfig() const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-rnn-lm.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-rnn-lm.cc new file mode 100644 index 00000000..60075a7d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-rnn-lm.cc @@ -0,0 +1,104 @@ +// sherpa-mnn/csrc/offline-rnn-lm.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-rnn-lm.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OfflineRnnLM::Impl { + public: + explicit Impl(const OfflineLMConfig &config) + : config_(config), + sess_opts_{GetSessionOptions(config)}, + allocator_{} { + auto buf = ReadFile(config_.model); + Init(buf.data(), buf.size()); + } + + template + Impl(Manager *mgr, const OfflineLMConfig &config) + : config_(config), + sess_opts_{GetSessionOptions(config)}, + allocator_{} { + auto buf = ReadFile(mgr, config_.model); + Init(buf.data(), buf.size()); + } + + MNN::Express::VARP Rescore(MNN::Express::VARP x, MNN::Express::VARP x_lens) { + std::vector inputs = {std::move(x), std::move(x_lens)}; + + auto out = + sess_->onForward(inputs); + + return std::move(out[0]); + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + } + + private: + OfflineLMConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; +}; + +OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineRnnLM::OfflineRnnLM(Manager *mgr, const OfflineLMConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineRnnLM::~OfflineRnnLM() = default; + +MNN::Express::VARP OfflineRnnLM::Rescore(MNN::Express::VARP x, MNN::Express::VARP x_lens) { + return impl_->Rescore(std::move(x), std::move(x_lens)); +} + +#if __ANDROID_API__ >= 9 +template OfflineRnnLM::OfflineRnnLM(AAssetManager *mgr, + const OfflineLMConfig &config); +#endif + +#if __OHOS__ +template OfflineRnnLM::OfflineRnnLM(NativeResourceManager *mgr, + const OfflineLMConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-rnn-lm.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-rnn-lm.h new file mode 100644 index 00000000..808ce21f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-rnn-lm.h @@ -0,0 +1,44 @@ +// sherpa-mnn/csrc/offline-rnn-lm.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-lm-config.h" +#include "sherpa-mnn/csrc/offline-lm.h" + +namespace sherpa_mnn { + +class OfflineRnnLM : public OfflineLM { + public: + ~OfflineRnnLM() override; + + explicit OfflineRnnLM(const OfflineLMConfig &config); + + template + OfflineRnnLM(Manager *mgr, const OfflineLMConfig &config); + + /** Rescore a batch of sentences. + * + * @param x A 2-D tensor of shape (N, L) with data type int64. + * @param x_lens A 1-D tensor of shape (N,) with data type int64. + * It contains number of valid tokens in x before padding. + * @return Return a 1-D tensor of shape (N,) containing the log likelihood + * of each utterance. Its data type is float32. + * + * Caution: It returns log likelihood, not negative log likelihood (nll). + */ + MNN::Express::VARP Rescore(MNN::Express::VARP x, MNN::Express::VARP x_lens) override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-sense-voice-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-sense-voice-model-config.cc new file mode 100644 index 00000000..87e9e573 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-sense-voice-model-config.cc @@ -0,0 +1,55 @@ +// sherpa-mnn/csrc/offline-sense-voice-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-sense-voice-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineSenseVoiceModelConfig::Register(ParseOptions *po) { + po->Register("sense-voice-model", &model, + "Path to model.onnx of SenseVoice."); + po->Register( + "sense-voice-language", &language, + "Valid values: auto, zh, en, ja, ko, yue. If left empty, auto is used"); + po->Register( + "sense-voice-use-itn", &use_itn, + "True to enable inverse text normalization. False to disable it."); +} + +bool OfflineSenseVoiceModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("SenseVoice model '%s' does not exist", model.c_str()); + return false; + } + + if (!language.empty()) { + if (language != "auto" && language != "zh" && language != "en" && + language != "ja" && language != "ko" && language != "yue") { + SHERPA_ONNX_LOGE( + "Invalid sense-voice-language: '%s'. Valid values are: auto, zh, en, " + "ja, ko, yue. Or you can leave it empty to use 'auto'", + language.c_str()); + + return false; + } + } + + return true; +} + +std::string OfflineSenseVoiceModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSenseVoiceModelConfig("; + os << "model=\"" << model << "\", "; + os << "language=\"" << language << "\", "; + os << "use_itn=" << (use_itn ? "True" : "False") << ")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-sense-voice-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-sense-voice-model-config.h new file mode 100644 index 00000000..04f13ede --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-sense-voice-model-config.h @@ -0,0 +1,39 @@ +// sherpa-mnn/csrc/offline-sense-voice-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineSenseVoiceModelConfig { + std::string model; + + // "" or "auto" to let the model recognize the language + // valid values: + // zh, en, ja, ko, yue, auto + std::string language = "auto"; + + // true to use inverse text normalization + // false to not use inverse text normalization + bool use_itn = false; + + OfflineSenseVoiceModelConfig() = default; + explicit OfflineSenseVoiceModelConfig(const std::string &model, + const std::string &language, + bool use_itn) + : model(model), language(language), use_itn(use_itn) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-sense-voice-model-meta-data.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-sense-voice-model-meta-data.h new file mode 100644 index 00000000..d1052b86 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-sense-voice-model-meta-data.h @@ -0,0 +1,50 @@ +// sherpa-mnn/csrc/offline-sense-voice-model-meta-data.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_ + +#include +#include +#include + +namespace sherpa_mnn { + +struct OfflineSenseVoiceModelMetaData { + // ID for using inverse text normalization + int32_t with_itn_id; + + // ID for not using inverse text normalization + int32_t without_itn_id; + + int32_t window_size; // lfr_m + int32_t window_shift; // lfr_n + int32_t vocab_size; + + int32_t subsampling_factor = 1; + + // Usually 0 for SenseVoice models. + // 0 means samples are scaled to [-32768, 32767] before are sent to the + // feature extractor + int32_t normalize_samples = 0; + + int32_t blank_id = 0; + + // possible values: + // zh, en, ja, ko, yue, auto + // where + // zh is Chinese (Mandarin) + // en is English + // ja is Japanese + // ko is Korean + // yue is Cantonese + // auto is to let the model recognize the language + std::unordered_map lang2id; + + std::vector neg_mean; + std::vector inv_stddev; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_META_DATA_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-sense-voice-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-sense-voice-model.cc new file mode 100644 index 00000000..f1d643fe --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-sense-voice-model.cc @@ -0,0 +1,176 @@ +// sherpa-mnn/csrc/offline-sense-voice-model.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-sense-voice-model.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OfflineSenseVoiceModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.sense_voice.model); + Init(buf.data(), buf.size()); + } + + template + Impl(Manager *mgr, const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.sense_voice.model); + Init(buf.data(), buf.size()); + } + + MNN::Express::VARP Forward(MNN::Express::VARP features, MNN::Express::VARP features_length, + MNN::Express::VARP language, MNN::Express::VARP text_norm) { + std::vector inputs = { + std::move(features), + std::move(features_length), + std::move(language), + std::move(text_norm), + }; + + auto ans = + sess_->onForward(inputs); + return std::move(ans[0]); + } + + const OfflineSenseVoiceModelMetaData& metaData() const { + return meta_data_; + } + + MNNAllocator *Allocator() { return allocator_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(meta_data_.vocab_size, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(meta_data_.window_size, "lfr_window_size"); + SHERPA_ONNX_READ_META_DATA(meta_data_.window_shift, "lfr_window_shift"); + SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_samples, + "normalize_samples"); + + SHERPA_ONNX_READ_META_DATA(meta_data_.with_itn_id, "with_itn"); + + SHERPA_ONNX_READ_META_DATA(meta_data_.without_itn_id, "without_itn"); + + int32_t lang_auto = 0; + int32_t lang_zh = 0; + int32_t lang_en = 0; + int32_t lang_ja = 0; + int32_t lang_ko = 0; + int32_t lang_yue = 0; + + SHERPA_ONNX_READ_META_DATA(lang_auto, "lang_auto"); + SHERPA_ONNX_READ_META_DATA(lang_zh, "lang_zh"); + SHERPA_ONNX_READ_META_DATA(lang_en, "lang_en"); + SHERPA_ONNX_READ_META_DATA(lang_ja, "lang_ja"); + SHERPA_ONNX_READ_META_DATA(lang_ko, "lang_ko"); + SHERPA_ONNX_READ_META_DATA(lang_yue, "lang_yue"); + + meta_data_.lang2id = { + {"auto", lang_auto}, {"zh", lang_zh}, {"en", lang_en}, + {"ja", lang_ja}, {"ko", lang_ko}, {"yue", lang_yue}, + }; + + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.neg_mean, "neg_mean"); + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(meta_data_.inv_stddev, "inv_stddev"); + } + + private: + OfflineModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + OfflineSenseVoiceModelMetaData meta_data_; +}; + +OfflineSenseVoiceModel::OfflineSenseVoiceModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineSenseVoiceModel::OfflineSenseVoiceModel(Manager *mgr, + const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineSenseVoiceModel::~OfflineSenseVoiceModel() = default; + +MNN::Express::VARP OfflineSenseVoiceModel::Forward(MNN::Express::VARP features, + MNN::Express::VARP features_length, + MNN::Express::VARP language, + MNN::Express::VARP text_norm) const { + return impl_->Forward(std::move(features), std::move(features_length), + std::move(language), std::move(text_norm)); +} + +const OfflineSenseVoiceModelMetaData &OfflineSenseVoiceModel::metaData() + const { + return impl_->metaData(); +} + +MNNAllocator *OfflineSenseVoiceModel::Allocator() const { + return impl_->Allocator(); +} + +#if __ANDROID_API__ >= 9 +template OfflineSenseVoiceModel::OfflineSenseVoiceModel( + AAssetManager *mgr, const OfflineModelConfig &config); +#endif + +#if __OHOS__ +template OfflineSenseVoiceModel::OfflineSenseVoiceModel( + NativeResourceManager *mgr, const OfflineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-sense-voice-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-sense-voice-model.h new file mode 100644 index 00000000..d34bd14a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-sense-voice-model.h @@ -0,0 +1,55 @@ +// sherpa-mnn/csrc/offline-sense-voice-model.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_ + +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-model-config.h" +#include "sherpa-mnn/csrc/offline-sense-voice-model-meta-data.h" + +namespace sherpa_mnn { + +class OfflineSenseVoiceModel { + public: + explicit OfflineSenseVoiceModel(const OfflineModelConfig &config); + + template + OfflineSenseVoiceModel(Manager *mgr, const OfflineModelConfig &config); + + ~OfflineSenseVoiceModel(); + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int32_t. + * @param language A 1-D tensor of shape (N,) with dtype int32_t + * @param text_norm A 1-D tensor of shape (N,) with dtype int32_t + * + * @return Return logits of shape (N, T, C) with dtype float + * + * Note: The subsampling factor is 1 for SenseVoice, so there is + * no need to output logits_length. + */ + MNN::Express::VARP Forward(MNN::Express::VARP features, MNN::Express::VARP features_length, + MNN::Express::VARP language, MNN::Express::VARP text_norm) const; + + const OfflineSenseVoiceModelMetaData& metaData() const; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SENSE_VOICE_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization-impl.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization-impl.cc new file mode 100644 index 00000000..462659c9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization-impl.cc @@ -0,0 +1,60 @@ +// sherpa-mnn/csrc/offline-speaker-diarization-impl.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-speaker-diarization-impl.h" + +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-speaker-diarization-pyannote-impl.h" + +namespace sherpa_mnn { + +std::unique_ptr +OfflineSpeakerDiarizationImpl::Create( + const OfflineSpeakerDiarizationConfig &config) { + if (!config.segmentation.pyannote.model.empty()) { + return std::make_unique(config); + } + + SHERPA_ONNX_LOGE("Please specify a speaker segmentation model."); + + return nullptr; +} + +template +std::unique_ptr +OfflineSpeakerDiarizationImpl::Create( + Manager *mgr, const OfflineSpeakerDiarizationConfig &config) { + if (!config.segmentation.pyannote.model.empty()) { + return std::make_unique(mgr, config); + } + + SHERPA_ONNX_LOGE("Please specify a speaker segmentation model."); + + return nullptr; +} + +#if __ANDROID_API__ >= 9 +template std::unique_ptr +OfflineSpeakerDiarizationImpl::Create( + AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr +OfflineSpeakerDiarizationImpl::Create( + NativeResourceManager *mgr, const OfflineSpeakerDiarizationConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization-impl.h new file mode 100644 index 00000000..17bee615 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization-impl.h @@ -0,0 +1,39 @@ +// sherpa-mnn/csrc/offline-speaker-diarization-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_ + +#include +#include + +#include "sherpa-mnn/csrc/offline-speaker-diarization.h" +namespace sherpa_mnn { + +class OfflineSpeakerDiarizationImpl { + public: + static std::unique_ptr Create( + const OfflineSpeakerDiarizationConfig &config); + + template + static std::unique_ptr Create( + Manager *mgr, const OfflineSpeakerDiarizationConfig &config); + + virtual ~OfflineSpeakerDiarizationImpl() = default; + + virtual int32_t SampleRate() const = 0; + + // Note: Only config.clustering is used. All other fields in config are + // ignored + virtual void SetConfig(const OfflineSpeakerDiarizationConfig &config) = 0; + + virtual OfflineSpeakerDiarizationResult Process( + const float *audio, int32_t n, + OfflineSpeakerDiarizationProgressCallback callback = nullptr, + void *callback_arg = nullptr) const = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization-pyannote-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization-pyannote-impl.h new file mode 100644 index 00000000..143196b9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization-pyannote-impl.h @@ -0,0 +1,743 @@ +// sherpa-mnn/csrc/offline-speaker-diarization-pyannote-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ + +#include +#include +#include +#include +#include +#include + +#include "Eigen/Dense" +#include "sherpa-mnn/csrc/fast-clustering.h" +#include "sherpa-mnn/csrc/math.h" +#include "sherpa-mnn/csrc/offline-speaker-diarization-impl.h" +#include "sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model.h" +#include "sherpa-mnn/csrc/speaker-embedding-extractor.h" + +namespace sherpa_mnn { + +namespace { // NOLINT + +// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L41 +template +inline void hash_combine(std::size_t *seed, const T &v) { // NOLINT + std::hash hasher; + *seed ^= hasher(v) + 0x9e3779b9 + ((*seed) << 6) + ((*seed) >> 2); // NOLINT +} + +// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L47 +struct PairHash { + template + std::size_t operator()(const std::pair &pair) const { + std::size_t result = 0; + hash_combine(&result, pair.first); + hash_combine(&result, pair.second); + return result; + } +}; +} // namespace + +using Matrix2D = + Eigen::Matrix; + +using Matrix2DInt32 = + Eigen::Matrix; + +using FloatRowVector = Eigen::Matrix; +using Int32RowVector = Eigen::Matrix; + +using Int32Pair = std::pair; + +class OfflineSpeakerDiarizationPyannoteImpl + : public OfflineSpeakerDiarizationImpl { + public: + ~OfflineSpeakerDiarizationPyannoteImpl() override = default; + + explicit OfflineSpeakerDiarizationPyannoteImpl( + const OfflineSpeakerDiarizationConfig &config) + : config_(config), + segmentation_model_(config_.segmentation), + embedding_extractor_(config_.embedding), + clustering_(std::make_unique(config_.clustering)) { + Init(); + } + + template + OfflineSpeakerDiarizationPyannoteImpl( + Manager *mgr, const OfflineSpeakerDiarizationConfig &config) + : config_(config), + segmentation_model_(mgr, config_.segmentation), + embedding_extractor_(mgr, config_.embedding), + clustering_(std::make_unique(config_.clustering)) { + Init(); + } + + int32_t SampleRate() const override { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + + return meta_data.sample_rate; + } + + void SetConfig(const OfflineSpeakerDiarizationConfig &config) override { + if (!config.clustering.Validate()) { + SHERPA_ONNX_LOGE("Invalid clustering config. Skip it"); + return; + } + clustering_ = std::make_unique(config.clustering); + config_.clustering = config.clustering; + } + + OfflineSpeakerDiarizationResult Process( + const float *audio, int32_t n, + OfflineSpeakerDiarizationProgressCallback callback = nullptr, + void *callback_arg = nullptr) const override { + std::vector segmentations = RunSpeakerSegmentationModel(audio, n); + // segmentations[i] is for chunk_i + // Each matrix is of shape (num_frames, num_powerset_classes) + if (segmentations.empty()) { + return {}; + } + + std::vector labels; + labels.reserve(segmentations.size()); + + for (const auto &m : segmentations) { + labels.push_back(ToMultiLabel(m)); + } + + segmentations.clear(); + + if (labels.size() == 1) { + if (callback) { + callback(1, 1, callback_arg); + } + + return HandleOneChunkSpecialCase(labels[0], n); + } + + // labels[i] is a 0-1 matrix of shape (num_frames, num_speakers) + + // speaker count per frame + Int32RowVector speakers_per_frame = ComputeSpeakersPerFrame(labels); + + if (speakers_per_frame.maxCoeff() == 0) { + SHERPA_ONNX_LOGE("No speakers found in the audio samples"); + return {}; + } + + auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels); + + // The embedding model may output NaN. valid_indexes contains indexes + // in chunk_speaker_samples_list_pair.second that don't lead to + // NaN embeddings. + std::vector valid_indexes; + valid_indexes.reserve(chunk_speaker_samples_list_pair.second.size()); + + Matrix2D embeddings = + ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second, + &valid_indexes, std::move(callback), callback_arg); + + if (valid_indexes.size() != chunk_speaker_samples_list_pair.second.size()) { + std::vector chunk_speaker_pair; + std::vector> sample_indexes; + + chunk_speaker_pair.reserve(valid_indexes.size()); + sample_indexes.reserve(valid_indexes.size()); + for (auto i : valid_indexes) { + chunk_speaker_pair.push_back(chunk_speaker_samples_list_pair.first[i]); + sample_indexes.push_back( + std::move(chunk_speaker_samples_list_pair.second[i])); + } + + chunk_speaker_samples_list_pair.first = std::move(chunk_speaker_pair); + chunk_speaker_samples_list_pair.second = std::move(sample_indexes); + } + + std::vector cluster_labels = clustering_->Cluster( + &embeddings(0, 0), embeddings.rows(), embeddings.cols()); + + int32_t max_cluster_index = + *std::max_element(cluster_labels.begin(), cluster_labels.end()); + + auto chunk_speaker_to_cluster = ConvertChunkSpeakerToCluster( + chunk_speaker_samples_list_pair.first, cluster_labels); + + auto new_labels = + ReLabel(labels, max_cluster_index, chunk_speaker_to_cluster); + + Matrix2DInt32 speaker_count = ComputeSpeakerCount(new_labels, n); + + Matrix2DInt32 final_labels = + FinalizeLabels(speaker_count, speakers_per_frame); + + auto result = ComputeResult(final_labels); + + return result; + } + + private: + void Init() { InitPowersetMapping(); } + + // see also + // https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/utils/powerset.py#L68 + void InitPowersetMapping() { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t num_classes = meta_data.num_classes; + int32_t powerset_max_classes = meta_data.powerset_max_classes; + int32_t num_speakers = meta_data.num_speakers; + + powerset_mapping_ = Matrix2DInt32(num_classes, num_speakers); + powerset_mapping_.setZero(); + + int32_t k = 1; + for (int32_t i = 1; i <= powerset_max_classes; ++i) { + if (i == 1) { + for (int32_t j = 0; j != num_speakers; ++j, ++k) { + powerset_mapping_(k, j) = 1; + } + } else if (i == 2) { + for (int32_t j = 0; j != num_speakers; ++j) { + for (int32_t m = j + 1; m < num_speakers; ++m, ++k) { + powerset_mapping_(k, j) = 1; + powerset_mapping_(k, m) = 1; + } + } + } else { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "powerset_max_classes = %{public}d is currently not supported!", i); +#else + SHERPA_ONNX_LOGE( + "powerset_max_classes = %d is currently not supported!", i); +#endif + SHERPA_ONNX_EXIT(-1); + } + } + } + + std::vector RunSpeakerSegmentationModel(const float *audio, + int32_t n) const { + std::vector ans; + + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + + if (n <= 0) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "number of audio samples is %{public}d (<= 0). Please provide a " + "positive number", + n); +#else + SHERPA_ONNX_LOGE( + "number of audio samples is %d (<= 0). Please provide a positive " + "number", + n); +#endif + return {}; + } + + if (n <= window_size) { + std::vector buf(window_size); + // NOTE: buf is zero initialized by default + + std::copy(audio, audio + n, buf.data()); + + Matrix2D m = ProcessChunk(buf.data()); + + ans.push_back(std::move(m)); + + return ans; + } + + int32_t num_chunks = (n - window_size) / window_shift + 1; + bool has_last_chunk = ((n - window_size) % window_shift) > 0; + + ans.reserve(num_chunks + has_last_chunk); + + const float *p = audio; + + for (int32_t i = 0; i != num_chunks; ++i, p += window_shift) { + Matrix2D m = ProcessChunk(p); + + ans.push_back(std::move(m)); + } + + if (has_last_chunk) { + std::vector buf(window_size); + std::copy(p, audio + n, buf.data()); + + Matrix2D m = ProcessChunk(buf.data()); + + ans.push_back(std::move(m)); + } + + return ans; + } + + Matrix2D ProcessChunk(const float *p) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + + auto memory_info = + (MNNAllocator*)(nullptr); + + std::array shape = {1, 1, window_size}; + + MNN::Express::VARP x = + MNNUtilsCreateTensor(memory_info, const_cast(p), + window_size, shape.data(), shape.size()); + + MNN::Express::VARP out = segmentation_model_.Forward(std::move(x)); + std::vector out_shape = out->getInfo()->dim; + Matrix2D m(out_shape[1], out_shape[2]); + std::copy(out->readMap(), out->readMap() + m.size(), + &m(0, 0)); + return m; + } + + Matrix2DInt32 ToMultiLabel(const Matrix2D &m) const { + int32_t num_rows = m.rows(); + Matrix2DInt32 ans(num_rows, powerset_mapping_.cols()); + + std::ptrdiff_t col_id; + + for (int32_t i = 0; i != num_rows; ++i) { + m.row(i).maxCoeff(&col_id); + ans.row(i) = powerset_mapping_.row(col_id); + } + + return ans; + } + + // See also + // https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/pipelines/utils/diarization.py#L122 + Int32RowVector ComputeSpeakersPerFrame( + const std::vector &labels) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + + int32_t num_chunks = labels.size(); + + int32_t num_frames = (window_size + (num_chunks - 1) * window_shift) / + receptive_field_shift + + 1; + + FloatRowVector count(num_frames); + FloatRowVector weight(num_frames); + count.setZero(); + weight.setZero(); + + for (int32_t i = 0; i != num_chunks; ++i) { + int32_t start = + static_cast(i) * window_shift / receptive_field_shift + 0.5; + + auto seq = Eigen::seqN(start, labels[i].rows()); + + count(seq).array() += labels[i].rowwise().sum().array().cast(); + + weight(seq).array() += 1; + } + + return ((count.array() / (weight.array() + 1e-12f)) + 0.5).cast(); + } + + // ans.first: a list of (chunk_id, speaker_id) + // ans.second: a list of list of (start_sample_index, end_sample_index) + // + // ans.first[i] corresponds to ans.second[i] + std::pair, std::vector>> + GetChunkSpeakerSampleIndexes(const std::vector &labels) const { + auto new_labels = ExcludeOverlap(labels); + + std::vector chunk_speaker_list; + std::vector> samples_index_list; + + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + int32_t num_speakers = meta_data.num_speakers; + + int32_t chunk_index = 0; + for (const auto &label : new_labels) { + Matrix2DInt32 tmp = label.transpose(); + // tmp: (num_speakers, num_frames) + int32_t num_frames = tmp.cols(); + + int32_t sample_offset = chunk_index * window_shift; + + for (int32_t speaker_index = 0; speaker_index != num_speakers; + ++speaker_index) { + auto d = tmp.row(speaker_index); + if (d.sum() < 10) { + // skip segments less than 10 frames + continue; + } + + Int32Pair this_chunk_speaker = {chunk_index, speaker_index}; + std::vector this_speaker_samples; + + bool is_active = false; + int32_t start_index; + + for (int32_t k = 0; k != num_frames; ++k) { + if (d[k] != 0) { + if (!is_active) { + is_active = true; + start_index = k; + } + } else if (is_active) { + is_active = false; + + int32_t start_samples = + static_cast(start_index) / num_frames * window_size + + sample_offset; + int32_t end_samples = + static_cast(k) / num_frames * window_size + + sample_offset; + + this_speaker_samples.emplace_back(start_samples, end_samples); + } + } + + if (is_active) { + int32_t start_samples = + static_cast(start_index) / num_frames * window_size + + sample_offset; + int32_t end_samples = + static_cast(num_frames - 1) / num_frames * window_size + + sample_offset; + this_speaker_samples.emplace_back(start_samples, end_samples); + } + + chunk_speaker_list.push_back(std::move(this_chunk_speaker)); + samples_index_list.push_back(std::move(this_speaker_samples)); + } // for (int32_t speaker_index = 0; + chunk_index += 1; + } // for (const auto &label : new_labels) + + return {chunk_speaker_list, samples_index_list}; + } + + // If there are multiple speakers at a frame, then this frame is excluded. + std::vector ExcludeOverlap( + const std::vector &labels) const { + int32_t num_chunks = labels.size(); + std::vector ans; + ans.reserve(num_chunks); + + for (const auto &label : labels) { + Matrix2DInt32 new_label(label.rows(), label.cols()); + new_label.setZero(); + Int32RowVector v = label.rowwise().sum(); + + for (int32_t i = 0; i != v.cols(); ++i) { + if (v[i] < 2) { + new_label.row(i) = label.row(i); + } + } + + ans.push_back(std::move(new_label)); + } + + return ans; + } + + /** + * @param sample_indexes[i] contains the sample segment start and end indexes + * for the i-th (chunk, speaker) pair + * @return Return a matrix of shape (sample_indexes.size(), embedding_dim) + * where ans.row[i] contains the embedding for the + * i-th (chunk, speaker) pair + */ + Matrix2D ComputeEmbeddings( + const float *audio, int32_t n, + const std::vector> &sample_indexes, + std::vector *valid_indexes, + OfflineSpeakerDiarizationProgressCallback callback, + void *callback_arg) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t sample_rate = meta_data.sample_rate; + Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim()); + + auto IsNaNWrapper = [](float f) -> bool { return std::isnan(f); }; + + int32_t k = 0; + int32_t cur_row_index = 0; + for (const auto &v : sample_indexes) { + auto stream = embedding_extractor_.CreateStream(); + for (const auto &p : v) { + int32_t end = (p.second <= n) ? p.second : n; + int32_t num_samples = end - p.first; + + if (num_samples > 0) { + stream->AcceptWaveform(sample_rate, audio + p.first, num_samples); + } + } + + stream->InputFinished(); + if (!embedding_extractor_.IsReady(stream.get())) { + SHERPA_ONNX_LOGE( + "This segment is too short, which should not happen since we have " + "already filtered short segments"); + SHERPA_ONNX_EXIT(-1); + } + + std::vector embedding = embedding_extractor_.Compute(stream.get()); + + if (std::none_of(embedding.begin(), embedding.end(), IsNaNWrapper)) { + // a valid embedding + std::copy(embedding.begin(), embedding.end(), &ans(cur_row_index, 0)); + cur_row_index += 1; + valid_indexes->push_back(k); + } + + k += 1; + + if (callback) { + callback(k, ans.rows(), callback_arg); + } + } + + if (k != cur_row_index) { + auto seq = Eigen::seqN(0, cur_row_index); + ans = ans(seq, Eigen::all); + } + + return ans; + } + + std::unordered_map ConvertChunkSpeakerToCluster( + const std::vector &chunk_speaker_pair, + const std::vector &cluster_labels) const { + std::unordered_map ans; + + int32_t k = 0; + for (const auto &p : chunk_speaker_pair) { + ans[p] = cluster_labels[k]; + k += 1; + } + + return ans; + } + + std::vector ReLabel( + const std::vector &labels, int32_t max_cluster_index, + std::unordered_map chunk_speaker_to_cluster) + const { + std::vector new_labels; + new_labels.reserve(labels.size()); + + int32_t chunk_index = 0; + for (const auto &label : labels) { + Matrix2DInt32 new_label(label.rows(), max_cluster_index + 1); + new_label.setZero(); + + Matrix2DInt32 t = label.transpose(); + // t: (num_speakers, num_frames) + + for (int32_t speaker_index = 0; speaker_index != t.rows(); + ++speaker_index) { + if (chunk_speaker_to_cluster.count({chunk_index, speaker_index}) == 0) { + continue; + } + + int32_t new_speaker_index = + chunk_speaker_to_cluster.at({chunk_index, speaker_index}); + + for (int32_t k = 0; k != t.cols(); ++k) { + if (t(speaker_index, k) == 1) { + new_label(k, new_speaker_index) = 1; + } + } + } + + new_labels.push_back(std::move(new_label)); + + chunk_index += 1; + } + + return new_labels; + } + + Matrix2DInt32 ComputeSpeakerCount(const std::vector &labels, + int32_t num_samples) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + + int32_t num_chunks = labels.size(); + + int32_t num_frames = (window_size + (num_chunks - 1) * window_shift) / + receptive_field_shift + + 1; + + Matrix2DInt32 count(num_frames, labels[0].cols()); + count.setZero(); + + for (int32_t i = 0; i != num_chunks; ++i) { + int32_t start = + static_cast(i) * window_shift / receptive_field_shift + 0.5; + + auto seq = Eigen::seqN(start, labels[i].rows()); + + count(seq, Eigen::all).array() += labels[i].array(); + } + + bool has_last_chunk = ((num_samples - window_size) % window_shift) > 0; + + if (!has_last_chunk) { + return count; + } + + int32_t last_frame = num_samples / receptive_field_shift; + return count(Eigen::seq(0, last_frame), Eigen::all); + } + + Matrix2DInt32 FinalizeLabels(const Matrix2DInt32 &count, + const Int32RowVector &speakers_per_frame) const { + int32_t num_rows = count.rows(); + int32_t num_cols = count.cols(); + + Matrix2DInt32 ans(num_rows, num_cols); + ans.setZero(); + + for (int32_t i = 0; i != num_rows; ++i) { + int32_t k = speakers_per_frame[i]; + if (k == 0) { + continue; + } + auto top_k = TopkIndex(&count(i, 0), num_cols, k); + + for (int32_t m : top_k) { + ans(i, m) = 1; + } + } + + return ans; + } + + OfflineSpeakerDiarizationResult ComputeResult( + const Matrix2DInt32 &final_labels) const { + Matrix2DInt32 final_labels_t = final_labels.transpose(); + int32_t num_speakers = final_labels_t.rows(); + int32_t num_frames = final_labels_t.cols(); + + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + int32_t receptive_field_size = meta_data.receptive_field_size; + int32_t sample_rate = meta_data.sample_rate; + + float scale = static_cast(receptive_field_shift) / sample_rate; + float scale_offset = 0.5 * receptive_field_size / sample_rate; + + OfflineSpeakerDiarizationResult ans; + + for (int32_t speaker_index = 0; speaker_index != num_speakers; + ++speaker_index) { + std::vector this_speaker; + + bool is_active = final_labels_t(speaker_index, 0) > 0; + int32_t start_index = is_active ? 0 : -1; + + for (int32_t frame_index = 1; frame_index != num_frames; ++frame_index) { + if (is_active) { + if (final_labels_t(speaker_index, frame_index) == 0) { + float start_time = start_index * scale + scale_offset; + float end_time = frame_index * scale + scale_offset; + + OfflineSpeakerDiarizationSegment segment(start_time, end_time, + speaker_index); + this_speaker.push_back(segment); + + is_active = false; + } + } else if (final_labels_t(speaker_index, frame_index) == 1) { + is_active = true; + start_index = frame_index; + } + } + + if (is_active) { + float start_time = start_index * scale + scale_offset; + float end_time = (num_frames - 1) * scale + scale_offset; + + OfflineSpeakerDiarizationSegment segment(start_time, end_time, + speaker_index); + this_speaker.push_back(segment); + } + + // merge segments if the gap between them is less than min_duration_off + MergeSegments(&this_speaker); + + for (const auto &seg : this_speaker) { + if (seg.Duration() > config_.min_duration_on) { + ans.Add(seg); + } + } + } // for (int32_t speaker_index = 0; speaker_index != num_speakers; + + return ans; + } + + OfflineSpeakerDiarizationResult HandleOneChunkSpecialCase( + const Matrix2DInt32 &final_labels, int32_t num_samples) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + + bool has_last_chunk = (num_samples - window_size) % window_shift > 0; + if (!has_last_chunk) { + return ComputeResult(final_labels); + } + + int32_t num_frames = final_labels.rows(); + + int32_t new_num_frames = num_samples / receptive_field_shift; + + num_frames = (new_num_frames <= num_frames) ? new_num_frames : num_frames; + + return ComputeResult(final_labels(Eigen::seq(0, num_frames), Eigen::all)); + } + + void MergeSegments( + std::vector *segments) const { + float min_duration_off = config_.min_duration_off; + bool changed = true; + while (changed) { + changed = false; + for (int32_t i = 0; i < static_cast(segments->size()) - 1; ++i) { + auto s = (*segments)[i].Merge((*segments)[i + 1], min_duration_off); + if (s) { + (*segments)[i] = s.value(); + segments->erase(segments->begin() + i + 1); + + changed = true; + break; + } + } + } + } + + private: + OfflineSpeakerDiarizationConfig config_; + OfflineSpeakerSegmentationPyannoteModel segmentation_model_; + SpeakerEmbeddingExtractor embedding_extractor_; + std::unique_ptr clustering_; + Matrix2DInt32 powerset_mapping_; +}; + +} // namespace sherpa_mnn +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization-result.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization-result.cc new file mode 100644 index 00000000..379a91e2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization-result.cc @@ -0,0 +1,113 @@ +// sherpa-mnn/csrc/offline-speaker-diarization-result.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-speaker-diarization-result.h" + +#include +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +OfflineSpeakerDiarizationSegment::OfflineSpeakerDiarizationSegment( + float start, float end, int32_t speaker, const std::string &text /*= {}*/) { + if (start > end) { + SHERPA_ONNX_LOGE("start %.3f should be less than end %.3f", start, end); + SHERPA_ONNX_EXIT(-1); + } + + start_ = start; + end_ = end; + speaker_ = speaker; + text_ = text; +} + +std::optional +OfflineSpeakerDiarizationSegment::Merge( + const OfflineSpeakerDiarizationSegment &other, float gap) const { + if (other.speaker_ != speaker_) { + SHERPA_ONNX_LOGE( + "The two segments should have the same speaker. this->speaker: %d, " + "other.speaker: %d", + speaker_, other.speaker_); + return std::nullopt; + } + + if (end_ < other.start_ && end_ + gap >= other.start_) { + return OfflineSpeakerDiarizationSegment(start_, other.end_, speaker_); + } else if (other.end_ < start_ && other.end_ + gap >= start_) { + return OfflineSpeakerDiarizationSegment(other.start_, end_, speaker_); + } else { + return std::nullopt; + } +} + +std::string OfflineSpeakerDiarizationSegment::ToString() const { + std::array s{}; + + snprintf(s.data(), s.size(), "%.3f -- %.3f speaker_%02d", start_, end_, + speaker_); + + std::ostringstream os; + os << s.data(); + + if (!text_.empty()) { + os << " " << text_; + } + + return os.str(); +} + +void OfflineSpeakerDiarizationResult::Add( + const OfflineSpeakerDiarizationSegment &segment) { + segments_.push_back(segment); +} + +int32_t OfflineSpeakerDiarizationResult::NumSpeakers() const { + std::unordered_set count; + for (const auto &s : segments_) { + count.insert(s.Speaker()); + } + + return count.size(); +} + +int32_t OfflineSpeakerDiarizationResult::NumSegments() const { + return segments_.size(); +} + +// Return a list of segments sorted by segment.start time +std::vector +OfflineSpeakerDiarizationResult::SortByStartTime() const { + auto ans = segments_; + std::sort(ans.begin(), ans.end(), [](const auto &a, const auto &b) { + return (a.Start() < b.Start()) || + ((a.Start() == b.Start()) && (a.Speaker() < b.Speaker())); + }); + + return ans; +} + +std::vector> +OfflineSpeakerDiarizationResult::SortBySpeaker() const { + auto tmp = segments_; + std::sort(tmp.begin(), tmp.end(), [](const auto &a, const auto &b) { + return (a.Speaker() < b.Speaker()) || + ((a.Speaker() == b.Speaker()) && (a.Start() < b.Start())); + }); + + std::vector> ans(NumSpeakers()); + for (auto &s : tmp) { + ans[s.Speaker()].push_back(std::move(s)); + } + + return ans; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization-result.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization-result.h new file mode 100644 index 00000000..b6340c11 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization-result.h @@ -0,0 +1,67 @@ +// sherpa-mnn/csrc/offline-speaker-diarization-result.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ + +#include +#include +#include +#include + +namespace sherpa_mnn { + +class OfflineSpeakerDiarizationSegment { + public: + OfflineSpeakerDiarizationSegment(float start, float end, int32_t speaker, + const std::string &text = {}); + + // If the gap between the two segments is less than the given gap, then we + // merge them and return a new segment. Otherwise, it returns null. + std::optional Merge( + const OfflineSpeakerDiarizationSegment &other, float gap) const; + + float Start() const { return start_; } + float End() const { return end_; } + int32_t Speaker() const { return speaker_; } + const std::string &Text() const { return text_; } + float Duration() const { return end_ - start_; } + + void SetText(const std::string &text) { text_ = text; } + + std::string ToString() const; + + private: + float start_; // in seconds + float end_; // in seconds + int32_t speaker_; // ID of the speaker, starting from 0 + std::string text_; // If not empty, it contains the speech recognition result + // of this segment +}; + +class OfflineSpeakerDiarizationResult { + public: + // Add a new segment + void Add(const OfflineSpeakerDiarizationSegment &segment); + + // Number of distinct speakers contained in this object at this point + int32_t NumSpeakers() const; + + int32_t NumSegments() const; + + // Return a list of segments sorted by segment.start time + std::vector SortByStartTime() const; + + // ans.size() == NumSpeakers(). + // ans[i] is for speaker_i and is sorted by start time + std::vector> SortBySpeaker() + const; + + private: + std::vector segments_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization.cc new file mode 100644 index 00000000..690948b6 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization.cc @@ -0,0 +1,119 @@ +// sherpa-mnn/csrc/offline-speaker-diarization.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-speaker-diarization.h" + +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/offline-speaker-diarization-impl.h" + +namespace sherpa_mnn { + +void OfflineSpeakerDiarizationConfig::Register(ParseOptions *po) { + ParseOptions po_segmentation("segmentation", po); + segmentation.Register(&po_segmentation); + + ParseOptions po_embedding("embedding", po); + embedding.Register(&po_embedding); + + ParseOptions po_clustering("clustering", po); + clustering.Register(&po_clustering); + + po->Register("min-duration-on", &min_duration_on, + "if a segment is less than this value, then it is discarded. " + "Set it to 0 so that no segment is discarded"); + + po->Register("min-duration-off", &min_duration_off, + "if the gap between to segments of the same speaker is less " + "than this value, then these two segments are merged into a " + "single segment. We do it recursively."); +} + +bool OfflineSpeakerDiarizationConfig::Validate() const { + if (!segmentation.Validate()) { + return false; + } + + if (!embedding.Validate()) { + return false; + } + + if (!clustering.Validate()) { + return false; + } + + if (min_duration_on < 0) { + SHERPA_ONNX_LOGE("min_duration_on %.3f is negative", min_duration_on); + return false; + } + + if (min_duration_off < 0) { + SHERPA_ONNX_LOGE("min_duration_off %.3f is negative", min_duration_off); + return false; + } + + return true; +} + +std::string OfflineSpeakerDiarizationConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeakerDiarizationConfig("; + os << "segmentation=" << segmentation.ToString() << ", "; + os << "embedding=" << embedding.ToString() << ", "; + os << "clustering=" << clustering.ToString() << ", "; + os << "min_duration_on=" << min_duration_on << ", "; + os << "min_duration_off=" << min_duration_off << ")"; + + return os.str(); +} + +OfflineSpeakerDiarization::OfflineSpeakerDiarization( + const OfflineSpeakerDiarizationConfig &config) + : impl_(OfflineSpeakerDiarizationImpl::Create(config)) {} + +template +OfflineSpeakerDiarization::OfflineSpeakerDiarization( + Manager *mgr, const OfflineSpeakerDiarizationConfig &config) + : impl_(OfflineSpeakerDiarizationImpl::Create(mgr, config)) {} + +OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default; + +int32_t OfflineSpeakerDiarization::SampleRate() const { + return impl_->SampleRate(); +} + +void OfflineSpeakerDiarization::SetConfig( + const OfflineSpeakerDiarizationConfig &config) { + impl_->SetConfig(config); +} + +OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process( + const float *audio, int32_t n, + OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/, + void *callback_arg /*= nullptr*/) const { + return impl_->Process(audio, n, std::move(callback), callback_arg); +} + +#if __ANDROID_API__ >= 9 +template OfflineSpeakerDiarization::OfflineSpeakerDiarization( + AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config); +#endif + +#if __OHOS__ +template OfflineSpeakerDiarization::OfflineSpeakerDiarization( + NativeResourceManager *mgr, const OfflineSpeakerDiarizationConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization.h new file mode 100644 index 00000000..aae7f2d5 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-diarization.h @@ -0,0 +1,84 @@ +// sherpa-mnn/csrc/offline-speaker-diarization.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ + +#include +#include +#include + +#include "sherpa-mnn/csrc/fast-clustering-config.h" +#include "sherpa-mnn/csrc/offline-speaker-diarization-result.h" +#include "sherpa-mnn/csrc/offline-speaker-segmentation-model-config.h" +#include "sherpa-mnn/csrc/speaker-embedding-extractor.h" + +namespace sherpa_mnn { + +struct OfflineSpeakerDiarizationConfig { + OfflineSpeakerSegmentationModelConfig segmentation; + SpeakerEmbeddingExtractorConfig embedding; + FastClusteringConfig clustering; + + // if a segment is less than this value, then it is discarded + float min_duration_on = 0.3; // in seconds + + // if the gap between to segments of the same speaker is less than this value, + // then these two segments are merged into a single segment. + // We do this recursively. + float min_duration_off = 0.5; // in seconds + + OfflineSpeakerDiarizationConfig() = default; + + OfflineSpeakerDiarizationConfig( + const OfflineSpeakerSegmentationModelConfig &segmentation, + const SpeakerEmbeddingExtractorConfig &embedding, + const FastClusteringConfig &clustering, float min_duration_on, + float min_duration_off) + : segmentation(segmentation), + embedding(embedding), + clustering(clustering), + min_duration_on(min_duration_on), + min_duration_off(min_duration_off) {} + + void Register(ParseOptions *po); + bool Validate() const; + std::string ToString() const; +}; + +class OfflineSpeakerDiarizationImpl; + +using OfflineSpeakerDiarizationProgressCallback = std::function; + +class OfflineSpeakerDiarization { + public: + explicit OfflineSpeakerDiarization( + const OfflineSpeakerDiarizationConfig &config); + + template + OfflineSpeakerDiarization(Manager *mgr, + const OfflineSpeakerDiarizationConfig &config); + + ~OfflineSpeakerDiarization(); + + // Expected sample rate of the input audio samples + int32_t SampleRate() const; + + // Note: Only config.clustering is used. All other fields in config are + // ignored + void SetConfig(const OfflineSpeakerDiarizationConfig &config); + + OfflineSpeakerDiarizationResult Process( + const float *audio, int32_t n, + OfflineSpeakerDiarizationProgressCallback callback = nullptr, + void *callback_arg = nullptr) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-model-config.cc new file mode 100644 index 00000000..c74832cd --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-model-config.cc @@ -0,0 +1,57 @@ +// sherpa-mnn/csrc/offline-speaker-segmentation-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-mnn/csrc/offline-speaker-segmentation-model-config.h" + +#include +#include + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineSpeakerSegmentationModelConfig::Register(ParseOptions *po) { + pyannote.Register(po); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool OfflineSpeakerSegmentationModelConfig::Validate() const { + if (num_threads < 1) { + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); + return false; + } + + if (!pyannote.model.empty()) { + return pyannote.Validate(); + } + + if (pyannote.model.empty()) { + SHERPA_ONNX_LOGE( + "You have to provide at least one speaker segmentation model"); + return false; + } + + return true; +} + +std::string OfflineSpeakerSegmentationModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeakerSegmentationModelConfig("; + os << "pyannote=" << pyannote.ToString() << ", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-model-config.h new file mode 100644 index 00000000..9fa45f3a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-model-config.h @@ -0,0 +1,40 @@ +// sherpa-mnn/csrc/offline-speaker-segmentation-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model-config.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineSpeakerSegmentationModelConfig { + OfflineSpeakerSegmentationPyannoteModelConfig pyannote; + + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + OfflineSpeakerSegmentationModelConfig() = default; + + explicit OfflineSpeakerSegmentationModelConfig( + const OfflineSpeakerSegmentationPyannoteModelConfig &pyannote, + int32_t num_threads, bool debug, const std::string &provider) + : pyannote(pyannote), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model-config.cc new file mode 100644 index 00000000..3051ba05 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model-config.cc @@ -0,0 +1,38 @@ +// sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model-config.h" + +#include +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineSpeakerSegmentationPyannoteModelConfig::Register(ParseOptions *po) { + po->Register("pyannote-model", &model, + "Path to model.onnx of the Pyannote segmentation model."); +} + +bool OfflineSpeakerSegmentationPyannoteModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("Pyannote segmentation model: '%s' does not exist", + model.c_str()); + return false; + } + + return true; +} + +std::string OfflineSpeakerSegmentationPyannoteModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeakerSegmentationPyannoteModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model-config.h new file mode 100644 index 00000000..98f4d633 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model-config.h @@ -0,0 +1,30 @@ +// sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_ +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineSpeakerSegmentationPyannoteModelConfig { + std::string model; + + OfflineSpeakerSegmentationPyannoteModelConfig() = default; + + explicit OfflineSpeakerSegmentationPyannoteModelConfig( + const std::string &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h new file mode 100644 index 00000000..db643628 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h @@ -0,0 +1,29 @@ +// sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ + +#include +#include + +namespace sherpa_mnn { + +// If you are not sure what each field means, please +// have a look of the Python file in the model directory that +// you have downloaded. +struct OfflineSpeakerSegmentationPyannoteModelMetaData { + int32_t sample_rate = 0; + int32_t window_size = 0; // in samples + int32_t window_shift = 0; // in samples + int32_t receptive_field_size = 0; // in samples + int32_t receptive_field_shift = 0; // in samples + int32_t num_speakers = 0; + int32_t powerset_max_classes = 0; + int32_t num_classes = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model.cc new file mode 100644 index 00000000..23b437e0 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model.cc @@ -0,0 +1,149 @@ +// sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" + +namespace sherpa_mnn { + +class OfflineSpeakerSegmentationPyannoteModel::Impl { + public: + explicit Impl(const OfflineSpeakerSegmentationModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.pyannote.model); + Init(buf.data(), buf.size()); + } + + template + Impl(Manager *mgr, const OfflineSpeakerSegmentationModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.pyannote.model); + Init(buf.data(), buf.size()); + } + + const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData() + const { + return meta_data_; + } + + MNN::Express::VARP Forward(MNN::Express::VARP x) { + auto out = sess_->onForward({x}); + + return std::move(out[0]); + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); + SHERPA_ONNX_READ_META_DATA(meta_data_.window_size, "window_size"); + + meta_data_.window_shift = + static_cast(0.1 * meta_data_.window_size); + + SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_size, + "receptive_field_size"); + SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_shift, + "receptive_field_shift"); + SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "num_speakers"); + SHERPA_ONNX_READ_META_DATA(meta_data_.powerset_max_classes, + "powerset_max_classes"); + SHERPA_ONNX_READ_META_DATA(meta_data_.num_classes, "num_classes"); + } + + private: + OfflineSpeakerSegmentationModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + OfflineSpeakerSegmentationPyannoteModelMetaData meta_data_; +}; + +OfflineSpeakerSegmentationPyannoteModel:: + OfflineSpeakerSegmentationPyannoteModel( + const OfflineSpeakerSegmentationModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineSpeakerSegmentationPyannoteModel:: + OfflineSpeakerSegmentationPyannoteModel( + Manager *mgr, const OfflineSpeakerSegmentationModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineSpeakerSegmentationPyannoteModel:: + ~OfflineSpeakerSegmentationPyannoteModel() = default; + +const OfflineSpeakerSegmentationPyannoteModelMetaData & +OfflineSpeakerSegmentationPyannoteModel::GetModelMetaData() const { + return impl_->GetModelMetaData(); +} + +MNN::Express::VARP OfflineSpeakerSegmentationPyannoteModel::Forward( + MNN::Express::VARP x) const { + return impl_->Forward(std::move(x)); +} + +#if __ANDROID_API__ >= 9 +template OfflineSpeakerSegmentationPyannoteModel:: + OfflineSpeakerSegmentationPyannoteModel( + AAssetManager *mgr, + const OfflineSpeakerSegmentationModelConfig &config); +#endif + +#if __OHOS__ +template OfflineSpeakerSegmentationPyannoteModel:: + OfflineSpeakerSegmentationPyannoteModel( + NativeResourceManager *mgr, + const OfflineSpeakerSegmentationModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model.h new file mode 100644 index 00000000..b3ddfce8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model.h @@ -0,0 +1,44 @@ +// sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-speaker-segmentation-model-config.h" +#include "sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h" + +namespace sherpa_mnn { + +class OfflineSpeakerSegmentationPyannoteModel { + public: + explicit OfflineSpeakerSegmentationPyannoteModel( + const OfflineSpeakerSegmentationModelConfig &config); + + template + OfflineSpeakerSegmentationPyannoteModel( + Manager *mgr, const OfflineSpeakerSegmentationModelConfig &config); + + ~OfflineSpeakerSegmentationPyannoteModel(); + + const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData() + const; + + /** + * @param x A 3-D float tensor of shape (batch_size, 1, num_samples) + * @return Return a float tensor of + * shape (batch_size, num_frames, num_speakers). Note that + * num_speakers here uses powerset encoding. + */ + MNN::Express::VARP Forward(MNN::Express::VARP x) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-impl.h new file mode 100644 index 00000000..fcf80de1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-impl.h @@ -0,0 +1,148 @@ +// sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_ + +#include +#include +#include +#include + +#include "kaldi-native-fbank/csrc/feature-window.h" +#include "kaldi-native-fbank/csrc/istft.h" +#include "kaldi-native-fbank/csrc/stft.h" +#include "sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model.h" +#include "sherpa-mnn/csrc/offline-speech-denoiser-impl.h" +#include "sherpa-mnn/csrc/offline-speech-denoiser.h" +#include "sherpa-mnn/csrc/resample.h" + +namespace sherpa_mnn { + +class OfflineSpeechDenoiserGtcrnImpl : public OfflineSpeechDenoiserImpl { + public: + explicit OfflineSpeechDenoiserGtcrnImpl( + const OfflineSpeechDenoiserConfig &config) + : model_(config.model) {} + + template + OfflineSpeechDenoiserGtcrnImpl(Manager *mgr, + const OfflineSpeechDenoiserConfig &config) + : model_(mgr, config.model) {} + + DenoisedAudio Run(const float *samples, int32_t n, + int32_t sample_rate) const override { + const auto &meta = model_.GetMetaData(); + + std::vector tmp; + auto p = samples; + + if (sample_rate != meta.sample_rate) { + SHERPA_ONNX_LOGE( + "Creating a resampler:\n" + " in_sample_rate: %d\n" + " output_sample_rate: %d\n", + sample_rate, meta.sample_rate); + + float min_freq = std::min(sample_rate, meta.sample_rate); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + + int32_t lowpass_filter_width = 6; + auto resampler = std::make_unique( + sample_rate, meta.sample_rate, lowpass_cutoff, lowpass_filter_width); + resampler->Resample(samples, n, true, &tmp); + p = tmp.data(); + n = tmp.size(); + } + + knf::StftConfig stft_config; + stft_config.n_fft = meta.n_fft; + stft_config.hop_length = meta.hop_length; + stft_config.win_length = meta.window_length; + stft_config.window_type = meta.window_type; + if (stft_config.window_type == "hann_sqrt") { + auto window = knf::GetWindow("hann", stft_config.win_length); + for (auto &w : window) { + w = std::sqrt(w); + } + stft_config.window = std::move(window); + } + + knf::Stft stft(stft_config); + knf::StftResult stft_result = stft.Compute(p, n); + + auto states = model_.GetInitStates(); + OfflineSpeechDenoiserGtcrnModel::States next_states; + + knf::StftResult enhanced_stft_result; + enhanced_stft_result.num_frames = stft_result.num_frames; + for (int32_t i = 0; i < stft_result.num_frames; ++i) { + auto p = Process(stft_result, i, std::move(states), &next_states); + states = std::move(next_states); + + enhanced_stft_result.real.insert(enhanced_stft_result.real.end(), + p.first.begin(), p.first.end()); + enhanced_stft_result.imag.insert(enhanced_stft_result.imag.end(), + p.second.begin(), p.second.end()); + } + + knf::IStft istft(stft_config); + + DenoisedAudio denoised_audio; + denoised_audio.sample_rate = meta.sample_rate; + denoised_audio.samples = istft.Compute(enhanced_stft_result); + return denoised_audio; + } + + int32_t GetSampleRate() const override { + return model_.GetMetaData().sample_rate; + } + + private: + std::pair, std::vector> Process( + const knf::StftResult &stft_result, int32_t frame_index, + OfflineSpeechDenoiserGtcrnModel::States states, + OfflineSpeechDenoiserGtcrnModel::States *next_states) const { + const auto &meta = model_.GetMetaData(); + int32_t n_fft = meta.n_fft; + std::vector x((n_fft / 2 + 1) * 2); + + const float *p_real = + stft_result.real.data() + frame_index * (n_fft / 2 + 1); + const float *p_imag = + stft_result.imag.data() + frame_index * (n_fft / 2 + 1); + + for (int32_t i = 0; i < n_fft / 2 + 1; ++i) { + x[2 * i] = p_real[i]; + x[2 * i + 1] = p_imag[i]; + } + auto memory_info = + (MNNAllocator*)(nullptr); + + std::array x_shape{1, n_fft / 2 + 1, 1, 2}; + MNN::Express::VARP x_tensor = MNNUtilsCreateTensor( + memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); + + MNN::Express::VARP output{nullptr}; + std::tie(output, *next_states) = + model_.Run(std::move(x_tensor), std::move(states)); + + std::vector real(n_fft / 2 + 1); + std::vector imag(n_fft / 2 + 1); + const auto *p = output->readMap(); + for (int32_t i = 0; i < n_fft / 2 + 1; ++i) { + real[i] = p[2 * i]; + imag[i] = p[2 * i + 1]; + } + + return {std::move(real), std::move(imag)}; + } + + private: + OfflineSpeechDenoiserGtcrnModel model_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model-config.cc new file mode 100644 index 00000000..91683420 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model-config.cc @@ -0,0 +1,40 @@ +// sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model-config.h" + +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineSpeechDenoiserGtcrnModelConfig::Register(ParseOptions *po) { + po->Register("speech-denoiser-gtcrn-model", &model, + "Path to the gtcrn model for speech denoising"); +} + +bool OfflineSpeechDenoiserGtcrnModelConfig::Validate() const { + if (model.empty()) { + SHERPA_ONNX_LOGE("Please provide --speech-denoiser-gtcrn-model"); + return false; + } + + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("gtcrn model file '%s' does not exist", model.c_str()); + return false; + } + return true; +} + +std::string OfflineSpeechDenoiserGtcrnModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeechDenoiserGtcrnModelConfig("; + os << "model=\"" << model << "\")"; + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model-config.h new file mode 100644 index 00000000..253e1b88 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model-config.h @@ -0,0 +1,25 @@ +// sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineSpeechDenoiserGtcrnModelConfig { + std::string model; + OfflineSpeechDenoiserGtcrnModelConfig() = default; + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h new file mode 100644 index 00000000..fa8c156b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h @@ -0,0 +1,31 @@ +// sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_ + +#include +#include +#include + +namespace sherpa_mnn { + +// please refer to +// https://github.com/k2-fsa/sherpa-mnn/blob/master/scripts/kokoro/add-meta-data.py +struct OfflineSpeechDenoiserGtcrnModelMetaData { + int32_t sample_rate = 0; + int32_t version = 1; + int32_t n_fft = 0; + int32_t hop_length = 0; + int32_t window_length = 0; + std::string window_type; + + std::vector conv_cache_shape; + std::vector tra_cache_shape; + std::vector inter_cache_shape; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_META_DATA_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model.cc new file mode 100644 index 00000000..153ebf18 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model.cc @@ -0,0 +1,193 @@ +// sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model.h" + +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OfflineSpeechDenoiserGtcrnModel::Impl { + public: + explicit Impl(const OfflineSpeechDenoiserModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.gtcrn.model); + Init(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const OfflineSpeechDenoiserModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.gtcrn.model); + Init(buf.data(), buf.size()); + } + } + + const OfflineSpeechDenoiserGtcrnModelMetaData &GetMetaData() const { + return meta_; + } + + States GetInitStates() const { + MNN::Express::VARP conv_cache = MNNUtilsCreateTensor( + allocator_, meta_.conv_cache_shape.data(), + meta_.conv_cache_shape.size()); + + MNN::Express::VARP tra_cache = MNNUtilsCreateTensor( + allocator_, meta_.tra_cache_shape.data(), meta_.tra_cache_shape.size()); + + MNN::Express::VARP inter_cache = MNNUtilsCreateTensor( + allocator_, meta_.inter_cache_shape.data(), + meta_.inter_cache_shape.size()); + + Fill(conv_cache, 0); + Fill(tra_cache, 0); + Fill(inter_cache, 0); + + std::vector states; + + states.reserve(3); + states.push_back(std::move(conv_cache)); + states.push_back(std::move(tra_cache)); + states.push_back(std::move(inter_cache)); + + return states; + } + + std::pair Run(MNN::Express::VARP x, States states) const { + std::vector inputs; + inputs.reserve(1 + states.size()); + inputs.push_back(std::move(x)); + for (auto &s : states) { + inputs.push_back(std::move(s)); + } + + auto out = + sess_->onForward(inputs); + + std::vector next_states; + next_states.reserve(out.size() - 1); + for (int32_t k = 1; k < out.size(); ++k) { + next_states.push_back(std::move(out[k])); + } + + return {std::move(out[0]), std::move(next_states)}; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---gtcrn model---\n"; + PrintModelMetadata(os, meta_data); + + os << "----------input names----------\n"; + int32_t i = 0; + for (const auto &s : input_names_) { + os << i << " " << s << "\n"; + ++i; + } + os << "----------output names----------\n"; + i = 0; + for (const auto &s : output_names_) { + os << i << " " << s << "\n"; + ++i; + } + +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + + std::string model_type; + SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type"); + if (model_type != "gtcrn") { + SHERPA_ONNX_LOGE("Expect model type 'gtcrn'. Given: '%s'", + model_type.c_str()); + SHERPA_ONNX_EXIT(-1); + } + + SHERPA_ONNX_READ_META_DATA(meta_.sample_rate, "sample_rate"); + SHERPA_ONNX_READ_META_DATA(meta_.n_fft, "n_fft"); + SHERPA_ONNX_READ_META_DATA(meta_.hop_length, "hop_length"); + SHERPA_ONNX_READ_META_DATA(meta_.window_length, "window_length"); + SHERPA_ONNX_READ_META_DATA_STR(meta_.window_type, "window_type"); + SHERPA_ONNX_READ_META_DATA(meta_.version, "version"); + + SHERPA_ONNX_READ_META_DATA_VEC(meta_.conv_cache_shape, "conv_cache_shape"); + SHERPA_ONNX_READ_META_DATA_VEC(meta_.tra_cache_shape, "tra_cache_shape"); + SHERPA_ONNX_READ_META_DATA_VEC(meta_.inter_cache_shape, + "inter_cache_shape"); + } + + private: + OfflineSpeechDenoiserModelConfig config_; + OfflineSpeechDenoiserGtcrnModelMetaData meta_; + + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; +}; + +OfflineSpeechDenoiserGtcrnModel::~OfflineSpeechDenoiserGtcrnModel() = default; + +OfflineSpeechDenoiserGtcrnModel::OfflineSpeechDenoiserGtcrnModel( + const OfflineSpeechDenoiserModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineSpeechDenoiserGtcrnModel::OfflineSpeechDenoiserGtcrnModel( + Manager *mgr, const OfflineSpeechDenoiserModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineSpeechDenoiserGtcrnModel::States +OfflineSpeechDenoiserGtcrnModel::GetInitStates() const { + return impl_->GetInitStates(); +} + +std::pair +OfflineSpeechDenoiserGtcrnModel::Run(MNN::Express::VARP x, States states) const { + return impl_->Run(std::move(x), std::move(states)); +} + +const OfflineSpeechDenoiserGtcrnModelMetaData & +OfflineSpeechDenoiserGtcrnModel::GetMetaData() const { + return impl_->GetMetaData(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model.h new file mode 100644 index 00000000..fa5d9229 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model.h @@ -0,0 +1,42 @@ +// sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_ +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model-meta-data.h" +#include "sherpa-mnn/csrc/offline-speech-denoiser-model-config.h" +#include "sherpa-mnn/csrc/offline-speech-denoiser.h" + +namespace sherpa_mnn { + +class OfflineSpeechDenoiserGtcrnModel { + public: + ~OfflineSpeechDenoiserGtcrnModel(); + explicit OfflineSpeechDenoiserGtcrnModel( + const OfflineSpeechDenoiserModelConfig &config); + + template + OfflineSpeechDenoiserGtcrnModel( + Manager *mgr, const OfflineSpeechDenoiserModelConfig &config); + + using States = std::vector; + + States GetInitStates() const; + + std::pair Run(MNN::Express::VARP x, States states) const; + + const OfflineSpeechDenoiserGtcrnModelMetaData &GetMetaData() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-impl.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-impl.cc new file mode 100644 index 00000000..4e58e8d1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-impl.cc @@ -0,0 +1,53 @@ +// sherpa-mnn/csrc/offline-speech-denoiser-impl.cc +// +// Copyright (c) 2025 Xiaomi Corporation +#include "sherpa-mnn/csrc/offline-speech-denoiser-impl.h" + +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-impl.h" + +namespace sherpa_mnn { + +std::unique_ptr OfflineSpeechDenoiserImpl::Create( + const OfflineSpeechDenoiserConfig &config) { + if (!config.model.gtcrn.model.empty()) { + return std::make_unique(config); + } + SHERPA_ONNX_LOGE("Please provide a speech denoising model."); + return nullptr; +} + +template +std::unique_ptr OfflineSpeechDenoiserImpl::Create( + Manager *mgr, const OfflineSpeechDenoiserConfig &config) { + if (!config.model.gtcrn.model.empty()) { + return std::make_unique(mgr, config); + } + SHERPA_ONNX_LOGE("Please provide a speech denoising model."); + return nullptr; +} + +#if __ANDROID_API__ >= 9 +template std::unique_ptr +OfflineSpeechDenoiserImpl::Create(AAssetManager *mgr, + const OfflineSpeechDenoiserConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr +OfflineSpeechDenoiserImpl::Create(NativeResourceManager *mgr, + const OfflineSpeechDenoiserConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-impl.h new file mode 100644 index 00000000..a103aa0e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-impl.h @@ -0,0 +1,33 @@ +// sherpa-mnn/csrc/offline-speaker-speech-denoiser-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_ + +#include + +#include "sherpa-mnn/csrc/offline-speech-denoiser.h" + +namespace sherpa_mnn { + +class OfflineSpeechDenoiserImpl { + public: + virtual ~OfflineSpeechDenoiserImpl() = default; + + static std::unique_ptr Create( + const OfflineSpeechDenoiserConfig &config); + + template + static std::unique_ptr Create( + Manager *mgr, const OfflineSpeechDenoiserConfig &config); + + virtual DenoisedAudio Run(const float *samples, int32_t n, + int32_t sample_rate) const = 0; + + virtual int32_t GetSampleRate() const = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-model-config.cc new file mode 100644 index 00000000..ca43d0dc --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-model-config.cc @@ -0,0 +1,40 @@ +// sherpa-mnn/csrc/offline-speech-denoiser-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-speech-denoiser-model-config.h" + +#include + +namespace sherpa_mnn { + +void OfflineSpeechDenoiserModelConfig::Register(ParseOptions *po) { + gtcrn.Register(po); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool OfflineSpeechDenoiserModelConfig::Validate() const { + return gtcrn.Validate(); +} + +std::string OfflineSpeechDenoiserModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeechDenoiserModelConfig("; + os << "gtcrn=" << gtcrn.ToString() << ", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-model-config.h new file mode 100644 index 00000000..95b8a56b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser-model-config.h @@ -0,0 +1,39 @@ +// sherpa-mnn/csrc/offline-speech-denoiser-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model-config.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineSpeechDenoiserModelConfig { + OfflineSpeechDenoiserGtcrnModelConfig gtcrn; + + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + OfflineSpeechDenoiserModelConfig() = default; + + OfflineSpeechDenoiserModelConfig(OfflineSpeechDenoiserGtcrnModelConfig gtcrn, + int32_t num_threads, bool debug, + const std::string &provider) + : gtcrn(gtcrn), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser.cc new file mode 100644 index 00000000..c1bbcfba --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser.cc @@ -0,0 +1,64 @@ +// sherpa-mnn/csrc/offline-speech-denoiser.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-speech-denoiser.h" + +#include "sherpa-mnn/csrc/offline-speech-denoiser-impl.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +namespace sherpa_mnn { + +void OfflineSpeechDenoiserConfig::Register(ParseOptions *po) { + model.Register(po); +} + +bool OfflineSpeechDenoiserConfig::Validate() const { return model.Validate(); } + +std::string OfflineSpeechDenoiserConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeechDenoiserConfig("; + os << "model=" << model.ToString() << ")"; + return os.str(); +} + +template +OfflineSpeechDenoiser::OfflineSpeechDenoiser( + Manager *mgr, const OfflineSpeechDenoiserConfig &config) + : impl_(OfflineSpeechDenoiserImpl::Create(mgr, config)) {} + +OfflineSpeechDenoiser::OfflineSpeechDenoiser( + const OfflineSpeechDenoiserConfig &config) + : impl_(OfflineSpeechDenoiserImpl::Create(config)) {} + +OfflineSpeechDenoiser::~OfflineSpeechDenoiser() = default; + +DenoisedAudio OfflineSpeechDenoiser::Run(const float *samples, int32_t n, + int32_t sample_rate) const { + return impl_->Run(samples, n, sample_rate); +} + +int32_t OfflineSpeechDenoiser::GetSampleRate() const { + return impl_->GetSampleRate(); +} + +#if __ANDROID_API__ >= 9 +template OfflineSpeechDenoiser::OfflineSpeechDenoiser( + AAssetManager *mgr, const OfflineSpeechDenoiserConfig &config); +#endif + +#if __OHOS__ +template OfflineSpeechDenoiser::OfflineSpeechDenoiser( + NativeResourceManager *mgr, const OfflineSpeechDenoiserConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser.h new file mode 100644 index 00000000..c877cac3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-speech-denoiser.h @@ -0,0 +1,61 @@ +// sherpa-mnn/csrc/offline-speech-denoiser.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_ + +#include +#include +#include + +#include "sherpa-mnn/csrc/offline-speech-denoiser-model-config.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct DenoisedAudio { + std::vector samples; + int32_t sample_rate; +}; + +struct OfflineSpeechDenoiserConfig { + OfflineSpeechDenoiserModelConfig model; + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +class OfflineSpeechDenoiserImpl; + +class OfflineSpeechDenoiser { + public: + explicit OfflineSpeechDenoiser(const OfflineSpeechDenoiserConfig &config); + ~OfflineSpeechDenoiser(); + + template + OfflineSpeechDenoiser(Manager *mgr, + const OfflineSpeechDenoiserConfig &config); + + /* + * @param samples 1-D array of audio samples. Each sample is in the + * range [-1, 1]. + * @param n Number of samples + * @param sample_rate Sample rate of the input samples + * + */ + DenoisedAudio Run(const float *samples, int32_t n, int32_t sample_rate) const; + + /* + * Return the sample rate of the denoised audio + */ + int32_t GetSampleRate() const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEECH_DENOISER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-stream.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-stream.cc new file mode 100644 index 00000000..3788764b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-stream.cc @@ -0,0 +1,434 @@ +// sherpa-mnn/csrc/offline-stream.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-stream.h" + +#include +#include +#include +#include +#include +#include + +#include "kaldi-native-fbank/csrc/online-feature.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-recognizer.h" +#include "sherpa-mnn/csrc/resample.h" + +namespace sherpa_mnn { + +/* Compute mean and inverse stddev over rows. + * + * @param p A pointer to a 2-d array of shape (num_rows, num_cols) + * @param num_rows Number of rows + * @param num_cols Number of columns + * @param mean On return, it contains p.mean(axis=0) + * @param inv_stddev On return, it contains 1/p.std(axis=0) + */ +static void ComputeMeanAndInvStd(const float *p, int32_t num_rows, + int32_t num_cols, std::vector *mean, + std::vector *inv_stddev) { + std::vector sum(num_cols); + std::vector sum_sq(num_cols); + + for (int32_t i = 0; i != num_rows; ++i) { + for (int32_t c = 0; c != num_cols; ++c) { + auto t = p[c]; + sum[c] += t; + sum_sq[c] += t * t; + } + p += num_cols; + } + + mean->resize(num_cols); + inv_stddev->resize(num_cols); + + for (int32_t i = 0; i != num_cols; ++i) { + auto t = sum[i] / num_rows; + (*mean)[i] = t; + + float stddev = std::sqrt(sum_sq[i] / num_rows - t * t); + (*inv_stddev)[i] = 1.0f / (stddev + 1e-5f); + } +} + +class OfflineStream::Impl { + public: + explicit Impl(const FeatureExtractorConfig &config, + ContextGraphPtr context_graph) + : config_(config), context_graph_(std::move(context_graph)) { + if (config.is_mfcc) { + mfcc_opts_.frame_opts.dither = config_.dither; + mfcc_opts_.frame_opts.snip_edges = config_.snip_edges; + mfcc_opts_.frame_opts.samp_freq = config_.sampling_rate; + mfcc_opts_.frame_opts.frame_shift_ms = config_.frame_shift_ms; + mfcc_opts_.frame_opts.frame_length_ms = config_.frame_length_ms; + mfcc_opts_.frame_opts.remove_dc_offset = config_.remove_dc_offset; + mfcc_opts_.frame_opts.window_type = config_.window_type; + + mfcc_opts_.mel_opts.num_bins = config_.feature_dim; + + mfcc_opts_.mel_opts.high_freq = config_.high_freq; + mfcc_opts_.mel_opts.low_freq = config_.low_freq; + + mfcc_opts_.mel_opts.is_librosa = config_.is_librosa; + + mfcc_opts_.num_ceps = config_.num_ceps; + mfcc_opts_.use_energy = config_.use_energy; + + mfcc_ = std::make_unique(mfcc_opts_); + } else { + opts_.frame_opts.dither = config.dither; + opts_.frame_opts.snip_edges = config.snip_edges; + opts_.frame_opts.samp_freq = config.sampling_rate; + opts_.frame_opts.frame_shift_ms = config.frame_shift_ms; + opts_.frame_opts.frame_length_ms = config.frame_length_ms; + opts_.frame_opts.remove_dc_offset = config.remove_dc_offset; + opts_.frame_opts.window_type = config.window_type; + + opts_.mel_opts.num_bins = config.feature_dim; + + opts_.mel_opts.high_freq = config.high_freq; + opts_.mel_opts.low_freq = config.low_freq; + + opts_.mel_opts.is_librosa = config.is_librosa; + + fbank_ = std::make_unique(opts_); + } + } + + explicit Impl(WhisperTag tag) { + config_.normalize_samples = true; + opts_.frame_opts.samp_freq = 16000; + opts_.mel_opts.num_bins = tag.dim; + + knf::WhisperFeatureOptions whisper_opts; + whisper_opts.frame_opts = opts_.frame_opts; + whisper_opts.dim = tag.dim; + + whisper_fbank_ = std::make_unique(whisper_opts); + config_.sampling_rate = opts_.frame_opts.samp_freq; + } + + explicit Impl(CEDTag /*tag*/) : is_ced_(true) { + // see + // https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py + + opts_.frame_opts.frame_length_ms = 32; + opts_.frame_opts.dither = 0; + opts_.frame_opts.preemph_coeff = 0; + opts_.frame_opts.remove_dc_offset = false; + opts_.frame_opts.window_type = "hann"; + opts_.frame_opts.snip_edges = false; + + opts_.frame_opts.samp_freq = 16000; // fixed to 16000 + opts_.mel_opts.num_bins = 64; + opts_.mel_opts.low_freq = 0; + opts_.mel_opts.high_freq = 8000; + opts_.use_log_fbank = false; + + config_.sampling_rate = opts_.frame_opts.samp_freq; + + fbank_ = std::make_unique(opts_); + } + + explicit Impl(MoonshineTag /*tag*/) : is_moonshine_(true) { + config_.sampling_rate = 16000; + } + + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { + if (config_.normalize_samples) { + AcceptWaveformImpl(sampling_rate, waveform, n); + } else { + std::vector buf(n); + for (int32_t i = 0; i != n; ++i) { + buf[i] = waveform[i] * 32768; + } + AcceptWaveformImpl(sampling_rate, buf.data(), n); + } + } + + void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform, + int32_t n) { + if (sampling_rate != config_.sampling_rate) { + SHERPA_ONNX_LOGE( + "Creating a resampler:\n" + " in_sample_rate: %d\n" + " output_sample_rate: %d\n", + sampling_rate, static_cast(config_.sampling_rate)); + + float min_freq = std::min(sampling_rate, config_.sampling_rate); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + + int32_t lowpass_filter_width = 6; + auto resampler = std::make_unique( + sampling_rate, config_.sampling_rate, lowpass_cutoff, + lowpass_filter_width); + std::vector samples; + resampler->Resample(waveform, n, true, &samples); + + if (is_moonshine_) { + samples_.insert(samples_.end(), samples.begin(), samples.end()); + } else if (fbank_) { + fbank_->AcceptWaveform(config_.sampling_rate, samples.data(), + samples.size()); + fbank_->InputFinished(); + } else if (mfcc_) { + mfcc_->AcceptWaveform(config_.sampling_rate, samples.data(), + samples.size()); + mfcc_->InputFinished(); + } else { + whisper_fbank_->AcceptWaveform(config_.sampling_rate, samples.data(), + samples.size()); + whisper_fbank_->InputFinished(); + } + + return; + } // if (sampling_rate != config_.sampling_rate) + + if (is_moonshine_) { + samples_.insert(samples_.end(), waveform, waveform + n); + } else if (fbank_) { + fbank_->AcceptWaveform(sampling_rate, waveform, n); + fbank_->InputFinished(); + } else if (mfcc_) { + mfcc_->AcceptWaveform(sampling_rate, waveform, n); + mfcc_->InputFinished(); + } else { + whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n); + whisper_fbank_->InputFinished(); + } + } + + int32_t FeatureDim() const { + if (is_moonshine_) { + return samples_.size(); + } + + return mfcc_ ? mfcc_opts_.num_ceps : opts_.mel_opts.num_bins; + } + + std::vector GetFrames() const { + if (is_moonshine_) { + return samples_; + } + + int32_t n = fbank_ ? fbank_->NumFramesReady() + : mfcc_ ? mfcc_->NumFramesReady() + : whisper_fbank_->NumFramesReady(); + assert(n > 0 && "Please first call AcceptWaveform()"); + + int32_t feature_dim = FeatureDim(); + + std::vector features(n * feature_dim); + + float *p = features.data(); + + for (int32_t i = 0; i != n; ++i) { + const float *f = fbank_ ? fbank_->GetFrame(i) + : mfcc_ ? mfcc_->GetFrame(i) + : whisper_fbank_->GetFrame(i); + std::copy(f, f + feature_dim, p); + p += feature_dim; + } + + NemoNormalizeFeatures(features.data(), n, feature_dim); + + if (is_ced_) { + AmplitudeToDB(features.data(), features.size()); + } + + return features; + } + + void SetResult(const OfflineRecognitionResult &r) { r_ = r; } + + const OfflineRecognitionResult &GetResult() const { return r_; } + + const ContextGraphPtr &GetContextGraph() const { return context_graph_; } + + private: + // see + // https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/functional.py#L359 + void AmplitudeToDB(float *p, int32_t n) const { + float multiplier = 10; + float top_db = 120; + float amin = 1e-10; + + float max_x = std::numeric_limits::min(); + + for (int32_t i = 0; i != n; ++i) { + float x = p[i]; + x = (x > amin) ? x : amin; + x = log10f(x) * multiplier; + + max_x = (x > max_x) ? x : max_x; + p[i] = x; + } + + float d = max_x - top_db; + for (int32_t i = 0; i != n; ++i) { + float x = p[i]; + x = (x > d) ? x : d; + p[i] = x; + } + } + + void NemoNormalizeFeatures(float *p, int32_t num_frames, + int32_t feature_dim) const { + if (config_.nemo_normalize_type.empty()) { + return; + } + + if (config_.nemo_normalize_type != "per_feature") { + SHERPA_ONNX_LOGE( + "Only normalize_type=per_feature is implemented. Given: %s", + config_.nemo_normalize_type.c_str()); + exit(-1); + } + + NemoNormalizePerFeature(p, num_frames, feature_dim); + } + + static void NemoNormalizePerFeature(float *p, int32_t num_frames, + int32_t feature_dim) { + std::vector mean; + std::vector inv_stddev; + + ComputeMeanAndInvStd(p, num_frames, feature_dim, &mean, &inv_stddev); + + for (int32_t n = 0; n != num_frames; ++n) { + for (int32_t i = 0; i != feature_dim; ++i) { + p[i] = (p[i] - mean[i]) * inv_stddev[i]; + } + p += feature_dim; + } + } + + private: + FeatureExtractorConfig config_; + std::unique_ptr fbank_; + std::unique_ptr mfcc_; + std::unique_ptr whisper_fbank_; + knf::FbankOptions opts_; + knf::MfccOptions mfcc_opts_; + OfflineRecognitionResult r_; + ContextGraphPtr context_graph_; + bool is_ced_ = false; + bool is_moonshine_ = false; + + // used only when is_moonshine_== true + std::vector samples_; +}; + +OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/, + ContextGraphPtr context_graph /*= nullptr*/) + : impl_(std::make_unique(config, std::move(context_graph))) {} + +OfflineStream::OfflineStream(WhisperTag tag) + : impl_(std::make_unique(tag)) {} + +OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique(tag)) {} + +OfflineStream::OfflineStream(MoonshineTag tag) + : impl_(std::make_unique(tag)) {} + +OfflineStream::~OfflineStream() = default; + +void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, + int32_t n) const { + impl_->AcceptWaveform(sampling_rate, waveform, n); +} + +int32_t OfflineStream::FeatureDim() const { return impl_->FeatureDim(); } + +std::vector OfflineStream::GetFrames() const { + return impl_->GetFrames(); +} + +void OfflineStream::SetResult(const OfflineRecognitionResult &r) { + impl_->SetResult(r); +} + +const ContextGraphPtr &OfflineStream::GetContextGraph() const { + return impl_->GetContextGraph(); +} + +const OfflineRecognitionResult &OfflineStream::GetResult() const { + return impl_->GetResult(); +} +std::string OfflineRecognitionResult::AsJsonString() const { + std::ostringstream os; + os << "{"; + + os << "\"lang\"" + << ": "; + os << std::quoted(lang) << ", "; + + os << "\"emotion\"" + << ": "; + os << std::quoted(emotion) << ", "; + + os << "\"event\"" + << ": "; + os << std::quoted(event) << ", "; + + os << "\"text\"" + << ": "; + os << std::quoted(text) << ", "; + + os << "\"" + << "timestamps" + << "\"" + << ": "; + os << "["; + + std::string sep = ""; + for (auto t : timestamps) { + os << sep << std::fixed << std::setprecision(2) << t; + sep = ", "; + } + os << "], "; + + os << "\"" + << "tokens" + << "\"" + << ":"; + os << "["; + + sep = ""; + auto oldFlags = os.flags(); + for (const auto &t : tokens) { + if (t.size() == 1 && static_cast(t[0]) > 0x7f) { + const uint8_t *p = reinterpret_cast(t.c_str()); + os << sep << "\"" + << "<0x" << std::hex << std::uppercase << static_cast(p[0]) + << ">" + << "\""; + os.flags(oldFlags); + } else { + os << sep << std::quoted(t); + } + sep = ", "; + } + os << "], "; + + sep = ""; + + os << "\"" + << "words" + << "\"" + << ": "; + os << "["; + for (int32_t w : words) { + os << sep << w; + sep = ", "; + } + + os << "]"; + os << "}"; + + return os.str(); +} +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-stream.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-stream.h new file mode 100644 index 00000000..548de060 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-stream.h @@ -0,0 +1,106 @@ +// sherpa-mnn/csrc/offline-stream.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_ +#include + +#include +#include +#include + +#include "sherpa-mnn/csrc/context-graph.h" +#include "sherpa-mnn/csrc/features.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineRecognitionResult { + // Recognition results. + // For English, it consists of space separated words. + // For Chinese, it consists of Chinese words without spaces. + std::string text; + + // Decoded results at the token level. + // For instance, for BPE-based models it consists of a list of BPE tokens. + std::vector tokens; + + std::string lang; + + // emotion target of the audio. + std::string emotion; + + // event target of the audio. + std::string event; + + /// timestamps.size() == tokens.size() + /// timestamps[i] records the time in seconds when tokens[i] is decoded. + std::vector timestamps; + + std::vector words; + + std::string AsJsonString() const; +}; + +struct WhisperTag { + int32_t dim = 80; +}; + +struct CEDTag {}; + +// It uses a neural network model, a preprocessor, to convert +// audio samples to features +struct MoonshineTag {}; + +class OfflineStream { + public: + explicit OfflineStream(const FeatureExtractorConfig &config = {}, + ContextGraphPtr context_graph = {}); + + explicit OfflineStream(WhisperTag tag); + explicit OfflineStream(CEDTag tag); + explicit OfflineStream(MoonshineTag tag); + ~OfflineStream(); + + /** + @param sampling_rate The sampling_rate of the input waveform. If it does + not equal to config.sampling_rate, we will do + resampling inside. + @param waveform Pointer to a 1-D array of size n. It must be normalized to + the range [-1, 1]. + @param n Number of entries in waveform + + Caution: You can only invoke this function once so you have to input + all the samples at once + */ + void AcceptWaveform(int32_t sampling_rate, const float *waveform, + int32_t n) const; + + /// Return feature dim of this extractor. + /// + /// Note: if it is Moonshine, then it returns the number of audio samples + /// currently received. + int32_t FeatureDim() const; + + // Get all the feature frames of this stream in a 1-D array, which is + // flattened from a 2-D array of shape (num_frames, feat_dim). + std::vector GetFrames() const; + + /** Set the recognition result for this stream. */ + void SetResult(const OfflineRecognitionResult &r); + + /** Get the recognition result of this stream */ + const OfflineRecognitionResult &GetResult() const; + + /** Get the ContextGraph of this stream */ + const ContextGraphPtr &GetContextGraph() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tdnn-ctc-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tdnn-ctc-model.cc new file mode 100644 index 00000000..6c577619 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tdnn-ctc-model.cc @@ -0,0 +1,147 @@ +// sherpa-mnn/csrc/offline-tdnn-ctc-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-tdnn-ctc-model.h" + +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" +#include "sherpa-mnn/csrc/transpose.h" + +namespace sherpa_mnn { + +class OfflineTdnnCtcModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.tdnn.model); + Init(buf.data(), buf.size()); + } + + template + Impl(Manager *mgr, const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.tdnn.model); + Init(buf.data(), buf.size()); + } + + std::vector Forward(MNN::Express::VARP features) { + auto nnet_out = + sess_->onForward({features}); + + std::vector nnet_out_shape = + nnet_out[0]->getInfo()->dim; + + std::vector out_length_vec(nnet_out_shape[0], nnet_out_shape[1]); + std::vector out_length_shape(1, nnet_out_shape[0]); + + auto memory_info = + (MNNAllocator*)(nullptr); + + MNN::Express::VARP nnet_out_length = MNNUtilsCreateTensor( + memory_info, out_length_vec.data(), out_length_vec.size(), + out_length_shape.data(), out_length_shape.size()); + + std::vector ans; + ans.reserve(2); + ans.push_back(std::move(nnet_out[0])); + ans.push_back(Clone(nullptr, nnet_out_length)); + return ans; + } + + int32_t VocabSize() const { return vocab_size_; } + + MNNAllocator *Allocator() { return allocator_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + } + + private: + OfflineModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t vocab_size_ = 0; +}; + +OfflineTdnnCtcModel::OfflineTdnnCtcModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineTdnnCtcModel::OfflineTdnnCtcModel(Manager *mgr, + const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default; + +std::vector OfflineTdnnCtcModel::Forward( + MNN::Express::VARP features, MNN::Express::VARP /*features_length*/) { + return impl_->Forward(std::move(features)); +} + +int32_t OfflineTdnnCtcModel::VocabSize() const { return impl_->VocabSize(); } + +MNNAllocator *OfflineTdnnCtcModel::Allocator() const { + return impl_->Allocator(); +} + +#if __ANDROID_API__ >= 9 +template OfflineTdnnCtcModel::OfflineTdnnCtcModel( + AAssetManager *mgr, const OfflineModelConfig &config); +#endif + +#if __OHOS__ +template OfflineTdnnCtcModel::OfflineTdnnCtcModel( + NativeResourceManager *mgr, const OfflineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tdnn-ctc-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tdnn-ctc-model.h new file mode 100644 index 00000000..f2a4d057 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tdnn-ctc-model.h @@ -0,0 +1,59 @@ +// sherpa-mnn/csrc/offline-tdnn-ctc-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-ctc-model.h" +#include "sherpa-mnn/csrc/offline-model-config.h" + +namespace sherpa_mnn { + +/** This class implements the tdnn model of the yesno recipe from icefall. + * + * See + * https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn + */ +class OfflineTdnnCtcModel : public OfflineCtcModel { + public: + explicit OfflineTdnnCtcModel(const OfflineModelConfig &config); + + template + OfflineTdnnCtcModel(Manager *mgr, const OfflineModelConfig &config); + + ~OfflineTdnnCtcModel() override; + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int. + * + * @return Return a pair containing: + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int + */ + std::vector Forward(MNN::Express::VARP features, + MNN::Express::VARP /*features_length*/) override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tdnn-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tdnn-model-config.cc new file mode 100644 index 00000000..a3821859 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tdnn-model-config.cc @@ -0,0 +1,34 @@ +// sherpa-mnn/csrc/offline-tdnn-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-tdnn-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineTdnnModelConfig::Register(ParseOptions *po) { + po->Register("tdnn-model", &model, "Path to onnx model"); +} + +bool OfflineTdnnModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("tdnn model file %s does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OfflineTdnnModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineTdnnModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tdnn-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tdnn-model-config.h new file mode 100644 index 00000000..39659902 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tdnn-model-config.h @@ -0,0 +1,28 @@ +// sherpa-mnn/csrc/offline-tdnn-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +// for https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn +struct OfflineTdnnModelConfig { + std::string model; + + OfflineTdnnModelConfig() = default; + explicit OfflineTdnnModelConfig(const std::string &model) : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-telespeech-ctc-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-telespeech-ctc-model.cc new file mode 100644 index 00000000..22c60503 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-telespeech-ctc-model.cc @@ -0,0 +1,161 @@ +// sherpa-mnn/csrc/offline-telespeech-ctc-model.cc +// +// Copyright (c) 2023-2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-telespeech-ctc-model.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" +#include "sherpa-mnn/csrc/transpose.h" + +namespace sherpa_mnn { + +class OfflineTeleSpeechCtcModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.telespeech_ctc); + Init(buf.data(), buf.size()); + } + + template + Impl(Manager *mgr, const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.telespeech_ctc); + Init(buf.data(), buf.size()); + } + + std::vector Forward(MNN::Express::VARP features, + MNN::Express::VARP /*features_length*/) { + std::vector shape = + features->getInfo()->dim; + + if (static_cast(shape[0]) != 1) { + SHERPA_ONNX_LOGE("This model supports only batch size 1. Given %d", + static_cast(shape[0])); + } + + auto out = sess_->onForward({features}); + + std::vector logits_shape = {1}; + MNN::Express::VARP logits_length = MNNUtilsCreateTensor( + allocator_, logits_shape.data(), logits_shape.size()); + + int *dst = logits_length->writeMap(); + dst[0] = out[0]->getInfo()->dim[0]; + + // (T, B, C) -> (B, T, C) + MNN::Express::VARP logits = Transpose01(allocator_, out[0]); + + std::vector ans; + ans.reserve(2); + ans.push_back(std::move(logits)); + ans.push_back(std::move(logits_length)); + + return ans; + } + + int32_t VocabSize() const { return vocab_size_; } + + int32_t SubsamplingFactor() const { return subsampling_factor_; } + + MNNAllocator *Allocator() { return allocator_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + auto iter = meta_data.find("vocab_size"); + if (iter != meta_data.end()){ + vocab_size_ = std::stoi(iter->second); + } + } + + private: + OfflineModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t vocab_size_ = 0; + int32_t subsampling_factor_ = 4; +}; + +OfflineTeleSpeechCtcModel::OfflineTeleSpeechCtcModel( + const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineTeleSpeechCtcModel::OfflineTeleSpeechCtcModel( + Manager *mgr, const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineTeleSpeechCtcModel::~OfflineTeleSpeechCtcModel() = default; + +std::vector OfflineTeleSpeechCtcModel::Forward( + MNN::Express::VARP features, MNN::Express::VARP features_length) { + return impl_->Forward(std::move(features), std::move(features_length)); +} + +int32_t OfflineTeleSpeechCtcModel::VocabSize() const { + return impl_->VocabSize(); +} +int32_t OfflineTeleSpeechCtcModel::SubsamplingFactor() const { + return impl_->SubsamplingFactor(); +} + +MNNAllocator *OfflineTeleSpeechCtcModel::Allocator() const { + return impl_->Allocator(); +} + +#if __ANDROID_API__ >= 9 +template OfflineTeleSpeechCtcModel::OfflineTeleSpeechCtcModel( + AAssetManager *mgr, const OfflineModelConfig &config); +#endif + +#if __OHOS__ +template OfflineTeleSpeechCtcModel::OfflineTeleSpeechCtcModel( + NativeResourceManager *mgr, const OfflineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-telespeech-ctc-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-telespeech-ctc-model.h new file mode 100644 index 00000000..dcb2cf32 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-telespeech-ctc-model.h @@ -0,0 +1,74 @@ +// sherpa-mnn/csrc/offline-telespeech-ctc-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TELESPEECH_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TELESPEECH_CTC_MODEL_H_ +#include +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-ctc-model.h" +#include "sherpa-mnn/csrc/offline-model-config.h" + +namespace sherpa_mnn { + +/** This class implements the CTC model from + * https://github.com/Tele-AI/TeleSpeech-ASR. + * + * See + * https://github.com/lovemefan/telespeech-asr-python/blob/main/telespeechasr/onnx/onnx_infer.py + * and + * https://github.com/k2-fsa/sherpa-mnn/blob/master/scripts/tele-speech/test.py + */ +class OfflineTeleSpeechCtcModel : public OfflineCtcModel { + public: + explicit OfflineTeleSpeechCtcModel(const OfflineModelConfig &config); + + template + OfflineTeleSpeechCtcModel(Manager *mgr, const OfflineModelConfig &config); + + ~OfflineTeleSpeechCtcModel() override; + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int. + * + * @return Return a vector containing: + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int + */ + std::vector Forward(MNN::Express::VARP features, + MNN::Express::VARP features_length) override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** SubsamplingFactor of the model + */ + int32_t SubsamplingFactor() const override; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const override; + + // TeleSpeech CTC models do not support batch size > 1 + bool SupportBatchProcessing() const override { return false; } + + std::string FeatureNormalizationMethod() const override { + return "per_feature"; + } + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TELESPEECH_CTC_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-decoder.h new file mode 100644 index 00000000..6aa35b79 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-decoder.h @@ -0,0 +1,43 @@ +// sherpa-mnn/csrc/offline-transducer-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-stream.h" + +namespace sherpa_mnn { + +struct OfflineTransducerDecoderResult { + /// The decoded token IDs + std::vector tokens; + + /// timestamps[i] contains the output frame index where tokens[i] is decoded. + /// Note: The index is after subsampling + std::vector timestamps; +}; + +class OfflineTransducerDecoder { + public: + virtual ~OfflineTransducerDecoder() = default; + + /** Run transducer beam search given the output from the encoder model. + * + * @param encoder_out A 3-D tensor of shape (N, T, joiner_dim) + * @param encoder_out_length A 1-D tensor of shape (N,) containing number + * of valid frames in encoder_out before padding. + * + * @return Return a vector of size `N` containing the decoded results. + */ + virtual std::vector Decode( + MNN::Express::VARP encoder_out, MNN::Express::VARP encoder_out_length, + OfflineStream **ss = nullptr, int32_t n = 0) = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-greedy-search-decoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-greedy-search-decoder.cc new file mode 100644 index 00000000..dc388d26 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-greedy-search-decoder.cc @@ -0,0 +1,87 @@ +// sherpa-mnn/csrc/offline-transducer-greedy-search-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-transducer-greedy-search-decoder.h" + +#include +#include +#include + +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/packed-sequence.h" +#include "sherpa-mnn/csrc/slice.h" + +namespace sherpa_mnn { + +std::vector +OfflineTransducerGreedySearchDecoder::Decode(MNN::Express::VARP encoder_out, + MNN::Express::VARP encoder_out_length, + OfflineStream **ss /*= nullptr*/, + int32_t n /*= 0*/) { + PackedSequence packed_encoder_out = PackPaddedSequence( + model_->Allocator(), encoder_out, encoder_out_length); + + int32_t batch_size = + static_cast(packed_encoder_out.sorted_indexes.size()); + + int32_t vocab_size = model_->VocabSize(); + int32_t context_size = model_->ContextSize(); + + std::vector ans(batch_size); + for (auto &r : ans) { + r.tokens.resize(context_size, -1); + // 0 is the ID of the blank token + r.tokens.back() = 0; + } + + auto decoder_input = model_->BuildDecoderInput(ans, ans.size()); + MNN::Express::VARP decoder_out = model_->RunDecoder(std::move(decoder_input)); + + int32_t start = 0; + int32_t t = 0; + for (auto n : packed_encoder_out.batch_sizes) { + MNN::Express::VARP cur_encoder_out = packed_encoder_out.Get(start, n); + MNN::Express::VARP cur_decoder_out = Slice(model_->Allocator(), decoder_out, 0, n); + start += n; + MNN::Express::VARP logit = model_->RunJoiner(std::move(cur_encoder_out), + std::move(cur_decoder_out)); + float *p_logit = logit->writeMap(); + bool emitted = false; + for (int32_t i = 0; i != n; ++i) { + if (blank_penalty_ > 0.0) { + p_logit[0] -= blank_penalty_; // assuming blank id is 0 + } + auto y = static_cast(std::distance( + static_cast(p_logit), + std::max_element(static_cast(p_logit), + static_cast(p_logit) + vocab_size))); + p_logit += vocab_size; + // blank id is hardcoded to 0 + // also, it treats unk as blank + if (y != 0 && y != unk_id_) { + ans[i].tokens.push_back(y); + ans[i].timestamps.push_back(t); + emitted = true; + } + } + if (emitted) { + MNN::Express::VARP decoder_input = model_->BuildDecoderInput(ans, n); + decoder_out = model_->RunDecoder(std::move(decoder_input)); + } + ++t; + } + + for (auto &r : ans) { + r.tokens = {r.tokens.begin() + context_size, r.tokens.end()}; + } + + std::vector unsorted_ans(batch_size); + for (int32_t i = 0; i != batch_size; ++i) { + unsorted_ans[packed_encoder_out.sorted_indexes[i]] = std::move(ans[i]); + } + + return unsorted_ans; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-greedy-search-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-greedy-search-decoder.h new file mode 100644 index 00000000..9e0c0df2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-greedy-search-decoder.h @@ -0,0 +1,34 @@ +// sherpa-mnn/csrc/offline-transducer-greedy-search-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ + +#include + +#include "sherpa-mnn/csrc/offline-transducer-decoder.h" +#include "sherpa-mnn/csrc/offline-transducer-model.h" + +namespace sherpa_mnn { + +class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder { + public: + OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model, + int32_t unk_id, + float blank_penalty) + : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} + + std::vector Decode( + MNN::Express::VARP encoder_out, MNN::Express::VARP encoder_out_length, + OfflineStream **ss = nullptr, int32_t n = 0) override; + + private: + OfflineTransducerModel *model_; // Not owned + int32_t unk_id_; + float blank_penalty_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-greedy-search-nemo-decoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-greedy-search-nemo-decoder.cc new file mode 100644 index 00000000..e5f422c1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-greedy-search-nemo-decoder.cc @@ -0,0 +1,117 @@ +// sherpa-mnn/csrc/offline-transducer-greedy-search-nemo-decoder.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-transducer-greedy-search-nemo-decoder.h" + +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +static std::pair BuildDecoderInput( + int32_t token, MNNAllocator *allocator) { + std::array shape{1, 1}; + + MNN::Express::VARP decoder_input = + MNNUtilsCreateTensor(allocator, shape.data(), shape.size()); + + std::array length_shape{1}; + MNN::Express::VARP decoder_input_length = MNNUtilsCreateTensor( + allocator, length_shape.data(), length_shape.size()); + + int32_t *p = decoder_input->writeMap(); + + int32_t *p_length = decoder_input_length->writeMap(); + + p[0] = token; + + p_length[0] = 1; + + return {std::move(decoder_input), std::move(decoder_input_length)}; +} + +static OfflineTransducerDecoderResult DecodeOne( + const float *p, int32_t num_rows, int32_t num_cols, + OfflineTransducerNeMoModel *model, float blank_penalty) { + auto memory_info = + (MNNAllocator*)(nullptr); + + OfflineTransducerDecoderResult ans; + + int32_t vocab_size = model->VocabSize(); + int32_t blank_id = vocab_size - 1; + + auto decoder_input_pair = BuildDecoderInput(blank_id, model->Allocator()); + + std::pair> decoder_output_pair = + model->RunDecoder(std::move(decoder_input_pair.first), + std::move(decoder_input_pair.second), + model->GetDecoderInitStates(1)); + + std::array encoder_shape{1, num_cols, 1}; + + for (int32_t t = 0; t != num_rows; ++t) { + MNN::Express::VARP cur_encoder_out = MNNUtilsCreateTensor( + memory_info, const_cast(p) + t * num_cols, num_cols, + encoder_shape.data(), encoder_shape.size()); + + MNN::Express::VARP logit = model->RunJoiner(std::move(cur_encoder_out), + View(decoder_output_pair.first)); + + float *p_logit = logit->writeMap(); + if (blank_penalty > 0) { + p_logit[blank_id] -= blank_penalty; + } + + auto y = static_cast(std::distance( + static_cast(p_logit), + std::max_element(static_cast(p_logit), + static_cast(p_logit) + vocab_size))); + + if (y != blank_id) { + ans.tokens.push_back(y); + ans.timestamps.push_back(t); + + decoder_input_pair = BuildDecoderInput(y, model->Allocator()); + + decoder_output_pair = + model->RunDecoder(std::move(decoder_input_pair.first), + std::move(decoder_input_pair.second), + std::move(decoder_output_pair.second)); + } // if (y != blank_id) + } // for (int32_t i = 0; i != num_rows; ++i) + + return ans; +} + +std::vector +OfflineTransducerGreedySearchNeMoDecoder::Decode( + MNN::Express::VARP encoder_out, MNN::Express::VARP encoder_out_length, + OfflineStream ** /*ss = nullptr*/, int32_t /*n= 0*/) { + auto shape = encoder_out->getInfo()->dim; + + int32_t batch_size = static_cast(shape[0]); + int32_t dim1 = static_cast(shape[1]); + int32_t dim2 = static_cast(shape[2]); + + const int *p_length = encoder_out_length->readMap(); + const float *p = encoder_out->readMap(); + + std::vector ans(batch_size); + + for (int32_t i = 0; i != batch_size; ++i) { + const float *this_p = p + dim1 * dim2 * i; + int32_t this_len = p_length[i]; + + ans[i] = DecodeOne(this_p, this_len, dim2, model_, blank_penalty_); + } + + return ans; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-greedy-search-nemo-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-greedy-search-nemo-decoder.h new file mode 100644 index 00000000..42b880c1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-greedy-search-nemo-decoder.h @@ -0,0 +1,33 @@ +// sherpa-mnn/csrc/offline-transducer-greedy-search-nemo-decoder.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ + +#include + +#include "sherpa-mnn/csrc/offline-transducer-decoder.h" +#include "sherpa-mnn/csrc/offline-transducer-nemo-model.h" + +namespace sherpa_mnn { + +class OfflineTransducerGreedySearchNeMoDecoder + : public OfflineTransducerDecoder { + public: + OfflineTransducerGreedySearchNeMoDecoder(OfflineTransducerNeMoModel *model, + float blank_penalty) + : model_(model), blank_penalty_(blank_penalty) {} + + std::vector Decode( + MNN::Express::VARP encoder_out, MNN::Express::VARP encoder_out_length, + OfflineStream **ss = nullptr, int32_t n = 0) override; + + private: + OfflineTransducerNeMoModel *model_; // Not owned + float blank_penalty_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-model-config.cc new file mode 100644 index 00000000..7f5f371d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-model-config.cc @@ -0,0 +1,52 @@ +// sherpa-mnn/csrc/offline-transducer-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-mnn/csrc/offline-transducer-model-config.h" + +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineTransducerModelConfig::Register(ParseOptions *po) { + po->Register("encoder", &encoder_filename, "Path to encoder.onnx"); + po->Register("decoder", &decoder_filename, "Path to decoder.onnx"); + po->Register("joiner", &joiner_filename, "Path to joiner.onnx"); +} + +bool OfflineTransducerModelConfig::Validate() const { + if (!FileExists(encoder_filename)) { + SHERPA_ONNX_LOGE("transducer encoder: '%s' does not exist", + encoder_filename.c_str()); + return false; + } + + if (!FileExists(decoder_filename)) { + SHERPA_ONNX_LOGE("transducer decoder: '%s' does not exist", + decoder_filename.c_str()); + return false; + } + + if (!FileExists(joiner_filename)) { + SHERPA_ONNX_LOGE("transducer joiner: '%s' does not exist", + joiner_filename.c_str()); + return false; + } + + return true; +} + +std::string OfflineTransducerModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineTransducerModelConfig("; + os << "encoder_filename=\"" << encoder_filename << "\", "; + os << "decoder_filename=\"" << decoder_filename << "\", "; + os << "joiner_filename=\"" << joiner_filename << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-model-config.h new file mode 100644 index 00000000..34c68f40 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-model-config.h @@ -0,0 +1,34 @@ +// sherpa-mnn/csrc/offline-transducer-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineTransducerModelConfig { + std::string encoder_filename; + std::string decoder_filename; + std::string joiner_filename; + + OfflineTransducerModelConfig() = default; + OfflineTransducerModelConfig(const std::string &encoder_filename, + const std::string &decoder_filename, + const std::string &joiner_filename) + : encoder_filename(encoder_filename), + decoder_filename(decoder_filename), + joiner_filename(joiner_filename) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-model.cc new file mode 100644 index 00000000..e4c73f34 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-model.cc @@ -0,0 +1,307 @@ +// sherpa-mnn/csrc/offline-transducer-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-transducer-model.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-transducer-decoder.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" + +namespace sherpa_mnn { + +class OfflineTransducerModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.transducer.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.transducer.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } + } + + std::pair RunEncoder(MNN::Express::VARP features, + MNN::Express::VARP features_length) { + std::vector encoder_inputs = {std::move(features), + std::move(features_length)}; + + auto encoder_out = encoder_sess_->onForward(encoder_inputs); + + return {std::move(encoder_out[0]), std::move(encoder_out[1])}; + } + + MNN::Express::VARP RunDecoder(MNN::Express::VARP decoder_input) { + auto decoder_out = decoder_sess_->onForward({decoder_input}); + return std::move(decoder_out[0]); + } + + MNN::Express::VARP RunJoiner(MNN::Express::VARP encoder_out, MNN::Express::VARP decoder_out) { + std::vector joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = joiner_sess_->onForward( + joiner_input); + + return std::move(logit[0]); + } + + int32_t VocabSize() const { return vocab_size_; } + int32_t ContextSize() const { return context_size_; } + int32_t SubsamplingFactor() const { return 4; } + MNNAllocator *Allocator() { return allocator_; } + + MNN::Express::VARP BuildDecoderInput( + const std::vector &results, + int32_t end_index) { + assert(end_index <= results.size()); + + int32_t batch_size = end_index; + int32_t context_size = ContextSize(); + std::array shape{batch_size, context_size}; + + MNN::Express::VARP decoder_input = MNNUtilsCreateTensor( + Allocator(), shape.data(), shape.size()); + int *p = decoder_input->writeMap(); + + for (int32_t i = 0; i != batch_size; ++i) { + const auto &r = results[i]; + const int *begin = r.tokens.data() + r.tokens.size() - context_size; + const int *end = r.tokens.data() + r.tokens.size(); + std::copy(begin, end, p); + p += context_size; + } + + return decoder_input; + } + + MNN::Express::VARP BuildDecoderInput(const std::vector &results, + int32_t end_index) { + assert(end_index <= results.size()); + + int32_t batch_size = end_index; + int32_t context_size = ContextSize(); + std::array shape{batch_size, context_size}; + + MNN::Express::VARP decoder_input = MNNUtilsCreateTensor( + Allocator(), shape.data(), shape.size()); + int *p = decoder_input->writeMap(); + + for (int32_t i = 0; i != batch_size; ++i) { + const auto &r = results[i]; + const int *begin = r.ys.data() + r.ys.size() - context_size; + const int *end = r.ys.data() + r.ys.size(); + std::copy(begin, end, p); + p += context_size; + } + + return decoder_input; + } + + private: + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = encoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + } + + void InitDecoder(void *model_data, size_t model_data_length) { + decoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = decoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---decoder---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); + } + + void InitJoiner(void *model_data, size_t model_data_length) { + joiner_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + + // get meta data + MNNMeta meta_data = joiner_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---joiner---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + } + + private: + OfflineModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + int32_t vocab_size_ = 0; // initialized in InitDecoder + int32_t context_size_ = 0; // initialized in InitDecoder +}; + +OfflineTransducerModel::OfflineTransducerModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineTransducerModel::OfflineTransducerModel(Manager *mgr, + const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineTransducerModel::~OfflineTransducerModel() = default; + +std::pair OfflineTransducerModel::RunEncoder( + MNN::Express::VARP features, MNN::Express::VARP features_length) { + return impl_->RunEncoder(std::move(features), std::move(features_length)); +} + +MNN::Express::VARP OfflineTransducerModel::RunDecoder(MNN::Express::VARP decoder_input) { + return impl_->RunDecoder(std::move(decoder_input)); +} + +MNN::Express::VARP OfflineTransducerModel::RunJoiner(MNN::Express::VARP encoder_out, + MNN::Express::VARP decoder_out) { + return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out)); +} + +int32_t OfflineTransducerModel::VocabSize() const { return impl_->VocabSize(); } + +int32_t OfflineTransducerModel::ContextSize() const { + return impl_->ContextSize(); +} + +int32_t OfflineTransducerModel::SubsamplingFactor() const { + return impl_->SubsamplingFactor(); +} + +MNNAllocator *OfflineTransducerModel::Allocator() const { + return impl_->Allocator(); +} + +MNN::Express::VARP OfflineTransducerModel::BuildDecoderInput( + const std::vector &results, + int32_t end_index) const { + return impl_->BuildDecoderInput(results, end_index); +} + +MNN::Express::VARP OfflineTransducerModel::BuildDecoderInput( + const std::vector &results, int32_t end_index) const { + return impl_->BuildDecoderInput(results, end_index); +} + +#if __ANDROID_API__ >= 9 +template OfflineTransducerModel::OfflineTransducerModel( + AAssetManager *mgr, const OfflineModelConfig &config); +#endif + +#if __OHOS__ +template OfflineTransducerModel::OfflineTransducerModel( + NativeResourceManager *mgr, const OfflineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-model.h new file mode 100644 index 00000000..035fc7c5 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-model.h @@ -0,0 +1,104 @@ +// sherpa-mnn/csrc/offline-transducer-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_ + +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/hypothesis.h" +#include "sherpa-mnn/csrc/offline-model-config.h" + +namespace sherpa_mnn { + +struct OfflineTransducerDecoderResult; + +class OfflineTransducerModel { + public: + explicit OfflineTransducerModel(const OfflineModelConfig &config); + + template + OfflineTransducerModel(Manager *mgr, const OfflineModelConfig &config); + + ~OfflineTransducerModel(); + + /** Run the encoder. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int. + * + * @return Return a pair containing: + * - encoder_out: A 3-D tensor of shape (N, T', encoder_dim) + * - encoder_out_length: A 1-D tensor of shape (N,) containing number + * of frames in `encoder_out` before padding. + */ + std::pair RunEncoder(MNN::Express::VARP features, + MNN::Express::VARP features_length); + + /** Run the decoder network. + * + * Caution: We assume there are no recurrent connections in the decoder and + * the decoder is stateless. See + * https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py + * for an example + * + * @param decoder_input It is usually of shape (N, context_size) + * @return Return a tensor of shape (N, decoder_dim). + */ + MNN::Express::VARP RunDecoder(MNN::Express::VARP decoder_input); + + /** Run the joint network. + * + * @param encoder_out Output of the encoder network. A tensor of shape + * (N, joiner_dim). + * @param decoder_out Output of the decoder network. A tensor of shape + * (N, joiner_dim). + * @return Return a tensor of shape (N, vocab_size). In icefall, the last + * last layer of the joint network is `nn.Linear`, + * not `nn.LogSoftmax`. + */ + MNN::Express::VARP RunJoiner(MNN::Express::VARP encoder_out, MNN::Express::VARP decoder_out); + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const; + + /** Return the context_size of the decoder model. + */ + int32_t ContextSize() const; + + /** Return the subsampling factor of the model. + */ + int32_t SubsamplingFactor() const; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const; + + /** Build decoder_input from the current results. + * + * @param results Current decoded results. + * @param end_index We only use results[0:end_index] to build + * the decoder_input. results[end_index] is not used. + * @return Return a tensor of shape (results.size(), ContextSize()) + */ + MNN::Express::VARP BuildDecoderInput( + const std::vector &results, + int32_t end_index) const; + + MNN::Express::VARP BuildDecoderInput(const std::vector &results, + int32_t end_index) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-modified-beam-search-decoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-modified-beam-search-decoder.cc new file mode 100644 index 00000000..1b338027 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-modified-beam-search-decoder.cc @@ -0,0 +1,192 @@ +// sherpa-mnn/csrc/offline-transducer-modified-beam-search-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-transducer-modified-beam-search-decoder.h" + +#include +#include +#include + +#include "sherpa-mnn/csrc/context-graph.h" +#include "sherpa-mnn/csrc/hypothesis.h" +#include "sherpa-mnn/csrc/log.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/packed-sequence.h" +#include "sherpa-mnn/csrc/slice.h" + +namespace sherpa_mnn { + +std::vector +OfflineTransducerModifiedBeamSearchDecoder::Decode( + MNN::Express::VARP encoder_out, MNN::Express::VARP encoder_out_length, + OfflineStream **ss /*=nullptr */, int32_t n /*= 0*/) { + PackedSequence packed_encoder_out = PackPaddedSequence( + model_->Allocator(), encoder_out, encoder_out_length); + + int32_t batch_size = + static_cast(packed_encoder_out.sorted_indexes.size()); + + if (ss != nullptr) SHERPA_ONNX_CHECK_EQ(batch_size, n); + + int32_t vocab_size = model_->VocabSize(); + int32_t context_size = model_->ContextSize(); + + std::vector blanks(context_size, -1); + blanks.back() = 0; + + std::deque finalized; + std::vector cur; + std::vector prev; + + std::vector context_graphs(batch_size, nullptr); + + for (int32_t i = 0; i < batch_size; ++i) { + const ContextState *context_state = nullptr; + if (ss != nullptr) { + context_graphs[i] = + ss[packed_encoder_out.sorted_indexes[i]]->GetContextGraph(); + if (context_graphs[i] != nullptr) + context_state = context_graphs[i]->Root(); + } + Hypotheses blank_hyp({{blanks, 0, context_state}}); + cur.emplace_back(std::move(blank_hyp)); + } + + int32_t start = 0; + int32_t t = 0; + for (auto n : packed_encoder_out.batch_sizes) { + MNN::Express::VARP cur_encoder_out = packed_encoder_out.Get(start, n); + start += n; + + if (n < static_cast(cur.size())) { + for (int32_t k = static_cast(cur.size()) - 1; k >= n; --k) { + finalized.push_front(std::move(cur[k])); + } + + cur.erase(cur.begin() + n, cur.end()); + } // if (n < static_cast(cur.size())) + + // Due to merging paths with identical token sequences, + // not all utterances have "max_active_paths" paths. + auto hyps_row_splits = GetHypsRowSplits(cur); + int32_t num_hyps = hyps_row_splits.back(); + + prev.clear(); + prev.reserve(num_hyps); + + for (auto &hyps : cur) { + for (auto &h : hyps) { + prev.push_back(std::move(h.second)); + } + } + cur.clear(); + cur.reserve(n); + + auto decoder_input = model_->BuildDecoderInput(prev, num_hyps); + // decoder_input shape: (num_hyps, context_size) + + auto decoder_out = model_->RunDecoder(std::move(decoder_input)); + // decoder_out is (num_hyps, joiner_dim) + + cur_encoder_out = + Repeat(model_->Allocator(), cur_encoder_out, hyps_row_splits); + // now cur_encoder_out is of shape (num_hyps, joiner_dim) + + MNN::Express::VARP logit = + model_->RunJoiner(std::move(cur_encoder_out), View(decoder_out)); + + float *p_logit = logit->writeMap(); + if (blank_penalty_ > 0.0) { + // assuming blank id is 0 + SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_); + } + LogSoftmax(p_logit, vocab_size, num_hyps); + + // now p_logit contains log_softmax output, we rename it to p_logprob + // to match what it actually contains + float *p_logprob = p_logit; + + // add log_prob of each hypothesis to p_logprob before taking top_k + for (int32_t i = 0; i != num_hyps; ++i) { + float log_prob = prev[i].log_prob; + for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) { + *p_logprob += log_prob; + } + } + p_logprob = p_logit; // we changed p_logprob in the above for loop + + // Now compute top_k for each utterance + for (int32_t i = 0; i != n; ++i) { + int32_t start = hyps_row_splits[i]; + int32_t end = hyps_row_splits[i + 1]; + auto topk = + TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_); + + Hypotheses hyps; + for (auto k : topk) { + int32_t hyp_index = k / vocab_size + start; + int32_t new_token = k % vocab_size; + Hypothesis new_hyp = prev[hyp_index]; + + float context_score = 0; + auto context_state = new_hyp.context_state; + // blank is hardcoded to 0 + // also, it treats unk as blank + if (new_token != 0 && new_token != unk_id_) { + new_hyp.ys.push_back(new_token); + new_hyp.timestamps.push_back(t); + if (context_graphs[i] != nullptr) { + auto context_res = + context_graphs[i]->ForwardOneStep(context_state, new_token); + context_score = std::get<0>(context_res); + new_hyp.context_state = std::get<1>(context_res); + } + } + + new_hyp.log_prob = p_logprob[k] + context_score; + hyps.Add(std::move(new_hyp)); + } // for (auto k : topk) + p_logprob += (end - start) * vocab_size; + cur.push_back(std::move(hyps)); + } // for (int32_t i = 0; i != n; ++i) + + ++t; + } // for (auto n : packed_encoder_out.batch_sizes) + + for (auto &h : finalized) { + cur.push_back(std::move(h)); + } + + // Finalize context biasing matching.. + for (int32_t i = 0; i < cur.size(); ++i) { + for (auto iter = cur[i].begin(); iter != cur[i].end(); ++iter) { + if (context_graphs[i] != nullptr) { + auto context_res = + context_graphs[i]->Finalize(iter->second.context_state); + iter->second.log_prob += context_res.first; + iter->second.context_state = context_res.second; + } + } + } + + if (lm_) { + // use LM for rescoring + lm_->ComputeLMScore(lm_scale_, context_size, &cur); + } + + std::vector unsorted_ans(batch_size); + for (int32_t i = 0; i != batch_size; ++i) { + Hypothesis hyp = cur[i].GetMostProbable(true); + + auto &r = unsorted_ans[packed_encoder_out.sorted_indexes[i]]; + + // strip leading blanks + r.tokens = {hyp.ys.begin() + context_size, hyp.ys.end()}; + r.timestamps = std::move(hyp.timestamps); + } + + return unsorted_ans; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-modified-beam-search-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-modified-beam-search-decoder.h new file mode 100644 index 00000000..8c636740 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-modified-beam-search-decoder.h @@ -0,0 +1,47 @@ +// sherpa-mnn/csrc/offline-transducer-modified-beam-search-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ + +#include + +#include "sherpa-mnn/csrc/offline-lm.h" +#include "sherpa-mnn/csrc/offline-transducer-decoder.h" +#include "sherpa-mnn/csrc/offline-transducer-model.h" + +namespace sherpa_mnn { + +class OfflineTransducerModifiedBeamSearchDecoder + : public OfflineTransducerDecoder { + public: + OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model, + OfflineLM *lm, + int32_t max_active_paths, + float lm_scale, int32_t unk_id, + float blank_penalty) + : model_(model), + lm_(lm), + max_active_paths_(max_active_paths), + lm_scale_(lm_scale), + unk_id_(unk_id), + blank_penalty_(blank_penalty) {} + + std::vector Decode( + MNN::Express::VARP encoder_out, MNN::Express::VARP encoder_out_length, + OfflineStream **ss = nullptr, int32_t n = 0) override; + + private: + OfflineTransducerModel *model_; // Not owned + OfflineLM *lm_; // Not owned; may be nullptr + + int32_t max_active_paths_; + float lm_scale_; // used only when lm_ is not nullptr + int32_t unk_id_; + float blank_penalty_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-nemo-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-nemo-model.cc new file mode 100644 index 00000000..fb0a0616 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-nemo-model.cc @@ -0,0 +1,320 @@ +// sherpa-mnn/csrc/offline-transducer-nemo-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-transducer-nemo-model.h" + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-transducer-decoder.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/transpose.h" + +namespace sherpa_mnn { + +class OfflineTransducerNeMoModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.transducer.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.transducer.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } + } + + std::vector RunEncoder(MNN::Express::VARP features, + MNN::Express::VARP features_length) { + // (B, T, C) -> (B, C, T) + features = Transpose12(allocator_, features); + + std::vector encoder_inputs = {std::move(features), + std::move(features_length)}; + + auto encoder_out = encoder_sess_->onForward(encoder_inputs); + + return encoder_out; + } + + std::pair> RunDecoder( + MNN::Express::VARP targets, MNN::Express::VARP targets_length, + std::vector states) { + std::vector decoder_inputs; + decoder_inputs.reserve(2 + states.size()); + + decoder_inputs.push_back(std::move(targets)); + decoder_inputs.push_back(std::move(targets_length)); + + for (auto &s : states) { + decoder_inputs.push_back(std::move(s)); + } + + auto decoder_out = decoder_sess_->onForward(decoder_inputs); + + std::vector states_next; + states_next.reserve(states.size()); + + // decoder_out[0]: decoder_output + // decoder_out[1]: decoder_output_length + // decoder_out[2:] states_next + + for (int32_t i = 0; i != states.size(); ++i) { + states_next.push_back(std::move(decoder_out[i + 2])); + } + + // we discard decoder_out[1] + return {std::move(decoder_out[0]), std::move(states_next)}; + } + + MNN::Express::VARP RunJoiner(MNN::Express::VARP encoder_out, MNN::Express::VARP decoder_out) { + std::vector joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = joiner_sess_->onForward( + joiner_input); + + return std::move(logit[0]); + } + + std::vector GetDecoderInitStates(int32_t batch_size) { + std::array s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; + MNN::Express::VARP s0 = MNNUtilsCreateTensor(allocator_, s0_shape.data(), + s0_shape.size()); + + Fill(s0, 0); + + std::array s1_shape{pred_rnn_layers_, batch_size, pred_hidden_}; + + MNN::Express::VARP s1 = MNNUtilsCreateTensor(allocator_, s1_shape.data(), + s1_shape.size()); + + Fill(s1, 0); + + std::vector states; + + states.reserve(2); + states.push_back(std::move(s0)); + states.push_back(std::move(s1)); + + return states; + } + + int32_t SubsamplingFactor() const { return subsampling_factor_; } + int32_t VocabSize() const { return vocab_size_; } + + MNNAllocator *Allocator() { return allocator_; } + + std::string FeatureNormalizationMethod() const { return normalize_type_; } + + bool IsGigaAM() const { return is_giga_am_; } + + private: + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = encoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + + // need to increase by 1 since the blank token is not included in computing + // vocab_size in NeMo. + vocab_size_ += 1; + + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); + SHERPA_ONNX_READ_META_DATA_STR_ALLOW_EMPTY(normalize_type_, + "normalize_type"); + SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers"); + SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden"); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(is_giga_am_, "is_giga_am", 0); + + if (normalize_type_ == "NA") { + normalize_type_ = ""; + } + } + + void InitDecoder(void *model_data, size_t model_data_length) { + decoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + } + + void InitJoiner(void *model_data, size_t model_data_length) { + joiner_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + } + + private: + OfflineModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + int32_t vocab_size_ = 0; + int32_t subsampling_factor_ = 8; + std::string normalize_type_; + int32_t pred_rnn_layers_ = -1; + int32_t pred_hidden_ = -1; + int32_t is_giga_am_ = 0; +}; + +OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( + const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( + Manager *mgr, const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineTransducerNeMoModel::~OfflineTransducerNeMoModel() = default; + +std::vector OfflineTransducerNeMoModel::RunEncoder( + MNN::Express::VARP features, MNN::Express::VARP features_length) const { + return impl_->RunEncoder(std::move(features), std::move(features_length)); +} + +std::pair> +OfflineTransducerNeMoModel::RunDecoder(MNN::Express::VARP targets, + MNN::Express::VARP targets_length, + std::vector states) const { + return impl_->RunDecoder(std::move(targets), std::move(targets_length), + std::move(states)); +} + +std::vector OfflineTransducerNeMoModel::GetDecoderInitStates( + int32_t batch_size) const { + return impl_->GetDecoderInitStates(batch_size); +} + +MNN::Express::VARP OfflineTransducerNeMoModel::RunJoiner(MNN::Express::VARP encoder_out, + MNN::Express::VARP decoder_out) const { + return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out)); +} + +int32_t OfflineTransducerNeMoModel::SubsamplingFactor() const { + return impl_->SubsamplingFactor(); +} + +int32_t OfflineTransducerNeMoModel::VocabSize() const { + return impl_->VocabSize(); +} + +MNNAllocator *OfflineTransducerNeMoModel::Allocator() const { + return impl_->Allocator(); +} + +std::string OfflineTransducerNeMoModel::FeatureNormalizationMethod() const { + return impl_->FeatureNormalizationMethod(); +} + +bool OfflineTransducerNeMoModel::IsGigaAM() const { return impl_->IsGigaAM(); } + +#if __ANDROID_API__ >= 9 +template OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( + AAssetManager *mgr, const OfflineModelConfig &config); +#endif + +#if __OHOS__ +template OfflineTransducerNeMoModel::OfflineTransducerNeMoModel( + NativeResourceManager *mgr, const OfflineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-nemo-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-nemo-model.h new file mode 100644 index 00000000..b5369f10 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-transducer-nemo-model.h @@ -0,0 +1,98 @@ +// sherpa-mnn/csrc/offline-transducer-nemo-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_ + +#include +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-model-config.h" + +namespace sherpa_mnn { + +// see +// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py#L40 +// Its decoder is stateful, not stateless. +class OfflineTransducerNeMoModel { + public: + explicit OfflineTransducerNeMoModel(const OfflineModelConfig &config); + + template + OfflineTransducerNeMoModel(Manager *mgr, const OfflineModelConfig &config); + + ~OfflineTransducerNeMoModel(); + + /** Run the encoder. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int. + * + * @return Return a vector containing: + * - encoder_out: A 3-D tensor of shape (N, T', encoder_dim) + * - encoder_out_length: A 1-D tensor of shape (N,) containing number + * of frames in `encoder_out` before padding. + */ + std::vector RunEncoder(MNN::Express::VARP features, + MNN::Express::VARP features_length) const; + + /** Run the decoder network. + * + * @param targets A int32 tensor of shape (batch_size, 1) + * @param targets_length A int32 tensor of shape (batch_size,) + * @param states The states for the decoder model. + * @return Return a vector: + * - ans[0] is the decoder_out (a float tensor) + * - ans[1] is the decoder_out_length (a int32 tensor) + * - ans[2:] is the states_next + */ + std::pair> RunDecoder( + MNN::Express::VARP targets, MNN::Express::VARP targets_length, + std::vector states) const; + + std::vector GetDecoderInitStates(int32_t batch_size) const; + + /** Run the joint network. + * + * @param encoder_out Output of the encoder network. + * @param decoder_out Output of the decoder network. + * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits. + */ + MNN::Express::VARP RunJoiner(MNN::Express::VARP encoder_out, MNN::Express::VARP decoder_out) const; + + /** Return the subsampling factor of the model. + */ + int32_t SubsamplingFactor() const; + + int32_t VocabSize() const; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const; + + // Possible values: + // - per_feature + // - all_features (not implemented yet) + // - fixed_mean (not implemented) + // - fixed_std (not implemented) + // - or just leave it to empty + // See + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 + // for details + std::string FeatureNormalizationMethod() const; + + bool IsGigaAM() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_NEMO_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-character-frontend.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-character-frontend.cc new file mode 100644 index 00000000..d69f1688 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-character-frontend.cc @@ -0,0 +1,208 @@ +// sherpa-mnn/csrc/offline-tts-character-frontend.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-tts-character-frontend.h" + +namespace sherpa_mnn { + +static std::unordered_map ReadTokens(std::istream &is) { + std::wstring_convert, char32_t> conv; + std::unordered_map token2id; + + std::string line; + + std::string sym; + std::u32string s; + int32_t id = 0; + while (std::getline(is, line)) { + std::istringstream iss(line); + iss >> sym; + if (iss.eof()) { + id = atoi(sym.c_str()); + sym = " "; + } else { + iss >> id; + } + + // eat the trailing \r\n on windows + iss >> std::ws; + if (!iss.eof()) { + SHERPA_ONNX_LOGE("Error when reading tokens: %s", line.c_str()); + exit(-1); + } + + // Form models from coqui-ai/TTS, we have saved the IDs of the following + // symbols in OfflineTtsVitsModelMetaData, so it is safe to skip them here. + if (sym == "" || sym == "" || sym == "" || sym == "") { + continue; + } + + s = conv.from_bytes(sym); + if (s.size() != 1) { + SHERPA_ONNX_LOGE("Error when reading tokens at Line %s. size: %d", + line.c_str(), static_cast(s.size())); + exit(-1); + } + + char32_t c = s[0]; + + if (token2id.count(c)) { + SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d", + sym.c_str(), line.c_str(), token2id.at(c)); + exit(-1); + } + + token2id.insert({c, id}); + } + + return token2id; +} + +OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend( + const std::string &tokens, const OfflineTtsVitsModelMetaData &meta_data) + : meta_data_(meta_data) { + std::ifstream is(tokens); + token2id_ = ReadTokens(is); +} + +template +OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend( + Manager *mgr, const std::string &tokens, + const OfflineTtsVitsModelMetaData &meta_data) + : meta_data_(meta_data) { + auto buf = ReadFile(mgr, tokens); + std::istrstream is(buf.data(), buf.size()); + token2id_ = ReadTokens(is); +} + +std::vector OfflineTtsCharacterFrontend::ConvertTextToTokenIds( + const std::string &_text, const std::string & /*voice = ""*/) const { + // see + // https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87 + int32_t use_eos_bos = meta_data_.use_eos_bos; + int32_t bos_id = meta_data_.bos_id; + int32_t eos_id = meta_data_.eos_id; + int32_t blank_id = meta_data_.blank_id; + int32_t add_blank = meta_data_.add_blank; + + std::string text(_text.size(), 0); + std::transform(_text.begin(), _text.end(), text.begin(), + [](auto c) { return std::tolower(c); }); + + std::wstring_convert, char32_t> conv; + std::u32string s = conv.from_bytes(text); + + std::vector ans; + + std::vector this_sentence; + if (add_blank) { + if (use_eos_bos) { + this_sentence.push_back(bos_id); + } + + this_sentence.push_back(blank_id); + + for (char32_t c : s) { + if (token2id_.count(c)) { + this_sentence.push_back(token2id_.at(c)); + this_sentence.push_back(blank_id); + } else { + SHERPA_ONNX_LOGE("Skip unknown character. Unicode codepoint: \\U+%04x.", + static_cast(c)); + } + + if (c == '.' || c == ':' || c == '?' || c == '!') { + // end of a sentence + if (use_eos_bos) { + this_sentence.push_back(eos_id); + } + + ans.emplace_back(std::move(this_sentence)); + this_sentence = {}; + + // re-initialize this_sentence + if (use_eos_bos) { + this_sentence.push_back(bos_id); + } + this_sentence.push_back(blank_id); + } + } + + if (use_eos_bos) { + this_sentence.push_back(eos_id); + } + + if (static_cast(this_sentence.size()) > 1 + use_eos_bos) { + ans.emplace_back(std::move(this_sentence)); + } + } else { + // not adding blank + if (use_eos_bos) { + this_sentence.push_back(bos_id); + } + + for (char32_t c : s) { + if (token2id_.count(c)) { + this_sentence.push_back(token2id_.at(c)); + } + + if (c == '.' || c == ':' || c == '?' || c == '!') { + // end of a sentence + if (use_eos_bos) { + this_sentence.push_back(eos_id); + } + + ans.emplace_back(std::move(this_sentence)); + this_sentence = {}; + + // re-initialize this_sentence + if (use_eos_bos) { + this_sentence.push_back(bos_id); + } + } + } + + if (this_sentence.size() > 1) { + ans.emplace_back(std::move(this_sentence)); + } + } + + return ans; +} + +#if __ANDROID_API__ >= 9 +template OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend( + AAssetManager *mgr, const std::string &tokens, + const OfflineTtsVitsModelMetaData &meta_data); + +#endif + +#if __OHOS__ +template OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend( + NativeResourceManager *mgr, const std::string &tokens, + const OfflineTtsVitsModelMetaData &meta_data); + +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-character-frontend.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-character-frontend.h new file mode 100644 index 00000000..bae54baf --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-character-frontend.h @@ -0,0 +1,48 @@ +// sherpa-mnn/csrc/offline-tts-character-frontend.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_CHARACTER_FRONTEND_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_CHARACTER_FRONTEND_H_ +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/offline-tts-frontend.h" +#include "sherpa-mnn/csrc/offline-tts-vits-model-meta-data.h" + +namespace sherpa_mnn { + +class OfflineTtsCharacterFrontend : public OfflineTtsFrontend { + public: + OfflineTtsCharacterFrontend(const std::string &tokens, + const OfflineTtsVitsModelMetaData &meta_data); + + template + OfflineTtsCharacterFrontend(Manager *mgr, const std::string &tokens, + const OfflineTtsVitsModelMetaData &meta_data); + + /** Convert a string to token IDs. + * + * @param text The input text. + * Example 1: "This is the first sample sentence; this is the + * second one." Example 2: "这是第一句。这是第二句。" + * @param voice Optional. It is for espeak-ng. + * + * @return Return a vector-of-vector of token IDs. Each subvector contains + * a sentence that can be processed independently. + * If a frontend does not support splitting the text into + * sentences, the resulting vector contains only one subvector. + */ + std::vector ConvertTextToTokenIds( + const std::string &text, const std::string &voice = "") const override; + + private: + OfflineTtsVitsModelMetaData meta_data_; + std::unordered_map token2id_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_CHARACTER_FRONTEND_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-frontend.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-frontend.cc new file mode 100644 index 00000000..8da821ce --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-frontend.cc @@ -0,0 +1,34 @@ +// sherpa-mnn/csrc/offline-tts-frontend.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-tts-frontend.h" + +#include +#include + +namespace sherpa_mnn { + +std::string TokenIDs::ToString() const { + std::ostringstream os; + os << "TokenIDs("; + os << "tokens=["; + std::string sep; + for (auto i : tokens) { + os << sep << i; + sep = ", "; + } + os << "], "; + + os << "tones=["; + sep = {}; + for (auto i : tones) { + os << sep << i; + sep = ", "; + } + os << "]"; + os << ")"; + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-frontend.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-frontend.h new file mode 100644 index 00000000..7e177f02 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-frontend.h @@ -0,0 +1,60 @@ +// sherpa-mnn/csrc/offline-tts-frontend.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_FRONTEND_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_FRONTEND_H_ +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +struct TokenIDs { + TokenIDs() = default; + + /*implicit*/ TokenIDs(std::vector tokens) // NOLINT + : tokens{std::move(tokens)} {} + + + TokenIDs(std::vector tokens, // NOLINT + std::vector tones) // NOLINT + : tokens{std::move(tokens)}, tones{std::move(tones)} {} + + std::string ToString() const; + + std::vector tokens; + + // Used only in MeloTTS + std::vector tones; +}; + +class OfflineTtsFrontend { + public: + virtual ~OfflineTtsFrontend() = default; + + /** Convert a string to token IDs. + * + * @param text The input text. + * Example 1: "This is the first sample sentence; this is the + * second one." Example 2: "这是第一句。这是第二句。" + * @param voice Optional. It is for espeak-ng. + * + * @return Return a vector-of-vector of token IDs. Each subvector contains + * a sentence that can be processed independently. + * If a frontend does not support splitting the text into sentences, + * the resulting vector contains only one subvector. + */ + virtual std::vector ConvertTextToTokenIds( + const std::string &text, const std::string &voice = "") const = 0; +}; + +// implementation is in ./piper-phonemize-lexicon.cc +void InitEspeak(const std::string &data_dir); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_FRONTEND_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-impl.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-impl.cc new file mode 100644 index 00000000..4d7f3c26 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-impl.cc @@ -0,0 +1,70 @@ +// sherpa-mnn/csrc/offline-tts-impl.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-tts-impl.h" + +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/offline-tts-kokoro-impl.h" +#include "sherpa-mnn/csrc/offline-tts-matcha-impl.h" +#include "sherpa-mnn/csrc/offline-tts-vits-impl.h" + +namespace sherpa_mnn { + +std::vector OfflineTtsImpl::AddBlank(const std::vector &x, + int32_t blank_id /*= 0*/) const { + // we assume the blank ID is 0 + std::vector buffer(x.size() * 2 + 1, blank_id); + int32_t i = 1; + for (auto k : x) { + buffer[i] = k; + i += 2; + } + return buffer; +} + +std::unique_ptr OfflineTtsImpl::Create( + const OfflineTtsConfig &config) { + if (!config.model.vits.model.empty()) { + return std::make_unique(config); + } else if (!config.model.matcha.acoustic_model.empty()) { + return std::make_unique(config); + } + + return std::make_unique(config); +} + +template +std::unique_ptr OfflineTtsImpl::Create( + Manager *mgr, const OfflineTtsConfig &config) { + if (!config.model.vits.model.empty()) { + return std::make_unique(mgr, config); + } else if (!config.model.matcha.acoustic_model.empty()) { + return std::make_unique(mgr, config); + } + + return std::make_unique(mgr, config); +} + +#if __ANDROID_API__ >= 9 +template std::unique_ptr OfflineTtsImpl::Create( + AAssetManager *mgr, const OfflineTtsConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr OfflineTtsImpl::Create( + NativeResourceManager *mgr, const OfflineTtsConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-impl.h new file mode 100644 index 00000000..771f0836 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-impl.h @@ -0,0 +1,43 @@ +// sherpa-mnn/csrc/offline-tts-impl.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_ + +#include +#include +#include + +#include "sherpa-mnn/csrc/offline-tts.h" + +namespace sherpa_mnn { + +class OfflineTtsImpl { + public: + virtual ~OfflineTtsImpl() = default; + + static std::unique_ptr Create(const OfflineTtsConfig &config); + + template + static std::unique_ptr Create(Manager *mgr, + const OfflineTtsConfig &config); + + virtual GeneratedAudio Generate( + const std::string &text, int sid = 0, float speed = 1.0, + GeneratedAudioCallback callback = nullptr) const = 0; + + // Return the sample rate of the generated audio + virtual int32_t SampleRate() const = 0; + + // Number of supported speakers. + // If it supports only a single speaker, then it return 0 or 1. + virtual int32_t NumSpeakers() const = 0; + + std::vector AddBlank(const std::vector &x, + int32_t blank_id = 0) const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-impl.h new file mode 100644 index 00000000..50b5eae4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-impl.h @@ -0,0 +1,440 @@ +// sherpa-mnn/csrc/offline-tts-kokoro-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_KOKORO_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_KOKORO_IMPL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "fst/extensions/far/far.h" +#include "kaldifst/csrc/kaldi-fst-io.h" +#include "kaldifst/csrc/text-normalizer.h" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/kokoro-multi-lang-lexicon.h" +#include "sherpa-mnn/csrc/lexicon.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-tts-frontend.h" +#include "sherpa-mnn/csrc/offline-tts-impl.h" +#include "sherpa-mnn/csrc/offline-tts-kokoro-model.h" +#include "sherpa-mnn/csrc/piper-phonemize-lexicon.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OfflineTtsKokoroImpl : public OfflineTtsImpl { + public: + explicit OfflineTtsKokoroImpl(const OfflineTtsConfig &config) + : config_(config), + model_(std::make_unique(config.model)) { + InitFrontend(); + + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + tn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("rule fst: %{public}s", f.c_str()); +#else + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); +#endif + } + tn_list_.push_back(std::make_unique(f)); + } + } + + if (!config.rule_fars.empty()) { + if (config.model.debug) { + SHERPA_ONNX_LOGE("Loading FST archives"); + } + std::vector files; + SplitStringToVector(config.rule_fars, ",", false, &files); + + tn_list_.reserve(files.size() + tn_list_.size()); + + for (const auto &f : files) { + if (config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("rule far: %{public}s", f.c_str()); +#else + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); +#endif + } + std::unique_ptr> reader( + fst::FarReader::Open(f)); + for (; !reader->Done(); reader->Next()) { + std::unique_ptr r( + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); + + tn_list_.push_back( + std::make_unique(std::move(r))); + } + } + + if (config.model.debug) { + SHERPA_ONNX_LOGE("FST archives loaded!"); + } + } + } + + template + OfflineTtsKokoroImpl(Manager *mgr, const OfflineTtsConfig &config) + : config_(config), + model_(std::make_unique(mgr, config.model)) { + InitFrontend(mgr); + + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + tn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("rule fst: %{public}s", f.c_str()); +#else + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); +#endif + } + auto buf = ReadFile(mgr, f); + std::istrstream is(buf.data(), buf.size()); + tn_list_.push_back(std::make_unique(is)); + } + } + + if (!config.rule_fars.empty()) { + std::vector files; + SplitStringToVector(config.rule_fars, ",", false, &files); + tn_list_.reserve(files.size() + tn_list_.size()); + + for (const auto &f : files) { + if (config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("rule far: %{public}s", f.c_str()); +#else + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); +#endif + } + + auto buf = ReadFile(mgr, f); + + std::unique_ptr s( + new std::istrstream(buf.data(), buf.size())); + + std::unique_ptr> reader( + fst::FarReader::Open(std::move(s))); + + for (; !reader->Done(); reader->Next()) { + std::unique_ptr r( + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); + + tn_list_.push_back( + std::make_unique(std::move(r))); + } // for (; !reader->Done(); reader->Next()) + } // for (const auto &f : files) + } // if (!config.rule_fars.empty()) + } + + int32_t SampleRate() const override { + return model_->GetMetaData().sample_rate; + } + + int32_t NumSpeakers() const override { + return model_->GetMetaData().num_speakers; + } + + GeneratedAudio Generate( + const std::string &_text, int sid = 0, float speed = 1.0, + GeneratedAudioCallback callback = nullptr) const override { + const auto &meta_data = model_->GetMetaData(); + int32_t num_speakers = meta_data.num_speakers; + + if (num_speakers == 0 && sid != 0) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "This is a single-speaker model and supports only sid 0. Given sid: " + "%{public}d. sid is ignored", + static_cast(sid)); +#else + SHERPA_ONNX_LOGE( + "This is a single-speaker model and supports only sid 0. Given sid: " + "%d. sid is ignored", + static_cast(sid)); +#endif + } + + if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "This model contains only %{public}d speakers. sid should be in the " + "range [%{public}d, %{public}d]. Given: %{public}d. Use sid=0", + num_speakers, 0, num_speakers - 1, static_cast(sid)); +#else + SHERPA_ONNX_LOGE( + "This model contains only %d speakers. sid should be in the range " + "[%d, %d]. Given: %d. Use sid=0", + num_speakers, 0, num_speakers - 1, static_cast(sid)); +#endif + sid = 0; + } + + std::string text = _text; + if (config_.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Raw text: %{public}s", text.c_str()); +#else + SHERPA_ONNX_LOGE("Raw text: %s", text.c_str()); +#endif + std::ostringstream os; + os << "In bytes (hex):\n"; + const auto p = reinterpret_cast(text.c_str()); + for (int32_t i = 0; i != text.size(); ++i) { + os << std::setw(2) << std::setfill('0') << std::hex + << static_cast(p[i]) << " "; + } + os << "\n"; + +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + if (!tn_list_.empty()) { + for (const auto &tn : tn_list_) { + text = tn->Normalize(text); + if (config_.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("After normalizing: %{public}s", text.c_str()); +#else + SHERPA_ONNX_LOGE("After normalizing: %s", text.c_str()); +#endif + } + } + } + + std::vector token_ids = + frontend_->ConvertTextToTokenIds(text, meta_data.voice); + + if (token_ids.empty() || + (token_ids.size() == 1 && token_ids[0].tokens.empty())) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Failed to convert '%{public}s' to token IDs", + text.c_str()); +#else + SHERPA_ONNX_LOGE("Failed to convert '%s' to token IDs", text.c_str()); +#endif + return {}; + } + + std::vector> x; + + x.reserve(token_ids.size()); + + for (auto &i : token_ids) { + x.push_back(std::move(i.tokens)); + } + + int32_t x_size = static_cast(x.size()); + + if (config_.max_num_sentences != 1) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "max_num_sentences (%{public}d) != 1 is ignored for Kokoro TTS " + "models", + config_.max_num_sentences); +#else + SHERPA_ONNX_LOGE( + "max_num_sentences (%d) != 1 is ignored for Kokoro TTS models", + config_.max_num_sentences); +#endif + } + + // the input text is too long, we process sentences within it in batches + // to avoid OOM. Batch size is config_.max_num_sentences + std::vector> batch_x; + + int32_t batch_size = 1; + batch_x.reserve(config_.max_num_sentences); + int32_t num_batches = x_size / batch_size; + + if (config_.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "Split it into %{public}d batches. batch size: " + "%{public}d. Number of sentences: %{public}d", + num_batches, batch_size, x_size); +#else + SHERPA_ONNX_LOGE( + "Split it into %d batches. batch size: %d. Number " + "of sentences: %d", + num_batches, batch_size, x_size); +#endif + } + + GeneratedAudio ans; + + int32_t should_continue = 1; + + int32_t k = 0; + + for (int32_t b = 0; b != num_batches && should_continue; ++b) { + batch_x.clear(); + for (int32_t i = 0; i != batch_size; ++i, ++k) { + batch_x.push_back(std::move(x[k])); + } + + auto audio = Process(batch_x, sid, speed); + ans.sample_rate = audio.sample_rate; + ans.samples.insert(ans.samples.end(), audio.samples.begin(), + audio.samples.end()); + if (callback) { + should_continue = callback(audio.samples.data(), audio.samples.size(), + (b + 1) * 1.0 / num_batches); + // Caution(fangjun): audio is freed when the callback returns, so users + // should copy the data if they want to access the data after + // the callback returns to avoid segmentation fault. + } + } + + batch_x.clear(); + while (k < static_cast(x.size()) && should_continue) { + batch_x.push_back(std::move(x[k])); + + ++k; + } + + if (!batch_x.empty()) { + auto audio = Process(batch_x, sid, speed); + ans.sample_rate = audio.sample_rate; + ans.samples.insert(ans.samples.end(), audio.samples.begin(), + audio.samples.end()); + if (callback) { + callback(audio.samples.data(), audio.samples.size(), 1.0); + // Caution(fangjun): audio is freed when the callback returns, so users + // should copy the data if they want to access the data after + // the callback returns to avoid segmentation fault. + } + } + + return ans; + } + + private: + template + void InitFrontend(Manager *mgr) { + const auto &meta_data = model_->GetMetaData(); + + if (meta_data.version >= 2) { + // this is a multi-lingual model, we require that you pass lexicon + // and dict_dir + if (config_.model.kokoro.lexicon.empty() || + config_.model.kokoro.dict_dir.empty()) { + SHERPA_ONNX_LOGE("Current model version: '%d'", meta_data.version); + SHERPA_ONNX_LOGE( + "You are using a multi-lingual Kokoro model (e.g., Kokoro >= " + "v1.0). please pass --kokoro-lexicon and --kokoro-dict-dir"); + SHERPA_ONNX_EXIT(-1); + } + + frontend_ = std::make_unique( + mgr, config_.model.kokoro.tokens, config_.model.kokoro.lexicon, + config_.model.kokoro.dict_dir, config_.model.kokoro.data_dir, + meta_data, config_.model.debug); + + return; + } + + frontend_ = std::make_unique( + mgr, config_.model.kokoro.tokens, config_.model.kokoro.data_dir, + meta_data); + } + + void InitFrontend() { + const auto &meta_data = model_->GetMetaData(); + if (meta_data.version >= 2) { + // this is a multi-lingual model, we require that you pass lexicon + // and dict_dir + if (config_.model.kokoro.lexicon.empty() || + config_.model.kokoro.dict_dir.empty()) { + SHERPA_ONNX_LOGE("Current model version: '%d'", meta_data.version); + SHERPA_ONNX_LOGE( + "You are using a multi-lingual Kokoro model (e.g., Kokoro >= " + "v1.0). please pass --kokoro-lexicon and --kokoro-dict-dir"); + SHERPA_ONNX_EXIT(-1); + } + + frontend_ = std::make_unique( + config_.model.kokoro.tokens, config_.model.kokoro.lexicon, + config_.model.kokoro.dict_dir, config_.model.kokoro.data_dir, + meta_data, config_.model.debug); + + return; + } + + // this is for kokoro v0.19, which supports only English + frontend_ = std::make_unique( + config_.model.kokoro.tokens, config_.model.kokoro.data_dir, meta_data); + } + + GeneratedAudio Process(const std::vector> &tokens, + int32_t sid, float speed) const { + int32_t num_tokens = 0; + for (const auto &k : tokens) { + num_tokens += k.size(); + } + + std::vector x; + x.reserve(num_tokens); + for (const auto &k : tokens) { + x.insert(x.end(), k.begin(), k.end()); + } + + auto memory_info = + (MNNAllocator*)(nullptr); + + std::array x_shape = {1, static_cast(x.size())}; + MNN::Express::VARP x_tensor = MNNUtilsCreateTensor( + memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); + + MNN::Express::VARP audio = model_->Run(std::move(x_tensor), sid, speed); + + std::vector audio_shape = + audio->getInfo()->dim; + + int total = 1; + // The output shape may be (1, 1, total) or (1, total) or (total,) + for (auto i : audio_shape) { + total *= i; + } + + const float *p = audio->readMap(); + + GeneratedAudio ans; + ans.sample_rate = model_->GetMetaData().sample_rate; + ans.samples = std::vector(p, p + total); + + float silence_scale = config_.silence_scale; + if (silence_scale != 1) { + ans = ans.ScaleSilence(silence_scale); + } + + return ans; + } + + private: + OfflineTtsConfig config_; + std::unique_ptr model_; + std::vector> tn_list_; + std::unique_ptr frontend_; +}; + +} // namespace sherpa_mnn +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_KOKORO_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-model-config.cc new file mode 100644 index 00000000..ef760b3b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-model-config.cc @@ -0,0 +1,135 @@ +// sherpa-mnn/csrc/offline-tts-kokoro-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-tts-kokoro-model-config.h" + +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +void OfflineTtsKokoroModelConfig::Register(ParseOptions *po) { + po->Register("kokoro-model", &model, "Path to Kokoro model"); + po->Register("kokoro-voices", &voices, + "Path to voices.bin for Kokoro models"); + po->Register("kokoro-tokens", &tokens, + "Path to tokens.txt for Kokoro models"); + po->Register( + "kokoro-lexicon", &lexicon, + "Path to lexicon.txt for Kokoro models. Used only for Kokoro >= v1.0" + "You can pass multiple files, separated by ','. Example: " + "./lexicon-us-en.txt,./lexicon-zh.txt"); + po->Register("kokoro-data-dir", &data_dir, + "Path to the directory containing dict for espeak-ng."); + po->Register("kokoro-dict-dir", &dict_dir, + "Path to the directory containing dict for jieba. " + "Used only for Kokoro >= v1.0"); + po->Register("kokoro-length-scale", &length_scale, + "Speech speed. Larger->Slower; Smaller->faster."); +} + +bool OfflineTtsKokoroModelConfig::Validate() const { + if (model.empty()) { + SHERPA_ONNX_LOGE("Please provide --kokoro-model"); + return false; + } + + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("--kokoro-model: '%s' does not exist", model.c_str()); + return false; + } + + if (tokens.empty()) { + SHERPA_ONNX_LOGE("Please provide --kokoro-tokens"); + return false; + } + + if (!FileExists(tokens)) { + SHERPA_ONNX_LOGE("--kokoro-tokens: '%s' does not exist", tokens.c_str()); + return false; + } + + if (!lexicon.empty()) { + std::vector files; + SplitStringToVector(lexicon, ",", false, &files); + for (const auto &f : files) { + if (!FileExists(f)) { + SHERPA_ONNX_LOGE( + "lexicon '%s' does not exist. Please re-check --kokoro-lexicon", + f.c_str()); + return false; + } + } + } + + if (data_dir.empty()) { + SHERPA_ONNX_LOGE("Please provide --kokoro-data-dir"); + return false; + } + + if (!FileExists(data_dir + "/phontab")) { + SHERPA_ONNX_LOGE( + "'%s/phontab' does not exist. Please check --kokoro-data-dir", + data_dir.c_str()); + return false; + } + + if (!FileExists(data_dir + "/phonindex")) { + SHERPA_ONNX_LOGE( + "'%s/phonindex' does not exist. Please check --kokoro-data-dir", + data_dir.c_str()); + return false; + } + + if (!FileExists(data_dir + "/phondata")) { + SHERPA_ONNX_LOGE( + "'%s/phondata' does not exist. Please check --kokoro-data-dir", + data_dir.c_str()); + return false; + } + + if (!FileExists(data_dir + "/intonations")) { + SHERPA_ONNX_LOGE( + "'%s/intonations' does not exist. Please check --kokoro-data-dir", + data_dir.c_str()); + return false; + } + + if (!dict_dir.empty()) { + std::vector required_files = { + "jieba.dict.utf8", "hmm_model.utf8", "user.dict.utf8", + "idf.utf8", "stop_words.utf8", + }; + + for (const auto &f : required_files) { + if (!FileExists(dict_dir + "/" + f)) { + SHERPA_ONNX_LOGE("'%s/%s' does not exist. Please check kokoro-dict-dir", + dict_dir.c_str(), f.c_str()); + return false; + } + } + } + + return true; +} + +std::string OfflineTtsKokoroModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineTtsKokoroModelConfig("; + os << "model=\"" << model << "\", "; + os << "voices=\"" << voices << "\", "; + os << "tokens=\"" << tokens << "\", "; + os << "lexicon=\"" << lexicon << "\", "; + os << "data_dir=\"" << data_dir << "\", "; + os << "dict_dir=\"" << dict_dir << "\", "; + os << "length_scale=" << length_scale << ")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-model-config.h new file mode 100644 index 00000000..90b478c9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-model-config.h @@ -0,0 +1,54 @@ +// sherpa-mnn/csrc/offline-tts-kokoro-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_KOKORO_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_KOKORO_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineTtsKokoroModelConfig { + std::string model; + std::string voices; + std::string tokens; + + // Note: You can pass multiple files, separated by ",", to lexicon + // Example: lexicon = "./lexicon-gb-en.txt,./lexicon-zh.txt"; + std::string lexicon; + + std::string data_dir; + + std::string dict_dir; + + // speed = 1 / length_scale + float length_scale = 1.0; + + OfflineTtsKokoroModelConfig() = default; + + OfflineTtsKokoroModelConfig(const std::string &model, + const std::string &voices, + const std::string &tokens, + const std::string &lexicon, + const std::string &data_dir, + const std::string &dict_dir, float length_scale) + : model(model), + voices(voices), + tokens(tokens), + lexicon(lexicon), + data_dir(data_dir), + dict_dir(dict_dir), + length_scale(length_scale) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_KOKORO_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-model-meta-data.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-model-meta-data.h new file mode 100644 index 00000000..a0bf3f14 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-model-meta-data.h @@ -0,0 +1,27 @@ +// sherpa-mnn/csrc/offline-tts-kokoro-model-meta-data.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_KOKORO_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_KOKORO_MODEL_META_DATA_H_ + +#include +#include + +namespace sherpa_mnn { + +// please refer to +// https://github.com/k2-fsa/sherpa-mnn/blob/master/scripts/kokoro/add-meta-data.py +struct OfflineTtsKokoroModelMetaData { + int32_t sample_rate = 0; + int32_t num_speakers = 0; + int32_t version = 1; + int32_t has_espeak = 1; + int32_t max_token_len = 0; + + std::string voice; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_KOKORO_MODEL_META_DATA_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-model.cc new file mode 100644 index 00000000..13102e01 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-model.cc @@ -0,0 +1,251 @@ +// sherpa-mnn/csrc/offline-tts-kokoro-model.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-tts-kokoro-model.h" + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OfflineTtsKokoroModel::Impl { + public: + explicit Impl(const OfflineTtsModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto model_buf = ReadFile(config.kokoro.model); + auto voices_buf = ReadFile(config.kokoro.voices); + Init(model_buf.data(), model_buf.size(), voices_buf.data(), + voices_buf.size()); + } + + template + Impl(Manager *mgr, const OfflineTtsModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto model_buf = ReadFile(mgr, config.kokoro.model); + auto voices_buf = ReadFile(mgr, config.kokoro.voices); + Init(model_buf.data(), model_buf.size(), voices_buf.data(), + voices_buf.size()); + } + + const OfflineTtsKokoroModelMetaData &GetMetaData() const { + return meta_data_; + } + + MNN::Express::VARP Run(MNN::Express::VARP x, int32_t sid, float speed) { + auto memory_info = + (MNNAllocator*)(nullptr); + + std::vector x_shape = x->getInfo()->dim; + if (x_shape[0] != 1) { + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d", + static_cast(x_shape[0])); + exit(-1); + } + + // there is a 0 at the front and end of x + int32_t len = static_cast(x_shape[1]) - 2; + int32_t num_speakers = meta_data_.num_speakers; + int32_t dim0 = style_dim_[0]; + int32_t dim1 = style_dim_[2]; + if (len >= dim0) { + SHERPA_ONNX_LOGE("Bad things happened! %d vs %d", len, dim0); + SHERPA_ONNX_EXIT(-1); + } + + /*const*/ float *p = styles_.data() + sid * dim0 * dim1 + len * dim1; + + std::array style_embedding_shape = {1, dim1}; + MNN::Express::VARP style_embedding = MNNUtilsCreateTensor( + memory_info, p, dim1, style_embedding_shape.data(), + style_embedding_shape.size()); + + int speed_shape = 1; + + MNN::Express::VARP speed_tensor = + MNNUtilsCreateTensor(memory_info, &speed, 1, &speed_shape, 1); + + std::vector inputs = { + std::move(x), std::move(style_embedding), std::move(speed_tensor)}; + + auto out = + sess_->onForward(inputs); + + return std::move(out[0]); + } + + private: + void Init(void *model_data, size_t model_data_length, const char *voices_data, + size_t voices_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---kokoro model---\n"; + PrintModelMetadata(os, meta_data); + + os << "----------input names----------\n"; + int32_t i = 0; + for (const auto &s : input_names_) { + os << i << " " << s << "\n"; + ++i; + } + os << "----------output names----------\n"; + i = 0; + for (const auto &s : output_names_) { + os << i << " " << s << "\n"; + ++i; + } + +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.version, "version", 1); + SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers"); + SHERPA_ONNX_READ_META_DATA(meta_data_.has_espeak, "has_espeak"); + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.voice, "voice", + "en-us"); + + if (config_.debug) { + std::vector speaker_names; + SHERPA_ONNX_READ_META_DATA_VEC_STRING(speaker_names, "speaker_names"); + std::ostringstream os; + os << "\n"; + for (int32_t i = 0; i != speaker_names.size(); ++i) { + os << i << "->" << speaker_names[i] << ", "; + } + os << "\n"; + +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + SHERPA_ONNX_READ_META_DATA_VEC(style_dim_, "style_dim"); + if (style_dim_.size() != 3) { + SHERPA_ONNX_LOGE("style_dim should be 3-d, given: %d", + static_cast(style_dim_.size())); + SHERPA_ONNX_EXIT(-1); + } + + if (style_dim_[1] != 1) { + SHERPA_ONNX_LOGE("style_dim[0] should be 1, given: %d", style_dim_[1]); + SHERPA_ONNX_EXIT(-1); + } + + int32_t actual_num_floats = voices_data_length / sizeof(float); + int32_t expected_num_floats = + style_dim_[0] * style_dim_[2] * meta_data_.num_speakers; + + if (actual_num_floats != expected_num_floats) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "Corrupted --kokoro-voices '%{public}s'. Expected #floats: " + "%{public}d, actual: %{public}d", + config_.kokoro.voices.c_str(), expected_num_floats, + actual_num_floats); +#else + SHERPA_ONNX_LOGE( + "Corrupted --kokoro-voices '%s'. Expected #floats: %d, actual: %d", + config_.kokoro.voices.c_str(), expected_num_floats, + actual_num_floats); +#endif + + SHERPA_ONNX_EXIT(-1); + } + + styles_ = std::vector( + reinterpret_cast(voices_data), + reinterpret_cast(voices_data) + expected_num_floats); + + meta_data_.max_token_len = style_dim_[0]; + } + + private: + OfflineTtsModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + OfflineTtsKokoroModelMetaData meta_data_; + std::vector style_dim_; + + // (num_speakers, style_dim_[0], style_dim_[2]) + std::vector styles_; +}; + +OfflineTtsKokoroModel::OfflineTtsKokoroModel( + const OfflineTtsModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineTtsKokoroModel::OfflineTtsKokoroModel( + Manager *mgr, const OfflineTtsModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineTtsKokoroModel::~OfflineTtsKokoroModel() = default; + +const OfflineTtsKokoroModelMetaData &OfflineTtsKokoroModel::GetMetaData() + const { + return impl_->GetMetaData(); +} + +MNN::Express::VARP OfflineTtsKokoroModel::Run(MNN::Express::VARP x, int sid /*= 0*/, + float speed /*= 1.0*/) const { + return impl_->Run(std::move(x), sid, speed); +} + +#if __ANDROID_API__ >= 9 +template OfflineTtsKokoroModel::OfflineTtsKokoroModel( + AAssetManager *mgr, const OfflineTtsModelConfig &config); +#endif + +#if __OHOS__ +template OfflineTtsKokoroModel::OfflineTtsKokoroModel( + NativeResourceManager *mgr, const OfflineTtsModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-model.h new file mode 100644 index 00000000..f913e6ce --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-kokoro-model.h @@ -0,0 +1,39 @@ +// sherpa-mnn/csrc/offline-tts-kokoro-model.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_KOKORO_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_KOKORO_MODEL_H_ + +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-tts-kokoro-model-meta-data.h" +#include "sherpa-mnn/csrc/offline-tts-model-config.h" + +namespace sherpa_mnn { + +class OfflineTtsKokoroModel { + public: + ~OfflineTtsKokoroModel(); + + explicit OfflineTtsKokoroModel(const OfflineTtsModelConfig &config); + + template + OfflineTtsKokoroModel(Manager *mgr, const OfflineTtsModelConfig &config); + + // Return a float32 tensor containing the mel + // of shape (batch_size, mel_dim, num_frames) + MNN::Express::VARP Run(MNN::Express::VARP x, int sid = 0, float speed = 1.0) const; + + const OfflineTtsKokoroModelMetaData &GetMetaData() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_KOKORO_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-impl.h new file mode 100644 index 00000000..bec06ab4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-impl.h @@ -0,0 +1,419 @@ +// sherpa-mnn/csrc/offline-tts-matcha-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_IMPL_H_ + +#include +#include +#include +#include +#include + +#include "fst/extensions/far/far.h" +#include "kaldifst/csrc/kaldi-fst-io.h" +#include "kaldifst/csrc/text-normalizer.h" +#include "sherpa-mnn/csrc/hifigan-vocoder.h" +#include "sherpa-mnn/csrc/jieba-lexicon.h" +#include "sherpa-mnn/csrc/lexicon.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/melo-tts-lexicon.h" +#include "sherpa-mnn/csrc/offline-tts-character-frontend.h" +#include "sherpa-mnn/csrc/offline-tts-frontend.h" +#include "sherpa-mnn/csrc/offline-tts-impl.h" +#include "sherpa-mnn/csrc/offline-tts-matcha-model.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/piper-phonemize-lexicon.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OfflineTtsMatchaImpl : public OfflineTtsImpl { + public: + explicit OfflineTtsMatchaImpl(const OfflineTtsConfig &config) + : config_(config), + model_(std::make_unique(config.model)), + vocoder_(std::make_unique( + config.model.num_threads, config.model.provider, + config.model.matcha.vocoder)) { + InitFrontend(); + + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + tn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("rule fst: %{public}s", f.c_str()); +#else + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); +#endif + } + tn_list_.push_back(std::make_unique(f)); + } + } + + if (!config.rule_fars.empty()) { + if (config.model.debug) { + SHERPA_ONNX_LOGE("Loading FST archives"); + } + std::vector files; + SplitStringToVector(config.rule_fars, ",", false, &files); + + tn_list_.reserve(files.size() + tn_list_.size()); + + for (const auto &f : files) { + if (config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("rule far: %{public}s", f.c_str()); +#else + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); +#endif + } + std::unique_ptr> reader( + fst::FarReader::Open(f)); + for (; !reader->Done(); reader->Next()) { + std::unique_ptr r( + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); + + tn_list_.push_back( + std::make_unique(std::move(r))); + } + } + + if (config.model.debug) { + SHERPA_ONNX_LOGE("FST archives loaded!"); + } + } + } + + template + OfflineTtsMatchaImpl(Manager *mgr, const OfflineTtsConfig &config) + : config_(config), + model_(std::make_unique(mgr, config.model)), + vocoder_(std::make_unique( + mgr, config.model.num_threads, config.model.provider, + config.model.matcha.vocoder)) { + InitFrontend(mgr); + + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + tn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("rule fst: %{public}s", f.c_str()); +#else + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); +#endif + } + auto buf = ReadFile(mgr, f); + std::istrstream is(buf.data(), buf.size()); + tn_list_.push_back(std::make_unique(is)); + } + } + + if (!config.rule_fars.empty()) { + std::vector files; + SplitStringToVector(config.rule_fars, ",", false, &files); + tn_list_.reserve(files.size() + tn_list_.size()); + + for (const auto &f : files) { + if (config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("rule far: %{public}s", f.c_str()); +#else + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); +#endif + } + + auto buf = ReadFile(mgr, f); + + std::unique_ptr s( + new std::istrstream(buf.data(), buf.size())); + + std::unique_ptr> reader( + fst::FarReader::Open(std::move(s))); + + for (; !reader->Done(); reader->Next()) { + std::unique_ptr r( + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); + + tn_list_.push_back( + std::make_unique(std::move(r))); + } // for (; !reader->Done(); reader->Next()) + } // for (const auto &f : files) + } // if (!config.rule_fars.empty()) + } + + int32_t SampleRate() const override { + return model_->GetMetaData().sample_rate; + } + + int32_t NumSpeakers() const override { + return model_->GetMetaData().num_speakers; + } + + GeneratedAudio Generate( + const std::string &_text, int sid = 0, float speed = 1.0, + GeneratedAudioCallback callback = nullptr) const override { + const auto &meta_data = model_->GetMetaData(); + int32_t num_speakers = meta_data.num_speakers; + + if (num_speakers == 0 && sid != 0) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "This is a single-speaker model and supports only sid 0. Given sid: " + "%{public}d. sid is ignored", + static_cast(sid)); +#else + SHERPA_ONNX_LOGE( + "This is a single-speaker model and supports only sid 0. Given sid: " + "%d. sid is ignored", + static_cast(sid)); +#endif + } + + if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "This model contains only %{public}d speakers. sid should be in the " + "range [%{public}d, %{public}d]. Given: %{public}d. Use sid=0", + num_speakers, 0, num_speakers - 1, static_cast(sid)); +#else + SHERPA_ONNX_LOGE( + "This model contains only %d speakers. sid should be in the range " + "[%d, %d]. Given: %d. Use sid=0", + num_speakers, 0, num_speakers - 1, static_cast(sid)); +#endif + sid = 0; + } + + std::string text = _text; + if (config_.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Raw text: %{public}s", text.c_str()); +#else + SHERPA_ONNX_LOGE("Raw text: %s", text.c_str()); +#endif + } + + if (!tn_list_.empty()) { + for (const auto &tn : tn_list_) { + text = tn->Normalize(text); + if (config_.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("After normalizing: %{public}s", text.c_str()); +#else + SHERPA_ONNX_LOGE("After normalizing: %s", text.c_str()); +#endif + } + } + } + + std::vector token_ids = + frontend_->ConvertTextToTokenIds(text, meta_data.voice); + + if (token_ids.empty() || + (token_ids.size() == 1 && token_ids[0].tokens.empty())) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Failed to convert '%{public}s' to token IDs", + text.c_str()); +#else + SHERPA_ONNX_LOGE("Failed to convert '%s' to token IDs", text.c_str()); +#endif + return {}; + } + + std::vector> x; + + x.reserve(token_ids.size()); + + for (auto &i : token_ids) { + x.push_back(std::move(i.tokens)); + } + + for (auto &k : x) { + k = AddBlank(k, meta_data.pad_id); + } + + int32_t x_size = static_cast(x.size()); + + if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) { + auto ans = Process(x, sid, speed); + if (callback) { + callback(ans.samples.data(), ans.samples.size(), 1.0); + } + return ans; + } + + // the input text is too long, we process sentences within it in batches + // to avoid OOM. Batch size is config_.max_num_sentences + std::vector> batch_x; + + int32_t batch_size = config_.max_num_sentences; + batch_x.reserve(config_.max_num_sentences); + int32_t num_batches = x_size / batch_size; + + if (config_.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "Text is too long. Split it into %{public}d batches. batch size: " + "%{public}d. Number of sentences: %{public}d", + num_batches, batch_size, x_size); +#else + SHERPA_ONNX_LOGE( + "Text is too long. Split it into %d batches. batch size: %d. Number " + "of sentences: %d", + num_batches, batch_size, x_size); +#endif + } + + GeneratedAudio ans; + + int32_t should_continue = 1; + + int32_t k = 0; + + for (int32_t b = 0; b != num_batches && should_continue; ++b) { + batch_x.clear(); + for (int32_t i = 0; i != batch_size; ++i, ++k) { + batch_x.push_back(std::move(x[k])); + } + + auto audio = Process(batch_x, sid, speed); + ans.sample_rate = audio.sample_rate; + ans.samples.insert(ans.samples.end(), audio.samples.begin(), + audio.samples.end()); + if (callback) { + should_continue = callback(audio.samples.data(), audio.samples.size(), + (b + 1) * 1.0 / num_batches); + // Caution(fangjun): audio is freed when the callback returns, so users + // should copy the data if they want to access the data after + // the callback returns to avoid segmentation fault. + } + } + + batch_x.clear(); + while (k < static_cast(x.size()) && should_continue) { + batch_x.push_back(std::move(x[k])); + + ++k; + } + + if (!batch_x.empty()) { + auto audio = Process(batch_x, sid, speed); + ans.sample_rate = audio.sample_rate; + ans.samples.insert(ans.samples.end(), audio.samples.begin(), + audio.samples.end()); + if (callback) { + callback(audio.samples.data(), audio.samples.size(), 1.0); + // Caution(fangjun): audio is freed when the callback returns, so users + // should copy the data if they want to access the data after + // the callback returns to avoid segmentation fault. + } + } + + return ans; + } + + private: + template + void InitFrontend(Manager *mgr) { + // for piper phonemizer + // we require that you copy espeak_ng_data + // from assets to disk + // + // for jieba + // we require that you copy dict from assets to disk + const auto &meta_data = model_->GetMetaData(); + + if (meta_data.jieba && !meta_data.has_espeak) { + frontend_ = std::make_unique( + mgr, config_.model.matcha.lexicon, config_.model.matcha.tokens, + config_.model.matcha.dict_dir, config_.model.debug); + } else if (meta_data.has_espeak && !meta_data.jieba) { + frontend_ = std::make_unique( + mgr, config_.model.matcha.tokens, config_.model.matcha.data_dir, + meta_data); + } else { + SHERPA_ONNX_LOGE("jieba + espeaker-ng is not supported yet"); + SHERPA_ONNX_EXIT(-1); + } + } + + void InitFrontend() { + const auto &meta_data = model_->GetMetaData(); + + if (meta_data.jieba && !meta_data.has_espeak) { + frontend_ = std::make_unique( + config_.model.matcha.lexicon, config_.model.matcha.tokens, + config_.model.matcha.dict_dir, config_.model.debug); + } else if (meta_data.has_espeak && !meta_data.jieba) { + frontend_ = std::make_unique( + config_.model.matcha.tokens, config_.model.matcha.data_dir, + meta_data); + } else { + SHERPA_ONNX_LOGE("jieba + espeaker-ng is not supported yet"); + SHERPA_ONNX_EXIT(-1); + } + } + + GeneratedAudio Process(const std::vector> &tokens, + int32_t sid, float speed) const { + int32_t num_tokens = 0; + for (const auto &k : tokens) { + num_tokens += k.size(); + } + + std::vector x; + x.reserve(num_tokens); + for (const auto &k : tokens) { + x.insert(x.end(), k.begin(), k.end()); + } + + auto memory_info = + (MNNAllocator*)(nullptr); + + std::array x_shape = {1, static_cast(x.size())}; + MNN::Express::VARP x_tensor = MNNUtilsCreateTensor( + memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); + + MNN::Express::VARP mel = model_->Run(std::move(x_tensor), sid, speed); + MNN::Express::VARP audio = vocoder_->Run(std::move(mel)); + + std::vector audio_shape = + audio->getInfo()->dim; + + int total = 1; + // The output shape may be (1, 1, total) or (1, total) or (total,) + for (auto i : audio_shape) { + total *= i; + } + + const float *p = audio->readMap(); + + GeneratedAudio ans; + ans.sample_rate = model_->GetMetaData().sample_rate; + ans.samples = std::vector(p, p + total); + + float silence_scale = config_.silence_scale; + if (silence_scale != 1) { + ans = ans.ScaleSilence(silence_scale); + } + + return ans; + } + + private: + OfflineTtsConfig config_; + std::unique_ptr model_; + std::unique_ptr vocoder_; + std::vector> tn_list_; + std::unique_ptr frontend_; +}; + +} // namespace sherpa_mnn +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-model-config.cc new file mode 100644 index 00000000..6c225f94 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-model-config.cc @@ -0,0 +1,143 @@ +// sherpa-mnn/csrc/offline-tts-matcha-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-tts-matcha-model-config.h" + +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineTtsMatchaModelConfig::Register(ParseOptions *po) { + po->Register("matcha-acoustic-model", &acoustic_model, + "Path to matcha acoustic model"); + po->Register("matcha-vocoder", &vocoder, "Path to matcha vocoder"); + po->Register("matcha-lexicon", &lexicon, + "Path to lexicon.txt for Matcha models"); + po->Register("matcha-tokens", &tokens, + "Path to tokens.txt for Matcha models"); + po->Register("matcha-data-dir", &data_dir, + "Path to the directory containing dict for espeak-ng. If it is " + "given, --matcha-lexicon is ignored."); + po->Register("matcha-dict-dir", &dict_dir, + "Path to the directory containing dict for jieba. Used only for " + "Chinese TTS models using jieba"); + po->Register("matcha-noise-scale", &noise_scale, + "noise_scale for Matcha models"); + po->Register("matcha-length-scale", &length_scale, + "Speech speed. Larger->Slower; Smaller->faster."); +} + +bool OfflineTtsMatchaModelConfig::Validate() const { + if (acoustic_model.empty()) { + SHERPA_ONNX_LOGE("Please provide --matcha-acoustic-model"); + return false; + } + + if (!FileExists(acoustic_model)) { + SHERPA_ONNX_LOGE("--matcha-acoustic-model: '%s' does not exist", + acoustic_model.c_str()); + return false; + } + + if (vocoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --matcha-vocoder"); + return false; + } + + if (!FileExists(vocoder)) { + SHERPA_ONNX_LOGE("--matcha-vocoder: '%s' does not exist", vocoder.c_str()); + return false; + } + + if (tokens.empty()) { + SHERPA_ONNX_LOGE("Please provide --matcha-tokens"); + return false; + } + + if (!FileExists(tokens)) { + SHERPA_ONNX_LOGE("--matcha-tokens: '%s' does not exist", tokens.c_str()); + return false; + } + + if (!data_dir.empty()) { + if (!FileExists(data_dir + "/phontab")) { + SHERPA_ONNX_LOGE( + "'%s/phontab' does not exist. Please check --matcha-data-dir", + data_dir.c_str()); + return false; + } + + if (!FileExists(data_dir + "/phonindex")) { + SHERPA_ONNX_LOGE( + "'%s/phonindex' does not exist. Please check --matcha-data-dir", + data_dir.c_str()); + return false; + } + + if (!FileExists(data_dir + "/phondata")) { + SHERPA_ONNX_LOGE( + "'%s/phondata' does not exist. Please check --matcha-data-dir", + data_dir.c_str()); + return false; + } + + if (!FileExists(data_dir + "/intonations")) { + SHERPA_ONNX_LOGE( + "'%s/intonations' does not exist. Please check --matcha-data-dir", + data_dir.c_str()); + return false; + } + } + + if (!dict_dir.empty()) { + std::vector required_files = { + "jieba.dict.utf8", "hmm_model.utf8", "user.dict.utf8", + "idf.utf8", "stop_words.utf8", + }; + + for (const auto &f : required_files) { + if (!FileExists(dict_dir + "/" + f)) { + SHERPA_ONNX_LOGE( + "'%s/%s' does not exist. Please check --matcha-dict-dir", + dict_dir.c_str(), f.c_str()); + return false; + } + } + + // we require that --matcha-lexicon is not empty + if (lexicon.empty()) { + SHERPA_ONNX_LOGE("Please provide --matcha-lexicon"); + return false; + } + + if (!FileExists(lexicon)) { + SHERPA_ONNX_LOGE("--matcha-lexicon: '%s' does not exist", + lexicon.c_str()); + return false; + } + } + + return true; +} + +std::string OfflineTtsMatchaModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineTtsMatchaModelConfig("; + os << "acoustic_model=\"" << acoustic_model << "\", "; + os << "vocoder=\"" << vocoder << "\", "; + os << "lexicon=\"" << lexicon << "\", "; + os << "tokens=\"" << tokens << "\", "; + os << "data_dir=\"" << data_dir << "\", "; + os << "dict_dir=\"" << dict_dir << "\", "; + os << "noise_scale=" << noise_scale << ", "; + os << "length_scale=" << length_scale << ")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-model-config.h new file mode 100644 index 00000000..038b23ed --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-model-config.h @@ -0,0 +1,56 @@ +// sherpa-mnn/csrc/offline-tts-matcha-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineTtsMatchaModelConfig { + std::string acoustic_model; + std::string vocoder; + std::string lexicon; + std::string tokens; + + // If data_dir is given, lexicon is ignored + // data_dir is for piper-phonemizer, which uses espeak-ng + std::string data_dir; + + // Used for Chinese TTS models using jieba + std::string dict_dir; + + float noise_scale = 1; + float length_scale = 1; + + OfflineTtsMatchaModelConfig() = default; + + OfflineTtsMatchaModelConfig(const std::string &acoustic_model, + const std::string &vocoder, + const std::string &lexicon, + const std::string &tokens, + const std::string &data_dir, + const std::string &dict_dir, + float noise_scale = 1.0, float length_scale = 1) + : acoustic_model(acoustic_model), + vocoder(vocoder), + lexicon(lexicon), + tokens(tokens), + data_dir(data_dir), + dict_dir(dict_dir), + noise_scale(noise_scale), + length_scale(length_scale) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-model-meta-data.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-model-meta-data.h new file mode 100644 index 00000000..543e1d44 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-model-meta-data.h @@ -0,0 +1,30 @@ +// sherpa-mnn/csrc/offline-tts-matcha-model-meta-data.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_META_DATA_H_ + +#include +#include + +namespace sherpa_mnn { + +// If you are not sure what each field means, please +// have a look of the Python file in the model directory that +// you have downloaded. +struct OfflineTtsMatchaModelMetaData { + int32_t sample_rate = 0; + int32_t num_speakers = 0; + int32_t version = 1; + int32_t jieba = 0; + int32_t has_espeak = 0; + int32_t use_eos_bos = 0; + int32_t pad_id = 0; + + std::string voice; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_META_DATA_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-model.cc new file mode 100644 index 00000000..f20f9ab1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-model.cc @@ -0,0 +1,215 @@ +// sherpa-mnn/csrc/offline-tts-matcha-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-tts-matcha-model.h" + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" + +namespace sherpa_mnn { + +class OfflineTtsMatchaModel::Impl { + public: + explicit Impl(const OfflineTtsModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config.matcha.acoustic_model); + Init(buf.data(), buf.size()); + } + + template + Impl(Manager *mgr, const OfflineTtsModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config.matcha.acoustic_model); + Init(buf.data(), buf.size()); + } + + const OfflineTtsMatchaModelMetaData &GetMetaData() const { + return meta_data_; + } + + MNN::Express::VARP Run(MNN::Express::VARP x, int sid, float speed) { + auto memory_info = + (MNNAllocator*)(nullptr); + + std::vector x_shape = x->getInfo()->dim; + if (x_shape[0] != 1) { + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d", + static_cast(x_shape[0])); + exit(-1); + } + + int len = x_shape[1]; + int len_shape = 1; + + MNN::Express::VARP x_length = + MNNUtilsCreateTensor(memory_info, &len, 1, &len_shape, 1); + + int scale_shape = 1; + float noise_scale = config_.matcha.noise_scale; + float length_scale = config_.matcha.length_scale; + + if (speed != 1 && speed > 0) { + length_scale = 1. / speed; + } + + MNN::Express::VARP noise_scale_tensor = + MNNUtilsCreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1); + + MNN::Express::VARP length_scale_tensor = MNNUtilsCreateTensor( + memory_info, &length_scale, 1, &scale_shape, 1); + + MNN::Express::VARP sid_tensor = + MNNUtilsCreateTensor(memory_info, &sid, 1, &scale_shape, 1); + + std::array scales = {noise_scale, length_scale}; + int scales_shape = 2; + + MNN::Express::VARP scales_tensor = MNNUtilsCreateTensor( + memory_info, scales.data(), scales.size(), &scales_shape, 1); + + std::vector inputs; + inputs.reserve(5); + inputs.push_back(std::move(x)); + inputs.push_back(std::move(x_length)); + if (input_names_[2] == "scales") { + // for models from + // https://github.com/shivammehta25/Matcha-TTS + inputs.push_back(std::move(scales_tensor)); + } else { + // for models from icefall + inputs.push_back(std::move(noise_scale_tensor)); + inputs.push_back(std::move(length_scale_tensor)); + } + + if (input_names_.size() == 5 && input_names_.back() == "sid") { + // for models from icefall + inputs.push_back(std::move(sid_tensor)); + + // Note that we have not supported multi-speaker tts models from + // https://github.com/shivammehta25/Matcha-TTS + } + + auto out = + sess_->onForward(inputs); + + return std::move(out[0]); + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---matcha model---\n"; + PrintModelMetadata(os, meta_data); + + os << "----------input names----------\n"; + int32_t i = 0; + for (const auto &s : input_names_) { + os << i << " " << s << "\n"; + ++i; + } + os << "----------output names----------\n"; + i = 0; + for (const auto &s : output_names_) { + os << i << " " << s << "\n"; + ++i; + } + +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.version, "version", 1); + SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers"); + SHERPA_ONNX_READ_META_DATA(meta_data_.jieba, "jieba"); + SHERPA_ONNX_READ_META_DATA(meta_data_.has_espeak, "has_espeak"); + SHERPA_ONNX_READ_META_DATA(meta_data_.use_eos_bos, "use_eos_bos"); + SHERPA_ONNX_READ_META_DATA(meta_data_.pad_id, "pad_id"); + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.voice, "voice", + "en-us"); + } + + private: + OfflineTtsModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + OfflineTtsMatchaModelMetaData meta_data_; +}; + +OfflineTtsMatchaModel::OfflineTtsMatchaModel( + const OfflineTtsModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineTtsMatchaModel::OfflineTtsMatchaModel( + Manager *mgr, const OfflineTtsModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineTtsMatchaModel::~OfflineTtsMatchaModel() = default; + +const OfflineTtsMatchaModelMetaData &OfflineTtsMatchaModel::GetMetaData() + const { + return impl_->GetMetaData(); +} + +MNN::Express::VARP OfflineTtsMatchaModel::Run(MNN::Express::VARP x, int sid /*= 0*/, + float speed /*= 1.0*/) const { + return impl_->Run(std::move(x), sid, speed); +} + +#if __ANDROID_API__ >= 9 +template OfflineTtsMatchaModel::OfflineTtsMatchaModel( + AAssetManager *mgr, const OfflineTtsModelConfig &config); +#endif + +#if __OHOS__ +template OfflineTtsMatchaModel::OfflineTtsMatchaModel( + NativeResourceManager *mgr, const OfflineTtsModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-model.h new file mode 100644 index 00000000..e60bf4fa --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-matcha-model.h @@ -0,0 +1,39 @@ +// sherpa-mnn/csrc/offline-tts-matcha-model.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_H_ + +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-tts-matcha-model-meta-data.h" +#include "sherpa-mnn/csrc/offline-tts-model-config.h" + +namespace sherpa_mnn { + +class OfflineTtsMatchaModel { + public: + ~OfflineTtsMatchaModel(); + + explicit OfflineTtsMatchaModel(const OfflineTtsModelConfig &config); + + template + OfflineTtsMatchaModel(Manager *mgr, const OfflineTtsModelConfig &config); + + // Return a float32 tensor containing the mel + // of shape (batch_size, mel_dim, num_frames) + MNN::Express::VARP Run(MNN::Express::VARP x, int sid = 0, float speed = 1.0) const; + + const OfflineTtsMatchaModelMetaData &GetMetaData() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-model-config.cc new file mode 100644 index 00000000..6d456e15 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-model-config.cc @@ -0,0 +1,57 @@ +// sherpa-mnn/csrc/offline-tts-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-tts-model-config.h" + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineTtsModelConfig::Register(ParseOptions *po) { + vits.Register(po); + matcha.Register(po); + kokoro.Register(po); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool OfflineTtsModelConfig::Validate() const { + if (num_threads < 1) { + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); + return false; + } + + if (!vits.model.empty()) { + return vits.Validate(); + } + + if (!matcha.acoustic_model.empty()) { + return matcha.Validate(); + } + + return kokoro.Validate(); +} + +std::string OfflineTtsModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineTtsModelConfig("; + os << "vits=" << vits.ToString() << ", "; + os << "matcha=" << matcha.ToString() << ", "; + os << "kokoro=" << kokoro.ToString() << ", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-model-config.h new file mode 100644 index 00000000..4f41d787 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-model-config.h @@ -0,0 +1,48 @@ +// sherpa-mnn/csrc/offline-tts-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/offline-tts-kokoro-model-config.h" +#include "sherpa-mnn/csrc/offline-tts-matcha-model-config.h" +#include "sherpa-mnn/csrc/offline-tts-vits-model-config.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineTtsModelConfig { + OfflineTtsVitsModelConfig vits; + OfflineTtsMatchaModelConfig matcha; + OfflineTtsKokoroModelConfig kokoro; + + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + OfflineTtsModelConfig() = default; + + OfflineTtsModelConfig(const OfflineTtsVitsModelConfig &vits, + const OfflineTtsMatchaModelConfig &matcha, + const OfflineTtsKokoroModelConfig &kokoro, + int32_t num_threads, bool debug, + const std::string &provider) + : vits(vits), + matcha(matcha), + kokoro(kokoro), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-impl.h new file mode 100644 index 00000000..9c7125d8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-impl.h @@ -0,0 +1,505 @@ +// sherpa-mnn/csrc/offline-tts-vits-impl.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_ + +#include +#include +#include +#include +#include + +#include "fst/extensions/far/far.h" +#include "kaldifst/csrc/kaldi-fst-io.h" +#include "kaldifst/csrc/text-normalizer.h" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/jieba-lexicon.h" +#include "sherpa-mnn/csrc/lexicon.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/melo-tts-lexicon.h" +#include "sherpa-mnn/csrc/offline-tts-character-frontend.h" +#include "sherpa-mnn/csrc/offline-tts-frontend.h" +#include "sherpa-mnn/csrc/offline-tts-impl.h" +#include "sherpa-mnn/csrc/offline-tts-vits-model.h" +#include "sherpa-mnn/csrc/piper-phonemize-lexicon.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OfflineTtsVitsImpl : public OfflineTtsImpl { + public: + explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config) + : config_(config), + model_(std::make_unique(config.model)) { + InitFrontend(); + + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + tn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("rule fst: %{public}s", f.c_str()); +#else + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); +#endif + } + tn_list_.push_back(std::make_unique(f)); + } + } + + if (!config.rule_fars.empty()) { + if (config.model.debug) { + SHERPA_ONNX_LOGE("Loading FST archives"); + } + std::vector files; + SplitStringToVector(config.rule_fars, ",", false, &files); + + tn_list_.reserve(files.size() + tn_list_.size()); + + for (const auto &f : files) { + if (config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("rule far: %{public}s", f.c_str()); +#else + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); +#endif + } + std::unique_ptr> reader( + fst::FarReader::Open(f)); + for (; !reader->Done(); reader->Next()) { + std::unique_ptr r( + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); + + tn_list_.push_back( + std::make_unique(std::move(r))); + } + } + + if (config.model.debug) { + SHERPA_ONNX_LOGE("FST archives loaded!"); + } + } + } + + template + OfflineTtsVitsImpl(Manager *mgr, const OfflineTtsConfig &config) + : config_(config), + model_(std::make_unique(mgr, config.model)) { + InitFrontend(mgr); + + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + tn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("rule fst: %{public}s", f.c_str()); +#else + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); +#endif + } + auto buf = ReadFile(mgr, f); + std::istrstream is(buf.data(), buf.size()); + tn_list_.push_back(std::make_unique(is)); + } + } + + if (!config.rule_fars.empty()) { + std::vector files; + SplitStringToVector(config.rule_fars, ",", false, &files); + tn_list_.reserve(files.size() + tn_list_.size()); + + for (const auto &f : files) { + if (config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("rule far: %{public}s", f.c_str()); +#else + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); +#endif + } + + auto buf = ReadFile(mgr, f); + + std::unique_ptr s( + new std::istrstream(buf.data(), buf.size())); + + std::unique_ptr> reader( + fst::FarReader::Open(std::move(s))); + + for (; !reader->Done(); reader->Next()) { + std::unique_ptr r( + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); + + tn_list_.push_back( + std::make_unique(std::move(r))); + } // for (; !reader->Done(); reader->Next()) + } // for (const auto &f : files) + } // if (!config.rule_fars.empty()) + } + + int32_t SampleRate() const override { + return model_->GetMetaData().sample_rate; + } + + int32_t NumSpeakers() const override { + return model_->GetMetaData().num_speakers; + } + + GeneratedAudio Generate( + const std::string &_text, int sid = 0, float speed = 1.0, + GeneratedAudioCallback callback = nullptr) const override { + const auto &meta_data = model_->GetMetaData(); + int32_t num_speakers = meta_data.num_speakers; + + if (num_speakers == 0 && sid != 0) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "This is a single-speaker model and supports only sid 0. Given sid: " + "%{public}d. sid is ignored", + static_cast(sid)); +#else + SHERPA_ONNX_LOGE( + "This is a single-speaker model and supports only sid 0. Given sid: " + "%d. sid is ignored", + static_cast(sid)); +#endif + } + + if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "This model contains only %{public}d speakers. sid should be in the " + "range [%{public}d, %{public}d]. Given: %{public}d. Use sid=0", + num_speakers, 0, num_speakers - 1, static_cast(sid)); +#else + SHERPA_ONNX_LOGE( + "This model contains only %d speakers. sid should be in the range " + "[%d, %d]. Given: %d. Use sid=0", + num_speakers, 0, num_speakers - 1, static_cast(sid)); +#endif + sid = 0; + } + + std::string text = _text; + if (config_.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Raw text: %{public}s", text.c_str()); +#else + SHERPA_ONNX_LOGE("Raw text: %s", text.c_str()); +#endif + } + + if (!tn_list_.empty()) { + for (const auto &tn : tn_list_) { + text = tn->Normalize(text); + if (config_.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("After normalizing: %{public}s", text.c_str()); +#else + SHERPA_ONNX_LOGE("After normalizing: %s", text.c_str()); +#endif + } + } + } + + std::vector token_ids = + frontend_->ConvertTextToTokenIds(text, meta_data.voice); + + if (token_ids.empty() || + (token_ids.size() == 1 && token_ids[0].tokens.empty())) { + SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); + return {}; + } + + std::vector> x; + std::vector> tones; + + x.reserve(token_ids.size()); + + for (auto &i : token_ids) { + x.push_back(std::move(i.tokens)); + } + + if (!token_ids[0].tones.empty()) { + tones.reserve(token_ids.size()); + for (auto &i : token_ids) { + tones.push_back(std::move(i.tones)); + } + } + + // TODO(fangjun): add blank inside the frontend, not here + if (meta_data.add_blank && config_.model.vits.data_dir.empty() && + meta_data.frontend != "characters") { + for (auto &k : x) { + k = AddBlank(k); + } + + for (auto &k : tones) { + k = AddBlank(k); + } + } + + int32_t x_size = static_cast(x.size()); + + if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) { + auto ans = Process(x, tones, sid, speed); + if (callback) { + callback(ans.samples.data(), ans.samples.size(), 1.0); + } + return ans; + } + + // the input text is too long, we process sentences within it in batches + // to avoid OOM. Batch size is config_.max_num_sentences + std::vector> batch_x; + std::vector> batch_tones; + + int32_t batch_size = config_.max_num_sentences; + batch_x.reserve(config_.max_num_sentences); + batch_tones.reserve(config_.max_num_sentences); + int32_t num_batches = x_size / batch_size; + + if (config_.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "Text is too long. Split it into %{public}d batches. batch size: " + "%{public}d. Number of sentences: %{public}d", + num_batches, batch_size, x_size); +#else + SHERPA_ONNX_LOGE( + "Text is too long. Split it into %d batches. batch size: %d. Number " + "of sentences: %d", + num_batches, batch_size, x_size); +#endif + } + + GeneratedAudio ans; + + int32_t should_continue = 1; + + int32_t k = 0; + + for (int32_t b = 0; b != num_batches && should_continue; ++b) { + batch_x.clear(); + batch_tones.clear(); + for (int32_t i = 0; i != batch_size; ++i, ++k) { + batch_x.push_back(std::move(x[k])); + + if (!tones.empty()) { + batch_tones.push_back(std::move(tones[k])); + } + } + + auto audio = Process(batch_x, batch_tones, sid, speed); + ans.sample_rate = audio.sample_rate; + ans.samples.insert(ans.samples.end(), audio.samples.begin(), + audio.samples.end()); + if (callback) { + should_continue = callback(audio.samples.data(), audio.samples.size(), + (b + 1) * 1.0 / num_batches); + // Caution(fangjun): audio is freed when the callback returns, so users + // should copy the data if they want to access the data after + // the callback returns to avoid segmentation fault. + } + } + + batch_x.clear(); + batch_tones.clear(); + while (k < static_cast(x.size()) && should_continue) { + batch_x.push_back(std::move(x[k])); + if (!tones.empty()) { + batch_tones.push_back(std::move(tones[k])); + } + + ++k; + } + + if (!batch_x.empty()) { + auto audio = Process(batch_x, batch_tones, sid, speed); + ans.sample_rate = audio.sample_rate; + ans.samples.insert(ans.samples.end(), audio.samples.begin(), + audio.samples.end()); + if (callback) { + callback(audio.samples.data(), audio.samples.size(), 1.0); + // Caution(fangjun): audio is freed when the callback returns, so users + // should copy the data if they want to access the data after + // the callback returns to avoid segmentation fault. + } + } + + return ans; + } + + private: + template + void InitFrontend(Manager *mgr) { + const auto &meta_data = model_->GetMetaData(); + + if (meta_data.frontend == "characters") { + frontend_ = std::make_unique( + mgr, config_.model.vits.tokens, meta_data); + } else if (meta_data.jieba && !config_.model.vits.dict_dir.empty() && + meta_data.is_melo_tts) { + frontend_ = std::make_unique( + mgr, config_.model.vits.lexicon, config_.model.vits.tokens, + config_.model.vits.dict_dir, model_->GetMetaData(), + config_.model.debug); + } else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) { + frontend_ = std::make_unique( + mgr, config_.model.vits.lexicon, config_.model.vits.tokens, + config_.model.vits.dict_dir, config_.model.debug); + } else if (meta_data.is_melo_tts && meta_data.language == "English") { + frontend_ = std::make_unique( + mgr, config_.model.vits.lexicon, config_.model.vits.tokens, + model_->GetMetaData(), config_.model.debug); + } else if ((meta_data.is_piper || meta_data.is_coqui || + meta_data.is_icefall) && + !config_.model.vits.data_dir.empty()) { + frontend_ = std::make_unique( + mgr, config_.model.vits.tokens, config_.model.vits.data_dir, + meta_data); + } else { + if (config_.model.vits.lexicon.empty()) { + SHERPA_ONNX_LOGE( + "Not a model using characters as modeling unit. Please provide " + "--vits-lexicon if you leave --vits-data-dir empty"); + exit(-1); + } + + frontend_ = std::make_unique( + mgr, config_.model.vits.lexicon, config_.model.vits.tokens, + meta_data.punctuations, meta_data.language, config_.model.debug); + } + } + + void InitFrontend() { + const auto &meta_data = model_->GetMetaData(); + + if (meta_data.jieba && config_.model.vits.dict_dir.empty()) { + SHERPA_ONNX_LOGE( + "Please provide --vits-dict-dir for Chinese TTS models using jieba"); + exit(-1); + } + + if (!meta_data.jieba && !config_.model.vits.dict_dir.empty()) { + SHERPA_ONNX_LOGE( + "Current model is not using jieba but you provided --vits-dict-dir"); + exit(-1); + } + + if (meta_data.frontend == "characters") { + frontend_ = std::make_unique( + config_.model.vits.tokens, meta_data); + } else if (meta_data.jieba && !config_.model.vits.dict_dir.empty() && + meta_data.is_melo_tts) { + frontend_ = std::make_unique( + config_.model.vits.lexicon, config_.model.vits.tokens, + config_.model.vits.dict_dir, model_->GetMetaData(), + config_.model.debug); + } else if (meta_data.is_melo_tts && meta_data.language == "English") { + frontend_ = std::make_unique( + config_.model.vits.lexicon, config_.model.vits.tokens, + model_->GetMetaData(), config_.model.debug); + } else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) { + frontend_ = std::make_unique( + config_.model.vits.lexicon, config_.model.vits.tokens, + config_.model.vits.dict_dir, config_.model.debug); + } else if ((meta_data.is_piper || meta_data.is_coqui || + meta_data.is_icefall) && + !config_.model.vits.data_dir.empty()) { + frontend_ = std::make_unique( + config_.model.vits.tokens, config_.model.vits.data_dir, + model_->GetMetaData()); + } else { + if (config_.model.vits.lexicon.empty()) { + SHERPA_ONNX_LOGE( + "Not a model using characters as modeling unit. Please provide " + "--vits-lexicon if you leave --vits-data-dir empty"); + exit(-1); + } + frontend_ = std::make_unique( + config_.model.vits.lexicon, config_.model.vits.tokens, + meta_data.punctuations, meta_data.language, config_.model.debug); + } + } + + GeneratedAudio Process(const std::vector> &tokens, + const std::vector> &tones, + int32_t sid, float speed) const { + int32_t num_tokens = 0; + for (const auto &k : tokens) { + num_tokens += k.size(); + } + + std::vector x; + x.reserve(num_tokens); + for (const auto &k : tokens) { + x.insert(x.end(), k.begin(), k.end()); + } + + std::vector tone_list; + if (!tones.empty()) { + tone_list.reserve(num_tokens); + for (const auto &k : tones) { + tone_list.insert(tone_list.end(), k.begin(), k.end()); + } + } + + auto memory_info = + (MNNAllocator*)(nullptr); + + std::array x_shape = {1, static_cast(x.size())}; + MNN::Express::VARP x_tensor = MNNUtilsCreateTensor( + memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); + + MNN::Express::VARP tones_tensor{nullptr}; + if (!tones.empty()) { + tones_tensor = MNNUtilsCreateTensor(memory_info, tone_list.data(), + tone_list.size(), x_shape.data(), + x_shape.size()); + } + + MNN::Express::VARP audio{nullptr}; + if (tones.empty()) { + audio = model_->Run(std::move(x_tensor), sid, speed); + } else { + audio = + model_->Run(std::move(x_tensor), std::move(tones_tensor), sid, speed); + } + + std::vector audio_shape = + audio->getInfo()->dim; + + int total = 1; + // The output shape may be (1, 1, total) or (1, total) or (total,) + for (auto i : audio_shape) { + total *= i; + } + + const float *p = audio->readMap(); + + GeneratedAudio ans; + ans.sample_rate = model_->GetMetaData().sample_rate; + ans.samples = std::vector(p, p + total); + + float silence_scale = config_.silence_scale; + if (silence_scale != 1) { + ans = ans.ScaleSilence(silence_scale); + } + + return ans; + } + + private: + OfflineTtsConfig config_; + std::unique_ptr model_; + std::vector> tn_list_; + std::unique_ptr frontend_; +}; + +} // namespace sherpa_mnn +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-model-config.cc new file mode 100644 index 00000000..f87d8ff1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-model-config.cc @@ -0,0 +1,115 @@ +// sherpa-mnn/csrc/offline-tts-vits-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-tts-vits-model-config.h" + +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineTtsVitsModelConfig::Register(ParseOptions *po) { + po->Register("vits-model", &model, "Path to VITS model"); + po->Register("vits-lexicon", &lexicon, "Path to lexicon.txt for VITS models"); + po->Register("vits-tokens", &tokens, "Path to tokens.txt for VITS models"); + po->Register("vits-data-dir", &data_dir, + "Path to the directory containing dict for espeak-ng. If it is " + "given, --vits-lexicon is ignored."); + po->Register("vits-dict-dir", &dict_dir, + "Path to the directory containing dict for jieba. Used only for " + "Chinese TTS models using jieba"); + po->Register("vits-noise-scale", &noise_scale, "noise_scale for VITS models"); + po->Register("vits-noise-scale-w", &noise_scale_w, + "noise_scale_w for VITS models"); + po->Register("vits-length-scale", &length_scale, + "Speech speed. Larger->Slower; Smaller->faster."); +} + +bool OfflineTtsVitsModelConfig::Validate() const { + if (model.empty()) { + SHERPA_ONNX_LOGE("Please provide --vits-model"); + return false; + } + + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("--vits-model: '%s' does not exist", model.c_str()); + return false; + } + + if (tokens.empty()) { + SHERPA_ONNX_LOGE("Please provide --vits-tokens"); + return false; + } + + if (!FileExists(tokens)) { + SHERPA_ONNX_LOGE("--vits-tokens: '%s' does not exist", tokens.c_str()); + return false; + } + + if (!data_dir.empty()) { + if (!FileExists(data_dir + "/phontab")) { + SHERPA_ONNX_LOGE( + "'%s/phontab' does not exist. Please check --vits-data-dir", + data_dir.c_str()); + return false; + } + + if (!FileExists(data_dir + "/phonindex")) { + SHERPA_ONNX_LOGE( + "'%s/phonindex' does not exist. Please check --vits-data-dir", + data_dir.c_str()); + return false; + } + + if (!FileExists(data_dir + "/phondata")) { + SHERPA_ONNX_LOGE( + "'%s/phondata' does not exist. Please check --vits-data-dir", + data_dir.c_str()); + return false; + } + + if (!FileExists(data_dir + "/intonations")) { + SHERPA_ONNX_LOGE( + "'%s/intonations' does not exist. Please check --vits-data-dir", + data_dir.c_str()); + return false; + } + } + + if (!dict_dir.empty()) { + std::vector required_files = { + "jieba.dict.utf8", "hmm_model.utf8", "user.dict.utf8", + "idf.utf8", "stop_words.utf8", + }; + + for (const auto &f : required_files) { + if (!FileExists(dict_dir + "/" + f)) { + SHERPA_ONNX_LOGE("'%s/%s' does not exist. Please check vits-dict-dir", + dict_dir.c_str(), f.c_str()); + return false; + } + } + } + return true; +} + +std::string OfflineTtsVitsModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineTtsVitsModelConfig("; + os << "model=\"" << model << "\", "; + os << "lexicon=\"" << lexicon << "\", "; + os << "tokens=\"" << tokens << "\", "; + os << "data_dir=\"" << data_dir << "\", "; + os << "dict_dir=\"" << dict_dir << "\", "; + os << "noise_scale=" << noise_scale << ", "; + os << "noise_scale_w=" << noise_scale_w << ", "; + os << "length_scale=" << length_scale << ")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-model-config.h new file mode 100644 index 00000000..eaf7cbcf --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-model-config.h @@ -0,0 +1,59 @@ +// sherpa-mnn/csrc/offline-tts-vits-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineTtsVitsModelConfig { + std::string model; + std::string lexicon; + std::string tokens; + + // If data_dir is given, lexicon is ignored + // data_dir is for piper-phonemize, which uses espeak-ng + std::string data_dir; + + // Used for Chinese TTS models using jieba + std::string dict_dir; + + float noise_scale = 0.667; + float noise_scale_w = 0.8; + float length_scale = 1; + + // used only for multi-speaker models, e.g, vctk speech dataset. + // Not applicable for single-speaker models, e.g., ljspeech dataset + + OfflineTtsVitsModelConfig() = default; + + OfflineTtsVitsModelConfig(const std::string &model, + const std::string &lexicon, + const std::string &tokens, + const std::string &data_dir, + const std::string &dict_dir, + float noise_scale = 0.667, + float noise_scale_w = 0.8, float length_scale = 1) + : model(model), + lexicon(lexicon), + tokens(tokens), + data_dir(data_dir), + dict_dir(dict_dir), + noise_scale(noise_scale), + noise_scale_w(noise_scale_w), + length_scale(length_scale) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-model-meta-data.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-model-meta-data.h new file mode 100644 index 00000000..8e3219dc --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-model-meta-data.h @@ -0,0 +1,49 @@ +// sherpa-mnn/csrc/offline-tts-vits-model-meta-data.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_META_DATA_H_ + +#include +#include + +namespace sherpa_mnn { + +// If you are not sure what each field means, please +// have a look of the Python file in the model directory that +// you have downloaded. +struct OfflineTtsVitsModelMetaData { + int32_t sample_rate = 0; + int32_t add_blank = 0; + int32_t num_speakers = 0; + + bool is_piper = false; + bool is_coqui = false; + bool is_icefall = false; + bool is_melo_tts = false; + + // for Chinese TTS models from + // https://github.com/Plachtaa/VITS-fast-fine-tuning + int32_t jieba = 0; + + // the following options are for models from coqui-ai/TTS + int32_t blank_id = 0; + int32_t bos_id = 0; + int32_t eos_id = 0; + int32_t use_eos_bos = 0; + int32_t pad_id = 0; + + // for melo tts + int32_t speaker_id = 0; + int32_t version = 0; + + std::string punctuations; + std::string language; + std::string voice; + std::string frontend; // characters +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_META_DATA_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-model.cc new file mode 100644 index 00000000..c99749e3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-model.cc @@ -0,0 +1,379 @@ +// sherpa-mnn/csrc/offline-tts-vits-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-tts-vits-model.h" + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" + +namespace sherpa_mnn { + +class OfflineTtsVitsModel::Impl { + public: + explicit Impl(const OfflineTtsModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config.vits.model); + Init(buf.data(), buf.size()); + } + + template + Impl(Manager *mgr, const OfflineTtsModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config.vits.model); + Init(buf.data(), buf.size()); + } + + MNN::Express::VARP Run(MNN::Express::VARP x, int sid, float speed) { + if (meta_data_.is_piper || meta_data_.is_coqui) { + return RunVitsPiperOrCoqui(std::move(x), sid, speed); + } + + return RunVits(std::move(x), sid, speed); + } + + MNN::Express::VARP Run(MNN::Express::VARP x, MNN::Express::VARP tones, int sid, float speed) { + if (meta_data_.num_speakers == 1) { + // For MeloTTS, we hardcode sid to the one contained in the meta data + sid = meta_data_.speaker_id; + } + + auto memory_info = + (MNNAllocator*)(nullptr); + + std::vector x_shape = x->getInfo()->dim; + if (x_shape[0] != 1) { + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d", + static_cast(x_shape[0])); + exit(-1); + } + + int len = x_shape[1]; + int len_shape = 1; + + MNN::Express::VARP x_length = + MNNUtilsCreateTensor(memory_info, &len, 1, &len_shape, 1); + + int scale_shape = 1; + float noise_scale = config_.vits.noise_scale; + float length_scale = config_.vits.length_scale; + float noise_scale_w = config_.vits.noise_scale_w; + + if (speed != 1 && speed > 0) { + length_scale = 1. / speed; + } + + MNN::Express::VARP noise_scale_tensor = + MNNUtilsCreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1); + + MNN::Express::VARP length_scale_tensor = MNNUtilsCreateTensor( + memory_info, &length_scale, 1, &scale_shape, 1); + + MNN::Express::VARP noise_scale_w_tensor = MNNUtilsCreateTensor( + memory_info, &noise_scale_w, 1, &scale_shape, 1); + + MNN::Express::VARP sid_tensor = + MNNUtilsCreateTensor(memory_info, &sid, 1, &scale_shape, 1); + + std::vector inputs; + inputs.reserve(7); + inputs.push_back(std::move(x)); + inputs.push_back(std::move(x_length)); + inputs.push_back(std::move(tones)); + inputs.push_back(std::move(sid_tensor)); + inputs.push_back(std::move(noise_scale_tensor)); + inputs.push_back(std::move(length_scale_tensor)); + inputs.push_back(std::move(noise_scale_w_tensor)); + + auto out = + sess_->onForward(inputs); + + return std::move(out[0]); + } + + const OfflineTtsVitsModelMetaData &GetMetaData() const { return meta_data_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---vits model---\n"; + PrintModelMetadata(os, meta_data); + + os << "----------input names----------\n"; + int32_t i = 0; + for (const auto &s : input_names_) { + os << i << " " << s << "\n"; + ++i; + } + os << "----------output names----------\n"; + i = 0; + for (const auto &s : output_names_) { + os << i << " " << s << "\n"; + ++i; + } + +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.add_blank, "add_blank", + 0); + + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.speaker_id, "speaker_id", + 0); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.version, "version", 0); + SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers"); + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.punctuations, + "punctuation", ""); + SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); + + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.voice, "voice", ""); + + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.frontend, "frontend", + ""); + + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.jieba, "jieba", 0); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.blank_id, "blank_id", 0); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.bos_id, "bos_id", 0); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.eos_id, "eos_id", 0); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.use_eos_bos, + "use_eos_bos", 1); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.pad_id, "pad_id", 0); + + std::string comment; + SHERPA_ONNX_READ_META_DATA_STR(comment, "comment"); + + if (comment.find("piper") != std::string::npos) { + meta_data_.is_piper = true; + } + + if (comment.find("coqui") != std::string::npos) { + meta_data_.is_coqui = true; + } + + if (comment.find("icefall") != std::string::npos) { + meta_data_.is_icefall = true; + } + + if (comment.find("melo") != std::string::npos) { + meta_data_.is_melo_tts = true; + int32_t expected_version = 2; + if (meta_data_.version < expected_version) { + SHERPA_ONNX_LOGE( + "Please download the latest MeloTTS model and retry. Current " + "version: %d. Expected version: %d", + meta_data_.version, expected_version); + exit(-1); + } + + // NOTE(fangjun): + // version 0 is the first version + // version 2: add jieba=1 to the metadata + } + } + + MNN::Express::VARP RunVitsPiperOrCoqui(MNN::Express::VARP x, int sid, float speed) { + auto memory_info = + (MNNAllocator*)(nullptr); + + std::vector x_shape = x->getInfo()->dim; + if (x_shape[0] != 1) { + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d", + static_cast(x_shape[0])); + exit(-1); + } + + int len = x_shape[1]; + int len_shape = 1; + + MNN::Express::VARP x_length = + MNNUtilsCreateTensor(memory_info, &len, 1, &len_shape, 1); + + float noise_scale = config_.vits.noise_scale; + float length_scale = config_.vits.length_scale; + float noise_scale_w = config_.vits.noise_scale_w; + + if (speed != 1 && speed > 0) { + length_scale = 1. / speed; + } + std::array scales = {noise_scale, length_scale, noise_scale_w}; + + int scale_shape = 3; + + MNN::Express::VARP scales_tensor = MNNUtilsCreateTensor( + memory_info, scales.data(), scales.size(), &scale_shape, 1); + + int sid_shape = 1; + MNN::Express::VARP sid_tensor = + MNNUtilsCreateTensor(memory_info, &sid, 1, &sid_shape, 1); + + int lang_id_shape = 1; + int lang_id = 0; + MNN::Express::VARP lang_id_tensor = + MNNUtilsCreateTensor(memory_info, &lang_id, 1, &lang_id_shape, 1); + + std::vector inputs; + inputs.reserve(5); + inputs.push_back(std::move(x)); + inputs.push_back(std::move(x_length)); + inputs.push_back(std::move(scales_tensor)); + + if (input_names_.size() >= 4 && input_names_[3] == "sid") { + inputs.push_back(std::move(sid_tensor)); + } + + if (input_names_.size() >= 5 && input_names_[4] == "langid") { + inputs.push_back(std::move(lang_id_tensor)); + } + + auto out = + sess_->onForward(inputs); + + return std::move(out[0]); + } + + MNN::Express::VARP RunVits(MNN::Express::VARP x, int sid, float speed) { + auto memory_info = + (MNNAllocator*)(nullptr); + + std::vector x_shape = x->getInfo()->dim; + if (x_shape[0] != 1) { + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d", + static_cast(x_shape[0])); + exit(-1); + } + + int len = x_shape[1]; + int len_shape = 1; + + MNN::Express::VARP x_length = + MNNUtilsCreateTensor(memory_info, &len, 1, &len_shape, 1); + + int scale_shape = 1; + float noise_scale = config_.vits.noise_scale; + float length_scale = config_.vits.length_scale; + float noise_scale_w = config_.vits.noise_scale_w; + + if (speed != 1 && speed > 0) { + length_scale = 1. / speed; + } + + MNN::Express::VARP noise_scale_tensor = + MNNUtilsCreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1); + + MNN::Express::VARP length_scale_tensor = MNNUtilsCreateTensor( + memory_info, &length_scale, 1, &scale_shape, 1); + + MNN::Express::VARP noise_scale_w_tensor = MNNUtilsCreateTensor( + memory_info, &noise_scale_w, 1, &scale_shape, 1); + + MNN::Express::VARP sid_tensor = + MNNUtilsCreateTensor(memory_info, &sid, 1, &scale_shape, 1); + + std::vector inputs; + inputs.reserve(6); + inputs.push_back(std::move(x)); + inputs.push_back(std::move(x_length)); + inputs.push_back(std::move(noise_scale_tensor)); + inputs.push_back(std::move(length_scale_tensor)); + inputs.push_back(std::move(noise_scale_w_tensor)); + + if (input_names_.size() == 6 && + (input_names_.back() == "sid" || input_names_.back() == "speaker")) { + inputs.push_back(std::move(sid_tensor)); + } + + auto out = + sess_->onForward(inputs); + + return std::move(out[0]); + } + + private: + OfflineTtsModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + OfflineTtsVitsModelMetaData meta_data_; +}; + +OfflineTtsVitsModel::OfflineTtsVitsModel(const OfflineTtsModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineTtsVitsModel::OfflineTtsVitsModel(Manager *mgr, + const OfflineTtsModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineTtsVitsModel::~OfflineTtsVitsModel() = default; + +MNN::Express::VARP OfflineTtsVitsModel::Run(MNN::Express::VARP x, int sid /*=0*/, + float speed /*= 1.0*/) { + return impl_->Run(std::move(x), sid, speed); +} + +MNN::Express::VARP OfflineTtsVitsModel::Run(MNN::Express::VARP x, MNN::Express::VARP tones, + int sid /*= 0*/, + float speed /*= 1.0*/) const { + return impl_->Run(std::move(x), std::move(tones), sid, speed); +} + +const OfflineTtsVitsModelMetaData &OfflineTtsVitsModel::GetMetaData() const { + return impl_->GetMetaData(); +} + +#if __ANDROID_API__ >= 9 +template OfflineTtsVitsModel::OfflineTtsVitsModel( + AAssetManager *mgr, const OfflineTtsModelConfig &config); +#endif + +#if __OHOS__ +template OfflineTtsVitsModel::OfflineTtsVitsModel( + NativeResourceManager *mgr, const OfflineTtsModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-model.h new file mode 100644 index 00000000..5145531e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts-vits-model.h @@ -0,0 +1,51 @@ +// sherpa-mnn/csrc/offline-tts-vits-model.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_ + +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-tts-model-config.h" +#include "sherpa-mnn/csrc/offline-tts-vits-model-meta-data.h" + +namespace sherpa_mnn { + +class OfflineTtsVitsModel { + public: + ~OfflineTtsVitsModel(); + + explicit OfflineTtsVitsModel(const OfflineTtsModelConfig &config); + + template + OfflineTtsVitsModel(Manager *mgr, const OfflineTtsModelConfig &config); + + /** Run the model. + * + * @param x A int64 tensor of shape (1, num_tokens) + // @param sid Speaker ID. Used only for multi-speaker models, e.g., models + // trained using the VCTK dataset. It is not used for + // single-speaker models, e.g., models trained using the ljspeech + // dataset. + * @return Return a float32 tensor containing audio samples. You can flatten + * it to a 1-D tensor. + */ + MNN::Express::VARP Run(MNN::Express::VARP x, int sid = 0, float speed = 1.0); + + // This is for MeloTTS + MNN::Express::VARP Run(MNN::Express::VARP x, MNN::Express::VARP tones, int sid = 0, + float speed = 1.0) const; + + const OfflineTtsVitsModelMetaData &GetMetaData() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_VITS_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts.cc new file mode 100644 index 00000000..528dffd0 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts.cc @@ -0,0 +1,213 @@ +// sherpa-mnn/csrc/offline-tts.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-tts.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-tts-impl.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +struct SilenceInterval { + int32_t start; + int32_t end; +}; + +GeneratedAudio GeneratedAudio::ScaleSilence(float scale) const { + if (scale == 1) { + return *this; + } + // if the interval is larger than 0.2 second, then we assume it is a pause + int32_t threshold = static_cast(sample_rate * 0.2); + + std::vector intervals; + int32_t num_samples = static_cast(samples.size()); + + int32_t last = -1; + int32_t i; + for (i = 0; i != num_samples; ++i) { + if (fabs(samples[i]) <= 0.01) { + if (last == -1) { + last = i; + } + continue; + } + + if (last != -1 && i - last < threshold) { + last = -1; + continue; + } + + if (last != -1) { + intervals.push_back({last, i}); + last = -1; + } + } + + if (last != -1 && num_samples - last > threshold) { + intervals.push_back({last, num_samples}); + } + + if (intervals.empty()) { + return *this; + } + + GeneratedAudio ans; + ans.sample_rate = sample_rate; + ans.samples.reserve(samples.size()); + + i = 0; + for (const auto &interval : intervals) { + ans.samples.insert(ans.samples.end(), samples.begin() + i, + samples.begin() + interval.start); + i = interval.end; + int32_t n = static_cast((interval.end - interval.start) * scale); + + ans.samples.insert(ans.samples.end(), samples.begin() + interval.start, + samples.begin() + interval.start + n); + } + + if (i < num_samples) { + ans.samples.insert(ans.samples.end(), samples.begin() + i, samples.end()); + } + + return ans; +} + +void OfflineTtsConfig::Register(ParseOptions *po) { + model.Register(po); + + po->Register("tts-rule-fsts", &rule_fsts, + "It not empty, it contains a list of rule FST filenames." + "Multiple filenames are separated by a comma and they are " + "applied from left to right. An example value: " + "rule1.fst,rule2.fst,rule3.fst"); + + po->Register("tts-rule-fars", &rule_fars, + "It not empty, it contains a list of rule FST archive filenames." + "Multiple filenames are separated by a comma and they are " + "applied from left to right. An example value: " + "rule1.far,rule2.far,rule3.far. Note that an *.far can contain " + "multiple *.fst files"); + + po->Register( + "tts-max-num-sentences", &max_num_sentences, + "Maximum number of sentences that we process at a time. " + "This is to avoid OOM for very long input text. " + "If you set it to -1, then we process all sentences in a single batch."); + + po->Register("tts-silence-scale", &silence_scale, + "Duration of the pause is scaled by this number. So a smaller " + "value leads to a shorter pause."); +} + +bool OfflineTtsConfig::Validate() const { + if (!rule_fsts.empty()) { + std::vector files; + SplitStringToVector(rule_fsts, ",", false, &files); + for (const auto &f : files) { + if (!FileExists(f)) { + SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str()); + return false; + } + } + } + + if (!rule_fars.empty()) { + std::vector files; + SplitStringToVector(rule_fars, ",", false, &files); + for (const auto &f : files) { + if (!FileExists(f)) { + SHERPA_ONNX_LOGE("Rule far '%s' does not exist. ", f.c_str()); + return false; + } + } + } + + if (silence_scale < 0.001) { + SHERPA_ONNX_LOGE("--tts-silence-scale '%.3f' is too small", silence_scale); + return false; + } + + return model.Validate(); +} + +std::string OfflineTtsConfig::ToString() const { + std::ostringstream os; + + os << "OfflineTtsConfig("; + os << "model=" << model.ToString() << ", "; + os << "rule_fsts=\"" << rule_fsts << "\", "; + os << "rule_fars=\"" << rule_fars << "\", "; + os << "max_num_sentences=" << max_num_sentences << ", "; + os << "silence_scale=" << silence_scale << ")"; + + return os.str(); +} + +OfflineTts::OfflineTts(const OfflineTtsConfig &config) + : impl_(OfflineTtsImpl::Create(config)) {} + +template +OfflineTts::OfflineTts(Manager *mgr, const OfflineTtsConfig &config) + : impl_(OfflineTtsImpl::Create(mgr, config)) {} + +OfflineTts::~OfflineTts() = default; + +GeneratedAudio OfflineTts::Generate( + const std::string &text, int sid /*=0*/, float speed /*= 1.0*/, + GeneratedAudioCallback callback /*= nullptr*/) const { +#if !defined(_WIN32) + return impl_->Generate(text, sid, speed, std::move(callback)); +#else + if (IsUtf8(text)) { + return impl_->Generate(text, sid, speed, std::move(callback)); + } else if (IsGB2312(text)) { + auto utf8_text = Gb2312ToUtf8(text); + static bool printed = false; + if (!printed) { + SHERPA_ONNX_LOGE( + "Detected GB2312 encoded string! Converting it to UTF8."); + printed = true; + } + return impl_->Generate(utf8_text, sid, speed, std::move(callback)); + } else { + SHERPA_ONNX_LOGE( + "Non UTF8 encoded string is received. You would not get expected " + "results!"); + return impl_->Generate(text, sid, speed, std::move(callback)); + } +#endif +} + +int32_t OfflineTts::SampleRate() const { return impl_->SampleRate(); } + +int32_t OfflineTts::NumSpeakers() const { return impl_->NumSpeakers(); } + +#if __ANDROID_API__ >= 9 +template OfflineTts::OfflineTts(AAssetManager *mgr, + const OfflineTtsConfig &config); +#endif + +#if __OHOS__ +template OfflineTts::OfflineTts(NativeResourceManager *mgr, + const OfflineTtsConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts.h new file mode 100644 index 00000000..2be536c4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-tts.h @@ -0,0 +1,111 @@ +// sherpa-mnn/csrc/offline-tts.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_H_ + +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/offline-tts-model-config.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineTtsConfig { + OfflineTtsModelConfig model; + // If not empty, it contains a list of rule FST filenames. + // Filenames are separated by a comma. + // Example value: rule1.fst,rule2,fst,rule3.fst + // + // If there are multiple rules, they are applied from left to right. + std::string rule_fsts; + + // If there are multiple FST archives, they are applied from left to right. + std::string rule_fars; + + // Maximum number of sentences that we process at a time. + // This is to avoid OOM for very long input text. + // If you set it to -1, then we process all sentences in a single batch. + int32_t max_num_sentences = 1; + + // A silence interval contains audio samples with value close to 0. + // + // the duration of the new interval is old_duration * silence_scale. + float silence_scale = 0.2; + + OfflineTtsConfig() = default; + OfflineTtsConfig(const OfflineTtsModelConfig &model, + const std::string &rule_fsts, const std::string &rule_fars, + int32_t max_num_sentences, float silence_scale) + : model(model), + rule_fsts(rule_fsts), + rule_fars(rule_fars), + max_num_sentences(max_num_sentences), + silence_scale(silence_scale) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +struct GeneratedAudio { + std::vector samples; + int32_t sample_rate; + + // Silence means pause here. + // If scale > 1, then it increases the duration of a pause + // If scale < 1, then it reduces the duration of a pause + GeneratedAudio ScaleSilence(float scale) const; +}; + +class OfflineTtsImpl; + +// If the callback returns 0, then it stop generating +// if the callback returns 1, then it keeps generating +using GeneratedAudioCallback = std::function; + +class OfflineTts { + public: + ~OfflineTts(); + explicit OfflineTts(const OfflineTtsConfig &config); + + template + OfflineTts(Manager *mgr, const OfflineTtsConfig &config); + + // @param text A string containing words separated by spaces + // @param sid Speaker ID. Used only for multi-speaker models, e.g., models + // trained using the VCTK dataset. It is not used for + // single-speaker models, e.g., models trained using the ljspeech + // dataset. + // @param speed The speed for the generated speech. E.g., 2 means 2x faster. + // @param callback If not NULL, it is called whenever config.max_num_sentences + // sentences have been processed. Note that the passed + // pointer `samples` for the callback might be invalidated + // after the callback is returned, so the caller should not + // keep a reference to it. The caller can copy the data if + // he/she wants to access the samples after the callback + // returns. The callback is called in the current thread. + GeneratedAudio Generate(const std::string &text, int sid = 0, + float speed = 1.0, + GeneratedAudioCallback callback = nullptr) const; + + // Return the sample rate of the generated audio + int32_t SampleRate() const; + + // Number of supported speakers. + // If it supports only a single speaker, then it return 0 or 1. + int32_t NumSpeakers() const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-websocket-server-impl.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-websocket-server-impl.cc new file mode 100644 index 00000000..1462f409 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-websocket-server-impl.cc @@ -0,0 +1,286 @@ +// sherpa-mnn/csrc/offline-websocket-server-impl.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-websocket-server-impl.h" + +#include + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineWebsocketDecoderConfig::Register(ParseOptions *po) { + recognizer_config.Register(po); + + po->Register("max-batch-size", &max_batch_size, + "Max batch size for decoding."); + + po->Register( + "max-utterance-length", &max_utterance_length, + "Max utterance length in seconds. If we receive an utterance " + "longer than this value, we will reject the connection. " + "If you have enough memory, you can select a large value for it."); +} + +void OfflineWebsocketDecoderConfig::Validate() const { + if (!recognizer_config.Validate()) { + SHERPA_ONNX_LOGE("Error in recongizer config"); + exit(-1); + } + + if (max_batch_size <= 0) { + SHERPA_ONNX_LOGE("Expect --max-batch-size > 0. Given: %d", max_batch_size); + exit(-1); + } + + if (max_utterance_length <= 0) { + SHERPA_ONNX_LOGE("Expect --max-utterance-length > 0. Given: %f", + max_utterance_length); + exit(-1); + } +} + +OfflineWebsocketDecoder::OfflineWebsocketDecoder(OfflineWebsocketServer *server) + : config_(server->GetConfig().decoder_config), + server_(server), + recognizer_(config_.recognizer_config) {} + +void OfflineWebsocketDecoder::Push(connection_hdl hdl, ConnectionDataPtr d) { + std::lock_guard lock(mutex_); + streams_.push_back({hdl, d}); +} + +void OfflineWebsocketDecoder::Decode() { + std::unique_lock lock(mutex_); + if (streams_.empty()) { + return; + } + + int32_t size = + std::min(static_cast(streams_.size()), config_.max_batch_size); + SHERPA_ONNX_LOGE("size: %d", size); + + // We first lock the mutex for streams_, take items from it, and then + // unlock the mutex; in doing so we don't need to lock the mutex to + // access hdl and connection_data later. + std::vector handles(size); + + // Store connection_data here to prevent the data from being freed + // while we are still using it. + std::vector connection_data(size); + + std::vector samples(size); + std::vector samples_length(size); + std::vector> ss(size); + std::vector p_ss(size); + + for (int32_t i = 0; i != size; ++i) { + auto &p = streams_.front(); + handles[i] = p.first; + connection_data[i] = p.second; + streams_.pop_front(); + + auto sample_rate = connection_data[i]->sample_rate; + auto samples = + reinterpret_cast(&connection_data[i]->data[0]); + auto num_samples = connection_data[i]->expected_byte_size / sizeof(float); + auto s = recognizer_.CreateStream(); + s->AcceptWaveform(sample_rate, samples, num_samples); + + ss[i] = std::move(s); + p_ss[i] = ss[i].get(); + } + + lock.unlock(); + + // Note: DecodeStreams is thread-safe + recognizer_.DecodeStreams(p_ss.data(), size); + + for (int32_t i = 0; i != size; ++i) { + connection_hdl hdl = handles[i]; + asio::post(server_->GetConnectionContext(), + [this, hdl, result = ss[i]->GetResult()]() { + websocketpp::lib::error_code ec; + server_->GetServer().send(hdl, result.AsJsonString(), + websocketpp::frame::opcode::text, + ec); + if (ec) { + server_->GetServer().get_alog().write( + websocketpp::log::alevel::app, ec.message()); + } + }); + } +} + +void OfflineWebsocketServerConfig::Register(ParseOptions *po) { + decoder_config.Register(po); + po->Register("log-file", &log_file, + "Path to the log file. Logs are " + "appended to this file"); +} + +void OfflineWebsocketServerConfig::Validate() const { + decoder_config.Validate(); +} + +OfflineWebsocketServer::OfflineWebsocketServer( + asio::io_context &io_conn, // NOLINT + asio::io_context &io_work, // NOLINT + const OfflineWebsocketServerConfig &config) + : io_conn_(io_conn), + io_work_(io_work), + config_(config), + log_(config.log_file, std::ios::app), + tee_(std::cout, log_), + decoder_(this) { + SetupLog(); + + server_.init_asio(&io_conn_); + + server_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); }); + + server_.set_close_handler([this](connection_hdl hdl) { OnClose(hdl); }); + + server_.set_message_handler( + [this](connection_hdl hdl, server::message_ptr msg) { + OnMessage(hdl, msg); + }); +} + +void OfflineWebsocketServer::SetupLog() { + server_.clear_access_channels(websocketpp::log::alevel::all); + server_.set_access_channels(websocketpp::log::alevel::connect); + server_.set_access_channels(websocketpp::log::alevel::disconnect); + + // So that it also prints to std::cout and std::cerr + server_.get_alog().set_ostream(&tee_); + server_.get_elog().set_ostream(&tee_); +} + +void OfflineWebsocketServer::OnOpen(connection_hdl hdl) { + std::lock_guard lock(mutex_); + connections_.emplace(hdl, std::make_shared()); + + SHERPA_ONNX_LOGE("Number of active connections: %d", + static_cast(connections_.size())); +} + +void OfflineWebsocketServer::OnClose(connection_hdl hdl) { + std::lock_guard lock(mutex_); + connections_.erase(hdl); + + SHERPA_ONNX_LOGE("Number of active connections: %d", + static_cast(connections_.size())); +} + +void OfflineWebsocketServer::OnMessage(connection_hdl hdl, + server::message_ptr msg) { + std::unique_lock lock(mutex_); + auto connection_data = connections_.find(hdl)->second; + lock.unlock(); + const std::string &payload = msg->get_payload(); + + switch (msg->get_opcode()) { + case websocketpp::frame::opcode::text: + if (payload == "Done") { + // The client will not send any more data. We can close the + // connection now. + Close(hdl, websocketpp::close::status::normal, "Done"); + } else { + Close(hdl, websocketpp::close::status::normal, + std::string("Invalid payload: ") + payload); + } + break; + + case websocketpp::frame::opcode::binary: { + auto p = reinterpret_cast(payload.data()); + + if (connection_data->expected_byte_size == 0) { + if (payload.size() < 8) { + Close(hdl, websocketpp::close::status::normal, + "Payload is too short"); + break; + } + + connection_data->sample_rate = *reinterpret_cast(p); + + connection_data->expected_byte_size = + *reinterpret_cast(p + 4); + + int32_t max_byte_size_ = decoder_.GetConfig().max_utterance_length * + connection_data->sample_rate * sizeof(float); + if (connection_data->expected_byte_size > max_byte_size_) { + float num_samples = + connection_data->expected_byte_size / sizeof(float); + + float duration = num_samples / connection_data->sample_rate; + + std::ostringstream os; + os << "Max utterance length is configured to " + << decoder_.GetConfig().max_utterance_length + << " seconds, received length is " << duration << " seconds. " + << "Payload is too large!"; + Close(hdl, websocketpp::close::status::message_too_big, os.str()); + break; + } + + connection_data->data.resize(connection_data->expected_byte_size); + std::copy(payload.begin() + 8, payload.end(), + connection_data->data.data()); + connection_data->cur = payload.size() - 8; + } else { + std::copy(payload.begin(), payload.end(), + connection_data->data.data() + connection_data->cur); + connection_data->cur += payload.size(); + } + + if (connection_data->expected_byte_size == connection_data->cur) { + auto d = std::make_shared(std::move(*connection_data)); + // Clear it so that we can handle the next audio file from the client. + // The client can send multiple audio files for recognition without + // the need to create another connection. + connection_data->sample_rate = 0; + connection_data->expected_byte_size = 0; + connection_data->cur = 0; + + decoder_.Push(hdl, d); + + connection_data->Clear(); + + asio::post(io_work_, [this]() { decoder_.Decode(); }); + } + break; + } + + default: + // Unexpected message, ignore it + break; + } +} + +void OfflineWebsocketServer::Close(connection_hdl hdl, + websocketpp::close::status::value code, + const std::string &reason) { + auto con = server_.get_con_from_hdl(hdl); + + std::ostringstream os; + os << "Closing " << con->get_remote_endpoint() << " with reason: " << reason + << "\n"; + + websocketpp::lib::error_code ec; + server_.close(hdl, code, reason, ec); + if (ec) { + os << "Failed to close" << con->get_remote_endpoint() << ". " + << ec.message() << "\n"; + } + server_.get_alog().write(websocketpp::log::alevel::app, os.str()); +} + +void OfflineWebsocketServer::Run(uint16_t port) { + server_.set_reuse_addr(true); + server_.listen(asio::ip::tcp::v4(), port); + server_.start_accept(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-websocket-server-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-websocket-server-impl.h new file mode 100644 index 00000000..1d5f4ac3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-websocket-server-impl.h @@ -0,0 +1,205 @@ +// sherpa-mnn/csrc/offline-websocket-server-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/offline-recognizer.h" +#include "sherpa-mnn/csrc/parse-options.h" +#include "sherpa-mnn/csrc/tee-stream.h" +#include "websocketpp/config/asio_no_tls.hpp" // TODO(fangjun): support TLS +#include "websocketpp/server.hpp" + +using server = websocketpp::server; +using connection_hdl = websocketpp::connection_hdl; + +namespace sherpa_mnn { + +/** Communication protocol + * + * The client sends a byte stream to the server. The first 4 bytes in little + * endian indicates the sample rate of the audio data that the client will send. + * The next 4 bytes in little endian indicates the total samples in bytes the + * client will send. The remaining bytes represent audio samples. Each audio + * sample is a float occupying 4 bytes and is normalized into the range + * [-1, 1]. + * + * The byte stream can be broken into arbitrary number of messages. + * We require that the first message has to be at least 8 bytes so that + * we can get `sample_rate` and `expected_byte_size` from the first message. + */ +struct ConnectionData { + // Sample rate of the audio samples the client + int32_t sample_rate; + + // Number of expected bytes sent from the client + int32_t expected_byte_size = 0; + + // Number of bytes received so far + int32_t cur = 0; + + // It saves the received samples from the client. + // We will **reinterpret_cast** it to float. + // We expect that data.size() == expected_byte_size + std::vector data; + + void Clear() { + sample_rate = 0; + expected_byte_size = 0; + cur = 0; + data.clear(); + } +}; + +using ConnectionDataPtr = std::shared_ptr; + +struct OfflineWebsocketDecoderConfig { + OfflineRecognizerConfig recognizer_config; + + int32_t max_batch_size = 5; + + float max_utterance_length = 300; // seconds + + void Register(ParseOptions *po); + void Validate() const; +}; + +class OfflineWebsocketServer; + +class OfflineWebsocketDecoder { + public: + /** + * @param config Configuration for the decoder. + * @param server **Borrowed** from outside. + */ + explicit OfflineWebsocketDecoder(OfflineWebsocketServer *server); + + /** Insert received data to the queue for decoding. + * + * @param hdl A handle to the connection. We can use it to send the result + * back to the client once it finishes decoding. + * @param d The received data + */ + void Push(connection_hdl hdl, ConnectionDataPtr d); + + /** It is called by one of the work thread. + */ + void Decode(); + + const OfflineWebsocketDecoderConfig &GetConfig() const { return config_; } + + private: + OfflineWebsocketDecoderConfig config_; + + /** When we have received all the data from the client, we put it into + * this queue; the worker threads will get items from this queue for + * decoding. + * + * Number of items to take from this queue is determined by + * `--max-batch-size`. If there are not enough items in the queue, we won't + * wait and take whatever we have for decoding. + */ + std::mutex mutex_; + std::deque> streams_; + + OfflineWebsocketServer *server_; // Not owned + OfflineRecognizer recognizer_; +}; + +struct OfflineWebsocketServerConfig { + OfflineWebsocketDecoderConfig decoder_config; + std::string log_file = "./log.txt"; + + void Register(ParseOptions *po); + void Validate() const; +}; + +class OfflineWebsocketServer { + public: + OfflineWebsocketServer(asio::io_context &io_conn, // NOLINT + asio::io_context &io_work, // NOLINT + const OfflineWebsocketServerConfig &config); + + asio::io_context &GetConnectionContext() { return io_conn_; } + server &GetServer() { return server_; } + + void Run(uint16_t port); + + const OfflineWebsocketServerConfig &GetConfig() const { return config_; } + + private: + void SetupLog(); + + // When a websocket client is connected, it will invoke this method + // (Not for HTTP) + void OnOpen(connection_hdl hdl); + + // When a websocket client is disconnected, it will invoke this method + void OnClose(connection_hdl hdl); + + // When a message is received from a websocket client, this method will + // be invoked. + // + // The protocol between the client and the server is as follows: + // + // (1) The client connects to the server + // (2) The client starts to send binary byte stream to the server. + // The byte stream can be broken into multiple messages or it can + // be put into a single message. + // The first message has to contain at least 8 bytes. The first + // 4 bytes in little endian contains a int32_t indicating the + // sampling rate. The next 4 bytes in little endian contains a int32_t + // indicating total number of bytes of samples the client will send. + // We assume each sample is a float containing 4 bytes and has been + // normalized to the range [-1, 1]. + // (4) When the server receives all the samples from the client, it will + // start to decode them. Once decoded, the server sends a text message + // to the client containing the decoded results + // (5) After receiving the decoded results from the server, if the client has + // another audio file to send, it repeats (2), (3), (4) + // (6) If the client has no more audio files to decode, the client sends a + // text message containing "Done" to the server and closes the connection + // (7) The server receives a text message "Done" and closes the connection + // + // Note: + // (a) All models in icefall use features extracted from audio samples + // normalized to the range [-1, 1]. Please send normalized audio samples + // if you use models from icefall. + // (b) Only sound files with a single channel is supported + // (c) Only audio samples are sent. For instance, if we want to decode + // a WAVE file, the RIFF header of the WAVE is not sent. + void OnMessage(connection_hdl hdl, server::message_ptr msg); + + // Close a websocket connection with given code and reason + void Close(connection_hdl hdl, websocketpp::close::status::value code, + const std::string &reason); + + private: + asio::io_context &io_conn_; + asio::io_context &io_work_; + server server_; + + std::map> + connections_; + std::mutex mutex_; + + OfflineWebsocketServerConfig config_; + + std::ofstream log_; + TeeStream tee_; + + OfflineWebsocketDecoder decoder_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_WEBSOCKET_SERVER_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-websocket-server.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-websocket-server.cc new file mode 100644 index 00000000..762785f9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-websocket-server.cc @@ -0,0 +1,121 @@ +// sherpa-mnn/csrc/offline-websocket-server.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "asio.hpp" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-websocket-server-impl.h" +#include "sherpa-mnn/csrc/parse-options.h" + +static constexpr const char *kUsageMessage = R"( +Automatic speech recognition with sherpa-mnn using websocket. + +Usage: + +./bin/sherpa-mnn-offline-websocket-server --help + +(1) For transducer models + +./bin/sherpa-mnn-offline-websocket-server \ + --port=6006 \ + --num-work-threads=5 \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --log-file=./log.txt \ + --max-batch-size=5 + +(2) For Paraformer + +./bin/sherpa-mnn-offline-websocket-server \ + --port=6006 \ + --num-work-threads=5 \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/model.onnx \ + --log-file=./log.txt \ + --max-batch-size=5 + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. +)"; + +int32_t main(int32_t argc, char *argv[]) { + sherpa_mnn::ParseOptions po(kUsageMessage); + + sherpa_mnn::OfflineWebsocketServerConfig config; + + // the server will listen on this port + int32_t port = 6006; + + // size of the thread pool for handling network connections + int32_t num_io_threads = 1; + + // size of the thread pool for neural network computation and decoding + int32_t num_work_threads = 3; + + po.Register("num-io-threads", &num_io_threads, + "Thread pool size for network connections."); + + po.Register("num-work-threads", &num_work_threads, + "Thread pool size for for neural network " + "computation and decoding."); + + po.Register("port", &port, "The port on which the server will listen."); + + config.Register(&po); + po.DisableOption("sample-rate"); + + if (argc == 1) { + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + po.Read(argc, argv); + + if (po.NumArgs() != 0) { + SHERPA_ONNX_LOGE("Unrecognized positional arguments!"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + config.Validate(); + + asio::io_context io_conn; // for network connections + asio::io_context io_work; // for neural network and decoding + + sherpa_mnn::OfflineWebsocketServer server(io_conn, io_work, config); + server.Run(port); + + SHERPA_ONNX_LOGE("Started!"); + SHERPA_ONNX_LOGE("Listening on: %d", port); + SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads); + + // give some work to do for the io_work pool + auto work_guard = asio::make_work_guard(io_work); + + std::vector io_threads; + + // decrement since the main thread is also used for network communications + for (int32_t i = 0; i < num_io_threads - 1; ++i) { + io_threads.emplace_back([&io_conn]() { io_conn.run(); }); + } + + std::vector work_threads; + for (int32_t i = 0; i < num_work_threads; ++i) { + work_threads.emplace_back([&io_work]() { io_work.run(); }); + } + + io_conn.run(); + + for (auto &t : io_threads) { + t.join(); + } + + for (auto &t : work_threads) { + t.join(); + } + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-wenet-ctc-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-wenet-ctc-model-config.cc new file mode 100644 index 00000000..27264cc2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-wenet-ctc-model-config.cc @@ -0,0 +1,37 @@ +// sherpa-mnn/csrc/offline-wenet-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-wenet-ctc-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineWenetCtcModelConfig::Register(ParseOptions *po) { + po->Register( + "wenet-ctc-model", &model, + "Path to model.onnx from WeNet. Please see " + "https://github.com/k2-fsa/sherpa-mnn/pull/425 for available models"); +} + +bool OfflineWenetCtcModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("WeNet model: '%s' does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OfflineWenetCtcModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineWenetCtcModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-wenet-ctc-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-wenet-ctc-model-config.h new file mode 100644 index 00000000..8d2f16bd --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-wenet-ctc-model-config.h @@ -0,0 +1,28 @@ +// sherpa-mnn/csrc/offline-wenet-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineWenetCtcModelConfig { + std::string model; + + OfflineWenetCtcModelConfig() = default; + explicit OfflineWenetCtcModelConfig(const std::string &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-wenet-ctc-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-wenet-ctc-model.cc new file mode 100644 index 00000000..114ea54e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-wenet-ctc-model.cc @@ -0,0 +1,137 @@ +// sherpa-mnn/csrc/offline-wenet-ctc-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-wenet-ctc-model.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" +#include "sherpa-mnn/csrc/transpose.h" + +namespace sherpa_mnn { + +class OfflineWenetCtcModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.wenet_ctc.model); + Init(buf.data(), buf.size()); + } + + template + Impl(Manager *mgr, const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.wenet_ctc.model); + Init(buf.data(), buf.size()); + } + + std::vector Forward(MNN::Express::VARP features, + MNN::Express::VARP features_length) { + std::vector inputs = {std::move(features), + std::move(features_length)}; + + return sess_->onForward(inputs); + } + + int32_t VocabSize() const { return vocab_size_; } + + int32_t SubsamplingFactor() const { return subsampling_factor_; } + + MNNAllocator *Allocator() { return allocator_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); + } + + private: + OfflineModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t vocab_size_ = 0; + int32_t subsampling_factor_ = 0; +}; + +OfflineWenetCtcModel::OfflineWenetCtcModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineWenetCtcModel::OfflineWenetCtcModel(Manager *mgr, + const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineWenetCtcModel::~OfflineWenetCtcModel() = default; + +std::vector OfflineWenetCtcModel::Forward( + MNN::Express::VARP features, MNN::Express::VARP features_length) { + return impl_->Forward(std::move(features), std::move(features_length)); +} + +int32_t OfflineWenetCtcModel::VocabSize() const { return impl_->VocabSize(); } + +int32_t OfflineWenetCtcModel::SubsamplingFactor() const { + return impl_->SubsamplingFactor(); +} + +MNNAllocator *OfflineWenetCtcModel::Allocator() const { + return impl_->Allocator(); +} + +#if __ANDROID_API__ >= 9 +template OfflineWenetCtcModel::OfflineWenetCtcModel( + AAssetManager *mgr, const OfflineModelConfig &config); +#endif + +#if __OHOS__ +template OfflineWenetCtcModel::OfflineWenetCtcModel( + NativeResourceManager *mgr, const OfflineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-wenet-ctc-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-wenet-ctc-model.h new file mode 100644 index 00000000..a45fc8bd --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-wenet-ctc-model.h @@ -0,0 +1,73 @@ +// sherpa-mnn/csrc/offline-wenet-ctc-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_ +#include +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-ctc-model.h" +#include "sherpa-mnn/csrc/offline-model-config.h" + +namespace sherpa_mnn { + +/** This class implements the CTC model from WeNet. + * + * See + * https://github.com/k2-fsa/sherpa-mnn/blob/master/scripts/wenet/export-onnx.py + * https://github.com/k2-fsa/sherpa-mnn/blob/master/scripts/wenet/test-onnx.py + * https://github.com/k2-fsa/sherpa-mnn/blob/master/scripts/wenet/run.sh + * + */ +class OfflineWenetCtcModel : public OfflineCtcModel { + public: + explicit OfflineWenetCtcModel(const OfflineModelConfig &config); + + template + OfflineWenetCtcModel(Manager *mgr, const OfflineModelConfig &config); + + ~OfflineWenetCtcModel() override; + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int. + * + * @return Return a vector containing: + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int + */ + std::vector Forward(MNN::Express::VARP features, + MNN::Express::VARP features_length) override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** SubsamplingFactor of the model + * + * For Citrinet, the subsampling factor is usually 4. + * For Conformer CTC, the subsampling factor is usually 8. + */ + int32_t SubsamplingFactor() const override; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const override; + + // WeNet CTC models do not support batch size > 1 + bool SupportBatchProcessing() const override { return false; } + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_WENET_CTC_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-decoder.h new file mode 100644 index 00000000..1bfd8935 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-decoder.h @@ -0,0 +1,44 @@ +// sherpa-mnn/csrc/offline-whisper-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ + +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-whisper-model-config.h" + +namespace sherpa_mnn { + +struct OfflineWhisperDecoderResult { + /// The decoded token IDs + std::vector tokens; + std::string lang; +}; + +class OfflineWhisperDecoder { + public: + virtual ~OfflineWhisperDecoder() = default; + + /** Run beam search given the output from the whisper encoder model. + * + * @param n_layer_cross_k A 4-D tensor of shape + * (n_text_layer, N, n_audio_ctx, n_text_state). + * @param n_layer_cross_v A 4-D tensor of shape + * (n_text_layer, N, n_audio_ctx, n_text_state). + * + * @return Return a vector of size `N` containing the decoded results. + */ + virtual std::vector Decode( + MNN::Express::VARP n_layer_cross_k, MNN::Express::VARP n_layer_cross_v, + int32_t num_feature_frames) = 0; + + virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-greedy-search-decoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-greedy-search-decoder.cc new file mode 100644 index 00000000..c279ee7e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-greedy-search-decoder.cc @@ -0,0 +1,158 @@ +// sherpa-mnn/csrc/offline-whisper-greedy-search-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-whisper-greedy-search-decoder.h" + +#include +#include + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +void OfflineWhisperGreedySearchDecoder::SetConfig( + const OfflineWhisperModelConfig &config) { + config_ = config; +} + +std::vector +OfflineWhisperGreedySearchDecoder::Decode(MNN::Express::VARP cross_k, + MNN::Express::VARP cross_v, + int32_t num_feature_frames) { + auto memory_info = + (MNNAllocator*)(nullptr); + + // For multilingual models, initial_tokens contains [sot, language, task] + // - language is English by default + // - task is transcribe by default + // + // For non-multilingual models, initial_tokens contains [sot] + std::vector initial_tokens = model_->GetInitialTokens(); + + if (model_->IsMultiLingual()) { + if (!config_.language.empty()) { + const auto &lang2id = model_->GetLang2ID(); + + if (!lang2id.count(config_.language)) { + SHERPA_ONNX_LOGE("Invalid language: %s", config_.language.c_str()); + exit(-1); + } + + int32_t lang_id = lang2id.at(config_.language); + + // 0: sot, 1: lang_id, 2: task, 3: no_timestamps + initial_tokens[1] = lang_id; + } else { + int32_t lang_id = model_->DetectLanguage(cross_k, cross_v); + + // 0: sot, 1: lang_id, 2: task, 3: no_timestamps + initial_tokens[1] = lang_id; + } + + if (config_.task == "translate") { + initial_tokens[2] = model_->Translate(); + } else if (config_.task != "transcribe") { + // initial_tokens[2] is transcribe by default + SHERPA_ONNX_LOGE( + "Unsupported task: %s. Valid values are: transcribe, translate.", + config_.task.c_str()); + } + } + + initial_tokens.push_back(model_->NoTimeStampsToken()); + + int32_t batch_size = 1; + std::array token_shape{ + batch_size, static_cast(initial_tokens.size())}; + + MNN::Express::VARP tokens = MNNUtilsCreateTensor( + memory_info, initial_tokens.data(), initial_tokens.size(), + token_shape.data(), token_shape.size()); + + std::array offset_shape{1}; + MNN::Express::VARP offset = MNNUtilsCreateTensor( + model_->Allocator(), offset_shape.data(), offset_shape.size()); + *(offset->writeMap()) = 0; + + auto self_kv_cache = model_->GetInitialSelfKVCache(); + + auto decoder_out = model_->ForwardDecoder( + std::move(tokens), std::move(self_kv_cache.first), + std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v), + std::move(offset)); + + *(std::get<5>(decoder_out)->writeMap()) = + initial_tokens.size(); + + const auto &logits = std::get<0>(decoder_out); + const float *p_logits = logits->readMap(); + + auto logits_shape = logits->getInfo()->dim; + int32_t vocab_size = logits_shape[2]; + + const float *p_start = p_logits + (logits_shape[1] - 1) * vocab_size; + + int32_t max_token_id = static_cast( + std::distance(p_start, std::max_element(p_start, p_start + vocab_size))); + + int32_t n_text_ctx = model_->TextCtx(); + + std::vector predicted_tokens; + + // assume at most 6 tokens per second + int32_t num_possible_tokens = num_feature_frames / 100 * 6; + num_possible_tokens = std::min(num_possible_tokens, n_text_ctx / 2); + + for (int32_t i = 0; i < num_possible_tokens; ++i) { + if (max_token_id == model_->EOT()) { + break; + } + + predicted_tokens.push_back(max_token_id); + + std::array token_shape{1, 1}; + MNN::Express::VARP tokens = MNNUtilsCreateTensor( + model_->Allocator(), token_shape.data(), token_shape.size()); + + int *p_tokens = tokens->writeMap(); + p_tokens[0] = max_token_id; + + decoder_out = model_->ForwardDecoder(std::move(tokens), + std::move(std::get<1>(decoder_out)), + std::move(std::get<2>(decoder_out)), + std::move(std::get<3>(decoder_out)), + std::move(std::get<4>(decoder_out)), + std::move(std::get<5>(decoder_out))); + + int *p_offset = + std::get<5>(decoder_out)->writeMap(); + + *p_offset += 1; + if (*p_offset >= n_text_ctx - 1) { + break; + } + + const auto &logits = std::get<0>(decoder_out); + const float *p_logits = logits->readMap(); + + max_token_id = static_cast(std::distance( + p_logits, std::max_element(p_logits, p_logits + vocab_size))); + } + + std::vector ans(1); + + const auto &id2lang = model_->GetID2Lang(); + if (id2lang.count(initial_tokens[1])) { + ans[0].lang = id2lang.at(initial_tokens[1]); + } else { + ans[0].lang = ""; + } + + ans[0].tokens = std::move(predicted_tokens); + + return ans; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-greedy-search-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-greedy-search-decoder.h new file mode 100644 index 00000000..bd98c67d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-greedy-search-decoder.h @@ -0,0 +1,34 @@ +// sherpa-mnn/csrc/offline-whisper-greedy-search-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_ + +#include + +#include "sherpa-mnn/csrc/offline-whisper-decoder.h" +#include "sherpa-mnn/csrc/offline-whisper-model.h" + +namespace sherpa_mnn { + +class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { + public: + OfflineWhisperGreedySearchDecoder(const OfflineWhisperModelConfig &config, + OfflineWhisperModel *model) + : config_(config), model_(model) {} + + std::vector Decode( + MNN::Express::VARP cross_k, MNN::Express::VARP cross_v, + int32_t num_feature_frames) override; + + void SetConfig(const OfflineWhisperModelConfig &config) override; + + private: + OfflineWhisperModelConfig config_; + OfflineWhisperModel *model_; // not owned +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-model-config.cc new file mode 100644 index 00000000..52753d38 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-model-config.cc @@ -0,0 +1,91 @@ +// sherpa-mnn/csrc/offline-whisper-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-whisper-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineWhisperModelConfig::Register(ParseOptions *po) { + po->Register("whisper-encoder", &encoder, + "Path to onnx encoder of whisper, e.g., tiny-encoder.onnx, " + "medium.en-encoder.onnx."); + + po->Register("whisper-decoder", &decoder, + "Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, " + "medium.en-decoder.onnx."); + + po->Register( + "whisper-language", &language, + "The spoken language in the input audio file. Example values: " + "en, de, fr, zh, jp. If it is not given for a multilingual model, we will" + " infer the language from the input audio file. " + "Please refer to " + "https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10" + " for valid values. Note that for non-multilingual models, it supports " + "only 'en'"); + + po->Register("whisper-task", &task, + "Valid values: transcribe, translate. " + "Note that for non-multilingual models, it supports " + "only 'transcribe'"); + + po->Register( + "whisper-tail-paddings", &tail_paddings, + "Suggested value: 50 for English models. 300 for multilingual models. " + "Since we have removed the 30-second constraint, we need to add some " + "tail padding frames " + "so that whisper can detect the eot token. Leave it to -1 to use 1000."); +} + +bool OfflineWhisperModelConfig::Validate() const { + if (encoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --whisper-encoder"); + return false; + } + + if (!FileExists(encoder)) { + SHERPA_ONNX_LOGE("whisper encoder file '%s' does not exist", + encoder.c_str()); + return false; + } + + if (decoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --whisper-decoder"); + return false; + } + + if (!FileExists(decoder)) { + SHERPA_ONNX_LOGE("whisper decoder file '%s' does not exist", + decoder.c_str()); + return false; + } + + if (task != "translate" && task != "transcribe") { + SHERPA_ONNX_LOGE( + "--whisper-task supports only translate and transcribe. Given: %s", + task.c_str()); + + return false; + } + + return true; +} + +std::string OfflineWhisperModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineWhisperModelConfig("; + os << "encoder=\"" << encoder << "\", "; + os << "decoder=\"" << decoder << "\", "; + os << "language=\"" << language << "\", "; + os << "task=\"" << task << "\", "; + os << "tail_paddings=" << tail_paddings << ")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-model-config.h new file mode 100644 index 00000000..c187f85c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-model-config.h @@ -0,0 +1,60 @@ +// sherpa-mnn/csrc/offline-whisper-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineWhisperModelConfig { + std::string encoder; + std::string decoder; + + // Available languages can be found at + // https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + // + // Note: For non-multilingual models, it supports only "en" + // + // If empty, we will infer it from the input audio file when + // the model is multilingual. + std::string language; + + // Valid values are transcribe and translate + // + // Note: For non-multilingual models, it supports only "transcribe" + std::string task = "transcribe"; + + // Number of tail padding frames. + // + // Since we remove the 30-second constraint, we need to add some paddings + // at the end. + // + // Recommended values: + // - 50 for English models + // - 300 for multilingual models + int32_t tail_paddings = -1; + + OfflineWhisperModelConfig() = default; + OfflineWhisperModelConfig(const std::string &encoder, + const std::string &decoder, + const std::string &language, + const std::string &task, int32_t tail_paddings) + : encoder(encoder), + decoder(decoder), + language(language), + task(task), + tail_paddings(tail_paddings) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-model.cc new file mode 100644 index 00000000..ca84fb97 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-model.cc @@ -0,0 +1,477 @@ +// sherpa-mnn/csrc/offline-whisper-model.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-whisper-model.h" + +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OfflineWhisperModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.whisper.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.whisper.decoder); + InitDecoder(buf.data(), buf.size()); + } + } + + explicit Impl(const SpokenLanguageIdentificationConfig &config) + : lid_config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.whisper.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.whisper.decoder); + InitDecoder(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.whisper.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.whisper.decoder); + InitDecoder(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const SpokenLanguageIdentificationConfig &config) + : lid_config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.whisper.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.whisper.decoder); + InitDecoder(buf.data(), buf.size()); + } + } + + std::pair ForwardEncoder(MNN::Express::VARP features) { + auto encoder_out = encoder_sess_->onForward({features}); + + return {std::move(encoder_out[0]), std::move(encoder_out[1])}; + } + + std::tuple + ForwardDecoder(MNN::Express::VARP tokens, MNN::Express::VARP n_layer_self_k_cache, + MNN::Express::VARP n_layer_self_v_cache, MNN::Express::VARP n_layer_cross_k, + MNN::Express::VARP n_layer_cross_v, MNN::Express::VARP offset) { + std::vector decoder_input = {std::move(tokens), + std::move(n_layer_self_k_cache), + std::move(n_layer_self_v_cache), + std::move(n_layer_cross_k), + std::move(n_layer_cross_v), + std::move(offset)}; + + auto decoder_out = decoder_sess_->onForward(decoder_input); + + return std::tuple{ + std::move(decoder_out[0]), std::move(decoder_out[1]), + std::move(decoder_out[2]), std::move(decoder_input[3]), + std::move(decoder_input[4]), std::move(decoder_input[5])}; + } + + int32_t DetectLanguage(MNN::Express::VARP &cross_k, // NOLINT + MNN::Express::VARP &cross_v) { // NOLINT + int token_val = SOT(); + std::array token_shape{1, 1}; + + auto memory_info = + (MNNAllocator*)(nullptr); + + MNN::Express::VARP tokens = MNNUtilsCreateTensor( + memory_info, &token_val, 1, token_shape.data(), token_shape.size()); + + auto self_kv_cache = GetInitialSelfKVCache(); + + std::array offset_shape{1}; + MNN::Express::VARP offset = MNNUtilsCreateTensor( + Allocator(), offset_shape.data(), offset_shape.size()); + *(offset->writeMap()) = 0; + + auto decoder_out = + ForwardDecoder(std::move(tokens), std::move(self_kv_cache.first), + std::move(self_kv_cache.second), std::move(cross_k), + std::move(cross_v), std::move(offset)); + + cross_k = std::move(std::get<3>(decoder_out)); + cross_v = std::move(std::get<4>(decoder_out)); + + const float *p_logits = std::get<0>(decoder_out)->readMap(); + const auto &all_language_ids = GetAllLanguageIDs(); + + int32_t lang_id = all_language_ids[0]; + float this_logit = p_logits[lang_id]; + + for (int32_t i = 1; i != all_language_ids.size(); ++i) { + int32_t id = all_language_ids[i]; + float p = p_logits[id]; + + if (p > this_logit) { + this_logit = p; + lang_id = id; + } + } + + if (config_.debug) { + SHERPA_ONNX_LOGE("Detected language: %s", + GetID2Lang().at(lang_id).c_str()); + } + + return lang_id; + } + + std::pair GetInitialSelfKVCache() { + std::array shape{n_text_layer_, 1, n_text_ctx_, n_text_state_}; + + MNN::Express::VARP n_layer_self_k_cache = MNNUtilsCreateTensor( + Allocator(), shape.data(), shape.size()); + + MNN::Express::VARP n_layer_self_v_cache = MNNUtilsCreateTensor( + Allocator(), shape.data(), shape.size()); + + auto n = shape[0] * shape[1] * shape[2] * shape[3]; + + float *p_k = n_layer_self_k_cache->writeMap(); + float *p_v = n_layer_self_v_cache->writeMap(); + + memset(p_k, 0, sizeof(float) * n); + memset(p_v, 0, sizeof(float) * n); + + return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)}; + } + + MNNAllocator *Allocator() { return allocator_; } + + const std::vector &GetInitialTokens() const { return sot_sequence_; } + + const std::vector &GetAllLanguageIDs() const { + return all_language_tokens_; + } + + const std::unordered_map &GetLang2ID() const { + return lang2id_; + } + + const std::unordered_map &GetID2Lang() const { + return id2lang_; + } + + int32_t NoTimeStampsToken() const { return no_timestamps_; } + + int32_t EOT() const { return eot_; } + + int32_t SOT() const { return sot_; } + + int32_t TextCtx() const { return n_text_ctx_; } + + int32_t VocabSize() const { return n_vocab_; } + + int32_t FeatureDim() const { return n_mels_; } + + int32_t Translate() const { return translate_; } + + bool IsMultiLingual() const { return is_multilingual_; } + + private: + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = encoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(n_mels_, "n_mels"); + SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer"); + SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx"); + SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state"); + SHERPA_ONNX_READ_META_DATA(n_vocab_, "n_vocab"); + SHERPA_ONNX_READ_META_DATA(sot_, "sot"); + SHERPA_ONNX_READ_META_DATA(eot_, "eot"); + SHERPA_ONNX_READ_META_DATA(blank_, "blank_id"); + SHERPA_ONNX_READ_META_DATA(translate_, "translate"); + SHERPA_ONNX_READ_META_DATA(transcribe_, "transcribe"); + SHERPA_ONNX_READ_META_DATA(is_multilingual_, "is_multilingual"); + SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps"); + SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech"); + SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence"); + + if (is_multilingual_) { + SHERPA_ONNX_READ_META_DATA_VEC(all_language_tokens_, + "all_language_tokens"); + SHERPA_ONNX_READ_META_DATA_VEC_STRING(all_language_codes_, + "all_language_codes"); + if (all_language_tokens_.size() != all_language_codes_.size()) { + SHERPA_ONNX_LOGE("# lang_id: %d != # lang_code: %d", + static_cast(all_language_tokens_.size()), + static_cast(all_language_codes_.size())); + exit(-1); + } + + for (int32_t i = 0; + i != static_cast(all_language_tokens_.size()); ++i) { + lang2id_[all_language_codes_[i]] = all_language_tokens_[i]; + id2lang_[all_language_tokens_[i]] = all_language_codes_[i]; + } + } + } + + void InitDecoder(void *model_data, size_t model_data_length) { + decoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + } + + private: + OfflineModelConfig config_; + SpokenLanguageIdentificationConfig lid_config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector all_language_tokens_; + std::vector all_language_codes_; + std::unordered_map lang2id_; + std::unordered_map id2lang_; + + // model meta data + int32_t n_mels_ = 80; + int32_t n_text_layer_ = 0; + int32_t n_text_ctx_ = 0; + int32_t n_text_state_ = 0; + int32_t n_vocab_ = 0; + int32_t sot_ = 0; + int32_t eot_ = 0; + int32_t blank_ = 0; + int32_t translate_ = 0; + int32_t transcribe_ = 0; + int32_t no_timestamps_ = 0; + int32_t no_speech_ = 0; + int32_t is_multilingual_ = 0; + std::vector sot_sequence_; +}; + +OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +OfflineWhisperModel::OfflineWhisperModel( + const SpokenLanguageIdentificationConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineWhisperModel::OfflineWhisperModel(Manager *mgr, + const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +template +OfflineWhisperModel::OfflineWhisperModel( + Manager *mgr, const SpokenLanguageIdentificationConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineWhisperModel::~OfflineWhisperModel() = default; + +std::pair OfflineWhisperModel::ForwardEncoder( + MNN::Express::VARP features) const { + return impl_->ForwardEncoder(std::move(features)); +} + +std::tuple +OfflineWhisperModel::ForwardDecoder(MNN::Express::VARP tokens, + MNN::Express::VARP n_layer_self_k_cache, + MNN::Express::VARP n_layer_self_v_cache, + MNN::Express::VARP n_layer_cross_k, + MNN::Express::VARP n_layer_cross_v, + MNN::Express::VARP offset) const { + return impl_->ForwardDecoder( + std::move(tokens), std::move(n_layer_self_k_cache), + std::move(n_layer_self_v_cache), std::move(n_layer_cross_k), + std::move(n_layer_cross_v), std::move(offset)); +} + +int32_t OfflineWhisperModel::DetectLanguage(MNN::Express::VARP &cross_k, // NOLINT + MNN::Express::VARP &cross_v) { // NOLINT + return impl_->DetectLanguage(cross_k, cross_v); +} + +std::pair OfflineWhisperModel::GetInitialSelfKVCache() + const { + return impl_->GetInitialSelfKVCache(); +} + +MNNAllocator *OfflineWhisperModel::Allocator() const { + return impl_->Allocator(); +} + +const std::vector &OfflineWhisperModel::GetInitialTokens() const { + return impl_->GetInitialTokens(); +} + +const std::vector &OfflineWhisperModel::GetAllLanguageIDs() const { + return impl_->GetAllLanguageIDs(); +} + +const std::unordered_map + &OfflineWhisperModel::GetLang2ID() const { + return impl_->GetLang2ID(); +} + +const std::unordered_map + &OfflineWhisperModel::GetID2Lang() const { + return impl_->GetID2Lang(); +} + +int32_t OfflineWhisperModel::NoTimeStampsToken() const { + return impl_->NoTimeStampsToken(); +} + +int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); } + +int32_t OfflineWhisperModel::SOT() const { return impl_->SOT(); } + +int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); } + +int32_t OfflineWhisperModel::VocabSize() const { return impl_->VocabSize(); } + +int32_t OfflineWhisperModel::FeatureDim() const { return impl_->FeatureDim(); } + +int32_t OfflineWhisperModel::Translate() const { return impl_->Translate(); } + +bool OfflineWhisperModel::IsMultiLingual() const { + return impl_->IsMultiLingual(); +} + +void OfflineWhisperModel::NormalizeFeatures(float *features, int32_t num_frames, + int32_t feat_dim) { + // log_spec = torch.clamp(features, min=1e-10).log10() + // log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + // mel = (log_spec + 4.0) / 4.0 + + int32_t n = num_frames * feat_dim; + float max_v = -1e20; + for (int32_t i = 0; i != n; ++i) { + float f = features[i]; + + f = std::max(f, 1e-10); + f = std::log10(f); + + max_v = std::max(f, max_v); + + features[i] = f; + } + + max_v -= 8; + + for (int32_t i = 0; i != n; ++i) { + float f = features[i]; + f = std::max(f, max_v); + + f = (f + 4) / 4; + + features[i] = f; + } +} + +#if __ANDROID_API__ >= 9 +template OfflineWhisperModel::OfflineWhisperModel( + AAssetManager *mgr, const OfflineModelConfig &config); + +template OfflineWhisperModel::OfflineWhisperModel( + AAssetManager *mgr, const SpokenLanguageIdentificationConfig &config); +#endif + +#if __OHOS__ +template OfflineWhisperModel::OfflineWhisperModel( + NativeResourceManager *mgr, const OfflineModelConfig &config); + +template OfflineWhisperModel::OfflineWhisperModel( + NativeResourceManager *mgr, + const SpokenLanguageIdentificationConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-model.h new file mode 100644 index 00000000..ea92005b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-whisper-model.h @@ -0,0 +1,115 @@ +// sherpa-mnn/csrc/offline-whisper-model.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_ + +#include +#include +#include +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-model-config.h" +#include "sherpa-mnn/csrc/spoken-language-identification.h" + +namespace sherpa_mnn { + +class OfflineWhisperModel { + public: + explicit OfflineWhisperModel(const OfflineModelConfig &config); + + explicit OfflineWhisperModel( + const SpokenLanguageIdentificationConfig &config); + + template + OfflineWhisperModel(Manager *mgr, const OfflineModelConfig &config); + + template + OfflineWhisperModel(Manager *mgr, + const SpokenLanguageIdentificationConfig &config); + + ~OfflineWhisperModel(); + + /** Run the encoder model. + * + * @param features A tensor of shape (N, C, T). It is changed in-place. + * C is 80 and T is 3000. + * + * @return Return a pair containing: + * - n_layer_cross_k: A 4-D tensor of shape + * (n_text_layer, N, n_audio_ctx, n_text_state) + * - n_layer_cross_v: A 4-D tensor of shape + * (n_text_layer, N, n_audio_ctx, n_text_state) + */ + std::pair ForwardEncoder(MNN::Express::VARP features) const; + + /** Run the decoder model. + * + * @param tokens A int64 tensor of shape (N, num_words) + * @param n_layer_self_k_cache A 4-D tensor of shape + * (n_text_layer, N, n_text_ctx, n_text_state). + * @param n_layer_self_v_cache A 4-D tensor of shape + * (n_text_layer, N, n_text_ctx, n_text_state). + * @param n_layer_cross_k A 4-D tensor of shape + * (n_text_layer, N, n_audio_ctx, n_text_state). + * @param n_layer_cross_v A 4-D tensor of shape + * (n_text_layer, N, n_audio_ctx, n_text_state). + * @param offset A int64 tensor of shape (N,) + * + * @return Return a tuple containing 6 tensors: + * + * - logits A 3-D tensor of shape (N, num_words, vocab_size) + * - out_n_layer_self_k_cache Same shape as n_layer_self_k_cache + * - out_n_layer_self_v_cache Same shape as n_layer_self_v_cache + * - out_n_layer_cross_k Same as n_layer_cross_k + * - out_n_layer_cross_v Same as n_layer_cross_v + * - out_offset Same as offset + */ + std::tuple + ForwardDecoder(MNN::Express::VARP tokens, MNN::Express::VARP n_layer_self_k_cache, + MNN::Express::VARP n_layer_self_v_cache, MNN::Express::VARP n_layer_cross_k, + MNN::Express::VARP n_layer_cross_v, MNN::Express::VARP offset) const; + + int32_t DetectLanguage(MNN::Express::VARP &cross_k, // NOLINT + MNN::Express::VARP &cross_v); // NOLINT + + /** Return the initial self kv cache in a pair + * - n_layer_self_k_cache A 4-D tensor of shape + * (n_text_layer, N, n_audio_ctx, n_text_state). + * - n_layer_self_v_cache A 4-D tensor of shape + * (n_text_layer, N, n_audio_ctx, n_text_state). + */ + std::pair GetInitialSelfKVCache() const; + const std::vector &GetInitialTokens() const; + const std::vector &GetAllLanguageIDs() const; + const std::unordered_map &GetLang2ID() const; + const std::unordered_map &GetID2Lang() const; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const; + + int32_t NoTimeStampsToken() const; + int32_t EOT() const; + int32_t SOT() const; + int32_t TextCtx() const; + int32_t VocabSize() const; + int32_t FeatureDim() const; + int32_t Translate() const; + bool IsMultiLingual() const; + + static void NormalizeFeatures(float *features, int32_t num_frames, + int32_t feat_dim); + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-audio-tagging-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-audio-tagging-model-config.cc new file mode 100644 index 00000000..19a9e2ad --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-audio-tagging-model-config.cc @@ -0,0 +1,40 @@ +// sherpa-mnn/csrc/offline-zipformer-audio-tagging-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-zipformer-audio-tagging-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineZipformerAudioTaggingModelConfig::Register(ParseOptions *po) { + po->Register("zipformer-model", &model, + "Path to zipformer model for audio tagging"); +} + +bool OfflineZipformerAudioTaggingModelConfig::Validate() const { + if (model.empty()) { + SHERPA_ONNX_LOGE("Please provide --zipformer-model"); + return false; + } + + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("--zipformer-model: '%s' does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OfflineZipformerAudioTaggingModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineZipformerAudioTaggingModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-audio-tagging-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-audio-tagging-model-config.h new file mode 100644 index 00000000..4037f0cb --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-audio-tagging-model-config.h @@ -0,0 +1,29 @@ +// sherpa-mnn/csrc/offline-zipformer-audio-tagging-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OfflineZipformerAudioTaggingModelConfig { + std::string model; + + OfflineZipformerAudioTaggingModelConfig() = default; + + explicit OfflineZipformerAudioTaggingModelConfig(const std::string &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-audio-tagging-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-audio-tagging-model.cc new file mode 100644 index 00000000..0d7f39a0 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-audio-tagging-model.cc @@ -0,0 +1,116 @@ +// sherpa-mnn/csrc/offline-zipformer-audio-tagging-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-zipformer-audio-tagging-model.h" + +#include +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OfflineZipformerAudioTaggingModel::Impl { + public: + explicit Impl(const AudioTaggingModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.zipformer.model); + Init(buf.data(), buf.size()); + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const AudioTaggingModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.zipformer.model); + Init(buf.data(), buf.size()); + } +#endif + + MNN::Express::VARP Forward(MNN::Express::VARP features, MNN::Express::VARP features_length) { + std::vector inputs = {std::move(features), + std::move(features_length)}; + + auto ans = + sess_->onForward(inputs); + return std::move(ans[0]); + } + + int32_t NumEventClasses() const { return num_event_classes_; } + + MNNAllocator *Allocator() { return allocator_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + // get num_event_classes from the output[0].shape, + // which is (N, num_event_classes) + //num_event_classes_ = + // sess_->GetOutputTypeInfo(0)->getInfo()->dim[1]; + } + + private: + AudioTaggingModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t num_event_classes_ = 0; +}; + +OfflineZipformerAudioTaggingModel::OfflineZipformerAudioTaggingModel( + const AudioTaggingModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OfflineZipformerAudioTaggingModel::OfflineZipformerAudioTaggingModel( + AAssetManager *mgr, const AudioTaggingModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OfflineZipformerAudioTaggingModel::~OfflineZipformerAudioTaggingModel() = + default; + +MNN::Express::VARP OfflineZipformerAudioTaggingModel::Forward( + MNN::Express::VARP features, MNN::Express::VARP features_length) const { + return impl_->Forward(std::move(features), std::move(features_length)); +} + +int32_t OfflineZipformerAudioTaggingModel::NumEventClasses() const { + return impl_->NumEventClasses(); +} + +MNNAllocator *OfflineZipformerAudioTaggingModel::Allocator() const { + return impl_->Allocator(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-audio-tagging-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-audio-tagging-model.h new file mode 100644 index 00000000..8e1c0207 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-audio-tagging-model.h @@ -0,0 +1,64 @@ +// sherpa-mnn/csrc/offline-zipformer-audio-tagging-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_ +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/audio-tagging-model-config.h" + +namespace sherpa_mnn { + +/** This class implements the zipformer CTC model of the librispeech recipe + * from icefall. + * + * See + * https://github.com/k2-fsa/icefall/blob/master/egs/audioset/AT/zipformer/export-onnx.py + */ +class OfflineZipformerAudioTaggingModel { + public: + explicit OfflineZipformerAudioTaggingModel( + const AudioTaggingModelConfig &config); + +#if __ANDROID_API__ >= 9 + OfflineZipformerAudioTaggingModel(AAssetManager *mgr, + const AudioTaggingModelConfig &config); +#endif + + ~OfflineZipformerAudioTaggingModel(); + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int. + * + * @return Return a tensor + * - probs: A 2-D tensor of shape (N, num_event_classes). + */ + MNN::Express::VARP Forward(MNN::Express::VARP features, MNN::Express::VARP features_length) const; + + /** Return the number of event classes of the model + */ + int32_t NumEventClasses() const; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_AUDIO_TAGGING_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-ctc-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-ctc-model-config.cc new file mode 100644 index 00000000..b0f5336a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-ctc-model-config.cc @@ -0,0 +1,35 @@ +// sherpa-mnn/csrc/offline-zipformer-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-zipformer-ctc-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OfflineZipformerCtcModelConfig::Register(ParseOptions *po) { + po->Register("zipformer-ctc-model", &model, "Path to zipformer CTC model"); +} + +bool OfflineZipformerCtcModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("zipformer CTC model file '%s' does not exist", + model.c_str()); + return false; + } + + return true; +} + +std::string OfflineZipformerCtcModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineZipformerCtcModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-ctc-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-ctc-model-config.h new file mode 100644 index 00000000..406d98f7 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-ctc-model-config.h @@ -0,0 +1,32 @@ +// sherpa-mnn/csrc/offline-zipformer-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +// for +// https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py +struct OfflineZipformerCtcModelConfig { + std::string model; + + OfflineZipformerCtcModelConfig() = default; + + explicit OfflineZipformerCtcModelConfig(const std::string &model) + : model(model) {} + + void Register(ParseOptions *po); + + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-ctc-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-ctc-model.cc new file mode 100644 index 00000000..6f4c2e3e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-ctc-model.cc @@ -0,0 +1,140 @@ +// sherpa-mnn/csrc/offline-zipformer-ctc-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-zipformer-ctc-model.h" + +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" +#include "sherpa-mnn/csrc/transpose.h" + +namespace sherpa_mnn { + +class OfflineZipformerCtcModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.zipformer_ctc.model); + Init(buf.data(), buf.size()); + } + + template + Impl(Manager *mgr, const OfflineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.zipformer_ctc.model); + Init(buf.data(), buf.size()); + } + + std::vector Forward(MNN::Express::VARP features, + MNN::Express::VARP features_length) { + std::vector inputs = {std::move(features), + std::move(features_length)}; + + return sess_->onForward(inputs); + } + + int32_t VocabSize() const { return vocab_size_; } + int32_t SubsamplingFactor() const { return 4; } + + MNNAllocator *Allocator() { return allocator_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + auto iter = meta_data.find("vocab_size"); + if (iter != meta_data.end()){ + vocab_size_ = std::stoi(iter->second); + } + } + + private: + OfflineModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t vocab_size_ = 0; +}; + +OfflineZipformerCtcModel::OfflineZipformerCtcModel( + const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineZipformerCtcModel::OfflineZipformerCtcModel( + Manager *mgr, const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineZipformerCtcModel::~OfflineZipformerCtcModel() = default; + +std::vector OfflineZipformerCtcModel::Forward( + MNN::Express::VARP features, MNN::Express::VARP features_length) { + return impl_->Forward(std::move(features), std::move(features_length)); +} + +int32_t OfflineZipformerCtcModel::VocabSize() const { + return impl_->VocabSize(); +} + +MNNAllocator *OfflineZipformerCtcModel::Allocator() const { + return impl_->Allocator(); +} + +int32_t OfflineZipformerCtcModel::SubsamplingFactor() const { + return impl_->SubsamplingFactor(); +} + +#if __ANDROID_API__ >= 9 +template OfflineZipformerCtcModel::OfflineZipformerCtcModel( + AAssetManager *mgr, const OfflineModelConfig &config); +#endif + +#if __OHOS__ +template OfflineZipformerCtcModel::OfflineZipformerCtcModel( + NativeResourceManager *mgr, const OfflineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-ctc-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-ctc-model.h new file mode 100644 index 00000000..b13e3f9d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/offline-zipformer-ctc-model.h @@ -0,0 +1,62 @@ +// sherpa-mnn/csrc/offline-zipformer-ctc-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-ctc-model.h" +#include "sherpa-mnn/csrc/offline-model-config.h" + +namespace sherpa_mnn { + +/** This class implements the zipformer CTC model of the librispeech recipe + * from icefall. + * + * See + * https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py + */ +class OfflineZipformerCtcModel : public OfflineCtcModel { + public: + explicit OfflineZipformerCtcModel(const OfflineModelConfig &config); + + template + OfflineZipformerCtcModel(Manager *mgr, const OfflineModelConfig &config); + + ~OfflineZipformerCtcModel() override; + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int. + * + * @return Return a vector containing: + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int + */ + std::vector Forward(MNN::Express::VARP features, + MNN::Express::VARP features_length) override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const override; + + int32_t SubsamplingFactor() const override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-cnn-bilstm-model-meta-data.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-cnn-bilstm-model-meta-data.h new file mode 100644 index 00000000..336acd5d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-cnn-bilstm-model-meta-data.h @@ -0,0 +1,25 @@ +// sherpa-mnn/csrc/online-cnn-bilstm-model-meta-data.h +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_ + +namespace sherpa_mnn { + +struct OnlineCNNBiLSTMModelMetaData { + int32_t comma_id = -1; + int32_t period_id = -1; + int32_t quest_id = -1; + + int32_t upper_id = -1; + int32_t cap_id = -1; + int32_t mix_case_id = -1; + + int32_t num_cases = -1; + int32_t num_punctuations = -1; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-cnn-bilstm-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-cnn-bilstm-model.cc new file mode 100644 index 00000000..ddf409bf --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-cnn-bilstm-model.cc @@ -0,0 +1,136 @@ +// sherpa-mnn/csrc/online-cnn-bilstm-model.cc +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#include "sherpa-mnn/csrc/online-cnn-bilstm-model.h" + +#include +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OnlineCNNBiLSTMModel::Impl { + public: + explicit Impl(const OnlinePunctuationModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.cnn_bilstm); + Init(buf.data(), buf.size()); + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OnlinePunctuationModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.cnn_bilstm); + Init(buf.data(), buf.size()); + } +#endif + + std::pair Forward(MNN::Express::VARP token_ids, + MNN::Express::VARP valid_ids, + MNN::Express::VARP label_lens) { + std::vector inputs = { + std::move(token_ids), std::move(valid_ids), std::move(label_lens)}; + + auto ans = + sess_->onForward(inputs); + return {std::move(ans[0]), std::move(ans[1])}; + } + + MNNAllocator *Allocator() { return allocator_; } + + const OnlineCNNBiLSTMModelMetaData & metaData() const { + return meta_data_; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + + MNNAllocator* allocator; // used in the macro below + + SHERPA_ONNX_READ_META_DATA(meta_data_.comma_id, "COMMA"); + SHERPA_ONNX_READ_META_DATA(meta_data_.period_id, "PERIOD"); + SHERPA_ONNX_READ_META_DATA(meta_data_.quest_id, "QUESTION"); + + // assert here, because we will use the constant value + assert(meta_data_.comma_id == 1); + assert(meta_data_.period_id == 2); + assert(meta_data_.quest_id == 3); + + SHERPA_ONNX_READ_META_DATA(meta_data_.upper_id, "UPPER"); + SHERPA_ONNX_READ_META_DATA(meta_data_.cap_id, "CAP"); + SHERPA_ONNX_READ_META_DATA(meta_data_.mix_case_id, "MIX_CASE"); + + assert(meta_data_.upper_id == 1); + assert(meta_data_.cap_id == 2); + assert(meta_data_.mix_case_id == 3); + + // output shape is (T', num_cases) + //meta_data_.num_cases = + // sess_->GetOutputTypeInfo(0)->getInfo()->dim[1]; + //meta_data_.num_punctuations = + // sess_->GetOutputTypeInfo(1)->getInfo()->dim[1]; + } + + private: + OnlinePunctuationModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + OnlineCNNBiLSTMModelMetaData meta_data_; +}; + +OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel( + const OnlinePunctuationModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel( + AAssetManager *mgr, const OnlinePunctuationModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OnlineCNNBiLSTMModel::~OnlineCNNBiLSTMModel() = default; + +std::pair OnlineCNNBiLSTMModel::Forward( + MNN::Express::VARP token_ids, MNN::Express::VARP valid_ids, MNN::Express::VARP label_lens) const { + return impl_->Forward(std::move(token_ids), std::move(valid_ids), + std::move(label_lens)); +} + +MNNAllocator *OnlineCNNBiLSTMModel::Allocator() const { + return impl_->Allocator(); +} + +const OnlineCNNBiLSTMModelMetaData &OnlineCNNBiLSTMModel::metaData() + const { + return impl_->metaData(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-cnn-bilstm-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-cnn-bilstm-model.h new file mode 100644 index 00000000..a39ddfff --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-cnn-bilstm-model.h @@ -0,0 +1,62 @@ +// sherpa-mnn/csrc/online-cnn-bilstm-model.h +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_ +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/online-cnn-bilstm-model-meta-data.h" +#include "sherpa-mnn/csrc/online-punctuation-model-config.h" + +namespace sherpa_mnn { + +/** This class implements + * https://github.com/frankyoujian/Edge-Punct-Casing/blob/main/onnx_decode_sentence.py + */ +class OnlineCNNBiLSTMModel { + public: + explicit OnlineCNNBiLSTMModel(const OnlinePunctuationModelConfig &config); + +#if __ANDROID_API__ >= 9 + OnlineCNNBiLSTMModel(AAssetManager *mgr, + const OnlinePunctuationModelConfig &config); +#endif + + ~OnlineCNNBiLSTMModel(); + + /** Run the forward method of the model. + * + * @param token_ids A tensor of shape (N, T) of dtype int32. + * @param valid_ids A tensor of shape (N, T) of dtype int32. + * @param label_lens A tensor of shape (N) of dtype int32. + * + * @return Return a pair of tensors + * - case_logits: A 2-D tensor of shape (T', num_cases). + * - punct_logits: A 2-D tensor of shape (T', num_puncts). + */ + std::pair Forward(MNN::Express::VARP token_ids, + MNN::Express::VARP valid_ids, + MNN::Express::VARP label_lens) const; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const; + + const OnlineCNNBiLSTMModelMetaData& metaData() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-conformer-transducer-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-conformer-transducer-model.cc new file mode 100644 index 00000000..4dfa295d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-conformer-transducer-model.cc @@ -0,0 +1,290 @@ +// sherpa-mnn/csrc/online-conformer-transducer-model.cc +// +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) + +#include "sherpa-mnn/csrc/online-conformer-transducer-model.h" + +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/cat.h" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-transducer-decoder.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" +#include "sherpa-mnn/csrc/unbind.h" + +namespace sherpa_mnn { + +OnlineConformerTransducerModel::OnlineConformerTransducerModel( + const OnlineModelConfig &config) + : + config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } +} + +template +OnlineConformerTransducerModel::OnlineConformerTransducerModel( + Manager *mgr, const OnlineModelConfig &config) + : + config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } +} + +void OnlineConformerTransducerModel::InitEncoder(void *model_data, + size_t model_data_length) { + encoder_sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, + model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = encoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers"); + SHERPA_ONNX_READ_META_DATA(T_, "T"); + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); + SHERPA_ONNX_READ_META_DATA(left_context_, "left_context"); + SHERPA_ONNX_READ_META_DATA(encoder_dim_, "encoder_dim"); + SHERPA_ONNX_READ_META_DATA(pad_length_, "pad_length"); + SHERPA_ONNX_READ_META_DATA(cnn_module_kernel_, "cnn_module_kernel"); +} + +void OnlineConformerTransducerModel::InitDecoder(void *model_data, + size_t model_data_length) { + decoder_sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, + model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = decoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---decoder---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); +} + +void OnlineConformerTransducerModel::InitJoiner(void *model_data, + size_t model_data_length) { + joiner_sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, + model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + + // get meta data + MNNMeta meta_data = joiner_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---joiner---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } +} + +std::vector OnlineConformerTransducerModel::StackStates( + const std::vector> &states) const { + int32_t batch_size = static_cast(states.size()); + + std::vector attn_vec(batch_size); + std::vector conv_vec(batch_size); + + for (int32_t i = 0; i != batch_size; ++i) { + assert(states[i].size() == 2); + attn_vec[i] = states[i][0]; + conv_vec[i] = states[i][1]; + } + + auto allocator = + const_cast(this)->allocator_; + + MNN::Express::VARP attn = Cat(allocator, attn_vec, 2); + MNN::Express::VARP conv = Cat(allocator, conv_vec, 2); + + std::vector ans; + ans.reserve(2); + ans.push_back(std::move(attn)); + ans.push_back(std::move(conv)); + + return ans; +} + +std::vector> +OnlineConformerTransducerModel::UnStackStates( + const std::vector &states) const { + const int32_t batch_size = + states[0]->getInfo()->dim[2]; + assert(states.size() == 2); + + std::vector> ans(batch_size); + + auto allocator = + const_cast(this)->allocator_; + + std::vector attn_vec = Unbind(allocator, states[0], 2); + std::vector conv_vec = Unbind(allocator, states[1], 2); + + assert(attn_vec.size() == batch_size); + assert(conv_vec.size() == batch_size); + + for (int32_t i = 0; i != batch_size; ++i) { + ans[i].push_back(std::move(attn_vec[i])); + ans[i].push_back(std::move(conv_vec[i])); + } + + return ans; +} + +std::vector OnlineConformerTransducerModel::GetEncoderInitStates() { + // Please see + // https://github.com/k2-fsa/icefall/blob/86b0db6eb9c84d9bc90a71d92774fe2a7f73e6ab/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py#L203 + // for details + constexpr int32_t kBatchSize = 1; + std::array h_shape{num_encoder_layers_, left_context_, kBatchSize, + encoder_dim_}; + MNN::Express::VARP h = MNNUtilsCreateTensor(allocator_, h_shape.data(), + h_shape.size()); + + Fill(h, 0); + + std::array c_shape{num_encoder_layers_, cnn_module_kernel_ - 1, + kBatchSize, encoder_dim_}; + + MNN::Express::VARP c = MNNUtilsCreateTensor(allocator_, c_shape.data(), + c_shape.size()); + + Fill(c, 0); + + std::vector states; + + states.reserve(2); + states.push_back(std::move(h)); + states.push_back(std::move(c)); + + return states; +} + +std::pair> +OnlineConformerTransducerModel::RunEncoder(MNN::Express::VARP features, + std::vector states, + MNN::Express::VARP processed_frames) { + std::vector encoder_inputs = { + std::move(features), std::move(states[0]), std::move(states[1]), + std::move(processed_frames)}; + + auto encoder_out = encoder_sess_->onForward(encoder_inputs); + + std::vector next_states; + next_states.reserve(2); + next_states.push_back(std::move(encoder_out[1])); + next_states.push_back(std::move(encoder_out[2])); + + return {std::move(encoder_out[0]), std::move(next_states)}; +} + +MNN::Express::VARP OnlineConformerTransducerModel::RunDecoder( + MNN::Express::VARP decoder_input) { + auto decoder_out = decoder_sess_->onForward({decoder_input}); + return std::move(decoder_out[0]); +} + +MNN::Express::VARP OnlineConformerTransducerModel::RunJoiner(MNN::Express::VARP encoder_out, + MNN::Express::VARP decoder_out) { + std::vector joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = + joiner_sess_->onForward(joiner_input); + + return std::move(logit[0]); +} + +#if __ANDROID_API__ >= 9 +template OnlineConformerTransducerModel::OnlineConformerTransducerModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineConformerTransducerModel::OnlineConformerTransducerModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-conformer-transducer-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-conformer-transducer-model.h new file mode 100644 index 00000000..f03d5eac --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-conformer-transducer-model.h @@ -0,0 +1,100 @@ +// sherpa-mnn/csrc/online-conformer-transducer-model.h +// +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) + +#ifndef SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_ + +#include +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/online-model-config.h" +#include "sherpa-mnn/csrc/online-transducer-model.h" + +namespace sherpa_mnn { + +class OnlineConformerTransducerModel : public OnlineTransducerModel { + public: + explicit OnlineConformerTransducerModel(const OnlineModelConfig &config); + + template + OnlineConformerTransducerModel(Manager *mgr, const OnlineModelConfig &config); + + std::vector StackStates( + const std::vector> &states) const override; + + std::vector> UnStackStates( + const std::vector &states) const override; + + std::vector GetEncoderInitStates() override; + + std::pair> RunEncoder( + MNN::Express::VARP features, std::vector states, + MNN::Express::VARP processed_frames) override; + + MNN::Express::VARP RunDecoder(MNN::Express::VARP decoder_input) override; + + MNN::Express::VARP RunJoiner(MNN::Express::VARP encoder_out, MNN::Express::VARP decoder_out) override; + + int32_t ContextSize() const override { return context_size_; } + + int32_t ChunkSize() const override { return T_; } + + int32_t ChunkShift() const override { return decode_chunk_len_; } + + int32_t VocabSize() const override { return vocab_size_; } + MNNAllocator *Allocator() override { return allocator_; } + + private: + void InitEncoder(void *model_data, size_t model_data_length); + void InitDecoder(void *model_data, size_t model_data_length); + void InitJoiner(void *model_data, size_t model_data_length); + + private: + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + OnlineModelConfig config_; + + int32_t num_encoder_layers_ = 0; + int32_t T_ = 0; + int32_t decode_chunk_len_ = 0; + int32_t cnn_module_kernel_ = 0; + int32_t context_size_ = 0; + int32_t left_context_ = 0; + // TODO(jingzhaoou): to retrieve from model medadata + int32_t right_context_ = 4; + int32_t encoder_dim_ = 0; + int32_t pad_length_ = 0; + int32_t vocab_size_ = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-decoder.h new file mode 100644 index 00000000..19ad63d1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-decoder.h @@ -0,0 +1,65 @@ +// sherpa-mnn/csrc/online-ctc-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ + +#include +#include + +#include "kaldi-decoder/csrc/faster-decoder.h" +#include "MNNUtils.hpp" // NOLINT + +namespace sherpa_mnn { + +class OnlineStream; + +struct OnlineCtcDecoderResult { + /// Number of frames after subsampling we have decoded so far + int32_t frame_offset = 0; + + /// The decoded token IDs + std::vector tokens; + + /// The decoded word IDs + /// Note: tokens.size() is usually not equal to words.size() + /// words is empty for greedy search decoding. + /// it is not empty when an HLG graph or an HLG graph is used. + std::vector words; + + /// timestamps[i] contains the output frame index where tokens[i] is decoded. + /// Note: The index is after subsampling + /// + /// tokens.size() == timestamps.size() + std::vector timestamps; + + int32_t num_trailing_blanks = 0; +}; + +class OnlineCtcDecoder { + public: + virtual ~OnlineCtcDecoder() = default; + + /** Run streaming CTC decoding given the output from the encoder model. + * + * @param log_probs A 3-D tensor of shape + * (batch_size, num_frames, vocab_size) containing + * lob_probs in row major. + * + * @param results Input & Output parameters.. + */ + virtual void Decode(const float *log_probs, int32_t batch_size, + int32_t num_frames, int32_t vocab_size, + std::vector *results, + OnlineStream **ss = nullptr, int32_t n = 0) = 0; + + virtual std::unique_ptr CreateFasterDecoder() + const { + return nullptr; + } +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-fst-decoder-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-fst-decoder-config.cc new file mode 100644 index 00000000..c6ba43c3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-fst-decoder-config.cc @@ -0,0 +1,40 @@ +// sherpa-mnn/csrc/online-ctc-fst-decoder-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-ctc-fst-decoder-config.h" + +#include +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +std::string OnlineCtcFstDecoderConfig::ToString() const { + std::ostringstream os; + + os << "OnlineCtcFstDecoderConfig("; + os << "graph=\"" << graph << "\", "; + os << "max_active=" << max_active << ")"; + + return os.str(); +} + +void OnlineCtcFstDecoderConfig::Register(ParseOptions *po) { + po->Register("ctc-graph", &graph, "Path to H.fst, HL.fst, or HLG.fst"); + + po->Register("ctc-max-active", &max_active, + "Decoder max active states. Larger->slower; more accurate"); +} + +bool OnlineCtcFstDecoderConfig::Validate() const { + if (!graph.empty() && !FileExists(graph)) { + SHERPA_ONNX_LOGE("graph: '%s' does not exist", graph.c_str()); + return false; + } + return true; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-fst-decoder-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-fst-decoder-config.h new file mode 100644 index 00000000..d88b3a23 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-fst-decoder-config.h @@ -0,0 +1,32 @@ +// sherpa-mnn/csrc/online-ctc-fst-decoder-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OnlineCtcFstDecoderConfig { + // Path to H.fst, HL.fst or HLG.fst + std::string graph; + int32_t max_active = 3000; + + OnlineCtcFstDecoderConfig() = default; + + OnlineCtcFstDecoderConfig(const std::string &graph, int32_t max_active) + : graph(graph), max_active(max_active) {} + + std::string ToString() const; + + void Register(ParseOptions *po); + bool Validate() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-fst-decoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-fst-decoder.cc new file mode 100644 index 00000000..310c963d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-fst-decoder.cc @@ -0,0 +1,118 @@ +// sherpa-mnn/csrc/online-ctc-fst-decoder.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-ctc-fst-decoder.h" + +#include +#include +#include +#include +#include + +#include "fst/fstlib.h" +#include "kaldi-decoder/csrc/decodable-ctc.h" +#include "kaldifst/csrc/fstext-utils.h" +#include "sherpa-mnn/csrc/fst-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-stream.h" + +namespace sherpa_mnn { + +OnlineCtcFstDecoder::OnlineCtcFstDecoder( + const OnlineCtcFstDecoderConfig &config, int32_t blank_id) + : config_(config), fst_(ReadGraph(config.graph)), blank_id_(blank_id) { + options_.max_active = config_.max_active; +} + +std::unique_ptr +OnlineCtcFstDecoder::CreateFasterDecoder() const { + return std::make_unique(*fst_, options_); +} + +static void DecodeOne(const float *log_probs, int32_t num_rows, + int32_t num_cols, OnlineCtcDecoderResult *result, + OnlineStream *s, int32_t blank_id) { + int32_t &processed_frames = s->GetFasterDecoderProcessedFrames(); + kaldi_decoder::DecodableCtc decodable(log_probs, num_rows, num_cols, + processed_frames); + + kaldi_decoder::FasterDecoder *decoder = s->GetFasterDecoder(); + if (processed_frames == 0) { + decoder->InitDecoding(); + } + + decoder->AdvanceDecoding(&decodable); + + if (decoder->ReachedFinal()) { + fst::VectorFst fst_out; + bool ok = decoder->GetBestPath(&fst_out); + if (ok) { + std::vector isymbols_out; + std::vector osymbols_out; + /*ok =*/fst::GetLinearSymbolSequence(fst_out, &isymbols_out, + &osymbols_out, nullptr); + // TODO(fangjun): handle ok is false + std::vector tokens; + tokens.reserve(isymbols_out.size()); + + std::vector timestamps; + timestamps.reserve(isymbols_out.size()); + + std::ostringstream os; + int32_t prev_id = -1; + int32_t &num_trailing_blanks = result->num_trailing_blanks; + int32_t f = 0; // frame number + + for (auto i : isymbols_out) { + i -= 1; + + if (i == blank_id) { + num_trailing_blanks += 1; + } else { + num_trailing_blanks = 0; + } + + if (i != blank_id && i != prev_id) { + tokens.push_back(i); + timestamps.push_back(f); + } + prev_id = i; + f += 1; + } + + result->tokens = std::move(tokens); + result->words = std::move(osymbols_out); + result->timestamps = std::move(timestamps); + // no need to set frame_offset + } + } + + processed_frames += num_rows; +} + +void OnlineCtcFstDecoder::Decode(const float *log_probs, int32_t batch_size, + int32_t num_frames, int32_t vocab_size, + std::vector *results, + OnlineStream **ss, int32_t n) { + if (batch_size != results->size()) { + SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d", + batch_size, static_cast(results->size())); + exit(-1); + } + + if (batch_size != n) { + SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d", batch_size, + n); + exit(-1); + } + + const float *p = log_probs; + + for (int32_t i = 0; i != batch_size; ++i) { + DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size, + &(*results)[i], ss[i], blank_id_); + } +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-fst-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-fst-decoder.h new file mode 100644 index 00000000..fb17b008 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-fst-decoder.h @@ -0,0 +1,39 @@ +// sherpa-mnn/csrc/online-ctc-fst-decoder.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_ + +#include +#include + +#include "fst/fst.h" +#include "sherpa-mnn/csrc/online-ctc-decoder.h" +#include "sherpa-mnn/csrc/online-ctc-fst-decoder-config.h" + +namespace sherpa_mnn { + +class OnlineCtcFstDecoder : public OnlineCtcDecoder { + public: + OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config, + int32_t blank_id); + + void Decode(const float *log_probs, int32_t batch_size, int32_t num_frames, + int32_t vocab_size, std::vector *results, + OnlineStream **ss = nullptr, int32_t n = 0) override; + + std::unique_ptr CreateFasterDecoder() + const override; + + private: + OnlineCtcFstDecoderConfig config_; + kaldi_decoder::FasterDecoderOptions options_; + + std::unique_ptr> fst_; + int32_t blank_id_ = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-greedy-search-decoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-greedy-search-decoder.cc new file mode 100644 index 00000000..e9fb1617 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-greedy-search-decoder.cc @@ -0,0 +1,59 @@ +// sherpa-mnn/csrc/online-ctc-greedy-search-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-ctc-greedy-search-decoder.h" + +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OnlineCtcGreedySearchDecoder::Decode( + const float *log_probs, int32_t batch_size, int32_t num_frames, + int32_t vocab_size, std::vector *results, + OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) { + if (batch_size != results->size()) { + SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d", + batch_size, static_cast(results->size())); + exit(-1); + } + + const float *p = log_probs; + + for (int32_t b = 0; b != batch_size; ++b) { + auto &r = (*results)[b]; + + int32_t prev_id = -1; + + for (int32_t t = 0; t != num_frames; ++t, p += vocab_size) { + int32_t y = static_cast(std::distance( + static_cast(p), + std::max_element(static_cast(p), + static_cast(p) + vocab_size))); + + if (y == blank_id_) { + r.num_trailing_blanks += 1; + } else { + r.num_trailing_blanks = 0; + } + + if (y != blank_id_ && y != prev_id) { + r.tokens.push_back(y); + r.timestamps.push_back(t + r.frame_offset); + } + + prev_id = y; + } // for (int32_t t = 0; t != num_frames; ++t) { + } // for (int32_t b = 0; b != batch_size; ++b) + + // Update frame_offset + for (auto &r : *results) { + r.frame_offset += num_frames; + } +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-greedy-search-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-greedy-search-decoder.h new file mode 100644 index 00000000..e61e249b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-greedy-search-decoder.h @@ -0,0 +1,29 @@ +// sherpa-mnn/csrc/online-ctc-greedy-search-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_GREEDY_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_CTC_GREEDY_SEARCH_DECODER_H_ + +#include + +#include "sherpa-mnn/csrc/online-ctc-decoder.h" + +namespace sherpa_mnn { + +class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder { + public: + explicit OnlineCtcGreedySearchDecoder(int32_t blank_id) + : blank_id_(blank_id) {} + + void Decode(const float *log_probs, int32_t batch_size, int32_t num_frames, + int32_t vocab_size, std::vector *results, + OnlineStream **ss = nullptr, int32_t n = 0) override; + + private: + int32_t blank_id_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_GREEDY_SEARCH_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-model.cc new file mode 100644 index 00000000..9f96d77e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-model.cc @@ -0,0 +1,68 @@ +// sherpa-mnn/csrc/online-ctc-model.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-ctc-model.h" + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-nemo-ctc-model.h" +#include "sherpa-mnn/csrc/online-wenet-ctc-model.h" +#include "sherpa-mnn/csrc/online-zipformer2-ctc-model.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +std::unique_ptr OnlineCtcModel::Create( + const OnlineModelConfig &config) { + if (!config.wenet_ctc.model.empty()) { + return std::make_unique(config); + } else if (!config.zipformer2_ctc.model.empty()) { + return std::make_unique(config); + } else if (!config.nemo_ctc.model.empty()) { + return std::make_unique(config); + } else { + SHERPA_ONNX_LOGE("Please specify a CTC model"); + exit(-1); + } +} + +template +std::unique_ptr OnlineCtcModel::Create( + Manager *mgr, const OnlineModelConfig &config) { + if (!config.wenet_ctc.model.empty()) { + return std::make_unique(mgr, config); + } else if (!config.zipformer2_ctc.model.empty()) { + return std::make_unique(mgr, config); + } else if (!config.nemo_ctc.model.empty()) { + return std::make_unique(mgr, config); + } else { + SHERPA_ONNX_LOGE("Please specify a CTC model"); + exit(-1); + } +} + +#if __ANDROID_API__ >= 9 +template std::unique_ptr OnlineCtcModel::Create( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr OnlineCtcModel::Create( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-model.h new file mode 100644 index 00000000..439aadcd --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ctc-model.h @@ -0,0 +1,84 @@ +// sherpa-mnn/csrc/online-ctc-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_CTC_MODEL_H_ + +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/online-model-config.h" + +namespace sherpa_mnn { + +class OnlineCtcModel { + public: + virtual ~OnlineCtcModel() = default; + + static std::unique_ptr Create( + const OnlineModelConfig &config); + + template + static std::unique_ptr Create( + Manager *mgr, const OnlineModelConfig &config); + + // Return a list of tensors containing the initial states + virtual std::vector GetInitStates() const = 0; + + /** Stack a list of individual states into a batch. + * + * It is the inverse operation of `UnStackStates`. + * + * @param states states[i] contains the state for the i-th utterance. + * @return Return a single value representing the batched state. + */ + virtual std::vector StackStates( + std::vector> states) const = 0; + + /** Unstack a batch state into a list of individual states. + * + * It is the inverse operation of `StackStates`. + * + * @param states A batched state. + * @return ans[i] contains the state for the i-th utterance. + */ + virtual std::vector> UnStackStates( + std::vector states) const = 0; + + /** + * + * @param x A 3-D tensor of shape (N, T, C). N has to be 1. + * @param states It is from GetInitStates() or returned from this method. + * + * @return Return a list of tensors + * - ans[0] contains log_probs, of shape (N, T, C) + * - ans[1:] contains next_states + */ + virtual std::vector Forward( + MNN::Express::VARP x, std::vector states) const = 0; + + /** Return the vocabulary size of the model + */ + virtual int32_t VocabSize() const = 0; + + /** Return an allocator for allocating memory + */ + virtual MNNAllocator *Allocator() const = 0; + + // The model accepts this number of frames before subsampling as input + virtual int32_t ChunkLength() const = 0; + + // Similar to frame_shift in feature extractor, after processing + // ChunkLength() frames, we advance by ChunkShift() frames + // before we process the next chunk. + virtual int32_t ChunkShift() const = 0; + + // Return true if the model supports batch size > 1 + virtual bool SupportBatchProcessing() const { return true; } +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ebranchformer-transducer-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ebranchformer-transducer-model.cc new file mode 100644 index 00000000..ac04104a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ebranchformer-transducer-model.cc @@ -0,0 +1,413 @@ +// sherpa-mnn/csrc/online-ebranchformer-transducer-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation +// 2025 Brno University of Technology (author: Karel Vesely) + +#include "sherpa-mnn/csrc/online-ebranchformer-transducer-model.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/cat.h" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-transducer-decoder.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" +#include "sherpa-mnn/csrc/unbind.h" + +namespace sherpa_mnn { + +OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel( + const OnlineModelConfig &config) + : + sess_opts_(GetSessionOptions(config)), + config_(config), + allocator_{} { + { + auto buf = ReadFile(config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } +} + +template +OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel( + Manager *mgr, const OnlineModelConfig &config) + : + config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } +} + +void OnlineEbranchformerTransducerModel::InitEncoder(void *model_data, + size_t model_data_length) { + encoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = encoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); + SHERPA_ONNX_READ_META_DATA(T_, "T"); + + SHERPA_ONNX_READ_META_DATA(num_hidden_layers_, "num_hidden_layers"); + SHERPA_ONNX_READ_META_DATA(hidden_size_, "hidden_size"); + SHERPA_ONNX_READ_META_DATA(intermediate_size_, "intermediate_size"); + SHERPA_ONNX_READ_META_DATA(csgu_kernel_size_, "csgu_kernel_size"); + SHERPA_ONNX_READ_META_DATA(merge_conv_kernel_, "merge_conv_kernel"); + SHERPA_ONNX_READ_META_DATA(left_context_len_, "left_context_len"); + SHERPA_ONNX_READ_META_DATA(num_heads_, "num_heads"); + SHERPA_ONNX_READ_META_DATA(head_dim_, "head_dim"); + + if (config_.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("T: %{public}d", T_); + SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_); + + SHERPA_ONNX_LOGE("num_hidden_layers_: %{public}d", num_hidden_layers_); + SHERPA_ONNX_LOGE("hidden_size_: %{public}d", hidden_size_); + SHERPA_ONNX_LOGE("intermediate_size_: %{public}d", intermediate_size_); + SHERPA_ONNX_LOGE("csgu_kernel_size_: %{public}d", csgu_kernel_size_); + SHERPA_ONNX_LOGE("merge_conv_kernel_: %{public}d", merge_conv_kernel_); + SHERPA_ONNX_LOGE("left_context_len_: %{public}d", left_context_len_); + SHERPA_ONNX_LOGE("num_heads_: %{public}d", num_heads_); + SHERPA_ONNX_LOGE("head_dim_: %{public}d", head_dim_); +#else + SHERPA_ONNX_LOGE("T: %d", T_); + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); + + SHERPA_ONNX_LOGE("num_hidden_layers_: %d", num_hidden_layers_); + SHERPA_ONNX_LOGE("hidden_size_: %d", hidden_size_); + SHERPA_ONNX_LOGE("intermediate_size_: %d", intermediate_size_); + SHERPA_ONNX_LOGE("csgu_kernel_size_: %d", csgu_kernel_size_); + SHERPA_ONNX_LOGE("merge_conv_kernel_: %d", merge_conv_kernel_); + SHERPA_ONNX_LOGE("left_context_len_: %d", left_context_len_); + SHERPA_ONNX_LOGE("num_heads_: %d", num_heads_); + SHERPA_ONNX_LOGE("head_dim_: %d", head_dim_); +#endif + } +} + +void OnlineEbranchformerTransducerModel::InitDecoder(void *model_data, + size_t model_data_length) { + decoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = decoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---decoder---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); +} + +void OnlineEbranchformerTransducerModel::InitJoiner(void *model_data, + size_t model_data_length) { + joiner_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + + // get meta data + MNNMeta meta_data = joiner_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---joiner---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } +} + +std::vector OnlineEbranchformerTransducerModel::StackStates( + const std::vector> &states) const { + int32_t batch_size = static_cast(states.size()); + + std::vector buf(batch_size); + + auto allocator = + const_cast(this)->allocator_; + + std::vector ans; + int32_t num_states = static_cast(states[0].size()); + ans.reserve(num_states); + + for (int32_t i = 0; i != num_hidden_layers_; ++i) { + { // cached_key + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][4 * i]; + } + auto v = Cat(allocator, buf, /* axis */ 0); + ans.push_back(std::move(v)); + } + { // cached_value + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][4 * i + 1]; + } + auto v = Cat(allocator, buf, 0); + ans.push_back(std::move(v)); + } + { // cached_conv + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][4 * i + 2]; + } + auto v = Cat(allocator, buf, 0); + ans.push_back(std::move(v)); + } + { // cached_conv_fusion + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][4 * i + 3]; + } + auto v = Cat(allocator, buf, 0); + ans.push_back(std::move(v)); + } + } + + { // processed_lens + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][num_states - 1]; + } + auto v = Cat(allocator, buf, 0); + ans.push_back(std::move(v)); + } + + return ans; +} + +std::vector> +OnlineEbranchformerTransducerModel::UnStackStates( + const std::vector &states) const { + assert(static_cast(states.size()) == num_hidden_layers_ * 4 + 1); + + int32_t batch_size = states[0]->getInfo()->dim[0]; + + auto allocator = + const_cast(this)->allocator_; + + std::vector> ans; + ans.resize(batch_size); + + for (int32_t i = 0; i != num_hidden_layers_; ++i) { + { // cached_key + auto v = Unbind(allocator, states[i * 4], /* axis */ 0); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { // cached_value + auto v = Unbind(allocator, states[i * 4 + 1], 0); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { // cached_conv + auto v = Unbind(allocator, states[i * 4 + 2], 0); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { // cached_conv_fusion + auto v = Unbind(allocator, states[i * 4 + 3], 0); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + } + + { // processed_lens + auto v = Unbind(allocator, states.back(), 0); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + return ans; +} + +std::vector +OnlineEbranchformerTransducerModel::GetEncoderInitStates() { + std::vector ans; + + ans.reserve(num_hidden_layers_ * 4 + 1); + + int32_t left_context_conv = csgu_kernel_size_ - 1; + int32_t channels_conv = intermediate_size_ / 2; + + int32_t left_context_conv_fusion = merge_conv_kernel_ - 1; + int32_t channels_conv_fusion = 2 * hidden_size_; + + for (int32_t i = 0; i != num_hidden_layers_; ++i) { + { // cached_key_{i} + std::array s{1, num_heads_, left_context_len_, head_dim_}; + auto v = MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + ans.push_back(std::move(v)); + } + + { // cahced_value_{i} + std::array s{1, num_heads_, left_context_len_, head_dim_}; + auto v = MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + ans.push_back(std::move(v)); + } + + { // cached_conv_{i} + std::array s{1, channels_conv, left_context_conv}; + auto v = MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + ans.push_back(std::move(v)); + } + + { // cached_conv_fusion_{i} + std::array s{1, channels_conv_fusion, + left_context_conv_fusion}; + auto v = MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + ans.push_back(std::move(v)); + } + } // num_hidden_layers_ + + { // processed_lens + std::array s{1}; + auto v = MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + ans.push_back(std::move(v)); + } + + return ans; +} + +std::pair> +OnlineEbranchformerTransducerModel::RunEncoder( + MNN::Express::VARP features, std::vector states, + MNN::Express::VARP /* processed_frames */) { + std::vector encoder_inputs; + encoder_inputs.reserve(1 + states.size()); + + encoder_inputs.push_back(std::move(features)); + for (auto &v : states) { + encoder_inputs.push_back(std::move(v)); + } + + auto encoder_out = encoder_sess_->onForward(encoder_inputs); + + std::vector next_states; + next_states.reserve(states.size()); + + for (int32_t i = 1; i != static_cast(encoder_out.size()); ++i) { + next_states.push_back(std::move(encoder_out[i])); + } + return {std::move(encoder_out[0]), std::move(next_states)}; +} + +MNN::Express::VARP OnlineEbranchformerTransducerModel::RunDecoder( + MNN::Express::VARP decoder_input) { + auto decoder_out = decoder_sess_->onForward({decoder_input}); + return std::move(decoder_out[0]); +} + +MNN::Express::VARP OnlineEbranchformerTransducerModel::RunJoiner( + MNN::Express::VARP encoder_out, MNN::Express::VARP decoder_out) { + std::vector joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = + joiner_sess_->onForward(joiner_input); + + return std::move(logit[0]); +} + +#if __ANDROID_API__ >= 9 +template OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ebranchformer-transducer-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ebranchformer-transducer-model.h new file mode 100644 index 00000000..f42e6971 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-ebranchformer-transducer-model.h @@ -0,0 +1,110 @@ +// sherpa-mnn/csrc/online-ebranchformer-transducer-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +// 2025 Brno University of Technology (author: Karel Vesely) +#ifndef SHERPA_ONNX_CSRC_ONLINE_EBRANCHFORMER_TRANSDUCER_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_EBRANCHFORMER_TRANSDUCER_MODEL_H_ + +#include +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/online-model-config.h" +#include "sherpa-mnn/csrc/online-transducer-model.h" + +namespace sherpa_mnn { + +class OnlineEbranchformerTransducerModel : public OnlineTransducerModel { + public: + explicit OnlineEbranchformerTransducerModel(const OnlineModelConfig &config); + + template + OnlineEbranchformerTransducerModel(Manager *mgr, + const OnlineModelConfig &config); + + std::vector StackStates( + const std::vector> &states) const override; + + std::vector> UnStackStates( + const std::vector &states) const override; + + std::vector GetEncoderInitStates() override; + + void SetFeatureDim(int32_t feature_dim) override { + feature_dim_ = feature_dim; + } + + std::pair> RunEncoder( + MNN::Express::VARP features, std::vector states, + MNN::Express::VARP processed_frames) override; + + MNN::Express::VARP RunDecoder(MNN::Express::VARP decoder_input) override; + + MNN::Express::VARP RunJoiner(MNN::Express::VARP encoder_out, MNN::Express::VARP decoder_out) override; + + int32_t ContextSize() const override { return context_size_; } + + int32_t ChunkSize() const override { return T_; } + + int32_t ChunkShift() const override { return decode_chunk_len_; } + + int32_t VocabSize() const override { return vocab_size_; } + MNNAllocator *Allocator() override { return allocator_; } + + private: + void InitEncoder(void *model_data, size_t model_data_length); + void InitDecoder(void *model_data, size_t model_data_length); + void InitJoiner(void *model_data, size_t model_data_length); + + private: + MNNEnv env_; + MNNConfig sess_opts_; + + MNNAllocator* allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + OnlineModelConfig config_; + + int32_t decode_chunk_len_ = 0; + int32_t T_ = 0; + + int32_t num_hidden_layers_ = 0; + int32_t hidden_size_ = 0; + int32_t intermediate_size_ = 0; + int32_t csgu_kernel_size_ = 0; + int32_t merge_conv_kernel_ = 0; + int32_t left_context_len_ = 0; + int32_t num_heads_ = 0; + int32_t head_dim_ = 0; + + int32_t context_size_ = 0; + int32_t vocab_size_ = 0; + int32_t feature_dim_ = 80; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_EBRANCHFORMER_TRANSDUCER_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lm-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lm-config.cc new file mode 100644 index 00000000..f3567368 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lm-config.cc @@ -0,0 +1,45 @@ +// sherpa-mnn/csrc/online-lm-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-lm-config.h" + +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OnlineLMConfig::Register(ParseOptions *po) { + po->Register("lm", &model, "Path to LM model."); + po->Register("lm-scale", &scale, "LM scale."); + po->Register("lm-num-threads", &lm_num_threads, + "Number of threads to run the neural network of LM model"); + po->Register("lm-provider", &lm_provider, + "Specify a provider to LM model use: cpu, cuda, coreml"); + po->Register("lm-shallow-fusion", &shallow_fusion, + "Boolean whether to use shallow fusion or rescore."); +} + +bool OnlineLMConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("'%s' does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OnlineLMConfig::ToString() const { + std::ostringstream os; + + os << "OnlineLMConfig("; + os << "model=\"" << model << "\", "; + os << "scale=" << scale << ", "; + os << "shallow_fusion=" << (shallow_fusion ? "True" : "False") << ")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lm-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lm-config.h new file mode 100644 index 00000000..93a34f5e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lm-config.h @@ -0,0 +1,42 @@ +// sherpa-mnn/csrc/online-lm-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OnlineLMConfig { + // path to the onnx model + std::string model; + + // LM scale + float scale = 0.5; + int32_t lm_num_threads = 1; + std::string lm_provider = "cpu"; + // enable shallow fusion + bool shallow_fusion = true; + + OnlineLMConfig() = default; + + OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads, + const std::string &lm_provider, bool shallow_fusion) + : model(model), + scale(scale), + lm_num_threads(lm_num_threads), + lm_provider(lm_provider), + shallow_fusion(shallow_fusion) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lm.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lm.cc new file mode 100644 index 00000000..99f82b19 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lm.cc @@ -0,0 +1,20 @@ +// sherpa-mnn/csrc/online-lm.cc +// +// Copyright (c) 2023 Pingfeng Luo +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-lm.h" + +#include +#include +#include + +#include "sherpa-mnn/csrc/online-rnn-lm.h" + +namespace sherpa_mnn { + +std::unique_ptr OnlineLM::Create(const OnlineLMConfig &config) { + return std::make_unique(config); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lm.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lm.h new file mode 100644 index 00000000..e6972790 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lm.h @@ -0,0 +1,63 @@ +// sherpa-mnn/csrc/online-lm.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_LM_H_ +#define SHERPA_ONNX_CSRC_ONLINE_LM_H_ + +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/hypothesis.h" +#include "sherpa-mnn/csrc/online-lm-config.h" + +namespace sherpa_mnn { + +class OnlineLM { + public: + virtual ~OnlineLM() = default; + + static std::unique_ptr Create(const OnlineLMConfig &config); + + // init states for classic rescore + virtual std::vector GetInitStates() = 0; + + // init states for shallow fusion + virtual std::pair> GetInitStatesSF() = 0; + + /** ScoreToken a batch of sentences (shallow fusion). + * + * @param x A 2-D tensor of shape (N, 1) with data type int64. + * @param states It contains the states for the LM model + * @return Return a pair containing + * - log_prob of NN LM + * - updated states + * + */ + virtual std::pair> ScoreToken( + MNN::Express::VARP x, std::vector states) = 0; + + /** This function updates hyp.lm_log_prob of hyps (classic rescore). + * + * @param scale LM score + * @param context_size Context size of the transducer decoder model + * @param hyps It is changed in-place. + * + */ + virtual void ComputeLMScore(float scale, int32_t context_size, + std::vector *hyps) = 0; + + /** This function updates lm_log_prob and nn_lm_scores of hyp (shallow fusion). + * + * @param scale LM score + * @param hyps It is changed in-place. + * + */ + virtual void ComputeLMScoreSF(float scale, Hypothesis *hyp) = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_LM_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lstm-transducer-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lstm-transducer-model.cc new file mode 100644 index 00000000..466daa79 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lstm-transducer-model.cc @@ -0,0 +1,274 @@ +// sherpa-mnn/csrc/online-lstm-transducer-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-mnn/csrc/online-lstm-transducer-model.h" + +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/cat.h" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-transducer-decoder.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/unbind.h" + +namespace sherpa_mnn { + +OnlineLstmTransducerModel::OnlineLstmTransducerModel( + const OnlineModelConfig &config) + : + config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } +} + +template +OnlineLstmTransducerModel::OnlineLstmTransducerModel( + Manager *mgr, const OnlineModelConfig &config) + : + config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } +} + +void OnlineLstmTransducerModel::InitEncoder(void *model_data, + size_t model_data_length) { + encoder_sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, + model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = encoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers"); + SHERPA_ONNX_READ_META_DATA(T_, "T"); + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); + SHERPA_ONNX_READ_META_DATA(rnn_hidden_size_, "rnn_hidden_size"); + SHERPA_ONNX_READ_META_DATA(d_model_, "d_model"); +} + +void OnlineLstmTransducerModel::InitDecoder(void *model_data, + size_t model_data_length) { + decoder_sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, + model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = decoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---decoder---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); +} + +void OnlineLstmTransducerModel::InitJoiner(void *model_data, + size_t model_data_length) { + joiner_sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, + model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + + // get meta data + MNNMeta meta_data = joiner_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---joiner---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } +} + +std::vector OnlineLstmTransducerModel::StackStates( + const std::vector> &states) const { + int32_t batch_size = static_cast(states.size()); + + std::vector h_buf(batch_size); + std::vector c_buf(batch_size); + + for (int32_t i = 0; i != batch_size; ++i) { + assert(states[i].size() == 2); + h_buf[i] = states[i][0]; + c_buf[i] = states[i][1]; + } + auto allocator = const_cast(this)->allocator_; + + MNN::Express::VARP h = Cat(allocator, h_buf, 1); + MNN::Express::VARP c = Cat(allocator, c_buf, 1); + + std::vector ans; + ans.reserve(2); + ans.push_back(std::move(h)); + ans.push_back(std::move(c)); + + return ans; +} + +std::vector> OnlineLstmTransducerModel::UnStackStates( + const std::vector &states) const { + int32_t batch_size = states[0]->getInfo()->dim[1]; + assert(states.size() == 2); + + std::vector> ans(batch_size); + + auto allocator = const_cast(this)->allocator_; + + std::vector h_vec = Unbind(allocator, states[0], 1); + std::vector c_vec = Unbind(allocator, states[1], 1); + + assert(h_vec.size() == batch_size); + assert(c_vec.size() == batch_size); + + for (int32_t i = 0; i != batch_size; ++i) { + ans[i].push_back(std::move(h_vec[i])); + ans[i].push_back(std::move(c_vec[i])); + } + + return ans; +} + +std::vector OnlineLstmTransducerModel::GetEncoderInitStates() { + // Please see + // https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py#L185 + // for details + constexpr int32_t kBatchSize = 1; + std::array h_shape{num_encoder_layers_, kBatchSize, d_model_}; + MNN::Express::VARP h = MNNUtilsCreateTensor(allocator_, h_shape.data(), + h_shape.size()); + + Fill(h, 0); + + std::array c_shape{num_encoder_layers_, kBatchSize, + rnn_hidden_size_}; + + MNN::Express::VARP c = MNNUtilsCreateTensor(allocator_, c_shape.data(), + c_shape.size()); + + Fill(c, 0); + + std::vector states; + + states.reserve(2); + states.push_back(std::move(h)); + states.push_back(std::move(c)); + + return states; +} + +std::pair> +OnlineLstmTransducerModel::RunEncoder(MNN::Express::VARP features, + std::vector states, + MNN::Express::VARP /* processed_frames */) { + std::vector encoder_inputs = { + std::move(features), std::move(states[0]), std::move(states[1])}; + + auto encoder_out = encoder_sess_->onForward(encoder_inputs); + + std::vector next_states; + next_states.reserve(2); + next_states.push_back(std::move(encoder_out[1])); + next_states.push_back(std::move(encoder_out[2])); + + return {std::move(encoder_out[0]), std::move(next_states)}; +} + +MNN::Express::VARP OnlineLstmTransducerModel::RunDecoder(MNN::Express::VARP decoder_input) { + auto decoder_out = decoder_sess_->onForward({decoder_input}); + return std::move(decoder_out[0]); +} + +MNN::Express::VARP OnlineLstmTransducerModel::RunJoiner(MNN::Express::VARP encoder_out, + MNN::Express::VARP decoder_out) { + std::vector joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = + joiner_sess_->onForward(joiner_input); + + return std::move(logit[0]); +} + +#if __ANDROID_API__ >= 9 +template OnlineLstmTransducerModel::OnlineLstmTransducerModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineLstmTransducerModel::OnlineLstmTransducerModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lstm-transducer-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lstm-transducer-model.h new file mode 100644 index 00000000..c1e9aa1c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-lstm-transducer-model.h @@ -0,0 +1,95 @@ +// sherpa-mnn/csrc/online-lstm-transducer-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_ + +#include +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/online-model-config.h" +#include "sherpa-mnn/csrc/online-transducer-model.h" + +namespace sherpa_mnn { + +class OnlineLstmTransducerModel : public OnlineTransducerModel { + public: + explicit OnlineLstmTransducerModel(const OnlineModelConfig &config); + + template + OnlineLstmTransducerModel(Manager *mgr, const OnlineModelConfig &config); + + std::vector StackStates( + const std::vector> &states) const override; + + std::vector> UnStackStates( + const std::vector &states) const override; + + std::vector GetEncoderInitStates() override; + + std::pair> RunEncoder( + MNN::Express::VARP features, std::vector states, + MNN::Express::VARP processed_frames) override; + + MNN::Express::VARP RunDecoder(MNN::Express::VARP decoder_input) override; + + MNN::Express::VARP RunJoiner(MNN::Express::VARP encoder_out, MNN::Express::VARP decoder_out) override; + + int32_t ContextSize() const override { return context_size_; } + + int32_t ChunkSize() const override { return T_; } + + int32_t ChunkShift() const override { return decode_chunk_len_; } + + int32_t VocabSize() const override { return vocab_size_; } + MNNAllocator *Allocator() override { return allocator_; } + + private: + void InitEncoder(void *model_data, size_t model_data_length); + void InitDecoder(void *model_data, size_t model_data_length); + void InitJoiner(void *model_data, size_t model_data_length); + + private: + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + OnlineModelConfig config_; + + int32_t num_encoder_layers_ = 0; + int32_t T_ = 0; + int32_t decode_chunk_len_ = 0; + int32_t rnn_hidden_size_ = 0; + int32_t d_model_ = 0; + int32_t context_size_ = 0; + int32_t vocab_size_ = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-model-config.cc new file mode 100644 index 00000000..591428ea --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-model-config.cc @@ -0,0 +1,180 @@ +// sherpa-mnn/csrc/online-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-mnn/csrc/online-model-config.h" + +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +void OnlineModelConfig::Register(ParseOptions *po) { + transducer.Register(po); + paraformer.Register(po); + wenet_ctc.Register(po); + zipformer2_ctc.Register(po); + nemo_ctc.Register(po); + provider_config.Register(po); + + po->Register("tokens", &tokens, "Path to tokens.txt"); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("warm-up", &warm_up, + "Number of warm-up to run the onnxruntime" + "Valid vales are: zipformer2"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("modeling-unit", &modeling_unit, + "The modeling unit of the model, commonly used units are bpe, " + "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when " + "hotwords are provided, we need it to encode the hotwords into " + "token sequence."); + + po->Register("bpe-vocab", &bpe_vocab, + "The vocabulary generated by google's sentencepiece program. " + "It is a file has two columns, one is the token, the other is " + "the log probability, you can get it from the directory where " + "your bpe model is generated. Only used when hotwords provided " + "and the modeling unit is bpe or cjkchar+bpe"); + + po->Register("model-type", &model_type, + "Specify it to reduce model initialization time. " + "Valid values are: conformer, lstm, zipformer, zipformer2, " + "wenet_ctc, nemo_ctc. " + "All other values lead to loading the model twice."); +} + +bool OnlineModelConfig::Validate() const { + // For RK NPU, we reinterpret num_threads: + // + // For RK3588 only + // num_threads == 1 -> Select a core randomly + // num_threads == 0 -> Use NPU core 0 + // num_threads == -1 -> Use NPU core 1 + // num_threads == -2 -> Use NPU core 2 + // num_threads == -3 -> Use NPU core 0 and core 1 + // num_threads == -4 -> Use NPU core 0, core 1, and core 2 + if (provider_config.provider != "rknn") { + if (num_threads < 1) { + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); + return false; + } + if (!transducer.encoder.empty() && (EndsWith(transducer.encoder, ".rknn") || + EndsWith(transducer.decoder, ".rknn") || + EndsWith(transducer.joiner, ".rknn"))) { + SHERPA_ONNX_LOGE( + "--provider is %s, which is not rknn, but you pass rknn model " + "filenames. encoder: '%s', decoder: '%s', joiner: '%s'", + provider_config.provider.c_str(), transducer.encoder.c_str(), + transducer.decoder.c_str(), transducer.joiner.c_str()); + return false; + } + + if (!zipformer2_ctc.model.empty() && + EndsWith(zipformer2_ctc.model, ".rknn")) { + SHERPA_ONNX_LOGE( + "--provider is %s, which is not rknn, but you pass rknn model " + "filename for zipformer2_ctc: '%s'", + provider_config.provider.c_str(), zipformer2_ctc.model.c_str()); + return false; + } + } + + if (provider_config.provider == "rknn") { + if (!transducer.encoder.empty() && (EndsWith(transducer.encoder, ".onnx") || + EndsWith(transducer.decoder, ".onnx") || + EndsWith(transducer.joiner, ".onnx"))) { + SHERPA_ONNX_LOGE( + "--provider is rknn, but you pass onnx model " + "filenames. encoder: '%s', decoder: '%s', joiner: '%s'", + transducer.encoder.c_str(), transducer.decoder.c_str(), + transducer.joiner.c_str()); + return false; + } + + if (!zipformer2_ctc.model.empty() && + EndsWith(zipformer2_ctc.model, ".onnx")) { + SHERPA_ONNX_LOGE( + "--provider rknn, but you pass onnx model filename for " + "zipformer2_ctc: '%s'", + zipformer2_ctc.model.c_str()); + return false; + } + } + + if (!tokens_buf.empty() && FileExists(tokens)) { + SHERPA_ONNX_LOGE( + "you can not provide a tokens_buf and a tokens file: '%s', " + "at the same time, which is confusing", + tokens.c_str()); + return false; + } + + if (tokens_buf.empty() && !FileExists(tokens)) { + SHERPA_ONNX_LOGE( + "tokens: '%s' does not exist, you should provide " + "either a tokens buffer or a tokens file", + tokens.c_str()); + return false; + } + + if (!modeling_unit.empty() && + (modeling_unit == "bpe" || modeling_unit == "cjkchar+bpe")) { + if (!FileExists(bpe_vocab)) { + SHERPA_ONNX_LOGE("bpe_vocab: '%s' does not exist", bpe_vocab.c_str()); + return false; + } + } + + if (!paraformer.encoder.empty()) { + return paraformer.Validate(); + } + + if (!wenet_ctc.model.empty()) { + return wenet_ctc.Validate(); + } + + if (!zipformer2_ctc.model.empty()) { + return zipformer2_ctc.Validate(); + } + + if (!nemo_ctc.model.empty()) { + return nemo_ctc.Validate(); + } + + if (!provider_config.Validate()) { + return false; + } + + return transducer.Validate(); +} + +std::string OnlineModelConfig::ToString() const { + std::ostringstream os; + + os << "OnlineModelConfig("; + os << "transducer=" << transducer.ToString() << ", "; + os << "paraformer=" << paraformer.ToString() << ", "; + os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; + os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; + os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; + os << "provider_config=" << provider_config.ToString() << ", "; + os << "tokens=\"" << tokens << "\", "; + os << "num_threads=" << num_threads << ", "; + os << "warm_up=" << warm_up << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "model_type=\"" << model_type << "\", "; + os << "modeling_unit=\"" << modeling_unit << "\", "; + os << "bpe_vocab=\"" << bpe_vocab << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-model-config.h new file mode 100644 index 00000000..3a7c42f9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-model-config.h @@ -0,0 +1,86 @@ +// sherpa-mnn/csrc/online-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/online-nemo-ctc-model-config.h" +#include "sherpa-mnn/csrc/online-paraformer-model-config.h" +#include "sherpa-mnn/csrc/online-transducer-model-config.h" +#include "sherpa-mnn/csrc/online-wenet-ctc-model-config.h" +#include "sherpa-mnn/csrc/online-zipformer2-ctc-model-config.h" +#include "sherpa-mnn/csrc/provider-config.h" + +namespace sherpa_mnn { + +struct OnlineModelConfig { + OnlineTransducerModelConfig transducer; + OnlineParaformerModelConfig paraformer; + OnlineWenetCtcModelConfig wenet_ctc; + OnlineZipformer2CtcModelConfig zipformer2_ctc; + OnlineNeMoCtcModelConfig nemo_ctc; + ProviderConfig provider_config; + std::string tokens; + int32_t num_threads = 1; + int32_t warm_up = 0; + bool debug = false; + + // Valid values: + // - conformer, conformer transducer from icefall + // - lstm, lstm transducer from icefall + // - zipformer, zipformer transducer from icefall + // - zipformer2, zipformer2 transducer or CTC from icefall + // - wenet_ctc, wenet CTC model + // - nemo_ctc, NeMo CTC model + // + // All other values are invalid and lead to loading the model twice. + std::string model_type; + + // Valid values: + // - cjkchar + // - bpe + // - cjkchar+bpe + std::string modeling_unit = "cjkchar"; + std::string bpe_vocab; + + /// if tokens_buf is non-empty, + /// the tokens will be loaded from the buffer instead of from the + /// "tokens" file + std::string tokens_buf; + + OnlineModelConfig() = default; + OnlineModelConfig(const OnlineTransducerModelConfig &transducer, + const OnlineParaformerModelConfig ¶former, + const OnlineWenetCtcModelConfig &wenet_ctc, + const OnlineZipformer2CtcModelConfig &zipformer2_ctc, + const OnlineNeMoCtcModelConfig &nemo_ctc, + const ProviderConfig &provider_config, + const std::string &tokens, int32_t num_threads, + int32_t warm_up, bool debug, const std::string &model_type, + const std::string &modeling_unit, + const std::string &bpe_vocab) + : transducer(transducer), + paraformer(paraformer), + wenet_ctc(wenet_ctc), + zipformer2_ctc(zipformer2_ctc), + nemo_ctc(nemo_ctc), + provider_config(provider_config), + tokens(tokens), + num_threads(num_threads), + warm_up(warm_up), + debug(debug), + model_type(model_type), + modeling_unit(modeling_unit), + bpe_vocab(bpe_vocab) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-nemo-ctc-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-nemo-ctc-model-config.cc new file mode 100644 index 00000000..60e1865a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-nemo-ctc-model-config.cc @@ -0,0 +1,36 @@ +// sherpa-mnn/csrc/online-nemo-ctc-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-nemo-ctc-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OnlineNeMoCtcModelConfig::Register(ParseOptions *po) { + po->Register("nemo-ctc-model", &model, + "Path to CTC model.onnx from NeMo. Please see " + "https://github.com/k2-fsa/sherpa-mnn/pull/843"); +} + +bool OnlineNeMoCtcModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("NeMo CTC model '%s' does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OnlineNeMoCtcModelConfig::ToString() const { + std::ostringstream os; + + os << "OnlineNeMoCtcModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-nemo-ctc-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-nemo-ctc-model-config.h new file mode 100644 index 00000000..358061fd --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-nemo-ctc-model-config.h @@ -0,0 +1,28 @@ +// sherpa-mnn/csrc/online-nemo-ctc-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OnlineNeMoCtcModelConfig { + std::string model; + + OnlineNeMoCtcModelConfig() = default; + + explicit OnlineNeMoCtcModelConfig(const std::string &model) : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-nemo-ctc-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-nemo-ctc-model.cc new file mode 100644 index 00000000..3cd735a7 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-nemo-ctc-model.cc @@ -0,0 +1,340 @@ +// sherpa-mnn/csrc/online-nemo-ctc-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-nemo-ctc-model.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/cat.h" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" +#include "sherpa-mnn/csrc/transpose.h" +#include "sherpa-mnn/csrc/unbind.h" + +namespace sherpa_mnn { + +class OnlineNeMoCtcModel::Impl { + public: + explicit Impl(const OnlineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.nemo_ctc.model); + Init(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const OnlineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.nemo_ctc.model); + Init(buf.data(), buf.size()); + } + } + + std::vector Forward(MNN::Express::VARP x, + std::vector states) { + MNN::Express::VARP &cache_last_channel = states[0]; + MNN::Express::VARP &cache_last_time = states[1]; + MNN::Express::VARP &cache_last_channel_len = states[2]; + + int32_t batch_size = x->getInfo()->dim[0]; + + std::array length_shape{batch_size}; + + MNN::Express::VARP length = MNNUtilsCreateTensor( + allocator_, length_shape.data(), length_shape.size()); + + int *p_length = length->writeMap(); + + std::fill(p_length, p_length + batch_size, ChunkLength()); + + // (B, T, C) -> (B, C, T) + x = Transpose12(allocator_, x); + + std::vector inputs = { + std::move(x), View(length), std::move(cache_last_channel), + std::move(cache_last_time), std::move(cache_last_channel_len)}; + + auto out = + sess_->onForward(inputs); + // out[0]: logit + // out[1] logit_length + // out[2:] states_next + // + // we need to remove out[1] + + std::vector ans; + ans.reserve(out.size() - 1); + + for (int32_t i = 0; i != out.size(); ++i) { + if (i == 1) { + continue; + } + + ans.push_back(std::move(out[i])); + } + + return ans; + } + + int32_t VocabSize() const { return vocab_size_; } + + int32_t ChunkLength() const { return window_size_; } + + int32_t ChunkShift() const { return chunk_shift_; } + + MNNAllocator *Allocator() { return allocator_; } + + // Return a vector containing 3 tensors + // - cache_last_channel + // - cache_last_time_ + // - cache_last_channel_len + std::vector GetInitStates() { + std::vector ans; + ans.reserve(3); + ans.push_back(View(cache_last_channel_)); + ans.push_back(View(cache_last_time_)); + ans.push_back(View(cache_last_channel_len_)); + + return ans; + } + + std::vector StackStates( + std::vector> states) { + int32_t batch_size = static_cast(states.size()); + if (batch_size == 1) { + return std::move(states[0]); + } + + std::vector ans; + + // stack cache_last_channel + std::vector buf(batch_size); + + // there are 3 states to be stacked + for (int32_t i = 0; i != 3; ++i) { + buf.clear(); + buf.reserve(batch_size); + + for (int32_t b = 0; b != batch_size; ++b) { + assert(states[b].size() == 3); + buf.push_back(states[b][i]); + } + + MNN::Express::VARP c{nullptr}; + if (i == 2) { + c = Cat(allocator_, buf, 0); + } else { + c = Cat(allocator_, buf, 0); + } + + ans.push_back(std::move(c)); + } + + return ans; + } + + std::vector> UnStackStates( + std::vector states) const { + assert(states.size() == 3); + + auto allocator = const_cast(this)->allocator_; + + std::vector> ans; + + auto shape = states[0]->getInfo()->dim; + int32_t batch_size = shape[0]; + ans.resize(batch_size); + + if (batch_size == 1) { + ans[0] = std::move(states); + return ans; + } + + for (int32_t i = 0; i != 3; ++i) { + std::vector v; + if (i == 2) { + v = Unbind(allocator, states[i], 0); + } else { + v = Unbind(allocator, states[i], 0); + } + + assert(v.size() == batch_size); + + for (int32_t b = 0; b != batch_size; ++b) { + ans[b].push_back(std::move(v[b])); + } + } + + return ans; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(window_size_, "window_size"); + SHERPA_ONNX_READ_META_DATA(chunk_shift_, "chunk_shift"); + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim1_, + "cache_last_channel_dim1"); + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim2_, + "cache_last_channel_dim2"); + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim3_, + "cache_last_channel_dim3"); + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim1_, "cache_last_time_dim1"); + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim2_, "cache_last_time_dim2"); + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim3_, "cache_last_time_dim3"); + + // need to increase by 1 since the blank token is not included in computing + // vocab_size in NeMo. + vocab_size_ += 1; + + InitStates(); + } + + void InitStates() { + std::array cache_last_channel_shape{1, cache_last_channel_dim1_, + cache_last_channel_dim2_, + cache_last_channel_dim3_}; + + cache_last_channel_ = MNNUtilsCreateTensor( + allocator_, cache_last_channel_shape.data(), + cache_last_channel_shape.size()); + + Fill(cache_last_channel_, 0); + + std::array cache_last_time_shape{ + 1, cache_last_time_dim1_, cache_last_time_dim2_, cache_last_time_dim3_}; + + cache_last_time_ = MNNUtilsCreateTensor( + allocator_, cache_last_time_shape.data(), cache_last_time_shape.size()); + + Fill(cache_last_time_, 0); + + int shape = 1; + cache_last_channel_len_ = + MNNUtilsCreateTensor(allocator_, &shape, 1); + + cache_last_channel_len_->writeMap()[0] = 0; + } + + private: + OnlineModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t window_size_ = 0; + int32_t chunk_shift_ = 0; + int32_t subsampling_factor_ = 0; + int32_t vocab_size_ = 0; + int32_t cache_last_channel_dim1_ = 0; + int32_t cache_last_channel_dim2_ = 0; + int32_t cache_last_channel_dim3_ = 0; + int32_t cache_last_time_dim1_ = 0; + int32_t cache_last_time_dim2_ = 0; + int32_t cache_last_time_dim3_ = 0; + + MNN::Express::VARP cache_last_channel_{nullptr}; + MNN::Express::VARP cache_last_time_{nullptr}; + MNN::Express::VARP cache_last_channel_len_{nullptr}; +}; + +OnlineNeMoCtcModel::OnlineNeMoCtcModel(const OnlineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OnlineNeMoCtcModel::OnlineNeMoCtcModel(Manager *mgr, + const OnlineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OnlineNeMoCtcModel::~OnlineNeMoCtcModel() = default; + +std::vector OnlineNeMoCtcModel::Forward( + MNN::Express::VARP x, std::vector states) const { + return impl_->Forward(std::move(x), std::move(states)); +} + +int32_t OnlineNeMoCtcModel::VocabSize() const { return impl_->VocabSize(); } + +int32_t OnlineNeMoCtcModel::ChunkLength() const { return impl_->ChunkLength(); } + +int32_t OnlineNeMoCtcModel::ChunkShift() const { return impl_->ChunkShift(); } + +MNNAllocator *OnlineNeMoCtcModel::Allocator() const { + return impl_->Allocator(); +} + +std::vector OnlineNeMoCtcModel::GetInitStates() const { + return impl_->GetInitStates(); +} + +std::vector OnlineNeMoCtcModel::StackStates( + std::vector> states) const { + return impl_->StackStates(std::move(states)); +} + +std::vector> OnlineNeMoCtcModel::UnStackStates( + std::vector states) const { + return impl_->UnStackStates(std::move(states)); +} + +#if __ANDROID_API__ >= 9 +template OnlineNeMoCtcModel::OnlineNeMoCtcModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineNeMoCtcModel::OnlineNeMoCtcModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-nemo-ctc-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-nemo-ctc-model.h new file mode 100644 index 00000000..e75e848b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-nemo-ctc-model.h @@ -0,0 +1,75 @@ +// sherpa-mnn/csrc/online-nemo-ctc-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_ + +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/online-ctc-model.h" +#include "sherpa-mnn/csrc/online-model-config.h" + +namespace sherpa_mnn { + +class OnlineNeMoCtcModel : public OnlineCtcModel { + public: + explicit OnlineNeMoCtcModel(const OnlineModelConfig &config); + + template + OnlineNeMoCtcModel(Manager *mgr, const OnlineModelConfig &config); + + ~OnlineNeMoCtcModel() override; + + // A list of 3 tensors: + // - cache_last_channel + // - cache_last_time + // - cache_last_channel_len + std::vector GetInitStates() const override; + + std::vector StackStates( + std::vector> states) const override; + + std::vector> UnStackStates( + std::vector states) const override; + + /** + * + * @param x A 3-D tensor of shape (N, T, C). N has to be 1. + * @param states It is from GetInitStates() or returned from this method. + * + * @return Return a list of tensors + * - ans[0] contains log_probs, of shape (N, T, C) + * - ans[1:] contains next_states + */ + std::vector Forward( + MNN::Express::VARP x, std::vector states) const override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const override; + + // The model accepts this number of frames before subsampling as input + int32_t ChunkLength() const override; + + // Similar to frame_shift in feature extractor, after processing + // ChunkLength() frames, we advance by ChunkShift() frames + // before we process the next chunk. + int32_t ChunkShift() const override; + + bool SupportBatchProcessing() const override { return true; } + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-paraformer-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-paraformer-decoder.h new file mode 100644 index 00000000..9c91e0da --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-paraformer-decoder.h @@ -0,0 +1,23 @@ +// sherpa-mnn/csrc/online-paraformer-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT + +namespace sherpa_mnn { + +struct OnlineParaformerDecoderResult { + /// The decoded token IDs + std::vector tokens; + + int32_t last_non_blank_frame_index = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-paraformer-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-paraformer-model-config.cc new file mode 100644 index 00000000..d8d2069f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-paraformer-model-config.cc @@ -0,0 +1,43 @@ +// sherpa-mnn/csrc/online-paraformer-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-paraformer-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OnlineParaformerModelConfig::Register(ParseOptions *po) { + po->Register("paraformer-encoder", &encoder, + "Path to encoder.onnx of paraformer."); + po->Register("paraformer-decoder", &decoder, + "Path to decoder.onnx of paraformer."); +} + +bool OnlineParaformerModelConfig::Validate() const { + if (!FileExists(encoder)) { + SHERPA_ONNX_LOGE("Paraformer encoder '%s' does not exist", encoder.c_str()); + return false; + } + + if (!FileExists(decoder)) { + SHERPA_ONNX_LOGE("Paraformer decoder '%s' does not exist", decoder.c_str()); + return false; + } + + return true; +} + +std::string OnlineParaformerModelConfig::ToString() const { + std::ostringstream os; + + os << "OnlineParaformerModelConfig("; + os << "encoder=\"" << encoder << "\", "; + os << "decoder=\"" << decoder << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-paraformer-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-paraformer-model-config.h new file mode 100644 index 00000000..1a1a1d6e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-paraformer-model-config.h @@ -0,0 +1,31 @@ +// sherpa-mnn/csrc/online-paraformer-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OnlineParaformerModelConfig { + std::string encoder; + std::string decoder; + + OnlineParaformerModelConfig() = default; + + OnlineParaformerModelConfig(const std::string &encoder, + const std::string &decoder) + : encoder(encoder), decoder(decoder) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-paraformer-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-paraformer-model.cc new file mode 100644 index 00000000..7071710c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-paraformer-model.cc @@ -0,0 +1,260 @@ +// sherpa-mnn/csrc/online-paraformer-model.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-paraformer-model.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OnlineParaformerModel::Impl { + public: + explicit Impl(const OnlineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.paraformer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.paraformer.decoder); + InitDecoder(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const OnlineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.paraformer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.paraformer.decoder); + InitDecoder(buf.data(), buf.size()); + } + } + + std::vector ForwardEncoder(MNN::Express::VARP features, + MNN::Express::VARP features_length) { + std::vector inputs = {std::move(features), + std::move(features_length)}; + + return encoder_sess_->onForward(inputs); + } + + std::vector ForwardDecoder(MNN::Express::VARP encoder_out, + MNN::Express::VARP encoder_out_length, + MNN::Express::VARP acoustic_embedding, + MNN::Express::VARP acoustic_embedding_length, + std::vector states) { + std::vector decoder_inputs; + decoder_inputs.reserve(4 + states.size()); + + decoder_inputs.push_back(std::move(encoder_out)); + decoder_inputs.push_back(std::move(encoder_out_length)); + decoder_inputs.push_back(std::move(acoustic_embedding)); + decoder_inputs.push_back(std::move(acoustic_embedding_length)); + + for (auto &v : states) { + decoder_inputs.push_back(std::move(v)); + } + + return decoder_sess_->onForward( + decoder_inputs); + } + + int32_t VocabSize() const { return vocab_size_; } + + int32_t LfrWindowSize() const { return lfr_window_size_; } + + int32_t LfrWindowShift() const { return lfr_window_shift_; } + + int32_t EncoderOutputSize() const { return encoder_output_size_; } + + int32_t DecoderKernelSize() const { return decoder_kernel_size_; } + + int32_t DecoderNumBlocks() const { return decoder_num_blocks_; } + + const std::vector &NegativeMean() const { return neg_mean_; } + + const std::vector &InverseStdDev() const { return inv_stddev_; } + + MNNAllocator *Allocator() { return allocator_; } + + private: + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = encoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(lfr_window_size_, "lfr_window_size"); + SHERPA_ONNX_READ_META_DATA(lfr_window_shift_, "lfr_window_shift"); + SHERPA_ONNX_READ_META_DATA(encoder_output_size_, "encoder_output_size"); + SHERPA_ONNX_READ_META_DATA(decoder_num_blocks_, "decoder_num_blocks"); + SHERPA_ONNX_READ_META_DATA(decoder_kernel_size_, "decoder_kernel_size"); + + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(neg_mean_, "neg_mean"); + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(inv_stddev_, "inv_stddev"); + + float scale = std::sqrt(encoder_output_size_); + for (auto &f : inv_stddev_) { + f *= scale; + } + } + + void InitDecoder(void *model_data, size_t model_data_length) { + decoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + } + + private: + OnlineModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr encoder_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::unique_ptr decoder_sess_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector neg_mean_; + std::vector inv_stddev_; + + int32_t vocab_size_ = 0; // initialized in Init + int32_t lfr_window_size_ = 0; + int32_t lfr_window_shift_ = 0; + + int32_t encoder_output_size_ = 0; + int32_t decoder_num_blocks_ = 0; + int32_t decoder_kernel_size_ = 0; +}; + +OnlineParaformerModel::OnlineParaformerModel(const OnlineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OnlineParaformerModel::OnlineParaformerModel(Manager *mgr, + const OnlineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OnlineParaformerModel::~OnlineParaformerModel() = default; + +std::vector OnlineParaformerModel::ForwardEncoder( + MNN::Express::VARP features, MNN::Express::VARP features_length) const { + return impl_->ForwardEncoder(std::move(features), std::move(features_length)); +} + +std::vector OnlineParaformerModel::ForwardDecoder( + MNN::Express::VARP encoder_out, MNN::Express::VARP encoder_out_length, + MNN::Express::VARP acoustic_embedding, MNN::Express::VARP acoustic_embedding_length, + std::vector states) const { + return impl_->ForwardDecoder( + std::move(encoder_out), std::move(encoder_out_length), + std::move(acoustic_embedding), std::move(acoustic_embedding_length), + std::move(states)); +} + +int32_t OnlineParaformerModel::VocabSize() const { return impl_->VocabSize(); } + +int32_t OnlineParaformerModel::LfrWindowSize() const { + return impl_->LfrWindowSize(); +} +int32_t OnlineParaformerModel::LfrWindowShift() const { + return impl_->LfrWindowShift(); +} + +int32_t OnlineParaformerModel::EncoderOutputSize() const { + return impl_->EncoderOutputSize(); +} + +int32_t OnlineParaformerModel::DecoderKernelSize() const { + return impl_->DecoderKernelSize(); +} + +int32_t OnlineParaformerModel::DecoderNumBlocks() const { + return impl_->DecoderNumBlocks(); +} + +const std::vector &OnlineParaformerModel::NegativeMean() const { + return impl_->NegativeMean(); +} +const std::vector &OnlineParaformerModel::InverseStdDev() const { + return impl_->InverseStdDev(); +} + +MNNAllocator *OnlineParaformerModel::Allocator() const { + return impl_->Allocator(); +} + +#if __ANDROID_API__ >= 9 +template OnlineParaformerModel::OnlineParaformerModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineParaformerModel::OnlineParaformerModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-paraformer-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-paraformer-model.h new file mode 100644 index 00000000..b86dcb95 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-paraformer-model.h @@ -0,0 +1,70 @@ +// sherpa-mnn/csrc/online-paraformer-model.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_ + +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/online-model-config.h" + +namespace sherpa_mnn { + +class OnlineParaformerModel { + public: + explicit OnlineParaformerModel(const OnlineModelConfig &config); + + template + OnlineParaformerModel(Manager *mgr, const OnlineModelConfig &config); + + ~OnlineParaformerModel(); + + std::vector ForwardEncoder(MNN::Express::VARP features, + MNN::Express::VARP features_length) const; + + std::vector ForwardDecoder(MNN::Express::VARP encoder_out, + MNN::Express::VARP encoder_out_length, + MNN::Express::VARP acoustic_embedding, + MNN::Express::VARP acoustic_embedding_length, + std::vector states) const; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const; + + /** It is lfr_m in config.yaml + */ + int32_t LfrWindowSize() const; + + /** It is lfr_n in config.yaml + */ + int32_t LfrWindowShift() const; + + int32_t EncoderOutputSize() const; + + int32_t DecoderKernelSize() const; + int32_t DecoderNumBlocks() const; + + /** Return negative mean for CMVN + */ + const std::vector &NegativeMean() const; + + /** Return inverse stddev for CMVN + */ + const std::vector &InverseStdDev() const; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation-cnn-bilstm-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation-cnn-bilstm-impl.h new file mode 100644 index 00000000..9ff0f054 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation-cnn-bilstm-impl.h @@ -0,0 +1,278 @@ +// sherpa-mnn/csrc/online-punctuation-cnn-bilstm-impl.h +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_ + +#include + +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include // NOLINT + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/math.h" +#include "sherpa-mnn/csrc/online-cnn-bilstm-model-meta-data.h" +#include "sherpa-mnn/csrc/online-cnn-bilstm-model.h" +#include "sherpa-mnn/csrc/online-punctuation-impl.h" +#include "sherpa-mnn/csrc/online-punctuation.h" +#include "sherpa-mnn/csrc/text-utils.h" +#include "ssentencepiece/csrc/ssentencepiece.h" + +namespace sherpa_mnn { + +static const int32_t kMaxSeqLen = 200; + +class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { + public: + explicit OnlinePunctuationCNNBiLSTMImpl(const OnlinePunctuationConfig &config) + : config_(config), model_(config.model) { + if (!config_.model.bpe_vocab.empty()) { + bpe_encoder_ = std::make_unique( + config_.model.bpe_vocab); + } + } + +#if __ANDROID_API__ >= 9 + OnlinePunctuationCNNBiLSTMImpl(AAssetManager *mgr, + const OnlinePunctuationConfig &config) + : config_(config), model_(mgr, config.model) { + if (!config_.model.bpe_vocab.empty()) { + auto buf = ReadFile(mgr, config_.model.bpe_vocab); + std::istringstream iss(std::string(buf.begin(), buf.end())); + bpe_encoder_ = std::make_unique(iss); + } + } +#endif + + std::string AddPunctuationWithCase(const std::string &text) const override { + if (text.empty()) { + return {}; + } + + std::vector tokens_list; // N * kMaxSeqLen + std::vector valids_list; // N * kMaxSeqLen + std::vector label_len_list; // N + + EncodeSentences(text, tokens_list, valids_list, label_len_list); + + const auto &meta_data = model_.metaData(); + + auto memory_info = + (MNNAllocator*)(nullptr); + + int32_t n = label_len_list.size(); + + std::array token_ids_shape = {n, kMaxSeqLen}; + MNN::Express::VARP token_ids = MNNUtilsCreateTensor( + memory_info, tokens_list.data(), tokens_list.size(), + token_ids_shape.data(), token_ids_shape.size()); + + std::array valid_ids_shape = {n, kMaxSeqLen}; + MNN::Express::VARP valid_ids = MNNUtilsCreateTensor( + memory_info, valids_list.data(), valids_list.size(), + valid_ids_shape.data(), valid_ids_shape.size()); + + std::array label_len_shape = {n}; + MNN::Express::VARP label_len = MNNUtilsCreateTensor( + memory_info, label_len_list.data(), label_len_list.size(), + label_len_shape.data(), label_len_shape.size()); + + auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids), + std::move(label_len)); + + std::vector case_pred; + std::vector punct_pred; + const float *active_case_logits = pair.first->readMap(); + const float *active_punct_logits = pair.second->readMap(); + std::vector case_logits_shape = + pair.first->getInfo()->dim; + + for (int32_t i = 0; i < case_logits_shape[0]; ++i) { + const float *p_cur_case = active_case_logits + i * meta_data.num_cases; + auto index_case = static_cast(std::distance( + p_cur_case, + std::max_element(p_cur_case, p_cur_case + meta_data.num_cases))); + case_pred.push_back(index_case); + + const float *p_cur_punct = + active_punct_logits + i * meta_data.num_punctuations; + auto index_punct = static_cast(std::distance( + p_cur_punct, + std::max_element(p_cur_punct, + p_cur_punct + meta_data.num_punctuations))); + punct_pred.push_back(index_punct); + } + + std::string ans = DecodeSentences(text, case_pred, punct_pred); + + return ans; + } + + private: + void EncodeSentences(const std::string &text, + std::vector &tokens_list, // NOLINT + std::vector &valids_list, // NOLINT + std::vector &label_len_list) const { // NOLINT + std::vector tokens; + std::vector valids; + int32_t label_len = 0; + + tokens.push_back(1); // hardcode 1 now, 1 - + valids.push_back(1); + + std::stringstream ss(text); + std::string word; + while (ss >> word) { + std::vector word_tokens; + bpe_encoder_->Encode(word, &word_tokens); + + int32_t seq_len = tokens.size() + word_tokens.size(); + if (seq_len > kMaxSeqLen - 1) { + tokens.push_back(2); // hardcode 2 now, 2 - + valids.push_back(1); + + label_len = std::count(valids.begin(), valids.end(), 1); + + if (tokens.size() < kMaxSeqLen) { + tokens.resize(kMaxSeqLen, 0); + valids.resize(kMaxSeqLen, 0); + } + + assert(tokens.size() == kMaxSeqLen); + assert(valids.size() == kMaxSeqLen); + + tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end()); + valids_list.insert(valids_list.end(), valids.begin(), valids.end()); + label_len_list.push_back(label_len); + + std::vector().swap(tokens); + std::vector().swap(valids); + label_len = 0; + tokens.push_back(1); // hardcode 1 now, 1 - + valids.push_back(1); + } + + tokens.insert(tokens.end(), word_tokens.begin(), word_tokens.end()); + valids.push_back(1); // only the first sub word is valid + int32_t remaining_size = static_cast(word_tokens.size()) - 1; + if (remaining_size > 0) { + int32_t valids_cur_size = static_cast(valids.size()); + valids.resize(valids_cur_size + remaining_size, 0); + } + } + + if (tokens.size() > 0) { + tokens.push_back(2); // hardcode 2 now, 2 - + valids.push_back(1); + + label_len = std::count(valids.begin(), valids.end(), 1); + + if (tokens.size() < kMaxSeqLen) { + tokens.resize(kMaxSeqLen, 0); + valids.resize(kMaxSeqLen, 0); + } + + assert(tokens.size() == kMaxSeqLen); + assert(valids.size() == kMaxSeqLen); + + tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end()); + valids_list.insert(valids_list.end(), valids.begin(), valids.end()); + label_len_list.push_back(label_len); + } + } + + std::string DecodeSentences(const std::string &raw_text, + const std::vector &case_pred, + const std::vector &punct_pred) const { + std::string result_text; + std::istringstream iss(raw_text); + std::vector words; + std::string word; + + while (iss >> word) { + words.emplace_back(word); + } + + assert(words.size() == case_pred.size()); + assert(words.size() == punct_pred.size()); + + for (int32_t i = 0; i < words.size(); ++i) { + std::string prefix = ((i != 0) ? " " : ""); + result_text += prefix; + switch (case_pred[i]) { + case 1: // upper + { + std::transform(words[i].begin(), words[i].end(), words[i].begin(), + [](auto c) { return std::toupper(c); }); + result_text += words[i]; + break; + } + case 2: // cap + { + words[i][0] = std::toupper(words[i][0]); + result_text += words[i]; + break; + } + case 3: // mix case + { + // TODO(frankyoujian): + // Need to add a map containing supported mix case words so that we + // can fetch the predicted word from the map e.g. mcdonald's -> + // McDonald's + result_text += words[i]; + break; + } + default: { + result_text += words[i]; + break; + } + } + + std::string suffix; + switch (punct_pred[i]) { + case 1: // comma + { + suffix = ","; + break; + } + case 2: // period + { + suffix = "."; + break; + } + case 3: // question + { + suffix = "?"; + break; + } + default: + break; + } + + result_text += suffix; + } + + return result_text; + } + + private: + OnlinePunctuationConfig config_; + OnlineCNNBiLSTMModel model_; + std::unique_ptr bpe_encoder_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation-impl.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation-impl.cc new file mode 100644 index 00000000..878ffc8e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation-impl.cc @@ -0,0 +1,43 @@ +// sherpa-mnn/csrc/online-punctuation-impl.cc +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#include "sherpa-mnn/csrc/online-punctuation-impl.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-punctuation-cnn-bilstm-impl.h" + +namespace sherpa_mnn { + +std::unique_ptr OnlinePunctuationImpl::Create( + const OnlinePunctuationConfig &config) { + if (!config.model.cnn_bilstm.empty() && !config.model.bpe_vocab.empty()) { + return std::make_unique(config); + } + + SHERPA_ONNX_LOGE( + "Please specify a punctuation model and bpe vocab! Return a null " + "pointer"); + return nullptr; +} + +#if __ANDROID_API__ >= 9 +std::unique_ptr OnlinePunctuationImpl::Create( + AAssetManager *mgr, const OnlinePunctuationConfig &config) { + if (!config.model.cnn_bilstm.empty() && !config.model.bpe_vocab.empty()) { + return std::make_unique(mgr, config); + } + + SHERPA_ONNX_LOGE( + "Please specify a punctuation model and bpe vocab! Return a null " + "pointer"); + return nullptr; +} +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation-impl.h new file mode 100644 index 00000000..c4289090 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation-impl.h @@ -0,0 +1,37 @@ +// sherpa-mnn/csrc/online-punctuation-impl.h +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_IMPL_H_ + +#include +#include +#include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/online-punctuation.h" + +namespace sherpa_mnn { + +class OnlinePunctuationImpl { + public: + virtual ~OnlinePunctuationImpl() = default; + + static std::unique_ptr Create( + const OnlinePunctuationConfig &config); + +#if __ANDROID_API__ >= 9 + static std::unique_ptr Create( + AAssetManager *mgr, const OnlinePunctuationConfig &config); +#endif + + virtual std::string AddPunctuationWithCase(const std::string &text) const = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation-model-config.cc new file mode 100644 index 00000000..d270c75a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation-model-config.cc @@ -0,0 +1,65 @@ +// sherpa-mnn/csrc/online-punctuation-model-config.cc +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#include "sherpa-mnn/csrc/online-punctuation-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OnlinePunctuationModelConfig::Register(ParseOptions *po) { + po->Register("cnn-bilstm", &cnn_bilstm, + "Path to the light-weight CNN-BiLSTM model"); + + po->Register("bpe-vocab", &bpe_vocab, "Path to the bpe vocab file"); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool OnlinePunctuationModelConfig::Validate() const { + if (cnn_bilstm.empty()) { + SHERPA_ONNX_LOGE("Please provide --cnn-bilstm"); + return false; + } + + if (!FileExists(cnn_bilstm)) { + SHERPA_ONNX_LOGE("--cnn-bilstm '%s' does not exist", cnn_bilstm.c_str()); + return false; + } + + if (bpe_vocab.empty()) { + SHERPA_ONNX_LOGE("Please provide --bpe-vocab"); + return false; + } + + if (!FileExists(bpe_vocab)) { + SHERPA_ONNX_LOGE("--bpe-vocab '%s' does not exist", bpe_vocab.c_str()); + return false; + } + + return true; +} + +std::string OnlinePunctuationModelConfig::ToString() const { + std::ostringstream os; + + os << "OnlinePunctuationModelConfig("; + os << "cnn_bilstm=\"" << cnn_bilstm << "\", "; + os << "bpe_vocab=\"" << bpe_vocab << "\", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation-model-config.h new file mode 100644 index 00000000..63fcedb6 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation-model-config.h @@ -0,0 +1,42 @@ +// sherpa-mnn/csrc/online-punctuation-model-config.h +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OnlinePunctuationModelConfig { + std::string cnn_bilstm; + std::string bpe_vocab; + + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + OnlinePunctuationModelConfig() = default; + + OnlinePunctuationModelConfig(const std::string &cnn_bilstm, + const std::string &bpe_vocab, + int32_t num_threads, bool debug, + const std::string &provider) + : cnn_bilstm(cnn_bilstm), + bpe_vocab(bpe_vocab), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation.cc new file mode 100644 index 00000000..1f506325 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation.cc @@ -0,0 +1,52 @@ +// sherpa-mnn/csrc/online-punctuation.cc +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#include "sherpa-mnn/csrc/online-punctuation.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-punctuation-impl.h" + +namespace sherpa_mnn { + +void OnlinePunctuationConfig::Register(ParseOptions *po) { model.Register(po); } + +bool OnlinePunctuationConfig::Validate() const { + if (!model.Validate()) { + return false; + } + + return true; +} + +std::string OnlinePunctuationConfig::ToString() const { + std::ostringstream os; + + os << "OnlinePunctuationConfig("; + os << "model=" << model.ToString() << ")"; + + return os.str(); +} + +OnlinePunctuation::OnlinePunctuation(const OnlinePunctuationConfig &config) + : impl_(OnlinePunctuationImpl::Create(config)) {} + +#if __ANDROID_API__ >= 9 +OnlinePunctuation::OnlinePunctuation(AAssetManager *mgr, + const OnlinePunctuationConfig &config) + : impl_(OnlinePunctuationImpl::Create(mgr, config)) {} +#endif + +OnlinePunctuation::~OnlinePunctuation() = default; + +std::string OnlinePunctuation::AddPunctuationWithCase( + const std::string &text) const { + return impl_->AddPunctuationWithCase(text); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation.h new file mode 100644 index 00000000..12f2128d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-punctuation.h @@ -0,0 +1,57 @@ +// sherpa-mnn/csrc/online-punctuation.h +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_H_ +#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_H_ + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/online-punctuation-model-config.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OnlinePunctuationConfig { + OnlinePunctuationModelConfig model; + + OnlinePunctuationConfig() = default; + + explicit OnlinePunctuationConfig(const OnlinePunctuationModelConfig &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +class OnlinePunctuationImpl; + +class OnlinePunctuation { + public: + explicit OnlinePunctuation(const OnlinePunctuationConfig &config); + +#if __ANDROID_API__ >= 9 + OnlinePunctuation(AAssetManager *mgr, const OnlinePunctuationConfig &config); +#endif + + ~OnlinePunctuation(); + + // Add punctuation and casing to the input text and return it. + std::string AddPunctuationWithCase(const std::string &text) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-ctc-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-ctc-impl.h new file mode 100644 index 00000000..fa288de4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-ctc-impl.h @@ -0,0 +1,329 @@ +// sherpa-mnn/csrc/online-recognizer-ctc-impl.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_CTC_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_CTC_IMPL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-ctc-decoder.h" +#include "sherpa-mnn/csrc/online-ctc-fst-decoder.h" +#include "sherpa-mnn/csrc/online-ctc-greedy-search-decoder.h" +#include "sherpa-mnn/csrc/online-ctc-model.h" +#include "sherpa-mnn/csrc/online-recognizer-impl.h" +#include "sherpa-mnn/csrc/symbol-table.h" + +namespace sherpa_mnn { + +OnlineRecognizerResult ConvertCtc(const OnlineCtcDecoderResult &src, + const SymbolTable &sym_table, + float frame_shift_ms, + int32_t subsampling_factor, int32_t segment, + int32_t frames_since_start) { + OnlineRecognizerResult r; + r.tokens.reserve(src.tokens.size()); + r.timestamps.reserve(src.tokens.size()); + + std::string text; + for (auto i : src.tokens) { + auto sym = sym_table[i]; + + text.append(sym); + + if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { + // for bpe models with byte_fallback + // (but don't rewrite printable characters 0x20..0x7e, + // which collide with standard BPE units) + std::ostringstream os; + os << "<0x" << std::hex << std::uppercase + << (static_cast(sym[0]) & 0xff) << ">"; + sym = os.str(); + } + + r.tokens.push_back(std::move(sym)); + } + + if (sym_table.IsByteBpe()) { + text = sym_table.DecodeByteBpe(text); + } + + r.text = std::move(text); + + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; + for (auto t : src.timestamps) { + float time = frame_shift_s * t; + r.timestamps.push_back(time); + } + + r.segment = segment; + r.words = std::move(src.words); + r.start_time = frames_since_start * frame_shift_ms / 1000.; + + return r; +} + +class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { + public: + explicit OnlineRecognizerCtcImpl(const OnlineRecognizerConfig &config) + : OnlineRecognizerImpl(config), + config_(config), + model_(OnlineCtcModel::Create(config.model_config)), + endpoint_(config_.endpoint_config) { + if (!config.model_config.tokens_buf.empty()) { + sym_ = SymbolTable(config.model_config.tokens_buf, false); + } else { + /// assuming tokens_buf and tokens are guaranteed not being both empty + sym_ = SymbolTable(config.model_config.tokens, true); + } + + if (!config.model_config.wenet_ctc.model.empty()) { + // WeNet CTC models assume input samples are in the range + // [-32768, 32767], so we set normalize_samples to false + config_.feat_config.normalize_samples = false; + } + + InitDecoder(); + } + + template + explicit OnlineRecognizerCtcImpl(Manager *mgr, + const OnlineRecognizerConfig &config) + : OnlineRecognizerImpl(mgr, config), + config_(config), + model_(OnlineCtcModel::Create(mgr, config.model_config)), + sym_(mgr, config.model_config.tokens), + endpoint_(config_.endpoint_config) { + if (!config.model_config.wenet_ctc.model.empty()) { + // WeNet CTC models assume input samples are in the range + // [-32768, 32767], so we set normalize_samples to false + config_.feat_config.normalize_samples = false; + } + + InitDecoder(); + } + + std::unique_ptr CreateStream() const override { + auto stream = std::make_unique(config_.feat_config); + stream->SetStates(model_->GetInitStates()); + stream->SetFasterDecoder(decoder_->CreateFasterDecoder()); + + return stream; + } + + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() + model_->ChunkLength() < + s->NumFramesReady(); + } + + void DecodeStreams(OnlineStream **ss, int32_t n) const override { + if (n == 1 || !model_->SupportBatchProcessing()) { + for (int32_t i = 0; i != n; ++i) { + DecodeStream(ss[i]); + } + return; + } + + // batch processing + int32_t chunk_length = model_->ChunkLength(); + int32_t chunk_shift = model_->ChunkShift(); + + int32_t feat_dim = ss[0]->FeatureDim(); + + std::vector results(n); + std::vector features_vec(n * chunk_length * feat_dim); + std::vector> states_vec(n); + std::vector all_processed_frames(n); + + for (int32_t i = 0; i != n; ++i) { + const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); + std::vector features = + ss[i]->GetFrames(num_processed_frames, chunk_length); + + // Question: should num_processed_frames include chunk_shift? + ss[i]->GetNumProcessedFrames() += chunk_shift; + + std::copy(features.begin(), features.end(), + features_vec.data() + i * chunk_length * feat_dim); + + results[i] = std::move(ss[i]->GetCtcResult()); + states_vec[i] = std::move(ss[i]->GetStates()); + all_processed_frames[i] = num_processed_frames; + } + + auto memory_info = + (MNNAllocator*)(nullptr); + + std::array x_shape{n, chunk_length, feat_dim}; + + MNN::Express::VARP x = MNNUtilsCreateTensor(memory_info, features_vec.data(), + features_vec.size(), x_shape.data(), + x_shape.size()); + + auto states = model_->StackStates(std::move(states_vec)); + int32_t num_states = states.size(); + auto out = model_->Forward(std::move(x), std::move(states)); + std::vector out_states; + out_states.reserve(num_states); + + for (int32_t k = 1; k != num_states + 1; ++k) { + out_states.push_back(std::move(out[k])); + } + + std::vector> next_states = + model_->UnStackStates(std::move(out_states)); + + std::vector log_probs_shape = + out[0]->getInfo()->dim; + decoder_->Decode(out[0]->readMap(), log_probs_shape[0], + log_probs_shape[1], log_probs_shape[2], &results, ss, n); + + for (int32_t k = 0; k != n; ++k) { + ss[k]->SetCtcResult(results[k]); + ss[k]->SetStates(std::move(next_states[k])); + } + } + + OnlineRecognizerResult GetResult(OnlineStream *s) const override { + OnlineCtcDecoderResult decoder_result = s->GetCtcResult(); + + // TODO(fangjun): Remember to change these constants if needed + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = 4; + auto r = + ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor, + s->GetCurrentSegment(), s->GetNumFramesSinceStart()); + r.text = ApplyInverseTextNormalization(r.text); + return r; + } + + bool IsEndpoint(OnlineStream *s) const override { + if (!config_.enable_endpoint) { + return false; + } + + int32_t num_processed_frames = s->GetNumProcessedFrames(); + + // frame shift is 10 milliseconds + float frame_shift_in_seconds = 0.01; + + // subsampling factor is 4 + int32_t trailing_silence_frames = s->GetCtcResult().num_trailing_blanks * 4; + + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, + frame_shift_in_seconds); + } + + void Reset(OnlineStream *s) const override { + // segment is incremented only when the last + // result is not empty + const auto &r = s->GetCtcResult(); + if (!r.tokens.empty()) { + s->GetCurrentSegment() += 1; + } + + // clear result + s->SetCtcResult({}); + + // clear states + s->SetStates(model_->GetInitStates()); + + s->GetFasterDecoderProcessedFrames() = 0; + + // Note: We only update counters. The underlying audio samples + // are not discarded. + s->Reset(); + } + + private: + void InitDecoder() { + if (!sym_.Contains("") && !sym_.Contains("") && + !sym_.Contains("")) { + SHERPA_ONNX_LOGE( + "We expect that tokens.txt contains " + "the symbol or or and its ID."); + exit(-1); + } + + int32_t blank_id = 0; + if (sym_.Contains("")) { + blank_id = sym_[""]; + } else if (sym_.Contains("")) { + // for tdnn models of the yesno recipe from icefall + blank_id = sym_[""]; + } else if (sym_.Contains("")) { + // for WeNet CTC models + blank_id = sym_[""]; + } + + if (!config_.ctc_fst_decoder_config.graph.empty()) { + decoder_ = std::make_unique( + config_.ctc_fst_decoder_config, blank_id); + } else if (config_.decoding_method == "greedy_search") { + decoder_ = std::make_unique(blank_id); + } else { + SHERPA_ONNX_LOGE( + "Unsupported decoding method: %s for streaming CTC models", + config_.decoding_method.c_str()); + exit(-1); + } + } + + void DecodeStream(OnlineStream *s) const { + int32_t chunk_length = model_->ChunkLength(); + int32_t chunk_shift = model_->ChunkShift(); + + int32_t feat_dim = s->FeatureDim(); + + const auto num_processed_frames = s->GetNumProcessedFrames(); + std::vector frames = + s->GetFrames(num_processed_frames, chunk_length); + s->GetNumProcessedFrames() += chunk_shift; + + auto memory_info = + (MNNAllocator*)(nullptr); + + std::array x_shape{1, chunk_length, feat_dim}; + MNN::Express::VARP x = + MNNUtilsCreateTensor(memory_info, frames.data(), frames.size(), + x_shape.data(), x_shape.size()); + auto out = model_->Forward(std::move(x), std::move(s->GetStates())); + int32_t num_states = static_cast(out.size()) - 1; + + std::vector states; + states.reserve(num_states); + + for (int32_t i = 0; i != num_states; ++i) { + states.push_back(std::move(out[i + 1])); + } + s->SetStates(std::move(states)); + + std::vector results(1); + results[0] = std::move(s->GetCtcResult()); + + std::vector log_probs_shape = + out[0]->getInfo()->dim; + decoder_->Decode(out[0]->readMap(), log_probs_shape[0], + log_probs_shape[1], log_probs_shape[2], &results, &s, 1); + s->SetCtcResult(results[0]); + } + + private: + OnlineRecognizerConfig config_; + std::unique_ptr model_; + std::unique_ptr decoder_; + SymbolTable sym_; + Endpoint endpoint_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_CTC_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-impl.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-impl.cc new file mode 100644 index 00000000..3144bac6 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-impl.cc @@ -0,0 +1,249 @@ +// sherpa-mnn/csrc/online-recognizer-impl.cc +// +// Copyright (c) 2023-2025 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-recognizer-impl.h" + +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "fst/extensions/far/far.h" +#include "kaldifst/csrc/kaldi-fst-io.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-recognizer-ctc-impl.h" +#include "sherpa-mnn/csrc/online-recognizer-paraformer-impl.h" +#include "sherpa-mnn/csrc/online-recognizer-transducer-impl.h" +#include "sherpa-mnn/csrc/online-recognizer-transducer-nemo-impl.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/text-utils.h" + +#if SHERPA_ONNX_ENABLE_RKNN +#include "sherpa-mnn/csrc/rknn/online-recognizer-ctc-rknn-impl.h" +#include "sherpa-mnn/csrc/rknn/online-recognizer-transducer-rknn-impl.h" +#endif + +namespace sherpa_mnn { + +std::unique_ptr OnlineRecognizerImpl::Create( + const OnlineRecognizerConfig &config) { + if (config.model_config.provider_config.provider == "rknn") { +#if SHERPA_ONNX_ENABLE_RKNN + // Currently, only zipformer v1 is suported for rknn + if (config.model_config.transducer.encoder.empty() && + config.model_config.zipformer2_ctc.model.empty()) { + SHERPA_ONNX_LOGE( + "Only Zipformer transducers and CTC models are currently supported " + "by rknn. Fallback to CPU"); + } else if (!config.model_config.transducer.encoder.empty()) { + return std::make_unique(config); + } else if (!config.model_config.zipformer2_ctc.model.empty()) { + return std::make_unique(config); + } +#else + SHERPA_ONNX_LOGE( + "Please rebuild sherpa-mnn with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " + "want to use rknn. Fallback to CPU"); +#endif + } + + if (!config.model_config.transducer.encoder.empty()) { + MNNEnv env; + + std::shared_ptr sess_opts; + + + + auto decoder_model = ReadFile(config.model_config.transducer.decoder); + auto sess = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)decoder_model.data(), + decoder_model.size(), sess_opts)); + + size_t node_count = sess->getInfo()->outputNames.size(); + + if (node_count == 1) { + return std::make_unique(config); + } else { + return std::make_unique(config); + } + } + + if (!config.model_config.paraformer.encoder.empty()) { + return std::make_unique(config); + } + + if (!config.model_config.wenet_ctc.model.empty() || + !config.model_config.zipformer2_ctc.model.empty() || + !config.model_config.nemo_ctc.model.empty()) { + return std::make_unique(config); + } + + SHERPA_ONNX_LOGE("Please specify a model"); + exit(-1); +} + +template +std::unique_ptr OnlineRecognizerImpl::Create( + Manager *mgr, const OnlineRecognizerConfig &config) { + if (!config.model_config.transducer.encoder.empty()) { + MNNEnv env; + + std::shared_ptr sess_opts; + + + + auto decoder_model = ReadFile(mgr, config.model_config.transducer.decoder); + auto sess = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)decoder_model.data(), + decoder_model.size(), sess_opts)); + + size_t node_count = sess->getInfo()->outputNames.size(); + + if (node_count == 1) { + return std::make_unique(mgr, config); + } else { + return std::make_unique(mgr, config); + } + } + + if (!config.model_config.paraformer.encoder.empty()) { + return std::make_unique(mgr, config); + } + + if (!config.model_config.wenet_ctc.model.empty() || + !config.model_config.zipformer2_ctc.model.empty() || + !config.model_config.nemo_ctc.model.empty()) { + return std::make_unique(mgr, config); + } + + SHERPA_ONNX_LOGE("Please specify a model"); + exit(-1); +} + +OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config) + : config_(config) { + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + itn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); + } + itn_list_.push_back(std::make_unique(f)); + } + } + + if (!config.rule_fars.empty()) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("Loading FST archives"); + } + std::vector files; + SplitStringToVector(config.rule_fars, ",", false, &files); + + itn_list_.reserve(files.size() + itn_list_.size()); + + for (const auto &f : files) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); + } + std::unique_ptr> reader( + fst::FarReader::Open(f)); + for (; !reader->Done(); reader->Next()) { + std::unique_ptr r( + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); + + itn_list_.push_back( + std::make_unique(std::move(r))); + } + } + + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("FST archives loaded!"); + } + } +} + +template +OnlineRecognizerImpl::OnlineRecognizerImpl(Manager *mgr, + const OnlineRecognizerConfig &config) + : config_(config) { + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + itn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); + } + auto buf = ReadFile(mgr, f); + std::istrstream is(buf.data(), buf.size()); + itn_list_.push_back(std::make_unique(is)); + } + } + + if (!config.rule_fars.empty()) { + std::vector files; + SplitStringToVector(config.rule_fars, ",", false, &files); + itn_list_.reserve(files.size() + itn_list_.size()); + + for (const auto &f : files) { + if (config.model_config.debug) { + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); + } + + auto buf = ReadFile(mgr, f); + + std::unique_ptr s( + new std::istrstream(buf.data(), buf.size())); + + std::unique_ptr> reader( + fst::FarReader::Open(std::move(s))); + + for (; !reader->Done(); reader->Next()) { + std::unique_ptr r( + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); + + itn_list_.push_back( + std::make_unique(std::move(r))); + } // for (; !reader->Done(); reader->Next()) + } // for (const auto &f : files) + } // if (!config.rule_fars.empty()) +} + +std::string OnlineRecognizerImpl::ApplyInverseTextNormalization( + std::string text) const { + text = RemoveInvalidUtf8Sequences(text); + + if (!itn_list_.empty()) { + for (const auto &tn : itn_list_) { + text = tn->Normalize(text); + } + } + + return text; +} + +#if __ANDROID_API__ >= 9 +template OnlineRecognizerImpl::OnlineRecognizerImpl( + AAssetManager *mgr, const OnlineRecognizerConfig &config); + +template std::unique_ptr OnlineRecognizerImpl::Create( + AAssetManager *mgr, const OnlineRecognizerConfig &config); +#endif + +#if __OHOS__ +template OnlineRecognizerImpl::OnlineRecognizerImpl( + NativeResourceManager *mgr, const OnlineRecognizerConfig &config); + +template std::unique_ptr OnlineRecognizerImpl::Create( + NativeResourceManager *mgr, const OnlineRecognizerConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-impl.h new file mode 100644 index 00000000..b45e416b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-impl.h @@ -0,0 +1,71 @@ +// sherpa-mnn/csrc/online-recognizer-impl.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_ + +#include +#include +#include + +#include "kaldifst/csrc/text-normalizer.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-recognizer.h" +#include "sherpa-mnn/csrc/online-stream.h" + +namespace sherpa_mnn { + +class OnlineRecognizerImpl { + public: + explicit OnlineRecognizerImpl(const OnlineRecognizerConfig &config); + + static std::unique_ptr Create( + const OnlineRecognizerConfig &config); + + template + OnlineRecognizerImpl(Manager *mgr, const OnlineRecognizerConfig &config); + + template + static std::unique_ptr Create( + Manager *mgr, const OnlineRecognizerConfig &config); + + virtual ~OnlineRecognizerImpl() = default; + + virtual std::unique_ptr CreateStream() const = 0; + + virtual std::unique_ptr CreateStream( + const std::string &hotwords) const { + SHERPA_ONNX_LOGE("Only transducer models support contextual biasing."); + exit(-1); + } + + virtual bool IsReady(OnlineStream *s) const = 0; + + virtual void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const { + // ToDo extending to other models + SHERPA_ONNX_LOGE("Only zipformer2 model supports Warm up for now."); + exit(-1); + } + + virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0; + + virtual OnlineRecognizerResult GetResult(OnlineStream *s) const = 0; + + virtual bool IsEndpoint(OnlineStream *s) const = 0; + + virtual void Reset(OnlineStream *s) const = 0; + + std::string ApplyInverseTextNormalization(std::string text) const; + + private: + OnlineRecognizerConfig config_; + // for inverse text normalization. Used only if + // config.rule_fsts is not empty or + // config.rule_fars is not empty + std::vector> itn_list_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-paraformer-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-paraformer-impl.h new file mode 100644 index 00000000..8642661d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-paraformer-impl.h @@ -0,0 +1,475 @@ +// sherpa-mnn/csrc/online-recognizer-paraformer-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_ + +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-lm.h" +#include "sherpa-mnn/csrc/online-paraformer-decoder.h" +#include "sherpa-mnn/csrc/online-paraformer-model.h" +#include "sherpa-mnn/csrc/online-recognizer-impl.h" +#include "sherpa-mnn/csrc/online-recognizer.h" +#include "sherpa-mnn/csrc/symbol-table.h" + +namespace sherpa_mnn { + +static OnlineRecognizerResult Convert(const OnlineParaformerDecoderResult &src, + const SymbolTable &sym_table) { + OnlineRecognizerResult r; + r.tokens.reserve(src.tokens.size()); + + std::string text; + + // When the current token ends with "@@" we set mergeable to true + bool mergeable = false; + + for (int32_t i = 0; i != src.tokens.size(); ++i) { + auto sym = sym_table[src.tokens[i]]; + r.tokens.push_back(sym); + + if ((sym.back() != '@') || (sym.size() > 2 && sym[sym.size() - 2] != '@')) { + // sym does not end with "@@" + const uint8_t *p = reinterpret_cast(sym.c_str()); + if (p[0] < 0x80) { + // an ascii + if (mergeable) { + mergeable = false; + text.append(sym); + } else { + text.append(" "); + text.append(sym); + } + } else { + // not an ascii + mergeable = false; + + if (i > 0) { + const uint8_t p = reinterpret_cast( + sym_table[src.tokens[i - 1]].c_str())[0]; + if (p < 0x80) { + // put a space between ascii and non-ascii + text.append(" "); + } + } + text.append(sym); + } + } else { + // this sym ends with @@ + sym = std::string(sym.data(), sym.size() - 2); + if (mergeable) { + text.append(sym); + } else { + text.append(" "); + text.append(sym); + mergeable = true; + } + } + } + r.text = std::move(text); + + return r; +} + +// y[i] += x[i] * scale +static void ScaleAddInPlace(const float *x, int32_t n, float scale, float *y) { + for (int32_t i = 0; i != n; ++i) { + y[i] += x[i] * scale; + } +} + +// y[i] = x[i] * scale +static void Scale(const float *x, int32_t n, float scale, float *y) { + for (int32_t i = 0; i != n; ++i) { + y[i] = x[i] * scale; + } +} + +class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { + public: + explicit OnlineRecognizerParaformerImpl(const OnlineRecognizerConfig &config) + : OnlineRecognizerImpl(config), + config_(config), + model_(config.model_config), + endpoint_(config_.endpoint_config) { + if (!config.model_config.tokens_buf.empty()) { + sym_ = SymbolTable(config.model_config.tokens_buf, false); + } else { + /// assuming tokens_buf and tokens are guaranteed not being both empty + sym_ = SymbolTable(config.model_config.tokens, true); + } + + if (config.decoding_method != "greedy_search") { + SHERPA_ONNX_LOGE( + "Unsupported decoding method: %s. Support only greedy_search at " + "present", + config.decoding_method.c_str()); + exit(-1); + } + + // Paraformer models assume input samples are in the range + // [-32768, 32767], so we set normalize_samples to false + config_.feat_config.normalize_samples = false; + } + + template + explicit OnlineRecognizerParaformerImpl(Manager *mgr, + const OnlineRecognizerConfig &config) + : OnlineRecognizerImpl(mgr, config), + config_(config), + model_(mgr, config.model_config), + sym_(mgr, config.model_config.tokens), + endpoint_(config_.endpoint_config) { + if (config.decoding_method != "greedy_search") { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config.decoding_method.c_str()); + exit(-1); + } + + // Paraformer models assume input samples are in the range + // [-32768, 32767], so we set normalize_samples to false + config_.feat_config.normalize_samples = false; + } + + OnlineRecognizerParaformerImpl(const OnlineRecognizerParaformerImpl &) = + delete; + + OnlineRecognizerParaformerImpl operator=( + const OnlineRecognizerParaformerImpl &) = delete; + + std::unique_ptr CreateStream() const override { + auto stream = std::make_unique(config_.feat_config); + + OnlineParaformerDecoderResult r; + stream->SetParaformerResult(r); + + return stream; + } + + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() + chunk_size_ < s->NumFramesReady(); + } + + void DecodeStreams(OnlineStream **ss, int32_t n) const override { + // TODO(fangjun): Support batch size > 1 + for (int32_t i = 0; i != n; ++i) { + DecodeStream(ss[i]); + } + } + + OnlineRecognizerResult GetResult(OnlineStream *s) const override { + auto decoder_result = s->GetParaformerResult(); + + auto r = Convert(decoder_result, sym_); + r.text = ApplyInverseTextNormalization(r.text); + return r; + } + + bool IsEndpoint(OnlineStream *s) const override { + if (!config_.enable_endpoint) { + return false; + } + + const auto &result = s->GetParaformerResult(); + + int32_t num_processed_frames = s->GetNumProcessedFrames(); + + // frame shift is 10 milliseconds + float frame_shift_in_seconds = 0.01; + + int32_t trailing_silence_frames = + num_processed_frames - result.last_non_blank_frame_index; + + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, + frame_shift_in_seconds); + } + + void Reset(OnlineStream *s) const override { + OnlineParaformerDecoderResult r; + s->SetParaformerResult(r); + + s->GetStates().clear(); + s->GetParaformerEncoderOutCache().clear(); + s->GetParaformerAlphaCache().clear(); + + // s->GetParaformerFeatCache().clear(); + + // Note: We only update counters. The underlying audio samples + // are not discarded. + s->Reset(); + } + + private: + void DecodeStream(OnlineStream *s) const { + const auto num_processed_frames = s->GetNumProcessedFrames(); + std::vector frames = s->GetFrames(num_processed_frames, chunk_size_); + s->GetNumProcessedFrames() += chunk_size_ - 1; + + frames = ApplyLFR(frames); + ApplyCMVN(&frames); + PositionalEncoding(&frames, num_processed_frames / model_.LfrWindowShift()); + + int32_t feat_dim = model_.NegativeMean().size(); + + // We have scaled inv_stddev by sqrt(encoder_output_size) + // so the following line can be commented out + // frames *= encoder_output_size ** 0.5 + + // add overlap chunk + std::vector &feat_cache = s->GetParaformerFeatCache(); + if (feat_cache.empty()) { + int32_t n = (left_chunk_size_ + right_chunk_size_) * feat_dim; + feat_cache.resize(n, 0); + } + + frames.insert(frames.begin(), feat_cache.begin(), feat_cache.end()); + std::copy(frames.end() - feat_cache.size(), frames.end(), + feat_cache.begin()); + + int32_t num_frames = frames.size() / feat_dim; + + auto memory_info = + (MNNAllocator*)(nullptr); + + std::array x_shape{1, num_frames, feat_dim}; + MNN::Express::VARP x = + MNNUtilsCreateTensor(memory_info, frames.data(), frames.size(), + x_shape.data(), x_shape.size()); + + int x_len_shape = 1; + int32_t x_len_val = num_frames; + + MNN::Express::VARP x_length = + MNNUtilsCreateTensor(memory_info, &x_len_val, 1, &x_len_shape, 1); + + auto encoder_out_vec = + model_.ForwardEncoder(std::move(x), std::move(x_length)); + + // CIF search + auto &encoder_out = encoder_out_vec[0]; + auto &encoder_out_len = encoder_out_vec[1]; + auto &alpha = encoder_out_vec[2]; + + float *p_alpha = alpha->writeMap(); + + std::vector alpha_shape = + alpha->getInfo()->dim; + + std::fill(p_alpha, p_alpha + left_chunk_size_, 0); + std::fill(p_alpha + alpha_shape[1] - right_chunk_size_, + p_alpha + alpha_shape[1], 0); + + const float *p_encoder_out = encoder_out->readMap(); + + std::vector encoder_out_shape = + encoder_out->getInfo()->dim; + + std::vector &initial_hidden = s->GetParaformerEncoderOutCache(); + if (initial_hidden.empty()) { + initial_hidden.resize(encoder_out_shape[2]); + } + + std::vector &alpha_cache = s->GetParaformerAlphaCache(); + if (alpha_cache.empty()) { + alpha_cache.resize(1); + } + + std::vector acoustic_embedding; + acoustic_embedding.reserve(encoder_out_shape[1] * encoder_out_shape[2]); + + float threshold = 1.0; + + float integrate = alpha_cache[0]; + + for (int32_t i = 0; i != encoder_out_shape[1]; ++i) { + float this_alpha = p_alpha[i]; + if (integrate + this_alpha < threshold) { + integrate += this_alpha; + ScaleAddInPlace(p_encoder_out + i * encoder_out_shape[2], + encoder_out_shape[2], this_alpha, + initial_hidden.data()); + continue; + } + + // fire + ScaleAddInPlace(p_encoder_out + i * encoder_out_shape[2], + encoder_out_shape[2], threshold - integrate, + initial_hidden.data()); + acoustic_embedding.insert(acoustic_embedding.end(), + initial_hidden.begin(), initial_hidden.end()); + integrate += this_alpha - threshold; + + Scale(p_encoder_out + i * encoder_out_shape[2], encoder_out_shape[2], + integrate, initial_hidden.data()); + } + + alpha_cache[0] = integrate; + + if (acoustic_embedding.empty()) { + return; + } + + auto &states = s->GetStates(); + if (states.empty()) { + states.reserve(model_.DecoderNumBlocks()); + + std::array shape{1, model_.EncoderOutputSize(), + model_.DecoderKernelSize() - 1}; + + int32_t num_bytes = sizeof(float) * shape[0] * shape[1] * shape[2]; + + for (int32_t i = 0; i != model_.DecoderNumBlocks(); ++i) { + MNN::Express::VARP this_state = MNNUtilsCreateTensor( + model_.Allocator(), shape.data(), shape.size()); + + memset(this_state->writeMap(), 0, num_bytes); + + states.push_back(std::move(this_state)); + } + } + + int32_t num_tokens = acoustic_embedding.size() / initial_hidden.size(); + std::array acoustic_embedding_shape{ + 1, num_tokens, static_cast(initial_hidden.size())}; + + MNN::Express::VARP acoustic_embedding_tensor = MNNUtilsCreateTensor( + memory_info, acoustic_embedding.data(), acoustic_embedding.size(), + acoustic_embedding_shape.data(), acoustic_embedding_shape.size()); + + std::array acoustic_embedding_length_shape{1}; + MNN::Express::VARP acoustic_embedding_length_tensor = MNNUtilsCreateTensor( + memory_info, &num_tokens, 1, acoustic_embedding_length_shape.data(), + acoustic_embedding_length_shape.size()); + + auto decoder_out_vec = model_.ForwardDecoder( + std::move(encoder_out), std::move(encoder_out_len), + std::move(acoustic_embedding_tensor), + std::move(acoustic_embedding_length_tensor), std::move(states)); + + states.reserve(model_.DecoderNumBlocks()); + for (int32_t i = 2; i != decoder_out_vec.size(); ++i) { + // TODO(fangjun): When we change chunk_size_, we need to + // slice decoder_out_vec[i] accordingly. + states.push_back(std::move(decoder_out_vec[i])); + } + + const auto &sample_ids = decoder_out_vec[1]; + const int *p_sample_ids = sample_ids->readMap(); + + bool non_blank_detected = false; + + auto &result = s->GetParaformerResult(); + + for (int32_t i = 0; i != num_tokens; ++i) { + int32_t t = p_sample_ids[i]; + if (t == 0) { + continue; + } + + non_blank_detected = true; + result.tokens.push_back(t); + } + + if (non_blank_detected) { + result.last_non_blank_frame_index = num_processed_frames; + } + } + + std::vector ApplyLFR(const std::vector &in) const { + int32_t lfr_window_size = model_.LfrWindowSize(); + int32_t lfr_window_shift = model_.LfrWindowShift(); + int32_t in_feat_dim = config_.feat_config.feature_dim; + + int32_t in_num_frames = in.size() / in_feat_dim; + int32_t out_num_frames = + (in_num_frames - lfr_window_size) / lfr_window_shift + 1; + int32_t out_feat_dim = in_feat_dim * lfr_window_size; + + std::vector out(out_num_frames * out_feat_dim); + + const float *p_in = in.data(); + float *p_out = out.data(); + + for (int32_t i = 0; i != out_num_frames; ++i) { + std::copy(p_in, p_in + out_feat_dim, p_out); + + p_out += out_feat_dim; + p_in += lfr_window_shift * in_feat_dim; + } + + return out; + } + + void ApplyCMVN(std::vector *v) const { + const std::vector &neg_mean = model_.NegativeMean(); + const std::vector &inv_stddev = model_.InverseStdDev(); + + int32_t dim = neg_mean.size(); + int32_t num_frames = v->size() / dim; + + float *p = v->data(); + + for (int32_t i = 0; i != num_frames; ++i) { + for (int32_t k = 0; k != dim; ++k) { + p[k] = (p[k] + neg_mean[k]) * inv_stddev[k]; + } + + p += dim; + } + } + + void PositionalEncoding(std::vector *v, int32_t t_offset) const { + int32_t lfr_window_size = model_.LfrWindowSize(); + int32_t in_feat_dim = config_.feat_config.feature_dim; + + int32_t feat_dim = in_feat_dim * lfr_window_size; + int32_t T = v->size() / feat_dim; + + // log(10000)/(7*80/2-1) == 0.03301197265941284 + // 7 is lfr_window_size + // 80 is in_feat_dim + // 7*80 is feat_dim + constexpr float kScale = -0.03301197265941284; + + for (int32_t t = 0; t != T; ++t) { + float *p = v->data() + t * feat_dim; + + int32_t offset = t + 1 + t_offset; + + for (int32_t d = 0; d < feat_dim / 2; ++d) { + float inv_timescale = offset * std::exp(d * kScale); + + float sin_d = std::sin(inv_timescale); + float cos_d = std::cos(inv_timescale); + + p[d] += sin_d; + p[d + feat_dim / 2] += cos_d; + } + } + } + + private: + OnlineRecognizerConfig config_; + OnlineParaformerModel model_; + SymbolTable sym_; + Endpoint endpoint_; + + // 0.61 seconds + int32_t chunk_size_ = 61; + // (61 - 7) / 6 + 1 = 10 + + int32_t left_chunk_size_ = 5; + int32_t right_chunk_size_ = 3; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-transducer-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-transducer-impl.h new file mode 100644 index 00000000..dbea3406 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-transducer-impl.h @@ -0,0 +1,502 @@ +// sherpa-mnn/csrc/online-recognizer-transducer-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ + +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-lm.h" +#include "sherpa-mnn/csrc/online-recognizer-impl.h" +#include "sherpa-mnn/csrc/online-recognizer.h" +#include "sherpa-mnn/csrc/online-transducer-decoder.h" +#include "sherpa-mnn/csrc/online-transducer-greedy-search-decoder.h" +#include "sherpa-mnn/csrc/online-transducer-model.h" +#include "sherpa-mnn/csrc/online-transducer-modified-beam-search-decoder.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/symbol-table.h" +#include "sherpa-mnn/csrc/utils.h" +#include "ssentencepiece/csrc/ssentencepiece.h" + +namespace sherpa_mnn { + +OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, + const SymbolTable &sym_table, + float frame_shift_ms, int32_t subsampling_factor, + int32_t segment, int32_t frames_since_start) { + OnlineRecognizerResult r; + r.tokens.reserve(src.tokens.size()); + r.timestamps.reserve(src.tokens.size()); + + std::string text; + for (auto i : src.tokens) { + auto sym = sym_table[i]; + + text.append(sym); + + if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { + // for bpe models with byte_fallback + // (but don't rewrite printable characters 0x20..0x7e, + // which collide with standard BPE units) + std::ostringstream os; + os << "<0x" << std::hex << std::uppercase + << (static_cast(sym[0]) & 0xff) << ">"; + sym = os.str(); + } + + r.tokens.push_back(std::move(sym)); + } + + if (sym_table.IsByteBpe()) { + text = sym_table.DecodeByteBpe(text); + } + + r.text = std::move(text); + + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; + for (auto t : src.timestamps) { + float time = frame_shift_s * t; + r.timestamps.push_back(time); + } + + r.ys_probs = std::move(src.ys_probs); + r.lm_probs = std::move(src.lm_probs); + r.context_scores = std::move(src.context_scores); + + r.segment = segment; + r.start_time = frames_since_start * frame_shift_ms / 1000.; + + return r; +} + +class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { + public: + explicit OnlineRecognizerTransducerImpl(const OnlineRecognizerConfig &config) + : OnlineRecognizerImpl(config), + config_(config), + model_(OnlineTransducerModel::Create(config.model_config)), + endpoint_(config_.endpoint_config) { + if (!config.model_config.tokens_buf.empty()) { + sym_ = SymbolTable(config.model_config.tokens_buf, false); + } else { + /// assuming tokens_buf and tokens are guaranteed not being both empty + sym_ = SymbolTable(config.model_config.tokens, true); + } + + if (sym_.Contains("")) { + unk_id_ = sym_[""]; + } + + model_->SetFeatureDim(config.feat_config.feature_dim); + + if (config.decoding_method == "modified_beam_search") { + if (!config_.model_config.bpe_vocab.empty()) { + bpe_encoder_ = std::make_unique( + config_.model_config.bpe_vocab); + } + + if (!config_.hotwords_buf.empty()) { + InitHotwordsFromBufStr(); + } else if (!config_.hotwords_file.empty()) { + InitHotwords(); + } + + if (!config_.lm_config.model.empty()) { + lm_ = OnlineLM::Create(config.lm_config); + } + + decoder_ = std::make_unique( + model_.get(), lm_.get(), config_.max_active_paths, + config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, + config_.blank_penalty, config_.temperature_scale); + + } else if (config.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + model_.get(), unk_id_, config_.blank_penalty, + config_.temperature_scale); + + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config.decoding_method.c_str()); + exit(-1); + } + } + + template + explicit OnlineRecognizerTransducerImpl(Manager *mgr, + const OnlineRecognizerConfig &config) + : OnlineRecognizerImpl(mgr, config), + config_(config), + model_(OnlineTransducerModel::Create(mgr, config.model_config)), + sym_(mgr, config.model_config.tokens), + endpoint_(config_.endpoint_config) { + if (sym_.Contains("")) { + unk_id_ = sym_[""]; + } + + model_->SetFeatureDim(config.feat_config.feature_dim); + + if (config.decoding_method == "modified_beam_search") { +#if 0 + // TODO(fangjun): Implement it + if (!config_.lm_config.model.empty()) { + lm_ = OnlineLM::Create(mgr, config.lm_config); + } +#endif + + if (!config_.model_config.bpe_vocab.empty()) { + auto buf = ReadFile(mgr, config_.model_config.bpe_vocab); + std::istringstream iss(std::string(buf.begin(), buf.end())); + bpe_encoder_ = std::make_unique(iss); + } + + if (!config_.hotwords_file.empty()) { + InitHotwords(mgr); + } + + decoder_ = std::make_unique( + model_.get(), lm_.get(), config_.max_active_paths, + config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, + config_.blank_penalty, config_.temperature_scale); + + } else if (config.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + model_.get(), unk_id_, config_.blank_penalty, + config_.temperature_scale); + + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config.decoding_method.c_str()); + exit(-1); + } + } + + std::unique_ptr CreateStream() const override { + auto stream = + std::make_unique(config_.feat_config, hotwords_graph_); + InitOnlineStream(stream.get()); + return stream; + } + + std::unique_ptr CreateStream( + const std::string &hotwords) const override { + auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); + std::istringstream is(hws); + std::vector> current; + std::vector current_scores; + if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, + bpe_encoder_.get(), ¤t, ¤t_scores)) { + SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", + hotwords.c_str()); + } + + int32_t num_default_hws = hotwords_.size(); + int32_t num_hws = current.size(); + + current.insert(current.end(), hotwords_.begin(), hotwords_.end()); + + if (!current_scores.empty() && !boost_scores_.empty()) { + current_scores.insert(current_scores.end(), boost_scores_.begin(), + boost_scores_.end()); + } else if (!current_scores.empty() && boost_scores_.empty()) { + current_scores.insert(current_scores.end(), num_default_hws, + config_.hotwords_score); + } else if (current_scores.empty() && !boost_scores_.empty()) { + current_scores.insert(current_scores.end(), num_hws, + config_.hotwords_score); + current_scores.insert(current_scores.end(), boost_scores_.begin(), + boost_scores_.end()); + } else { + // Do nothing. + } + + auto context_graph = std::make_shared( + current, config_.hotwords_score, current_scores); + auto stream = + std::make_unique(config_.feat_config, context_graph); + InitOnlineStream(stream.get()); + return stream; + } + + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() + model_->ChunkSize() < + s->NumFramesReady(); + } + + // Warmping up engine with wp: warm_up count and max-batch-size + void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const override { + auto max_batch_size = mbs; + if (warmup <= 0 || warmup > 100) { + return; + } + int32_t chunk_size = model_->ChunkSize(); + int32_t chunk_shift = model_->ChunkShift(); + int32_t feature_dim = 80; + std::vector results(max_batch_size); + std::vector features_vec(max_batch_size * chunk_size * feature_dim); + std::vector> states_vec(max_batch_size); + + auto memory_info = + (MNNAllocator*)(nullptr); + + std::array x_shape{max_batch_size, chunk_size, feature_dim}; + + for (int32_t i = 0; i != max_batch_size; ++i) { + states_vec[i] = model_->GetEncoderInitStates(); + results[i] = decoder_->GetEmptyResult(); + } + + for (int32_t i = 0; i != warmup; ++i) { + auto states = model_->StackStates(states_vec); + MNN::Express::VARP x = MNNUtilsCreateTensor(memory_info, features_vec.data(), + features_vec.size(), + x_shape.data(), x_shape.size()); + auto x_copy = Clone(model_->Allocator(), x); + auto pair = model_->RunEncoder(std::move(x), std::move(states), + std::move(x_copy)); + decoder_->Decode(std::move(pair.first), &results); + } + } + + void DecodeStreams(OnlineStream **ss, int32_t n) const override { + int32_t chunk_size = model_->ChunkSize(); + int32_t chunk_shift = model_->ChunkShift(); + + int32_t feature_dim = ss[0]->FeatureDim(); + + std::vector results(n); + std::vector features_vec(n * chunk_size * feature_dim); + std::vector> states_vec(n); + std::vector all_processed_frames(n); + bool has_context_graph = false; + + for (int32_t i = 0; i != n; ++i) { + if (!has_context_graph && ss[i]->GetContextGraph()) { + has_context_graph = true; + } + + const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); + std::vector features = + ss[i]->GetFrames(num_processed_frames, chunk_size); + + // Question: should num_processed_frames include chunk_shift? + ss[i]->GetNumProcessedFrames() += chunk_shift; + + std::copy(features.begin(), features.end(), + features_vec.data() + i * chunk_size * feature_dim); + + results[i] = std::move(ss[i]->GetResult()); + states_vec[i] = std::move(ss[i]->GetStates()); + all_processed_frames[i] = num_processed_frames; + } + + auto memory_info = + (MNNAllocator*)(nullptr); + + std::array x_shape{n, chunk_size, feature_dim}; + + MNN::Express::VARP x = MNNUtilsCreateTensor(memory_info, features_vec.data(), + features_vec.size(), x_shape.data(), + x_shape.size()); + + std::array processed_frames_shape{ + static_cast(all_processed_frames.size())}; + + MNN::Express::VARP processed_frames = MNNUtilsCreateTensor( + memory_info, all_processed_frames.data(), all_processed_frames.size(), + processed_frames_shape.data(), processed_frames_shape.size()); + + auto states = model_->StackStates(states_vec); + + auto pair = model_->RunEncoder(std::move(x), std::move(states), + std::move(processed_frames)); + + if (has_context_graph) { + decoder_->Decode(std::move(pair.first), ss, &results); + } else { + decoder_->Decode(std::move(pair.first), &results); + } + + std::vector> next_states = + model_->UnStackStates(pair.second); + + for (int32_t i = 0; i != n; ++i) { + ss[i]->SetResult(results[i]); + ss[i]->SetStates(std::move(next_states[i])); + } + } + + OnlineRecognizerResult GetResult(OnlineStream *s) const override { + OnlineTransducerDecoderResult decoder_result = s->GetResult(); + decoder_->StripLeadingBlanks(&decoder_result); + + // TODO(fangjun): Remember to change these constants if needed + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = 4; + auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, + s->GetCurrentSegment(), s->GetNumFramesSinceStart()); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + return r; + } + + bool IsEndpoint(OnlineStream *s) const override { + if (!config_.enable_endpoint) { + return false; + } + + int32_t num_processed_frames = s->GetNumProcessedFrames(); + + // frame shift is 10 milliseconds + float frame_shift_in_seconds = 0.01; + + // subsampling factor is 4 + int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 4; + + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, + frame_shift_in_seconds); + } + + void Reset(OnlineStream *s) const override { + int32_t context_size = model_->ContextSize(); + + { + // segment is incremented only when the last + // result is not empty, contains non-blanks and longer than context_size) + const auto &r = s->GetResult(); + if (!r.tokens.empty() && r.tokens.back() != 0 && + r.tokens.size() > context_size) { + s->GetCurrentSegment() += 1; + } + } + + // reset encoder states + // s->SetStates(model_->GetEncoderInitStates()); + + auto r = decoder_->GetEmptyResult(); + auto last_result = s->GetResult(); + // if last result is not empty, then + // preserve last tokens as the context for next result + if (static_cast(last_result.tokens.size()) > context_size) { + std::vector context(last_result.tokens.end() - context_size, + last_result.tokens.end()); + + Hypotheses context_hyp({{context, 0}}); + r.hyps = std::move(context_hyp); + r.tokens = std::move(context); + } + + if (config_.decoding_method == "modified_beam_search" && + nullptr != s->GetContextGraph()) { + for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) { + it->second.context_state = s->GetContextGraph()->Root(); + } + } + + s->SetResult(r); + + // Note: We only update counters. The underlying audio samples + // are not discarded. + s->Reset(); + } + + private: + void InitHotwords() { + // each line in hotwords_file contains space-separated words + + std::ifstream is(config_.hotwords_file); + if (!is) { + SHERPA_ONNX_LOGE("Open hotwords file failed: %s", + config_.hotwords_file.c_str()); + exit(-1); + } + + if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + SHERPA_ONNX_LOGE( + "Failed to encode some hotwords, skip them already, see logs above " + "for details."); + } + hotwords_graph_ = std::make_shared( + hotwords_, config_.hotwords_score, boost_scores_); + } + + template + void InitHotwords(Manager *mgr) { + // each line in hotwords_file contains space-separated words + + auto buf = ReadFile(mgr, config_.hotwords_file); + + std::istringstream is(std::string(buf.begin(), buf.end())); + + if (!is) { + SHERPA_ONNX_LOGE("Open hotwords file failed: %s", + config_.hotwords_file.c_str()); + exit(-1); + } + + if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + SHERPA_ONNX_LOGE( + "Failed to encode some hotwords, skip them already, see logs above " + "for details."); + } + hotwords_graph_ = std::make_shared( + hotwords_, config_.hotwords_score, boost_scores_); + } + + void InitHotwordsFromBufStr() { + // each line in hotwords_file contains space-separated words + + std::istringstream iss(config_.hotwords_buf); + if (!EncodeHotwords(iss, config_.model_config.modeling_unit, sym_, + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + SHERPA_ONNX_LOGE( + "Failed to encode some hotwords, skip them already, see logs above " + "for details."); + } + hotwords_graph_ = std::make_shared( + hotwords_, config_.hotwords_score, boost_scores_); + } + + void InitOnlineStream(OnlineStream *stream) const { + auto r = decoder_->GetEmptyResult(); + + if (config_.decoding_method == "modified_beam_search" && + nullptr != stream->GetContextGraph()) { + // r.hyps has only one element. + for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) { + it->second.context_state = stream->GetContextGraph()->Root(); + } + } + + stream->SetResult(r); + stream->SetStates(model_->GetEncoderInitStates()); + } + + private: + OnlineRecognizerConfig config_; + std::vector> hotwords_; + std::vector boost_scores_; + ContextGraphPtr hotwords_graph_; + std::unique_ptr bpe_encoder_; + std::unique_ptr model_; + std::unique_ptr lm_; + std::unique_ptr decoder_; + SymbolTable sym_; + Endpoint endpoint_; + int32_t unk_id_ = -1; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-transducer-nemo-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-transducer-nemo-impl.h new file mode 100644 index 00000000..53a470c4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer-transducer-nemo-impl.h @@ -0,0 +1,251 @@ +// sherpa-mnn/csrc/online-recognizer-transducer-nemo-impl.h +// +// Copyright (c) 2022-2024 Xiaomi Corporation +// Copyright (c) 2024 Sangeet Sagar + +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ + +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-recognizer-impl.h" +#include "sherpa-mnn/csrc/online-recognizer.h" +#include "sherpa-mnn/csrc/online-transducer-greedy-search-nemo-decoder.h" +#include "sherpa-mnn/csrc/online-transducer-nemo-model.h" +#include "sherpa-mnn/csrc/symbol-table.h" +#include "sherpa-mnn/csrc/transpose.h" +#include "sherpa-mnn/csrc/utils.h" + +namespace sherpa_mnn { + +// defined in ./online-recognizer-transducer-impl.h +OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src, + const SymbolTable &sym_table, + float frame_shift_ms, int32_t subsampling_factor, + int32_t segment, int32_t frames_since_start); + +class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { + public: + explicit OnlineRecognizerTransducerNeMoImpl( + const OnlineRecognizerConfig &config) + : OnlineRecognizerImpl(config), + config_(config), + endpoint_(config_.endpoint_config), + model_( + std::make_unique(config.model_config)) { + if (!config.model_config.tokens_buf.empty()) { + symbol_table_ = SymbolTable(config.model_config.tokens_buf, false); + } else { + /// assuming tokens_buf and tokens are guaranteed not being both empty + symbol_table_ = SymbolTable(config.model_config.tokens, true); + } + + if (config.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + model_.get(), config_.blank_penalty); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config.decoding_method.c_str()); + exit(-1); + } + PostInit(); + } + + template + explicit OnlineRecognizerTransducerNeMoImpl( + Manager *mgr, const OnlineRecognizerConfig &config) + : OnlineRecognizerImpl(mgr, config), + config_(config), + symbol_table_(mgr, config.model_config.tokens), + endpoint_(config_.endpoint_config), + model_(std::make_unique( + mgr, config.model_config)) { + if (config.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + model_.get(), config_.blank_penalty); + } else { + SHERPA_ONNX_LOGE("Unsupported decoding method: %s", + config.decoding_method.c_str()); + exit(-1); + } + + PostInit(); + } + + std::unique_ptr CreateStream() const override { + auto stream = std::make_unique(config_.feat_config); + InitOnlineStream(stream.get()); + return stream; + } + + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() + model_->ChunkSize() < + s->NumFramesReady(); + } + + OnlineRecognizerResult GetResult(OnlineStream *s) const override { + // TODO(fangjun): Remember to change these constants if needed + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = model_->SubsamplingFactor(); + auto r = Convert(s->GetResult(), symbol_table_, frame_shift_ms, + subsampling_factor, s->GetCurrentSegment(), + s->GetNumFramesSinceStart()); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + return r; + } + + bool IsEndpoint(OnlineStream *s) const override { + if (!config_.enable_endpoint) { + return false; + } + + int32_t num_processed_frames = s->GetNumProcessedFrames(); + + // frame shift is 10 milliseconds + float frame_shift_in_seconds = 0.01; + + int32_t trailing_silence_frames = + s->GetResult().num_trailing_blanks * model_->SubsamplingFactor(); + + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, + frame_shift_in_seconds); + } + + void Reset(OnlineStream *s) const override { + { + // segment is incremented only when the last + // result is not empty + const auto &r = s->GetResult(); + if (!r.tokens.empty()) { + s->GetCurrentSegment() += 1; + } + } + + s->SetResult({}); + + s->SetStates(model_->GetEncoderInitStates()); + + s->SetNeMoDecoderStates(model_->GetDecoderInitStates()); + + // Note: We only update counters. The underlying audio samples + // are not discarded. + s->Reset(); + } + + void DecodeStreams(OnlineStream **ss, int32_t n) const override { + int32_t chunk_size = model_->ChunkSize(); + int32_t chunk_shift = model_->ChunkShift(); + + int32_t feature_dim = ss[0]->FeatureDim(); + + std::vector features_vec(n * chunk_size * feature_dim); + std::vector> encoder_states(n); + + for (int32_t i = 0; i != n; ++i) { + const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); + std::vector features = + ss[i]->GetFrames(num_processed_frames, chunk_size); + + // Question: should num_processed_frames include chunk_shift? + ss[i]->GetNumProcessedFrames() += chunk_shift; + + std::copy(features.begin(), features.end(), + features_vec.data() + i * chunk_size * feature_dim); + + encoder_states[i] = std::move(ss[i]->GetStates()); + } + + auto memory_info = + (MNNAllocator*)(nullptr); + + std::array x_shape{n, chunk_size, feature_dim}; + + MNN::Express::VARP x = MNNUtilsCreateTensor(memory_info, features_vec.data(), + features_vec.size(), x_shape.data(), + x_shape.size()); + + auto states = model_->StackStates(std::move(encoder_states)); + int32_t num_states = states.size(); // num_states = 3 + auto t = model_->RunEncoder(std::move(x), std::move(states)); + // t[0] encoder_out, float tensor, (batch_size, dim, T) + // t[1] next states + + std::vector out_states; + out_states.reserve(num_states); + + for (int32_t k = 1; k != num_states + 1; ++k) { + out_states.push_back(std::move(t[k])); + } + + auto unstacked_states = model_->UnStackStates(std::move(out_states)); + for (int32_t i = 0; i != n; ++i) { + ss[i]->SetStates(std::move(unstacked_states[i])); + } + + MNN::Express::VARP encoder_out = Transpose12(model_->Allocator(), t[0]); + + decoder_->Decode(std::move(encoder_out), ss, n); + } + + void InitOnlineStream(OnlineStream *stream) const { + // set encoder states + stream->SetStates(model_->GetEncoderInitStates()); + + // set decoder states + stream->SetNeMoDecoderStates(model_->GetDecoderInitStates()); + } + + private: + void PostInit() { + config_.feat_config.nemo_normalize_type = + model_->FeatureNormalizationMethod(); + + config_.feat_config.low_freq = 0; + // config_.feat_config.high_freq = 8000; + config_.feat_config.is_librosa = true; + config_.feat_config.remove_dc_offset = false; + // config_.feat_config.window_type = "hann"; + config_.feat_config.dither = 0; + config_.feat_config.nemo_normalize_type = + model_->FeatureNormalizationMethod(); + + int32_t vocab_size = model_->VocabSize(); + + // check the blank ID + if (!symbol_table_.Contains("")) { + SHERPA_ONNX_LOGE("tokens.txt does not include the blank token "); + exit(-1); + } + + if (symbol_table_[""] != vocab_size - 1) { + SHERPA_ONNX_LOGE(" is not the last token!"); + exit(-1); + } + + if (symbol_table_.NumSymbols() != vocab_size) { + SHERPA_ONNX_LOGE("number of lines in tokens.txt %d != %d (vocab_size)", + symbol_table_.NumSymbols(), vocab_size); + exit(-1); + } + } + + private: + OnlineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; + Endpoint endpoint_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer.cc new file mode 100644 index 00000000..ca780781 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer.cc @@ -0,0 +1,259 @@ +// sherpa-mnn/csrc/online-recognizer.cc +// +// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2023 Pingfeng Luo + +#include "sherpa-mnn/csrc/online-recognizer.h" + +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/online-recognizer-impl.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +namespace { + +/// Helper for `OnlineRecognizerResult::AsJsonString()` +template +std::string VecToString(const std::vector &vec, int32_t precision = 6) { + std::ostringstream oss; + if (precision != 0) { + oss << std::fixed << std::setprecision(precision); + } + oss << "["; + std::string sep = ""; + for (const auto &item : vec) { + oss << sep << item; + sep = ", "; + } + oss << "]"; + return oss.str(); +} + +/// Helper for `OnlineRecognizerResult::AsJsonString()` +template <> // explicit specialization for T = std::string +std::string VecToString(const std::vector &vec, + int32_t) { // ignore 2nd arg + std::ostringstream oss; + oss << "["; + std::string sep = ""; + for (const auto &item : vec) { + oss << sep << std::quoted(item); + sep = ", "; + } + oss << "]"; + return oss.str(); +} + +} // namespace + +std::string OnlineRecognizerResult::AsJsonString() const { + std::ostringstream os; + os << "{ "; + os << "\"text\": " << std::quoted(text) << ", "; + os << "\"tokens\": " << VecToString(tokens) << ", "; + os << "\"timestamps\": " << VecToString(timestamps, 2) << ", "; + os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", "; + os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", "; + os << "\"context_scores\": " << VecToString(context_scores, 6) << ", "; + os << "\"segment\": " << segment << ", "; + os << "\"words\": " << VecToString(words, 0) << ", "; + os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time + << ", "; + os << "\"is_final\": " << (is_final ? "true" : "false"); + os << "}"; + return os.str(); +} + +void OnlineRecognizerConfig::Register(ParseOptions *po) { + feat_config.Register(po); + model_config.Register(po); + endpoint_config.Register(po); + lm_config.Register(po); + ctc_fst_decoder_config.Register(po); + + po->Register("enable-endpoint", &enable_endpoint, + "True to enable endpoint detection. False to disable it."); + po->Register("max-active-paths", &max_active_paths, + "beam size used in modified beam search."); + po->Register("blank-penalty", &blank_penalty, + "The penalty applied on blank symbol during decoding. " + "Note: It is a positive value. " + "Increasing value will lead to lower deletion at the cost" + "of higher insertions. " + "Currently only applicable for transducer models."); + po->Register("hotwords-score", &hotwords_score, + "The bonus score for each token in context word/phrase. " + "Used only when decoding_method is modified_beam_search"); + po->Register( + "hotwords-file", &hotwords_file, + "The file containing hotwords, one words/phrases per line, For example: " + "HELLO WORLD" + "你好世界"); + po->Register("decoding-method", &decoding_method, + "decoding method," + "now support greedy_search and modified_beam_search."); + po->Register("temperature-scale", &temperature_scale, + "Temperature scale for confidence computation in decoding."); + po->Register( + "rule-fsts", &rule_fsts, + "If not empty, it specifies fsts for inverse text normalization. " + "If there are multiple fsts, they are separated by a comma."); + + po->Register( + "rule-fars", &rule_fars, + "If not empty, it specifies fst archives for inverse text normalization. " + "If there are multiple archives, they are separated by a comma."); +} + +bool OnlineRecognizerConfig::Validate() const { + if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) { + if (max_active_paths <= 0) { + SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d", + max_active_paths); + return false; + } + + if (!lm_config.Validate()) { + return false; + } + } + + if (!hotwords_file.empty() && decoding_method != "modified_beam_search") { + SHERPA_ONNX_LOGE( + "Please use --decoding-method=modified_beam_search if you" + " provide --hotwords-file. Given --decoding-method=%s", + decoding_method.c_str()); + return false; + } + + if (!ctc_fst_decoder_config.graph.empty() && + !ctc_fst_decoder_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in ctc_fst_decoder_config"); + return false; + } + + if (!hotwords_file.empty() && !FileExists(hotwords_file)) { + SHERPA_ONNX_LOGE("--hotwords-file: '%s' does not exist", + hotwords_file.c_str()); + return false; + } + + if (!rule_fsts.empty()) { + std::vector files; + SplitStringToVector(rule_fsts, ",", false, &files); + for (const auto &f : files) { + if (!FileExists(f)) { + SHERPA_ONNX_LOGE("Rule fst '%s' does not exist. ", f.c_str()); + return false; + } + } + } + + if (!rule_fars.empty()) { + std::vector files; + SplitStringToVector(rule_fars, ",", false, &files); + for (const auto &f : files) { + if (!FileExists(f)) { + SHERPA_ONNX_LOGE("Rule far '%s' does not exist. ", f.c_str()); + return false; + } + } + } + + return model_config.Validate(); +} + +std::string OnlineRecognizerConfig::ToString() const { + std::ostringstream os; + + os << "OnlineRecognizerConfig("; + os << "feat_config=" << feat_config.ToString() << ", "; + os << "model_config=" << model_config.ToString() << ", "; + os << "lm_config=" << lm_config.ToString() << ", "; + os << "endpoint_config=" << endpoint_config.ToString() << ", "; + os << "ctc_fst_decoder_config=" << ctc_fst_decoder_config.ToString() << ", "; + os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; + os << "max_active_paths=" << max_active_paths << ", "; + os << "hotwords_score=" << hotwords_score << ", "; + os << "hotwords_file=\"" << hotwords_file << "\", "; + os << "decoding_method=\"" << decoding_method << "\", "; + os << "blank_penalty=" << blank_penalty << ", "; + os << "temperature_scale=" << temperature_scale << ", "; + os << "rule_fsts=\"" << rule_fsts << "\", "; + os << "rule_fars=\"" << rule_fars << "\")"; + + return os.str(); +} + +OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config) + : impl_(OnlineRecognizerImpl::Create(config)) {} + +template +OnlineRecognizer::OnlineRecognizer(Manager *mgr, + const OnlineRecognizerConfig &config) + : impl_(OnlineRecognizerImpl::Create(mgr, config)) {} + +OnlineRecognizer::~OnlineRecognizer() = default; + +std::unique_ptr OnlineRecognizer::CreateStream() const { + return impl_->CreateStream(); +} + +std::unique_ptr OnlineRecognizer::CreateStream( + const std::string &hotwords) const { + return impl_->CreateStream(hotwords); +} + +bool OnlineRecognizer::IsReady(OnlineStream *s) const { + return impl_->IsReady(s); +} + +void OnlineRecognizer::WarmpUpRecognizer(int32_t warmup, int32_t mbs) const { + if (warmup > 0) { + impl_->WarmpUpRecognizer(warmup, mbs); + } +} + +void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) const { + impl_->DecodeStreams(ss, n); +} + +OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) const { + return impl_->GetResult(s); +} + +bool OnlineRecognizer::IsEndpoint(OnlineStream *s) const { + return impl_->IsEndpoint(s); +} + +void OnlineRecognizer::Reset(OnlineStream *s) const { impl_->Reset(s); } + +#if __ANDROID_API__ >= 9 +template OnlineRecognizer::OnlineRecognizer( + AAssetManager *mgr, const OnlineRecognizerConfig &config); +#endif + +#if __OHOS__ +template OnlineRecognizer::OnlineRecognizer( + NativeResourceManager *mgr, const OnlineRecognizerConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer.h new file mode 100644 index 00000000..f381a9a7 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-recognizer.h @@ -0,0 +1,211 @@ +// sherpa-mnn/csrc/online-recognizer.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_ + +#include +#include +#include + +#include "sherpa-mnn/csrc/endpoint.h" +#include "sherpa-mnn/csrc/features.h" +#include "sherpa-mnn/csrc/online-ctc-fst-decoder-config.h" +#include "sherpa-mnn/csrc/online-lm-config.h" +#include "sherpa-mnn/csrc/online-model-config.h" +#include "sherpa-mnn/csrc/online-stream.h" +#include "sherpa-mnn/csrc/online-transducer-model-config.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OnlineRecognizerResult { + /// Recognition results. + /// For English, it consists of space separated words. + /// For Chinese, it consists of Chinese words without spaces. + /// Example 1: "hello world" + /// Example 2: "你好世界" + std::string text; + + /// Decoded results at the token level. + /// For instance, for BPE-based models it consists of a list of BPE tokens. + std::vector tokens; + + /// timestamps.size() == tokens.size() + /// timestamps[i] records the time in seconds when tokens[i] is decoded. + std::vector timestamps; + + std::vector ys_probs; //< log-prob scores from ASR model + std::vector lm_probs; //< log-prob scores from language model + // + /// log-domain scores from "hot-phrase" contextual boosting + std::vector context_scores; + + std::vector words; + + /// ID of this segment + /// When an endpoint is detected, it is incremented + int32_t segment = 0; + + /// Starting time of this segment. + /// When an endpoint is detected, it will change + float start_time = 0; + + /// True if the end of this segment is reached + bool is_final = false; + + /** Return a json string. + * + * The returned string contains: + * { + * "text": "The recognition result", + * "tokens": [x, x, x], + * "timestamps": [x, x, x], + * "ys_probs": [x, x, x], + * "lm_probs": [x, x, x], + * "context_scores": [x, x, x], + * "segment": x, + * "start_time": x, + * "is_final": true|false + * } + */ + std::string AsJsonString() const; +}; + +struct OnlineRecognizerConfig { + FeatureExtractorConfig feat_config; + OnlineModelConfig model_config; + OnlineLMConfig lm_config; + EndpointConfig endpoint_config; + OnlineCtcFstDecoderConfig ctc_fst_decoder_config; + bool enable_endpoint = true; + + std::string decoding_method = "greedy_search"; + // now support modified_beam_search and greedy_search + + // used only for modified_beam_search + int32_t max_active_paths = 4; + + /// used only for modified_beam_search + std::string hotwords_file; + float hotwords_score = 1.5; + + float blank_penalty = 0.0; + + float temperature_scale = 2.0; + + // If there are multiple rules, they are applied from left to right. + std::string rule_fsts; + + // If there are multiple FST archives, they are applied from left to right. + std::string rule_fars; + + /// used only for modified_beam_search, if hotwords_buf is non-empty, + /// the hotwords will be loaded from the buffered string instead of from the + /// "hotwords_file" + std::string hotwords_buf; + + OnlineRecognizerConfig() = default; + + OnlineRecognizerConfig( + const FeatureExtractorConfig &feat_config, + const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config, + const EndpointConfig &endpoint_config, + const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config, + bool enable_endpoint, const std::string &decoding_method, + int32_t max_active_paths, const std::string &hotwords_file, + float hotwords_score, float blank_penalty, float temperature_scale, + const std::string &rule_fsts, const std::string &rule_fars) + : feat_config(feat_config), + model_config(model_config), + lm_config(lm_config), + endpoint_config(endpoint_config), + ctc_fst_decoder_config(ctc_fst_decoder_config), + enable_endpoint(enable_endpoint), + decoding_method(decoding_method), + max_active_paths(max_active_paths), + hotwords_file(hotwords_file), + hotwords_score(hotwords_score), + blank_penalty(blank_penalty), + temperature_scale(temperature_scale), + rule_fsts(rule_fsts), + rule_fars(rule_fars) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +class OnlineRecognizerImpl; + +class OnlineRecognizer { + public: + explicit OnlineRecognizer(const OnlineRecognizerConfig &config); + + template + OnlineRecognizer(Manager *mgr, const OnlineRecognizerConfig &config); + + ~OnlineRecognizer(); + + /// Create a stream for decoding. + std::unique_ptr CreateStream() const; + + /** Create a stream for decoding. + * + * @param The hotwords for this string, it might contain several hotwords, + * the hotwords are separated by "/". In each of the hotwords, there + * are cjkchars or bpes, the bpe/cjkchar are separated by space (" "). + * For example, hotwords I LOVE YOU and HELLO WORLD, looks like: + * + * "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD" + */ + std::unique_ptr CreateStream(const std::string &hotwords) const; + + /** + * Return true if the given stream has enough frames for decoding. + * Return false otherwise + */ + bool IsReady(OnlineStream *s) const; + + /** Decode a single stream. */ + void DecodeStream(OnlineStream *s) const { + OnlineStream *ss[1] = {s}; + DecodeStreams(ss, 1); + } + + /** + * Warmups up onnxruntime sessions by apply optimization and + * allocating memory prior + * + * @param warmup Number of warmups. + * @param mbs : max-batch-size Max batch size for the models + */ + void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const; + + /** Decode multiple streams in parallel + * + * @param ss Pointer array containing streams to be decoded. + * @param n Number of streams in `ss`. + */ + void DecodeStreams(OnlineStream **ss, int32_t n) const; + + OnlineRecognizerResult GetResult(OnlineStream *s) const; + + // Return true if we detect an endpoint for this stream. + // Note: If this function returns true, you usually want to + // invoke Reset(s). + bool IsEndpoint(OnlineStream *s) const; + + // Clear the state of this stream. If IsEndpoint(s) returns true, + // after calling this function, IsEndpoint(s) will return false + void Reset(OnlineStream *s) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-rnn-lm.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-rnn-lm.cc new file mode 100644 index 00000000..c74177fd --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-rnn-lm.cc @@ -0,0 +1,236 @@ +// sherpa-mnn/csrc/on-rnn-lm.cc +// +// Copyright (c) 2023 Pingfeng Luo +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-rnn-lm.h" + +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OnlineRnnLM::Impl { + public: + explicit Impl(const OnlineLMConfig &config) + : config_(config), + sess_opts_{GetSessionOptions(config)}, + allocator_{} { + Init(config); + } + + // shallow fusion scoring function + void ComputeLMScoreSF(float scale, Hypothesis *hyp) { + if (hyp->nn_lm_states.empty()) { + auto init_states = GetInitStatesSF(); + hyp->nn_lm_scores.value = std::move(init_states.first); + hyp->nn_lm_states = Convert(std::move(init_states.second)); + } + + // get lm score for cur token given the hyp->ys[:-1] and save to lm_log_prob + const float *nn_lm_scores = hyp->nn_lm_scores.value->readMap(); + hyp->lm_log_prob += nn_lm_scores[hyp->ys.back()] * scale; + + // get lm scores for next tokens given the hyp->ys[:] and save to + // nn_lm_scores + std::array x_shape{1, 1}; + MNN::Express::VARP x = MNNUtilsCreateTensor(allocator_, x_shape.data(), + x_shape.size()); + *x->writeMap() = hyp->ys.back(); + auto lm_out = ScoreToken(std::move(x), Convert(hyp->nn_lm_states)); + hyp->nn_lm_scores.value = std::move(lm_out.first); + hyp->nn_lm_states = Convert(std::move(lm_out.second)); + } + + // classic rescore function + void ComputeLMScore(float scale, int32_t context_size, + std::vector *hyps) { + MNNAllocator* allocator; + + for (auto &hyp : *hyps) { + for (auto &h_m : hyp) { + auto &h = h_m.second; + auto &ys = h.ys; + const int32_t token_num_in_chunk = + ys.size() - context_size - h.cur_scored_pos - 1; + + if (token_num_in_chunk < 1) { + continue; + } + + if (h.nn_lm_states.empty()) { + h.nn_lm_states = Convert(GetInitStates()); + } + + if (token_num_in_chunk >= h.lm_rescore_min_chunk) { + std::array x_shape{1, token_num_in_chunk}; + + MNN::Express::VARP x = MNNUtilsCreateTensor( + allocator, x_shape.data(), x_shape.size()); + int *p_x = x->writeMap(); + std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1, + p_x); + + // streaming forward by NN LM + auto out = + ScoreToken(std::move(x), Convert(std::move(h.nn_lm_states))); + + // update NN LM score in hyp + const float *p_nll = out.first->readMap(); + h.lm_log_prob = -scale * (*p_nll); + + // update NN LM states in hyp + h.nn_lm_states = Convert(std::move(out.second)); + + h.cur_scored_pos += token_num_in_chunk; + } + } + } + } + + std::pair> ScoreToken( + MNN::Express::VARP x, std::vector states) { + std::vector inputs = {std::move(x), std::move(states[0]), + std::move(states[1])}; + + auto out = + sess_->onForward(inputs); + + std::vector next_states; + next_states.reserve(2); + next_states.push_back(std::move(out[1])); + next_states.push_back(std::move(out[2])); + + return {std::move(out[0]), std::move(next_states)}; + } + + // get init states for shallow fusion + std::pair> GetInitStatesSF() { + std::vector ans; + ans.reserve(init_states_.size()); + for (auto &s : init_states_) { + ans.emplace_back(View(s)); + } + return {View(init_scores_.value), std::move(ans)}; + } + + // get init states for classic rescore + std::vector GetInitStates() { + std::vector ans; + ans.reserve(init_states_.size()); + + for (const auto &s : init_states_) { + ans.emplace_back(Clone(allocator_, s)); + } + + return ans; + } + + private: + void Init(const OnlineLMConfig &config) { + auto buf = ReadFile(config_.model); + + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)buf.data(), buf.size(), + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + MNNMeta meta_data = sess_->getInfo()->metaData; + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(rnn_num_layers_, "num_layers"); + SHERPA_ONNX_READ_META_DATA(rnn_hidden_size_, "hidden_size"); + SHERPA_ONNX_READ_META_DATA(sos_id_, "sos_id"); + + ComputeInitStates(); + } + + void ComputeInitStates() { + constexpr int32_t kBatchSize = 1; + std::array h_shape{rnn_num_layers_, kBatchSize, + rnn_hidden_size_}; + std::array c_shape{rnn_num_layers_, kBatchSize, + rnn_hidden_size_}; + MNN::Express::VARP h = MNNUtilsCreateTensor(allocator_, h_shape.data(), + h_shape.size()); + MNN::Express::VARP c = MNNUtilsCreateTensor(allocator_, c_shape.data(), + c_shape.size()); + Fill(h, 0); + Fill(c, 0); + std::array x_shape{1, 1}; + MNN::Express::VARP x = MNNUtilsCreateTensor(allocator_, x_shape.data(), + x_shape.size()); + *x->writeMap() = sos_id_; + + std::vector states; + states.push_back(std::move(h)); + states.push_back(std::move(c)); + auto pair = ScoreToken(std::move(x), std::move(states)); + + init_scores_.value = std::move(pair.first); // only used during + // shallow fusion + init_states_ = std::move(pair.second); + } + + private: + OnlineLMConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + CopyableOrtValue init_scores_; + std::vector init_states_; + + int32_t rnn_num_layers_ = 2; + int32_t rnn_hidden_size_ = 512; + int32_t sos_id_ = 1; +}; + +OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config) + : impl_(std::make_unique(config)) {} + +OnlineRnnLM::~OnlineRnnLM() = default; + +// classic rescore state init +std::vector OnlineRnnLM::GetInitStates() { + return impl_->GetInitStates(); +} + +// shallow fusion state init +std::pair> OnlineRnnLM::GetInitStatesSF() { + return impl_->GetInitStatesSF(); +} + +std::pair> OnlineRnnLM::ScoreToken( + MNN::Express::VARP x, std::vector states) { + return impl_->ScoreToken(std::move(x), std::move(states)); +} + +// classic rescore scores +void OnlineRnnLM::ComputeLMScore(float scale, int32_t context_size, + std::vector *hyps) { + return impl_->ComputeLMScore(scale, context_size, hyps); +} + +// shallow fusion scores +void OnlineRnnLM::ComputeLMScoreSF(float scale, Hypothesis *hyp) { + return impl_->ComputeLMScoreSF(scale, hyp); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-rnn-lm.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-rnn-lm.h new file mode 100644 index 00000000..32310e35 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-rnn-lm.h @@ -0,0 +1,68 @@ +// sherpa-mnn/csrc/online-rnn-lm.h +// +// Copyright (c) 2023 Pingfeng Luo +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_ +#define SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_ + +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/online-lm-config.h" +#include "sherpa-mnn/csrc/online-lm.h" + +namespace sherpa_mnn { + +class OnlineRnnLM : public OnlineLM { + public: + ~OnlineRnnLM() override; + + explicit OnlineRnnLM(const OnlineLMConfig &config); + + // init scores for classic rescore + std::vector GetInitStates() override; + + // init scores for shallow fusion + std::pair> GetInitStatesSF() override; + + /** ScoreToken a batch of sentences (shallow fusion). + * + * @param x A 2-D tensor of shape (N, L) with data type int64. + * @param states It contains the states for the LM model + * @return Return a pair containing + * - log_prob of NN LM + * - updated states + * + */ + std::pair> ScoreToken( + MNN::Express::VARP x, std::vector states) override; + + /** This function updates hyp.lm_lob_prob of hyps (classic rescore). + * + * @param scale LM score + * @param context_size Context size of the transducer decoder model + * @param hyps It is changed in-place. + * + */ + void ComputeLMScore(float scale, int32_t context_size, + std::vector *hyps) override; + + /** This function updates lm_lob_prob and nn_lm_scores of hyp (shallow fusion). + * + * @param scale LM score + * @param hyps It is changed in-place. + * + */ + void ComputeLMScoreSF(float scale, Hypothesis *hyp) override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-stream.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-stream.cc new file mode 100644 index 00000000..28ff9a01 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-stream.cc @@ -0,0 +1,267 @@ +// sherpa-mnn/csrc/online-stream.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-mnn/csrc/online-stream.h" + +#include +#include +#include + +#include "sherpa-mnn/csrc/features.h" +#include "sherpa-mnn/csrc/transducer-keyword-decoder.h" + +namespace sherpa_mnn { + +class OnlineStream::Impl { + public: + explicit Impl(const FeatureExtractorConfig &config, + ContextGraphPtr context_graph) + : feat_extractor_(config), context_graph_(std::move(context_graph)) {} + + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { + feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); + } + + void InputFinished() const { feat_extractor_.InputFinished(); } + + int32_t NumFramesReady() const { + return feat_extractor_.NumFramesReady() - start_frame_index_; + } + + bool IsLastFrame(int32_t frame) const { + return feat_extractor_.IsLastFrame(frame); + } + + std::vector GetFrames(int32_t frame_index, int32_t n) const { + return feat_extractor_.GetFrames(frame_index + start_frame_index_, n); + } + + void Reset() { + // we don't reset the feature extractor + start_frame_index_ += num_processed_frames_; + num_processed_frames_ = 0; + } + + int32_t &GetNumProcessedFrames() { return num_processed_frames_; } + + int32_t GetNumFramesSinceStart() const { return start_frame_index_; } + + int32_t &GetCurrentSegment() { return segment_; } + + void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; } + + OnlineTransducerDecoderResult &GetResult() { return result_; } + + void SetKeywordResult(const TransducerKeywordResult &r) { + keyword_result_ = r; + } + TransducerKeywordResult &GetKeywordResult(bool remove_duplicates) { + if (remove_duplicates) { + if (!prev_keyword_result_.timestamps.empty() && + !keyword_result_.timestamps.empty() && + keyword_result_.timestamps[0] <= + prev_keyword_result_.timestamps.back()) { + return empty_keyword_result_; + } else { + prev_keyword_result_ = keyword_result_; + } + return keyword_result_; + } else { + return keyword_result_; + } + } + + OnlineCtcDecoderResult &GetCtcResult() { return ctc_result_; } + + void SetCtcResult(const OnlineCtcDecoderResult &r) { ctc_result_ = r; } + + void SetParaformerResult(const OnlineParaformerDecoderResult &r) { + paraformer_result_ = r; + } + + OnlineParaformerDecoderResult &GetParaformerResult() { + return paraformer_result_; + } + + int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); } + + void SetStates(std::vector states) { + states_ = std::move(states); + } + + std::vector &GetStates() { return states_; } + + void SetNeMoDecoderStates(std::vector decoder_states) { + decoder_states_ = std::move(decoder_states); + } + + std::vector &GetNeMoDecoderStates() { return decoder_states_; } + + const ContextGraphPtr &GetContextGraph() const { return context_graph_; } + + std::vector &GetParaformerFeatCache() { + return paraformer_feat_cache_; + } + + std::vector &GetParaformerEncoderOutCache() { + return paraformer_encoder_out_cache_; + } + + std::vector &GetParaformerAlphaCache() { + return paraformer_alpha_cache_; + } + + void SetFasterDecoder(std::unique_ptr decoder) { + faster_decoder_ = std::move(decoder); + } + + kaldi_decoder::FasterDecoder *GetFasterDecoder() const { + return faster_decoder_.get(); + } + + int32_t &GetFasterDecoderProcessedFrames() { + return faster_decoder_processed_frames_; + } + + private: + FeatureExtractor feat_extractor_; + /// For contextual-biasing + ContextGraphPtr context_graph_; + int32_t num_processed_frames_ = 0; // before subsampling + int32_t start_frame_index_ = 0; // never reset + int32_t segment_ = 0; + OnlineTransducerDecoderResult result_; + TransducerKeywordResult prev_keyword_result_; + TransducerKeywordResult keyword_result_; + TransducerKeywordResult empty_keyword_result_; + OnlineCtcDecoderResult ctc_result_; + std::vector states_; // states for transducer or ctc models + std::vector decoder_states_; // states for nemo transducer models + std::vector paraformer_feat_cache_; + std::vector paraformer_encoder_out_cache_; + std::vector paraformer_alpha_cache_; + OnlineParaformerDecoderResult paraformer_result_; + std::unique_ptr faster_decoder_; + int32_t faster_decoder_processed_frames_ = 0; +}; + +OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/, + ContextGraphPtr context_graph /*= nullptr */) + : impl_(std::make_unique(config, std::move(context_graph))) {} + +OnlineStream::~OnlineStream() = default; + +void OnlineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, + int32_t n) const { + impl_->AcceptWaveform(sampling_rate, waveform, n); +} + +void OnlineStream::InputFinished() const { impl_->InputFinished(); } + +int32_t OnlineStream::NumFramesReady() const { return impl_->NumFramesReady(); } + +bool OnlineStream::IsLastFrame(int32_t frame) const { + return impl_->IsLastFrame(frame); +} + +std::vector OnlineStream::GetFrames(int32_t frame_index, + int32_t n) const { + return impl_->GetFrames(frame_index, n); +} + +void OnlineStream::Reset() { impl_->Reset(); } + +int32_t OnlineStream::FeatureDim() const { return impl_->FeatureDim(); } + +int32_t &OnlineStream::GetNumProcessedFrames() { + return impl_->GetNumProcessedFrames(); +} + +int32_t OnlineStream::GetNumFramesSinceStart() const { + return impl_->GetNumFramesSinceStart(); +} + +int32_t &OnlineStream::GetCurrentSegment() { + return impl_->GetCurrentSegment(); +} + +void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) { + impl_->SetResult(r); +} + +OnlineTransducerDecoderResult &OnlineStream::GetResult() { + return impl_->GetResult(); +} + +void OnlineStream::SetKeywordResult(const TransducerKeywordResult &r) { + impl_->SetKeywordResult(r); +} + +TransducerKeywordResult &OnlineStream::GetKeywordResult( + bool remove_duplicates /*=false*/) { + return impl_->GetKeywordResult(remove_duplicates); +} + +OnlineCtcDecoderResult &OnlineStream::GetCtcResult() { + return impl_->GetCtcResult(); +} + +void OnlineStream::SetCtcResult(const OnlineCtcDecoderResult &r) { + impl_->SetCtcResult(r); +} + +void OnlineStream::SetParaformerResult(const OnlineParaformerDecoderResult &r) { + impl_->SetParaformerResult(r); +} + +OnlineParaformerDecoderResult &OnlineStream::GetParaformerResult() { + return impl_->GetParaformerResult(); +} + +void OnlineStream::SetStates(std::vector states) { + impl_->SetStates(std::move(states)); +} + +std::vector &OnlineStream::GetStates() { + return impl_->GetStates(); +} + +void OnlineStream::SetNeMoDecoderStates( + std::vector decoder_states) { + return impl_->SetNeMoDecoderStates(std::move(decoder_states)); +} + +std::vector &OnlineStream::GetNeMoDecoderStates() { + return impl_->GetNeMoDecoderStates(); +} + +const ContextGraphPtr &OnlineStream::GetContextGraph() const { + return impl_->GetContextGraph(); +} + +void OnlineStream::SetFasterDecoder( + std::unique_ptr decoder) { + impl_->SetFasterDecoder(std::move(decoder)); +} + +kaldi_decoder::FasterDecoder *OnlineStream::GetFasterDecoder() const { + return impl_->GetFasterDecoder(); +} + +int32_t &OnlineStream::GetFasterDecoderProcessedFrames() { + return impl_->GetFasterDecoderProcessedFrames(); +} + +std::vector &OnlineStream::GetParaformerFeatCache() { + return impl_->GetParaformerFeatCache(); +} + +std::vector &OnlineStream::GetParaformerEncoderOutCache() { + return impl_->GetParaformerEncoderOutCache(); +} + +std::vector &OnlineStream::GetParaformerAlphaCache() { + return impl_->GetParaformerAlphaCache(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-stream.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-stream.h new file mode 100644 index 00000000..c25496e8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-stream.h @@ -0,0 +1,121 @@ +// sherpa-mnn/csrc/online-stream.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_STREAM_H_ +#define SHERPA_ONNX_CSRC_ONLINE_STREAM_H_ + +#include +#include + +#include "kaldi-decoder/csrc/faster-decoder.h" +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/context-graph.h" +#include "sherpa-mnn/csrc/features.h" +#include "sherpa-mnn/csrc/online-ctc-decoder.h" +#include "sherpa-mnn/csrc/online-paraformer-decoder.h" +#include "sherpa-mnn/csrc/online-transducer-decoder.h" + +namespace sherpa_mnn { + +struct TransducerKeywordResult; +class OnlineStream { + public: + explicit OnlineStream(const FeatureExtractorConfig &config = {}, + ContextGraphPtr context_graph = nullptr); + + virtual ~OnlineStream(); + + /** + @param sampling_rate The sampling_rate of the input waveform. If it does + not equal to config.sampling_rate, we will do + resampling inside. + @param waveform Pointer to a 1-D array of size n. It must be normalized to + the range [-1, 1]. + @param n Number of entries in waveform + */ + void AcceptWaveform(int32_t sampling_rate, const float *waveform, + int32_t n) const; + + /** + * InputFinished() tells the class you won't be providing any + * more waveform. This will help flush out the last frame or two + * of features, in the case where snip-edges == false; it also + * affects the return value of IsLastFrame(). + */ + void InputFinished() const; + + int32_t NumFramesReady() const; + + /** Note: IsLastFrame() will only ever return true if you have called + * InputFinished() (and this frame is the last frame). + */ + bool IsLastFrame(int32_t frame) const; + + /** Get n frames starting from the given frame index. + * + * @param frame_index The starting frame index + * @param n Number of frames to get. + * @return Return a 2-D tensor of shape (n, feature_dim). + * which is flattened into a 1-D vector (flattened in row major) + */ + std::vector GetFrames(int32_t frame_index, int32_t n) const; + + void Reset(); + + int32_t FeatureDim() const; + + // Return a reference to the number of processed frames so far + // before subsampling.. + // Initially, it is 0. It is always less than NumFramesReady(). + // + // The returned reference is valid as long as this object is alive. + int32_t &GetNumProcessedFrames(); // It's reset after calling Reset() + + int32_t GetNumFramesSinceStart() const; + + int32_t &GetCurrentSegment(); + + void SetResult(const OnlineTransducerDecoderResult &r); + OnlineTransducerDecoderResult &GetResult(); + + void SetKeywordResult(const TransducerKeywordResult &r); + TransducerKeywordResult &GetKeywordResult(bool remove_duplicates = false); + + void SetCtcResult(const OnlineCtcDecoderResult &r); + OnlineCtcDecoderResult &GetCtcResult(); + + void SetParaformerResult(const OnlineParaformerDecoderResult &r); + OnlineParaformerDecoderResult &GetParaformerResult(); + + void SetStates(std::vector states); + std::vector &GetStates(); + + void SetNeMoDecoderStates(std::vector decoder_states); + std::vector &GetNeMoDecoderStates(); + + /** + * Get the context graph corresponding to this stream. + * + * @return Return the context graph for this stream. + */ + const ContextGraphPtr &GetContextGraph() const; + + // for online ctc decoder + void SetFasterDecoder(std::unique_ptr decoder); + kaldi_decoder::FasterDecoder *GetFasterDecoder() const; + int32_t &GetFasterDecoderProcessedFrames(); + + // for streaming paraformer + std::vector &GetParaformerFeatCache(); + std::vector &GetParaformerEncoderOutCache(); + std::vector &GetParaformerAlphaCache(); + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_STREAM_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-decoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-decoder.cc new file mode 100644 index 00000000..a93cd886 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-decoder.cc @@ -0,0 +1,73 @@ +// sherpa-mnn/csrc/online-transducer-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-transducer-decoder.h" + +#include +#include + +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +OnlineTransducerDecoderResult::OnlineTransducerDecoderResult( + const OnlineTransducerDecoderResult &other) + : OnlineTransducerDecoderResult() { + *this = other; +} + +OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( + const OnlineTransducerDecoderResult &other) { + if (this == &other) { + return *this; + } + + tokens = other.tokens; + num_trailing_blanks = other.num_trailing_blanks; + + MNNAllocator* allocator; + if (other.decoder_out.get() != nullptr) { + decoder_out = Clone(allocator, other.decoder_out); + } + + hyps = other.hyps; + + frame_offset = other.frame_offset; + timestamps = other.timestamps; + + ys_probs = other.ys_probs; + lm_probs = other.lm_probs; + context_scores = other.context_scores; + + return *this; +} + +OnlineTransducerDecoderResult::OnlineTransducerDecoderResult( + OnlineTransducerDecoderResult &&other) noexcept + : OnlineTransducerDecoderResult() { + *this = std::move(other); +} + +OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=( + OnlineTransducerDecoderResult &&other) noexcept { + if (this == &other) { + return *this; + } + + tokens = std::move(other.tokens); + num_trailing_blanks = other.num_trailing_blanks; + decoder_out = std::move(other.decoder_out); + hyps = std::move(other.hyps); + + frame_offset = other.frame_offset; + timestamps = std::move(other.timestamps); + + ys_probs = std::move(other.ys_probs); + lm_probs = std::move(other.lm_probs); + context_scores = std::move(other.context_scores); + + return *this; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-decoder.h new file mode 100644 index 00000000..18a871ec --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-decoder.h @@ -0,0 +1,111 @@ +// sherpa-mnn/csrc/online-transducer-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/hypothesis.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +struct OnlineTransducerDecoderResult { + /// Number of frames after subsampling we have decoded so far + int32_t frame_offset = 0; + + /// The decoded token IDs so far + std::vector tokens; + + /// number of trailing blank frames decoded so far + int32_t num_trailing_blanks = 0; + + /// timestamps[i] contains the output frame index where tokens[i] is decoded. + std::vector timestamps; + + std::vector ys_probs; + std::vector lm_probs; + std::vector context_scores; + + // Cache decoder_out for endpointing + MNN::Express::VARP decoder_out; + + // used only in modified beam_search + Hypotheses hyps; + + OnlineTransducerDecoderResult() + : tokens{}, num_trailing_blanks(0), decoder_out{nullptr}, hyps{} {} + + OnlineTransducerDecoderResult(const OnlineTransducerDecoderResult &other); + + OnlineTransducerDecoderResult &operator=( + const OnlineTransducerDecoderResult &other); + + OnlineTransducerDecoderResult(OnlineTransducerDecoderResult &&other) noexcept; + + OnlineTransducerDecoderResult &operator=( + OnlineTransducerDecoderResult &&other) noexcept; +}; + +class OnlineStream; +class OnlineTransducerDecoder { + public: + virtual ~OnlineTransducerDecoder() = default; + + /* Return an empty result. + * + * To simplify the decoding code, we add `context_size` blanks + * to the beginning of the decoding result, which will be + * stripped by calling `StripPrecedingBlanks()`. + */ + virtual OnlineTransducerDecoderResult GetEmptyResult() const = 0; + + /** Strip blanks added by `GetEmptyResult()`. + * + * @param r It is changed in-place. + */ + virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) const { + } + + /** Run transducer beam search given the output from the encoder model. + * + * @param encoder_out A 3-D tensor of shape (N, T, joiner_dim) + * @param result It is modified in-place. + * + * @note There is no need to pass encoder_out_length here since for the + * online decoding case, each utterance has the same number of frames + * and there are no paddings. + */ + virtual void Decode(MNN::Express::VARP encoder_out, + std::vector *result) = 0; + + /** Run transducer beam search given the output from the encoder model. + * + * Note: Currently this interface is for contextual-biasing feature which + * needs a ContextGraph owned by the OnlineStream. + * + * @param encoder_out A 3-D tensor of shape (N, T, joiner_dim) + * @param ss A list of OnlineStreams. + * @param result It is modified in-place. + * + * @note There is no need to pass encoder_out_length here since for the + * online decoding case, each utterance has the same number of frames + * and there are no paddings. + */ + virtual void Decode(MNN::Express::VARP /*encoder_out*/, OnlineStream ** /*ss*/, + std::vector * /*result*/) { + SHERPA_ONNX_LOGE( + "This interface is for OnlineTransducerModifiedBeamSearchDecoder."); + exit(-1); + } + + // used for endpointing. We need to keep decoder_out after reset + virtual void UpdateDecoderOut(OnlineTransducerDecoderResult * /*result*/) {} +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-greedy-search-decoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-greedy-search-decoder.cc new file mode 100644 index 00000000..76e86c13 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-greedy-search-decoder.cc @@ -0,0 +1,174 @@ +// sherpa-mnn/csrc/online-transducer-greedy-search-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-transducer-greedy-search-decoder.h" + +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +static void UseCachedDecoderOut( + const std::vector &results, + MNN::Express::VARP decoder_out) { + std::vector shape = + decoder_out->getInfo()->dim; + float *dst = decoder_out->writeMap(); + for (const auto &r : results) { + if (r.decoder_out.get() != nullptr) { + const float *src = r.decoder_out->readMap(); + std::copy(src, src + shape[1], dst); + } + dst += shape[1]; + } +} + +static void UpdateCachedDecoderOut( + MNNAllocator *allocator, MNN::Express::VARP decoder_out, + std::vector *results) { + std::vector shape = + decoder_out->getInfo()->dim; + auto memory_info = + (MNNAllocator*)(nullptr); + std::array v_shape{1, shape[1]}; + + const float *src = decoder_out->readMap(); + for (auto &r : *results) { + if (r.decoder_out.get() == nullptr) { + r.decoder_out = MNNUtilsCreateTensor(allocator, v_shape.data(), + v_shape.size()); + } + + float *dst = r.decoder_out->writeMap(); + std::copy(src, src + shape[1], dst); + src += shape[1]; + } +} + +OnlineTransducerDecoderResult +OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { + int32_t context_size = model_->ContextSize(); + int32_t blank_id = 0; // always 0 + OnlineTransducerDecoderResult r; + r.tokens.resize(context_size, -1); + r.tokens.back() = blank_id; + + return r; +} + +void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( + OnlineTransducerDecoderResult *r) const { + int32_t context_size = model_->ContextSize(); + + auto start = r->tokens.begin() + context_size; + auto end = r->tokens.end(); + + r->tokens = std::vector(start, end); +} + +void OnlineTransducerGreedySearchDecoder::Decode( + MNN::Express::VARP encoder_out, + std::vector *result) { + std::vector encoder_out_shape = + encoder_out->getInfo()->dim; + + if (encoder_out_shape[0] != static_cast(result->size())) { + SHERPA_ONNX_LOGE( + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d", + static_cast(encoder_out_shape[0]), + static_cast(result->size())); + exit(-1); + } + + int32_t batch_size = static_cast(encoder_out_shape[0]); + int32_t num_frames = static_cast(encoder_out_shape[1]); + int32_t vocab_size = model_->VocabSize(); + + MNN::Express::VARP decoder_out{nullptr}; + bool is_batch_decoder_out_cached = true; + for (const auto &r : *result) { + if (r.decoder_out.get() == nullptr) { + is_batch_decoder_out_cached = false; + break; + } + } + + if (is_batch_decoder_out_cached) { + auto &r = result->front(); + std::vector decoder_out_shape = + r.decoder_out->getInfo()->dim; + decoder_out_shape[0] = batch_size; + decoder_out = MNNUtilsCreateTensor(model_->Allocator(), + decoder_out_shape.data(), + decoder_out_shape.size()); + UseCachedDecoderOut(*result, decoder_out); + } else { + MNN::Express::VARP decoder_input = model_->BuildDecoderInput(*result); + decoder_out = model_->RunDecoder(std::move(decoder_input)); + } + + for (int32_t t = 0; t != num_frames; ++t) { + MNN::Express::VARP cur_encoder_out = + GetEncoderOutFrame(model_->Allocator(), encoder_out, t); + MNN::Express::VARP logit = + model_->RunJoiner(std::move(cur_encoder_out), View(decoder_out)); + + float *p_logit = logit->writeMap(); + + bool emitted = false; + for (int32_t i = 0; i < batch_size; ++i, p_logit += vocab_size) { + auto &r = (*result)[i]; + if (blank_penalty_ > 0.0) { + p_logit[0] -= blank_penalty_; // assuming blank id is 0 + } + + auto y = static_cast(std::distance( + static_cast(p_logit), + std::max_element(static_cast(p_logit), + static_cast(p_logit) + vocab_size))); + // blank id is hardcoded to 0 + // also, it treats unk as blank + if (y != 0 && y != unk_id_) { + emitted = true; + r.tokens.push_back(y); + r.timestamps.push_back(t + r.frame_offset); + r.num_trailing_blanks = 0; + } else { + ++r.num_trailing_blanks; + } + + // export the per-token log scores + if (y != 0 && y != unk_id_) { + // apply temperature-scaling + for (int32_t n = 0; n < vocab_size; ++n) { + p_logit[n] /= temperature_scale_; + } + LogSoftmax(p_logit, vocab_size); // renormalize probabilities, + // save time by doing it only for + // emitted symbols + const float *p_logprob = p_logit; // rename p_logit as p_logprob, + // now it contains normalized + // probability + r.ys_probs.push_back(p_logprob[y]); + } + } + if (emitted) { + MNN::Express::VARP decoder_input = model_->BuildDecoderInput(*result); + decoder_out = model_->RunDecoder(std::move(decoder_input)); + } + } + + UpdateCachedDecoderOut(model_->Allocator(), decoder_out, result); + + // Update frame_offset + for (auto &r : *result) { + r.frame_offset += num_frames; + } +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-greedy-search-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-greedy-search-decoder.h new file mode 100644 index 00000000..804d8923 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-greedy-search-decoder.h @@ -0,0 +1,42 @@ +// sherpa-mnn/csrc/online-transducer-greedy-search-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ + +#include + +#include "sherpa-mnn/csrc/online-transducer-decoder.h" +#include "sherpa-mnn/csrc/online-transducer-model.h" + +namespace sherpa_mnn { + +class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { + public: + OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model, + int32_t unk_id, + float blank_penalty, + float temperature_scale) + : model_(model), + unk_id_(unk_id), + blank_penalty_(blank_penalty), + temperature_scale_(temperature_scale) {} + + OnlineTransducerDecoderResult GetEmptyResult() const override; + + void StripLeadingBlanks(OnlineTransducerDecoderResult *r) const override; + + void Decode(MNN::Express::VARP encoder_out, + std::vector *result) override; + + private: + OnlineTransducerModel *model_; // Not owned + int32_t unk_id_; + float blank_penalty_; + float temperature_scale_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-greedy-search-nemo-decoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-greedy-search-nemo-decoder.cc new file mode 100644 index 00000000..32ebb3b3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-greedy-search-nemo-decoder.cc @@ -0,0 +1,129 @@ +// sherpa-mnn/csrc/online-transducer-greedy-search-nemo-decoder.cc +// +// Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Sangeet Sagar + +#include "sherpa-mnn/csrc/online-transducer-greedy-search-nemo-decoder.h" + +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-stream.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +static MNN::Express::VARP BuildDecoderInput(int32_t token, MNNAllocator *allocator) { + std::array shape{1, 1}; + + MNN::Express::VARP decoder_input = + MNNUtilsCreateTensor(allocator, shape.data(), shape.size()); + + int32_t *p = decoder_input->writeMap(); + + p[0] = token; + + return decoder_input; +} + +static void DecodeOne(const float *encoder_out, int32_t num_rows, + int32_t num_cols, OnlineTransducerNeMoModel *model, + float blank_penalty, OnlineStream *s) { + auto memory_info = + (MNNAllocator*)(nullptr); + + int32_t vocab_size = model->VocabSize(); + int32_t blank_id = vocab_size - 1; + + auto &r = s->GetResult(); + + MNN::Express::VARP decoder_out{nullptr}; + + auto decoder_input = BuildDecoderInput( + r.tokens.empty() ? blank_id : r.tokens.back(), model->Allocator()); + + std::vector &last_decoder_states = s->GetNeMoDecoderStates(); + + std::vector tmp_decoder_states; + tmp_decoder_states.reserve(last_decoder_states.size()); + for (auto &v : last_decoder_states) { + tmp_decoder_states.push_back(View(v)); + } + + // decoder_output_pair.second returns the next decoder state + std::pair> decoder_output_pair = + model->RunDecoder(std::move(decoder_input), + std::move(tmp_decoder_states)); + + std::array encoder_shape{1, num_cols, 1}; + + bool emitted = false; + + for (int32_t t = 0; t != num_rows; ++t) { + MNN::Express::VARP cur_encoder_out = MNNUtilsCreateTensor( + memory_info, const_cast(encoder_out) + t * num_cols, num_cols, + encoder_shape.data(), encoder_shape.size()); + + MNN::Express::VARP logit = model->RunJoiner(std::move(cur_encoder_out), + View(decoder_output_pair.first)); + + float *p_logit = logit->writeMap(); + if (blank_penalty > 0) { + p_logit[blank_id] -= blank_penalty; + } + + auto y = static_cast(std::distance( + static_cast(p_logit), + std::max_element(static_cast(p_logit), + static_cast(p_logit) + vocab_size))); + + if (y != blank_id) { + emitted = true; + r.tokens.push_back(y); + r.timestamps.push_back(t + r.frame_offset); + r.num_trailing_blanks = 0; + + decoder_input = BuildDecoderInput(y, model->Allocator()); + + // last decoder state becomes the current state for the first chunk + decoder_output_pair = model->RunDecoder( + std::move(decoder_input), std::move(decoder_output_pair.second)); + } else { + ++r.num_trailing_blanks; + } + } + + if (emitted) { + s->SetNeMoDecoderStates(std::move(decoder_output_pair.second)); + } + + r.frame_offset += num_rows; +} + +void OnlineTransducerGreedySearchNeMoDecoder::Decode(MNN::Express::VARP encoder_out, + OnlineStream **ss, + int32_t n) const { + auto shape = encoder_out->getInfo()->dim; + int32_t batch_size = static_cast(shape[0]); // bs = 1 + + if (batch_size != n) { + SHERPA_ONNX_LOGE("Size mismatch! encoder_out.size(0) %d, n: %d", + static_cast(shape[0]), n); + exit(-1); + } + + int32_t dim1 = static_cast(shape[1]); // T + int32_t dim2 = static_cast(shape[2]); // encoder_out_dim + + const float *p = encoder_out->readMap(); + + for (int32_t i = 0; i != batch_size; ++i) { + const float *this_p = p + dim1 * dim2 * i; + + DecodeOne(this_p, dim1, dim2, model_, blank_penalty_, ss[i]); + } +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-greedy-search-nemo-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-greedy-search-nemo-decoder.h new file mode 100644 index 00000000..fd3fd2b5 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-greedy-search-nemo-decoder.h @@ -0,0 +1,34 @@ +// sherpa-mnn/csrc/online-transducer-greedy-search-nemo-decoder.h +// +// Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Sangeet Sagar + +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ + +#include + +#include "sherpa-mnn/csrc/online-transducer-decoder.h" +#include "sherpa-mnn/csrc/online-transducer-nemo-model.h" + +namespace sherpa_mnn { + +class OnlineStream; + +class OnlineTransducerGreedySearchNeMoDecoder { + public: + OnlineTransducerGreedySearchNeMoDecoder(OnlineTransducerNeMoModel *model, + float blank_penalty) + : model_(model), blank_penalty_(blank_penalty) {} + + // @param n number of elements in ss + void Decode(MNN::Express::VARP encoder_out, OnlineStream **ss, int32_t n) const; + + private: + OnlineTransducerNeMoModel *model_; // Not owned + float blank_penalty_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_NEMO_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-model-config.cc new file mode 100644 index 00000000..2fd9e3c8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-model-config.cc @@ -0,0 +1,51 @@ +// sherpa-mnn/csrc/online-transducer-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-mnn/csrc/online-transducer-model-config.h" + +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OnlineTransducerModelConfig::Register(ParseOptions *po) { + po->Register("encoder", &encoder, "Path to encoder.onnx"); + po->Register("decoder", &decoder, "Path to decoder.onnx"); + po->Register("joiner", &joiner, "Path to joiner.onnx"); +} + +bool OnlineTransducerModelConfig::Validate() const { + if (!FileExists(encoder)) { + SHERPA_ONNX_LOGE("transducer encoder: '%s' does not exist", + encoder.c_str()); + return false; + } + + if (!FileExists(decoder)) { + SHERPA_ONNX_LOGE("transducer decoder: '%s' does not exist", + decoder.c_str()); + return false; + } + + if (!FileExists(joiner)) { + SHERPA_ONNX_LOGE("joiner: '%s' does not exist", joiner.c_str()); + return false; + } + + return true; +} + +std::string OnlineTransducerModelConfig::ToString() const { + std::ostringstream os; + + os << "OnlineTransducerModelConfig("; + os << "encoder=\"" << encoder << "\", "; + os << "decoder=\"" << decoder << "\", "; + os << "joiner=\"" << joiner << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-model-config.h new file mode 100644 index 00000000..6c3f972b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-model-config.h @@ -0,0 +1,32 @@ +// sherpa-mnn/csrc/online-transducer-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OnlineTransducerModelConfig { + std::string encoder; + std::string decoder; + std::string joiner; + + OnlineTransducerModelConfig() = default; + OnlineTransducerModelConfig(const std::string &encoder, + const std::string &decoder, + const std::string &joiner) + : encoder(encoder), decoder(decoder), joiner(joiner) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-model.cc new file mode 100644 index 00000000..0eeaba80 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-model.cc @@ -0,0 +1,230 @@ +// sherpa-mnn/csrc/online-transducer-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2023 Pingfeng Luo +#include "sherpa-mnn/csrc/online-transducer-model.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-conformer-transducer-model.h" +#include "sherpa-mnn/csrc/online-ebranchformer-transducer-model.h" +#include "sherpa-mnn/csrc/online-lstm-transducer-model.h" +#include "sherpa-mnn/csrc/online-zipformer-transducer-model.h" +#include "sherpa-mnn/csrc/online-zipformer2-transducer-model.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace { + +enum class ModelType : std::uint8_t { + kConformer, + kEbranchformer, + kLstm, + kZipformer, + kZipformer2, + kUnknown, +}; + +} // namespace + +namespace sherpa_mnn { + +static ModelType GetModelType(char *model_data, size_t model_data_length, + bool debug) { + MNNEnv env; + std::shared_ptr sess_opts; + + + + auto sess = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts)); + + MNNMeta meta_data = sess->getInfo()->metaData; + if (debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; + auto model_type = + LookupCustomModelMetaData(meta_data, "model_type", allocator); + if (model_type.empty()) { + SHERPA_ONNX_LOGE( + "No model_type in the metadata!\n" + "Please make sure you are using the latest export-onnx.py from icefall " + "to export your transducer models"); + return ModelType::kUnknown; + } + + if (model_type == "conformer") { + return ModelType::kConformer; + } else if (model_type == "ebranchformer") { + return ModelType::kEbranchformer; + } else if (model_type == "lstm") { + return ModelType::kLstm; + } else if (model_type == "zipformer") { + return ModelType::kZipformer; + } else if (model_type == "zipformer2") { + return ModelType::kZipformer2; + } else { + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str()); + return ModelType::kUnknown; + } +} + +std::unique_ptr OnlineTransducerModel::Create( + const OnlineModelConfig &config) { + if (!config.model_type.empty()) { + const auto &model_type = config.model_type; + if (model_type == "conformer") { + return std::make_unique(config); + } else if (model_type == "ebranchformer") { + return std::make_unique(config); + } else if (model_type == "lstm") { + return std::make_unique(config); + } else if (model_type == "zipformer") { + return std::make_unique(config); + } else if (model_type == "zipformer2") { + return std::make_unique(config); + } else { + SHERPA_ONNX_LOGE( + "Invalid model_type: %s. Trying to load the model to get its type", + model_type.c_str()); + } + } + ModelType model_type = ModelType::kUnknown; + + { + auto buffer = ReadFile(config.transducer.encoder); + + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); + } + + switch (model_type) { + case ModelType::kConformer: + return std::make_unique(config); + case ModelType::kEbranchformer: + return std::make_unique(config); + case ModelType::kLstm: + return std::make_unique(config); + case ModelType::kZipformer: + return std::make_unique(config); + case ModelType::kZipformer2: + return std::make_unique(config); + case ModelType::kUnknown: + SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); + return nullptr; + } + + // unreachable code + return nullptr; +} + +MNN::Express::VARP OnlineTransducerModel::BuildDecoderInput( + const std::vector &results) { + int32_t batch_size = static_cast(results.size()); + int32_t context_size = ContextSize(); + std::array shape{batch_size, context_size}; + MNN::Express::VARP decoder_input = MNNUtilsCreateTensor( + Allocator(), shape.data(), shape.size()); + int *p = decoder_input->writeMap(); + + for (const auto &r : results) { + const int *begin = r.tokens.data() + r.tokens.size() - context_size; + const int *end = r.tokens.data() + r.tokens.size(); + std::copy(begin, end, p); + p += context_size; + } + return decoder_input; +} + +MNN::Express::VARP OnlineTransducerModel::BuildDecoderInput( + const std::vector &hyps) { + int32_t batch_size = static_cast(hyps.size()); + int32_t context_size = ContextSize(); + std::array shape{batch_size, context_size}; + MNN::Express::VARP decoder_input = MNNUtilsCreateTensor( + Allocator(), shape.data(), shape.size()); + int *p = decoder_input->writeMap(); + + for (const auto &h : hyps) { + std::copy(h.ys.end() - context_size, h.ys.end(), p); + p += context_size; + } + return decoder_input; +} + +template +std::unique_ptr OnlineTransducerModel::Create( + Manager *mgr, const OnlineModelConfig &config) { + if (!config.model_type.empty()) { + const auto &model_type = config.model_type; + if (model_type == "conformer") { + return std::make_unique(mgr, config); + } else if (model_type == "ebranchformer") { + return std::make_unique(mgr, config); + } else if (model_type == "lstm") { + return std::make_unique(mgr, config); + } else if (model_type == "zipformer") { + return std::make_unique(mgr, config); + } else if (model_type == "zipformer2") { + return std::make_unique(mgr, config); + } else { + SHERPA_ONNX_LOGE( + "Invalid model_type: %s. Trying to load the model to get its type", + model_type.c_str()); + } + } + + auto buffer = ReadFile(mgr, config.transducer.encoder); + auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); + + switch (model_type) { + case ModelType::kConformer: + return std::make_unique(mgr, config); + case ModelType::kEbranchformer: + return std::make_unique(mgr, config); + case ModelType::kLstm: + return std::make_unique(mgr, config); + case ModelType::kZipformer: + return std::make_unique(mgr, config); + case ModelType::kZipformer2: + return std::make_unique(mgr, config); + case ModelType::kUnknown: + SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); + return nullptr; + } + + // unreachable code + return nullptr; +} + +#if __ANDROID_API__ >= 9 +template std::unique_ptr OnlineTransducerModel::Create( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr OnlineTransducerModel::Create( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-model.h new file mode 100644 index 00000000..2613c9f7 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-model.h @@ -0,0 +1,145 @@ +// sherpa-mnn/csrc/online-transducer-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_ + +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/hypothesis.h" +#include "sherpa-mnn/csrc/online-model-config.h" +#include "sherpa-mnn/csrc/online-transducer-decoder.h" +#include "sherpa-mnn/csrc/online-transducer-model-config.h" + +namespace sherpa_mnn { + +struct OnlineTransducerDecoderResult; + +class OnlineTransducerModel { + public: + virtual ~OnlineTransducerModel() = default; + + static std::unique_ptr Create( + const OnlineModelConfig &config); + + template + static std::unique_ptr Create( + Manager *mgr, const OnlineModelConfig &config); + + /** Stack a list of individual states into a batch. + * + * It is the inverse operation of `UnStackStates`. + * + * @param states states[i] contains the state for the i-th utterance. + * @return Return a single value representing the batched state. + */ + virtual std::vector StackStates( + const std::vector> &states) const = 0; + + /** Unstack a batch state into a list of individual states. + * + * It is the inverse operation of `StackStates`. + * + * @param states A batched state. + * @return ans[i] contains the state for the i-th utterance. + */ + virtual std::vector> UnStackStates( + const std::vector &states) const = 0; + + /** Get the initial encoder states. + * + * @return Return the initial encoder state. + */ + virtual std::vector GetEncoderInitStates() = 0; + + /** Set feature dim. + * + * This is used in `OnlineZipformer2TransducerModel`, + * to pass `feature_dim` for `GetEncoderInitStates()`. + * + * This has to be called before GetEncoderInitStates(), so the `encoder_embed` + * init state has the correct `embed_dim` of its output. + */ + virtual void SetFeatureDim(int32_t /*feature_dim*/) {} + + /** Run the encoder. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param states Encoder state of the previous chunk. It is changed in-place. + * @param processed_frames Processed frames before subsampling. It is a 1-D + * tensor with data type int. + * + * @return Return a tuple containing: + * - encoder_out, a tensor of shape (N, T', encoder_out_dim) + * - next_states Encoder state for the next chunk. + */ + virtual std::pair> RunEncoder( + MNN::Express::VARP features, std::vector states, + MNN::Express::VARP processed_frames) = 0; // NOLINT + + /** Run the decoder network. + * + * Caution: We assume there are no recurrent connections in the decoder and + * the decoder is stateless. See + * https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py + * for an example + * + * @param decoder_input It is usually of shape (N, context_size) + * @return Return a tensor of shape (N, decoder_dim). + */ + virtual MNN::Express::VARP RunDecoder(MNN::Express::VARP decoder_input) = 0; + + /** Run the joint network. + * + * @param encoder_out Output of the encoder network. A tensor of shape + * (N, joiner_dim). + * @param decoder_out Output of the decoder network. A tensor of shape + * (N, joiner_dim). + * @return Return a tensor of shape (N, vocab_size). In icefall, the last + * last layer of the joint network is `nn.Linear`, + * not `nn.LogSoftmax`. + */ + virtual MNN::Express::VARP RunJoiner(MNN::Express::VARP encoder_out, + MNN::Express::VARP decoder_out) = 0; + + /** If we are using a stateless decoder and if it contains a + * Conv1D, this function returns the kernel size of the convolution layer. + */ + virtual int32_t ContextSize() const = 0; + + /** We send this number of feature frames to the encoder at a time. */ + virtual int32_t ChunkSize() const = 0; + + /** Number of input frames to discard after each call to RunEncoder. + * + * For instance, if we have 30 frames, chunk_size=8, chunk_shift=6. + * + * In the first call of RunEncoder, we use frames 0~7 since chunk_size is 8. + * Then we discard frame 0~5 since chunk_shift is 6. + * In the second call of RunEncoder, we use frames 6~13; and then we discard + * frames 6~11. + * In the third call of RunEncoder, we use frames 12~19; and then we discard + * frames 12~16. + * + * Note: ChunkSize() - ChunkShift() == right context size + */ + virtual int32_t ChunkShift() const = 0; + + virtual int32_t VocabSize() const = 0; + + virtual int32_t SubsamplingFactor() const { return 4; } + + virtual MNNAllocator *Allocator() = 0; + + MNN::Express::VARP BuildDecoderInput( + const std::vector &results); + + MNN::Express::VARP BuildDecoderInput(const std::vector &hyps); +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-modified-beam-search-decoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-modified-beam-search-decoder.cc new file mode 100644 index 00000000..3e6ff534 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-modified-beam-search-decoder.cc @@ -0,0 +1,270 @@ +// sherpa-mnn/csrc/online-transducer-modified-beam-search-decoder.cc +// +// Copyright (c) 2023 Pingfeng Luo +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-transducer-modified-beam-search-decoder.h" + +#include +#include +#include + +#include "sherpa-mnn/csrc/log.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +static void UseCachedDecoderOut( + const std::vector &hyps_row_splits, + const std::vector &results, + MNN::Express::VARP decoder_out) { + std::vector shape = + decoder_out->getInfo()->dim; + + float *dst = decoder_out->writeMap(); + + int32_t batch_size = static_cast(results.size()); + for (int32_t i = 0; i != batch_size; ++i) { + int32_t num_hyps = hyps_row_splits[i + 1] - hyps_row_splits[i]; + if (num_hyps > 1 || nullptr == results[i].decoder_out.get()) { + dst += num_hyps * shape[1]; + continue; + } + + const float *src = results[i].decoder_out->readMap(); + std::copy(src, src + shape[1], dst); + dst += shape[1]; + } +} + +OnlineTransducerDecoderResult +OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { + int32_t context_size = model_->ContextSize(); + int32_t blank_id = 0; // always 0 + OnlineTransducerDecoderResult r; + std::vector blanks(context_size, -1); + blanks.back() = blank_id; + + Hypotheses blank_hyp({{blanks, 0}}); + r.hyps = std::move(blank_hyp); + r.tokens = std::move(blanks); + return r; +} + +void OnlineTransducerModifiedBeamSearchDecoder::StripLeadingBlanks( + OnlineTransducerDecoderResult *r) const { + int32_t context_size = model_->ContextSize(); + auto hyp = r->hyps.GetMostProbable(true); + + std::vector tokens(hyp.ys.begin() + context_size, hyp.ys.end()); + r->tokens = std::move(tokens); + r->timestamps = std::move(hyp.timestamps); + + // export per-token scores + r->ys_probs = std::move(hyp.ys_probs); + r->lm_probs = std::move(hyp.lm_probs); + r->context_scores = std::move(hyp.context_scores); + + r->num_trailing_blanks = hyp.num_trailing_blanks; +} + +void OnlineTransducerModifiedBeamSearchDecoder::Decode( + MNN::Express::VARP encoder_out, + std::vector *result) { + Decode(std::move(encoder_out), nullptr, result); +} + +void OnlineTransducerModifiedBeamSearchDecoder::Decode( + MNN::Express::VARP encoder_out, OnlineStream **ss, + std::vector *result) { + std::vector encoder_out_shape = + encoder_out->getInfo()->dim; + + if (static_cast(encoder_out_shape[0]) != + static_cast(result->size())) { + SHERPA_ONNX_LOGE( + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n", + static_cast(encoder_out_shape[0]), + static_cast(result->size())); + exit(-1); + } + + int32_t batch_size = static_cast(encoder_out_shape[0]); + + int32_t num_frames = static_cast(encoder_out_shape[1]); + int32_t vocab_size = model_->VocabSize(); + + std::vector cur; + for (auto &r : *result) { + cur.push_back(std::move(r.hyps)); + } + std::vector prev; + + for (int32_t t = 0; t != num_frames; ++t) { + // Due to merging paths with identical token sequences, + // not all utterances have "num_active_paths" paths. + auto hyps_row_splits = GetHypsRowSplits(cur); + int32_t num_hyps = + hyps_row_splits.back(); // total num hyps for all utterance + prev.clear(); + for (auto &hyps : cur) { + for (auto &h : hyps) { + prev.push_back(std::move(h.second)); + } + } + cur.clear(); + cur.reserve(batch_size); + + MNN::Express::VARP decoder_input = model_->BuildDecoderInput(prev); + MNN::Express::VARP decoder_out = model_->RunDecoder(std::move(decoder_input)); + if (t == 0) { + UseCachedDecoderOut(hyps_row_splits, *result, decoder_out); + } + + MNN::Express::VARP cur_encoder_out = + GetEncoderOutFrame(model_->Allocator(), encoder_out, t); + cur_encoder_out = + Repeat(model_->Allocator(), cur_encoder_out, hyps_row_splits); + MNN::Express::VARP logit = + model_->RunJoiner(std::move(cur_encoder_out), View(decoder_out)); + + float *p_logit = logit->writeMap(); + + // copy raw logits, apply temperature-scaling (for confidences) + // Note: temperature scaling is used only for the confidences, + // the decoding algorithm uses the original logits + int32_t p_logit_items = vocab_size * num_hyps; + std::vector logit_with_temperature(p_logit_items); + { + std::copy(p_logit, p_logit + p_logit_items, + logit_with_temperature.begin()); + for (float &elem : logit_with_temperature) { + elem /= temperature_scale_; + } + LogSoftmax(logit_with_temperature.data(), vocab_size, num_hyps); + } + + if (blank_penalty_ > 0.0) { + // assuming blank id is 0 + SubtractBlank(p_logit, vocab_size, num_hyps, 0, blank_penalty_); + } + LogSoftmax(p_logit, vocab_size, num_hyps); + + // now p_logit contains log_softmax output, we rename it to p_logprob + // to match what it actually contains + float *p_logprob = p_logit; + + // add log_prob of each hypothesis to p_logprob before taking top_k + for (int32_t i = 0; i != num_hyps; ++i) { + float log_prob = prev[i].log_prob; + if (lm_ && shallow_fusion_) { + log_prob += prev[i].lm_log_prob; + } + + for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) { + *p_logprob += log_prob; + } + } + p_logprob = p_logit; // we changed p_logprob in the above for loop + + for (int32_t b = 0; b != batch_size; ++b) { + int32_t frame_offset = (*result)[b].frame_offset; + int32_t start = hyps_row_splits[b]; + int32_t end = hyps_row_splits[b + 1]; + auto topk = + TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_); + + Hypotheses hyps; + for (auto k : topk) { + int32_t hyp_index = k / vocab_size + start; + int32_t new_token = k % vocab_size; + + Hypothesis new_hyp = prev[hyp_index]; + const float prev_lm_log_prob = new_hyp.lm_log_prob; + float context_score = 0; + auto context_state = new_hyp.context_state; + + // blank is hardcoded to 0 + // also, it treats unk as blank + if (new_token != 0 && new_token != unk_id_) { + new_hyp.ys.push_back(new_token); + new_hyp.timestamps.push_back(t + frame_offset); + new_hyp.num_trailing_blanks = 0; + if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) { + auto context_res = ss[b]->GetContextGraph()->ForwardOneStep( + context_state, new_token, false /*strict mode*/); + context_score = std::get<0>(context_res); + new_hyp.context_state = std::get<1>(context_res); + } + if (lm_ && shallow_fusion_) { + lm_->ComputeLMScoreSF(lm_scale_, &new_hyp); + } + } else { + ++new_hyp.num_trailing_blanks; + } + if (lm_ && shallow_fusion_) { + new_hyp.log_prob = p_logprob[k] + context_score - + prev_lm_log_prob; // log_prob only includes the + // score of the transducer + } else { + new_hyp.log_prob = p_logprob[k] + context_score; // rescore or no LM + // previous token + // score is ignored + } + + // export the per-token log scores + if (new_token != 0 && new_token != unk_id_) { + float y_prob = logit_with_temperature[start * vocab_size + k]; + new_hyp.ys_probs.push_back(y_prob); + + if (lm_ && shallow_fusion_) { // export only if + // LM shallow fusion is used + float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob; + + if (lm_scale_ != 0.0) { + lm_prob /= lm_scale_; // remove lm-scale + } + new_hyp.lm_probs.push_back(lm_prob); + } + + // export only when `ContextGraph` is used + if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) { + new_hyp.context_scores.push_back(context_score); + } + } + + hyps.Add(std::move(new_hyp)); + } // for (auto k : topk) + cur.push_back(std::move(hyps)); + p_logprob += (end - start) * vocab_size; + } // for (int32_t b = 0; b != batch_size; ++b) + } // for (int32_t t = 0; t != num_frames; ++t) + + // classic lm rescore + if (lm_ && !shallow_fusion_) { + lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur); + } + + for (int32_t b = 0; b != batch_size; ++b) { + auto &hyps = cur[b]; + auto best_hyp = hyps.GetMostProbable(true); + auto &r = (*result)[b]; + + r.hyps = std::move(hyps); + r.tokens = std::move(best_hyp.ys); + r.num_trailing_blanks = best_hyp.num_trailing_blanks; + r.frame_offset += num_frames; + } +} + +void OnlineTransducerModifiedBeamSearchDecoder::UpdateDecoderOut( + OnlineTransducerDecoderResult *result) { + if (static_cast(result->tokens.size()) == model_->ContextSize()) { + result->decoder_out = MNN::Express::VARP{nullptr}; + return; + } + MNN::Express::VARP decoder_input = model_->BuildDecoderInput({*result}); + result->decoder_out = model_->RunDecoder(std::move(decoder_input)); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-modified-beam-search-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-modified-beam-search-decoder.h new file mode 100644 index 00000000..9c941047 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-modified-beam-search-decoder.h @@ -0,0 +1,64 @@ +// sherpa-mnn/csrc/online-transducer-modified_beam-search-decoder.h +// +// Copyright (c) 2023 Pingfeng Luo +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ + +#include + +#include "sherpa-mnn/csrc/online-lm.h" +#include "sherpa-mnn/csrc/online-stream.h" +#include "sherpa-mnn/csrc/online-transducer-decoder.h" +#include "sherpa-mnn/csrc/online-transducer-model.h" + +namespace sherpa_mnn { + +class OnlineTransducerModifiedBeamSearchDecoder + : public OnlineTransducerDecoder { + public: + OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model, + OnlineLM *lm, + int32_t max_active_paths, + float lm_scale, + bool shallow_fusion, + int32_t unk_id, + float blank_penalty, + float temperature_scale) + : model_(model), + lm_(lm), + max_active_paths_(max_active_paths), + lm_scale_(lm_scale), + shallow_fusion_(shallow_fusion), + unk_id_(unk_id), + blank_penalty_(blank_penalty), + temperature_scale_(temperature_scale) {} + + OnlineTransducerDecoderResult GetEmptyResult() const override; + + void StripLeadingBlanks(OnlineTransducerDecoderResult *r) const override; + + void Decode(MNN::Express::VARP encoder_out, + std::vector *result) override; + + void Decode(MNN::Express::VARP encoder_out, OnlineStream **ss, + std::vector *result) override; + + void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override; + + private: + OnlineTransducerModel *model_; // Not owned + OnlineLM *lm_; // Not owned + + int32_t max_active_paths_; + float lm_scale_; // used only when lm_ is not nullptr + bool shallow_fusion_; // used only when lm_ is not nullptr + int32_t unk_id_; + float blank_penalty_; + float temperature_scale_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-nemo-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-nemo-model.cc new file mode 100644 index 00000000..d8f90379 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-nemo-model.cc @@ -0,0 +1,538 @@ +// sherpa-mnn/csrc/online-transducer-nemo-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Sangeet Sagar + +#include "sherpa-mnn/csrc/online-transducer-nemo-model.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/cat.h" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-transducer-decoder.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" +#include "sherpa-mnn/csrc/transpose.h" +#include "sherpa-mnn/csrc/unbind.h" + +namespace sherpa_mnn { + +class OnlineTransducerNeMoModel::Impl { + public: + explicit Impl(const OnlineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const OnlineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } + } + + std::vector RunEncoder(MNN::Express::VARP features, + std::vector states) { + MNN::Express::VARP &cache_last_channel = states[0]; + MNN::Express::VARP &cache_last_time = states[1]; + MNN::Express::VARP &cache_last_channel_len = states[2]; + + int32_t batch_size = features->getInfo()->dim[0]; + + std::array length_shape{batch_size}; + + MNN::Express::VARP length = MNNUtilsCreateTensor( + allocator_, length_shape.data(), length_shape.size()); + + int *p_length = length->writeMap(); + + std::fill(p_length, p_length + batch_size, ChunkSize()); + + // (B, T, C) -> (B, C, T) + features = Transpose12(allocator_, features); + + std::vector inputs = { + std::move(features), View(length), std::move(cache_last_channel), + std::move(cache_last_time), std::move(cache_last_channel_len)}; + + auto out = encoder_sess_->onForward(inputs); + // out[0]: logit + // out[1] logit_length + // out[2:] states_next + // + // we need to remove out[1] + + std::vector ans; + ans.reserve(out.size() - 1); + + for (int32_t i = 0; i != out.size(); ++i) { + if (i == 1) { + continue; + } + + ans.push_back(std::move(out[i])); + } + + return ans; + } + + std::pair> RunDecoder( + MNN::Express::VARP targets, std::vector states) { + MNNAllocator* memory_info = nullptr; + + auto shape = targets->getInfo()->dim; + int32_t batch_size = static_cast(shape[0]); + + std::vector length_shape = {batch_size}; + std::vector length_value(batch_size, 1); + + MNN::Express::VARP targets_length = MNNUtilsCreateTensor( + memory_info, length_value.data(), batch_size, length_shape.data(), + length_shape.size()); + + std::vector decoder_inputs; + decoder_inputs.reserve(2 + states.size()); + + decoder_inputs.push_back(std::move(targets)); + decoder_inputs.push_back(std::move(targets_length)); + + for (auto &s : states) { + decoder_inputs.push_back(std::move(s)); + } + + auto decoder_out = decoder_sess_->onForward(decoder_inputs); + + std::vector states_next; + states_next.reserve(states.size()); + + // decoder_out[0]: decoder_output + // decoder_out[1]: decoder_output_length (discarded) + // decoder_out[2:] states_next + + for (int32_t i = 0; i != states.size(); ++i) { + states_next.push_back(std::move(decoder_out[i + 2])); + } + + // we discard decoder_out[1] + return {std::move(decoder_out[0]), std::move(states_next)}; + } + + MNN::Express::VARP RunJoiner(MNN::Express::VARP encoder_out, MNN::Express::VARP decoder_out) { + std::vector joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = joiner_sess_->onForward( + joiner_input); + + return std::move(logit[0]); + } + + std::vector GetDecoderInitStates() { + std::vector ans; + ans.reserve(2); + ans.push_back(View(lstm0_)); + ans.push_back(View(lstm1_)); + + return ans; + } + + int32_t ChunkSize() const { return window_size_; } + + int32_t ChunkShift() const { return chunk_shift_; } + + int32_t SubsamplingFactor() const { return subsampling_factor_; } + + int32_t VocabSize() const { return vocab_size_; } + + MNNAllocator *Allocator() { return allocator_; } + + std::string FeatureNormalizationMethod() const { return normalize_type_; } + + // Return a vector containing 3 tensors + // - cache_last_channel + // - cache_last_time_ + // - cache_last_channel_len + std::vector GetEncoderInitStates() { + std::vector ans; + ans.reserve(3); + ans.push_back(View(cache_last_channel_)); + ans.push_back(View(cache_last_time_)); + ans.push_back(View(cache_last_channel_len_)); + + return ans; + } + + std::vector StackStates( + std::vector> states) const { + int32_t batch_size = static_cast(states.size()); + if (batch_size == 1) { + return std::move(states[0]); + } + + std::vector ans; + + auto allocator = const_cast(this)->allocator_; + + // stack cache_last_channel + std::vector buf(batch_size); + + // there are 3 states to be stacked + for (int32_t i = 0; i != 3; ++i) { + buf.clear(); + buf.reserve(batch_size); + + for (int32_t b = 0; b != batch_size; ++b) { + assert(states[b].size() == 3); + buf.push_back(states[b][i]); + } + + MNN::Express::VARP c{nullptr}; + if (i == 2) { + c = Cat(allocator, buf, 0); + } else { + c = Cat(allocator, buf, 0); + } + + ans.push_back(std::move(c)); + } + + return ans; + } + + std::vector> UnStackStates( + std::vector states) { + assert(states.size() == 3); + + std::vector> ans; + + auto shape = states[0]->getInfo()->dim; + int32_t batch_size = shape[0]; + ans.resize(batch_size); + + if (batch_size == 1) { + ans[0] = std::move(states); + return ans; + } + + for (int32_t i = 0; i != 3; ++i) { + std::vector v; + if (i == 2) { + v = Unbind(allocator_, states[i], 0); + } else { + v = Unbind(allocator_, states[i], 0); + } + + assert(v.size() == batch_size); + + for (int32_t b = 0; b != batch_size; ++b) { + ans[b].push_back(std::move(v[b])); + } + } + + return ans; + } + + private: + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = encoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + + // need to increase by 1 since the blank token is not included in computing + // vocab_size in NeMo. + vocab_size_ += 1; + + SHERPA_ONNX_READ_META_DATA(window_size_, "window_size"); + SHERPA_ONNX_READ_META_DATA(chunk_shift_, "chunk_shift"); + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); + SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type"); + SHERPA_ONNX_READ_META_DATA(pred_rnn_layers_, "pred_rnn_layers"); + SHERPA_ONNX_READ_META_DATA(pred_hidden_, "pred_hidden"); + + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim1_, + "cache_last_channel_dim1"); + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim2_, + "cache_last_channel_dim2"); + SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim3_, + "cache_last_channel_dim3"); + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim1_, "cache_last_time_dim1"); + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim2_, "cache_last_time_dim2"); + SHERPA_ONNX_READ_META_DATA(cache_last_time_dim3_, "cache_last_time_dim3"); + + if (normalize_type_ == "NA") { + normalize_type_ = ""; + } + + InitEncoderStates(); + } + + void InitEncoderStates() { + std::array cache_last_channel_shape{1, cache_last_channel_dim1_, + cache_last_channel_dim2_, + cache_last_channel_dim3_}; + + cache_last_channel_ = MNNUtilsCreateTensor( + allocator_, cache_last_channel_shape.data(), + cache_last_channel_shape.size()); + + Fill(cache_last_channel_, 0); + + std::array cache_last_time_shape{ + 1, cache_last_time_dim1_, cache_last_time_dim2_, cache_last_time_dim3_}; + + cache_last_time_ = MNNUtilsCreateTensor( + allocator_, cache_last_time_shape.data(), cache_last_time_shape.size()); + + Fill(cache_last_time_, 0); + + int shape = 1; + cache_last_channel_len_ = + MNNUtilsCreateTensor(allocator_, &shape, 1); + + cache_last_channel_len_->writeMap()[0] = 0; + } + + void InitDecoder(void *model_data, size_t model_data_length) { + decoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + + InitDecoderStates(); + } + + void InitDecoderStates() { + int32_t batch_size = 1; + std::array s0_shape{pred_rnn_layers_, batch_size, pred_hidden_}; + lstm0_ = MNNUtilsCreateTensor(allocator_, s0_shape.data(), + s0_shape.size()); + + Fill(lstm0_, 0); + + std::array s1_shape{pred_rnn_layers_, batch_size, pred_hidden_}; + + lstm1_ = MNNUtilsCreateTensor(allocator_, s1_shape.data(), + s1_shape.size()); + + Fill(lstm1_, 0); + } + + void InitJoiner(void *model_data, size_t model_data_length) { + joiner_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + } + + private: + OnlineModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + int32_t window_size_ = 0; + int32_t chunk_shift_ = 0; + int32_t vocab_size_ = 0; + int32_t subsampling_factor_ = 8; + std::string normalize_type_; + int32_t pred_rnn_layers_ = -1; + int32_t pred_hidden_ = -1; + + // encoder states + int32_t cache_last_channel_dim1_ = 0; + int32_t cache_last_channel_dim2_ = 0; + int32_t cache_last_channel_dim3_ = 0; + int32_t cache_last_time_dim1_ = 0; + int32_t cache_last_time_dim2_ = 0; + int32_t cache_last_time_dim3_ = 0; + + // init encoder states + MNN::Express::VARP cache_last_channel_{nullptr}; + MNN::Express::VARP cache_last_time_{nullptr}; + MNN::Express::VARP cache_last_channel_len_{nullptr}; + + // init decoder states + MNN::Express::VARP lstm0_{nullptr}; + MNN::Express::VARP lstm1_{nullptr}; +}; + +OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( + const OnlineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( + Manager *mgr, const OnlineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OnlineTransducerNeMoModel::~OnlineTransducerNeMoModel() = default; + +std::vector OnlineTransducerNeMoModel::RunEncoder( + MNN::Express::VARP features, std::vector states) const { + return impl_->RunEncoder(std::move(features), std::move(states)); +} + +std::pair> +OnlineTransducerNeMoModel::RunDecoder(MNN::Express::VARP targets, + std::vector states) const { + return impl_->RunDecoder(std::move(targets), std::move(states)); +} + +std::vector OnlineTransducerNeMoModel::GetDecoderInitStates() + const { + return impl_->GetDecoderInitStates(); +} + +MNN::Express::VARP OnlineTransducerNeMoModel::RunJoiner(MNN::Express::VARP encoder_out, + MNN::Express::VARP decoder_out) const { + return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out)); +} + +int32_t OnlineTransducerNeMoModel::ChunkSize() const { + return impl_->ChunkSize(); +} + +int32_t OnlineTransducerNeMoModel::ChunkShift() const { + return impl_->ChunkShift(); +} + +int32_t OnlineTransducerNeMoModel::SubsamplingFactor() const { + return impl_->SubsamplingFactor(); +} + +int32_t OnlineTransducerNeMoModel::VocabSize() const { + return impl_->VocabSize(); +} + +MNNAllocator *OnlineTransducerNeMoModel::Allocator() const { + return impl_->Allocator(); +} + +std::string OnlineTransducerNeMoModel::FeatureNormalizationMethod() const { + return impl_->FeatureNormalizationMethod(); +} + +std::vector OnlineTransducerNeMoModel::GetEncoderInitStates() + const { + return impl_->GetEncoderInitStates(); +} + +std::vector OnlineTransducerNeMoModel::StackStates( + std::vector> states) const { + return impl_->StackStates(std::move(states)); +} + +std::vector> OnlineTransducerNeMoModel::UnStackStates( + std::vector states) const { + return impl_->UnStackStates(std::move(states)); +} + +#if __ANDROID_API__ >= 9 +template OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineTransducerNeMoModel::OnlineTransducerNeMoModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-nemo-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-nemo-model.h new file mode 100644 index 00000000..a7ea654c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-transducer-nemo-model.h @@ -0,0 +1,124 @@ +// sherpa-mnn/csrc/online-transducer-nemo-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Sangeet Sagar + +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_ + +#include +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/online-model-config.h" + +namespace sherpa_mnn { + +// see +// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py#L40 +// Its decoder is stateful, not stateless. +class OnlineTransducerNeMoModel { + public: + explicit OnlineTransducerNeMoModel(const OnlineModelConfig &config); + + template + OnlineTransducerNeMoModel(Manager *mgr, const OnlineModelConfig &config); + + ~OnlineTransducerNeMoModel(); + // A list of 3 tensors: + // - cache_last_channel + // - cache_last_time + // - cache_last_channel_len + std::vector GetEncoderInitStates() const; + + // stack encoder states + std::vector StackStates( + std::vector> states) const; + + // unstack encoder states + std::vector> UnStackStates( + std::vector states) const; + + /** Run the encoder. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param states It is from GetEncoderInitStates() or returned from this + * method. + * + * @return Return a tuple containing: + * - ans[0]: encoder_out, a tensor of shape (N, encoder_out_dim, T') + * - ans[1:]: contains next states + */ + std::vector RunEncoder( + MNN::Express::VARP features, std::vector states) const; // NOLINT + + /** Run the decoder network. + * + * @param targets A int32 tensor of shape (batch_size, 1) + * @param states The states for the decoder model. + * @return Return a vector: + * - ans[0] is the decoder_out (a float tensor) + * - ans[1:] is the next states + */ + std::pair> RunDecoder( + MNN::Express::VARP targets, std::vector states) const; + + std::vector GetDecoderInitStates() const; + + /** Run the joint network. + * + * @param encoder_out Output of the encoder network. + * @param decoder_out Output of the decoder network. + * @return Return a tensor of shape (N, 1, 1, vocab_size) containing logits. + */ + MNN::Express::VARP RunJoiner(MNN::Express::VARP encoder_out, MNN::Express::VARP decoder_out) const; + + /** We send this number of feature frames to the encoder at a time. */ + int32_t ChunkSize() const; + + /** Number of input frames to discard after each call to RunEncoder. + * + * For instance, if we have 30 frames, chunk_size=8, chunk_shift=6. + * + * In the first call of RunEncoder, we use frames 0~7 since chunk_size is 8. + * Then we discard frame 0~5 since chunk_shift is 6. + * In the second call of RunEncoder, we use frames 6~13; and then we discard + * frames 6~11. + * In the third call of RunEncoder, we use frames 12~19; and then we discard + * frames 12~16. + * + * Note: ChunkSize() - ChunkShift() == right context size + */ + int32_t ChunkShift() const; + + /** Return the subsampling factor of the model. + */ + int32_t SubsamplingFactor() const; + + int32_t VocabSize() const; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const; + + // Possible values: + // - per_feature + // - all_features (not implemented yet) + // - fixed_mean (not implemented) + // - fixed_std (not implemented) + // - or just leave it to empty + // See + // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59 + // for details + std::string FeatureNormalizationMethod() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_NEMO_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-websocket-client.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-websocket-client.cc new file mode 100644 index 00000000..0559c1d8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-websocket-client.cc @@ -0,0 +1,274 @@ +// sherpa/cpp_api/websocket/online-websocket-client.cc +// +// Copyright (c) 2022 Xiaomi Corporation +#include // NOLINT +#include +#include + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/parse-options.h" +#include "sherpa-mnn/csrc/wave-reader.h" +#include "websocketpp/client.hpp" +#include "websocketpp/config/asio_no_tls_client.hpp" +#include "websocketpp/uri.hpp" + +using client = websocketpp::client; + +using message_ptr = client::message_ptr; +using websocketpp::connection_hdl; + +static constexpr const char *kUsageMessage = R"( +Automatic speech recognition with sherpa-mnn using websocket. + +Usage: + +./bin/sherpa-mnn-online-websocket-client --help + +./bin/sherpa-mnn-online-websocket-client \ + --server-ip=127.0.0.1 \ + --server-port=6006 \ + --samples-per-message=8000 \ + --seconds-per-message=0.2 \ + /path/to/foo.wav + +It support only wave of with a single channel, 16kHz, 16-bit samples. +)"; + +class Client { + public: + Client(asio::io_context &io, // NOLINT + const std::string &ip, int16_t port, const std::vector &samples, + int32_t samples_per_message, float seconds_per_message) + : io_(io), + uri_(/*secure*/ false, ip, port, /*resource*/ "/"), + samples_(samples), + samples_per_message_(samples_per_message), + seconds_per_message_(seconds_per_message) { + c_.clear_access_channels(websocketpp::log::alevel::all); + // c_.set_access_channels(websocketpp::log::alevel::connect); + // c_.set_access_channels(websocketpp::log::alevel::disconnect); + + c_.init_asio(&io_); + c_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); }); + c_.set_close_handler( + [](connection_hdl /*hdl*/) { SHERPA_ONNX_LOGE("Disconnected"); }); + c_.set_message_handler( + [this](connection_hdl hdl, message_ptr msg) { OnMessage(hdl, msg); }); + + Run(); + } + + private: + void Run() { + websocketpp::lib::error_code ec; + client::connection_ptr con = c_.get_connection(uri_.str(), ec); + if (ec) { + SHERPA_ONNX_LOGE("Could not create connection to %s because %s", + uri_.str().c_str(), ec.message().c_str()); + exit(EXIT_FAILURE); + } + + c_.connect(con); + } + + void OnOpen(connection_hdl hdl) { + auto start_time = std::chrono::steady_clock::now(); + asio::post( + io_, [this, hdl, start_time]() { this->SendMessage(hdl, start_time); }); + } + + void OnMessage(connection_hdl hdl, message_ptr msg) { + const std::string &payload = msg->get_payload(); + + if (payload == "Done!") { + websocketpp::lib::error_code ec; + c_.close(hdl, websocketpp::close::status::normal, "I'm exiting now", ec); + if (ec) { + SHERPA_ONNX_LOGE("Failed to close because %s", ec.message().c_str()); + exit(EXIT_FAILURE); + } + } else { + SHERPA_ONNX_LOGE("%s", payload.c_str()); + } + } + + void SendMessage( + connection_hdl hdl, + std::chrono::time_point start_time) { + int32_t num_samples = samples_.size(); + int32_t num_messages = num_samples / samples_per_message_; + + websocketpp::lib::error_code ec; + auto time = std::chrono::steady_clock::now(); + int elapsed_time_ms = + std::chrono::duration_cast(time - start_time) + .count(); + + if (elapsed_time_ms < + static_cast(seconds_per_message_ * num_sent_messages_ * 1000)) { + std::this_thread::sleep_for(std::chrono::milliseconds(int( + seconds_per_message_ * num_sent_messages_ * 1000 - elapsed_time_ms))); + } + + if (num_sent_messages_ < 1) { + SHERPA_ONNX_LOGE("Starting to send audio"); + } + + if (num_sent_messages_ < num_messages) { + c_.send(hdl, samples_.data() + num_sent_messages_ * samples_per_message_, + samples_per_message_ * sizeof(float), + websocketpp::frame::opcode::binary, ec); + + if (ec) { + SHERPA_ONNX_LOGE("Failed to send audio samples because %s", + ec.message().c_str()); + exit(EXIT_FAILURE); + } + + ec.clear(); + + ++num_sent_messages_; + } + + if (num_sent_messages_ == num_messages) { + int32_t remaining_samples = num_samples % samples_per_message_; + if (remaining_samples) { + c_.send(hdl, + samples_.data() + num_sent_messages_ * samples_per_message_, + remaining_samples * sizeof(float), + websocketpp::frame::opcode::binary, ec); + + if (ec) { + SHERPA_ONNX_LOGE("Failed to send audio samples because %s", + ec.message().c_str()); + exit(EXIT_FAILURE); + } + ec.clear(); + } + + // To signal that we have send all the messages + c_.send(hdl, "Done", websocketpp::frame::opcode::text, ec); + SHERPA_ONNX_LOGE("Sent Done Signal"); + + if (ec) { + SHERPA_ONNX_LOGE("Failed to send audio samples because %s", + ec.message().c_str()); + exit(EXIT_FAILURE); + } + } else { + asio::post(io_, [this, hdl, start_time]() { + this->SendMessage(hdl, start_time); + }); + } + } + + private: + client c_; + asio::io_context &io_; + websocketpp::uri uri_; + std::vector samples_; + int32_t samples_per_message_ = 8000; // 0.5 seconds + float seconds_per_message_ = 0.2; + int32_t num_sent_messages_ = 0; +}; + +int32_t main(int32_t argc, char *argv[]) { + std::string server_ip = "127.0.0.1"; + int32_t server_port = 6006; + + // Sample rate of the input wave. No resampling is made. + int32_t sample_rate = 16000; + int32_t samples_per_message = 8000; + float seconds_per_message = 0.2; + + sherpa_mnn::ParseOptions po(kUsageMessage); + + po.Register("server-ip", &server_ip, "IP address of the websocket server"); + po.Register("server-port", &server_port, "Port of the websocket server"); + po.Register("sample-rate", &sample_rate, + "Sample rate of the input wave. Should be the one expected by " + "the server"); + + po.Register("samples-per-message", &samples_per_message, + "Send this number of samples per message."); + + po.Register("seconds-per-message", &seconds_per_message, + "We will simulate that each message takes this number of seconds " + "to send. If you select a very large value, it will take a long " + "time to send all the samples"); + + po.Read(argc, argv); + + if (!websocketpp::uri_helper::ipv4_literal(server_ip.begin(), + server_ip.end())) { + SHERPA_ONNX_LOGE("Invalid server IP: %s", server_ip.c_str()); + return -1; + } + + if (server_port <= 0 || server_port > 65535) { + SHERPA_ONNX_LOGE("Invalid server port: %d", server_port); + return -1; + } + + // 0.01 is an arbitrary value. You can change it. + if (samples_per_message <= 0.01 * sample_rate) { + SHERPA_ONNX_LOGE("--samples-per-message is too small: %d", + samples_per_message); + return -1; + } + + // 100 is an arbitrary value. You can change it. + if (samples_per_message >= sample_rate * 100) { + SHERPA_ONNX_LOGE("--samples-per-message is too small: %d", + samples_per_message); + return -1; + } + + if (seconds_per_message < 0) { + SHERPA_ONNX_LOGE("--seconds-per-message is too small: %.3f", + seconds_per_message); + return -1; + } + + // 1 is an arbitrary value. + if (seconds_per_message > 1) { + SHERPA_ONNX_LOGE( + "--seconds-per-message is too large: %.3f. You will wait a long time " + "to " + "send all the samples", + seconds_per_message); + return -1; + } + + if (po.NumArgs() != 1) { + po.PrintUsage(); + return -1; + } + + std::string wave_filename = po.GetArg(1); + + bool is_ok = false; + int32_t actual_sample_rate = -1; + std::vector samples = + sherpa_mnn::ReadWave(wave_filename, &actual_sample_rate, &is_ok); + + if (!is_ok) { + SHERPA_ONNX_LOGE("Failed to read '%s'", wave_filename.c_str()); + return -1; + } + + if (actual_sample_rate != sample_rate) { + SHERPA_ONNX_LOGE("Expected sample rate: %d, given %d", sample_rate, + actual_sample_rate); + return -1; + } + + asio::io_context io_conn; // for network connections + Client c(io_conn, server_ip, server_port, samples, samples_per_message, + seconds_per_message); + + io_conn.run(); // will exit when the above connection is closed + + SHERPA_ONNX_LOGE("Done!"); + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-websocket-server-impl.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-websocket-server-impl.cc new file mode 100644 index 00000000..19ba4f73 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-websocket-server-impl.cc @@ -0,0 +1,365 @@ +// sherpa-mnn/csrc/online-websocket-server-impl.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-websocket-server-impl.h" + +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/log.h" + +namespace sherpa_mnn { + +void OnlineWebsocketDecoderConfig::Register(ParseOptions *po) { + recognizer_config.Register(po); + + po->Register("loop-interval-ms", &loop_interval_ms, + "It determines how often the decoder loop runs. "); + + po->Register("max-batch-size", &max_batch_size, + "Max batch size for recognition."); + + po->Register("end-tail-padding", &end_tail_padding, + "It determines the length of tail_padding at the end of audio."); +} + +void OnlineWebsocketDecoderConfig::Validate() const { + recognizer_config.Validate(); + SHERPA_ONNX_CHECK_GT(loop_interval_ms, 0); + SHERPA_ONNX_CHECK_GT(max_batch_size, 0); + SHERPA_ONNX_CHECK_GT(end_tail_padding, 0); +} + +void OnlineWebsocketServerConfig::Register(sherpa_mnn::ParseOptions *po) { + decoder_config.Register(po); + + po->Register("log-file", &log_file, + "Path to the log file. Logs are " + "appended to this file"); +} + +void OnlineWebsocketServerConfig::Validate() const { + decoder_config.Validate(); +} + +OnlineWebsocketDecoder::OnlineWebsocketDecoder(OnlineWebsocketServer *server) + : server_(server), + config_(server->GetConfig().decoder_config), + timer_(server->GetWorkContext()) { + recognizer_ = std::make_unique(config_.recognizer_config); +} + +std::shared_ptr OnlineWebsocketDecoder::GetOrCreateConnection( + connection_hdl hdl) { + std::lock_guard lock(mutex_); + auto it = connections_.find(hdl); + if (it != connections_.end()) { + return it->second; + } else { + // create a new connection + std::shared_ptr s = recognizer_->CreateStream(); + auto c = std::make_shared(hdl, s); + connections_.insert({hdl, c}); + return c; + } +} + +void OnlineWebsocketDecoder::AcceptWaveform(std::shared_ptr c) { + std::lock_guard lock(c->mutex); + float sample_rate = config_.recognizer_config.feat_config.sampling_rate; + while (!c->samples.empty()) { + const auto &s = c->samples.front(); + c->s->AcceptWaveform(sample_rate, s.data(), s.size()); + c->samples.pop_front(); + } +} + +void OnlineWebsocketDecoder::InputFinished(std::shared_ptr c) { + std::lock_guard lock(c->mutex); + + float sample_rate = config_.recognizer_config.feat_config.sampling_rate; + + while (!c->samples.empty()) { + const auto &s = c->samples.front(); + c->s->AcceptWaveform(sample_rate, s.data(), s.size()); + c->samples.pop_front(); + } + + std::vector tail_padding( + static_cast(config_.end_tail_padding * sample_rate)); + + c->s->AcceptWaveform(sample_rate, tail_padding.data(), tail_padding.size()); + + c->s->InputFinished(); + c->eof = true; +} + +void OnlineWebsocketDecoder::Warmup() const { + recognizer_->WarmpUpRecognizer(config_.recognizer_config.model_config.warm_up, + config_.max_batch_size); +} + +void OnlineWebsocketDecoder::Run() { + timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms)); + + timer_.async_wait( + [this](const asio::error_code &ec) { ProcessConnections(ec); }); +} + +void OnlineWebsocketDecoder::ProcessConnections(const asio::error_code &ec) { + if (ec) { + SHERPA_ONNX_LOG(FATAL) << "The decoder loop is aborted!"; + } + + std::lock_guard lock(mutex_); + std::vector to_remove; + for (auto &p : connections_) { + auto hdl = p.first; + auto c = p.second; + + // The order of `if` below matters! + if (!server_->Contains(hdl)) { + // If the connection is disconnected, we stop processing it + to_remove.push_back(hdl); + continue; + } + + if (active_.count(hdl)) { + // Another thread is decoding this stream, so skip it + continue; + } + + if (!recognizer_->IsReady(c->s.get()) && !c->eof) { + // this stream has not enough frames to decode, so skip it + continue; + } + + if (!recognizer_->IsReady(c->s.get()) && c->eof) { + // We won't receive samples from the client, so send a Done! to client + + asio::post(server_->GetWorkContext(), + [this, hdl = c->hdl]() { server_->Send(hdl, "Done!"); }); + + to_remove.push_back(hdl); + continue; + } + + // TODO(fangun): If the connection is timed out, we need to also + // add it to `to_remove` + + // this stream has enough frames and is currently not processed by any + // threads, so put it into the ready queue + ready_connections_.push_back(c); + + // In `Decode()`, it will remove hdl from `active_` + active_.insert(c->hdl); + } + + for (auto hdl : to_remove) { + connections_.erase(hdl); + } + + if (!ready_connections_.empty()) { + asio::post(server_->GetWorkContext(), [this]() { Decode(); }); + } + + // Schedule another call + timer_.expires_after(std::chrono::milliseconds(config_.loop_interval_ms)); + + timer_.async_wait( + [this](const asio::error_code &ec) { ProcessConnections(ec); }); +} + +void OnlineWebsocketDecoder::Decode() { + std::unique_lock lock(mutex_); + if (ready_connections_.empty()) { + // There are no connections that are ready for decoding, + // so we return directly + return; + } + + std::vector> c_vec; + std::vector s_vec; + while (!ready_connections_.empty() && + static_cast(s_vec.size()) < config_.max_batch_size) { + auto c = ready_connections_.front(); + ready_connections_.pop_front(); + + c_vec.push_back(c); + s_vec.push_back(c->s.get()); + } + + if (!ready_connections_.empty()) { + // there are too many ready connections but this thread can only handle + // max_batch_size connections at a time, so we schedule another call + // to Decode() and let other threads to process the ready connections + asio::post(server_->GetWorkContext(), [this]() { Decode(); }); + } + + lock.unlock(); + recognizer_->DecodeStreams(s_vec.data(), s_vec.size()); + lock.lock(); + + for (auto c : c_vec) { + auto result = recognizer_->GetResult(c->s.get()); + if (recognizer_->IsEndpoint(c->s.get())) { + result.is_final = true; + recognizer_->Reset(c->s.get()); + } + + if (!recognizer_->IsReady(c->s.get()) && c->eof) { + result.is_final = true; + } + + asio::post(server_->GetConnectionContext(), + [this, hdl = c->hdl, str = result.AsJsonString()]() { + server_->Send(hdl, str); + }); + active_.erase(c->hdl); + } +} + +OnlineWebsocketServer::OnlineWebsocketServer( + asio::io_context &io_conn, asio::io_context &io_work, + const OnlineWebsocketServerConfig &config) + : config_(config), + io_conn_(io_conn), + io_work_(io_work), + log_(config.log_file, std::ios::app), + tee_(std::cout, log_), + decoder_(this) { + SetupLog(); + + server_.init_asio(&io_conn_); + + server_.set_open_handler([this](connection_hdl hdl) { OnOpen(hdl); }); + + server_.set_close_handler([this](connection_hdl hdl) { OnClose(hdl); }); + + server_.set_message_handler( + [this](connection_hdl hdl, server::message_ptr msg) { + OnMessage(hdl, msg); + }); +} + +void OnlineWebsocketServer::Run(uint16_t port) { + server_.set_reuse_addr(true); + server_.listen(asio::ip::tcp::v4(), port); + server_.start_accept(); + auto recognizer_config = config_.decoder_config.recognizer_config; + int32_t warm_up = recognizer_config.model_config.warm_up; + const std::string &model_type = recognizer_config.model_config.model_type; + if (0 < warm_up && warm_up < 100) { + if (model_type == "zipformer2") { + decoder_.Warmup(); + SHERPA_ONNX_LOGE("Warm up completed : %d times.", warm_up); + } else { + SHERPA_ONNX_LOGE("Only Zipformer2 has warmup support for now."); + SHERPA_ONNX_LOGE("Given: %s", model_type.c_str()); + exit(0); + } + } else if (warm_up == 0) { + SHERPA_ONNX_LOGE("Starting without warmup!"); + } else { + SHERPA_ONNX_LOGE("Invalid Warm up Value!. Expected 0 < warm_up < 100"); + exit(0); + } + decoder_.Run(); +} + +void OnlineWebsocketServer::SetupLog() { + server_.clear_access_channels(websocketpp::log::alevel::all); + // server_.set_access_channels(websocketpp::log::alevel::connect); + // server_.set_access_channels(websocketpp::log::alevel::disconnect); + + // So that it also prints to std::cout and std::cerr + server_.get_alog().set_ostream(&tee_); + server_.get_elog().set_ostream(&tee_); +} + +void OnlineWebsocketServer::Send(connection_hdl hdl, const std::string &text) { + websocketpp::lib::error_code ec; + if (!Contains(hdl)) { + return; + } + + server_.send(hdl, text, websocketpp::frame::opcode::text, ec); + if (ec) { + server_.get_alog().write(websocketpp::log::alevel::app, ec.message()); + } +} + +void OnlineWebsocketServer::OnOpen(connection_hdl hdl) { + std::lock_guard lock(mutex_); + connections_.insert(hdl); + + std::ostringstream os; + os << "New connection: " + << server_.get_con_from_hdl(hdl)->get_remote_endpoint() << ". " + << "Number of active connections: " << connections_.size() << ".\n"; + SHERPA_ONNX_LOG(INFO) << os.str(); +} + +void OnlineWebsocketServer::OnClose(connection_hdl hdl) { + std::lock_guard lock(mutex_); + connections_.erase(hdl); + + SHERPA_ONNX_LOG(INFO) << "Number of active connections: " + << connections_.size() << "\n"; +} + +bool OnlineWebsocketServer::Contains(connection_hdl hdl) const { + std::lock_guard lock(mutex_); + return connections_.count(hdl); +} + +void OnlineWebsocketServer::OnMessage(connection_hdl hdl, + server::message_ptr msg) { + auto c = decoder_.GetOrCreateConnection(hdl); + + const std::string &payload = msg->get_payload(); + + switch (msg->get_opcode()) { + case websocketpp::frame::opcode::text: + if (payload == "Done") { + asio::post(io_work_, [this, c]() { decoder_.InputFinished(c); }); + } + break; + case websocketpp::frame::opcode::binary: { + auto p = reinterpret_cast(payload.data()); + int32_t num_samples = payload.size() / sizeof(float); + std::vector samples(p, p + num_samples); + + { + std::lock_guard lock(c->mutex); + c->samples.push_back(std::move(samples)); + } + + asio::post(io_work_, [this, c]() { decoder_.AcceptWaveform(c); }); + break; + } + default: + break; + } +} + +void OnlineWebsocketServer::Close(connection_hdl hdl, + websocketpp::close::status::value code, + const std::string &reason) { + auto con = server_.get_con_from_hdl(hdl); + + std::ostringstream os; + os << "Closing " << con->get_remote_endpoint() << " with reason: " << reason + << "\n"; + + websocketpp::lib::error_code ec; + server_.close(hdl, code, reason, ec); + if (ec) { + os << "Failed to close" << con->get_remote_endpoint() << ". " + << ec.message() << "\n"; + } + server_.get_alog().write(websocketpp::log::alevel::app, os.str()); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-websocket-server-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-websocket-server-impl.h new file mode 100644 index 00000000..1eedbd02 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-websocket-server-impl.h @@ -0,0 +1,181 @@ +// sherpa-mnn/csrc/online-websocket-server-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_ + +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include + +#include "asio.hpp" +#include "sherpa-mnn/csrc/online-recognizer.h" +#include "sherpa-mnn/csrc/online-stream.h" +#include "sherpa-mnn/csrc/parse-options.h" +#include "sherpa-mnn/csrc/tee-stream.h" +#include "websocketpp/config/asio_no_tls.hpp" // TODO(fangjun): support TLS +#include "websocketpp/server.hpp" +using server = websocketpp::server; +using connection_hdl = websocketpp::connection_hdl; + +namespace sherpa_mnn { + +struct Connection { + // handle to the connection. We can use it to send messages to the client + connection_hdl hdl; + std::shared_ptr s; + + // set it to true when InputFinished() is called + bool eof = false; + + // The last time we received a message from the client + // TODO(fangjun): Use it to disconnect from a client if it is inactive + // for a specified time. + std::chrono::steady_clock::time_point last_active; + + std::mutex mutex; // protect samples + + // Audio samples received from the client. + // + // The I/O threads receive audio samples into this queue + // and invoke work threads to compute features + std::deque> samples; + + Connection() = default; + Connection(connection_hdl hdl, std::shared_ptr s) + : hdl(hdl), s(s), last_active(std::chrono::steady_clock::now()) {} +}; + +struct OnlineWebsocketDecoderConfig { + OnlineRecognizerConfig recognizer_config; + + // It determines how often the decoder loop runs. + int32_t loop_interval_ms = 10; + + int32_t max_batch_size = 5; + + float end_tail_padding = 0.8; + + void Register(ParseOptions *po); + void Validate() const; +}; + +class OnlineWebsocketServer; + +class OnlineWebsocketDecoder { + public: + /** + * @param server Not owned. + */ + explicit OnlineWebsocketDecoder(OnlineWebsocketServer *server); + + std::shared_ptr GetOrCreateConnection(connection_hdl hdl); + + // Compute features for a stream given audio samples + void AcceptWaveform(std::shared_ptr c); + + // signal that there will be no more audio samples for a stream + void InputFinished(std::shared_ptr c); + + void Warmup() const; + + void Run(); + + private: + void ProcessConnections(const asio::error_code &ec); + + /** It is called by one of the worker thread. + */ + void Decode(); + + private: + OnlineWebsocketServer *server_; // not owned + std::unique_ptr recognizer_; + OnlineWebsocketDecoderConfig config_; + asio::steady_timer timer_; + + // It protects `connections_`, `ready_connections_`, and `active_` + std::mutex mutex_; + + std::map, + std::owner_less> + connections_; + + // Whenever a connection has enough feature frames for decoding, we put + // it in this queue + std::deque> ready_connections_; + + // If we are decoding a stream, we put it in the active_ set so that + // only one thread can decode a stream at a time. + std::set> active_; +}; + +struct OnlineWebsocketServerConfig { + OnlineWebsocketDecoderConfig decoder_config; + + std::string log_file = "./log.txt"; + + void Register(sherpa_mnn::ParseOptions *po); + void Validate() const; +}; + +class OnlineWebsocketServer { + public: + explicit OnlineWebsocketServer(asio::io_context &io_conn, // NOLINT + asio::io_context &io_work, // NOLINT + const OnlineWebsocketServerConfig &config); + + void Run(uint16_t port); + + const OnlineWebsocketServerConfig &GetConfig() const { return config_; } + asio::io_context &GetConnectionContext() { return io_conn_; } + asio::io_context &GetWorkContext() { return io_work_; } + server &GetServer() { return server_; } + + void Send(connection_hdl hdl, const std::string &text); + + bool Contains(connection_hdl hdl) const; + + private: + void SetupLog(); + + // When a websocket client is connected, it will invoke this method + // (Not for HTTP) + void OnOpen(connection_hdl hdl); + + // When a websocket client is disconnected, it will invoke this method + void OnClose(connection_hdl hdl); + + void OnMessage(connection_hdl hdl, server::message_ptr msg); + + // Close a websocket connection with given code and reason + void Close(connection_hdl hdl, websocketpp::close::status::value code, + const std::string &reason); + + private: + OnlineWebsocketServerConfig config_; + asio::io_context &io_conn_; + asio::io_context &io_work_; + server server_; + + std::ofstream log_; + sherpa_mnn::TeeStream tee_; + + OnlineWebsocketDecoder decoder_; + + mutable std::mutex mutex_; + + std::set> connections_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_WEBSOCKET_SERVER_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-websocket-server.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-websocket-server.cc new file mode 100644 index 00000000..9017e727 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-websocket-server.cc @@ -0,0 +1,109 @@ +// sherpa-mnn/csrc/online-websocket-server.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "asio.hpp" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-websocket-server-impl.h" +#include "sherpa-mnn/csrc/parse-options.h" + +static constexpr const char *kUsageMessage = R"( +Automatic speech recognition with sherpa-mnn using websocket. + +Usage: + +./bin/sherpa-mnn-online-websocket-server --help + +./bin/sherpa-mnn-online-websocket-server \ + --port=6006 \ + --num-work-threads=5 \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --log-file=./log.txt \ + --max-batch-size=5 \ + --loop-interval-ms=10 + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. +)"; + +int32_t main(int32_t argc, char *argv[]) { + sherpa_mnn::ParseOptions po(kUsageMessage); + + sherpa_mnn::OnlineWebsocketServerConfig config; + + // the server will listen on this port + int32_t port = 6006; + + // size of the thread pool for handling network connections + int32_t num_io_threads = 1; + + // size of the thread pool for neural network computation and decoding + int32_t num_work_threads = 3; + + po.Register("num-io-threads", &num_io_threads, + "Thread pool size for network connections."); + + po.Register("num-work-threads", &num_work_threads, + "Thread pool size for for neural network " + "computation and decoding."); + + po.Register("port", &port, "The port on which the server will listen."); + + config.Register(&po); + + if (argc == 1) { + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + po.Read(argc, argv); + + if (po.NumArgs() != 0) { + SHERPA_ONNX_LOGE("Unrecognized positional arguments!"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + config.Validate(); + + asio::io_context io_conn; // for network connections + asio::io_context io_work; // for neural network and decoding + + sherpa_mnn::OnlineWebsocketServer server(io_conn, io_work, config); + server.Run(port); + + SHERPA_ONNX_LOGE("Started!"); + SHERPA_ONNX_LOGE("Listening on: %d", port); + SHERPA_ONNX_LOGE("Number of work threads: %d", num_work_threads); + + // give some work to do for the io_work pool + auto work_guard = asio::make_work_guard(io_work); + + std::vector io_threads; + + // decrement since the main thread is also used for network communications + for (int32_t i = 0; i < num_io_threads - 1; ++i) { + io_threads.emplace_back([&io_conn]() { io_conn.run(); }); + } + + std::vector work_threads; + for (int32_t i = 0; i < num_work_threads; ++i) { + work_threads.emplace_back([&io_work]() { io_work.run(); }); + } + + io_conn.run(); + + for (auto &t : io_threads) { + t.join(); + } + + for (auto &t : work_threads) { + t.join(); + } + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-wenet-ctc-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-wenet-ctc-model-config.cc new file mode 100644 index 00000000..acec71c9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-wenet-ctc-model-config.cc @@ -0,0 +1,59 @@ +// sherpa-mnn/csrc/online-wenet-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-wenet-ctc-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OnlineWenetCtcModelConfig::Register(ParseOptions *po) { + po->Register("wenet-ctc-model", &model, + "Path to CTC model.onnx from WeNet. Please see " + "https://github.com/k2-fsa/sherpa-mnn/pull/425"); + po->Register("wenet-ctc-chunk-size", &chunk_size, + "Chunk size after subsampling used for decoding."); + po->Register("wenet-ctc-num-left-chunks", &num_left_chunks, + "Number of left chunks after subsampling used for decoding."); +} + +bool OnlineWenetCtcModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("WeNet CTC model '%s' does not exist", model.c_str()); + return false; + } + + if (chunk_size <= 0) { + SHERPA_ONNX_LOGE( + "Please specify a positive value for --wenet-ctc-chunk-size. Currently " + "given: %d", + chunk_size); + return false; + } + + if (num_left_chunks <= 0) { + SHERPA_ONNX_LOGE( + "Please specify a positive value for --wenet-ctc-num-left-chunks. " + "Currently given: %d. Note that if you want to use -1, please consider " + "using a non-streaming model.", + num_left_chunks); + return false; + } + + return true; +} + +std::string OnlineWenetCtcModelConfig::ToString() const { + std::ostringstream os; + + os << "OnlineWenetCtcModelConfig("; + os << "model=\"" << model << "\", "; + os << "chunk_size=" << chunk_size << ", "; + os << "num_left_chunks=" << num_left_chunks << ")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-wenet-ctc-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-wenet-ctc-model-config.h new file mode 100644 index 00000000..41f7c27a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-wenet-ctc-model-config.h @@ -0,0 +1,38 @@ +// sherpa-mnn/csrc/online-wenet-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OnlineWenetCtcModelConfig { + std::string model; + + // --chunk_size from wenet + int32_t chunk_size = 16; + + // --num_left_chunks from wenet + int32_t num_left_chunks = 4; + + OnlineWenetCtcModelConfig() = default; + + OnlineWenetCtcModelConfig(const std::string &model, int32_t chunk_size, + int32_t num_left_chunks) + : model(model), + chunk_size(chunk_size), + num_left_chunks(num_left_chunks) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-wenet-ctc-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-wenet-ctc-model.cc new file mode 100644 index 00000000..8468ebe0 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-wenet-ctc-model.cc @@ -0,0 +1,275 @@ +// sherpa-mnn/csrc/online-wenet-ctc-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-wenet-ctc-model.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +class OnlineWenetCtcModel::Impl { + public: + explicit Impl(const OnlineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.wenet_ctc.model); + Init(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const OnlineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.wenet_ctc.model); + Init(buf.data(), buf.size()); + } + } + + std::vector Forward(MNN::Express::VARP x, + std::vector states) { + MNN::Express::VARP &attn_cache = states[0]; + MNN::Express::VARP &conv_cache = states[1]; + MNN::Express::VARP &offset = states[2]; + + int32_t chunk_size = config_.wenet_ctc.chunk_size; + int32_t left_chunks = config_.wenet_ctc.num_left_chunks; + // build attn_mask + std::array attn_mask_shape{1, 1, + required_cache_size_ + chunk_size}; + MNN::Express::VARP attn_mask = MNNUtilsCreateTensor( + allocator_, attn_mask_shape.data(), attn_mask_shape.size()); + bool *p = attn_mask->writeMap(); + int32_t chunk_idx = + offset->readMap()[0] / chunk_size - left_chunks; + if (chunk_idx < left_chunks) { + std::fill(p, p + required_cache_size_ - chunk_idx * chunk_size, 0); + std::fill(p + required_cache_size_ - chunk_idx * chunk_size, + p + attn_mask_shape[2], 1); + } else { + std::fill(p, p + attn_mask_shape[2], 1); + } + + std::vector inputs = {std::move(x), + View(offset), + View(required_cache_size_tensor_), + std::move(attn_cache), + std::move(conv_cache), + std::move(attn_mask)}; + + auto out = + sess_->onForward(inputs); + + offset->writeMap()[0] += + out[0]->getInfo()->dim[1]; + out.push_back(std::move(offset)); + + return out; + } + + int32_t VocabSize() const { return vocab_size_; } + + int32_t ChunkLength() const { + // When chunk_size is 16, subsampling_factor_ is 4, right_context_ is 6, + // the returned value is (16 - 1)*4 + 6 + 1 = 67 + return (config_.wenet_ctc.chunk_size - 1) * subsampling_factor_ + + right_context_ + 1; + } + + int32_t ChunkShift() const { + return config_.wenet_ctc.chunk_size * subsampling_factor_; + } + + MNNAllocator *Allocator() { return allocator_; } + + // Return a vector containing 3 tensors + // - attn_cache + // - conv_cache + // - offset + std::vector GetInitStates() { + std::vector ans; + ans.reserve(3); + ans.push_back(View(attn_cache_)); + ans.push_back(View(conv_cache_)); + + int offset_shape = 1; + + MNN::Express::VARP offset = + MNNUtilsCreateTensor(allocator_, &offset_shape, 1); + + offset->writeMap()[0] = required_cache_size_; + + ans.push_back(std::move(offset)); + + return ans; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(head_, "head"); + SHERPA_ONNX_READ_META_DATA(num_blocks_, "num_blocks"); + SHERPA_ONNX_READ_META_DATA(output_size_, "output_size"); + SHERPA_ONNX_READ_META_DATA(cnn_module_kernel_, "cnn_module_kernel"); + SHERPA_ONNX_READ_META_DATA(right_context_, "right_context"); + SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor"); + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + + required_cache_size_ = + config_.wenet_ctc.chunk_size * config_.wenet_ctc.num_left_chunks; + + InitStates(); + } + + void InitStates() { + std::array attn_cache_shape{ + num_blocks_, head_, required_cache_size_, output_size_ / head_ * 2}; + attn_cache_ = MNNUtilsCreateTensor( + allocator_, attn_cache_shape.data(), attn_cache_shape.size()); + + Fill(attn_cache_, 0); + + std::array conv_cache_shape{num_blocks_, 1, output_size_, + cnn_module_kernel_ - 1}; + conv_cache_ = MNNUtilsCreateTensor( + allocator_, conv_cache_shape.data(), conv_cache_shape.size()); + + Fill(conv_cache_, 0); + + int shape = 1; + required_cache_size_tensor_ = + MNNUtilsCreateTensor(allocator_, &shape, 1); + + required_cache_size_tensor_->writeMap()[0] = + required_cache_size_; + } + + private: + OnlineModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t head_ = 0; + int32_t num_blocks_ = 0; + int32_t output_size_ = 0; + int32_t cnn_module_kernel_ = 0; + int32_t right_context_ = 0; + int32_t subsampling_factor_ = 0; + int32_t vocab_size_ = 0; + + int32_t required_cache_size_ = 0; + + MNN::Express::VARP attn_cache_{nullptr}; + MNN::Express::VARP conv_cache_{nullptr}; + MNN::Express::VARP required_cache_size_tensor_{nullptr}; +}; + +OnlineWenetCtcModel::OnlineWenetCtcModel(const OnlineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OnlineWenetCtcModel::OnlineWenetCtcModel(Manager *mgr, + const OnlineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OnlineWenetCtcModel::~OnlineWenetCtcModel() = default; + +std::vector OnlineWenetCtcModel::Forward( + MNN::Express::VARP x, std::vector states) const { + return impl_->Forward(std::move(x), std::move(states)); +} + +int32_t OnlineWenetCtcModel::VocabSize() const { return impl_->VocabSize(); } + +int32_t OnlineWenetCtcModel::ChunkLength() const { + return impl_->ChunkLength(); +} + +int32_t OnlineWenetCtcModel::ChunkShift() const { return impl_->ChunkShift(); } + +MNNAllocator *OnlineWenetCtcModel::Allocator() const { + return impl_->Allocator(); +} + +std::vector OnlineWenetCtcModel::GetInitStates() const { + return impl_->GetInitStates(); +} + +std::vector OnlineWenetCtcModel::StackStates( + std::vector> states) const { + if (states.size() != 1) { + SHERPA_ONNX_LOGE("wenet CTC model supports only batch_size==1. Given: %d", + static_cast(states.size())); + } + + return std::move(states[0]); +} + +std::vector> OnlineWenetCtcModel::UnStackStates( + std::vector states) const { + std::vector> ans(1); + ans[0] = std::move(states); + return ans; +} + +#if __ANDROID_API__ >= 9 +template OnlineWenetCtcModel::OnlineWenetCtcModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineWenetCtcModel::OnlineWenetCtcModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-wenet-ctc-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-wenet-ctc-model.h new file mode 100644 index 00000000..3dc7ca83 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-wenet-ctc-model.h @@ -0,0 +1,75 @@ +// sherpa-mnn/csrc/online-wenet-ctc-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_H_ + +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/online-ctc-model.h" +#include "sherpa-mnn/csrc/online-model-config.h" + +namespace sherpa_mnn { + +class OnlineWenetCtcModel : public OnlineCtcModel { + public: + explicit OnlineWenetCtcModel(const OnlineModelConfig &config); + + template + OnlineWenetCtcModel(Manager *mgr, const OnlineModelConfig &config); + + ~OnlineWenetCtcModel() override; + + // A list of 3 tensors: + // - attn_cache + // - conv_cache + // - offset + std::vector GetInitStates() const override; + + std::vector StackStates( + std::vector> states) const override; + + std::vector> UnStackStates( + std::vector states) const override; + + /** + * + * @param x A 3-D tensor of shape (N, T, C). N has to be 1. + * @param states It is from GetInitStates() or returned from this method. + * + * @return Return a list of tensors + * - ans[0] contains log_probs, of shape (N, T, C) + * - ans[1:] contains next_states + */ + std::vector Forward( + MNN::Express::VARP x, std::vector states) const override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const override; + + // The model accepts this number of frames before subsampling as input + int32_t ChunkLength() const override; + + // Similar to frame_shift in feature extractor, after processing + // ChunkLength() frames, we advance by ChunkShift() frames + // before we process the next chunk. + int32_t ChunkShift() const override; + + bool SupportBatchProcessing() const override { return false; } + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_WENET_CTC_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer-transducer-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer-transducer-model.cc new file mode 100644 index 00000000..6a512f87 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer-transducer-model.cc @@ -0,0 +1,510 @@ +// sherpa-mnn/csrc/online-zipformer-transducer-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-zipformer-transducer-model.h" + +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/cat.h" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-transducer-decoder.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" +#include "sherpa-mnn/csrc/unbind.h" + +namespace sherpa_mnn { + +OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( + const OnlineModelConfig &config) + : + config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } +} + +template +OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( + Manager *mgr, const OnlineModelConfig &config) + : + config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } +} + +void OnlineZipformerTransducerModel::InitEncoder(void *model_data, + size_t model_data_length) { + encoder_sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, + model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = encoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA_VEC(encoder_dims_, "encoder_dims"); + SHERPA_ONNX_READ_META_DATA_VEC(attention_dims_, "attention_dims"); + SHERPA_ONNX_READ_META_DATA_VEC(num_encoder_layers_, "num_encoder_layers"); + SHERPA_ONNX_READ_META_DATA_VEC(cnn_module_kernels_, "cnn_module_kernels"); + SHERPA_ONNX_READ_META_DATA_VEC(left_context_len_, "left_context_len"); + + SHERPA_ONNX_READ_META_DATA(T_, "T"); + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); + + if (config_.debug) { + auto print = [](const std::vector &v, const char *name) { + std::ostringstream os; + os << name << ": "; + for (auto i : v) { + os << i << " "; + } +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + }; + print(encoder_dims_, "encoder_dims"); + print(attention_dims_, "attention_dims"); + print(num_encoder_layers_, "num_encoder_layers"); + print(cnn_module_kernels_, "cnn_module_kernels"); + print(left_context_len_, "left_context_len"); +#if __OHOS__ + SHERPA_ONNX_LOGE("T: %{public}d", T_); + SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_); +#else + SHERPA_ONNX_LOGE("T: %d", T_); + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); +#endif + } +} + +void OnlineZipformerTransducerModel::InitDecoder(void *model_data, + size_t model_data_length) { + decoder_sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, + model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = decoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---decoder---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); +} + +void OnlineZipformerTransducerModel::InitJoiner(void *model_data, + size_t model_data_length) { + joiner_sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, + model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + + // get meta data + MNNMeta meta_data = joiner_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---joiner---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } +} + +std::vector OnlineZipformerTransducerModel::StackStates( + const std::vector> &states) const { + int32_t batch_size = static_cast(states.size()); + int32_t num_encoders = static_cast(num_encoder_layers_.size()); + + std::vector buf(batch_size); + + std::vector ans; + ans.reserve(states[0].size()); + + auto allocator = + const_cast(this)->allocator_; + + // cached_len + for (int32_t i = 0; i != num_encoders; ++i) { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][i]; + } + auto v = Cat(allocator, buf, 1); // (num_layers, 1) + ans.push_back(std::move(v)); + } + + // cached_avg + for (int32_t i = 0; i != num_encoders; ++i) { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][num_encoders + i]; + } + auto v = Cat(allocator, buf, 1); // (num_layers, 1, encoder_dims) + ans.push_back(std::move(v)); + } + + // cached_key + for (int32_t i = 0; i != num_encoders; ++i) { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][num_encoders * 2 + i]; + } + // (num_layers, left_context_len, 1, attention_dims) + auto v = Cat(allocator, buf, 2); + ans.push_back(std::move(v)); + } + + // cached_val + for (int32_t i = 0; i != num_encoders; ++i) { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][num_encoders * 3 + i]; + } + // (num_layers, left_context_len, 1, attention_dims/2) + auto v = Cat(allocator, buf, 2); + ans.push_back(std::move(v)); + } + + // cached_val2 + for (int32_t i = 0; i != num_encoders; ++i) { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][num_encoders * 4 + i]; + } + // (num_layers, left_context_len, 1, attention_dims/2) + auto v = Cat(allocator, buf, 2); + ans.push_back(std::move(v)); + } + + // cached_conv1 + for (int32_t i = 0; i != num_encoders; ++i) { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][num_encoders * 5 + i]; + } + // (num_layers, 1, encoder_dims, cnn_module_kernels-1) + auto v = Cat(allocator, buf, 1); + ans.push_back(std::move(v)); + } + + // cached_conv2 + for (int32_t i = 0; i != num_encoders; ++i) { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][num_encoders * 6 + i]; + } + // (num_layers, 1, encoder_dims, cnn_module_kernels-1) + auto v = Cat(allocator, buf, 1); + ans.push_back(std::move(v)); + } + + return ans; +} + +std::vector> +OnlineZipformerTransducerModel::UnStackStates( + const std::vector &states) const { + assert(states.size() == num_encoder_layers_.size() * 7); + + int32_t batch_size = states[0]->getInfo()->dim[1]; + int32_t num_encoders = num_encoder_layers_.size(); + + auto allocator = + const_cast(this)->allocator_; + + std::vector> ans; + ans.resize(batch_size); + + // cached_len + for (int32_t i = 0; i != num_encoders; ++i) { + auto v = Unbind(allocator, states[i], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + // cached_avg + for (int32_t i = num_encoders; i != 2 * num_encoders; ++i) { + auto v = Unbind(allocator, states[i], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + // cached_key + for (int32_t i = 2 * num_encoders; i != 3 * num_encoders; ++i) { + auto v = Unbind(allocator, states[i], 2); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + // cached_val + for (int32_t i = 3 * num_encoders; i != 4 * num_encoders; ++i) { + auto v = Unbind(allocator, states[i], 2); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + // cached_val2 + for (int32_t i = 4 * num_encoders; i != 5 * num_encoders; ++i) { + auto v = Unbind(allocator, states[i], 2); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + // cached_conv1 + for (int32_t i = 5 * num_encoders; i != 6 * num_encoders; ++i) { + auto v = Unbind(allocator, states[i], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + // cached_conv2 + for (int32_t i = 6 * num_encoders; i != 7 * num_encoders; ++i) { + auto v = Unbind(allocator, states[i], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + return ans; +} + +std::vector OnlineZipformerTransducerModel::GetEncoderInitStates() { + // Please see + // https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py#L673 + // for details + + int32_t n = static_cast(encoder_dims_.size()); + std::vector cached_len_vec; + std::vector cached_avg_vec; + std::vector cached_key_vec; + std::vector cached_val_vec; + std::vector cached_val2_vec; + std::vector cached_conv1_vec; + std::vector cached_conv2_vec; + + cached_len_vec.reserve(n); + cached_avg_vec.reserve(n); + cached_key_vec.reserve(n); + cached_val_vec.reserve(n); + cached_val2_vec.reserve(n); + cached_conv1_vec.reserve(n); + cached_conv2_vec.reserve(n); + + for (int32_t i = 0; i != n; ++i) { + { + std::array s{num_encoder_layers_[i], 1}; + auto v = + MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + cached_len_vec.push_back(std::move(v)); + } + + { + std::array s{num_encoder_layers_[i], 1, encoder_dims_[i]}; + auto v = MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + cached_avg_vec.push_back(std::move(v)); + } + + { + std::array s{num_encoder_layers_[i], left_context_len_[i], 1, + attention_dims_[i]}; + auto v = MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + cached_key_vec.push_back(std::move(v)); + } + + { + std::array s{num_encoder_layers_[i], left_context_len_[i], 1, + attention_dims_[i] / 2}; + auto v = MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + cached_val_vec.push_back(std::move(v)); + } + + { + std::array s{num_encoder_layers_[i], left_context_len_[i], 1, + attention_dims_[i] / 2}; + auto v = MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + cached_val2_vec.push_back(std::move(v)); + } + + { + std::array s{num_encoder_layers_[i], 1, encoder_dims_[i], + cnn_module_kernels_[i] - 1}; + auto v = MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + cached_conv1_vec.push_back(std::move(v)); + } + + { + std::array s{num_encoder_layers_[i], 1, encoder_dims_[i], + cnn_module_kernels_[i] - 1}; + auto v = MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + cached_conv2_vec.push_back(std::move(v)); + } + } + + std::vector ans; + ans.reserve(n * 7); + + for (auto &v : cached_len_vec) ans.push_back(std::move(v)); + for (auto &v : cached_avg_vec) ans.push_back(std::move(v)); + for (auto &v : cached_key_vec) ans.push_back(std::move(v)); + for (auto &v : cached_val_vec) ans.push_back(std::move(v)); + for (auto &v : cached_val2_vec) ans.push_back(std::move(v)); + for (auto &v : cached_conv1_vec) ans.push_back(std::move(v)); + for (auto &v : cached_conv2_vec) ans.push_back(std::move(v)); + + return ans; +} + +std::pair> +OnlineZipformerTransducerModel::RunEncoder(MNN::Express::VARP features, + std::vector states, + MNN::Express::VARP /* processed_frames */) { + std::vector encoder_inputs; + encoder_inputs.reserve(1 + states.size()); + + encoder_inputs.push_back(std::move(features)); + for (auto &v : states) { + encoder_inputs.push_back(std::move(v)); + } + + auto encoder_out = encoder_sess_->onForward(encoder_inputs); + + std::vector next_states; + next_states.reserve(states.size()); + + for (int32_t i = 1; i != static_cast(encoder_out.size()); ++i) { + next_states.push_back(std::move(encoder_out[i])); + } + + return {std::move(encoder_out[0]), std::move(next_states)}; +} + +MNN::Express::VARP OnlineZipformerTransducerModel::RunDecoder( + MNN::Express::VARP decoder_input) { + auto decoder_out = decoder_sess_->onForward({decoder_input}); + return std::move(decoder_out[0]); +} + +MNN::Express::VARP OnlineZipformerTransducerModel::RunJoiner(MNN::Express::VARP encoder_out, + MNN::Express::VARP decoder_out) { + std::vector joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = + joiner_sess_->onForward(joiner_input); + + return std::move(logit[0]); +} + +#if __ANDROID_API__ >= 9 +template OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineZipformerTransducerModel::OnlineZipformerTransducerModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer-transducer-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer-transducer-model.h new file mode 100644 index 00000000..43f51a6a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer-transducer-model.h @@ -0,0 +1,99 @@ +// sherpa-mnn/csrc/online-zipformer-transducer-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_ + +#include +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/online-model-config.h" +#include "sherpa-mnn/csrc/online-transducer-model.h" + +namespace sherpa_mnn { + +class OnlineZipformerTransducerModel : public OnlineTransducerModel { + public: + explicit OnlineZipformerTransducerModel(const OnlineModelConfig &config); + + template + OnlineZipformerTransducerModel(Manager *mgr, const OnlineModelConfig &config); + + std::vector StackStates( + const std::vector> &states) const override; + + std::vector> UnStackStates( + const std::vector &states) const override; + + std::vector GetEncoderInitStates() override; + + std::pair> RunEncoder( + MNN::Express::VARP features, std::vector states, + MNN::Express::VARP processed_frames) override; + + MNN::Express::VARP RunDecoder(MNN::Express::VARP decoder_input) override; + + MNN::Express::VARP RunJoiner(MNN::Express::VARP encoder_out, MNN::Express::VARP decoder_out) override; + + int32_t ContextSize() const override { return context_size_; } + + int32_t ChunkSize() const override { return T_; } + + int32_t ChunkShift() const override { return decode_chunk_len_; } + + int32_t VocabSize() const override { return vocab_size_; } + MNNAllocator *Allocator() override { return allocator_; } + + private: + void InitEncoder(void *model_data, size_t model_data_length); + void InitDecoder(void *model_data, size_t model_data_length); + void InitJoiner(void *model_data, size_t model_data_length); + + private: + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + OnlineModelConfig config_; + + std::vector encoder_dims_; + std::vector attention_dims_; + std::vector num_encoder_layers_; + std::vector cnn_module_kernels_; + std::vector left_context_len_; + + int32_t T_ = 0; + int32_t decode_chunk_len_ = 0; + + int32_t context_size_ = 0; + int32_t vocab_size_ = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-ctc-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-ctc-model-config.cc new file mode 100644 index 00000000..c6a743cb --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-ctc-model-config.cc @@ -0,0 +1,42 @@ +// sherpa-mnn/csrc/online-zipformer2-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-zipformer2-ctc-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void OnlineZipformer2CtcModelConfig::Register(ParseOptions *po) { + po->Register("zipformer2-ctc-model", &model, + "Path to CTC model.onnx. See also " + "https://github.com/k2-fsa/icefall/pull/1413"); +} + +bool OnlineZipformer2CtcModelConfig::Validate() const { + if (model.empty()) { + SHERPA_ONNX_LOGE("--zipformer2-ctc-model is empty!"); + return false; + } + + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("--zipformer2-ctc-model '%s' does not exist", + model.c_str()); + return false; + } + + return true; +} + +std::string OnlineZipformer2CtcModelConfig::ToString() const { + std::ostringstream os; + + os << "OnlineZipformer2CtcModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-ctc-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-ctc-model-config.h new file mode 100644 index 00000000..3e6bff89 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-ctc-model-config.h @@ -0,0 +1,29 @@ +// sherpa-mnn/csrc/online-zipformer2-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct OnlineZipformer2CtcModelConfig { + std::string model; + + OnlineZipformer2CtcModelConfig() = default; + + explicit OnlineZipformer2CtcModelConfig(const std::string &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-ctc-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-ctc-model.cc new file mode 100644 index 00000000..6838fec1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-ctc-model.cc @@ -0,0 +1,473 @@ +// sherpa-mnn/csrc/online-zipformer2-ctc-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-zipformer2-ctc-model.h" + +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/cat.h" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" +#include "sherpa-mnn/csrc/unbind.h" + +namespace sherpa_mnn { + +class OnlineZipformer2CtcModel::Impl { + public: + explicit Impl(const OnlineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.zipformer2_ctc.model); + Init(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const OnlineModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.zipformer2_ctc.model); + Init(buf.data(), buf.size()); + } + } + + std::vector Forward(MNN::Express::VARP features, + std::vector states) { + std::vector inputs; + inputs.reserve(1 + states.size()); + + inputs.push_back(std::move(features)); + for (auto &v : states) { + inputs.push_back(std::move(v)); + } + + return sess_->onForward(inputs); + } + + int32_t VocabSize() const { return vocab_size_; } + + int32_t ChunkLength() const { return T_; } + + int32_t ChunkShift() const { return decode_chunk_len_; } + + MNNAllocator *Allocator() { return allocator_; } + + // Return a vector containing 3 tensors + // - attn_cache + // - conv_cache + // - offset + std::vector GetInitStates() { + std::vector ans; + ans.reserve(initial_states_.size()); + for (auto &s : initial_states_) { + ans.push_back(View(s)); + } + return ans; + } + + std::vector StackStates( + std::vector> states) { + int32_t batch_size = static_cast(states.size()); + + std::vector buf(batch_size); + + std::vector ans; + int32_t num_states = static_cast(states[0].size()); + ans.reserve(num_states); + + for (int32_t i = 0; i != (num_states - 2) / 6; ++i) { + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][6 * i]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][6 * i + 1]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][6 * i + 2]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][6 * i + 3]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][6 * i + 4]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][6 * i + 5]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + } + + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][num_states - 2]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][num_states - 1]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + return ans; + } + + std::vector> UnStackStates( + std::vector states) { + int32_t m = std::accumulate(num_encoder_layers_.begin(), + num_encoder_layers_.end(), 0); + assert(states.size() == m * 6 + 2); + + int32_t batch_size = states[0]->getInfo()->dim[1]; + + std::vector> ans; + ans.resize(batch_size); + + for (int32_t i = 0; i != m; ++i) { + { + auto v = Unbind(allocator_, states[i * 6], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, states[i * 6 + 1], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, states[i * 6 + 2], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, states[i * 6 + 3], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, states[i * 6 + 4], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, states[i * 6 + 5], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + } + + { + auto v = Unbind(allocator_, states[m * 6], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, states[m * 6 + 1], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + return ans; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---zipformer2_ctc---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA_VEC(encoder_dims_, "encoder_dims"); + SHERPA_ONNX_READ_META_DATA_VEC(query_head_dims_, "query_head_dims"); + SHERPA_ONNX_READ_META_DATA_VEC(value_head_dims_, "value_head_dims"); + SHERPA_ONNX_READ_META_DATA_VEC(num_heads_, "num_heads"); + SHERPA_ONNX_READ_META_DATA_VEC(num_encoder_layers_, "num_encoder_layers"); + SHERPA_ONNX_READ_META_DATA_VEC(cnn_module_kernels_, "cnn_module_kernels"); + SHERPA_ONNX_READ_META_DATA_VEC(left_context_len_, "left_context_len"); + + SHERPA_ONNX_READ_META_DATA(T_, "T"); + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); + + if (meta_data.find("vocab_size") != meta_data.end()) { + vocab_size_ = std::stoi(meta_data["vocab_size"]); + } + + if (config_.debug) { + auto print = [](const std::vector &v, const char *name) { + std::ostringstream os; + os << name << ": "; + for (auto i : v) { + os << i << " "; + } + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + }; + print(encoder_dims_, "encoder_dims"); + print(query_head_dims_, "query_head_dims"); + print(value_head_dims_, "value_head_dims"); + print(num_heads_, "num_heads"); + print(num_encoder_layers_, "num_encoder_layers"); + print(cnn_module_kernels_, "cnn_module_kernels"); + print(left_context_len_, "left_context_len"); + SHERPA_ONNX_LOGE("T: %d", T_); + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); + SHERPA_ONNX_LOGE("vocab_size_: %d", vocab_size_); + } + + InitStates(); + } + + void InitStates() { + int32_t n = static_cast(encoder_dims_.size()); + int32_t m = std::accumulate(num_encoder_layers_.begin(), + num_encoder_layers_.end(), 0); + initial_states_.reserve(m * 6 + 2); + + for (int32_t i = 0; i != n; ++i) { + int32_t num_layers = num_encoder_layers_[i]; + int32_t key_dim = query_head_dims_[i] * num_heads_[i]; + int32_t value_dim = value_head_dims_[i] * num_heads_[i]; + int32_t nonlin_attn_head_dim = 3 * encoder_dims_[i] / 4; + + for (int32_t j = 0; j != num_layers; ++j) { + { + std::array s{left_context_len_[i], 1, key_dim}; + auto v = + MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + initial_states_.push_back(std::move(v)); + } + + { + std::array s{1, 1, left_context_len_[i], + nonlin_attn_head_dim}; + auto v = + MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + initial_states_.push_back(std::move(v)); + } + + { + std::array s{left_context_len_[i], 1, value_dim}; + auto v = + MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + initial_states_.push_back(std::move(v)); + } + + { + std::array s{left_context_len_[i], 1, value_dim}; + auto v = + MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + initial_states_.push_back(std::move(v)); + } + + { + std::array s{1, encoder_dims_[i], + cnn_module_kernels_[i] / 2}; + auto v = + MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + initial_states_.push_back(std::move(v)); + } + + { + std::array s{1, encoder_dims_[i], + cnn_module_kernels_[i] / 2}; + auto v = + MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + initial_states_.push_back(std::move(v)); + } + } + } + + { + std::array s{1, 128, 3, 19}; + auto v = MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + initial_states_.push_back(std::move(v)); + } + + { + std::array s{1}; + auto v = + MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + initial_states_.push_back(std::move(v)); + } + } + + private: + OnlineModelConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + std::vector initial_states_; + + std::vector encoder_dims_; + std::vector query_head_dims_; + std::vector value_head_dims_; + std::vector num_heads_; + std::vector num_encoder_layers_; + std::vector cnn_module_kernels_; + std::vector left_context_len_; + + int32_t T_ = 0; + int32_t decode_chunk_len_ = 0; + int32_t vocab_size_ = 0; +}; + +OnlineZipformer2CtcModel::OnlineZipformer2CtcModel( + const OnlineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OnlineZipformer2CtcModel::OnlineZipformer2CtcModel( + Manager *mgr, const OnlineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OnlineZipformer2CtcModel::~OnlineZipformer2CtcModel() = default; + +std::vector OnlineZipformer2CtcModel::Forward( + MNN::Express::VARP x, std::vector states) const { + return impl_->Forward(std::move(x), std::move(states)); +} + +int32_t OnlineZipformer2CtcModel::VocabSize() const { + return impl_->VocabSize(); +} + +int32_t OnlineZipformer2CtcModel::ChunkLength() const { + return impl_->ChunkLength(); +} + +int32_t OnlineZipformer2CtcModel::ChunkShift() const { + return impl_->ChunkShift(); +} + +MNNAllocator *OnlineZipformer2CtcModel::Allocator() const { + return impl_->Allocator(); +} + +std::vector OnlineZipformer2CtcModel::GetInitStates() const { + return impl_->GetInitStates(); +} + +std::vector OnlineZipformer2CtcModel::StackStates( + std::vector> states) const { + return impl_->StackStates(std::move(states)); +} + +std::vector> OnlineZipformer2CtcModel::UnStackStates( + std::vector states) const { + return impl_->UnStackStates(std::move(states)); +} + +#if __ANDROID_API__ >= 9 +template OnlineZipformer2CtcModel::OnlineZipformer2CtcModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineZipformer2CtcModel::OnlineZipformer2CtcModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-ctc-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-ctc-model.h new file mode 100644 index 00000000..2ffa3a3b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-ctc-model.h @@ -0,0 +1,74 @@ +// sherpa-mnn/csrc/online-zipformer2-ctc-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_H_ + +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/online-ctc-model.h" +#include "sherpa-mnn/csrc/online-model-config.h" + +namespace sherpa_mnn { + +class OnlineZipformer2CtcModel : public OnlineCtcModel { + public: + explicit OnlineZipformer2CtcModel(const OnlineModelConfig &config); + + template + OnlineZipformer2CtcModel(Manager *mgr, const OnlineModelConfig &config); + + ~OnlineZipformer2CtcModel() override; + + // A list of tensors. + // See also + // https://github.com/k2-fsa/icefall/pull/1413 + // and + // https://github.com/k2-fsa/icefall/pull/1415 + std::vector GetInitStates() const override; + + std::vector StackStates( + std::vector> states) const override; + + std::vector> UnStackStates( + std::vector states) const override; + + /** + * + * @param x A 3-D tensor of shape (N, T, C). N has to be 1. + * @param states It is from GetInitStates() or returned from this method. + * + * @return Return a list of tensors + * - ans[0] contains log_probs, of shape (N, T, C) + * - ans[1:] contains next_states + */ + std::vector Forward( + MNN::Express::VARP x, std::vector states) const override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** Return an allocator for allocating memory + */ + MNNAllocator *Allocator() const override; + + // The model accepts this number of frames before subsampling as input + int32_t ChunkLength() const override; + + // Similar to frame_shift in feature extractor, after processing + // ChunkLength() frames, we advance by ChunkShift() frames + // before we process the next chunk. + int32_t ChunkShift() const override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-transducer-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-transducer-model.cc new file mode 100644 index 00000000..e0aaf467 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-transducer-model.cc @@ -0,0 +1,493 @@ +// sherpa-mnn/csrc/online-zipformer2-transducer-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-zipformer2-transducer-model.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/cat.h" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/online-transducer-decoder.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/text-utils.h" +#include "sherpa-mnn/csrc/unbind.h" + +namespace sherpa_mnn { + +OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( + const OnlineModelConfig &config) + : + sess_opts_(GetSessionOptions(config)), + config_(config), + allocator_{} { + { + auto buf = ReadFile(config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } +} + +template +OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( + Manager *mgr, const OnlineModelConfig &config) + : + config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } +} + +void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, + size_t model_data_length) { + encoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = encoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA_VEC(encoder_dims_, "encoder_dims"); + SHERPA_ONNX_READ_META_DATA_VEC(query_head_dims_, "query_head_dims"); + SHERPA_ONNX_READ_META_DATA_VEC(value_head_dims_, "value_head_dims"); + SHERPA_ONNX_READ_META_DATA_VEC(num_heads_, "num_heads"); + SHERPA_ONNX_READ_META_DATA_VEC(num_encoder_layers_, "num_encoder_layers"); + SHERPA_ONNX_READ_META_DATA_VEC(cnn_module_kernels_, "cnn_module_kernels"); + SHERPA_ONNX_READ_META_DATA_VEC(left_context_len_, "left_context_len"); + + SHERPA_ONNX_READ_META_DATA(T_, "T"); + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); + + if (config_.debug) { + auto print = [](const std::vector &v, const char *name) { + std::ostringstream os; + os << name << ": "; + for (auto i : v) { + os << i << " "; + } +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + }; + print(encoder_dims_, "encoder_dims"); + print(query_head_dims_, "query_head_dims"); + print(value_head_dims_, "value_head_dims"); + print(num_heads_, "num_heads"); + print(num_encoder_layers_, "num_encoder_layers"); + print(cnn_module_kernels_, "cnn_module_kernels"); + print(left_context_len_, "left_context_len"); + +#if __OHOS__ + SHERPA_ONNX_LOGE("T: %{public}d", T_); + SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_); +#else + SHERPA_ONNX_LOGE("T: %d", T_); + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); +#endif + } +} + +void OnlineZipformer2TransducerModel::InitDecoder(void *model_data, + size_t model_data_length) { + decoder_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + + // get meta data + MNNMeta meta_data = decoder_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---decoder---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); +} + +void OnlineZipformer2TransducerModel::InitJoiner(void *model_data, + size_t model_data_length) { + joiner_sess_ = std::unique_ptr( + MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + + // get meta data + MNNMeta meta_data = joiner_sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + os << "---joiner---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } +} + +std::vector OnlineZipformer2TransducerModel::StackStates( + const std::vector> &states) const { + int32_t batch_size = static_cast(states.size()); + + std::vector buf(batch_size); + + auto allocator = + const_cast(this)->allocator_; + + std::vector ans; + int32_t num_states = static_cast(states[0].size()); + ans.reserve(num_states); + + for (int32_t i = 0; i != (num_states - 2) / 6; ++i) { + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][6 * i]; + } + auto v = Cat(allocator, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][6 * i + 1]; + } + auto v = Cat(allocator, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][6 * i + 2]; + } + auto v = Cat(allocator, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][6 * i + 3]; + } + auto v = Cat(allocator, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][6 * i + 4]; + } + auto v = Cat(allocator, buf, 0); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][6 * i + 5]; + } + auto v = Cat(allocator, buf, 0); + ans.push_back(std::move(v)); + } + } + + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][num_states - 2]; + } + auto v = Cat(allocator, buf, 0); + ans.push_back(std::move(v)); + } + + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = states[n][num_states - 1]; + } + auto v = Cat(allocator, buf, 0); + ans.push_back(std::move(v)); + } + return ans; +} + +std::vector> +OnlineZipformer2TransducerModel::UnStackStates( + const std::vector &states) const { + int32_t m = std::accumulate(num_encoder_layers_.begin(), + num_encoder_layers_.end(), 0); + assert(static_cast(states.size()) == m * 6 + 2); + + int32_t batch_size = states[0]->getInfo()->dim[1]; + + auto allocator = + const_cast(this)->allocator_; + + std::vector> ans; + ans.resize(batch_size); + + for (int32_t i = 0; i != m; ++i) { + { + auto v = Unbind(allocator, states[i * 6], 1); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator, states[i * 6 + 1], 1); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator, states[i * 6 + 2], 1); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator, states[i * 6 + 3], 1); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator, states[i * 6 + 4], 0); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator, states[i * 6 + 5], 0); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + } + + { + auto v = Unbind(allocator, states[m * 6], 0); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator, states[m * 6 + 1], 0); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + return ans; +} + +std::vector +OnlineZipformer2TransducerModel::GetEncoderInitStates() { + std::vector ans; + int32_t n = static_cast(encoder_dims_.size()); + int32_t m = std::accumulate(num_encoder_layers_.begin(), + num_encoder_layers_.end(), 0); + ans.reserve(m * 6 + 2); + + for (int32_t i = 0; i != n; ++i) { + int32_t num_layers = num_encoder_layers_[i]; + int32_t key_dim = query_head_dims_[i] * num_heads_[i]; + int32_t value_dim = value_head_dims_[i] * num_heads_[i]; + int32_t nonlin_attn_head_dim = 3 * encoder_dims_[i] / 4; + + for (int32_t j = 0; j != num_layers; ++j) { + { + std::array s{left_context_len_[i], 1, key_dim}; + auto v = + MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + ans.push_back(std::move(v)); + } + + { + std::array s{1, 1, left_context_len_[i], + nonlin_attn_head_dim}; + auto v = + MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + ans.push_back(std::move(v)); + } + + { + std::array s{left_context_len_[i], 1, value_dim}; + auto v = + MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + ans.push_back(std::move(v)); + } + + { + std::array s{left_context_len_[i], 1, value_dim}; + auto v = + MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + ans.push_back(std::move(v)); + } + + { + std::array s{1, encoder_dims_[i], + cnn_module_kernels_[i] / 2}; + auto v = + MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + ans.push_back(std::move(v)); + } + + { + std::array s{1, encoder_dims_[i], + cnn_module_kernels_[i] / 2}; + auto v = + MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + ans.push_back(std::move(v)); + } + } + } + + { + SHERPA_ONNX_CHECK_NE(feature_dim_, 0); + int32_t embed_dim = (((feature_dim_ - 1) / 2) - 1) / 2; + std::array s{1, 128, 3, embed_dim}; + + auto v = MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + ans.push_back(std::move(v)); + } + + { + std::array s{1}; + auto v = MNNUtilsCreateTensor(allocator_, s.data(), s.size()); + Fill(v, 0); + ans.push_back(std::move(v)); + } + return ans; +} + +std::pair> +OnlineZipformer2TransducerModel::RunEncoder(MNN::Express::VARP features, + std::vector states, + MNN::Express::VARP /* processed_frames */) { + std::vector encoder_inputs; + encoder_inputs.reserve(1 + states.size()); + + encoder_inputs.push_back(std::move(features)); + for (auto &v : states) { + encoder_inputs.push_back(std::move(v)); + } + + auto encoder_out = encoder_sess_->onForward(encoder_inputs); + + std::vector next_states; + next_states.reserve(states.size()); + + for (int32_t i = 1; i != static_cast(encoder_out.size()); ++i) { + next_states.push_back(std::move(encoder_out[i])); + } + return {std::move(encoder_out[0]), std::move(next_states)}; +} + +MNN::Express::VARP OnlineZipformer2TransducerModel::RunDecoder( + MNN::Express::VARP decoder_input) { + auto decoder_out = decoder_sess_->onForward({decoder_input}); + return std::move(decoder_out[0]); +} + +MNN::Express::VARP OnlineZipformer2TransducerModel::RunJoiner(MNN::Express::VARP encoder_out, + MNN::Express::VARP decoder_out) { + std::vector joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = + joiner_sess_->onForward(joiner_input); + + return std::move(logit[0]); +} + +#if __ANDROID_API__ >= 9 +template OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-transducer-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-transducer-model.h new file mode 100644 index 00000000..c1e2aeb2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/online-zipformer2-transducer-model.h @@ -0,0 +1,108 @@ +// sherpa-mnn/csrc/online-zipformer2-transducer-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_TRANSDUCER_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_TRANSDUCER_MODEL_H_ + +#include +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/online-model-config.h" +#include "sherpa-mnn/csrc/online-transducer-model.h" + +namespace sherpa_mnn { + +class OnlineZipformer2TransducerModel : public OnlineTransducerModel { + public: + explicit OnlineZipformer2TransducerModel(const OnlineModelConfig &config); + + template + OnlineZipformer2TransducerModel(Manager *mgr, + const OnlineModelConfig &config); + + std::vector StackStates( + const std::vector> &states) const override; + + std::vector> UnStackStates( + const std::vector &states) const override; + + std::vector GetEncoderInitStates() override; + + void SetFeatureDim(int32_t feature_dim) override { + feature_dim_ = feature_dim; + } + + std::pair> RunEncoder( + MNN::Express::VARP features, std::vector states, + MNN::Express::VARP processed_frames) override; + + MNN::Express::VARP RunDecoder(MNN::Express::VARP decoder_input) override; + + MNN::Express::VARP RunJoiner(MNN::Express::VARP encoder_out, MNN::Express::VARP decoder_out) override; + + int32_t ContextSize() const override { return context_size_; } + + int32_t ChunkSize() const override { return T_; } + + int32_t ChunkShift() const override { return decode_chunk_len_; } + + int32_t VocabSize() const override { return vocab_size_; } + MNNAllocator *Allocator() override { return allocator_; } + + private: + void InitEncoder(void *model_data, size_t model_data_length); + void InitDecoder(void *model_data, size_t model_data_length); + void InitJoiner(void *model_data, size_t model_data_length); + + private: + MNNEnv env_; + MNNConfig sess_opts_; + + MNNAllocator* allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + OnlineModelConfig config_; + + std::vector encoder_dims_; + std::vector query_head_dims_; + std::vector value_head_dims_; + std::vector num_heads_; + std::vector num_encoder_layers_; + std::vector cnn_module_kernels_; + std::vector left_context_len_; + + int32_t T_ = 0; + int32_t decode_chunk_len_ = 0; + + int32_t context_size_ = 0; + int32_t vocab_size_ = 0; + int32_t feature_dim_ = 80; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_TRANSDUCER_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/onnx-utils.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/onnx-utils.h new file mode 100644 index 00000000..44a4508e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/onnx-utils.h @@ -0,0 +1,24 @@ +// sherpa-mnn/csrc/onnx-utils.h +// +// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2023 Pingfeng Luo +#ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_ +#define SHERPA_ONNX_CSRC_ONNX_UTILS_H_ + +#ifdef _MSC_VER +// For ToWide() below +#include +#include +#endif + +#include +#include +#include +#include +#include + +#include "MNNUtils.hpp" // NOLINT + + + +#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/packed-sequence-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/packed-sequence-test.cc new file mode 100644 index 00000000..1f3510a4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/packed-sequence-test.cc @@ -0,0 +1,52 @@ +// sherpa-mnn/csrc/packed-sequence-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/packed-sequence.h" + +#include + +#include "gtest/gtest.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +TEST(PackedSequence, Case1) { + MNNAllocator* allocator; + std::array shape{5, 5, 4}; + MNN::Express::VARP v = + MNNUtilsCreateTensor(allocator, shape.data(), shape.size()); + float *p = v->writeMap(); + + std::iota(p, p + shape[0] * shape[1] * shape[2], 0); + + MNN::Express::VARP length = + MNNUtilsCreateTensor(allocator, shape.data(), 1); + int *p_length = length->writeMap(); + p_length[0] = 1; + p_length[1] = 2; + p_length[2] = 3; + p_length[3] = 5; + p_length[4] = 2; + + auto packed_seq = PackPaddedSequence(allocator, &v, &length); + fprintf(stderr, "sorted indexes: "); + for (auto i : packed_seq.sorted_indexes) { + fprintf(stderr, "%d ", static_cast(i)); + } + fprintf(stderr, "\n"); + // output index: 0 1 2 3 4 + // sorted indexes: 3 2 1 4 0 + // length: 5 3 2 2 1 + Print3D(&v); + Print2D(&packed_seq.data); + fprintf(stderr, "batch sizes per time step: "); + for (auto i : packed_seq.batch_sizes) { + fprintf(stderr, "%d ", static_cast(i)); + } + fprintf(stderr, "\n"); + + // TODO(fangjun): Check that the return value is correct +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/packed-sequence.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/packed-sequence.cc new file mode 100644 index 00000000..e3679e98 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/packed-sequence.cc @@ -0,0 +1,106 @@ +// sherpa-mnn/csrc/packed-sequence.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/packed-sequence.h" + +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/slice.h" +#include "sherpa-mnn/csrc/transpose.h" + +namespace sherpa_mnn { + +static MNN::Express::VARP IndexSelect(MNNAllocator *allocator, MNN::Express::VARP value, + const std::vector &sorted_indexes) { + auto shape = value->getInfo()->dim; + assert(shape.size() == 3); + std::array ans_shape{static_cast(sorted_indexes.size()), + shape[1], shape[2]}; + + MNN::Express::VARP ans = MNNUtilsCreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + float *dst = ans->writeMap(); + const float *src = value->readMap(); + + for (auto i : sorted_indexes) { + const float *start = src + i * shape[1] * shape[2]; + std::copy(start, start + shape[1] * shape[2], dst); + dst += shape[1] * shape[2]; + } + return ans; +} + +PackedSequence PackPaddedSequence(MNNAllocator *allocator, + MNN::Express::VARP value, MNN::Express::VARP length) { + std::vector v_shape = value->getInfo()->dim; + std::vector l_shape = length->getInfo()->dim; + + assert(v_shape.size() == 3); + assert(l_shape.size() == 1); + assert(v_shape[0] == l_shape[0]); + + std::vector indexes(v_shape[0]); + std::iota(indexes.begin(), indexes.end(), 0); + + const int *p_length = length->readMap(); + // sort in descending order + std::sort(indexes.begin(), indexes.end(), [p_length](int32_t i, int32_t j) { + return p_length[i] > p_length[j]; + }); + + int32_t n = static_cast(v_shape[0]); + + int max_T = p_length[indexes[0]]; + + auto sum_T = std::accumulate(p_length, p_length + n, static_cast(0)); + + std::array data_shape{sum_T, v_shape[2]}; + + MNN::Express::VARP data = MNNUtilsCreateTensor( + allocator, data_shape.data(), data_shape.size()); + float *dst = data->writeMap(); + + MNN::Express::VARP tensor = IndexSelect(allocator, value, indexes); + tensor = Transpose01(allocator, tensor); + + // batch size at each time step + std::vector batch_sizes; + batch_sizes.reserve(max_T); + + int prev_l = 0; + for (int32_t i = 0; i != n; ++i) { + auto cur_l = p_length[indexes[n - 1 - i]]; + assert(cur_l >= prev_l); + if (cur_l == prev_l) { + continue; + } + + auto cur_batch_size = n - i; + + MNN::Express::VARP cur_batch = + Slice(allocator, tensor, prev_l, cur_l, 0, cur_batch_size); + auto count = cur_batch->getInfo()->size; + const float *src = cur_batch->readMap(); + std::copy(src, src + count, dst); + dst += count; + + for (int32_t j = prev_l; j < cur_l; ++j) { + batch_sizes.push_back(cur_batch_size); + } + + prev_l = cur_l; + } + + PackedSequence packed_seq; + packed_seq.sorted_indexes = std::move(indexes); + packed_seq.data = std::move(data); + packed_seq.batch_sizes = std::move(batch_sizes); + + return packed_seq; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/packed-sequence.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/packed-sequence.h new file mode 100644 index 00000000..0ef47373 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/packed-sequence.h @@ -0,0 +1,52 @@ +// sherpa-mnn/csrc/packed-sequence.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_PACKED_SEQUENCE_H_ +#define SHERPA_ONNX_CSRC_PACKED_SEQUENCE_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT + +namespace sherpa_mnn { + +struct PackedSequence { + std::vector sorted_indexes; + std::vector batch_sizes; + + // data is a 2-D tensor of shape (sum(batch_sizes), channels) + MNN::Express::VARP data{nullptr}; + + // Return a shallow copy of data[start:start+size, :] + MNN::Express::VARP Get(int32_t start, int32_t size) { + auto shape = data->getInfo()->dim; + + std::array ans_shape{size, shape[1]}; + + float *p = data->writeMap(); + + auto memory_info = + (MNNAllocator*)(nullptr); + + // a shallow copy + return MNNUtilsCreateTensor(memory_info, p + start * shape[1], + size * shape[1], ans_shape.data(), + ans_shape.size()); + } +}; + +/** Similar to torch.nn.utils.rnn.pad_sequence but it supports only + * batch_first=true. + * + * @param allocator + * @param value A 3-D tensor of shape (B, T, C). Its dtype is float. + * @param length A 1-D tensor of shape (B,). Its dtype is int. Each + * element in it specifies the valid length of the corresponding + * entry in value before padding. + */ +PackedSequence PackPaddedSequence(MNNAllocator *allocator, + MNN::Express::VARP value, MNN::Express::VARP length); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_PACKED_SEQUENCE_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/pad-sequence-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/pad-sequence-test.cc new file mode 100644 index 00000000..2f66cc62 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/pad-sequence-test.cc @@ -0,0 +1,43 @@ +// sherpa-mnn/csrc/pad-sequence-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/pad-sequence.h" + +#include + +#include "gtest/gtest.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +TEST(PadSequence, ThreeTensors) { + MNNAllocator* allocator; + + std::array shape1{3, 5}; + MNN::Express::VARP v1 = + MNNUtilsCreateTensor(allocator, shape1.data(), shape1.size()); + float *p1 = v1->writeMap(); + std::iota(p1, p1 + shape1[0] * shape1[1], 0); + + std::array shape2{4, 5}; + MNN::Express::VARP v2 = + MNNUtilsCreateTensor(allocator, shape2.data(), shape2.size()); + float *p2 = v2->writeMap(); + std::iota(p2, p2 + shape2[0] * shape2[1], 0); + + std::array shape3{2, 5}; + MNN::Express::VARP v3 = + MNNUtilsCreateTensor(allocator, shape3.data(), shape3.size()); + float *p3 = v3->writeMap(); + std::iota(p3, p3 + shape3[0] * shape3[1], 0); + + auto ans = PadSequence(allocator, {&v1, &v2, &v3}, -1); + + Print2D(&v1); + Print2D(&v2); + Print2D(&v3); + Print3D(&ans); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/pad-sequence.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/pad-sequence.cc new file mode 100644 index 00000000..b4beae0f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/pad-sequence.cc @@ -0,0 +1,52 @@ +// sherpa-mnn/csrc/pad-sequence.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/pad-sequence.h" + +#include +#include +#include + +namespace sherpa_mnn { + +MNN::Express::VARP PadSequence(MNNAllocator *allocator, + const std::vector &values, + float padding_value) { + int32_t batch_size = static_cast(values.size()); + + std::vector shape0 = + values[0]->getInfo()->dim; + assert(shape0.size() == 2); + + auto feature_dim = shape0[1]; + auto max_T = shape0[0]; + + for (int32_t i = 1; i != batch_size; ++i) { + auto shape = values[i]->getInfo()->dim; + + assert(shape.size() == 2); + assert(shape[1] == feature_dim); + + max_T = std::max(max_T, shape[0]); + } + std::array ans_shape{batch_size, max_T, feature_dim}; + + MNN::Express::VARP ans = MNNUtilsCreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + float *dst = ans->writeMap(); + std::fill(dst, dst + batch_size * max_T * feature_dim, padding_value); + + for (const auto v : values) { + const float *src = v->readMap(); + auto shape = v->getInfo()->dim; + std::copy(src, src + shape[0] * shape[1], dst); + dst += max_T * feature_dim; + } + + return ans; + + // TODO(fangjun): Check that the returned value is correct. +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/pad-sequence.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/pad-sequence.h new file mode 100644 index 00000000..c78bd9bd --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/pad-sequence.h @@ -0,0 +1,31 @@ +// sherpa-mnn/csrc/pad-sequence.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_PAD_SEQUENCE_H_ +#define SHERPA_ONNX_CSRC_PAD_SEQUENCE_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT + +namespace sherpa_mnn { + +/** Similar to torch.nn.utils.rnn.pad_sequence but it supports only + * batch_first=true. + * + * @param allocator + * @param values A list of 2-D tensors. Each tensor's second dimension + * must be the same and the data type of each tensor should + * be float. + * @param padding_value Value used for padding. For log-fbank, you usually use + * -23.025850929940457f as the padding value. + * + * @return Return a 3-D tensor of shape (B, max_T, C). + */ +MNN::Express::VARP PadSequence(MNNAllocator *allocator, + const std::vector &values, + float padding_value); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_PAD_SEQUENCE_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/parse-options.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/parse-options.cc new file mode 100644 index 00000000..e62a552d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/parse-options.cc @@ -0,0 +1,678 @@ +// sherpa-mnn/csrc/parse-options.cc +/** + * Copyright 2009-2011 Karel Vesely; Microsoft Corporation; + * Saarland University (Author: Arnab Ghoshal); + * Copyright 2012-2013 Johns Hopkins University (Author: Daniel Povey); + * Frantisek Skala; Arnab Ghoshal + * Copyright 2013 Tanel Alumae + */ + +// This file is copied and modified from kaldi/src/util/parse-options.cu + +#include "sherpa-mnn/csrc/parse-options.h" + +#include +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/log.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +ParseOptions::ParseOptions(const std::string &prefix, ParseOptions *po) + : print_args_(false), help_(false), usage_(""), argc_(0), argv_(nullptr) { + if (po != nullptr && po->other_parser_ != nullptr) { + // we get here if this constructor is used twice, recursively. + other_parser_ = po->other_parser_; + } else { + other_parser_ = po; + } + if (po != nullptr && !po->prefix_.empty()) { + prefix_ = po->prefix_ + std::string(".") + prefix; + } else { + prefix_ = prefix; + } +} + +void ParseOptions::Register(const std::string &name, bool *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, int *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, uint32_t *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, float *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, double *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, std::string *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +// old-style, used for registering application-specific parameters +template +void ParseOptions::RegisterTmpl(const std::string &name, T *ptr, + const std::string &doc) { + if (other_parser_ == nullptr) { + this->RegisterCommon(name, ptr, doc, false); + } else { + SHERPA_ONNX_CHECK(prefix_ != "") + << "prefix: " << prefix_ << "\n" + << "Cannot use empty prefix when registering with prefix."; + std::string new_name = prefix_ + '.' + name; // name becomes prefix.name + other_parser_->Register(new_name, ptr, doc); + } +} + +// does the common part of the job of registering a parameter +template +void ParseOptions::RegisterCommon(const std::string &name, T *ptr, + const std::string &doc, bool is_standard) { + SHERPA_ONNX_CHECK(ptr != nullptr); + std::string idx = name; + NormalizeArgName(&idx); + if (doc_map_.find(idx) != doc_map_.end()) { + SHERPA_ONNX_LOGE("Registering option twice, ignoring second time: %s", + name.c_str()); + } else { + this->RegisterSpecific(name, idx, ptr, doc, is_standard); + } +} + +// used to register standard parameters (those that are present in all of the +// applications) +template +void ParseOptions::RegisterStandard(const std::string &name, T *ptr, + const std::string &doc) { + this->RegisterCommon(name, ptr, doc, true); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, bool *b, + const std::string &doc, bool is_standard) { + bool_map_[idx] = b; + doc_map_[idx] = + DocInfo(name, doc + " (bool, default = " + ((*b) ? "true)" : "false)"), + is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, int *i, + const std::string &doc, bool is_standard) { + int_map_[idx] = i; + std::ostringstream ss; + ss << doc << " (int64, default = " << *i << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, uint32_t *u, + const std::string &doc, bool is_standard) { + uint_map_[idx] = u; + std::ostringstream ss; + ss << doc << " (uint, default = " << *u << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, float *f, + const std::string &doc, bool is_standard) { + float_map_[idx] = f; + std::ostringstream ss; + ss << doc << " (float, default = " << *f << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, double *f, + const std::string &doc, bool is_standard) { + double_map_[idx] = f; + std::ostringstream ss; + ss << doc << " (double, default = " << *f << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, std::string *s, + const std::string &doc, bool is_standard) { + string_map_[idx] = s; + doc_map_[idx] = + DocInfo(name, doc + " (string, default = \"" + *s + "\")", is_standard); +} + +void ParseOptions::DisableOption(const std::string &name) { + if (argv_ != nullptr) { + SHERPA_ONNX_LOGE("DisableOption must not be called after calling Read()."); + exit(-1); + } + if (doc_map_.erase(name) == 0) { + SHERPA_ONNX_LOGE("Option %s was not registered so cannot be disabled: ", + name.c_str()); + exit(-1); + } + bool_map_.erase(name); + int_map_.erase(name); + int64_map_.erase(name); + uint_map_.erase(name); + float_map_.erase(name); + double_map_.erase(name); + string_map_.erase(name); +} + +int32_t ParseOptions::NumArgs() const { return positional_args_.size(); } + +std::string ParseOptions::GetArg(int32_t i) const { + if (i < 1 || i > static_cast(positional_args_.size())) { + SHERPA_ONNX_LOGE("ParseOptions::GetArg, invalid index %d", i); + exit(-1); + } + + return positional_args_[i - 1]; +} + +// We currently do not support any other options. +enum ShellType : std::uint8_t { kBash = 0 }; + +// This can be changed in the code if it ever does need to be changed (as it's +// unlikely that one compilation of this tool-set would use both shells). +static ShellType kShellType = kBash; + +// Returns true if we need to escape a string before putting it into +// a shell (mainly thinking of bash shell, but should work for others) +// This is for the convenience of the user so command-lines that are +// printed out by ParseOptions::Read (with --print-args=true) are +// paste-able into the shell and will run. If you use a different type of +// shell, it might be necessary to change this function. +// But it's mostly a cosmetic issue as it basically affects how +// the program echoes its command-line arguments to the screen. +static bool MustBeQuoted(const std::string &str, ShellType st) { + // Only Bash is supported (for the moment). + SHERPA_ONNX_CHECK_EQ(st, kBash) << "Invalid shell type."; + + const char *c = str.c_str(); + if (*c == '\0') { + return true; // Must quote empty string + } else { + std::array ok_chars{}; + + // These seem not to be interpreted as long as there are no other "bad" + // characters involved (e.g. "," would be interpreted as part of something + // like a{b,c}, but not on its own. + ok_chars[kBash] = "[]~#^_-+=:.,/"; + + // Just want to make sure that a space character doesn't get automatically + // inserted here via an automated style-checking script, like it did before. + SHERPA_ONNX_CHECK(!strchr(ok_chars[kBash], ' ')); + + for (; *c != '\0'; ++c) { + // For non-alphanumeric characters we have a list of characters which + // are OK. All others are forbidden (this is easier since the shell + // interprets most non-alphanumeric characters). + if (!isalnum(*c)) { + const char *d = nullptr; + for (d = ok_chars[st]; *d != '\0'; ++d) { + if (*c == *d) break; + } + // If not alphanumeric or one of the "ok_chars", it must be escaped. + if (*d == '\0') return true; + } + } + return false; // The string was OK. No quoting or escaping. + } +} + +// Returns a quoted and escaped version of "str" +// which has previously been determined to need escaping. +// Our aim is to print out the command line in such a way that if it's +// pasted into a shell of ShellType "st" (only bash for now), it +// will get passed to the program in the same way. +static std::string QuoteAndEscape(const std::string &str, ShellType /*st*/) { + // Only Bash is supported (for the moment). + SHERPA_ONNX_CHECK_EQ(st, kBash) << "Invalid shell type."; + + // For now we use the following rules: + // In the normal case, we quote with single-quote "'", and to escape + // a single-quote we use the string: '\'' (interpreted as closing the + // single-quote, putting an escaped single-quote from the shell, and + // then reopening the single quote). + char quote_char = '\''; + const char *escape_str = "'\\''"; // e.g. echo 'a'\''b' returns a'b + + // If the string contains single-quotes that would need escaping this + // way, and we determine that the string could be safely double-quoted + // without requiring any escaping, then we double-quote the string. + // This is the case if the characters "`$\ do not appear in the string. + // e.g. see http://www.redhat.com/mirrors/LDP/LDP/abs/html/quotingvar.html + const char *c_str = str.c_str(); + if (strchr(c_str, '\'') && !strpbrk(c_str, "\"`$\\")) { + quote_char = '"'; + escape_str = "\\\""; // should never be accessed. + } + + std::array buf{}; + buf[1] = '\0'; + + buf[0] = quote_char; + std::string ans = buf.data(); + const char *c = str.c_str(); + for (; *c != '\0'; ++c) { + if (*c == quote_char) { + ans += escape_str; + } else { + buf[0] = *c; + ans += buf.data(); + } + } + buf[0] = quote_char; + ans += buf.data(); + return ans; +} + +// static function +std::string ParseOptions::Escape(const std::string &str) { + return MustBeQuoted(str, kShellType) ? QuoteAndEscape(str, kShellType) : str; +} + +int32_t ParseOptions::Read(int32_t argc, const char *const *argv) { + argc_ = argc; + argv_ = argv; + std::string key, value; + int32_t i = 0; + + // first pass: look for config parameter, look for priority + for (i = 1; i < argc; ++i) { + if (std::strncmp(argv[i], "--", 2) == 0) { + if (std::strcmp(argv[i], "--") == 0) { + // a lone "--" marks the end of named options + break; + } + bool has_equal_sign = false; + SplitLongArg(argv[i], &key, &value, &has_equal_sign); + NormalizeArgName(&key); + Trim(&value); + if (key == "config") { + ReadConfigFile(value); + } else if (key == "help") { + PrintUsage(); + exit(0); + } + } + } + + bool double_dash_seen = false; + // second pass: add the command line options + for (i = 1; i < argc; ++i) { + if (std::strncmp(argv[i], "--", 2) == 0) { + if (std::strcmp(argv[i], "--") == 0) { + // A lone "--" marks the end of named options. + // Skip that option and break the processing of named options + i += 1; + double_dash_seen = true; + break; + } + bool has_equal_sign = false; + SplitLongArg(argv[i], &key, &value, &has_equal_sign); + NormalizeArgName(&key); + Trim(&value); + if (!SetOption(key, value, has_equal_sign)) { + PrintUsage(true); + SHERPA_ONNX_LOGE("Invalid option %s", argv[i]); + exit(-1); + } + } else { + break; + } + } + + // process remaining arguments as positional + for (; i < argc; ++i) { + if ((std::strcmp(argv[i], "--") == 0) && !double_dash_seen) { + double_dash_seen = true; + } else { + positional_args_.emplace_back(argv[i]); + } + } + + // if the user did not suppress this with --print-args = false.... + if (print_args_) { + std::ostringstream strm; + for (int32_t j = 0; j < argc; ++j) strm << Escape(argv[j]) << " "; + strm << '\n'; + SHERPA_ONNX_LOGE("%s", strm.str().c_str()); + } + return i; +} + +void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const { + std::ostringstream os; + os << '\n' << usage_ << '\n'; + // first we print application-specific options + bool app_specific_header_printed = false; + for (const auto &it : doc_map_) { + if (it.second.is_standard_ == false) { // application-specific option + if (app_specific_header_printed == false) { // header was not yet printed + os << "Options:" << '\n'; + app_specific_header_printed = true; + } + os << " --" << std::setw(25) << std::left << it.second.name_ << " : " + << it.second.use_msg_ << '\n'; + } + } + if (app_specific_header_printed == true) { + os << '\n'; + } + + // then the standard options + os << "Standard options:" << '\n'; + for (const auto &it : doc_map_) { + if (it.second.is_standard_ == true) { // we have standard option + os << " --" << std::setw(25) << std::left << it.second.name_ << " : " + << it.second.use_msg_ << '\n'; + } + } + os << '\n'; + if (print_command_line) { + std::ostringstream strm; + strm << "Command line was: "; + for (int32_t j = 0; j < argc_; ++j) strm << Escape(argv_[j]) << " "; + strm << '\n'; + os << strm.str(); + } + + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +} + +void ParseOptions::PrintConfig(std::ostream &os) const { + os << '\n' << "[[ Configuration of UI-Registered options ]]" << '\n'; + std::string key; + for (const auto &it : doc_map_) { + key = it.first; + os << it.second.name_ << " = "; + if (bool_map_.end() != bool_map_.find(key)) { + os << (*bool_map_.at(key) ? "true" : "false"); + } else if (int_map_.end() != int_map_.find(key)) { + os << (*int_map_.at(key)); + } else if (int64_map_.end() != int64_map_.find(key)) { + os << (*int64_map_.at(key)); + } else if (uint_map_.end() != uint_map_.find(key)) { + os << (*uint_map_.at(key)); + } else if (float_map_.end() != float_map_.find(key)) { + os << (*float_map_.at(key)); + } else if (double_map_.end() != double_map_.find(key)) { + os << (*double_map_.at(key)); + } else if (string_map_.end() != string_map_.find(key)) { + os << "'" << *string_map_.at(key) << "'"; + } else { + SHERPA_ONNX_LOGE("PrintConfig: unrecognized option %s [code error]", + key.c_str()); + exit(-1); + } + os << '\n'; + } + os << '\n'; +} + +void ParseOptions::ReadConfigFile(const std::string &filename) { + std::ifstream is(filename.c_str(), std::ifstream::in); + if (!is.good()) { + SHERPA_ONNX_LOGE("Cannot open config file: %s", filename.c_str()); + exit(-1); + } + + std::string line, key, value; + int32_t line_number = 0; + while (std::getline(is, line)) { + ++line_number; + // trim out the comments + size_t pos = line.find_first_of('#'); + if (pos != std::string::npos) { + line.erase(pos); + } + // skip empty lines + Trim(&line); + if (line.empty()) continue; + + if (line.substr(0, 2) != "--") { + SHERPA_ONNX_LOGE( + "Reading config file %s: line %d does not look like a line " + "from a sherpa-mnn command-line program's config file: should " + "be of the form --x=y. Note: config files intended to " + "be sourced by shell scripts lack the '--'.", + filename.c_str(), line_number); + exit(-1); + } + + // parse option + bool has_equal_sign = false; + SplitLongArg(line, &key, &value, &has_equal_sign); + NormalizeArgName(&key); + Trim(&value); + if (!SetOption(key, value, has_equal_sign)) { + PrintUsage(true); + SHERPA_ONNX_LOGE("Invalid option %s in config file %s: line %d", + line.c_str(), filename.c_str(), line_number); + exit(-1); + } + } +} + +void ParseOptions::SplitLongArg(const std::string &in, std::string *key, + std::string *value, + bool *has_equal_sign) const { + SHERPA_ONNX_CHECK(in.substr(0, 2) == "--") << in; // precondition. + size_t pos = in.find_first_of('=', 0); + if (pos == std::string::npos) { // we allow --option for bools + // defaults to empty. We handle this differently in different cases. + *key = in.substr(2, in.size() - 2); // 2 because starts with --. + *value = ""; + *has_equal_sign = false; + } else if (pos == 2) { // we also don't allow empty keys: --=value + PrintUsage(true); + SHERPA_ONNX_LOGE("Invalid option (no key): %s", in.c_str()); + exit(-1); + } else { // normal case: --option=value + *key = in.substr(2, pos - 2); // 2 because starts with --. + *value = in.substr(pos + 1); + *has_equal_sign = true; + } +} + +void ParseOptions::NormalizeArgName(std::string *str) const { + std::string out; + std::string::iterator it; + + for (it = str->begin(); it != str->end(); ++it) { + if (*it == '_') { + out += '-'; // convert _ to - + } else { + out += std::tolower(*it); + } + } + *str = out; + + SHERPA_ONNX_CHECK_GT(str->length(), 0); +} + +void ParseOptions::Trim(std::string *str) const { + const char *white_chars = " \t\n\r\f\v"; + + std::string::size_type pos = str->find_last_not_of(white_chars); + if (pos != std::string::npos) { + str->erase(pos + 1); + pos = str->find_first_not_of(white_chars); + if (pos != std::string::npos) str->erase(0, pos); + } else { + str->erase(str->begin(), str->end()); + } +} + +bool ParseOptions::SetOption(const std::string &key, const std::string &value, + bool has_equal_sign) { + if (bool_map_.end() != bool_map_.find(key)) { + if (has_equal_sign && value.empty()) { + SHERPA_ONNX_LOGE("Invalid option --%s=", key.c_str()); + exit(-1); + } + *(bool_map_[key]) = ToBool(value); + } else if (int_map_.end() != int_map_.find(key)) { + *(int_map_[key]) = ToInt(value); + } else if (int64_map_.end() != int64_map_.find(key)) { + *(int64_map_[key]) = ToInt64(value); + } else if (uint_map_.end() != uint_map_.find(key)) { + *(uint_map_[key]) = ToUint(value); + } else if (float_map_.end() != float_map_.find(key)) { + *(float_map_[key]) = ToFloat(value); + } else if (double_map_.end() != double_map_.find(key)) { + *(double_map_[key]) = ToDouble(value); + } else if (string_map_.end() != string_map_.find(key)) { + if (!has_equal_sign) { + SHERPA_ONNX_LOGE("Invalid option --%s (option format is --x=y).", + key.c_str()); + exit(-1); + } + *(string_map_[key]) = value; + } else { + return false; + } + return true; +} + +bool ParseOptions::ToBool(std::string str) const { + std::transform(str.begin(), str.end(), str.begin(), ::tolower); + + // allow "" as a valid option for "true", so that --x is the same as --x=true + if (str == "true" || str == "t" || str == "1" || str.empty()) { + return true; + } + if (str == "false" || str == "f" || str == "0") { + return false; + } + // if it is neither true nor false: + PrintUsage(true); + SHERPA_ONNX_LOGE( + "Invalid format for boolean argument [expected true or false]: %s", + str.c_str()); + exit(-1); + return false; // never reached +} + +int32_t ParseOptions::ToInt(const std::string &str) const { + int32_t ret = 0; + if (!ConvertStringToInteger(str, &ret)) { + SHERPA_ONNX_LOGE("Invalid integer option \"%s\"", str.c_str()); + exit(-1); + } + return ret; +} + +int64_t ParseOptions::ToInt64(const std::string &str) const { + int64_t ret = 0; + if (!ConvertStringToInteger(str, &ret)) { + SHERPA_ONNX_LOGE("Invalid integer int64 option \"%s\"", str.c_str()); + exit(-1); + } + return ret; +} + +uint32_t ParseOptions::ToUint(const std::string &str) const { + uint32_t ret = 0; + if (!ConvertStringToInteger(str, &ret)) { + SHERPA_ONNX_LOGE("Invalid integer option \"%s\"", str.c_str()); + exit(-1); + } + return ret; +} + +float ParseOptions::ToFloat(const std::string &str) const { + float ret = 0; + if (!ConvertStringToReal(str, &ret)) { + SHERPA_ONNX_LOGE("Invalid floating-point option \"%s\"", str.c_str()); + exit(-1); + } + return ret; +} + +double ParseOptions::ToDouble(const std::string &str) const { + double ret = 0; + if (!ConvertStringToReal(str, &ret)) { + SHERPA_ONNX_LOGE("Invalid floating-point option \"%s\"", str.c_str()); + exit(-1); + } + return ret; +} + +// instantiate templates +template void ParseOptions::RegisterTmpl(const std::string &name, bool *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, int *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, uint32_t *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, float *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, double *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, + std::string *ptr, + const std::string &doc); + +template void ParseOptions::RegisterStandard(const std::string &name, bool *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + int *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + uint32_t *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + float *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + double *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + std::string *ptr, + const std::string &doc); + +template void ParseOptions::RegisterCommon(const std::string &name, bool *ptr, + const std::string &doc, + bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, + int *ptr, const std::string &doc, + bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, + uint32_t *ptr, + const std::string &doc, + bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, float *ptr, + const std::string &doc, + bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, double *ptr, + const std::string &doc, + bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, + std::string *ptr, + const std::string &doc, + bool is_standard); + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/parse-options.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/parse-options.h new file mode 100644 index 00000000..b4bf9900 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/parse-options.h @@ -0,0 +1,259 @@ +// sherpa-mnn/csrc/parse-options.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation +// +// This file is copied and modified from kaldi/src/util/parse-options.h + +#ifndef SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_ +#define SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_ + +#include +#include +#include +#include +#include + +namespace sherpa_mnn { + +class ParseOptions { + public: + explicit ParseOptions(const char *usage) + : print_args_(true), + help_(false), + usage_(usage), + argc_(0), + argv_(nullptr), + prefix_(""), + other_parser_(nullptr) { +#if !defined(_MSC_VER) && !defined(__CYGWIN__) + // This is just a convenient place to set the stderr to line + // buffering mode, since it's called at program start. + // This helps ensure different programs' output is not mixed up. + setlinebuf(stderr); +#endif + RegisterStandard("config", &config_, + "Configuration file to read (this " + "option may be repeated)"); + RegisterStandard("print-args", &print_args_, + "Print the command line arguments (to stderr)"); + RegisterStandard("help", &help_, "Print out usage message"); + } + + /** + This is a constructor for the special case where some options are + registered with a prefix to avoid conflicts. The object thus created will + only be used temporarily to register an options class with the original + options parser (which is passed as the *other pointer) using the given + prefix. It should not be used for any other purpose, and the prefix must + not be the empty string. It seems to be the least bad way of implementing + options with prefixes at this point. + Example of usage is: + ParseOptions po; // original ParseOptions object + ParseOptions po_mfcc("mfcc", &po); // object with prefix. + MfccOptions mfcc_opts; + mfcc_opts.Register(&po_mfcc); + The options will now get registered as, e.g., --mfcc.frame-shift=10.0 + instead of just --frame-shift=10.0 + */ + ParseOptions(const std::string &prefix, ParseOptions *other); + + ParseOptions(const ParseOptions &) = delete; + ParseOptions &operator=(const ParseOptions &) = delete; + ~ParseOptions() = default; + + void Register(const std::string &name, bool *ptr, const std::string &doc); + void Register(const std::string &name, int32_t *ptr, const std::string &doc); + void Register(const std::string &name, int64_t *ptr, const std::string &doc); + void Register(const std::string &name, uint32_t *ptr, const std::string &doc); + void Register(const std::string &name, float *ptr, const std::string &doc); + void Register(const std::string &name, double *ptr, const std::string &doc); + void Register(const std::string &name, std::string *ptr, + const std::string &doc); + + /// If called after registering an option and before calling + /// Read(), disables that option from being used. Will crash + /// at runtime if that option had not been registered. + void DisableOption(const std::string &name); + + /// This one is used for registering standard parameters of all the programs + template + void RegisterStandard(const std::string &name, T *ptr, + const std::string &doc); + + /** + Parses the command line options and fills the ParseOptions-registered + variables. This must be called after all the variables were registered!!! + + Initially the variables have implicit values, + then the config file values are set-up, + finally the command line values given. + Returns the first position in argv that was not used. + [typically not useful: use NumParams() and GetParam(). ] + */ + int Read(int argc, const char *const *argv); + + /// Prints the usage documentation [provided in the constructor]. + void PrintUsage(bool print_command_line = false) const; + + /// Prints the actual configuration of all the registered variables + void PrintConfig(std::ostream &os) const; + + /// Reads the options values from a config file. Must be called after + /// registering all options. This is usually used internally after the + /// standard --config option is used, but it may also be called from a + /// program. + void ReadConfigFile(const std::string &filename); + + /// Number of positional parameters (c.f. argc-1). + int NumArgs() const; + + /// Returns one of the positional parameters; 1-based indexing for argc/argv + /// compatibility. Will crash if param is not >=1 and <=NumArgs(). + /// + /// Note: Index is 1 based. + std::string GetArg(int param) const; + + std::string GetOptArg(int param) const { + return (param <= NumArgs() ? GetArg(param) : ""); + } + + /// The following function will return a possibly quoted and escaped + /// version of "str", according to the current shell. Currently + /// this is just hardwired to bash. It's useful for debug output. + static std::string Escape(const std::string &str); + + private: + /// Template to register various variable types, + /// used for program-specific parameters + template + void RegisterTmpl(const std::string &name, T *ptr, const std::string &doc); + + // Following functions do just the datatype-specific part of the job + /// Register boolean variable + void RegisterSpecific(const std::string &name, const std::string &idx, + bool *b, const std::string &doc, bool is_standard); + /// Register int32_t variable + void RegisterSpecific(const std::string &name, const std::string &idx, + int32_t *i, const std::string &doc, bool is_standard); + /// Register int64_t variable + void RegisterSpecific(const std::string &name, const std::string &idx, + int64_t *i, const std::string &doc, bool is_standard); + /// Register unsigned int32_t variable + void RegisterSpecific(const std::string &name, const std::string &idx, + uint32_t *u, const std::string &doc, bool is_standard); + /// Register float variable + void RegisterSpecific(const std::string &name, const std::string &idx, + float *f, const std::string &doc, bool is_standard); + /// Register double variable [useful as we change BaseFloat type]. + void RegisterSpecific(const std::string &name, const std::string &idx, + double *f, const std::string &doc, bool is_standard); + /// Register string variable + void RegisterSpecific(const std::string &name, const std::string &idx, + std::string *s, const std::string &doc, + bool is_standard); + + /// Does the actual job for both kinds of parameters + /// Does the common part of the job for all datatypes, + /// then calls RegisterSpecific + template + void RegisterCommon(const std::string &name, T *ptr, const std::string &doc, + bool is_standard); + + /// Set option with name "key" to "value"; will crash if can't do it. + /// "has_equal_sign" is used to allow --x for a boolean option x, + /// and --y=, for a string option y. + bool SetOption(const std::string &key, const std::string &value, + bool has_equal_sign); + + bool ToBool(std::string str) const; + int32_t ToInt(const std::string &str) const; + int64_t ToInt64(const std::string &str) const; + uint32_t ToUint(const std::string &str) const; + float ToFloat(const std::string &str) const; + double ToDouble(const std::string &str) const; + + // maps for option variables + std::unordered_map bool_map_; + std::unordered_map int_map_; + std::unordered_map int64_map_; + std::unordered_map uint_map_; + std::unordered_map float_map_; + std::unordered_map double_map_; + std::unordered_map string_map_; + + /** + Structure for options' documentation + */ + struct DocInfo { + DocInfo() = default; + DocInfo(const std::string &name, const std::string &usemsg) + : name_(name), use_msg_(usemsg), is_standard_(false) {} + DocInfo(const std::string &name, const std::string &usemsg, + bool is_standard) + : name_(name), use_msg_(usemsg), is_standard_(is_standard) {} + + std::string name_; + std::string use_msg_; + bool is_standard_; + }; + using DocMapType = std::unordered_map; + DocMapType doc_map_; ///< map for the documentation + + bool print_args_; ///< variable for the implicit --print-args parameter + bool help_; ///< variable for the implicit --help parameter + std::string config_; ///< variable for the implicit --config parameter + std::vector positional_args_; + const char *usage_; + int argc_; + const char *const *argv_; + + /// These members are not normally used. They are only used when the object + /// is constructed with a prefix + std::string prefix_; + ParseOptions *other_parser_; + + protected: + /// SplitLongArg parses an argument of the form --a=b, --a=, or --a, + /// and sets "has_equal_sign" to true if an equals-sign was parsed.. + /// this is needed in order to correctly allow --x for a boolean option + /// x, and --y= for a string option y, and to disallow --x= and --y. + void SplitLongArg(const std::string &in, std::string *key, std::string *value, + bool *has_equal_sign) const; + + void NormalizeArgName(std::string *str) const; + + /// Removes the beginning and trailing whitespaces from a string + void Trim(std::string *str) const; +}; + +/// This template is provided for convenience in reading config classes from +/// files; this is not the standard way to read configuration options, but may +/// occasionally be needed. This function assumes the config has a function +/// "void Register(ParseOptions *opts)" which it can call to register the +/// ParseOptions object. +template +void ReadConfigFromFile(const std::string &config_filename, C *c) { + std::ostringstream usage_str; + usage_str << "Parsing config from " + << "from '" << config_filename << "'"; + ParseOptions po(usage_str.str().c_str()); + c->Register(&po); + po.ReadConfigFile(config_filename); +} + +/// This variant of the template ReadConfigFromFile is for if you need to read +/// two config classes from the same file. +template +void ReadConfigsFromFile(const std::string &conf, C1 *c1, C2 *c2) { + std::ostringstream usage_str; + usage_str << "Parsing config from " + << "from '" << conf << "'"; + ParseOptions po(usage_str.str().c_str()); + c1->Register(&po); + c2->Register(&po); + po.ReadConfigFile(conf); +} + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_PARSE_OPTIONS_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/piper-phonemize-lexicon.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/piper-phonemize-lexicon.cc new file mode 100644 index 00000000..c284d051 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/piper-phonemize-lexicon.cc @@ -0,0 +1,477 @@ +// sherpa-mnn/csrc/piper-phonemize-lexicon.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/piper-phonemize-lexicon.h" + +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "espeak-ng/speak_lib.h" +#include "phoneme_ids.hpp" +#include "phonemize.hpp" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void CallPhonemizeEspeak(const std::string &text, + piper::eSpeakPhonemeConfig &config, // NOLINT + std::vector> *phonemes) { + static std::mutex espeak_mutex; + + std::lock_guard lock(espeak_mutex); + + // keep multi threads from calling into piper::phonemize_eSpeak + piper::phonemize_eSpeak(text, config, *phonemes); +} + +static std::unordered_map ReadTokens(std::istream &is) { + std::wstring_convert, char32_t> conv; + std::unordered_map token2id; + + std::string line; + + std::string sym; + std::u32string s; + int32_t id = 0; + while (std::getline(is, line)) { + std::istringstream iss(line); + iss >> sym; + if (iss.eof()) { + id = atoi(sym.c_str()); + sym = " "; + } else { + iss >> id; + } + + // eat the trailing \r\n on windows + iss >> std::ws; + if (!iss.eof()) { + SHERPA_ONNX_LOGE("Error when reading tokens: %s", line.c_str()); + exit(-1); + } + + s = conv.from_bytes(sym); + if (s.size() != 1) { + // for tokens.txt from coqui-ai/TTS, the last token is + if (s.size() == 6 && s[0] == '<' && s[1] == 'B' && s[2] == 'L' && + s[3] == 'N' && s[4] == 'K' && s[5] == '>') { + continue; + } + + SHERPA_ONNX_LOGE("Error when reading tokens at Line %s. size: %d", + line.c_str(), static_cast(s.size())); + exit(-1); + } + + char32_t c = s[0]; + + if (token2id.count(c)) { + SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d", + sym.c_str(), line.c_str(), token2id.at(c)); + exit(-1); + } + + token2id.insert({c, id}); + } + + return token2id; +} + +// see the function "phonemes_to_ids" from +// https://github.com/rhasspy/piper/blob/master/notebooks/piper_inference_(ONNX).ipynb +static std::vector PiperPhonemesToIdsVits( + const std::unordered_map &token2id, + const std::vector &phonemes) { + // see + // https://github.com/rhasspy/piper-phonemize/blob/master/src/phoneme_ids.hpp#L17 + int32_t pad = token2id.at(U'_'); + int32_t bos = token2id.at(U'^'); + int32_t eos = token2id.at(U'$'); + + std::vector ans; + ans.reserve(phonemes.size()); + + ans.push_back(bos); + for (auto p : phonemes) { + if (token2id.count(p)) { + ans.push_back(token2id.at(p)); + ans.push_back(pad); + } else { + SHERPA_ONNX_LOGE("Skip unknown phonemes. Unicode codepoint: \\U+%04x.", + static_cast(p)); + } + } + ans.push_back(eos); + + return ans; +} + +static std::vector PiperPhonemesToIdsMatcha( + const std::unordered_map &token2id, + const std::vector &phonemes, bool use_eos_bos) { + std::vector ans; + ans.reserve(phonemes.size()); + + int32_t bos = token2id.at(U'^'); + int32_t eos = token2id.at(U'$'); + + if (use_eos_bos) { + ans.push_back(bos); + } + + for (auto p : phonemes) { + if (token2id.count(p)) { + ans.push_back(token2id.at(p)); + } else { + SHERPA_ONNX_LOGE("Skip unknown phonemes. Unicode codepoint: \\U+%04x.", + static_cast(p)); + } + } + + if (use_eos_bos) { + ans.push_back(eos); + } + + return ans; +} + +static std::vector> PiperPhonemesToIdsKokoro( + const std::unordered_map &token2id, + const std::vector &phonemes, int32_t max_len) { + std::vector> ans; + + std::vector current; + current.reserve(phonemes.size()); + + current.push_back(0); + + for (auto p : phonemes) { + if (token2id.count(p)) { + if (current.size() > max_len - 1) { + current.push_back(0); + ans.push_back(std::move(current)); + + current.reserve(phonemes.size()); + current.push_back(0); + } + + current.push_back(token2id.at(p)); + } else { + SHERPA_ONNX_LOGE("Skip unknown phonemes. Unicode codepoint: \\U+%04x.", + static_cast(p)); + } + } + + current.push_back(0); + ans.push_back(std::move(current)); + return ans; +} + +static std::vector CoquiPhonemesToIds( + const std::unordered_map &token2id, + const std::vector &phonemes, + const OfflineTtsVitsModelMetaData &vits_meta_data) { + // see + // https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87 + int32_t use_eos_bos = vits_meta_data.use_eos_bos; + int32_t bos_id = vits_meta_data.bos_id; + int32_t eos_id = vits_meta_data.eos_id; + int32_t blank_id = vits_meta_data.blank_id; + int32_t add_blank = vits_meta_data.add_blank; + int32_t comma_id = token2id.at(','); + + std::vector ans; + if (add_blank) { + ans.reserve(phonemes.size() * 2 + 3); + } else { + ans.reserve(phonemes.size() + 2); + } + + if (use_eos_bos) { + ans.push_back(bos_id); + } + + if (add_blank) { + ans.push_back(blank_id); + + for (auto p : phonemes) { + if (token2id.count(p)) { + ans.push_back(token2id.at(p)); + ans.push_back(blank_id); + } else { + SHERPA_ONNX_LOGE("Skip unknown phonemes. Unicode codepoint: \\U+%04x.", + static_cast(p)); + } + } + } else { + // not adding blank + for (auto p : phonemes) { + if (token2id.count(p)) { + ans.push_back(token2id.at(p)); + } else { + SHERPA_ONNX_LOGE("Skip unknown phonemes. Unicode codepoint: \\U+%04x.", + static_cast(p)); + } + } + } + + // add a comma at the end of a sentence so that we can have a longer pause. + ans.push_back(comma_id); + + if (use_eos_bos) { + ans.push_back(eos_id); + } + + return ans; +} + +void InitEspeak(const std::string &data_dir) { + static std::once_flag init_flag; + std::call_once(init_flag, [data_dir]() { + int32_t result = + espeak_Initialize(AUDIO_OUTPUT_SYNCHRONOUS, 0, data_dir.c_str(), 0); + if (result != 22050) { + SHERPA_ONNX_LOGE( + "Failed to initialize espeak-ng with data dir: %s. Return code is: " + "%d", + data_dir.c_str(), result); + exit(-1); + } + }); +} + +PiperPhonemizeLexicon::PiperPhonemizeLexicon( + const std::string &tokens, const std::string &data_dir, + const OfflineTtsVitsModelMetaData &vits_meta_data) + : vits_meta_data_(vits_meta_data) { + { + std::ifstream is(tokens); + token2id_ = ReadTokens(is); + } + + InitEspeak(data_dir); +} + +template +PiperPhonemizeLexicon::PiperPhonemizeLexicon( + Manager *mgr, const std::string &tokens, const std::string &data_dir, + const OfflineTtsVitsModelMetaData &vits_meta_data) + : vits_meta_data_(vits_meta_data) { + { + auto buf = ReadFile(mgr, tokens); + std::istrstream is(buf.data(), buf.size()); + token2id_ = ReadTokens(is); + } + + // We should copy the directory of espeak-ng-data from the asset to + // some internal or external storage and then pass the directory to + // data_dir. + InitEspeak(data_dir); +} + +PiperPhonemizeLexicon::PiperPhonemizeLexicon( + const std::string &tokens, const std::string &data_dir, + const OfflineTtsMatchaModelMetaData &matcha_meta_data) + : matcha_meta_data_(matcha_meta_data), is_matcha_(true) { + { + std::ifstream is(tokens); + token2id_ = ReadTokens(is); + } + + InitEspeak(data_dir); +} + +PiperPhonemizeLexicon::PiperPhonemizeLexicon( + const std::string &tokens, const std::string &data_dir, + const OfflineTtsKokoroModelMetaData &kokoro_meta_data) + : kokoro_meta_data_(kokoro_meta_data), is_kokoro_(true) { + { + std::ifstream is(tokens); + token2id_ = ReadTokens(is); + } + + InitEspeak(data_dir); +} + +template +PiperPhonemizeLexicon::PiperPhonemizeLexicon( + Manager *mgr, const std::string &tokens, const std::string &data_dir, + const OfflineTtsMatchaModelMetaData &matcha_meta_data) + : matcha_meta_data_(matcha_meta_data), is_matcha_(true) { + { + auto buf = ReadFile(mgr, tokens); + std::istrstream is(buf.data(), buf.size()); + token2id_ = ReadTokens(is); + } + + // We should copy the directory of espeak-ng-data from the asset to + // some internal or external storage and then pass the directory to + // data_dir. + InitEspeak(data_dir); +} + +template +PiperPhonemizeLexicon::PiperPhonemizeLexicon( + Manager *mgr, const std::string &tokens, const std::string &data_dir, + const OfflineTtsKokoroModelMetaData &kokoro_meta_data) + : kokoro_meta_data_(kokoro_meta_data), is_kokoro_(true) { + { + auto buf = ReadFile(mgr, tokens); + std::istrstream is(buf.data(), buf.size()); + token2id_ = ReadTokens(is); + } + + // We should copy the directory of espeak-ng-data from the asset to + // some internal or external storage and then pass the directory to + // data_dir. + InitEspeak(data_dir); +} + +std::vector PiperPhonemizeLexicon::ConvertTextToTokenIds( + const std::string &text, const std::string &voice /*= ""*/) const { + if (is_matcha_) { + return ConvertTextToTokenIdsMatcha(text, voice); + } else if (is_kokoro_) { + return ConvertTextToTokenIdsKokoro(text, voice); + } else { + return ConvertTextToTokenIdsVits(text, voice); + } +} + +std::vector PiperPhonemizeLexicon::ConvertTextToTokenIdsMatcha( + const std::string &text, const std::string &voice /*= ""*/) const { + piper::eSpeakPhonemeConfig config; + + // ./bin/espeak-ng-bin --path ./install/share/espeak-ng-data/ --voices + // to list available voices + config.voice = voice; // e.g., voice is en-us + + std::vector> phonemes; + + CallPhonemizeEspeak(text, config, &phonemes); + + std::vector ans; + + std::vector phoneme_ids; + + for (const auto &p : phonemes) { + phoneme_ids = + PiperPhonemesToIdsMatcha(token2id_, p, matcha_meta_data_.use_eos_bos); + ans.emplace_back(std::move(phoneme_ids)); + } + + return ans; +} + +std::vector PiperPhonemizeLexicon::ConvertTextToTokenIdsKokoro( + const std::string &text, const std::string &voice /*= ""*/) const { + piper::eSpeakPhonemeConfig config; + + // ./bin/espeak-ng-bin --path ./install/share/espeak-ng-data/ --voices + // to list available voices + config.voice = voice; // e.g., voice is en-us + + std::vector> phonemes; + + CallPhonemizeEspeak(text, config, &phonemes); + + std::vector ans; + + for (const auto &p : phonemes) { + auto phoneme_ids = + PiperPhonemesToIdsKokoro(token2id_, p, kokoro_meta_data_.max_token_len); + + for (auto &ids : phoneme_ids) { + ans.emplace_back(std::move(ids)); + } + } + + return ans; +} + +std::vector PiperPhonemizeLexicon::ConvertTextToTokenIdsVits( + const std::string &text, const std::string &voice /*= ""*/) const { + piper::eSpeakPhonemeConfig config; + + // ./bin/espeak-ng-bin --path ./install/share/espeak-ng-data/ --voices + // to list available voices + config.voice = voice; // e.g., voice is en-us + + std::vector> phonemes; + + CallPhonemizeEspeak(text, config, &phonemes); + + std::vector ans; + + std::vector phoneme_ids; + + if (vits_meta_data_.is_piper || vits_meta_data_.is_icefall) { + for (const auto &p : phonemes) { + phoneme_ids = PiperPhonemesToIdsVits(token2id_, p); + ans.emplace_back(std::move(phoneme_ids)); + } + } else if (vits_meta_data_.is_coqui) { + for (const auto &p : phonemes) { + phoneme_ids = CoquiPhonemesToIds(token2id_, p, vits_meta_data_); + ans.emplace_back(std::move(phoneme_ids)); + } + + } else { + SHERPA_ONNX_LOGE("Unsupported model"); + exit(-1); + } + + return ans; +} + +#if __ANDROID_API__ >= 9 +template PiperPhonemizeLexicon::PiperPhonemizeLexicon( + AAssetManager *mgr, const std::string &tokens, const std::string &data_dir, + const OfflineTtsVitsModelMetaData &vits_meta_data); + +template PiperPhonemizeLexicon::PiperPhonemizeLexicon( + AAssetManager *mgr, const std::string &tokens, const std::string &data_dir, + const OfflineTtsMatchaModelMetaData &matcha_meta_data); + +template PiperPhonemizeLexicon::PiperPhonemizeLexicon( + AAssetManager *mgr, const std::string &tokens, const std::string &data_dir, + const OfflineTtsKokoroModelMetaData &kokoro_meta_data); +#endif + +#if __OHOS__ +template PiperPhonemizeLexicon::PiperPhonemizeLexicon( + NativeResourceManager *mgr, const std::string &tokens, + const std::string &data_dir, + const OfflineTtsVitsModelMetaData &vits_meta_data); + +template PiperPhonemizeLexicon::PiperPhonemizeLexicon( + NativeResourceManager *mgr, const std::string &tokens, + const std::string &data_dir, + const OfflineTtsMatchaModelMetaData &matcha_meta_data); + +template PiperPhonemizeLexicon::PiperPhonemizeLexicon( + NativeResourceManager *mgr, const std::string &tokens, + const std::string &data_dir, + const OfflineTtsKokoroModelMetaData &kokoro_meta_data); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/piper-phonemize-lexicon.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/piper-phonemize-lexicon.h new file mode 100644 index 00000000..cd0a1a92 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/piper-phonemize-lexicon.h @@ -0,0 +1,70 @@ +// sherpa-mnn/csrc/piper-phonemize-lexicon.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_PIPER_PHONEMIZE_LEXICON_H_ +#define SHERPA_ONNX_CSRC_PIPER_PHONEMIZE_LEXICON_H_ + +#include +#include +#include + +#include "sherpa-mnn/csrc/offline-tts-frontend.h" +#include "sherpa-mnn/csrc/offline-tts-kokoro-model-meta-data.h" +#include "sherpa-mnn/csrc/offline-tts-matcha-model-meta-data.h" +#include "sherpa-mnn/csrc/offline-tts-vits-model-meta-data.h" + +namespace sherpa_mnn { + +class PiperPhonemizeLexicon : public OfflineTtsFrontend { + public: + PiperPhonemizeLexicon(const std::string &tokens, const std::string &data_dir, + const OfflineTtsVitsModelMetaData &vits_meta_data); + + PiperPhonemizeLexicon(const std::string &tokens, const std::string &data_dir, + const OfflineTtsMatchaModelMetaData &matcha_meta_data); + + PiperPhonemizeLexicon(const std::string &tokens, const std::string &data_dir, + const OfflineTtsKokoroModelMetaData &kokoro_meta_data); + + template + PiperPhonemizeLexicon(Manager *mgr, const std::string &tokens, + const std::string &data_dir, + const OfflineTtsVitsModelMetaData &vits_meta_data); + + template + PiperPhonemizeLexicon(Manager *mgr, const std::string &tokens, + const std::string &data_dir, + const OfflineTtsMatchaModelMetaData &matcha_meta_data); + + template + PiperPhonemizeLexicon(Manager *mgr, const std::string &tokens, + const std::string &data_dir, + const OfflineTtsKokoroModelMetaData &kokoro_meta_data); + + std::vector ConvertTextToTokenIds( + const std::string &text, const std::string &voice = "") const override; + + private: + std::vector ConvertTextToTokenIdsVits( + const std::string &text, const std::string &voice = "") const; + + std::vector ConvertTextToTokenIdsMatcha( + const std::string &text, const std::string &voice = "") const; + + std::vector ConvertTextToTokenIdsKokoro( + const std::string &text, const std::string &voice = "") const; + + private: + // map unicode codepoint to an integer ID + std::unordered_map token2id_; + OfflineTtsVitsModelMetaData vits_meta_data_; + OfflineTtsMatchaModelMetaData matcha_meta_data_; + OfflineTtsKokoroModelMetaData kokoro_meta_data_; + bool is_matcha_ = false; + bool is_kokoro_ = false; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_PIPER_PHONEMIZE_LEXICON_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/piper-phonemize-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/piper-phonemize-test.cc new file mode 100644 index 00000000..f246b55a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/piper-phonemize-test.cc @@ -0,0 +1,78 @@ +// sherpa-mnn/csrc/piper-phonemize-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "espeak-ng/speak_lib.h" +#include "gtest/gtest.h" +#include "phoneme_ids.hpp" +#include "phonemize.hpp" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +TEST(PiperPhonemize, Case1) { + std::string data_dir = "./install/share/espeak-ng-data"; + if (!FileExists(data_dir + "/en_dict")) { + SHERPA_ONNX_LOGE("%s/en_dict does not exist. Skipping test", + data_dir.c_str()); + return; + } + + if (!FileExists(data_dir + "/phontab")) { + SHERPA_ONNX_LOGE("%s/phontab does not exist. Skipping test", + data_dir.c_str()); + return; + } + + if (!FileExists(data_dir + "/phonindex")) { + SHERPA_ONNX_LOGE("%s/phonindex does not exist. Skipping test", + data_dir.c_str()); + return; + } + + if (!FileExists(data_dir + "/phondata")) { + SHERPA_ONNX_LOGE("%s/phondata does not exist. Skipping test", + data_dir.c_str()); + return; + } + + if (!FileExists(data_dir + "/intonations")) { + SHERPA_ONNX_LOGE("%s/intonations does not exist. Skipping test", + data_dir.c_str()); + return; + } + int32_t result = + espeak_Initialize(AUDIO_OUTPUT_SYNCHRONOUS, 0, data_dir.c_str(), 0); + EXPECT_EQ(result, 22050); + + piper::eSpeakPhonemeConfig config; + + // ./bin/espeak-ng-bin --path ./install/share/espeak-ng-data/ --voices + // to list available voices + config.voice = "en-us"; + + std::vector> phonemes; + std::string text = "how are you doing?"; + piper::phonemize_eSpeak(text, config, phonemes); + + for (int32_t p : phonemes[0]) { + std::cout << p << " "; + } + std::cout << "\n"; + + std::vector phoneme_ids; + std::map missing_phonemes; + + { + piper::PhonemeIdConfig config; + phonemes_to_ids(phonemes[0], config, phoneme_ids, missing_phonemes); + } + + for (int32_t p : phoneme_ids) { + std::cout << p << " "; + } + std::cout << "\n"; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/provider-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/provider-config.cc new file mode 100644 index 00000000..a7d8d9f4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/provider-config.cc @@ -0,0 +1,142 @@ +// sherpa-mnn/csrc/provider-config.cc +// +// Copyright (c) 2024 Uniphore (Author: Manickavela) + +#include "sherpa-mnn/csrc/provider-config.h" + +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void CudaConfig::Register(ParseOptions *po) { + po->Register("cuda-cudnn-conv-algo-search", &cudnn_conv_algo_search, + "CuDNN convolution algrorithm search"); +} + +bool CudaConfig::Validate() const { + if (cudnn_conv_algo_search < 1 || cudnn_conv_algo_search > 3) { + SHERPA_ONNX_LOGE( + "cudnn_conv_algo_search: '%d' is not a valid option." + "Options : [1,3]. Check OnnxRT docs", + cudnn_conv_algo_search); + return false; + } + return true; +} + +std::string CudaConfig::ToString() const { + std::ostringstream os; + + os << "CudaConfig("; + os << "cudnn_conv_algo_search=" << cudnn_conv_algo_search << ")"; + + return os.str(); +} + +void TensorrtConfig::Register(ParseOptions *po) { + po->Register("trt-max-workspace-size", &trt_max_workspace_size, + "Set TensorRT EP GPU memory usage limit."); + po->Register("trt-max-partition-iterations", &trt_max_partition_iterations, + "Limit partitioning iterations for model conversion."); + po->Register("trt-min-subgraph-size", &trt_min_subgraph_size, + "Set minimum size for subgraphs in partitioning."); + po->Register("trt-fp16-enable", &trt_fp16_enable, + "Enable FP16 precision for faster performance."); + po->Register("trt-detailed-build-log", &trt_detailed_build_log, + "Enable detailed logging of build steps."); + po->Register("trt-engine-cache-enable", &trt_engine_cache_enable, + "Enable caching of TensorRT engines."); + po->Register("trt-timing-cache-enable", &trt_timing_cache_enable, + "Enable use of timing cache to speed up builds."); + po->Register("trt-engine-cache-path", &trt_engine_cache_path, + "Set path to store cached TensorRT engines."); + po->Register("trt-timing-cache-path", &trt_timing_cache_path, + "Set path for storing timing cache."); + po->Register("trt-dump-subgraphs", &trt_dump_subgraphs, + "Dump optimized subgraphs for debugging."); +} + +bool TensorrtConfig::Validate() const { + if (trt_max_workspace_size < 0) { + std::ostringstream os; + os << "trt_max_workspace_size: " << trt_max_workspace_size + << " is not valid."; + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + return false; + } + if (trt_max_partition_iterations < 0) { + SHERPA_ONNX_LOGE("trt_max_partition_iterations: %d is not valid.", + trt_max_partition_iterations); + return false; + } + if (trt_min_subgraph_size < 0) { + SHERPA_ONNX_LOGE("trt_min_subgraph_size: %d is not valid.", + trt_min_subgraph_size); + return false; + } + + return true; +} + +std::string TensorrtConfig::ToString() const { + std::ostringstream os; + + os << "TensorrtConfig("; + os << "trt_max_workspace_size=" << trt_max_workspace_size << ", "; + os << "trt_max_partition_iterations=" << trt_max_partition_iterations << ", "; + os << "trt_min_subgraph_size=" << trt_min_subgraph_size << ", "; + os << "trt_fp16_enable=\"" << (trt_fp16_enable ? "True" : "False") << "\", "; + os << "trt_detailed_build_log=\"" + << (trt_detailed_build_log ? "True" : "False") << "\", "; + os << "trt_engine_cache_enable=\"" + << (trt_engine_cache_enable ? "True" : "False") << "\", "; + os << "trt_engine_cache_path=\"" << trt_engine_cache_path.c_str() << "\", "; + os << "trt_timing_cache_enable=\"" + << (trt_timing_cache_enable ? "True" : "False") << "\", "; + os << "trt_timing_cache_path=\"" << trt_timing_cache_path.c_str() << "\","; + os << "trt_dump_subgraphs=\"" << (trt_dump_subgraphs ? "True" : "False") + << "\" )"; + return os.str(); +} + +void ProviderConfig::Register(ParseOptions *po) { + cuda_config.Register(po); + trt_config.Register(po); + + po->Register("device", &device, "GPU device index for CUDA and Trt EP"); + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool ProviderConfig::Validate() const { + if (device < 0) { + SHERPA_ONNX_LOGE("device: '%d' is invalid.", device); + return false; + } + + if (provider == "cuda" && !cuda_config.Validate()) { + return false; + } + + if (provider == "trt" && !trt_config.Validate()) { + return false; + } + + return true; +} + +std::string ProviderConfig::ToString() const { + std::ostringstream os; + + os << "ProviderConfig("; + os << "device=" << device << ", "; + os << "provider=\"" << provider << "\", "; + os << "cuda_config=" << cuda_config.ToString() << ", "; + os << "trt_config=" << trt_config.ToString() << ")"; + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/provider-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/provider-config.h new file mode 100644 index 00000000..dfb29097 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/provider-config.h @@ -0,0 +1,93 @@ +// sherpa-mnn/csrc/provider-config.h +// +// Copyright (c) 2024 Uniphore (Author: Manickavela) + +#ifndef SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_ +#define SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct CudaConfig { + int32_t cudnn_conv_algo_search = 0; + + CudaConfig() = default; + explicit CudaConfig(int32_t cudnn_conv_algo_search) + : cudnn_conv_algo_search(cudnn_conv_algo_search) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +struct TensorrtConfig { + int trt_max_workspace_size = 2147483647; + int32_t trt_max_partition_iterations = 10; + int32_t trt_min_subgraph_size = 5; + bool trt_fp16_enable = true; + bool trt_detailed_build_log = false; + bool trt_engine_cache_enable = true; + bool trt_timing_cache_enable = true; + std::string trt_engine_cache_path = "."; + std::string trt_timing_cache_path = "."; + bool trt_dump_subgraphs = false; + + TensorrtConfig() = default; + TensorrtConfig(int trt_max_workspace_size, + int32_t trt_max_partition_iterations, + int32_t trt_min_subgraph_size, bool trt_fp16_enable, + bool trt_detailed_build_log, bool trt_engine_cache_enable, + bool trt_timing_cache_enable, + const std::string &trt_engine_cache_path, + const std::string &trt_timing_cache_path, + bool trt_dump_subgraphs) + : trt_max_workspace_size(trt_max_workspace_size), + trt_max_partition_iterations(trt_max_partition_iterations), + trt_min_subgraph_size(trt_min_subgraph_size), + trt_fp16_enable(trt_fp16_enable), + trt_detailed_build_log(trt_detailed_build_log), + trt_engine_cache_enable(trt_engine_cache_enable), + trt_timing_cache_enable(trt_timing_cache_enable), + trt_engine_cache_path(trt_engine_cache_path), + trt_timing_cache_path(trt_timing_cache_path), + trt_dump_subgraphs(trt_dump_subgraphs) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +struct ProviderConfig { + TensorrtConfig trt_config; + CudaConfig cuda_config; + std::string provider = "cpu"; + int32_t device = 0; + // device only used for cuda and trt + + ProviderConfig() = default; + ProviderConfig(const std::string &provider, int32_t device) + : provider(provider), device(device) {} + ProviderConfig(const TensorrtConfig &trt_config, + const CudaConfig &cuda_config, const std::string &provider, + int32_t device) + : trt_config(trt_config), + cuda_config(cuda_config), + provider(provider), + device(device) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/provider.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/provider.cc new file mode 100644 index 00000000..daf0eeef --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/provider.cc @@ -0,0 +1,37 @@ +// sherpa-mnn/csrc/provider.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/provider.h" + +#include +#include + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +Provider StringToProvider(std::string s) { + std::transform(s.cbegin(), s.cend(), s.begin(), + [](unsigned char c) { return std::tolower(c); }); + if (s == "cpu") { + return Provider::kCPU; + } else if (s == "cuda") { + return Provider::kCUDA; + } else if (s == "coreml") { + return Provider::kCoreML; + } else if (s == "xnnpack") { + return Provider::kXnnpack; + } else if (s == "nnapi") { + return Provider::kNNAPI; + } else if (s == "trt") { + return Provider::kTRT; + } else if (s == "directml") { + return Provider::kDirectML; + } else { + SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); + return Provider::kCPU; + } +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/provider.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/provider.h new file mode 100644 index 00000000..d6d50772 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/provider.h @@ -0,0 +1,36 @@ +// sherpa-mnn/csrc/provider.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_PROVIDER_H_ +#define SHERPA_ONNX_CSRC_PROVIDER_H_ + +#include + +#include "sherpa-mnn/csrc/provider-config.h" +namespace sherpa_mnn { + +// Please refer to +// https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java +// for a list of available providers +enum class Provider { + kCPU = 0, // CPUExecutionProvider + kCUDA = 1, // CUDAExecutionProvider + kCoreML = 2, // CoreMLExecutionProvider + kXnnpack = 3, // XnnpackExecutionProvider + kNNAPI = 4, // NnapiExecutionProvider + kTRT = 5, // TensorRTExecutionProvider + kDirectML = 6, // DmlExecutionProvider +}; + +/** + * Convert a string to an enum. + * + * @param s We will convert it to lowercase before comparing. + * @return Return an instance of Provider. + */ +Provider StringToProvider(std::string s); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_PROVIDER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/regex-lang-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/regex-lang-test.cc new file mode 100644 index 00000000..425f7e66 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/regex-lang-test.cc @@ -0,0 +1,86 @@ +// sherpa-mnn/csrc/regex-lang-test.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include // NOLINT + +#include "gtest/gtest.h" +#include "sherpa-mnn/csrc/text-utils.cc" + +namespace sherpa_mnn { + +static void TestLang(const std::string &expr, const std::string &text, + const std::vector &expected) { + auto ws = ToWideString(text); + std::wstring wexpr = ToWideString(expr); + std::wregex we(wexpr); + + auto begin = std::wsregex_iterator(ws.begin(), ws.end(), we); + auto end = std::wsregex_iterator(); + int32_t k = 0; + for (std::wsregex_iterator i = begin; i != end; ++i) { + std::wsmatch match = *i; + std::wstring match_str = match.str(); + auto ms = ToString(match_str); + std::cout << ms << "\n"; + EXPECT_EQ(ms, expected[k]); + k++; + } + EXPECT_EQ(k, expected.size()); +} + +TEST(German, Case1) { + std::cout << "----------Test German----------"; + // see https://character-table.netlify.app/german/ + std::string expr = + "([\\u0020-\\u005f\\u0061-" + "\\u007d\\u00a0\\u00a7\\u00a9\\u00ab\\u00bb\\u00c4\\u00d6\\u00dc\\u00df\\" + "u00e4\\u00f6\\u00fc\\u2010-\\u2011\\u2013-" + "\\u2014\\u2018\\u201a\\u201c\\u201e\\u2026\\u2030\\u20ac]+)"; + + std::string text = + "开始Übeltäter übergibt Ärzten 中间öfters äußerst ätzende Öle结束3€"; + + std::vector expected = {"Übeltäter übergibt Ärzten ", + "öfters äußerst ätzende Öle", "3€"}; + + TestLang(expr, text, expected); +} + +TEST(French, Case1) { + std::string expr = + "([\\u0020-\\u005f\\u0061-" + "\\u007a\\u007c\\u00a0\\u00a7\\u00a9\\u00ab\\u00b2-" + "\\u00b3\\u00bb\\u00c0\\u00c2\\u00c6-\\u00cb\\u00ce-" + "\\u00cf\\u00d4\\u00d9\\u00db-\\u00dc\\u00e0\\u00e2\\u00e6-" + "\\u00eb\\u00ee-\\u00ef\\u00f4\\u00f9\\u00fb-\\u00fc\\u00ff\\u0152-" + "\\u0153\\u0178\\u02b3\\u02e2\\u1d48-\\u1d49\\u2010-\\u2011\\u2013-" + "\\u2014\\u2019\\u201c-\\u201d\\u2020-\\u2021\\u2026\\u202f-" + "\\u2030\\u20ac\\u2212]+)"; + std::string text = + "L'été, 一avec son ciel bleuâtre, 二est un moment où, 三Noël, maçon"; + std::vector expected = { + "L'été, ", + "avec son ciel bleuâtre, ", + "est un moment où, ", + "Noël, maçon", + }; + TestLang(expr, text, expected); +} + +TEST(English, Case1) { + // https://character-table.netlify.app/english/ + std::string expr = + "([\\u0020-\\u005f\\u0061-\\u007a\\u007c\\u00a0\\u00a7\\u00a9\\u2010-" + "\\u2011\\u2013-\\u2014\\u2018-\\u2019\\u201c-\\u201d\\u2020-" + "\\u2021\\u2026\\u2030\\u2032-\\u2033\\u20ac]+)"; + std::string text = "一how are you doing? 二Thank you!"; + + std::vector expected = { + "how are you doing? ", + "Thank you!", + }; + TestLang(expr, text, expected); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/resample.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/resample.cc new file mode 100644 index 00000000..f0fd5c7a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/resample.cc @@ -0,0 +1,308 @@ +/** + * Copyright 2013 Pegah Ghahremani + * 2014 IMSL, PKU-HKUST (author: Wei Shi) + * 2014 Yanqing Sun, Junjie Wang + * 2014 Johns Hopkins University (author: Daniel Povey) + * Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// this file is copied and modified from +// kaldi/src/feat/resample.cc + +#include "sherpa-mnn/csrc/resample.h" + +#include +#include +#include +#include +#include + +#ifndef M_2PI +#define M_2PI 6.283185307179586476925286766559005 +#endif + +#ifndef M_PI +#define M_PI 3.1415926535897932384626433832795 +#endif + +namespace sherpa_mnn { + +template +static I Gcd(I m, I n) { + // this function is copied from kaldi/src/base/kaldi-math.h + if (m == 0 || n == 0) { + if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors. + fprintf(stderr, "Undefined GCD since m = 0, n = 0.\n"); + exit(-1); + } + return (m == 0 ? (n > 0 ? n : -n) : (m > 0 ? m : -m)); + // return absolute value of whichever is nonzero + } + // could use compile-time assertion + // but involves messing with complex template stuff. + static_assert(std::is_integral_v); + while (true) { + m %= n; + if (m == 0) return (n > 0 ? n : -n); + n %= m; + if (n == 0) return (m > 0 ? m : -m); + } +} + +/// Returns the least common multiple of two integers. Will +/// crash unless the inputs are positive. +template +static I Lcm(I m, I n) { + // This function is copied from kaldi/src/base/kaldi-math.h + assert(m > 0 && n > 0); + I gcd = Gcd(m, n); + return gcd * (m / gcd) * (n / gcd); +} + +static float DotProduct(const float *a, const float *b, int32_t n) { + float sum = 0; + for (int32_t i = 0; i != n; ++i) { + sum += a[i] * b[i]; + } + return sum; +} + +LinearResample::LinearResample(int32_t samp_rate_in_hz, + int32_t samp_rate_out_hz, float filter_cutoff_hz, + int32_t num_zeros) + : samp_rate_in_(samp_rate_in_hz), + samp_rate_out_(samp_rate_out_hz), + filter_cutoff_(filter_cutoff_hz), + num_zeros_(num_zeros) { + assert(samp_rate_in_hz > 0.0 && samp_rate_out_hz > 0.0 && + filter_cutoff_hz > 0.0 && filter_cutoff_hz * 2 <= samp_rate_in_hz && + filter_cutoff_hz * 2 <= samp_rate_out_hz && num_zeros > 0); + + // base_freq is the frequency of the repeating unit, which is the gcd + // of the input frequencies. + int32_t base_freq = Gcd(samp_rate_in_, samp_rate_out_); + input_samples_in_unit_ = samp_rate_in_ / base_freq; + output_samples_in_unit_ = samp_rate_out_ / base_freq; + + SetIndexesAndWeights(); + Reset(); +} + +void LinearResample::SetIndexesAndWeights() { + first_index_.resize(output_samples_in_unit_); + weights_.resize(output_samples_in_unit_); + + double window_width = num_zeros_ / (2.0 * filter_cutoff_); + + for (int32_t i = 0; i < output_samples_in_unit_; i++) { + double output_t = i / static_cast(samp_rate_out_); + double min_t = output_t - window_width, max_t = output_t + window_width; + // we do ceil on the min and floor on the max, because if we did it + // the other way around we would unnecessarily include indexes just + // outside the window, with zero coefficients. It's possible + // if the arguments to the ceil and floor expressions are integers + // (e.g. if filter_cutoff_ has an exact ratio with the sample rates), + // that we unnecessarily include something with a zero coefficient, + // but this is only a slight efficiency issue. + int32_t min_input_index = ceil(min_t * samp_rate_in_), + max_input_index = floor(max_t * samp_rate_in_), + num_indices = max_input_index - min_input_index + 1; + first_index_[i] = min_input_index; + weights_[i].resize(num_indices); + for (int32_t j = 0; j < num_indices; j++) { + int32_t input_index = min_input_index + j; + double input_t = input_index / static_cast(samp_rate_in_), + delta_t = input_t - output_t; + // sign of delta_t doesn't matter. + weights_[i][j] = FilterFunc(delta_t) / samp_rate_in_; + } + } +} + +/** Here, t is a time in seconds representing an offset from + the center of the windowed filter function, and FilterFunction(t) + returns the windowed filter function, described + in the header as h(t) = f(t)g(t), evaluated at t. +*/ +float LinearResample::FilterFunc(float t) const { + float window = 0, // raised-cosine (Hanning) window of width + // num_zeros_/2*filter_cutoff_ + filter = 0; // sinc filter function + if (std::fabs(t) < num_zeros_ / (2.0 * filter_cutoff_)) + window = 0.5 * (1 + cos(M_2PI * filter_cutoff_ / num_zeros_ * t)); + else + window = 0.0; // outside support of window function + if (t != 0) + filter = sin(M_2PI * filter_cutoff_ * t) / (M_PI * t); + else + filter = 2 * filter_cutoff_; // limit of the function at t = 0 + return filter * window; +} + +void LinearResample::Reset() { + input_sample_offset_ = 0; + output_sample_offset_ = 0; + input_remainder_.resize(0); +} + +void LinearResample::Resample(const float *input, int32_t input_dim, bool flush, + std::vector *output) { + int tot_input_samp = input_sample_offset_ + input_dim, + tot_output_samp = GetNumOutputSamples(tot_input_samp, flush); + + assert(tot_output_samp >= output_sample_offset_); + + output->resize(tot_output_samp - output_sample_offset_); + + // samp_out is the index into the total output signal, not just the part + // of it we are producing here. + for (int samp_out = output_sample_offset_; samp_out < tot_output_samp; + samp_out++) { + int first_samp_in = 0; + int32_t samp_out_wrapped = 0; + GetIndexes(samp_out, &first_samp_in, &samp_out_wrapped); + const std::vector &weights = weights_[samp_out_wrapped]; + // first_input_index is the first index into "input" that we have a weight + // for. + int32_t first_input_index = + static_cast(first_samp_in - input_sample_offset_); + float this_output = 0; + if (first_input_index >= 0 && + first_input_index + static_cast(weights.size()) <= input_dim) { + this_output = + DotProduct(input + first_input_index, weights.data(), weights.size()); + } else { // Handle edge cases. + this_output = 0.0; + for (int32_t i = 0; i < static_cast(weights.size()); i++) { + float weight = weights[i]; + int32_t input_index = first_input_index + i; + if (input_index < 0 && + static_cast(input_remainder_.size()) + input_index >= 0) { + this_output += + weight * input_remainder_[input_remainder_.size() + input_index]; + } else if (input_index >= 0 && input_index < input_dim) { + this_output += weight * input[input_index]; + } else if (input_index >= input_dim) { + // We're past the end of the input and are adding zero; should only + // happen if the user specified flush == true, or else we would not + // be trying to output this sample. + assert(flush); + } + } + } + int32_t output_index = + static_cast(samp_out - output_sample_offset_); + (*output)[output_index] = this_output; + } + + if (flush) { + Reset(); // Reset the internal state. + } else { + SetRemainder(input, input_dim); + input_sample_offset_ = tot_input_samp; + output_sample_offset_ = tot_output_samp; + } +} + +int LinearResample::GetNumOutputSamples(int input_num_samp, + bool flush) const { + // For exact computation, we measure time in "ticks" of 1.0 / tick_freq, + // where tick_freq is the least common multiple of samp_rate_in_ and + // samp_rate_out_. + int32_t tick_freq = Lcm(samp_rate_in_, samp_rate_out_); + int32_t ticks_per_input_period = tick_freq / samp_rate_in_; + + // work out the number of ticks in the time interval + // [ 0, input_num_samp/samp_rate_in_ ). + int interval_length_in_ticks = input_num_samp * ticks_per_input_period; + if (!flush) { + float window_width = num_zeros_ / (2.0 * filter_cutoff_); + // To count the window-width in ticks we take the floor. This + // is because since we're looking for the largest integer num-out-samp + // that fits in the interval, which is open on the right, a reduction + // in interval length of less than a tick will never make a difference. + // For example, the largest integer in the interval [ 0, 2 ) and the + // largest integer in the interval [ 0, 2 - 0.9 ) are the same (both one). + // So when we're subtracting the window-width we can ignore the fractional + // part. + int32_t window_width_ticks = std::floor(window_width * tick_freq); + // The time-period of the output that we can sample gets reduced + // by the window-width (which is actually the distance from the + // center to the edge of the windowing function) if we're not + // "flushing the output". + interval_length_in_ticks -= window_width_ticks; + } + if (interval_length_in_ticks <= 0) return 0; + + int32_t ticks_per_output_period = tick_freq / samp_rate_out_; + // Get the last output-sample in the closed interval, i.e. replacing [ ) with + // [ ]. Note: integer division rounds down. See + // http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of + // the notation. + int last_output_samp = interval_length_in_ticks / ticks_per_output_period; + // We need the last output-sample in the open interval, so if it takes us to + // the end of the interval exactly, subtract one. + if (last_output_samp * ticks_per_output_period == interval_length_in_ticks) + last_output_samp--; + + // First output-sample index is zero, so the number of output samples + // is the last output-sample plus one. + int num_output_samp = last_output_samp + 1; + return num_output_samp; +} + +// inline +void LinearResample::GetIndexes(int samp_out, int *first_samp_in, + int32_t *samp_out_wrapped) const { + // A unit is the smallest nonzero amount of time that is an exact + // multiple of the input and output sample periods. The unit index + // is the answer to "which numbered unit we are in". + int unit_index = samp_out / output_samples_in_unit_; + // samp_out_wrapped is equal to samp_out % output_samples_in_unit_ + *samp_out_wrapped = + static_cast(samp_out - unit_index * output_samples_in_unit_); + *first_samp_in = + first_index_[*samp_out_wrapped] + unit_index * input_samples_in_unit_; +} + +void LinearResample::SetRemainder(const float *input, int32_t input_dim) { + std::vector old_remainder(input_remainder_); + // max_remainder_needed is the width of the filter from side to side, + // measured in input samples. you might think it should be half that, + // but you have to consider that you might be wanting to output samples + // that are "in the past" relative to the beginning of the latest + // input... anyway, storing more remainder than needed is not harmful. + int32_t max_remainder_needed = + std::ceil(samp_rate_in_ * num_zeros_ / filter_cutoff_); + input_remainder_.resize(max_remainder_needed); + for (int32_t index = -static_cast(input_remainder_.size()); + index < 0; index++) { + // we interpret "index" as an offset from the end of "input" and + // from the end of input_remainder_. + int32_t input_index = index + input_dim; + if (input_index >= 0) { + input_remainder_[index + static_cast(input_remainder_.size())] = + input[input_index]; + } else if (input_index + static_cast(old_remainder.size()) >= 0) { + input_remainder_[index + static_cast(input_remainder_.size())] = + old_remainder[input_index + + static_cast(old_remainder.size())]; + // else leave it at zero. + } + } +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/resample.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/resample.h new file mode 100644 index 00000000..3be66785 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/resample.h @@ -0,0 +1,144 @@ +/** + * Copyright 2013 Pegah Ghahremani + * 2014 IMSL, PKU-HKUST (author: Wei Shi) + * 2014 Yanqing Sun, Junjie Wang + * 2014 Johns Hopkins University (author: Daniel Povey) + * Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// this file is copied and modified from +// kaldi/src/feat/resample.h +#ifndef SHERPA_ONNX_CSRC_RESAMPLE_H_ +#define SHERPA_ONNX_CSRC_RESAMPLE_H_ + +#include +#include + +namespace sherpa_mnn { + +/* + We require that the input and output sampling rate be specified as + integers, as this is an easy way to specify that their ratio be rational. +*/ + +class LinearResample { + public: + /// Constructor. We make the input and output sample rates integers, because + /// we are going to need to find a common divisor. This should just remind + /// you that they need to be integers. The filter cutoff needs to be less + /// than samp_rate_in_hz/2 and less than samp_rate_out_hz/2. num_zeros + /// controls the sharpness of the filter, more == sharper but less efficient. + /// We suggest around 4 to 10 for normal use. + LinearResample(int32_t samp_rate_in_hz, int32_t samp_rate_out_hz, + float filter_cutoff_hz, int32_t num_zeros); + + /// Calling the function Reset() resets the state of the object prior to + /// processing a new signal; it is only necessary if you have called + /// Resample(x, x_size, false, y) for some signal, leading to a remainder of + /// the signal being called, but then abandon processing the signal before + /// calling Resample(x, x_size, true, y) for the last piece. Call it + /// unnecessarily between signals will not do any harm. + void Reset(); + + /// This function does the resampling. If you call it with flush == true and + /// you have never called it with flush == false, it just resamples the input + /// signal (it resizes the output to a suitable number of samples). + /// + /// You can also use this function to process a signal a piece at a time. + /// suppose you break it into piece1, piece2, ... pieceN. You can call + /// \code{.cc} + /// Resample(piece1, piece1_size, false, &output1); + /// Resample(piece2, piece2_size, false, &output2); + /// Resample(piece3, piece3_size, true, &output3); + /// \endcode + /// If you call it with flush == false, it won't output the last few samples + /// but will remember them, so that if you later give it a second piece of + /// the input signal it can process it correctly. + /// If your most recent call to the object was with flush == false, it will + /// have internal state; you can remove this by calling Reset(). + /// Empty input is acceptable. + void Resample(const float *input, int32_t input_dim, bool flush, + std::vector *output); + + //// Return the input and output sampling rates (for checks, for example) + int32_t GetInputSamplingRate() const { return samp_rate_in_; } + int32_t GetOutputSamplingRate() const { return samp_rate_out_; } + + private: + void SetIndexesAndWeights(); + + float FilterFunc(float) const; + + /// This function outputs the number of output samples we will output + /// for a signal with "input_num_samp" input samples. If flush == true, + /// we return the largest n such that + /// (n/samp_rate_out_) is in the interval [ 0, input_num_samp/samp_rate_in_ ), + /// and note that the interval is half-open. If flush == false, + /// define window_width as num_zeros / (2.0 * filter_cutoff_); + /// we return the largest n such that (n/samp_rate_out_) is in the interval + /// [ 0, input_num_samp/samp_rate_in_ - window_width ). + int GetNumOutputSamples(int input_num_samp, bool flush) const; + + /// Given an output-sample index, this function outputs to *first_samp_in the + /// first input-sample index that we have a weight on (may be negative), + /// and to *samp_out_wrapped the index into weights_ where we can get the + /// corresponding weights on the input. + inline void GetIndexes(int samp_out, int *first_samp_in, + int32_t *samp_out_wrapped) const; + + void SetRemainder(const float *input, int32_t input_dim); + + private: + // The following variables are provided by the user. + int32_t samp_rate_in_; + int32_t samp_rate_out_; + float filter_cutoff_; + int32_t num_zeros_; + + int32_t input_samples_in_unit_; ///< The number of input samples in the + ///< smallest repeating unit: num_samp_in_ = + ///< samp_rate_in_hz / Gcd(samp_rate_in_hz, + ///< samp_rate_out_hz) + + int32_t output_samples_in_unit_; ///< The number of output samples in the + ///< smallest repeating unit: num_samp_out_ + ///< = samp_rate_out_hz / + ///< Gcd(samp_rate_in_hz, samp_rate_out_hz) + + /// The first input-sample index that we sum over, for this output-sample + /// index. May be negative; any truncation at the beginning is handled + /// separately. This is just for the first few output samples, but we can + /// extrapolate the correct input-sample index for arbitrary output samples. + std::vector first_index_; + + /// Weights on the input samples, for this output-sample index. + std::vector> weights_; + + // the following variables keep track of where we are in a particular signal, + // if it is being provided over multiple calls to Resample(). + + int input_sample_offset_ = 0; ///< The number of input samples we have + ///< already received for this signal + ///< (including anything in remainder_) + int output_sample_offset_ = 0; ///< The number of samples we have already + ///< output for this signal. + std::vector input_remainder_; ///< A small trailing part of the + ///< previously seen input signal. +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_RESAMPLE_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/session.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/session.cc new file mode 100644 index 00000000..e5357d56 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/session.cc @@ -0,0 +1,72 @@ +// sherpa-mnn/csrc/session.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/session.h" + +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/provider.h" + + +#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1 +#include "dml_provider_factory.h" // NOLINT +#endif + +namespace sherpa_mnn { + + +MNNConfig GetSessionOptionsImpl( + int32_t num_threads, const std::string &provider_str, + const ProviderConfig *provider_config /*= nullptr*/) { + MNN::ScheduleConfig config; + config.numThread = num_threads; + MNN::BackendConfig bnConfig; + bnConfig.memory = MNN::BackendConfig::Memory_Low; + config.backendConfig = &bnConfig; + MNNConfig sess_opts; + sess_opts.pManager.reset(MNN::Express::Executor::RuntimeManager::createRuntimeManager(config)); + sess_opts.pConfig.rearrange = true; + return sess_opts; +} + +MNNConfig GetSessionOptions(const OnlineModelConfig &config) { + return GetSessionOptionsImpl(config.num_threads, + config.provider_config.provider, + &config.provider_config); +} + +MNNConfig GetSessionOptions(const OnlineModelConfig &config, + const std::string &model_type) { + /* + Transducer models : Only encoder will run with tensorrt, + decoder and joiner will run with cuda + */ + if (config.provider_config.provider == "trt" && + (model_type == "decoder" || model_type == "joiner")) { + return GetSessionOptionsImpl(config.num_threads, "cuda", + &config.provider_config); + } + return GetSessionOptionsImpl(config.num_threads, + config.provider_config.provider, + &config.provider_config); +} + +MNNConfig GetSessionOptions(const OfflineLMConfig &config) { + return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); +} + +MNNConfig GetSessionOptions(const OnlineLMConfig &config) { + return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); +} + +MNNConfig GetSessionOptions(int32_t num_threads, + const std::string &provider_str) { + return GetSessionOptionsImpl(num_threads, provider_str); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/session.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/session.h new file mode 100644 index 00000000..20198cc7 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/session.h @@ -0,0 +1,39 @@ +// sherpa-mnn/csrc/session.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_SESSION_H_ +#define SHERPA_ONNX_CSRC_SESSION_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/offline-lm-config.h" +#include "sherpa-mnn/csrc/online-lm-config.h" +#include "sherpa-mnn/csrc/online-model-config.h" + +namespace sherpa_mnn { + +MNNConfig GetSessionOptionsImpl( + int32_t num_threads, const std::string &provider_str, + const ProviderConfig *provider_config = nullptr); + +MNNConfig GetSessionOptions(const OfflineLMConfig &config); +MNNConfig GetSessionOptions(const OnlineLMConfig &config); + +MNNConfig GetSessionOptions(const OnlineModelConfig &config); + +MNNConfig GetSessionOptions(const OnlineModelConfig &config, + const std::string &model_type); + +MNNConfig GetSessionOptions(int32_t num_threads, + const std::string &provider_str); + +template +MNNConfig GetSessionOptions(const T &config) { + return GetSessionOptionsImpl(config.num_threads, config.provider); +} + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_SESSION_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-alsa-offline-audio-tagging.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-alsa-offline-audio-tagging.cc new file mode 100644 index 00000000..b76c3e7b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-alsa-offline-audio-tagging.cc @@ -0,0 +1,190 @@ +// sherpa-mnn/csrc/sherpa-mnn-alsa-offline-audio-tagging.cc +// +// Copyright (c) 2022-2024 Xiaomi Corporation + +#include +#include +#include + +#include +#include // NOLINT +#include // NOLINT + +#include "sherpa-mnn/csrc/alsa.h" +#include "sherpa-mnn/csrc/audio-tagging.h" +#include "sherpa-mnn/csrc/macros.h" + +enum class State { + kIdle, + kRecording, + kDecoding, +}; + +State state = State::kIdle; + +// true to stop the program and exit +bool stop = false; + +std::vector samples; +std::mutex samples_mutex; + +static void DetectKeyPress() { + SHERPA_ONNX_LOGE("Press Enter to start"); + int32_t key; + while (!stop && (key = getchar())) { + if (key != 0x0a) { + continue; + } + + switch (state) { + case State::kIdle: + SHERPA_ONNX_LOGE("Start recording. Press Enter to stop recording"); + state = State::kRecording; + { + std::lock_guard lock(samples_mutex); + samples.clear(); + } + break; + case State::kRecording: + SHERPA_ONNX_LOGE("Stop recording. Decoding ..."); + state = State::kDecoding; + break; + case State::kDecoding: + break; + } + } +} + +static void Record(const char *device_name, int32_t expected_sample_rate) { + sherpa_mnn::Alsa alsa(device_name); + + if (alsa.GetExpectedSampleRate() != expected_sample_rate) { + fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(), + expected_sample_rate); + exit(-1); + } + + int32_t chunk = 0.1 * alsa.GetActualSampleRate(); + while (!stop) { + const std::vector &s = alsa.Read(chunk); + std::lock_guard lock(samples_mutex); + samples.insert(samples.end(), s.begin(), s.end()); + } +} + +static void Handler(int32_t sig) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Press Enter to exit\n"); +} + +int32_t main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +Audio tagging from microphone (Linux only). +Usage: + +wget https://github.com/k2-fsa/sherpa-mnn/releases/download/audio-tagging-models/sherpa-mnn-zipformer-audio-tagging-2024-04-09.tar.bz2 +tar xvf sherpa-mnn-zipformer-audio-tagging-2024-04-09.tar.bz2 +rm sherpa-mnn-zipformer-audio-tagging-2024-04-09.tar.bz2 + +./bin/sherpa-mnn-alsa-offline-audio-tagging \ + --zipformer-model=./sherpa-mnn-zipformer-audio-tagging-2024-04-09/model.onnx \ + --labels=./sherpa-mnn-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv \ + device_name + +Please refer to +https://github.com/k2-fsa/sherpa-mnn/releases/tag/audio-tagging-models +for a list of pre-trained models to download. + +The device name specifies which microphone to use in case there are several +on your system. You can use + + arecord -l + +to find all available microphones on your computer. For instance, if it outputs + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and device 0 on that card, please use: + + plughw:3,0 + +as the device_name. +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::AudioTaggingConfig config; + config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, "Please provide only 1 argument: the device name\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + SHERPA_ONNX_LOGE("Creating audio tagger ..."); + sherpa_mnn::AudioTagging tagger(config); + SHERPA_ONNX_LOGE("Audio tagger created created!"); + + std::string device_name = po.GetArg(1); + fprintf(stderr, "Use recording device: %s\n", device_name.c_str()); + + int32_t sample_rate = 16000; // fixed to 16000Hz for all models from icefall + + std::thread t2(Record, device_name.c_str(), sample_rate); + using namespace std::chrono_literals; // NOLINT + std::this_thread::sleep_for(100ms); // sleep for 100ms + std::thread t(DetectKeyPress); + + while (!stop) { + switch (state) { + case State::kIdle: + break; + case State::kRecording: + break; + case State::kDecoding: { + std::vector buf; + { + std::lock_guard lock(samples_mutex); + buf = std::move(samples); + } + SHERPA_ONNX_LOGE("Computing..."); + auto s = tagger.CreateStream(); + s->AcceptWaveform(sample_rate, buf.data(), buf.size()); + auto results = tagger.Compute(s.get()); + SHERPA_ONNX_LOGE("Result is:"); + + int32_t i = 0; + std::ostringstream os; + for (const auto &event : results) { + os << i << ": " << event.ToString() << "\n"; + i += 1; + } + + SHERPA_ONNX_LOGE("\n%s\n", os.str().c_str()); + + state = State::kIdle; + SHERPA_ONNX_LOGE("Press Enter to start"); + break; + } + } + + std::this_thread::sleep_for(20ms); // sleep for 20ms + } + t.join(); + t2.join(); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-alsa-offline-speaker-identification.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-alsa-offline-speaker-identification.cc new file mode 100644 index 00000000..4c245edf --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-alsa-offline-speaker-identification.cc @@ -0,0 +1,287 @@ +// sherpa-mnn/csrc/sherpa-mnn-alsa-offline-speaker-identification.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include +#include +#include + +#include +#include +#include // NOLINT +#include +#include // NOLINT + +#include "sherpa-mnn/csrc/alsa.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/microphone.h" +#include "sherpa-mnn/csrc/speaker-embedding-extractor.h" +#include "sherpa-mnn/csrc/speaker-embedding-manager.h" +#include "sherpa-mnn/csrc/wave-reader.h" + +enum class State { + kIdle, + kRecording, + kComputing, +}; + +State state = State::kIdle; + +// true to stop the program and exit +bool stop = false; + +std::vector samples; +std::mutex samples_mutex; + +static void DetectKeyPress() { + SHERPA_ONNX_LOGE("\nPress Enter to start"); + int32_t key; + while (!stop && (key = getchar())) { + if (key != 0x0a) { + continue; + } + + switch (state) { + case State::kIdle: + SHERPA_ONNX_LOGE("\nStart recording. Press Enter to stop recording"); + state = State::kRecording; + { + std::lock_guard lock(samples_mutex); + samples.clear(); + } + break; + case State::kRecording: + SHERPA_ONNX_LOGE("\nStop recording. Computing ..."); + state = State::kComputing; + break; + case State::kComputing: + break; + } + } +} + +static void Record(const char *device_name, int32_t expected_sample_rate) { + sherpa_mnn::Alsa alsa(device_name); + + if (alsa.GetExpectedSampleRate() != expected_sample_rate) { + fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(), + expected_sample_rate); + exit(-1); + } + + int32_t chunk = 0.1 * alsa.GetActualSampleRate(); + while (!stop) { + const std::vector &s = alsa.Read(chunk); + std::lock_guard lock(samples_mutex); + samples.insert(samples.end(), s.begin(), s.end()); + } +} + +static void Handler(int32_t sig) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Press Enter to exit\n"); +} + +static std::vector> ComputeEmbeddings( + const std::vector &filenames, + sherpa_mnn::SpeakerEmbeddingExtractor *extractor) { + std::vector> embedding_list; + embedding_list.reserve(filenames.size()); + + for (const auto &f : filenames) { + int32_t sampling_rate = -1; + + bool is_ok = false; + const std::vector samples = + sherpa_mnn::ReadWave(f, &sampling_rate, &is_ok); + + if (!is_ok) { + fprintf(stderr, "Failed to read '%s'\n", f.c_str()); + exit(-1); + } + + auto s = extractor->CreateStream(); + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); + s->InputFinished(); + auto embedding = extractor->Compute(s.get()); + embedding_list.push_back(embedding); + } + return embedding_list; +} + +static std::unordered_map> +ReadSpeakerFile(const std::string &filename) { + std::unordered_map> ans; + + std::ifstream is(filename); + if (!is) { + fprintf(stderr, "Failed to open %s", filename.c_str()); + exit(0); + } + + std::string line; + std::string name; + std::string path; + + while (std::getline(is, line)) { + std::istringstream iss(line); + name.clear(); + path.clear(); + + iss >> name >> path; + if (!iss || !iss.eof() || name.empty() || path.empty()) { + fprintf(stderr, "Invalid line: %s\n", line.c_str()); + exit(-1); + } + ans[name].push_back(path); + } + + return ans; +} + +int32_t main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +This program shows how to use non-streaming speaker identification. +Usage: + +(1) Prepare a text file containing speaker related files. + +Each line in the text file contains two columns. The first column is the +speaker name, while the second column contains the wave file of the speaker. + +If the text file contains multiple wave files for the same speaker, then the +embeddings of these files are averaged. + +An example text file is given below: + + foo /path/to/a.wav + bar /path/to/b.wav + foo /path/to/c.wav + foobar /path/to/d.wav + +Each wave file should contain only a single channel; the sample format +should be int16_t; the sample rate can be arbitrary. + +(2) Download a model for computing speaker embeddings + +Please visit +https://github.com/k2-fsa/sherpa-mnn/releases/tag/speaker-recongition-models +to download a model. An example is given below: + + wget https://github.com/k2-fsa/sherpa-mnn/releases/download/speaker-recongition-models/wespeaker_zh_cnceleb_resnet34.onnx + +Note that `zh` means Chinese, while `en` means English. + +(3) Run it ! + + ./bin/sherpa-mnn-alsa-offline-speaker-identification \ + --model=/path/to/your-model.onnx \ + --speaker-file=/path/to/speaker.txt \ + device_name + +The device name specifies which microphone to use in case there are several +on your system. You can use + + arecord -l + +to find all available microphones on your computer. For instance, if it outputs + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and device 0 on that card, please use: + plughw:3,0 +as the device_name. + +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + float threshold = 0.5; + std::string speaker_file; + + po.Register("threshold", &threshold, + "Threshold for comparing embedding scores."); + + po.Register("speaker-file", &speaker_file, "Path to speaker.txt"); + + sherpa_mnn::SpeakerEmbeddingExtractorConfig config; + config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, "Please provide only 1 argument: the device name\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config! Please use --help to view the usage.\n"); + return -1; + } + + SHERPA_ONNX_LOGE("\nCreating extractor ..."); + sherpa_mnn::SpeakerEmbeddingExtractor extractor(config); + SHERPA_ONNX_LOGE("\nextractor created!"); + + sherpa_mnn::SpeakerEmbeddingManager manager(extractor.Dim()); + + auto name2files = ReadSpeakerFile(speaker_file); + for (const auto &p : name2files) { + SHERPA_ONNX_LOGE("\nProcessing speaker %s", p.first.c_str()); + auto embedding_list = ComputeEmbeddings(p.second, &extractor); + manager.Add(p.first, embedding_list); + } + + std::string device_name = po.GetArg(1); + fprintf(stderr, "Use recording device: %s\n", device_name.c_str()); + int32_t sample_rate = 16000; + + std::thread t(DetectKeyPress); + std::thread t2(Record, device_name.c_str(), sample_rate); + + while (!stop) { + switch (state) { + case State::kIdle: + break; + case State::kRecording: + break; + case State::kComputing: { + std::vector buf; + { + std::lock_guard lock(samples_mutex); + buf = std::move(samples); + } + + auto s = extractor.CreateStream(); + s->AcceptWaveform(sample_rate, buf.data(), buf.size()); + s->InputFinished(); + auto embedding = extractor.Compute(s.get()); + auto name = manager.Search(embedding.data(), threshold); + + if (name.empty()) { + name = "--Unknown--"; + } + + SHERPA_ONNX_LOGE("\nDone!\nDetected speaker is: %s", name.c_str()); + + state = State::kIdle; + SHERPA_ONNX_LOGE("\nPress Enter to start"); + break; + } + } + + using namespace std::chrono_literals; // NOLINT + std::this_thread::sleep_for(20ms); // sleep for 20ms + } + + t.join(); + t2.join(); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-alsa-offline.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-alsa-offline.cc new file mode 100644 index 00000000..0e2f8264 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-alsa-offline.cc @@ -0,0 +1,202 @@ +// sherpa-mnn/csrc/sherpa-mnn-alsa-offline.cc +// +// Copyright (c) 2022-2024 Xiaomi Corporation + +#include +#include +#include + +#include +#include // std::tolower +#include // NOLINT +#include // NOLINT +#include // NOLINT + +#include "sherpa-mnn/csrc/alsa.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/offline-recognizer.h" + +enum class State { + kIdle, + kRecording, + kDecoding, +}; + +State state = State::kIdle; + +// true to stop the program and exit +bool stop = false; + +std::vector samples; +std::mutex samples_mutex; + +static void DetectKeyPress() { + SHERPA_ONNX_LOGE("Press Enter to start"); + int32_t key; + while (!stop && (key = getchar())) { + if (key != 0x0a) { + continue; + } + + switch (state) { + case State::kIdle: + SHERPA_ONNX_LOGE("Start recording. Press Enter to stop recording"); + state = State::kRecording; + { + std::lock_guard lock(samples_mutex); + samples.clear(); + } + break; + case State::kRecording: + SHERPA_ONNX_LOGE("Stop recording. Decoding ..."); + state = State::kDecoding; + break; + case State::kDecoding: + break; + } + } +} + +static void Record(const char *device_name, int32_t expected_sample_rate) { + sherpa_mnn::Alsa alsa(device_name); + + if (alsa.GetExpectedSampleRate() != expected_sample_rate) { + fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(), + expected_sample_rate); + exit(-1); + } + + int32_t chunk = 0.1 * alsa.GetActualSampleRate(); + while (!stop) { + const std::vector &s = alsa.Read(chunk); + std::lock_guard lock(samples_mutex); + samples.insert(samples.end(), s.begin(), s.end()); + } +} + +static void Handler(int32_t sig) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Press Enter to exit\n"); +} + +int32_t main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +This program uses non-streaming models with microphone for speech recognition. +Usage: + +(1) Transducer from icefall + + ./bin/sherpa-mnn-alsa-offline \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + device_name + +(2) Paraformer from FunASR + + ./bin/sherpa-mnn-alsa-offline \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/model.onnx \ + --num-threads=1 \ + device_name + +(3) Whisper models + + ./bin/sherpa-mnn-alsa-offline \ + --whisper-encoder=./sherpa-mnn-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-mnn-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-mnn-whisper-base.en/base.en-tokens.txt \ + --num-threads=1 \ + device_name + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. + +The device name specifies which microphone to use in case there are several +on your system. You can use + + arecord -l + +to find all available microphones on your computer. For instance, if it outputs + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and device 0 on that card, please use: + + plughw:3,0 + +as the device_name. +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::OfflineRecognizerConfig config; + config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, "Please provide only 1 argument: the device name\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + SHERPA_ONNX_LOGE("Creating recognizer ..."); + sherpa_mnn::OfflineRecognizer recognizer(config); + SHERPA_ONNX_LOGE("Recognizer created!"); + + std::string device_name = po.GetArg(1); + fprintf(stderr, "Use recording device: %s\n", device_name.c_str()); + + int32_t sample_rate = config.feat_config.sampling_rate; + + std::thread t(DetectKeyPress); + std::thread t2(Record, device_name.c_str(), sample_rate); + + while (!stop) { + switch (state) { + case State::kIdle: + break; + case State::kRecording: + break; + case State::kDecoding: { + std::vector buf; + { + std::lock_guard lock(samples_mutex); + buf = std::move(samples); + } + + auto s = recognizer.CreateStream(); + s->AcceptWaveform(sample_rate, buf.data(), buf.size()); + recognizer.DecodeStream(s.get()); + SHERPA_ONNX_LOGE("Decoding Done! Result is:"); + SHERPA_ONNX_LOGE("%s", s->GetResult().text.c_str()); + + state = State::kIdle; + SHERPA_ONNX_LOGE("Press Enter to start"); + break; + } + } + + using namespace std::chrono_literals; // NOLINT + std::this_thread::sleep_for(20ms); // sleep for 20ms + } + t.join(); + t2.join(); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-alsa.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-alsa.cc new file mode 100644 index 00000000..305213bd --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-alsa.cc @@ -0,0 +1,152 @@ +// sherpa-mnn/csrc/sherpa-mnn-alsa.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation +#include +#include +#include + +#include +#include // std::tolower +#include + +#include "sherpa-mnn/csrc/alsa.h" +#include "sherpa-mnn/csrc/display.h" +#include "sherpa-mnn/csrc/online-recognizer.h" +#include "sherpa-mnn/csrc/parse-options.h" + +bool stop = false; + +static void Handler(int sig) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); +} + +int main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +Usage: + ./bin/sherpa-mnn-alsa \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --provider=cpu \ + --num-threads=2 \ + --decoding-method=greedy_search \ + device_name + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. + +The device name specifies which microphone to use in case there are several +on your system. You can use + + arecord -l + +to find all available microphones on your computer. For instance, if it outputs + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and device 0 on that card, please use: + + plughw:3,0 + +as the device_name. +)usage"; + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::OnlineRecognizerConfig config; + + config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, "Please provide only 1 argument: the device name\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + sherpa_mnn::OnlineRecognizer recognizer(config); + + int32_t expected_sample_rate = config.feat_config.sampling_rate; + + std::string device_name = po.GetArg(1); + sherpa_mnn::Alsa alsa(device_name.c_str()); + fprintf(stderr, "Use recording device: %s\n", device_name.c_str()); + + if (alsa.GetExpectedSampleRate() != expected_sample_rate) { + fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(), + expected_sample_rate); + exit(-1); + } + + fprintf(stderr, "Started! Please speak\n"); + + int32_t chunk = 0.1 * alsa.GetActualSampleRate(); + + std::string last_text; + + auto stream = recognizer.CreateStream(); + + sherpa_mnn::Display display; + + int32_t segment_index = 0; + while (!stop) { + const std::vector &samples = alsa.Read(chunk); + + stream->AcceptWaveform(expected_sample_rate, samples.data(), + samples.size()); + + while (recognizer.IsReady(stream.get())) { + recognizer.DecodeStream(stream.get()); + } + + auto text = recognizer.GetResult(stream.get()).text; + + bool is_endpoint = recognizer.IsEndpoint(stream.get()); + + if (is_endpoint && !config.model_config.paraformer.encoder.empty()) { + // For streaming paraformer models, since it has a large right chunk size + // we need to pad it on endpointing so that the last character + // can be recognized + std::vector tail_paddings( + static_cast(1.0 * expected_sample_rate)); + stream->AcceptWaveform(expected_sample_rate, tail_paddings.data(), + tail_paddings.size()); + while (recognizer.IsReady(stream.get())) { + recognizer.DecodeStream(stream.get()); + } + text = recognizer.GetResult(stream.get()).text; + } + + if (!text.empty() && last_text != text) { + last_text = text; + + std::transform(text.begin(), text.end(), text.begin(), + [](auto c) { return std::tolower(c); }); + + display.Print(segment_index, text); + fflush(stderr); + } + + if (is_endpoint) { + if (!text.empty()) { + ++segment_index; + } + + recognizer.Reset(stream.get()); + } + } + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-keyword-spotter-alsa.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-keyword-spotter-alsa.cc new file mode 100644 index 00000000..13ecfca7 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-keyword-spotter-alsa.cc @@ -0,0 +1,122 @@ +// sherpa-mnn/csrc/sherpa-mnn-keyword-spotter-alsa.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include +#include +#include + +#include +#include + +#include "sherpa-mnn/csrc/alsa.h" +#include "sherpa-mnn/csrc/display.h" +#include "sherpa-mnn/csrc/keyword-spotter.h" +#include "sherpa-mnn/csrc/parse-options.h" + +bool stop = false; + +static void Handler(int sig) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); +} + +int main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +Usage: + ./bin/sherpa-mnn-keyword-spotter-alsa \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --provider=cpu \ + --num-threads=2 \ + --keywords-file=keywords.txt \ + device_name + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html +for a list of pre-trained models to download. + +The device name specifies which microphone to use in case there are several +on your system. You can use + + arecord -l + +to find all available microphones on your computer. For instance, if it outputs + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and device 0 on that card, please use: + + plughw:3,0 + +as the device_name. +)usage"; + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::KeywordSpotterConfig config; + + config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, "Please provide only 1 argument: the device name\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + sherpa_mnn::KeywordSpotter spotter(config); + + int32_t expected_sample_rate = config.feat_config.sampling_rate; + + std::string device_name = po.GetArg(1); + sherpa_mnn::Alsa alsa(device_name.c_str()); + fprintf(stderr, "Use recording device: %s\n", device_name.c_str()); + + if (alsa.GetExpectedSampleRate() != expected_sample_rate) { + fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(), + expected_sample_rate); + exit(-1); + } + + int32_t chunk = 0.1 * alsa.GetActualSampleRate(); + + std::string last_text; + + auto stream = spotter.CreateStream(); + + sherpa_mnn::Display display; + + int32_t keyword_index = 0; + while (!stop) { + const std::vector &samples = alsa.Read(chunk); + + stream->AcceptWaveform(expected_sample_rate, samples.data(), + samples.size()); + + while (spotter.IsReady(stream.get())) { + spotter.DecodeStream(stream.get()); + + const auto r = spotter.GetResult(stream.get()); + if (!r.keyword.empty()) { + display.Print(keyword_index, r.AsJsonString()); + fflush(stderr); + keyword_index++; + + spotter.Reset(stream.get()); + } + } + } + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-keyword-spotter-microphone.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-keyword-spotter-microphone.cc new file mode 100644 index 00000000..1a8c4210 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-keyword-spotter-microphone.cc @@ -0,0 +1,174 @@ +// sherpa-mnn/csrc/sherpa-mnn-keyword-spotter-microphone.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include +#include +#include + +#include + +#include "portaudio.h" // NOLINT +#include "sherpa-mnn/csrc/display.h" +#include "sherpa-mnn/csrc/keyword-spotter.h" +#include "sherpa-mnn/csrc/microphone.h" + +bool stop = false; +float mic_sample_rate = 16000; + +static int32_t RecordCallback(const void *input_buffer, + void * /*output_buffer*/, + unsigned long frames_per_buffer, // NOLINT + const PaStreamCallbackTimeInfo * /*time_info*/, + PaStreamCallbackFlags /*status_flags*/, + void *user_data) { + auto stream = reinterpret_cast(user_data); + + stream->AcceptWaveform(mic_sample_rate, + reinterpret_cast(input_buffer), + frames_per_buffer); + + return stop ? paComplete : paContinue; +} + +static void Handler(int32_t /*sig*/) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); +} + +int32_t main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +This program uses streaming models with microphone for keyword spotting. +Usage: + + ./bin/sherpa-mnn-keyword-spotter-microphone \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --provider=cpu \ + --num-threads=1 \ + --keywords-file=keywords.txt + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html +for a list of pre-trained models to download. +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::KeywordSpotterConfig config; + + config.Register(&po); + po.Read(argc, argv); + if (po.NumArgs() != 0) { + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + sherpa_mnn::KeywordSpotter spotter(config); + auto s = spotter.CreateStream(); + + sherpa_mnn::Microphone mic; + + PaDeviceIndex num_devices = Pa_GetDeviceCount(); + fprintf(stderr, "Num devices: %d\n", num_devices); + + int32_t device_index = Pa_GetDefaultInputDevice(); + + if (device_index == paNoDevice) { + fprintf(stderr, "No default input device found\n"); + fprintf(stderr, "If you are using Linux, please switch to \n"); + fprintf(stderr, " ./bin/sherpa-mnn-keyword-spotter-alsa \n"); + exit(EXIT_FAILURE); + } + + const char *pDeviceIndex = std::getenv("SHERPA_ONNX_MIC_DEVICE"); + if (pDeviceIndex) { + fprintf(stderr, "Use specified device: %s\n", pDeviceIndex); + device_index = atoi(pDeviceIndex); + } + + for (int32_t i = 0; i != num_devices; ++i) { + const PaDeviceInfo *info = Pa_GetDeviceInfo(i); + fprintf(stderr, " %s %d %s\n", (i == device_index) ? "*" : " ", i, + info->name); + } + + PaStreamParameters param; + param.device = device_index; + + fprintf(stderr, "Use device: %d\n", param.device); + + const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device); + fprintf(stderr, " Name: %s\n", info->name); + fprintf(stderr, " Max input channels: %d\n", info->maxInputChannels); + + param.channelCount = 1; + param.sampleFormat = paFloat32; + + param.suggestedLatency = info->defaultLowInputLatency; + param.hostApiSpecificStreamInfo = nullptr; + + const char *pSampleRateStr = std::getenv("SHERPA_ONNX_MIC_SAMPLE_RATE"); + if (pSampleRateStr) { + fprintf(stderr, "Use sample rate %f for mic\n", mic_sample_rate); + mic_sample_rate = atof(pSampleRateStr); + } + + PaStream *stream; + PaError err = + Pa_OpenStream(&stream, ¶m, nullptr, /* &outputParameters, */ + mic_sample_rate, + 0, // frames per buffer + paClipOff, // we won't output out of range samples + // so don't bother clipping them + RecordCallback, s.get()); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + err = Pa_StartStream(stream); + fprintf(stderr, "Started\n"); + + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + int32_t keyword_index = 0; + sherpa_mnn::Display display; + while (!stop) { + while (spotter.IsReady(s.get())) { + spotter.DecodeStream(s.get()); + + const auto r = spotter.GetResult(s.get()); + if (!r.keyword.empty()) { + display.Print(keyword_index, r.AsJsonString()); + fflush(stderr); + keyword_index++; + + spotter.Reset(s.get()); + } + } + + Pa_Sleep(20); // sleep for 20ms + } + + err = Pa_CloseStream(stream); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-keyword-spotter.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-keyword-spotter.cc new file mode 100644 index 00000000..e54a634c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-keyword-spotter.cc @@ -0,0 +1,121 @@ +// sherpa-mnn/csrc/sherpa-mnn-keyword-spotter.cc +// +// Copyright (c) 2023-2024 Xiaomi Corporation + +#include + +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/keyword-spotter.h" +#include "sherpa-mnn/csrc/online-stream.h" +#include "sherpa-mnn/csrc/parse-options.h" +#include "sherpa-mnn/csrc/wave-reader.h" + +typedef struct { + std::unique_ptr online_stream; + std::string filename; +} Stream; + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Usage: + +(1) Streaming transducer + + ./bin/sherpa-mnn-keyword-spotter \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --provider=cpu \ + --num-threads=2 \ + --keywords-file=keywords.txt \ + /path/to/foo.wav [bar.wav foobar.wav ...] + +Note: It supports decoding multiple files in batches + +Default value for num_threads is 2. +Valid values for provider: cpu (default), cuda, coreml. +foo.wav should be of single channel, 16-bit PCM encoded wave file; its +sampling rate can be arbitrary and does not need to be 16kHz. + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::KeywordSpotterConfig config; + + config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() < 1) { + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + sherpa_mnn::KeywordSpotter keyword_spotter(config); + + std::vector ss; + + for (int32_t i = 1; i <= po.NumArgs(); ++i) { + const std::string wav_filename = po.GetArg(i); + int32_t sampling_rate = -1; + + bool is_ok = false; + const std::vector samples = + sherpa_mnn::ReadWave(wav_filename, &sampling_rate, &is_ok); + + if (!is_ok) { + fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str()); + return -1; + } + + auto s = keyword_spotter.CreateStream(); + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); + + std::vector tail_paddings(static_cast(0.8 * sampling_rate)); + // Note: We can call AcceptWaveform() multiple times. + s->AcceptWaveform(sampling_rate, tail_paddings.data(), + tail_paddings.size()); + + // Call InputFinished() to indicate that no audio samples are available + s->InputFinished(); + ss.push_back({std::move(s), wav_filename}); + } + + std::vector ready_streams; + for (;;) { + ready_streams.clear(); + for (auto &s : ss) { + const auto p_ss = s.online_stream.get(); + if (keyword_spotter.IsReady(p_ss)) { + ready_streams.push_back(p_ss); + } + std::ostringstream os; + const auto r = keyword_spotter.GetResult(p_ss); + if (!r.keyword.empty()) { + os << s.filename << "\n"; + os << r.AsJsonString() << "\n\n"; + fprintf(stderr, "%s", os.str().c_str()); + } + } + + if (ready_streams.empty()) { + break; + } + keyword_spotter.DecodeStreams(ready_streams.data(), ready_streams.size()); + } + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-microphone-offline-audio-tagging.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-microphone-offline-audio-tagging.cc new file mode 100644 index 00000000..3517f77b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-microphone-offline-audio-tagging.cc @@ -0,0 +1,236 @@ +// sherpa-mnn/csrc/sherpa-mnn-microphone-offline-audio-tagging.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include +#include +#include + +#include +#include // std::tolower +#include // NOLINT +#include // NOLINT + +#include "portaudio.h" // NOLINT +#include "sherpa-mnn/csrc/audio-tagging.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/microphone.h" + +enum class State { + kIdle, + kRecording, + kDecoding, +}; + +State state = State::kIdle; + +// true to stop the program and exit +bool stop = false; + +std::vector samples; +std::mutex samples_mutex; + +static void DetectKeyPress() { + SHERPA_ONNX_LOGE("Press Enter to start"); + int32_t key; + while (!stop && (key = getchar())) { + if (key != 0x0a) { + continue; + } + + switch (state) { + case State::kIdle: + SHERPA_ONNX_LOGE("Start recording. Press Enter to stop recording"); + state = State::kRecording; + { + std::lock_guard lock(samples_mutex); + samples.clear(); + } + break; + case State::kRecording: + SHERPA_ONNX_LOGE("Stop recording. Decoding ..."); + state = State::kDecoding; + break; + case State::kDecoding: + break; + } + } +} + +static int32_t RecordCallback(const void *input_buffer, + void * /*output_buffer*/, + unsigned long frames_per_buffer, // NOLINT + const PaStreamCallbackTimeInfo * /*time_info*/, + PaStreamCallbackFlags /*status_flags*/, + void * /*user_data*/) { + std::lock_guard lock(samples_mutex); + + auto p = reinterpret_cast(input_buffer); + samples.insert(samples.end(), p, p + frames_per_buffer); + + return stop ? paComplete : paContinue; +} + +static void Handler(int32_t /*sig*/) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Press Enter to exit\n"); +} + +int32_t main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +Audio tagging from microphone. +Usage: + +wget https://github.com/k2-fsa/sherpa-mnn/releases/download/audio-tagging-models/sherpa-mnn-zipformer-audio-tagging-2024-04-09.tar.bz2 +tar xvf sherpa-mnn-zipformer-audio-tagging-2024-04-09.tar.bz2 +rm sherpa-mnn-zipformer-audio-tagging-2024-04-09.tar.bz2 + +./bin/sherpa-mnn-microphone-offline-audio-tagging \ + --zipformer-model=./sherpa-mnn-zipformer-audio-tagging-2024-04-09/model.onnx \ + --labels=./sherpa-mnn-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv + +Please see +https://github.com/k2-fsa/sherpa-mnn/releases/tag/audio-tagging-models +for more models. +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::AudioTaggingConfig config; + config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 0) { + fprintf(stderr, "\nThis program does not support positional arguments\n\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + SHERPA_ONNX_LOGE("Creating audio tagger ..."); + sherpa_mnn::AudioTagging tagger(config); + SHERPA_ONNX_LOGE("Audio tagger created created!"); + + sherpa_mnn::Microphone mic; + + PaDeviceIndex num_devices = Pa_GetDeviceCount(); + fprintf(stderr, "Num devices: %d\n", num_devices); + + int32_t device_index = Pa_GetDefaultInputDevice(); + + if (device_index == paNoDevice) { + fprintf(stderr, "No default input device found\n"); + fprintf(stderr, "If you are using Linux, please switch to \n"); + fprintf(stderr, " ./bin/sherpa-mnn-alsa-offline-audio-tagging \n"); + exit(EXIT_FAILURE); + } + + const char *pDeviceIndex = std::getenv("SHERPA_ONNX_MIC_DEVICE"); + if (pDeviceIndex) { + fprintf(stderr, "Use specified device: %s\n", pDeviceIndex); + device_index = atoi(pDeviceIndex); + } + + for (int32_t i = 0; i != num_devices; ++i) { + const PaDeviceInfo *info = Pa_GetDeviceInfo(i); + fprintf(stderr, " %s %d %s\n", (i == device_index) ? "*" : " ", i, + info->name); + } + + PaStreamParameters param; + param.device = device_index; + + fprintf(stderr, "Use device: %d\n", param.device); + + const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device); + fprintf(stderr, " Name: %s\n", info->name); + fprintf(stderr, " Max input channels: %d\n", info->maxInputChannels); + + param.channelCount = 1; + param.sampleFormat = paFloat32; + + param.suggestedLatency = info->defaultLowInputLatency; + param.hostApiSpecificStreamInfo = nullptr; + float mic_sample_rate = 16000; + const char *pSampleRateStr = std::getenv("SHERPA_ONNX_MIC_SAMPLE_RATE"); + if (pSampleRateStr) { + fprintf(stderr, "Use sample rate %f for mic\n", mic_sample_rate); + mic_sample_rate = atof(pSampleRateStr); + } + + PaStream *stream; + PaError err = + Pa_OpenStream(&stream, ¶m, nullptr, /* &outputParameters, */ + mic_sample_rate, + 0, // frames per buffer + paClipOff, // we won't output out of range samples + // so don't bother clipping them + RecordCallback, nullptr); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + err = Pa_StartStream(stream); + fprintf(stderr, "Started\n"); + + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + std::thread t(DetectKeyPress); + while (!stop) { + switch (state) { + case State::kIdle: + break; + case State::kRecording: + break; + case State::kDecoding: { + std::vector buf; + { + std::lock_guard lock(samples_mutex); + buf = std::move(samples); + } + + SHERPA_ONNX_LOGE("Computing..."); + auto s = tagger.CreateStream(); + s->AcceptWaveform(mic_sample_rate, buf.data(), buf.size()); + auto results = tagger.Compute(s.get()); + + SHERPA_ONNX_LOGE("Result is:"); + + int32_t i = 0; + std::ostringstream os; + for (const auto &event : results) { + os << i << ": " << event.ToString() << "\n"; + i += 1; + } + + SHERPA_ONNX_LOGE("\n%s\n", os.str().c_str()); + + state = State::kIdle; + SHERPA_ONNX_LOGE("Press Enter to start"); + break; + } + } + + Pa_Sleep(20); // sleep for 20ms + } + t.join(); + + err = Pa_CloseStream(stream); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-microphone-offline-speaker-identification.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-microphone-offline-speaker-identification.cc new file mode 100644 index 00000000..9b31acb1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-microphone-offline-speaker-identification.cc @@ -0,0 +1,333 @@ +// sherpa-mnn/csrc/sherpa-mnn-microphone-offline-speaker-identification.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include +#include +#include + +#include +#include +#include // NOLINT +#include +#include // NOLINT + +#include "portaudio.h" // NOLINT +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/microphone.h" +#include "sherpa-mnn/csrc/speaker-embedding-extractor.h" +#include "sherpa-mnn/csrc/speaker-embedding-manager.h" +#include "sherpa-mnn/csrc/wave-reader.h" + +enum class State { + kIdle, + kRecording, + kComputing, +}; + +State state = State::kIdle; + +// true to stop the program and exit +bool stop = false; + +std::vector samples; +std::mutex samples_mutex; + +static void DetectKeyPress() { + SHERPA_ONNX_LOGE("\nPress Enter to start"); + int32_t key; + while (!stop && (key = getchar())) { + if (key != 0x0a) { + continue; + } + + switch (state) { + case State::kIdle: + SHERPA_ONNX_LOGE("\nStart recording. Press Enter to stop recording"); + state = State::kRecording; + { + std::lock_guard lock(samples_mutex); + samples.clear(); + } + break; + case State::kRecording: + SHERPA_ONNX_LOGE("\nStop recording. Computing ..."); + state = State::kComputing; + break; + case State::kComputing: + break; + } + } +} + +static int32_t RecordCallback(const void *input_buffer, + void * /*output_buffer*/, + unsigned long frames_per_buffer, // NOLINT + const PaStreamCallbackTimeInfo * /*time_info*/, + PaStreamCallbackFlags /*status_flags*/, + void *user_data) { + std::lock_guard lock(samples_mutex); + + auto p = reinterpret_cast(input_buffer); + samples.insert(samples.end(), p, p + frames_per_buffer); + + return stop ? paComplete : paContinue; +} + +static void Handler(int32_t sig) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Press Enter to exit\n"); +} + +static std::vector> ComputeEmbeddings( + const std::vector &filenames, + sherpa_mnn::SpeakerEmbeddingExtractor *extractor) { + std::vector> embedding_list; + embedding_list.reserve(filenames.size()); + + for (const auto &f : filenames) { + int32_t sampling_rate = -1; + + bool is_ok = false; + const std::vector samples = + sherpa_mnn::ReadWave(f, &sampling_rate, &is_ok); + + if (!is_ok) { + fprintf(stderr, "Failed to read '%s'\n", f.c_str()); + exit(-1); + } + + auto s = extractor->CreateStream(); + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); + s->InputFinished(); + auto embedding = extractor->Compute(s.get()); + embedding_list.push_back(embedding); + } + return embedding_list; +} + +static std::unordered_map> +ReadSpeakerFile(const std::string &filename) { + std::unordered_map> ans; + + std::ifstream is(filename); + if (!is) { + fprintf(stderr, "Failed to open %s", filename.c_str()); + exit(0); + } + + std::string line; + std::string name; + std::string path; + + while (std::getline(is, line)) { + std::istringstream iss(line); + name.clear(); + path.clear(); + + iss >> name >> path; + if (!iss || !iss.eof() || name.empty() || path.empty()) { + fprintf(stderr, "Invalid line: %s\n", line.c_str()); + exit(-1); + } + ans[name].push_back(path); + } + + return ans; +} + +int32_t main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +This program shows how to use non-streaming speaker identification. +Usage: + +(1) Prepare a text file containing speaker related files. + +Each line in the text file contains two columns. The first column is the +speaker name, while the second column contains the wave file of the speaker. + +If the text file contains multiple wave files for the same speaker, then the +embeddings of these files are averaged. + +An example text file is given below: + + foo /path/to/a.wav + bar /path/to/b.wav + foo /path/to/c.wav + foobar /path/to/d.wav + +Each wave file should contain only a single channel; the sample format +should be int16_t; the sample rate can be arbitrary. + +(2) Download a model for computing speaker embeddings + +Please visit +https://github.com/k2-fsa/sherpa-mnn/releases/tag/speaker-recongition-models +to download a model. An example is given below: + + wget https://github.com/k2-fsa/sherpa-mnn/releases/download/speaker-recongition-models/wespeaker_zh_cnceleb_resnet34.onnx + +Note that `zh` means Chinese, while `en` means English. + +(3) Run it ! + + ./bin/sherpa-mnn-microphone-offline-speaker-identification \ + --model=/path/to/your-model.onnx \ + --speaker-file=/path/to/speaker.txt +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + float threshold = 0.5; + std::string speaker_file; + + po.Register("threshold", &threshold, + "Threshold for comparing embedding scores."); + + po.Register("speaker-file", &speaker_file, "Path to speaker.txt"); + + sherpa_mnn::SpeakerEmbeddingExtractorConfig config; + config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 0) { + fprintf(stderr, + "This program does not support any positional arguments.\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config! Please use --help to view the usage.\n"); + return -1; + } + + SHERPA_ONNX_LOGE("\nCreating extractor ..."); + sherpa_mnn::SpeakerEmbeddingExtractor extractor(config); + SHERPA_ONNX_LOGE("\nextractor created!"); + + sherpa_mnn::SpeakerEmbeddingManager manager(extractor.Dim()); + + auto name2files = ReadSpeakerFile(speaker_file); + for (const auto &p : name2files) { + SHERPA_ONNX_LOGE("\nProcessing speaker %s", p.first.c_str()); + auto embedding_list = ComputeEmbeddings(p.second, &extractor); + manager.Add(p.first, embedding_list); + } + + sherpa_mnn::Microphone mic; + + PaDeviceIndex num_devices = Pa_GetDeviceCount(); + fprintf(stderr, "Num devices: %d\n", num_devices); + + int32_t device_index = Pa_GetDefaultInputDevice(); + if (device_index == paNoDevice) { + fprintf(stderr, "No default input device found\n"); + fprintf(stderr, "If you are using Linux, please switch to \n"); + fprintf(stderr, + " ./bin/sherpa-mnn-alsa-offline-speaker-identification \n"); + exit(EXIT_FAILURE); + } + + const char *pDeviceIndex = std::getenv("SHERPA_ONNX_MIC_DEVICE"); + if (pDeviceIndex) { + fprintf(stderr, "Use specified device: %s\n", pDeviceIndex); + device_index = atoi(pDeviceIndex); + } + + for (int32_t i = 0; i != num_devices; ++i) { + const PaDeviceInfo *info = Pa_GetDeviceInfo(i); + fprintf(stderr, " %s %d %s\n", (i == device_index) ? "*" : " ", i, + info->name); + } + + PaStreamParameters param; + param.device = device_index; + + fprintf(stderr, "Use device: %d\n", param.device); + + const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device); + fprintf(stderr, " Name: %s\n", info->name); + fprintf(stderr, " Max input channels: %d\n", info->maxInputChannels); + + param.channelCount = 1; + param.sampleFormat = paFloat32; + + param.suggestedLatency = info->defaultLowInputLatency; + param.hostApiSpecificStreamInfo = nullptr; + float mic_sample_rate = 16000; + const char *pSampleRateStr = std::getenv("SHERPA_ONNX_MIC_SAMPLE_RATE"); + if (pSampleRateStr) { + fprintf(stderr, "Use sample rate %f for mic\n", mic_sample_rate); + mic_sample_rate = atof(pSampleRateStr); + } + float sample_rate = 16000; + + PaStream *stream; + PaError err = + Pa_OpenStream(&stream, ¶m, nullptr, /* &outputParameters, */ + mic_sample_rate, + 0, // frames per buffer + paClipOff, // we won't output out of range samples + // so don't bother clipping them + RecordCallback, nullptr); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + err = Pa_StartStream(stream); + fprintf(stderr, "Started\n"); + + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + std::thread t(DetectKeyPress); + while (!stop) { + switch (state) { + case State::kIdle: + break; + case State::kRecording: + break; + case State::kComputing: { + std::vector buf; + { + std::lock_guard lock(samples_mutex); + buf = std::move(samples); + } + + auto s = extractor.CreateStream(); + s->AcceptWaveform(mic_sample_rate, buf.data(), buf.size()); + s->InputFinished(); + auto embedding = extractor.Compute(s.get()); + auto name = manager.Search(embedding.data(), threshold); + + if (name.empty()) { + name = "--Unknown--"; + } + + SHERPA_ONNX_LOGE("\nDone!\nDetected speaker is: %s", name.c_str()); + + state = State::kIdle; + SHERPA_ONNX_LOGE("\nPress Enter to start"); + break; + } + } + + Pa_Sleep(20); // sleep for 20ms + } + t.join(); + + err = Pa_CloseStream(stream); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-microphone-offline.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-microphone-offline.cc new file mode 100644 index 00000000..379295af --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-microphone-offline.cc @@ -0,0 +1,242 @@ +// sherpa-mnn/csrc/sherpa-mnn-microphone-offline.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include +#include +#include + +#include +#include // std::tolower +#include // NOLINT +#include // NOLINT + +#include "portaudio.h" // NOLINT +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/microphone.h" +#include "sherpa-mnn/csrc/offline-recognizer.h" + +enum class State { + kIdle, + kRecording, + kDecoding, +}; + +State state = State::kIdle; + +// true to stop the program and exit +bool stop = false; + +std::vector samples; +std::mutex samples_mutex; + +static void DetectKeyPress() { + SHERPA_ONNX_LOGE("Press Enter to start"); + int32_t key; + while (!stop && (key = getchar())) { + if (key != 0x0a) { + continue; + } + + switch (state) { + case State::kIdle: + SHERPA_ONNX_LOGE("Start recording. Press Enter to stop recording"); + state = State::kRecording; + { + std::lock_guard lock(samples_mutex); + samples.clear(); + } + break; + case State::kRecording: + SHERPA_ONNX_LOGE("Stop recording. Decoding ..."); + state = State::kDecoding; + break; + case State::kDecoding: + break; + } + } +} + +static int32_t RecordCallback(const void *input_buffer, + void * /*output_buffer*/, + unsigned long frames_per_buffer, // NOLINT + const PaStreamCallbackTimeInfo * /*time_info*/, + PaStreamCallbackFlags /*status_flags*/, + void * /*user_data*/) { + std::lock_guard lock(samples_mutex); + + auto p = reinterpret_cast(input_buffer); + samples.insert(samples.end(), p, p + frames_per_buffer); + + return stop ? paComplete : paContinue; +} + +static void Handler(int32_t /*sig*/) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Press Enter to exit\n"); +} + +int32_t main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +This program uses non-streaming models with microphone for speech recognition. +Usage: + +(1) Transducer from icefall + + ./bin/sherpa-mnn-microphone-offline \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search + +(2) Paraformer from FunASR + + ./bin/sherpa-mnn-microphone-offline \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/model.onnx \ + --num-threads=1 + +(3) Whisper models + + ./bin/sherpa-mnn-microphone-offline \ + --whisper-encoder=./sherpa-mnn-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-mnn-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-mnn-whisper-base.en/base.en-tokens.txt \ + --num-threads=1 + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::OfflineRecognizerConfig config; + config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 0) { + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + SHERPA_ONNX_LOGE("Creating recognizer ..."); + sherpa_mnn::OfflineRecognizer recognizer(config); + SHERPA_ONNX_LOGE("Recognizer created!"); + + sherpa_mnn::Microphone mic; + + PaDeviceIndex num_devices = Pa_GetDeviceCount(); + fprintf(stderr, "Num devices: %d\n", num_devices); + + int32_t device_index = Pa_GetDefaultInputDevice(); + + if (device_index == paNoDevice) { + fprintf(stderr, "No default input device found\n"); + fprintf(stderr, "If you are using Linux, please switch to \n"); + fprintf(stderr, " ./bin/sherpa-mnn-alsa-offline \n"); + exit(EXIT_FAILURE); + } + + const char *pDeviceIndex = std::getenv("SHERPA_ONNX_MIC_DEVICE"); + if (pDeviceIndex) { + fprintf(stderr, "Use specified device: %s\n", pDeviceIndex); + device_index = atoi(pDeviceIndex); + } + + for (int32_t i = 0; i != num_devices; ++i) { + const PaDeviceInfo *info = Pa_GetDeviceInfo(i); + fprintf(stderr, " %s %d %s\n", (i == device_index) ? "*" : " ", i, + info->name); + } + + PaStreamParameters param; + param.device = device_index; + + fprintf(stderr, "Use device: %d\n", param.device); + + const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device); + fprintf(stderr, " Name: %s\n", info->name); + fprintf(stderr, " Max input channels: %d\n", info->maxInputChannels); + + param.channelCount = 1; + param.sampleFormat = paFloat32; + + param.suggestedLatency = info->defaultLowInputLatency; + param.hostApiSpecificStreamInfo = nullptr; + float mic_sample_rate = 16000; + const char *pSampleRateStr = std::getenv("SHERPA_ONNX_MIC_SAMPLE_RATE"); + if (pSampleRateStr) { + fprintf(stderr, "Use sample rate %f for mic\n", mic_sample_rate); + mic_sample_rate = atof(pSampleRateStr); + } + + PaStream *stream; + PaError err = + Pa_OpenStream(&stream, ¶m, nullptr, /* &outputParameters, */ + mic_sample_rate, + 0, // frames per buffer + paClipOff, // we won't output out of range samples + // so don't bother clipping them + RecordCallback, nullptr); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + err = Pa_StartStream(stream); + fprintf(stderr, "Started\n"); + + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + std::thread t(DetectKeyPress); + while (!stop) { + switch (state) { + case State::kIdle: + break; + case State::kRecording: + break; + case State::kDecoding: { + std::vector buf; + { + std::lock_guard lock(samples_mutex); + buf = std::move(samples); + } + + auto s = recognizer.CreateStream(); + s->AcceptWaveform(mic_sample_rate, buf.data(), buf.size()); + recognizer.DecodeStream(s.get()); + SHERPA_ONNX_LOGE("Decoding Done! Result is:"); + SHERPA_ONNX_LOGE("%s", s->GetResult().text.c_str()); + + state = State::kIdle; + SHERPA_ONNX_LOGE("Press Enter to start"); + break; + } + } + + Pa_Sleep(20); // sleep for 20ms + } + t.join(); + + err = Pa_CloseStream(stream); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-microphone.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-microphone.cc new file mode 100644 index 00000000..f7666d8c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-microphone.cc @@ -0,0 +1,223 @@ +// sherpa-mnn/csrc/sherpa-mnn-microphone.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include +#include +#include + +#include +#include +#include + +#include "portaudio.h" // NOLINT +#include "sherpa-mnn/csrc/display.h" +#include "sherpa-mnn/csrc/microphone.h" +#include "sherpa-mnn/csrc/online-recognizer.h" + +bool stop = false; +float mic_sample_rate = 16000; + +static int32_t RecordCallback(const void *input_buffer, + void * /*output_buffer*/, + unsigned long frames_per_buffer, // NOLINT + const PaStreamCallbackTimeInfo * /*time_info*/, + PaStreamCallbackFlags /*status_flags*/, + void *user_data) { + auto stream = reinterpret_cast(user_data); + + stream->AcceptWaveform(mic_sample_rate, + reinterpret_cast(input_buffer), + frames_per_buffer); + + return stop ? paComplete : paContinue; +} + +static void Handler(int32_t /*sig*/) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); +} + +static std::string tolowerUnicode(const std::string &input_str) { + // Use system locale + std::setlocale(LC_ALL, ""); + + // From char string to wchar string + std::wstring input_wstr(input_str.size() + 1, '\0'); + std::mbstowcs(&input_wstr[0], input_str.c_str(), input_str.size()); + std::wstring lowercase_wstr; + + for (wchar_t wc : input_wstr) { + if (std::iswupper(wc)) { + lowercase_wstr += std::towlower(wc); + } else { + lowercase_wstr += wc; + } + } + + // Back to char string + std::string lowercase_str(input_str.size() + 1, '\0'); + std::wcstombs(&lowercase_str[0], lowercase_wstr.c_str(), + lowercase_wstr.size()); + + return lowercase_str; +} + +int32_t main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +This program uses streaming models with microphone for speech recognition. +Usage: + + ./bin/sherpa-mnn-microphone \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --provider=cpu \ + --num-threads=1 \ + --decoding-method=greedy_search + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::OnlineRecognizerConfig config; + + config.Register(&po); + po.Read(argc, argv); + if (po.NumArgs() != 0) { + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + sherpa_mnn::OnlineRecognizer recognizer(config); + auto s = recognizer.CreateStream(); + + sherpa_mnn::Microphone mic; + + PaDeviceIndex num_devices = Pa_GetDeviceCount(); + fprintf(stderr, "Num devices: %d\n", num_devices); + + int32_t device_index = Pa_GetDefaultInputDevice(); + + if (device_index == paNoDevice) { + fprintf(stderr, "No default input device found\n"); + fprintf(stderr, "If you are using Linux, please switch to \n"); + fprintf(stderr, " ./bin/sherpa-mnn-alsa \n"); + exit(EXIT_FAILURE); + } + + const char *pDeviceIndex = std::getenv("SHERPA_ONNX_MIC_DEVICE"); + if (pDeviceIndex) { + fprintf(stderr, "Use specified device: %s\n", pDeviceIndex); + device_index = atoi(pDeviceIndex); + } + + for (int32_t i = 0; i != num_devices; ++i) { + const PaDeviceInfo *info = Pa_GetDeviceInfo(i); + fprintf(stderr, " %s %d %s\n", (i == device_index) ? "*" : " ", i, + info->name); + } + + PaStreamParameters param; + param.device = device_index; + + fprintf(stderr, "Use device: %d\n", param.device); + + const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device); + fprintf(stderr, " Name: %s\n", info->name); + fprintf(stderr, " Max input channels: %d\n", info->maxInputChannels); + + param.channelCount = 1; + param.sampleFormat = paFloat32; + + param.suggestedLatency = info->defaultLowInputLatency; + param.hostApiSpecificStreamInfo = nullptr; + const char *pSampleRateStr = std::getenv("SHERPA_ONNX_MIC_SAMPLE_RATE"); + if (pSampleRateStr) { + fprintf(stderr, "Use sample rate %f for mic\n", mic_sample_rate); + mic_sample_rate = atof(pSampleRateStr); + } + float sample_rate = 16000; + + PaStream *stream; + PaError err = + Pa_OpenStream(&stream, ¶m, nullptr, /* &outputParameters, */ + sample_rate, + 0, // frames per buffer + paClipOff, // we won't output out of range samples + // so don't bother clipping them + RecordCallback, s.get()); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + err = Pa_StartStream(stream); + fprintf(stderr, "Started\n"); + + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + std::string last_text; + int32_t segment_index = 0; + sherpa_mnn::Display display(30); + while (!stop) { + while (recognizer.IsReady(s.get())) { + recognizer.DecodeStream(s.get()); + } + + auto text = recognizer.GetResult(s.get()).text; + bool is_endpoint = recognizer.IsEndpoint(s.get()); + + if (is_endpoint && !config.model_config.paraformer.encoder.empty()) { + // For streaming paraformer models, since it has a large right chunk size + // we need to pad it on endpointing so that the last character + // can be recognized + std::vector tail_paddings(static_cast(1.0 * mic_sample_rate)); + s->AcceptWaveform(mic_sample_rate, tail_paddings.data(), + tail_paddings.size()); + while (recognizer.IsReady(s.get())) { + recognizer.DecodeStream(s.get()); + } + text = recognizer.GetResult(s.get()).text; + } + + if (!text.empty() && last_text != text) { + last_text = text; + display.Print(segment_index, tolowerUnicode(text)); + fflush(stderr); + } + + if (is_endpoint) { + if (!text.empty()) { + ++segment_index; + } + + recognizer.Reset(s.get()); + } + + Pa_Sleep(20); // sleep for 20ms + } + + err = Pa_CloseStream(stream); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-audio-tagging.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-audio-tagging.cc new file mode 100644 index 00000000..8b9e20ed --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-audio-tagging.cc @@ -0,0 +1,97 @@ +// sherpa-mnn/csrc/sherpa-mnn-offline-audio-tagging.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include + +#include "sherpa-mnn/csrc/audio-tagging.h" +#include "sherpa-mnn/csrc/parse-options.h" +#include "sherpa-mnn/csrc/wave-reader.h" + +int32_t main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Audio tagging from a file. + +Usage: + +wget https://github.com/k2-fsa/sherpa-mnn/releases/download/audio-tagging-models/sherpa-mnn-zipformer-audio-tagging-2024-04-09.tar.bz2 +tar xvf sherpa-mnn-zipformer-audio-tagging-2024-04-09.tar.bz2 +rm sherpa-mnn-zipformer-audio-tagging-2024-04-09.tar.bz2 + +./bin/sherpa-mnn-offline-audio-tagging \ + --zipformer-model=./sherpa-mnn-zipformer-audio-tagging-2024-04-09/model.onnx \ + --labels=./sherpa-mnn-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv \ + sherpa-mnn-zipformer-audio-tagging-2024-04-09/test_wavs/0.wav + +Input wave files should be of single channel, 16-bit PCM encoded wave file; its +sampling rate can be arbitrary and does not need to be 16kHz. + +Please see +https://github.com/k2-fsa/sherpa-mnn/releases/tag/audio-tagging-models +for more models. +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::AudioTaggingConfig config; + config.Register(&po); + po.Read(argc, argv); + + if (po.NumArgs() != 1) { + fprintf(stderr, "\nError: Please provide 1 wave file\n\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + sherpa_mnn::AudioTagging tagger(config); + std::string wav_filename = po.GetArg(1); + + int32_t sampling_rate = -1; + + bool is_ok = false; + const std::vector samples = + sherpa_mnn::ReadWave(wav_filename, &sampling_rate, &is_ok); + + if (!is_ok) { + fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str()); + return -1; + } + + const float duration = samples.size() / static_cast(sampling_rate); + + fprintf(stderr, "Start to compute\n"); + const auto begin = std::chrono::steady_clock::now(); + + auto stream = tagger.CreateStream(); + + stream->AcceptWaveform(sampling_rate, samples.data(), samples.size()); + + auto results = tagger.Compute(stream.get()); + const auto end = std::chrono::steady_clock::now(); + fprintf(stderr, "Done\n"); + + int32_t i = 0; + + for (const auto &event : results) { + fprintf(stderr, "%d: %s\n", i, event.ToString().c_str()); + i += 1; + } + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Num threads: %d\n", config.model.num_threads); + fprintf(stderr, "Wave duration: %.3f\n", duration); + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-denoiser.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-denoiser.cc new file mode 100644 index 00000000..8a8fb163 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-denoiser.cc @@ -0,0 +1,95 @@ +// sherpa-mnn/csrc/sherpa-mnn-offline-denoiser.cc +// +// Copyright (c) 2025 Xiaomi Corporation +#include + +#include // NOLINT + +#include "sherpa-mnn/csrc/offline-speech-denoiser.h" +#include "sherpa-mnn/csrc/wave-reader.h" +#include "sherpa-mnn/csrc/wave-writer.h" + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Non-stremaing speech denoising with sherpa-mnn. + +Please visit +https://github.com/k2-fsa/sherpa-mnn/releases/tag/speech-enhancement-models +to download models. + +Usage: + +(1) Use gtcrn models + +wget https://github.com/k2-fsa/sherpa-mnn/releases/download/speech-enhancement-models/gtcrn_simple.onnx +./bin/sherpa-mnn-offline-denoiser \ + --speech-denoiser-gtcrn-model=gtcrn_simple.onnx \ + --input-wav input.wav \ + --output-wav output_16k.wav +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::OfflineSpeechDenoiserConfig config; + std::string input_wave; + std::string output_wave; + + config.Register(&po); + po.Register("input-wav", &input_wave, "Path to input wav."); + po.Register("output-wav", &output_wave, "Path to output wav"); + + po.Read(argc, argv); + if (po.NumArgs() != 0) { + fprintf(stderr, "Please don't give positional arguments\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (input_wave.empty()) { + fprintf(stderr, "Please provide --input-wav\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + if (output_wave.empty()) { + fprintf(stderr, "Please provide --output-wav\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + sherpa_mnn::OfflineSpeechDenoiser denoiser(config); + int32_t sampling_rate = -1; + bool is_ok = false; + std::vector samples = + sherpa_mnn::ReadWave(input_wave, &sampling_rate, &is_ok); + if (!is_ok) { + fprintf(stderr, "Failed to read '%s'\n", input_wave.c_str()); + return -1; + } + + fprintf(stderr, "Started\n"); + const auto begin = std::chrono::steady_clock::now(); + auto result = denoiser.Run(samples.data(), samples.size(), sampling_rate); + const auto end = std::chrono::steady_clock::now(); + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + fprintf(stderr, "Done\n"); + is_ok = sherpa_mnn::WriteWave(output_wave, result.sample_rate, + result.samples.data(), result.samples.size()); + if (is_ok) { + fprintf(stderr, "Saved to %s\n", output_wave.c_str()); + } else { + fprintf(stderr, "Failed to save to %s\n", output_wave.c_str()); + } + + float duration = samples.size() / static_cast(sampling_rate); + fprintf(stderr, "num threads: %d\n", config.model.num_threads); + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-language-identification.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-language-identification.cc new file mode 100644 index 00000000..3455d90a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-language-identification.cc @@ -0,0 +1,107 @@ +// sherpa-mnn/csrc/sherpa-mnn-offline-language-identification.cc +// +// Copyright (c) 2022-2024 Xiaomi Corporation + +#include + +#include // NOLINT +#include +#include + +#include "sherpa-mnn/csrc/parse-options.h" +#include "sherpa-mnn/csrc/spoken-language-identification.h" +#include "sherpa-mnn/csrc/wave-reader.h" + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Spoken language identification with sherpa-mnn. + +Usage: + +(1) Use a whisper multilingual model + +wget https://github.com/k2-fsa/sherpa-mnn/releases/download/asr-models/sherpa-mnn-whisper-tiny.tar.bz2 +tar xvf sherpa-mnn-whisper-tiny.tar.bz2 +rm sherpa-mnn-whisper-tiny.tar.bz2 + +We only use the int8.onnx models below. + +./bin/sherpa-mnn-offline-spoken-language-identification \ + --whisper-encoder=sherpa-mnn-whisper-tiny/tiny-encoder.int8.onnx \ + --whisper-decoder=sherpa-mnn-whisper-tiny/tiny-decoder.int8.onnx \ + --num-threads=1 \ + /path/to/foo.wav + +foo.wav should be of single channel, 16-bit PCM encoded wave file; its +sampling rate can be arbitrary and does not need to be 16kHz. +You can find test waves for different languages at +https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html +Note that only whisper multilingual models are supported. For instance, +"tiny" is supported but "tiny.en" is not. +for a list of pre-trained models to download. +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::SpokenLanguageIdentificationConfig config; + config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, "Error: Please provide 1 wave file.\n\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + fprintf(stderr, "Creating spoken language identifier ...\n"); + sherpa_mnn::SpokenLanguageIdentification slid(config); + + fprintf(stderr, "Started\n"); + const std::string wav_filename = po.GetArg(1); + + int32_t sampling_rate = -1; + bool is_ok = false; + const std::vector samples = + sherpa_mnn::ReadWave(wav_filename, &sampling_rate, &is_ok); + if (!is_ok) { + fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str()); + return -1; + } + float duration = samples.size() / static_cast(sampling_rate); + + const auto begin = std::chrono::steady_clock::now(); + + auto s = slid.CreateStream(); + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); + + auto language = slid.Compute(s.get()); + + const auto end = std::chrono::steady_clock::now(); + + fprintf(stderr, "Done!\n\n"); + fprintf(stderr, "%s\nDetected language: %s\n", wav_filename.c_str(), + language.c_str()); + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + fprintf(stderr, "num threads: %d\n", config.num_threads); + + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-parallel.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-parallel.cc new file mode 100644 index 00000000..02d98a94 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-parallel.cc @@ -0,0 +1,305 @@ +// sherpa-mnn/csrc/sherpa-mnn-offline-parallel.cc +// +// Copyright (c) 2022-2023 cuidc + +#include + +#include +#include // NOLINT +#include +#include // NOLINT +#include +#include // NOLINT +#include + +#include "sherpa-mnn/csrc/offline-recognizer.h" +#include "sherpa-mnn/csrc/parse-options.h" +#include "sherpa-mnn/csrc/wave-reader.h" + +std::atomic wav_index(0); +std::mutex mtx; + +std::vector> SplitToBatches( + const std::vector &input, int32_t batch_size) { + std::vector> outputs; + auto itr = input.cbegin(); + int32_t process_num = 0; + + while (process_num + batch_size <= static_cast(input.size())) { + auto chunk_end = itr + batch_size; + outputs.emplace_back(itr, chunk_end); + itr = chunk_end; + process_num += batch_size; + } + if (itr != input.cend()) { + outputs.emplace_back(itr, input.cend()); + } + return outputs; +} + +std::vector LoadScpFile(const std::string &wav_scp_path) { + std::vector wav_paths; + std::ifstream in(wav_scp_path); + if (!in.is_open()) { + fprintf(stderr, "Failed to open file: %s.\n", wav_scp_path.c_str()); + return wav_paths; + } + std::string line, column1, column2; + while (std::getline(in, line)) { + std::istringstream iss(line); + iss >> column1 >> column2; + wav_paths.emplace_back(std::move(column2)); + } + + return wav_paths; +} + +void AsrInference(const std::vector> &chunk_wav_paths, + sherpa_mnn::OfflineRecognizer *recognizer, + float *total_length, float *total_time) { + std::vector> ss; + std::vector ss_pointers; + float duration = 0.0f; + float elapsed_seconds_batch = 0.0f; + + // warm up + for (const auto &wav_filename : chunk_wav_paths[0]) { + int32_t sampling_rate = -1; + bool is_ok = false; + const std::vector samples = + sherpa_mnn::ReadWave(wav_filename, &sampling_rate, &is_ok); + if (!is_ok) { + fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str()); + continue; + } + duration += samples.size() / static_cast(sampling_rate); + auto s = recognizer->CreateStream(); + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); + + ss.push_back(std::move(s)); + ss_pointers.push_back(ss.back().get()); + } + recognizer->DecodeStreams(ss_pointers.data(), ss_pointers.size()); + ss_pointers.clear(); + ss.clear(); + + while (true) { + int chunk = wav_index.fetch_add(1); + if (chunk >= static_cast(chunk_wav_paths.size())) { + break; + } + const auto &wav_paths = chunk_wav_paths[chunk]; + const auto begin = std::chrono::steady_clock::now(); + for (const auto &wav_filename : wav_paths) { + int32_t sampling_rate = -1; + bool is_ok = false; + const std::vector samples = + sherpa_mnn::ReadWave(wav_filename, &sampling_rate, &is_ok); + if (!is_ok) { + fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str()); + continue; + } + duration += samples.size() / static_cast(sampling_rate); + auto s = recognizer->CreateStream(); + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); + + ss.push_back(std::move(s)); + ss_pointers.push_back(ss.back().get()); + } + recognizer->DecodeStreams(ss_pointers.data(), ss_pointers.size()); + const auto end = std::chrono::steady_clock::now(); + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + elapsed_seconds_batch += elapsed_seconds; + int i = 0; + for (const auto &wav_filename : wav_paths) { + fprintf(stderr, "%s\n%s\n----\n", wav_filename.c_str(), + ss[i]->GetResult().AsJsonString().c_str()); + i = i + 1; + } + ss_pointers.clear(); + ss.clear(); + } + + { + std::lock_guard guard(mtx); + *total_length += duration; + if (*total_time < elapsed_seconds_batch) { + *total_time = elapsed_seconds_batch; + } + } +} + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Speech recognition using non-streaming models with sherpa-mnn. + +Usage: + +(1) Transducer from icefall + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html + + ./bin/sherpa-mnn-offline-parallel \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=1 \ + --decoding-method=greedy_search \ + --batch-size=8 \ + --nj=1 \ + --wav-scp=wav.scp + + ./bin/sherpa-mnn-offline-parallel \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=1 \ + --decoding-method=greedy_search \ + --batch-size=1 \ + --nj=8 \ + /path/to/foo.wav [bar.wav foobar.wav ...] + +(2) Paraformer from FunASR + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html + + ./bin/sherpa-mnn-offline-parallel \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/model.onnx \ + --num-threads=1 \ + --decoding-method=greedy_search \ + /path/to/foo.wav [bar.wav foobar.wav ...] + +(3) Whisper models + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html + + ./bin/sherpa-mnn-offline-parallel \ + --whisper-encoder=./sherpa-mnn-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-mnn-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-mnn-whisper-base.en/base.en-tokens.txt \ + --num-threads=1 \ + /path/to/foo.wav [bar.wav foobar.wav ...] + +(4) NeMo CTC models + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html + + ./bin/sherpa-mnn-offline-parallel \ + --tokens=./sherpa-mnn-nemo-ctc-en-conformer-medium/tokens.txt \ + --nemo-ctc-model=./sherpa-mnn-nemo-ctc-en-conformer-medium/model.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + ./sherpa-mnn-nemo-ctc-en-conformer-medium/test_wavs/0.wav \ + ./sherpa-mnn-nemo-ctc-en-conformer-medium/test_wavs/1.wav \ + ./sherpa-mnn-nemo-ctc-en-conformer-medium/test_wavs/8k.wav + +(5) TDNN CTC model for the yesno recipe from icefall + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html + // + ./bin/sherpa-mnn-offline-parallel \ + --sample-rate=8000 \ + --feat-dim=23 \ + --tokens=./sherpa-mnn-tdnn-yesno/tokens.txt \ + --tdnn-model=./sherpa-mnn-tdnn-yesno/model-epoch-14-avg-2.onnx \ + ./sherpa-mnn-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ + ./sherpa-mnn-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav + +Note: It supports decoding multiple files in batches + +foo.wav should be of single channel, 16-bit PCM encoded wave file; its +sampling rate can be arbitrary and does not need to be 16kHz. + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. +)usage"; + std::string wav_scp = ""; // file path, kaldi style wav list. + int32_t nj = 1; // thread number + int32_t batch_size = 1; // number of wav files processed at once. + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::OfflineRecognizerConfig config; + config.Register(&po); + po.Register("wav-scp", &wav_scp, + "a file including wav-id and wav-path, kaldi style wav list." + "default=" + ". when it is not empty, wav files which positional " + "parameters provide are invalid."); + po.Register("nj", &nj, "multi-thread num for decoding, default=1"); + po.Register("batch-size", &batch_size, + "number of wav files processed at once during the decoding" + "process. default=1"); + + po.Read(argc, argv); + if (po.NumArgs() < 1 && wav_scp.empty()) { + fprintf(stderr, "Error: Please provide at least 1 wave file.\n\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + std::this_thread::sleep_for(std::chrono::seconds(10)); // sleep 10s + fprintf(stderr, "Creating recognizer ...\n"); + const auto begin = std::chrono::steady_clock::now(); + sherpa_mnn::OfflineRecognizer recognizer(config); + const auto end = std::chrono::steady_clock::now(); + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + fprintf(stderr, + "Started nj: %d, batch_size: %d, wav_path: %s. recognizer init time: " + "%.6f\n", + nj, batch_size, wav_scp.c_str(), elapsed_seconds); + std::this_thread::sleep_for(std::chrono::seconds(10)); // sleep 10s + std::vector wav_paths; + if (!wav_scp.empty()) { + wav_paths = LoadScpFile(wav_scp); + } else { + for (int32_t i = 1; i <= po.NumArgs(); ++i) { + wav_paths.emplace_back(po.GetArg(i)); + } + } + if (wav_paths.empty()) { + fprintf(stderr, "wav files is empty.\n"); + return -1; + } + std::vector threads; + std::vector> batch_wav_paths = + SplitToBatches(wav_paths, batch_size); + float total_length = 0.0f; + float total_time = 0.0f; + for (int i = 0; i < nj; i++) { + threads.emplace_back(std::thread(AsrInference, batch_wav_paths, &recognizer, + &total_length, &total_time)); + } + + for (auto &thread : threads) { + thread.join(); + } + + fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); + fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); + if (config.decoding_method == "modified_beam_search") { + fprintf(stderr, "max active paths: %d\n", config.max_active_paths); + } + fprintf(stderr, "Elapsed seconds: %.3f s\n", total_time); + float rtf = total_time / total_length; + fprintf(stderr, "Real time factor (RTF): %.6f / %.6f = %.4f\n", total_time, + total_length, rtf); + fprintf(stderr, "SPEEDUP: %.4f\n", 1.0 / rtf); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-punctuation.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-punctuation.cc new file mode 100644 index 00000000..8fc46f71 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-punctuation.cc @@ -0,0 +1,68 @@ +// sherpa-mnn/csrc/sherpa-mnn-offline-punctuation.cc +// +// Copyright (c) 2022-2024 Xiaomi Corporation +#include + +#include // NOLINT + +#include "sherpa-mnn/csrc/offline-punctuation.h" +#include "sherpa-mnn/csrc/parse-options.h" + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Add punctuations to the input text. + +The input text can contain both Chinese and English words. + +Usage: + +wget https://github.com/k2-fsa/sherpa-mnn/releases/download/punctuation-models/sherpa-mnn-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +tar xvf sherpa-mnn-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +rm sherpa-mnn-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 + +./bin/sherpa-mnn-offline-punctuation \ + --ct-transformer=./sherpa-mnn-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx + "你好吗how are you Fantasitic 谢谢我很好你怎么样呢" + +The output text should look like below: +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::OfflinePunctuationConfig config; + config.Register(&po); + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, + "Error: Please provide only 1 position argument containing the " + "input text.\n\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + fprintf(stderr, "Creating OfflinePunctuation ...\n"); + sherpa_mnn::OfflinePunctuation punct(config); + fprintf(stderr, "Started\n"); + const auto begin = std::chrono::steady_clock::now(); + + std::string text = po.GetArg(1); + std::string text_with_punct = punct.AddPunctuation(text); + fprintf(stderr, "Done\n"); + const auto end = std::chrono::steady_clock::now(); + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + fprintf(stderr, "Num threads: %d\n", config.model.num_threads); + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + fprintf(stderr, "Input text: %s\n", text.c_str()); + fprintf(stderr, "Output text: %s\n", text_with_punct.c_str()); +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-speaker-diarization.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-speaker-diarization.cc new file mode 100644 index 00000000..88c1669a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-speaker-diarization.cc @@ -0,0 +1,133 @@ +// sherpa-mnn/csrc/sherpa-mnn-offline-speaker-diarization.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-speaker-diarization.h" +#include "sherpa-mnn/csrc/parse-options.h" +#include "sherpa-mnn/csrc/wave-reader.h" + +static int32_t ProgressCallback(int32_t processed_chunks, int32_t num_chunks, + void *) { + float progress = 100.0 * processed_chunks / num_chunks; + fprintf(stderr, "progress %.2f%%\n", progress); + + // the return value is currently ignored + return 0; +} + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Offline/Non-streaming speaker diarization with sherpa-mnn +Usage example: + +Step 1: Download a speaker segmentation model + +Please visit https://github.com/k2-fsa/sherpa-mnn/releases/tag/speaker-segmentation-models +for a list of available models. The following is an example + + wget https://github.com/k2-fsa/sherpa-mnn/releases/download/speaker-segmentation-models/sherpa-mnn-pyannote-segmentation-3-0.tar.bz2 + tar xvf sherpa-mnn-pyannote-segmentation-3-0.tar.bz2 + rm sherpa-mnn-pyannote-segmentation-3-0.tar.bz2 + +Step 2: Download a speaker embedding extractor model + +Please visit https://github.com/k2-fsa/sherpa-mnn/releases/tag/speaker-recongition-models +for a list of available models. The following is an example + + wget https://github.com/k2-fsa/sherpa-mnn/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx + +Step 3. Download test wave files + +Please visit https://github.com/k2-fsa/sherpa-mnn/releases/tag/speaker-segmentation-models +for a list of available test wave files. The following is an example + + wget https://github.com/k2-fsa/sherpa-mnn/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav + +Step 4. Build sherpa-mnn + +Step 5. Run it + + ./bin/sherpa-mnn-offline-speaker-diarization \ + --clustering.num-clusters=4 \ + --segmentation.pyannote-model=./sherpa-mnn-pyannote-segmentation-3-0/model.onnx \ + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ + ./0-four-speakers-zh.wav + +Since we know that there are four speakers in the test wave file, we use +--clustering.num-clusters=4 in the above example. + +If we don't know number of speakers in the given wave file, we can use +the argument --clustering.cluster-threshold. The following is an example: + + ./bin/sherpa-mnn-offline-speaker-diarization \ + --clustering.cluster-threshold=0.90 \ + --segmentation.pyannote-model=./sherpa-mnn-pyannote-segmentation-3-0/model.onnx \ + --embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ + ./0-four-speakers-zh.wav + +A larger threshold leads to few clusters, i.e., few speakers; +a smaller threshold leads to more clusters, i.e., more speakers + )usage"; + sherpa_mnn::OfflineSpeakerDiarizationConfig config; + sherpa_mnn::ParseOptions po(kUsageMessage); + config.Register(&po); + po.Read(argc, argv); + + std::cout << config.ToString() << "\n"; + + if (!config.Validate()) { + po.PrintUsage(); + std::cerr << "Errors in config!\n"; + return -1; + } + + if (po.NumArgs() != 1) { + std::cerr << "Error: Please provide exactly 1 wave file.\n\n"; + po.PrintUsage(); + return -1; + } + + sherpa_mnn::OfflineSpeakerDiarization sd(config); + + std::cout << "Started\n"; + const auto begin = std::chrono::steady_clock::now(); + const std::string wav_filename = po.GetArg(1); + int32_t sample_rate = -1; + bool is_ok = false; + const std::vector samples = + sherpa_mnn::ReadWave(wav_filename, &sample_rate, &is_ok); + if (!is_ok) { + std::cerr << "Failed to read " << wav_filename.c_str() << "\n"; + return -1; + } + + if (sample_rate != sd.SampleRate()) { + std::cerr << "Expect sample rate " << sd.SampleRate() + << ". Given: " << sample_rate << "\n"; + return -1; + } + + float duration = samples.size() / static_cast(sample_rate); + + auto result = + sd.Process(samples.data(), samples.size(), ProgressCallback, nullptr) + .SortByStartTime(); + + for (const auto &r : result) { + std::cout << r.ToString() << "\n"; + } + + const auto end = std::chrono::steady_clock::now(); + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + fprintf(stderr, "Duration : %.3f s\n", duration); + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-tts-play-alsa.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-tts-play-alsa.cc new file mode 100644 index 00000000..7585fac4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-tts-play-alsa.cc @@ -0,0 +1,226 @@ +// sherpa-mnn/csrc/sherpa-mnn-tts-play-alsa.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +// see https://www.alsa-project.org/alsa-doc/alsa-lib/group___p_c_m.html +// https://www.alsa-project.org/alsa-doc/alsa-lib/group___p_c_m___h_w___params.html +// https://www.alsa-project.org/alsa-doc/alsa-lib/group___p_c_m.html + +#include + +#include +#include // NOLINT +#include // NOLINT +#include +#include // NOLINT +#include +#include // NOLINT +#include + +#include "sherpa-mnn/csrc/alsa-play.h" +#include "sherpa-mnn/csrc/offline-tts.h" +#include "sherpa-mnn/csrc/parse-options.h" +#include "sherpa-mnn/csrc/wave-writer.h" + +static std::condition_variable g_cv; +static std::mutex g_cv_m; + +struct Buffer { + std::queue> samples; + std::mutex mutex; +}; + +static Buffer g_buffer; + +static bool g_stopped = false; +static bool g_killed = false; + +static void Handler(int32_t /*sig*/) { + if (g_killed) { + exit(0); + } + + g_killed = true; + fprintf(stderr, "\nCaught Ctrl + C. Exiting\n"); +} + +static int32_t AudioGeneratedCallback(const float *s, int32_t n, + float /*progress*/) { + if (n > 0) { + std::lock_guard lock(g_buffer.mutex); + g_buffer.samples.push({s, s + n}); + g_cv.notify_all(); + } + + if (g_killed) { + return 0; // stop generating + } + + // continue generating + return 1; +} + +static void StartPlayback(const std::string &device_name, int32_t sample_rate) { + sherpa_mnn::AlsaPlay alsa(device_name.c_str(), sample_rate); + + std::unique_lock lock(g_cv_m); + while (!g_killed && !g_stopped) { + while (!g_buffer.samples.empty()) { + auto &p = g_buffer.samples.front(); + alsa.Play(p); + g_buffer.samples.pop(); + } + + g_cv.wait(lock); + } + + if (g_killed) { + return; + } + + if (g_stopped) { + while (!g_buffer.samples.empty()) { + auto &p = g_buffer.samples.front(); + alsa.Play(p); + g_buffer.samples.pop(); + } + } + + alsa.Drain(); +} + +int main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +Offline text-to-speech with sherpa-mnn. + +It plays the generated audio as the model is processing. + +Note that it is alsa so it works only on **Linux**. For instance, you can +use it on Raspberry Pi. + +Usage example: + +wget https://github.com/k2-fsa/sherpa-mnn/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 +tar xf vits-piper-en_US-amy-low.tar.bz2 + +./bin/sherpa-mnn-offline-tts-play-alsa \ + --vits-model=./vits-piper-en_US-amy-low/en_US-amy-low.onnx \ + --vits-tokens=./vits-piper-en_US-amy-low/tokens.txt \ + --vits-data-dir=./vits-piper-en_US-amy-low/espeak-ng-data \ + --output-filename=./generated.wav \ + "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar." + +It will generate a file ./generated.wav as specified by --output-filename. + +You can find more models at +https://github.com/k2-fsa/sherpa-mnn/releases/tag/tts-models + +Please see +https://k2-fsa.github.io/sherpa/onnx/tts/index.html +or details. +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + std::string device_name = "default"; + std::string output_filename = "./generated.wav"; + int32_t sid = 0; + + po.Register("output-filename", &output_filename, + "Path to save the generated audio"); + + po.Register("device-name", &device_name, + "Name of the device to play the generated audio"); + + po.Register("sid", &sid, + "Speaker ID. Used only for multi-speaker models, e.g., models " + "trained using the VCTK dataset. Not used for single-speaker " + "models, e.g., models trained using the LJSpeech dataset"); + + sherpa_mnn::OfflineTtsConfig config; + + config.Register(&po); + po.Read(argc, argv); + + if (po.NumArgs() == 0) { + fprintf(stderr, "Error: Please provide the text to generate audio.\n\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + if (po.NumArgs() > 1) { + fprintf(stderr, + "Error: Accept only one positional argument. Please use single " + "quotes to wrap your text\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + exit(EXIT_FAILURE); + } + + if (config.max_num_sentences != 1) { + fprintf(stderr, "Setting config.max_num_sentences to 1\n"); + config.max_num_sentences = 1; + } + + fprintf(stderr, "Loading the model\n"); + sherpa_mnn::OfflineTts tts(config); + + fprintf(stderr, "Start the playback thread\n"); + std::thread playback_thread(StartPlayback, device_name, tts.SampleRate()); + + float speed = 1.0; + + fprintf(stderr, "Generating ...\n"); + const auto begin = std::chrono::steady_clock::now(); + auto audio = tts.Generate(po.GetArg(1), sid, speed, AudioGeneratedCallback); + const auto end = std::chrono::steady_clock::now(); + g_stopped = true; + g_cv.notify_all(); + fprintf(stderr, "Generating done!\n"); + if (audio.samples.empty()) { + fprintf( + stderr, + "Error in generating audio. Please read previous error messages.\n"); + exit(EXIT_FAILURE); + } + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + float duration = audio.samples.size() / static_cast(audio.sample_rate); + + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + fprintf(stderr, "Audio duration: %.3f s\n", duration); + fprintf(stderr, "Real-time factor (RTF): %.3f/%.3f = %.3f\n", elapsed_seconds, + duration, rtf); + + bool ok = sherpa_mnn::WriteWave(output_filename, audio.sample_rate, + audio.samples.data(), audio.samples.size()); + if (!ok) { + fprintf(stderr, "Failed to write wave to %s\n", output_filename.c_str()); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "The text is: %s. Speaker ID: %d\n\n", po.GetArg(1).c_str(), + sid); + fprintf(stderr, "\n**** Saved to %s successfully! ****\n", + output_filename.c_str()); + + fprintf(stderr, "\n"); + fprintf( + stderr, + "Wait for the playback to finish. You can safely press ctrl + C to stop " + "the playback.\n"); + playback_thread.join(); + + fprintf(stderr, "Done!\n"); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-tts-play.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-tts-play.cc new file mode 100644 index 00000000..f08dcac9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-tts-play.cc @@ -0,0 +1,330 @@ +// sherpa-mnn/csrc/sherpa-mnn-offline-tts-play.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include + +#include +#include // NOLINT +#include // NOLINT +#include +#include // NOLINT +#include +#include // NOLINT +#include + +#include "portaudio.h" // NOLINT +#include "sherpa-mnn/csrc/microphone.h" +#include "sherpa-mnn/csrc/offline-tts.h" +#include "sherpa-mnn/csrc/parse-options.h" +#include "sherpa-mnn/csrc/wave-writer.h" + +static std::condition_variable g_cv; +static std::mutex g_cv_m; + +struct Samples { + std::vector data; + int32_t consumed = 0; +}; + +struct Buffer { + std::queue samples; + std::mutex mutex; +}; + +static Buffer g_buffer; + +static bool g_started = false; +static bool g_stopped = false; +static bool g_killed = false; + +static void Handler(int32_t /*sig*/) { + if (g_killed) { + exit(0); + } + + g_killed = true; + fprintf(stderr, "\nCaught Ctrl + C. Exiting\n"); +} + +static int32_t AudioGeneratedCallback(const float *s, int32_t n, + float /*progress*/) { + if (n > 0) { + Samples samples; + samples.data = std::vector{s, s + n}; + + std::lock_guard lock(g_buffer.mutex); + g_buffer.samples.push(std::move(samples)); + g_started = true; + } + if (g_killed) { + return 0; // stop generating + } + + // continue generating + return 1; +} + +static int PlayCallback(const void * /*in*/, void *out, + unsigned long n, // NOLINT + const PaStreamCallbackTimeInfo * /*time_info*/, + PaStreamCallbackFlags /*status_flags*/, + void * /*user_data*/) { + if (g_killed) { + return paComplete; + } + + float *pout = reinterpret_cast(out); + std::lock_guard lock(g_buffer.mutex); + + if (g_buffer.samples.empty()) { + if (g_stopped) { + // no more data is available and we have processed all of the samples + return paComplete; + } + + // The current sentence is so long, though very unlikely, that + // the model has not finished processing it yet. + std::fill_n(pout, n, 0); + + return paContinue; + } + + int32_t k = 0; + for (; k < static_cast(n) && !g_buffer.samples.empty();) { + int32_t this_block = n - k; + + auto &p = g_buffer.samples.front(); + + int32_t remaining = p.data.size() - p.consumed; + + if (this_block <= remaining) { + std::copy(p.data.begin() + p.consumed, + p.data.begin() + p.consumed + this_block, pout + k); + p.consumed += this_block; + + k = n; + + if (p.consumed == static_cast(p.data.size())) { + g_buffer.samples.pop(); + } + break; + } + + std::copy(p.data.begin() + p.consumed, p.data.end(), pout + k); + k += p.data.size() - p.consumed; + g_buffer.samples.pop(); + } + + if (k < static_cast(n)) { + std::fill_n(pout + k, n - k, 0); + } + + if (g_stopped && g_buffer.samples.empty()) { + return paComplete; + } + + return paContinue; +} + +static void PlayCallbackFinished(void * /*userData*/) { g_cv.notify_all(); } + +static void StartPlayback(int32_t sample_rate) { + int32_t frames_per_buffer = 1024; + PaStreamParameters outputParameters; + PaStream *stream; + PaError err; + + outputParameters.device = + Pa_GetDefaultOutputDevice(); /* default output device */ + + outputParameters.channelCount = 1; /* stereo output */ + outputParameters.sampleFormat = paFloat32; /* 32 bit floating point output */ + outputParameters.suggestedLatency = + Pa_GetDeviceInfo(outputParameters.device)->defaultLowOutputLatency; + outputParameters.hostApiSpecificStreamInfo = nullptr; + + err = Pa_OpenStream(&stream, nullptr, /* no input */ + &outputParameters, sample_rate, frames_per_buffer, + paClipOff, // we won't output out of range samples so + // don't bother clipping them + PlayCallback, nullptr); + if (err != paNoError) { + fprintf(stderr, "%d portaudio error: %s\n", __LINE__, Pa_GetErrorText(err)); + return; + } + + err = Pa_SetStreamFinishedCallback(stream, &PlayCallbackFinished); + if (err != paNoError) { + fprintf(stderr, "%d portaudio error: %s\n", __LINE__, Pa_GetErrorText(err)); + return; + } + + err = Pa_StartStream(stream); + if (err != paNoError) { + fprintf(stderr, "%d portaudio error: %s\n", __LINE__, Pa_GetErrorText(err)); + return; + } + + std::unique_lock lock(g_cv_m); + while (!g_killed && !g_stopped && + (!g_started || (g_started && !g_buffer.samples.empty()))) { + g_cv.wait(lock); + } + + err = Pa_StopStream(stream); + if (err != paNoError) { + return; + } + + err = Pa_CloseStream(stream); + if (err != paNoError) { + return; + } +} + +int main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +Offline text-to-speech with sherpa-mnn. + +It plays the generated audio as the model is processing. + +Usage example: + +wget https://github.com/k2-fsa/sherpa-mnn/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 +tar xf vits-piper-en_US-amy-low.tar.bz2 + +./bin/sherpa-mnn-offline-tts-play \ + --vits-model=./vits-piper-en_US-amy-low/en_US-amy-low.onnx \ + --vits-tokens=./vits-piper-en_US-amy-low/tokens.txt \ + --vits-data-dir=./vits-piper-en_US-amy-low/espeak-ng-data \ + --output-filename=./generated.wav \ + "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar." + +It will generate a file ./generated.wav as specified by --output-filename. + +You can find more models at +https://github.com/k2-fsa/sherpa-mnn/releases/tag/tts-models + +Please see +https://k2-fsa.github.io/sherpa/onnx/tts/index.html +or details. +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + std::string output_filename = "./generated.wav"; + int32_t sid = 0; + + po.Register("output-filename", &output_filename, + "Path to save the generated audio"); + + po.Register("sid", &sid, + "Speaker ID. Used only for multi-speaker models, e.g., models " + "trained using the VCTK dataset. Not used for single-speaker " + "models, e.g., models trained using the LJSpeech dataset"); + + sherpa_mnn::OfflineTtsConfig config; + + config.Register(&po); + po.Read(argc, argv); + + if (po.NumArgs() == 0) { + fprintf(stderr, "Error: Please provide the text to generate audio.\n\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + if (po.NumArgs() > 1) { + fprintf(stderr, + "Error: Accept only one positional argument. Please use single " + "quotes to wrap your text\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + exit(EXIT_FAILURE); + } + + sherpa_mnn::Microphone mic; + + PaDeviceIndex num_devices = Pa_GetDeviceCount(); + fprintf(stderr, "Num devices: %d\n", num_devices); + + PaStreamParameters param; + + param.device = Pa_GetDefaultOutputDevice(); + if (param.device == paNoDevice) { + fprintf(stderr, "No default output device found\n"); + exit(EXIT_FAILURE); + } + fprintf(stderr, "Use default device: %d\n", param.device); + + const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device); + fprintf(stderr, " Name: %s\n", info->name); + fprintf(stderr, " Max output channels: %d\n", info->maxOutputChannels); + + if (config.max_num_sentences != 1) { + fprintf(stderr, "Setting config.max_num_sentences to 1\n"); + config.max_num_sentences = 1; + } + + fprintf(stderr, "Loading the model\n"); + sherpa_mnn::OfflineTts tts(config); + + fprintf(stderr, "Start the playback thread\n"); + std::thread playback_thread(StartPlayback, tts.SampleRate()); + + float speed = 1.0; + + fprintf(stderr, "Generating ...\n"); + const auto begin = std::chrono::steady_clock::now(); + auto audio = tts.Generate(po.GetArg(1), sid, speed, AudioGeneratedCallback); + const auto end = std::chrono::steady_clock::now(); + g_stopped = true; + fprintf(stderr, "Generating done!\n"); + if (audio.samples.empty()) { + fprintf( + stderr, + "Error in generating audio. Please read previous error messages.\n"); + exit(EXIT_FAILURE); + } + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + float duration = audio.samples.size() / static_cast(audio.sample_rate); + + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + fprintf(stderr, "Audio duration: %.3f s\n", duration); + fprintf(stderr, "Real-time factor (RTF): %.3f/%.3f = %.3f\n", elapsed_seconds, + duration, rtf); + + bool ok = sherpa_mnn::WriteWave(output_filename, audio.sample_rate, + audio.samples.data(), audio.samples.size()); + if (!ok) { + fprintf(stderr, "Failed to write wave to %s\n", output_filename.c_str()); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "The text is: %s. Speaker ID: %d\n\n", po.GetArg(1).c_str(), + sid); + fprintf(stderr, "\n**** Saved to %s successfully! ****\n", + output_filename.c_str()); + + fprintf(stderr, "\n"); + fprintf( + stderr, + "Wait for the playback to finish. You can safely press ctrl + C to stop " + "the playback.\n"); + playback_thread.join(); + + fprintf(stderr, "Done!\n"); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-tts.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-tts.cc new file mode 100644 index 00000000..429cd9cf --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline-tts.cc @@ -0,0 +1,121 @@ +// sherpa-mnn/csrc/sherpa-mnn-offline-tts.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include // NOLINT +#include + +#include "sherpa-mnn/csrc/offline-tts.h" +#include "sherpa-mnn/csrc/parse-options.h" +#include "sherpa-mnn/csrc/wave-writer.h" + +static int32_t AudioCallback(const float * /*samples*/, int32_t n, + float progress) { + printf("sample=%d, progress=%f\n", n, progress); + return 1; +} + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Offline/Non-streaming text-to-speech with sherpa-mnn + +Usage example: + +wget https://github.com/k2-fsa/sherpa-mnn/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 +tar xf vits-piper-en_US-amy-low.tar.bz2 + +./bin/sherpa-mnn-offline-tts \ + --vits-model=./vits-piper-en_US-amy-low/en_US-amy-low.onnx \ + --vits-tokens=./vits-piper-en_US-amy-low/tokens.txt \ + --vits-data-dir=./vits-piper-en_US-amy-low/espeak-ng-data \ + --output-filename=./generated.wav \ + "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar." + +It will generate a file ./generated.wav as specified by --output-filename. + +You can find more models at +https://github.com/k2-fsa/sherpa-mnn/releases/tag/tts-models + +Please see +https://k2-fsa.github.io/sherpa/onnx/tts/index.html +or details. +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + std::string output_filename = "./generated.wav"; + int32_t sid = 0; + + po.Register("output-filename", &output_filename, + "Path to save the generated audio"); + + po.Register("sid", &sid, + "Speaker ID. Used only for multi-speaker models, e.g., models " + "trained using the VCTK dataset. Not used for single-speaker " + "models, e.g., models trained using the LJSpeech dataset"); + + sherpa_mnn::OfflineTtsConfig config; + + config.Register(&po); + po.Read(argc, argv); + + if (po.NumArgs() == 0) { + fprintf(stderr, "Error: Please provide the text to generate audio.\n\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + if (po.NumArgs() > 1) { + fprintf(stderr, + "Error: Accept only one positional argument. Please use single " + "quotes to wrap your text\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + if (config.model.debug) { + fprintf(stderr, "%s\n", config.model.ToString().c_str()); + } + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + exit(EXIT_FAILURE); + } + + sherpa_mnn::OfflineTts tts(config); + + const auto begin = std::chrono::steady_clock::now(); + auto audio = tts.Generate(po.GetArg(1), sid, 1.0, AudioCallback); + const auto end = std::chrono::steady_clock::now(); + + if (audio.samples.empty()) { + fprintf( + stderr, + "Error in generating audio. Please read previous error messages.\n"); + exit(EXIT_FAILURE); + } + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + float duration = audio.samples.size() / static_cast(audio.sample_rate); + + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + fprintf(stderr, "Audio duration: %.3f s\n", duration); + fprintf(stderr, "Real-time factor (RTF): %.3f/%.3f = %.3f\n", elapsed_seconds, + duration, rtf); + + bool ok = sherpa_mnn::WriteWave(output_filename, audio.sample_rate, + audio.samples.data(), audio.samples.size()); + if (!ok) { + fprintf(stderr, "Failed to write wave to %s\n", output_filename.c_str()); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "The text is: %s. Speaker ID: %d\n", po.GetArg(1).c_str(), + sid); + fprintf(stderr, "Saved to %s successfully!\n", output_filename.c_str()); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline.cc new file mode 100644 index 00000000..73dbcef3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-offline.cc @@ -0,0 +1,179 @@ +// sherpa-mnn/csrc/sherpa-mnn-offline.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include + +#include // NOLINT +#include +#include + +#include "sherpa-mnn/csrc/offline-recognizer.h" +#include "sherpa-mnn/csrc/parse-options.h" +#include "sherpa-mnn/csrc/wave-reader.h" + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Speech recognition using non-streaming models with sherpa-mnn. + +Usage: + +(1) Transducer from icefall + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html + + ./bin/sherpa-mnn-offline \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=1 \ + --decoding-method=greedy_search \ + /path/to/foo.wav [bar.wav foobar.wav ...] + + +(2) Paraformer from FunASR + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html + + ./bin/sherpa-mnn-offline \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/model.onnx \ + --num-threads=1 \ + --decoding-method=greedy_search \ + /path/to/foo.wav [bar.wav foobar.wav ...] + +(3) Moonshine models + +See https://k2-fsa.github.io/sherpa/onnx/moonshine/index.html + + ./bin/sherpa-mnn-offline \ + --moonshine-preprocessor=/Users/fangjun/open-source/sherpa-mnn/scripts/moonshine/preprocess.onnx \ + --moonshine-encoder=/Users/fangjun/open-source/sherpa-mnn/scripts/moonshine/encode.int8.onnx \ + --moonshine-uncached-decoder=/Users/fangjun/open-source/sherpa-mnn/scripts/moonshine/uncached_decode.int8.onnx \ + --moonshine-cached-decoder=/Users/fangjun/open-source/sherpa-mnn/scripts/moonshine/cached_decode.int8.onnx \ + --tokens=/Users/fangjun/open-source/sherpa-mnn/scripts/moonshine/tokens.txt \ + --num-threads=1 \ + /path/to/foo.wav [bar.wav foobar.wav ...] + +(4) Whisper models + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html + + ./bin/sherpa-mnn-offline \ + --whisper-encoder=./sherpa-mnn-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-mnn-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-mnn-whisper-base.en/base.en-tokens.txt \ + --num-threads=1 \ + /path/to/foo.wav [bar.wav foobar.wav ...] + +(5) NeMo CTC models + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html + + ./bin/sherpa-mnn-offline \ + --tokens=./sherpa-mnn-nemo-ctc-en-conformer-medium/tokens.txt \ + --nemo-ctc-model=./sherpa-mnn-nemo-ctc-en-conformer-medium/model.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + ./sherpa-mnn-nemo-ctc-en-conformer-medium/test_wavs/0.wav \ + ./sherpa-mnn-nemo-ctc-en-conformer-medium/test_wavs/1.wav \ + ./sherpa-mnn-nemo-ctc-en-conformer-medium/test_wavs/8k.wav + +(6) TDNN CTC model for the yesno recipe from icefall + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html + // + ./build/bin/sherpa-mnn-offline \ + --sample-rate=8000 \ + --feat-dim=23 \ + --tokens=./sherpa-mnn-tdnn-yesno/tokens.txt \ + --tdnn-model=./sherpa-mnn-tdnn-yesno/model-epoch-14-avg-2.onnx \ + ./sherpa-mnn-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \ + ./sherpa-mnn-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav + +Note: It supports decoding multiple files in batches + +foo.wav should be of single channel, 16-bit PCM encoded wave file; its +sampling rate can be arbitrary and does not need to be 16kHz. + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::OfflineRecognizerConfig config; + config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() < 1) { + fprintf(stderr, "Error: Please provide at least 1 wave file.\n\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + fprintf(stderr, "Creating recognizer ...\n"); + sherpa_mnn::OfflineRecognizer recognizer(config); + + fprintf(stderr, "Started\n"); + const auto begin = std::chrono::steady_clock::now(); + + std::vector> ss; + std::vector ss_pointers; + float duration = 0; + for (int32_t i = 1; i <= po.NumArgs(); ++i) { + std::string wav_filename = po.GetArg(i); + int32_t sampling_rate = -1; + bool is_ok = false; + std::vector samples = + sherpa_mnn::ReadWave(wav_filename, &sampling_rate, &is_ok); + if (!is_ok) { + fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str()); + return -1; + } + duration += samples.size() / static_cast(sampling_rate); + + auto s = recognizer.CreateStream(); + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); + + ss.push_back(std::move(s)); + ss_pointers.push_back(ss.back().get()); + } + + recognizer.DecodeStreams(ss_pointers.data(), ss_pointers.size()); + + const auto end = std::chrono::steady_clock::now(); + + fprintf(stderr, "Done!\n\n"); + for (int32_t i = 1; i <= po.NumArgs(); ++i) { + fprintf(stderr, "%s\n%s\n----\n", po.GetArg(i).c_str(), + ss[i - 1]->GetResult().AsJsonString().c_str()); + } + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); + fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); + if (config.decoding_method == "modified_beam_search") { + fprintf(stderr, "max active paths: %d\n", config.max_active_paths); + } + + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-online-punctuation.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-online-punctuation.cc new file mode 100644 index 00000000..5d1d7c79 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-online-punctuation.cc @@ -0,0 +1,73 @@ +// sherpa-mnn/csrc/sherpa-mnn-online-punctuation.cc +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#include + +#include // NOLINT +#include + +#include "sherpa-mnn/csrc/online-punctuation.h" +#include "sherpa-mnn/csrc/parse-options.h" + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Add punctuations to the input text. + +The input text can contain English words. + +Usage: + +Please download the model from: +https://github.com/k2-fsa/sherpa-mnn/releases/download/punctuation-models/sherpa-mnn-online-punct-en-2024-08-06.tar.bz2 + +./bin/Release/sherpa-mnn-online-punctuation \ + --cnn-bilstm=/path/to/model.onnx \ + --bpe-vocab=/path/to/bpe.vocab \ + "how are you i am fine thank you" + +The output text should look like below: + "How are you? I am fine. Thank you." +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::OnlinePunctuationConfig config; + config.Register(&po); + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, + "Error: Please provide only 1 positional argument containing the " + "input text.\n\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + fprintf(stderr, "Creating OnlinePunctuation ...\n"); + sherpa_mnn::OnlinePunctuation punct(config); + fprintf(stderr, "Started\n"); + const auto begin = std::chrono::steady_clock::now(); + + std::string text = po.GetArg(1); + + std::string text_with_punct_case = punct.AddPunctuationWithCase(text); + + const auto end = std::chrono::steady_clock::now(); + fprintf(stderr, "Done\n"); + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + fprintf(stderr, "Num threads: %d\n", config.model.num_threads); + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + fprintf(stderr, "Input text: %s\n", text.c_str()); + fprintf(stderr, "Output text: %s\n", text_with_punct_case.c_str()); +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-vad-alsa.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-vad-alsa.cc new file mode 100644 index 00000000..b90e6bbf --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-vad-alsa.cc @@ -0,0 +1,132 @@ +// sherpa-mnn/csrc/sherpa-mnn-vad-alsa.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include +#include +#include + +#include + +#include "sherpa-mnn/csrc/alsa.h" +#include "sherpa-mnn/csrc/circular-buffer.h" +#include "sherpa-mnn/csrc/voice-activity-detector.h" +#include "sherpa-mnn/csrc/wave-writer.h" + +bool stop = false; +static void Handler(int32_t sig) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); +} + +int32_t main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +This program shows how to use VAD in sherpa-mnn. + + ./bin/sherpa-mnn-vad-alsa \ + --silero-vad-model=/path/to/silero_vad.onnx \ + device_name + +Please download silero_vad.onnx from +https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx + +For instance, use +wget https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx + +The device name specifies which microphone to use in case there are several +on your system. You can use + + arecord -l + +to find all available microphones on your computer. For instance, if it outputs + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and device 0 on that card, please use: + + plughw:3,0 + +as the device_name. +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::VadModelConfig config; + + config.Register(&po); + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, "Please provide only 1 argument: the device name\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + std::string device_name = po.GetArg(1); + sherpa_mnn::Alsa alsa(device_name.c_str()); + fprintf(stderr, "Use recording device: %s\n", device_name.c_str()); + + int32_t sample_rate = 16000; + + if (alsa.GetExpectedSampleRate() != sample_rate) { + fprintf(stderr, "sample rate: %d != %d\n", alsa.GetExpectedSampleRate(), + sample_rate); + exit(-1); + } + + int32_t chunk = 0.1 * alsa.GetActualSampleRate(); + + auto vad = std::make_unique(config); + + fprintf(stderr, "Started. Please speak\n"); + + int32_t window_size = config.silero_vad.window_size; + bool printed = false; + + int32_t k = 0; + while (!stop) { + { + const std::vector &samples = alsa.Read(chunk); + + vad->AcceptWaveform(samples.data(), samples.size()); + + if (vad->IsSpeechDetected() && !printed) { + printed = true; + fprintf(stderr, "\nDetected speech!\n"); + } + if (!vad->IsSpeechDetected()) { + printed = false; + } + + while (!vad->Empty()) { + const auto &segment = vad->Front(); + float duration = + segment.samples.size() / static_cast(sample_rate); + + fprintf(stderr, "Duration: %.3f seconds\n", duration); + + char filename[128]; + snprintf(filename, sizeof(filename), "seg-%d-%.3fs.wav", k, duration); + k += 1; + sherpa_mnn::WriteWave(filename, 16000, segment.samples.data(), + segment.samples.size()); + fprintf(stderr, "Saved to %s\n", filename); + fprintf(stderr, "----------\n"); + + vad->Pop(); + } + } + } + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-vad-microphone-offline-asr.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-vad-microphone-offline-asr.cc new file mode 100644 index 00000000..ee973ecd --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-vad-microphone-offline-asr.cc @@ -0,0 +1,237 @@ +// sherpa-mnn/csrc/sherpa-mnn-vad-microphone-offline-asr.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include +#include +#include + +#include +#include // NOLINT + +#include "portaudio.h" // NOLINT +#include "sherpa-mnn/csrc/circular-buffer.h" +#include "sherpa-mnn/csrc/microphone.h" +#include "sherpa-mnn/csrc/offline-recognizer.h" +#include "sherpa-mnn/csrc/resample.h" +#include "sherpa-mnn/csrc/voice-activity-detector.h" + +bool stop = false; +std::mutex mutex; +sherpa_mnn::CircularBuffer buffer(16000 * 60); + +static int32_t RecordCallback(const void *input_buffer, + void * /*output_buffer*/, + unsigned long frames_per_buffer, // NOLINT + const PaStreamCallbackTimeInfo * /*time_info*/, + PaStreamCallbackFlags /*status_flags*/, + void * /*user_data*/) { + std::lock_guard lock(mutex); + buffer.Push(reinterpret_cast(input_buffer), frames_per_buffer); + + return stop ? paComplete : paContinue; +} + +static void Handler(int32_t /*sig*/) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); +} + +int32_t main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +This program shows how to use a streaming VAD with non-streaming ASR in +sherpa-mnn. + +Please download silero_vad.onnx from +https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx + +For instance, use +wget https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx + +Please refer to ./sherpa-mnn-microphone-offline.cc +to download models for offline ASR. + +(1) Transducer from icefall + + ./bin/sherpa-mnn-vad-microphone-offline-asr \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx + +(2) Paraformer from FunASR + + ./bin/sherpa-mnn-vad-microphone-offline-asr \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/model.onnx \ + --num-threads=1 + +(3) Whisper models + + ./bin/sherpa-mnn-vad-microphone-offline-asr \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --whisper-encoder=./sherpa-mnn-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-mnn-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-mnn-whisper-base.en/base.en-tokens.txt \ + --num-threads=1 +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::VadModelConfig vad_config; + + sherpa_mnn::OfflineRecognizerConfig asr_config; + + vad_config.Register(&po); + asr_config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 0) { + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", vad_config.ToString().c_str()); + fprintf(stderr, "%s\n", asr_config.ToString().c_str()); + + if (!vad_config.Validate()) { + fprintf(stderr, "Errors in vad_config!\n"); + return -1; + } + + if (!asr_config.Validate()) { + fprintf(stderr, "Errors in asr_config!\n"); + return -1; + } + + fprintf(stderr, "Creating recognizer ...\n"); + sherpa_mnn::OfflineRecognizer recognizer(asr_config); + fprintf(stderr, "Recognizer created!\n"); + + sherpa_mnn::Microphone mic; + + PaDeviceIndex num_devices = Pa_GetDeviceCount(); + fprintf(stderr, "Num devices: %d\n", num_devices); + + int32_t device_index = Pa_GetDefaultInputDevice(); + + if (device_index == paNoDevice) { + fprintf(stderr, "No default input device found\n"); + exit(EXIT_FAILURE); + } + + const char *pDeviceIndex = std::getenv("SHERPA_ONNX_MIC_DEVICE"); + if (pDeviceIndex) { + fprintf(stderr, "Use specified device: %s\n", pDeviceIndex); + device_index = atoi(pDeviceIndex); + } + + for (int32_t i = 0; i != num_devices; ++i) { + const PaDeviceInfo *info = Pa_GetDeviceInfo(i); + fprintf(stderr, " %s %d %s\n", (i == device_index) ? "*" : " ", i, + info->name); + } + + PaStreamParameters param; + param.device = device_index; + + fprintf(stderr, "Use device: %d\n", param.device); + + const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device); + fprintf(stderr, " Name: %s\n", info->name); + fprintf(stderr, " Max input channels: %d\n", info->maxInputChannels); + + param.channelCount = 1; + param.sampleFormat = paFloat32; + + param.suggestedLatency = info->defaultLowInputLatency; + param.hostApiSpecificStreamInfo = nullptr; + float mic_sample_rate = 16000; + const char *pSampleRateStr = std::getenv("SHERPA_ONNX_MIC_SAMPLE_RATE"); + if (pSampleRateStr) { + fprintf(stderr, "Use sample rate %f for mic\n", mic_sample_rate); + mic_sample_rate = atof(pSampleRateStr); + } + float sample_rate = 16000; + std::unique_ptr resampler; + if (mic_sample_rate != sample_rate) { + float min_freq = std::min(mic_sample_rate, sample_rate); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + + int32_t lowpass_filter_width = 6; + resampler = std::make_unique( + mic_sample_rate, sample_rate, lowpass_cutoff, lowpass_filter_width); + } + + PaStream *stream; + PaError err = + Pa_OpenStream(&stream, ¶m, nullptr, /* &outputParameters, */ + mic_sample_rate, + 0, // frames per buffer + paClipOff, // we won't output out of range samples + // so don't bother clipping them + RecordCallback, nullptr); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + err = Pa_StartStream(stream); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + auto vad = std::make_unique(vad_config); + + fprintf(stderr, "Started. Please speak\n"); + + int32_t window_size = vad_config.silero_vad.window_size; + int32_t index = 0; + + while (!stop) { + { + std::lock_guard lock(mutex); + + while (buffer.Size() >= window_size) { + std::vector samples = buffer.Get(buffer.Head(), window_size); + buffer.Pop(window_size); + + if (resampler) { + std::vector tmp; + resampler->Resample(samples.data(), samples.size(), true, &tmp); + samples = std::move(tmp); + } + + vad->AcceptWaveform(samples.data(), samples.size()); + } + } + + while (!vad->Empty()) { + const auto &segment = vad->Front(); + auto s = recognizer.CreateStream(); + s->AcceptWaveform(sample_rate, segment.samples.data(), + segment.samples.size()); + recognizer.DecodeStream(s.get()); + const auto &result = s->GetResult(); + if (!result.text.empty()) { + fprintf(stderr, "%2d: %s\n", index, result.text.c_str()); + ++index; + } + vad->Pop(); + } + + Pa_Sleep(100); // sleep for 100ms + } + + err = Pa_CloseStream(stream); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-vad-microphone.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-vad-microphone.cc new file mode 100644 index 00000000..8b6a0331 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-vad-microphone.cc @@ -0,0 +1,212 @@ +// sherpa-mnn/csrc/sherpa-mnn-vad-microphone.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include +#include +#include + +#include +#include // NOLINT + +#include "portaudio.h" // NOLINT +#include "sherpa-mnn/csrc/circular-buffer.h" +#include "sherpa-mnn/csrc/microphone.h" +#include "sherpa-mnn/csrc/resample.h" +#include "sherpa-mnn/csrc/voice-activity-detector.h" +#include "sherpa-mnn/csrc/wave-writer.h" + +bool stop = false; +std::mutex mutex; +sherpa_mnn::CircularBuffer buffer(16000 * 60); + +static int32_t RecordCallback(const void *input_buffer, + void * /*output_buffer*/, + unsigned long frames_per_buffer, // NOLINT + const PaStreamCallbackTimeInfo * /*time_info*/, + PaStreamCallbackFlags /*status_flags*/, + void * /*user_data*/) { + std::lock_guard lock(mutex); + buffer.Push(reinterpret_cast(input_buffer), frames_per_buffer); + + return stop ? paComplete : paContinue; +} + +static void Handler(int32_t /*sig*/) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); +} + +int32_t main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +This program shows how to use VAD in sherpa-mnn. + + ./bin/sherpa-mnn-vad-microphone \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --vad-provider=cpu \ + --vad-num-threads=1 + +Please download silero_vad.onnx from +https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx + +For instance, use +wget https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::VadModelConfig config; + + config.Register(&po); + po.Read(argc, argv); + if (po.NumArgs() != 0) { + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + sherpa_mnn::Microphone mic; + + PaDeviceIndex num_devices = Pa_GetDeviceCount(); + fprintf(stderr, "Num devices: %d\n", num_devices); + + int32_t device_index = Pa_GetDefaultInputDevice(); + + if (device_index == paNoDevice) { + fprintf(stderr, "No default input device found\n"); + fprintf(stderr, "If you are using Linux, please switch to \n"); + fprintf(stderr, " ./bin/sherpa-mnn-vad-alsa \n"); + exit(EXIT_FAILURE); + } + + const char *pDeviceIndex = std::getenv("SHERPA_ONNX_MIC_DEVICE"); + if (pDeviceIndex) { + fprintf(stderr, "Use specified device: %s\n", pDeviceIndex); + device_index = atoi(pDeviceIndex); + } + + for (int32_t i = 0; i != num_devices; ++i) { + const PaDeviceInfo *info = Pa_GetDeviceInfo(i); + fprintf(stderr, " %s %d %s\n", (i == device_index) ? "*" : " ", i, + info->name); + } + + PaStreamParameters param; + param.device = device_index; + + fprintf(stderr, "Use device: %d\n", param.device); + + const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device); + fprintf(stderr, " Name: %s\n", info->name); + fprintf(stderr, " Max input channels: %d\n", info->maxInputChannels); + + param.channelCount = 1; + param.sampleFormat = paFloat32; + + param.suggestedLatency = info->defaultLowInputLatency; + param.hostApiSpecificStreamInfo = nullptr; + float mic_sample_rate = 16000; + const char *pSampleRateStr = std::getenv("SHERPA_ONNX_MIC_SAMPLE_RATE"); + if (pSampleRateStr) { + fprintf(stderr, "Use sample rate %f for mic\n", mic_sample_rate); + mic_sample_rate = atof(pSampleRateStr); + } + float sample_rate = 16000; + + std::unique_ptr resampler; + if (mic_sample_rate != sample_rate) { + float min_freq = std::min(mic_sample_rate, sample_rate); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + + int32_t lowpass_filter_width = 6; + resampler = std::make_unique( + mic_sample_rate, sample_rate, lowpass_cutoff, lowpass_filter_width); + } + + PaStream *stream; + PaError err = + Pa_OpenStream(&stream, ¶m, nullptr, /* &outputParameters, */ + mic_sample_rate, + 0, // frames per buffer + paClipOff, // we won't output out of range samples + // so don't bother clipping them + RecordCallback, nullptr); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + err = Pa_StartStream(stream); + + auto vad = std::make_unique(config); + + fprintf(stderr, "Started\n"); + + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + int32_t window_size = config.silero_vad.window_size; + bool printed = false; + + int32_t k = 0; + while (!stop) { + { + std::lock_guard lock(mutex); + + while (buffer.Size() >= window_size) { + std::vector samples = buffer.Get(buffer.Head(), window_size); + buffer.Pop(window_size); + + if (resampler) { + std::vector tmp; + resampler->Resample(samples.data(), samples.size(), true, &tmp); + samples = std::move(tmp); + } + + vad->AcceptWaveform(samples.data(), samples.size()); + + if (vad->IsSpeechDetected() && !printed) { + printed = true; + fprintf(stderr, "\nDetected speech!\n"); + } + if (!vad->IsSpeechDetected()) { + printed = false; + } + + while (!vad->Empty()) { + const auto &segment = vad->Front(); + float duration = segment.samples.size() / sample_rate; + fprintf(stderr, "Duration: %.3f seconds\n", duration); + + char filename[128]; + snprintf(filename, sizeof(filename), "seg-%d-%.3fs.wav", k, duration); + k += 1; + sherpa_mnn::WriteWave(filename, sample_rate, segment.samples.data(), + segment.samples.size()); + fprintf(stderr, "Saved to %s\n", filename); + fprintf(stderr, "----------\n"); + + vad->Pop(); + } + } + } + Pa_Sleep(100); // sleep for 100ms + } + + err = Pa_CloseStream(stream); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-vad-with-offline-asr.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-vad-with-offline-asr.cc new file mode 100644 index 00000000..55d46730 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx-vad-with-offline-asr.cc @@ -0,0 +1,238 @@ +// sherpa-mnn/csrc/sherpa-mnn-vad-with-offline-asr.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include + +#include // NOLINT +#include +#include + +#include "sherpa-mnn/csrc/offline-recognizer.h" +#include "sherpa-mnn/csrc/parse-options.h" +#include "sherpa-mnn/csrc/resample.h" +#include "sherpa-mnn/csrc/voice-activity-detector.h" +#include "sherpa-mnn/csrc/wave-reader.h" + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Speech recognition using VAD + non-streaming models with sherpa-mnn. + +Usage: + +Note you can download silero_vad.onnx using + +wget https://github.com/k2-fsa/sherpa-mnn/releases/download/asr-models/silero_vad.onnx + +(0) FireRedAsr + +See https://k2-fsa.github.io/sherpa/onnx/FireRedAsr/pretrained.html + + ./bin/sherpa-mnn-vad-with-offline-asr \ + --tokens=./sherpa-mnn-fire-red-asr-large-zh_en-2025-02-16/tokens.txt \ + --fire-red-asr-encoder=./sherpa-mnn-fire-red-asr-large-zh_en-2025-02-16/encoder.int8.onnx \ + --fire-red-asr-decoder=./sherpa-mnn-fire-red-asr-large-zh_en-2025-02-16/decoder.int8.onnx \ + --num-threads=1 \ + --silero-vad-model=/path/to/silero_vad.onnx \ + /path/to/foo.wav + +(1) Transducer from icefall + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html + + ./bin/sherpa-mnn-vad-with-offline-asr \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=1 \ + --decoding-method=greedy_search \ + /path/to/foo.wav + + +(2) Paraformer from FunASR + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html + + ./bin/sherpa-mnn-vad-with-offline-asr \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/model.onnx \ + --num-threads=1 \ + --decoding-method=greedy_search \ + /path/to/foo.wav + +(3) Moonshine models + +See https://k2-fsa.github.io/sherpa/onnx/moonshine/index.html + + ./bin/sherpa-mnn-vad-with-offline-asr \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --moonshine-preprocessor=/Users/fangjun/open-source/sherpa-mnn/scripts/moonshine/preprocess.onnx \ + --moonshine-encoder=/Users/fangjun/open-source/sherpa-mnn/scripts/moonshine/encode.int8.onnx \ + --moonshine-uncached-decoder=/Users/fangjun/open-source/sherpa-mnn/scripts/moonshine/uncached_decode.int8.onnx \ + --moonshine-cached-decoder=/Users/fangjun/open-source/sherpa-mnn/scripts/moonshine/cached_decode.int8.onnx \ + --tokens=/Users/fangjun/open-source/sherpa-mnn/scripts/moonshine/tokens.txt \ + --num-threads=1 \ + /path/to/foo.wav + +(4) Whisper models + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html + + ./bin/sherpa-mnn-vad-with-offline-asr \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --whisper-encoder=./sherpa-mnn-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-mnn-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-mnn-whisper-base.en/base.en-tokens.txt \ + --num-threads=1 \ + /path/to/foo.wav + +(5) NeMo CTC models + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html + + ./bin/sherpa-mnn-vad-with-offline-asr \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --tokens=./sherpa-mnn-nemo-ctc-en-conformer-medium/tokens.txt \ + --nemo-ctc-model=./sherpa-mnn-nemo-ctc-en-conformer-medium/model.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + ./sherpa-mnn-nemo-ctc-en-conformer-medium/test_wavs/0.wav + +(6) TDNN CTC model for the yesno recipe from icefall + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html + + ./bin/sherpa-mnn-vad-with-offline-asr \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --sample-rate=8000 \ + --feat-dim=23 \ + --tokens=./sherpa-mnn-tdnn-yesno/tokens.txt \ + --tdnn-model=./sherpa-mnn-tdnn-yesno/model-epoch-14-avg-2.onnx \ + ./sherpa-mnn-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav + +The input wav should be of single channel, 16-bit PCM encoded wave file; its +sampling rate can be arbitrary and does not need to be 16kHz. + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::OfflineRecognizerConfig asr_config; + asr_config.Register(&po); + + sherpa_mnn::VadModelConfig vad_config; + vad_config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, "Error: Please provide at only 1 wave file. Given: %d\n\n", + po.NumArgs()); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", vad_config.ToString().c_str()); + fprintf(stderr, "%s\n", asr_config.ToString().c_str()); + + if (!vad_config.Validate()) { + fprintf(stderr, "Errors in vad_config!\n"); + return -1; + } + + if (!asr_config.Validate()) { + fprintf(stderr, "Errors in ASR config!\n"); + return -1; + } + + fprintf(stderr, "Creating recognizer ...\n"); + sherpa_mnn::OfflineRecognizer recognizer(asr_config); + fprintf(stderr, "Recognizer created!\n"); + + auto vad = std::make_unique(vad_config); + + fprintf(stderr, "Started\n"); + const auto begin = std::chrono::steady_clock::now(); + + std::string wave_filename = po.GetArg(1); + fprintf(stderr, "Reading: %s\n", wave_filename.c_str()); + int32_t sampling_rate = -1; + bool is_ok = false; + auto samples = sherpa_mnn::ReadWave(wave_filename, &sampling_rate, &is_ok); + if (!is_ok) { + fprintf(stderr, "Failed to read '%s'\n", wave_filename.c_str()); + return -1; + } + + if (sampling_rate != 16000) { + fprintf(stderr, "Resampling from %d Hz to 16000 Hz", sampling_rate); + float min_freq = std::min(sampling_rate, 16000); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + + int32_t lowpass_filter_width = 6; + auto resampler = std::make_unique( + sampling_rate, 16000, lowpass_cutoff, lowpass_filter_width); + std::vector out_samples; + resampler->Resample(samples.data(), samples.size(), true, &out_samples); + samples = std::move(out_samples); + fprintf(stderr, "Resampling done\n"); + } + + fprintf(stderr, "Started!\n"); + int32_t window_size = vad_config.silero_vad.window_size; + int32_t i = 0; + while (i + window_size < samples.size()) { + vad->AcceptWaveform(samples.data() + i, window_size); + i += window_size; + if (i >= samples.size()) { + vad->Flush(); + } + + while (!vad->Empty()) { + const auto &segment = vad->Front(); + float duration = segment.samples.size() / 16000.; + float start_time = segment.start / 16000.; + float end_time = start_time + duration; + if (duration < 0.1) { + vad->Pop(); + continue; + } + + auto s = recognizer.CreateStream(); + s->AcceptWaveform(16000, segment.samples.data(), segment.samples.size()); + recognizer.DecodeStream(s.get()); + const auto &result = s->GetResult(); + if (!result.text.empty()) { + fprintf(stderr, "%.3f -- %.3f: %s\n", start_time, end_time, + result.text.c_str()); + } + vad->Pop(); + } + } + + const auto end = std::chrono::steady_clock::now(); + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + fprintf(stderr, "num threads: %d\n", asr_config.model_config.num_threads); + fprintf(stderr, "decoding method: %s\n", asr_config.decoding_method.c_str()); + if (asr_config.decoding_method == "modified_beam_search") { + fprintf(stderr, "max active paths: %d\n", asr_config.max_active_paths); + } + + float duration = samples.size() / 16000.; + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx.cc new file mode 100644 index 00000000..8293613d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/sherpa-onnx.cc @@ -0,0 +1,175 @@ +// sherpa-mnn/csrc/sherpa-mnn.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include + +#include // NOLINT +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/online-recognizer.h" +#include "sherpa-mnn/csrc/online-stream.h" +#include "sherpa-mnn/csrc/parse-options.h" +#include "sherpa-mnn/csrc/symbol-table.h" +#include "sherpa-mnn/csrc/wave-reader.h" + +typedef struct { + std::unique_ptr online_stream; + float duration; + float elapsed_seconds; +} Stream; + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Usage: + +(1) Streaming transducer + + ./bin/sherpa-mnn \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --provider=cpu \ + --num-threads=2 \ + --decoding-method=greedy_search \ + /path/to/foo.wav [bar.wav foobar.wav ...] + +(2) Streaming zipformer2 CTC + + wget -q https://github.com/k2-fsa/sherpa-mnn/releases/download/asr-models/sherpa-mnn-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + tar xvf sherpa-mnn-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2 + + ./bin/sherpa-mnn \ + --debug=1 \ + --zipformer2-ctc-model=./sherpa-mnn-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx \ + --tokens=./sherpa-mnn-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/tokens.txt \ + ./sherpa-mnn-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000000.wav \ + ./sherpa-mnn-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000001.wav \ + ./sherpa-mnn-streaming-zipformer-ctc-multi-zh-hans-2023-12-13/test_wavs/DEV_T0000000002.wav + +(3) Streaming paraformer + + wget https://github.com/k2-fsa/sherpa-mnn/releases/download/asr-models/sherpa-mnn-streaming-paraformer-bilingual-zh-en.tar.bz2 + tar xvf sherpa-mnn-streaming-paraformer-bilingual-zh-en.tar.bz2 + + ./bin/sherpa-mnn \ + --tokens=./sherpa-mnn-streaming-paraformer-bilingual-zh-en/tokens.txt \ + --paraformer-encoder=./sherpa-mnn-streaming-paraformer-bilingual-zh-en/encoder.onnx \ + --paraformer-decoder=./sherpa-mnn-streaming-paraformer-bilingual-zh-en/decoder.onnx \ + ./sherpa-mnn-streaming-paraformer-bilingual-zh-en/test_wavs/0.wav + +Note: It supports decoding multiple files in batches + +Default value for num_threads is 2. +Valid values for decoding_method: greedy_search (default), modified_beam_search. +Valid values for provider: cpu (default), cuda, coreml. +foo.wav should be of single channel, 16-bit PCM encoded wave file; its +sampling rate can be arbitrary and does not need to be 16kHz. + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. +)usage"; + + sherpa_mnn::ParseOptions po(kUsageMessage); + sherpa_mnn::OnlineRecognizerConfig config; + + config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() < 1) { + po.PrintUsage(); + fprintf(stderr, "Error! Please provide at lease 1 wav file\n"); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + sherpa_mnn::OnlineRecognizer recognizer(config); + + std::vector ss; + + const auto begin = std::chrono::steady_clock::now(); + std::vector durations; + + for (int32_t i = 1; i <= po.NumArgs(); ++i) { + const std::string wav_filename = po.GetArg(i); + int32_t sampling_rate = -1; + + bool is_ok = false; + const std::vector samples = + sherpa_mnn::ReadWave(wav_filename, &sampling_rate, &is_ok); + + if (!is_ok) { + fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str()); + return -1; + } + + const float duration = samples.size() / static_cast(sampling_rate); + + auto s = recognizer.CreateStream(); + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); + + std::vector tail_paddings(static_cast(0.8 * sampling_rate)); + // Note: We can call AcceptWaveform() multiple times. + s->AcceptWaveform(sampling_rate, tail_paddings.data(), + tail_paddings.size()); + + // Call InputFinished() to indicate that no audio samples are available + s->InputFinished(); + ss.push_back({std::move(s), duration, 0}); + } + + std::vector ready_streams; + for (;;) { + ready_streams.clear(); + for (auto &s : ss) { + const auto p_ss = s.online_stream.get(); + if (recognizer.IsReady(p_ss)) { + ready_streams.push_back(p_ss); + } else if (s.elapsed_seconds == 0) { + const auto end = std::chrono::steady_clock::now(); + const float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + s.elapsed_seconds = elapsed_seconds; + } + } + + if (ready_streams.empty()) { + break; + } + + recognizer.DecodeStreams(ready_streams.data(), ready_streams.size()); + } + + std::ostringstream os; + for (int32_t i = 1; i <= po.NumArgs(); ++i) { + const auto &s = ss[i - 1]; + const float rtf = s.elapsed_seconds / s.duration; + + os << po.GetArg(i) << "\n"; + os << "Number of threads: " << config.model_config.num_threads << ", " + << std::setprecision(2) << "Elapsed seconds: " << s.elapsed_seconds + << ", Audio duration (s): " << s.duration + << ", Real time factor (RTF) = " << s.elapsed_seconds << "/" + << s.duration << " = " << rtf << "\n"; + const auto r = recognizer.GetResult(s.online_stream.get()); + os << r.text << "\n"; + os << r.AsJsonString() << "\n\n"; + } + + std::cerr << os.str(); + + return 0; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/silero-vad-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/silero-vad-model-config.cc new file mode 100644 index 00000000..b689afd0 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/silero-vad-model-config.cc @@ -0,0 +1,116 @@ +// sherpa-mnn/csrc/silero-vad-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/silero-vad-model-config.h" + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +void SileroVadModelConfig::Register(ParseOptions *po) { + po->Register("silero-vad-model", &model, "Path to silero VAD ONNX model."); + + po->Register("silero-vad-threshold", &threshold, + "Speech threshold. Silero VAD outputs speech probabilities for " + "each audio chunk, probabilities ABOVE this value are " + "considered as SPEECH. It is better to tune this parameter for " + "each dataset separately, but lazy " + "0.5 is pretty good for most datasets."); + + po->Register( + "silero-vad-min-silence-duration", &min_silence_duration, + "In seconds. In the end of each speech chunk wait for " + "--silero-vad-min-silence-duration seconds before separating it"); + + po->Register("silero-vad-min-speech-duration", &min_speech_duration, + "In seconds. In the end of each silence chunk wait for " + "--silero-vad-min-speech-duration seconds before separating it"); + + po->Register( + "silero-vad-max-speech-duration", &max_speech_duration, + "In seconds. If a speech segment is longer than this value, then we " + "increase the threshold to 0.9. After finishing detecting the segment, " + "the threshold value is reset to its original value."); + + po->Register( + "silero-vad-window-size", &window_size, + "In samples. Audio chunks of --silero-vad-window-size samples are fed " + "to the silero VAD model. WARNING! Silero VAD models were trained using " + "512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples " + "for 8000 sample rate. Values other than these may affect model " + "perfomance!"); +} + +bool SileroVadModelConfig::Validate() const { + if (model.empty()) { + SHERPA_ONNX_LOGE("Please provide --silero-vad-model"); + return false; + } + + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("Silero vad model file '%s' does not exist", + model.c_str()); + return false; + } + + if (threshold < 0.01) { + SHERPA_ONNX_LOGE( + "Please use a larger value for --silero-vad-threshold. Given: %f", + threshold); + return false; + } + + if (threshold >= 1) { + SHERPA_ONNX_LOGE( + "Please use a smaller value for --silero-vad-threshold. Given: %f", + threshold); + return false; + } + + if (min_silence_duration <= 0) { + SHERPA_ONNX_LOGE( + "Please use a larger value for --silero-vad-min-silence-duration. " + "Given: " + "%f", + min_silence_duration); + return false; + } + + if (min_speech_duration <= 0) { + SHERPA_ONNX_LOGE( + "Please use a larger value for --silero-vad-min-speech-duration. " + "Given: " + "%f", + min_speech_duration); + return false; + } + + if (max_speech_duration <= 0) { + SHERPA_ONNX_LOGE( + "Please use a larger value for --silero-vad-max-speech-duration. " + "Given: " + "%f", + max_speech_duration); + return false; + } + + return true; +} + +std::string SileroVadModelConfig::ToString() const { + std::ostringstream os; + + os << "SileroVadModelConfig("; + os << "model=\"" << model << "\", "; + os << "threshold=" << threshold << ", "; + os << "min_silence_duration=" << min_silence_duration << ", "; + os << "min_speech_duration=" << min_speech_duration << ", "; + os << "max_speech_duration=" << max_speech_duration << ", "; + os << "window_size=" << window_size << ")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/silero-vad-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/silero-vad-model-config.h new file mode 100644 index 00000000..98112e24 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/silero-vad-model-config.h @@ -0,0 +1,46 @@ +// sherpa-mnn/csrc/silero-vad-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SILERO_VAD_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_SILERO_VAD_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct SileroVadModelConfig { + std::string model; + + // threshold to classify a segment as speech + // + // If the predicted probability of a segment is larger than this + // value, then it is classified as speech. + float threshold = 0.5; + + float min_silence_duration = 0.5; // in seconds + + float min_speech_duration = 0.25; // in seconds + + // 512, 1024, 1536 samples for 16000 Hz + // 256, 512, 768 samples for 800 Hz + int32_t window_size = 512; // in samples + + // If a speech segment is longer than this value, then we increase + // the threshold to 0.9. After finishing detecting the segment, + // the threshold value is reset to its original value. + float max_speech_duration = 20; // in seconds + + SileroVadModelConfig() = default; + + void Register(ParseOptions *po); + + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_SILERO_VAD_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/silero-vad-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/silero-vad-model.cc new file mode 100644 index 00000000..ea7a085b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/silero-vad-model.cc @@ -0,0 +1,482 @@ +// sherpa-mnn/csrc/silero-vad-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/silero-vad-model.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" + +namespace sherpa_mnn { + +class SileroVadModel::Impl { + public: + explicit Impl(const VadModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{}, + sample_rate_(config.sample_rate) { + auto buf = ReadFile(config.silero_vad.model); + Init(buf.data(), buf.size()); + + if (sample_rate_ != 16000) { + SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d", + config.sample_rate); + exit(-1); + } + + min_silence_samples_ = + sample_rate_ * config_.silero_vad.min_silence_duration; + + min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration; + } + + template + Impl(Manager *mgr, const VadModelConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{}, + sample_rate_(config.sample_rate) { + auto buf = ReadFile(mgr, config.silero_vad.model); + Init(buf.data(), buf.size()); + + if (sample_rate_ != 16000) { + SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d", + config.sample_rate); + exit(-1); + } + + min_silence_samples_ = + sample_rate_ * config_.silero_vad.min_silence_duration; + + min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration; + } + + void Reset() { + if (is_v5_) { + ResetV5(); + } else { + ResetV4(); + } + + triggered_ = false; + current_sample_ = 0; + temp_start_ = 0; + temp_end_ = 0; + } + + bool IsSpeech(const float *samples, int32_t n) { + if (n != WindowSize()) { + SHERPA_ONNX_LOGE("n: %d != window_size: %d", n, WindowSize()); + exit(-1); + } + + float prob = Run(samples, n); + + float threshold = config_.silero_vad.threshold; + + current_sample_ += config_.silero_vad.window_size; + + if (prob > threshold && temp_end_ != 0) { + temp_end_ = 0; + } + + if (prob > threshold && temp_start_ == 0) { + // start speaking, but we require that it must satisfy + // min_speech_duration + temp_start_ = current_sample_; + return false; + } + + if (prob > threshold && temp_start_ != 0 && !triggered_) { + if (current_sample_ - temp_start_ < min_speech_samples_) { + return false; + } + + triggered_ = true; + + return true; + } + + if ((prob < threshold) && !triggered_) { + // silence + temp_start_ = 0; + temp_end_ = 0; + return false; + } + + if ((prob > threshold - 0.15) && triggered_) { + // speaking + return true; + } + + if ((prob > threshold) && !triggered_) { + // start speaking + triggered_ = true; + + return true; + } + + if ((prob < threshold) && triggered_) { + // stop to speak + if (temp_end_ == 0) { + temp_end_ = current_sample_; + } + + if (current_sample_ - temp_end_ < min_silence_samples_) { + // continue speaking + return true; + } + // stopped speaking + temp_start_ = 0; + temp_end_ = 0; + triggered_ = false; + return false; + } + + return false; + } + + int32_t WindowShift() const { return config_.silero_vad.window_size; } + + int32_t WindowSize() const { + return config_.silero_vad.window_size + window_overlap_; + } + + int32_t MinSilenceDurationSamples() const { return min_silence_samples_; } + + int32_t MinSpeechDurationSamples() const { return min_speech_samples_; } + + void SetMinSilenceDuration(float s) { + min_silence_samples_ = sample_rate_ * s; + } + + void SetThreshold(float threshold) { + config_.silero_vad.threshold = threshold; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + if (input_names_.size() == 4 && output_names_.size() == 3) { + is_v5_ = false; + } else if (input_names_.size() == 3 && output_names_.size() == 2) { + is_v5_ = true; + + // 64 for 16kHz + // 32 for 8kHz + window_overlap_ = 64; + + if (config_.silero_vad.window_size != 512) { + SHERPA_ONNX_LOGE( + "For silero_vad v5, we require window_size to be 512 for 16kHz"); + exit(-1); + } + } else { + SHERPA_ONNX_LOGE("Unsupported silero vad model"); + exit(-1); + } + + Check(); + + Reset(); + } + + void ResetV5() { + // 2 - number of LSTM layer + // 1 - batch size + // 128 - hidden dim + std::array shape{2, 1, 128}; + + MNN::Express::VARP s = + MNNUtilsCreateTensor(allocator_, shape.data(), shape.size()); + + Fill(s, 0); + states_.clear(); + states_.push_back(std::move(s)); + } + + void ResetV4() { + // 2 - number of LSTM layer + // 1 - batch size + // 64 - hidden dim + std::array shape{2, 1, 64}; + + MNN::Express::VARP h = + MNNUtilsCreateTensor(allocator_, shape.data(), shape.size()); + + MNN::Express::VARP c = + MNNUtilsCreateTensor(allocator_, shape.data(), shape.size()); + + Fill(h, 0); + Fill(c, 0); + + states_.clear(); + + states_.reserve(2); + states_.push_back(std::move(h)); + states_.push_back(std::move(c)); + } + + void Check() const { + if (is_v5_) { + CheckV5(); + } else { + CheckV4(); + } + } + + void CheckV4() const { + if (input_names_.size() != 4) { + SHERPA_ONNX_LOGE("Expect 4 inputs. Given: %d", + static_cast(input_names_.size())); + exit(-1); + } + + if (input_names_[0] != "input") { + SHERPA_ONNX_LOGE("Input[0]: %s. Expected: input", + input_names_[0].c_str()); + exit(-1); + } + + if (input_names_[1] != "sr") { + SHERPA_ONNX_LOGE("Input[1]: %s. Expected: sr", input_names_[1].c_str()); + exit(-1); + } + + if (input_names_[2] != "h") { + SHERPA_ONNX_LOGE("Input[2]: %s. Expected: h", input_names_[2].c_str()); + exit(-1); + } + + if (input_names_[3] != "c") { + SHERPA_ONNX_LOGE("Input[3]: %s. Expected: c", input_names_[3].c_str()); + exit(-1); + } + + // Now for outputs + if (output_names_.size() != 3) { + SHERPA_ONNX_LOGE("Expect 3 outputs. Given: %d", + static_cast(output_names_.size())); + exit(-1); + } + + if (output_names_[0] != "output") { + SHERPA_ONNX_LOGE("Output[0]: %s. Expected: output", + output_names_[0].c_str()); + exit(-1); + } + + if (output_names_[1] != "hn") { + SHERPA_ONNX_LOGE("Output[1]: %s. Expected: sr", output_names_[1].c_str()); + exit(-1); + } + + if (output_names_[2] != "cn") { + SHERPA_ONNX_LOGE("Output[2]: %s. Expected: sr", output_names_[2].c_str()); + exit(-1); + } + } + + void CheckV5() const { + if (input_names_.size() != 3) { + SHERPA_ONNX_LOGE("Expect 3 inputs. Given: %d", + static_cast(input_names_.size())); + exit(-1); + } + + if (input_names_[0] != "input") { + SHERPA_ONNX_LOGE("Input[0]: %s. Expected: input", + input_names_[0].c_str()); + exit(-1); + } + + if (input_names_[1] != "state") { + SHERPA_ONNX_LOGE("Input[1]: %s. Expected: state", + input_names_[1].c_str()); + exit(-1); + } + + if (input_names_[2] != "sr") { + SHERPA_ONNX_LOGE("Input[2]: %s. Expected: sr", input_names_[2].c_str()); + exit(-1); + } + + // Now for outputs + if (output_names_.size() != 2) { + SHERPA_ONNX_LOGE("Expect 2 outputs. Given: %d", + static_cast(output_names_.size())); + exit(-1); + } + + if (output_names_[0] != "output") { + SHERPA_ONNX_LOGE("Output[0]: %s. Expected: output", + output_names_[0].c_str()); + exit(-1); + } + + if (output_names_[1] != "stateN") { + SHERPA_ONNX_LOGE("Output[1]: %s. Expected: stateN", + output_names_[1].c_str()); + exit(-1); + } + } + + float Run(const float *samples, int32_t n) { + if (is_v5_) { + return RunV5(samples, n); + } else { + return RunV4(samples, n); + } + } + + float RunV5(const float *samples, int32_t n) { + auto memory_info = + (MNNAllocator*)(nullptr); + + std::array x_shape = {1, n}; + + MNN::Express::VARP x = + MNNUtilsCreateTensor(memory_info, const_cast(samples), n, + x_shape.data(), x_shape.size()); + + int sr_shape = 1; + MNN::Express::VARP sr = + MNNUtilsCreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1); + + std::vector inputs = {std::move(x), std::move(states_[0]), + std::move(sr)}; + + auto out = + sess_->onForward(inputs); + + states_[0] = std::move(out[1]); + + float prob = out[0]->readMap()[0]; + return prob; + } + + float RunV4(const float *samples, int32_t n) { + auto memory_info = + (MNNAllocator*)(nullptr); + + std::array x_shape = {1, n}; + + MNN::Express::VARP x = + MNNUtilsCreateTensor(memory_info, const_cast(samples), n, + x_shape.data(), x_shape.size()); + + int sr_shape = 1; + MNN::Express::VARP sr = + MNNUtilsCreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1); + + std::vector inputs = {std::move(x), std::move(sr), + std::move(states_[0]), + std::move(states_[1])}; + + auto out = + sess_->onForward(inputs); + + states_[0] = std::move(out[1]); + states_[1] = std::move(out[2]); + + float prob = out[0]->readMap()[0]; + return prob; + } + + private: + VadModelConfig config_; + + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + std::vector states_; + int sample_rate_; + int32_t min_silence_samples_; + int32_t min_speech_samples_; + + bool triggered_ = false; + int32_t current_sample_ = 0; + int32_t temp_start_ = 0; + int32_t temp_end_ = 0; + + int32_t window_overlap_ = 0; + + bool is_v5_ = false; +}; + +SileroVadModel::SileroVadModel(const VadModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +SileroVadModel::SileroVadModel(Manager *mgr, const VadModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +SileroVadModel::~SileroVadModel() = default; + +void SileroVadModel::Reset() { return impl_->Reset(); } + +bool SileroVadModel::IsSpeech(const float *samples, int32_t n) { + return impl_->IsSpeech(samples, n); +} + +int32_t SileroVadModel::WindowSize() const { return impl_->WindowSize(); } + +int32_t SileroVadModel::WindowShift() const { return impl_->WindowShift(); } + +int32_t SileroVadModel::MinSilenceDurationSamples() const { + return impl_->MinSilenceDurationSamples(); +} + +int32_t SileroVadModel::MinSpeechDurationSamples() const { + return impl_->MinSpeechDurationSamples(); +} + +void SileroVadModel::SetMinSilenceDuration(float s) { + impl_->SetMinSilenceDuration(s); +} + +void SileroVadModel::SetThreshold(float threshold) { + impl_->SetThreshold(threshold); +} + +#if __ANDROID_API__ >= 9 +template SileroVadModel::SileroVadModel(AAssetManager *mgr, + const VadModelConfig &config); +#endif + +#if __OHOS__ +template SileroVadModel::SileroVadModel(NativeResourceManager *mgr, + const VadModelConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/silero-vad-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/silero-vad-model.h new file mode 100644 index 00000000..07bbff67 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/silero-vad-model.h @@ -0,0 +1,55 @@ +// sherpa-mnn/csrc/silero-vad-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SILERO_VAD_MODEL_H_ +#define SHERPA_ONNX_CSRC_SILERO_VAD_MODEL_H_ + +#include + +#include "sherpa-mnn/csrc/vad-model.h" + +namespace sherpa_mnn { + +class SileroVadModel : public VadModel { + public: + explicit SileroVadModel(const VadModelConfig &config); + + template + SileroVadModel(Manager *mgr, const VadModelConfig &config); + + ~SileroVadModel() override; + + // reset the internal model states + void Reset() override; + + /** + * @param samples Pointer to a 1-d array containing audio samples. + * Each sample should be normalized to the range [-1, 1]. + * @param n Number of samples. + * + * @return Return true if speech is detected. Return false otherwise. + */ + bool IsSpeech(const float *samples, int32_t n) override; + + // For silero vad V4, it is WindowShift(). + // For silero vad V5, it is WindowShift()+64 for 16kHz and + // WindowShift()+32 for 8kHz + int32_t WindowSize() const override; + + // 512 + int32_t WindowShift() const override; + + int32_t MinSilenceDurationSamples() const override; + int32_t MinSpeechDurationSamples() const override; + + void SetMinSilenceDuration(float s) override; + void SetThreshold(float threshold) override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_SILERO_VAD_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/slice-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/slice-test.cc new file mode 100644 index 00000000..c503c2ce --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/slice-test.cc @@ -0,0 +1,52 @@ +// sherpa-mnn/csrc/slice-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/slice.h" + +#include + +#include "gtest/gtest.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +TEST(Slice, Slice3D) { + MNNAllocator* allocator; + std::array shape{5, 5, 4}; + MNN::Express::VARP v = + MNNUtilsCreateTensor(allocator, shape.data(), shape.size()); + float *p = v->writeMap(); + + std::iota(p, p + shape[0] * shape[1] * shape[2], 0); + + auto v1 = Slice(allocator, &v, 2, 4, 0, 2); + auto v2 = Slice(allocator, &v, 1, 3, 1, 3); + + Print3D(&v); + Print3D(&v1); + Print3D(&v2); + + // TODO(fangjun): Check that the results are correct +} + +TEST(Slice, Slice2D) { + MNNAllocator* allocator; + std::array shape{5, 8}; + MNN::Express::VARP v = + MNNUtilsCreateTensor(allocator, shape.data(), shape.size()); + float *p = v->writeMap(); + + std::iota(p, p + shape[0] * shape[1], 0); + + auto v1 = Slice(allocator, &v, 1, 3); + auto v2 = Slice(allocator, &v, 0, 2); + + Print2D(&v); + Print2D(&v1); + Print2D(&v2); + + // TODO(fangjun): Check that the results are correct +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/slice.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/slice.cc new file mode 100644 index 00000000..b5ba757b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/slice.cc @@ -0,0 +1,77 @@ +// sherpa-mnn/csrc/slice.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/slice.h" + +#include +#include +#include + +namespace sherpa_mnn { + +template +MNN::Express::VARP Slice(MNNAllocator *allocator, MNN::Express::VARP v, + int32_t dim0_start, int32_t dim0_end, int32_t dim1_start, + int32_t dim1_end) { + std::vector shape = v->getInfo()->dim; + assert(shape.size() == 3); + + assert(0 <= dim0_start); + assert(dim0_start < dim0_end); + assert(dim0_end <= shape[0]); + + assert(0 <= dim1_start); + assert(dim1_start < dim1_end); + assert(dim1_end <= shape[1]); + + std::array ans_shape{dim0_end - dim0_start, dim1_end - dim1_start, + shape[2]}; + + MNN::Express::VARP ans = MNNUtilsCreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + T *dst = ans->writeMap(); + for (int32_t i = dim0_start; i != dim0_end; ++i) { + const T *src = v->readMap() + i * shape[1] * shape[2]; + const T *start = src + dim1_start * shape[2]; + const T *end = src + dim1_end * shape[2]; + + std::copy(start, end, dst); + dst += ans_shape[1] * ans_shape[2]; + } + + return ans; +} + +template +MNN::Express::VARP Slice(MNNAllocator *allocator, MNN::Express::VARP v, + int32_t dim0_start, int32_t dim0_end) { + std::vector shape = v->getInfo()->dim; + assert(shape.size() == 2); + + assert(0 <= dim0_start); + assert(dim0_start < dim0_end); + assert(dim0_end <= shape[0]); + + const T *src = v->readMap(); + + std::array ans_shape{dim0_end - dim0_start, shape[1]}; + + MNN::Express::VARP ans = MNNUtilsCreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + const T *start = v->readMap() + dim0_start * shape[1]; + const T *end = v->readMap() + dim0_end * shape[1]; + T *dst = ans->writeMap(); + std::copy(start, end, dst); + + return ans; +} + +template MNN::Express::VARP Slice(MNNAllocator *allocator, MNN::Express::VARP v, + int32_t dim0_start, int32_t dim0_end, + int32_t dim1_start, int32_t dim1_end); + +template MNN::Express::VARP Slice(MNNAllocator *allocator, MNN::Express::VARP v, + int32_t dim0_start, int32_t dim0_end); + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/slice.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/slice.h new file mode 100644 index 00000000..f6dfca2a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/slice.h @@ -0,0 +1,48 @@ +// sherpa-mnn/csrc/slice.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SLICE_H_ +#define SHERPA_ONNX_CSRC_SLICE_H_ + +#include "MNNUtils.hpp" // NOLINT + +namespace sherpa_mnn { + +/** Get a deep copy by slicing a 3-D tensor v. + * + * It returns v[dim0_start:dim0_end, dim1_start:dim1_end, :] + * + * @param allocator + * @param v A 3-D tensor. Its data type is T. + * @param dim0_start Start index of the first dimension.. + * @param dim0_end End index of the first dimension.. + * @param dim1_start Start index of the second dimension. + * @param dim1_end End index of the second dimension. + * + * @return Return a 3-D tensor of shape + * (dim0_end-dim0_start, dim1_end-dim1_start, v.shape[2]) + */ +template +MNN::Express::VARP Slice(MNNAllocator *allocator, MNN::Express::VARP v, + int32_t dim0_start, int32_t dim0_end, int32_t dim1_start, + int32_t dim1_end); + +/** Get a deep copy by slicing a 2-D tensor v. + * + * It returns v[dim0_start:dim0_end, :] + * + * @param allocator + * @param v A 2-D tensor. Its data type is T. + * @param dim0_start Start index of the first dimension.. + * @param dim0_end End index of the first dimension.. + * + * @return Return a 2-D tensor of shape + * (dim0_end-dim0_start, v.shape[1]) + */ +template +MNN::Express::VARP Slice(MNNAllocator *allocator, MNN::Express::VARP v, + int32_t dim0_start, int32_t dim0_end); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_SLICE_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-general-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-general-impl.h new file mode 100644 index 00000000..80fc4224 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-general-impl.h @@ -0,0 +1,117 @@ +// sherpa-mnn/csrc/speaker-embedding-extractor-general-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ +#include +#include +#include +#include + +#include "Eigen/Dense" +#include "sherpa-mnn/csrc/speaker-embedding-extractor-impl.h" +#include "sherpa-mnn/csrc/speaker-embedding-extractor-model.h" + +namespace sherpa_mnn { + +class SpeakerEmbeddingExtractorGeneralImpl + : public SpeakerEmbeddingExtractorImpl { + public: + explicit SpeakerEmbeddingExtractorGeneralImpl( + const SpeakerEmbeddingExtractorConfig &config) + : model_(config) {} + + template + SpeakerEmbeddingExtractorGeneralImpl( + Manager *mgr, const SpeakerEmbeddingExtractorConfig &config) + : model_(mgr, config) {} + + int32_t Dim() const override { return model_.GetMetaData().output_dim; } + + std::unique_ptr CreateStream() const override { + FeatureExtractorConfig feat_config; + const auto &meta_data = model_.GetMetaData(); + feat_config.sampling_rate = meta_data.sample_rate; + feat_config.normalize_samples = meta_data.normalize_samples; + + return std::make_unique(feat_config); + } + + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() < s->NumFramesReady(); + } + + std::vector Compute(OnlineStream *s) const override { + int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames(); + if (num_frames <= 0) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "Please make sure IsReady(s) returns true. num_frames: %{public}d", + num_frames); +#else + SHERPA_ONNX_LOGE( + "Please make sure IsReady(s) returns true. num_frames: %d", + num_frames); +#endif + return {}; + } + + std::vector features = + s->GetFrames(s->GetNumProcessedFrames(), num_frames); + + s->GetNumProcessedFrames() += num_frames; + + int32_t feat_dim = features.size() / num_frames; + + const auto &meta_data = model_.GetMetaData(); + if (!meta_data.feature_normalize_type.empty()) { + if (meta_data.feature_normalize_type == "global-mean") { + SubtractGlobalMean(features.data(), num_frames, feat_dim); + } else { +#if __OHOS__ + SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %{public}s", + meta_data.feature_normalize_type.c_str()); +#else + SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s", + meta_data.feature_normalize_type.c_str()); +#endif + exit(-1); + } + } + + auto memory_info = + (MNNAllocator*)(nullptr); + + std::array x_shape{1, num_frames, feat_dim}; + MNN::Express::VARP x = + MNNUtilsCreateTensor(memory_info, features.data(), features.size(), + x_shape.data(), x_shape.size()); + MNN::Express::VARP embedding = model_.Compute(std::move(x)); + std::vector embedding_shape = + embedding->getInfo()->dim; + + std::vector ans(embedding_shape[1]); + std::copy(embedding->readMap(), + embedding->readMap() + ans.size(), ans.begin()); + + return ans; + } + + private: + void SubtractGlobalMean(float *p, int32_t num_frames, + int32_t feat_dim) const { + auto m = Eigen::Map< + Eigen::Matrix>( + p, num_frames, feat_dim); + + m = m.rowwise() - m.colwise().mean(); + } + + private: + SpeakerEmbeddingExtractorModel model_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-impl.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-impl.cc new file mode 100644 index 00000000..4961fef1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-impl.cc @@ -0,0 +1,154 @@ +// sherpa-mnn/csrc/speaker-embedding-extractor-impl.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-mnn/csrc/speaker-embedding-extractor-impl.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/speaker-embedding-extractor-general-impl.h" +#include "sherpa-mnn/csrc/speaker-embedding-extractor-nemo-impl.h" + +namespace sherpa_mnn { + +namespace { + +enum class ModelType : std::uint8_t { + kWeSpeaker, + k3dSpeaker, + kNeMo, + kUnknown, +}; + +} // namespace + +static ModelType GetModelType(char *model_data, size_t model_data_length, + bool debug) { + MNNEnv env; + std::shared_ptr sess_opts; + + + + auto sess = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts)); + + MNNMeta meta_data = sess->getInfo()->metaData; + if (debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; + auto model_type = + LookupCustomModelMetaData(meta_data, "framework", allocator); + if (model_type.empty()) { + SHERPA_ONNX_LOGE( + "No model_type in the metadata!\n" + "Please make sure you have added metadata to the model.\n\n" + "For instance, you can use\n" + "https://github.com/k2-fsa/sherpa-mnn/blob/master/scripts/wespeaker/" + "add_meta_data.py" + "to add metadata to models from WeSpeaker\n"); + return ModelType::kUnknown; + } + + if (model_type == "wespeaker") { + return ModelType::kWeSpeaker; + } else if (model_type == "3d-speaker") { + return ModelType::k3dSpeaker; + } else if (model_type == "nemo") { + return ModelType::kNeMo; + } else { +#if __OHOS__ + SHERPA_ONNX_LOGE("Unsupported model_type: %{public}s", model_type.c_str()); +#else + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str()); +#endif + return ModelType::kUnknown; + } +} + +std::unique_ptr +SpeakerEmbeddingExtractorImpl::Create( + const SpeakerEmbeddingExtractorConfig &config) { + ModelType model_type = ModelType::kUnknown; + + { + auto buffer = ReadFile(config.model); + + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); + } + + switch (model_type) { + case ModelType::kWeSpeaker: + // fall through + case ModelType::k3dSpeaker: + return std::make_unique(config); + case ModelType::kNeMo: + return std::make_unique(config); + case ModelType::kUnknown: + SHERPA_ONNX_LOGE("Unknown model type for speaker embedding extractor!"); + return nullptr; + } + + // unreachable code + return nullptr; +} + +template +std::unique_ptr +SpeakerEmbeddingExtractorImpl::Create( + Manager *mgr, const SpeakerEmbeddingExtractorConfig &config) { + ModelType model_type = ModelType::kUnknown; + + { + auto buffer = ReadFile(mgr, config.model); + + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); + } + + switch (model_type) { + case ModelType::kWeSpeaker: + // fall through + case ModelType::k3dSpeaker: + return std::make_unique(mgr, + config); + case ModelType::kNeMo: + return std::make_unique(mgr, config); + case ModelType::kUnknown: + SHERPA_ONNX_LOGE( + "Unknown model type in for speaker embedding extractor!"); + return nullptr; + } + + // unreachable code + return nullptr; +} + +#if __ANDROID_API__ >= 9 +template std::unique_ptr +SpeakerEmbeddingExtractorImpl::Create( + AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr +SpeakerEmbeddingExtractorImpl::Create( + NativeResourceManager *mgr, const SpeakerEmbeddingExtractorConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-impl.h new file mode 100644 index 00000000..5af020ce --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-impl.h @@ -0,0 +1,38 @@ +// sherpa-mnn/csrc/speaker-embedding-extractor-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ + +#include +#include +#include + +#include "sherpa-mnn/csrc/speaker-embedding-extractor.h" + +namespace sherpa_mnn { + +class SpeakerEmbeddingExtractorImpl { + public: + virtual ~SpeakerEmbeddingExtractorImpl() = default; + + static std::unique_ptr Create( + const SpeakerEmbeddingExtractorConfig &config); + + template + static std::unique_ptr Create( + Manager *mgr, const SpeakerEmbeddingExtractorConfig &config); + + virtual int32_t Dim() const = 0; + + virtual std::unique_ptr CreateStream() const = 0; + + virtual bool IsReady(OnlineStream *s) const = 0; + + virtual std::vector Compute(OnlineStream *s) const = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-model-meta-data.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-model-meta-data.h new file mode 100644 index 00000000..05680df3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-model-meta-data.h @@ -0,0 +1,28 @@ +// sherpa-mnn/csrc/speaker-embedding-extractor-model-meta-data.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_ + +#include +#include + +namespace sherpa_mnn { + +struct SpeakerEmbeddingExtractorModelMetaData { + int32_t output_dim = 0; + int32_t sample_rate = 0; + + // for wespeaker models, it is 0; + // for 3d-speaker models, it is 1 + int32_t normalize_samples = 1; + + // Chinese, English, etc. + std::string language; + + // for 3d-speaker, it is global-mean + std::string feature_normalize_type; +}; + +} // namespace sherpa_mnn +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-model.cc new file mode 100644 index 00000000..0f8d43fe --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-model.cc @@ -0,0 +1,156 @@ +// sherpa-mnn/csrc/speaker-embedding-extractor-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/speaker-embedding-extractor-model.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/speaker-embedding-extractor-model-meta-data.h" + +namespace sherpa_mnn { + +class SpeakerEmbeddingExtractorModel::Impl { + public: + explicit Impl(const SpeakerEmbeddingExtractorConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.model); + Init(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const SpeakerEmbeddingExtractorConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.model); + Init(buf.data(), buf.size()); + } + } + + MNN::Express::VARP Compute(MNN::Express::VARP x) const { + std::vector inputs = {std::move(x)}; + + auto outputs = + sess_->onForward(inputs); + return std::move(outputs[0]); + } + + const SpeakerEmbeddingExtractorModelMetaData &GetMetaData() const { + return meta_data_; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim"); + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); + SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_samples, + "normalize_samples"); + SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); + + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT( + meta_data_.feature_normalize_type, "feature_normalize_type", ""); + + std::string framework; + SHERPA_ONNX_READ_META_DATA_STR(framework, "framework"); + if (framework != "wespeaker" && framework != "3d-speaker") { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "Expect a wespeaker or a 3d-speaker model, given: %{public}s", + framework.c_str()); +#else + SHERPA_ONNX_LOGE("Expect a wespeaker or a 3d-speaker model, given: %s", + framework.c_str()); +#endif + exit(-1); + } + } + + private: + SpeakerEmbeddingExtractorConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + SpeakerEmbeddingExtractorModelMetaData meta_data_; +}; + +SpeakerEmbeddingExtractorModel::SpeakerEmbeddingExtractorModel( + const SpeakerEmbeddingExtractorConfig &config) + : impl_(std::make_unique(config)) {} + +template +SpeakerEmbeddingExtractorModel::SpeakerEmbeddingExtractorModel( + Manager *mgr, const SpeakerEmbeddingExtractorConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +SpeakerEmbeddingExtractorModel::~SpeakerEmbeddingExtractorModel() = default; + +const SpeakerEmbeddingExtractorModelMetaData & +SpeakerEmbeddingExtractorModel::GetMetaData() const { + return impl_->GetMetaData(); +} + +MNN::Express::VARP SpeakerEmbeddingExtractorModel::Compute(MNN::Express::VARP x) const { + return impl_->Compute(std::move(x)); +} + +#if __ANDROID_API__ >= 9 +template SpeakerEmbeddingExtractorModel::SpeakerEmbeddingExtractorModel( + AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config); +#endif + +#if __OHOS__ +template SpeakerEmbeddingExtractorModel::SpeakerEmbeddingExtractorModel( + NativeResourceManager *mgr, const SpeakerEmbeddingExtractorConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-model.h new file mode 100644 index 00000000..034adb4e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-model.h @@ -0,0 +1,41 @@ +// sherpa-mnn/csrc/speaker-embedding-extractor-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/speaker-embedding-extractor-model-meta-data.h" +#include "sherpa-mnn/csrc/speaker-embedding-extractor.h" + +namespace sherpa_mnn { + +class SpeakerEmbeddingExtractorModel { + public: + explicit SpeakerEmbeddingExtractorModel( + const SpeakerEmbeddingExtractorConfig &config); + + template + SpeakerEmbeddingExtractorModel(Manager *mgr, + const SpeakerEmbeddingExtractorConfig &config); + + ~SpeakerEmbeddingExtractorModel(); + + const SpeakerEmbeddingExtractorModelMetaData &GetMetaData() const; + + /** + * @param x A float32 tensor of shape (N, T, C) + * @return A float32 tensor of shape (N, C) + */ + MNN::Express::VARP Compute(MNN::Express::VARP x) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-nemo-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-nemo-impl.h new file mode 100644 index 00000000..002e3341 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-nemo-impl.h @@ -0,0 +1,145 @@ +// sherpa-mnn/csrc/speaker-embedding-extractor-nemo-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_ +#include +#include +#include +#include + +#include "Eigen/Dense" +#include "sherpa-mnn/csrc/speaker-embedding-extractor-impl.h" +#include "sherpa-mnn/csrc/speaker-embedding-extractor-nemo-model.h" +#include "sherpa-mnn/csrc/transpose.h" + +namespace sherpa_mnn { + +class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl { + public: + explicit SpeakerEmbeddingExtractorNeMoImpl( + const SpeakerEmbeddingExtractorConfig &config) + : model_(config) {} + + template + SpeakerEmbeddingExtractorNeMoImpl( + Manager *mgr, const SpeakerEmbeddingExtractorConfig &config) + : model_(mgr, config) {} + + int32_t Dim() const override { return model_.GetMetaData().output_dim; } + + std::unique_ptr CreateStream() const override { + FeatureExtractorConfig feat_config; + const auto &meta_data = model_.GetMetaData(); + feat_config.sampling_rate = meta_data.sample_rate; + feat_config.feature_dim = meta_data.feat_dim; + feat_config.normalize_samples = true; + feat_config.snip_edges = true; + feat_config.frame_shift_ms = meta_data.window_stride_ms; + feat_config.frame_length_ms = meta_data.window_size_ms; + feat_config.low_freq = 0; + feat_config.is_librosa = true; + feat_config.remove_dc_offset = false; + feat_config.window_type = meta_data.window_type; + + return std::make_unique(feat_config); + } + + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() < s->NumFramesReady(); + } + + std::vector Compute(OnlineStream *s) const override { + int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames(); + if (num_frames <= 0) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "Please make sure IsReady(s) returns true. num_frames: %{public}d", + num_frames); +#else + SHERPA_ONNX_LOGE( + "Please make sure IsReady(s) returns true. num_frames: %d", + num_frames); +#endif + return {}; + } + + std::vector features = + s->GetFrames(s->GetNumProcessedFrames(), num_frames); + + s->GetNumProcessedFrames() += num_frames; + + int32_t feat_dim = features.size() / num_frames; + + const auto &meta_data = model_.GetMetaData(); + if (!meta_data.feature_normalize_type.empty()) { + if (meta_data.feature_normalize_type == "per_feature") { + NormalizePerFeature(features.data(), num_frames, feat_dim); + } else { +#if __OHOS__ + SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %{public}s", + meta_data.feature_normalize_type.c_str()); +#else + + SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s", + meta_data.feature_normalize_type.c_str()); +#endif + exit(-1); + } + } + + if (num_frames % 16 != 0) { + int32_t pad = 16 - num_frames % 16; + features.resize((num_frames + pad) * feat_dim); + } + + auto memory_info = + (MNNAllocator*)(nullptr); + + std::array x_shape{1, num_frames, feat_dim}; + MNN::Express::VARP x = + MNNUtilsCreateTensor(memory_info, features.data(), features.size(), + x_shape.data(), x_shape.size()); + + x = Transpose12(model_.Allocator(), x); + + int x_lens = num_frames; + std::array x_lens_shape{1}; + MNN::Express::VARP x_lens_tensor = MNNUtilsCreateTensor( + memory_info, &x_lens, 1, x_lens_shape.data(), x_lens_shape.size()); + + MNN::Express::VARP embedding = + model_.Compute(std::move(x), std::move(x_lens_tensor)); + std::vector embedding_shape = + embedding->getInfo()->dim; + + std::vector ans(embedding_shape[1]); + std::copy(embedding->readMap(), + embedding->readMap() + ans.size(), ans.begin()); + + return ans; + } + + private: + void NormalizePerFeature(float *p, int32_t num_frames, + int32_t feat_dim) const { + auto m = Eigen::Map< + Eigen::Matrix>( + p, num_frames, feat_dim); + + auto EX = m.colwise().mean(); + auto EX2 = m.array().pow(2).colwise().sum() / num_frames; + auto variance = EX2 - EX.array().pow(2); + auto stddev = variance.array().sqrt(); + + m = (m.rowwise() - EX).array().rowwise() / (stddev.array() + 1e-5); + } + + private: + SpeakerEmbeddingExtractorNeMoModel model_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-nemo-model-meta-data.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-nemo-model-meta-data.h new file mode 100644 index 00000000..0ae8caf3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-nemo-model-meta-data.h @@ -0,0 +1,28 @@ +// sherpa-mnn/csrc/speaker-embedding-extractor-nemo-model-meta-data.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_ + +#include +#include + +namespace sherpa_mnn { + +struct SpeakerEmbeddingExtractorNeMoModelMetaData { + int32_t output_dim = 0; + int32_t feat_dim = 80; + int32_t sample_rate = 0; + int32_t window_size_ms = 25; + int32_t window_stride_ms = 25; + + // Chinese, English, etc. + std::string language; + + // for 3d-speaker, it is global-mean + std::string feature_normalize_type; + std::string window_type = "hann"; +}; + +} // namespace sherpa_mnn +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-nemo-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-nemo-model.cc new file mode 100644 index 00000000..38e3bbd6 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-nemo-model.cc @@ -0,0 +1,169 @@ +// sherpa-mnn/csrc/speaker-embedding-extractor-nemo-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/speaker-embedding-extractor-nemo-model.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/session.h" +#include "sherpa-mnn/csrc/speaker-embedding-extractor-nemo-model-meta-data.h" + +namespace sherpa_mnn { + +class SpeakerEmbeddingExtractorNeMoModel::Impl { + public: + explicit Impl(const SpeakerEmbeddingExtractorConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.model); + Init(buf.data(), buf.size()); + } + } + + template + Impl(Manager *mgr, const SpeakerEmbeddingExtractorConfig &config) + : config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.model); + Init(buf.data(), buf.size()); + } + } + + MNN::Express::VARP Compute(MNN::Express::VARP x, MNN::Express::VARP x_lens) const { + std::vector inputs = {std::move(x), std::move(x_lens)}; + + // output_names_ptr_[0] is logits + // output_names_ptr_[1] is embeddings + // so we use output_names_ptr_.data() + 1 here to extract only the + // embeddings + auto outputs = sess_->onForward(inputs); + return std::move(outputs[0]); + } + + MNNAllocator *Allocator() { return allocator_; } + + const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const { + return meta_data_; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts_.pManager, &sess_opts_.pConfig)); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + MNNMeta meta_data = sess_->getInfo()->metaData; + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + MNNAllocator* allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim"); + SHERPA_ONNX_READ_META_DATA(meta_data_.feat_dim, "feat_dim"); + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); + SHERPA_ONNX_READ_META_DATA(meta_data_.window_size_ms, "window_size_ms"); + SHERPA_ONNX_READ_META_DATA(meta_data_.window_stride_ms, "window_stride_ms"); + SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); + + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT( + meta_data_.feature_normalize_type, "feature_normalize_type", ""); + + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.window_type, + "window_type", "povey"); + + std::string framework; + SHERPA_ONNX_READ_META_DATA_STR(framework, "framework"); + if (framework != "nemo") { +#if __OHOS__ + SHERPA_ONNX_LOGE("Expect a NeMo model, given: %{public}s", + framework.c_str()); +#else + SHERPA_ONNX_LOGE("Expect a NeMo model, given: %s", framework.c_str()); +#endif + exit(-1); + } + } + + private: + SpeakerEmbeddingExtractorConfig config_; + MNNEnv env_; + MNNConfig sess_opts_; + MNNAllocator* allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + SpeakerEmbeddingExtractorNeMoModelMetaData meta_data_; +}; + +SpeakerEmbeddingExtractorNeMoModel::SpeakerEmbeddingExtractorNeMoModel( + const SpeakerEmbeddingExtractorConfig &config) + : impl_(std::make_unique(config)) {} + +template +SpeakerEmbeddingExtractorNeMoModel::SpeakerEmbeddingExtractorNeMoModel( + Manager *mgr, const SpeakerEmbeddingExtractorConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +SpeakerEmbeddingExtractorNeMoModel::~SpeakerEmbeddingExtractorNeMoModel() = + default; + +const SpeakerEmbeddingExtractorNeMoModelMetaData & +SpeakerEmbeddingExtractorNeMoModel::GetMetaData() const { + return impl_->GetMetaData(); +} + +MNN::Express::VARP SpeakerEmbeddingExtractorNeMoModel::Compute( + MNN::Express::VARP x, MNN::Express::VARP x_lens) const { + return impl_->Compute(std::move(x), std::move(x_lens)); +} + +MNNAllocator *SpeakerEmbeddingExtractorNeMoModel::Allocator() const { + return impl_->Allocator(); +} + +#if __ANDROID_API__ >= 9 +template SpeakerEmbeddingExtractorNeMoModel::SpeakerEmbeddingExtractorNeMoModel( + AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config); +#endif + +#if __OHOS__ +template SpeakerEmbeddingExtractorNeMoModel::SpeakerEmbeddingExtractorNeMoModel( + NativeResourceManager *mgr, const SpeakerEmbeddingExtractorConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-nemo-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-nemo-model.h new file mode 100644 index 00000000..9df34697 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor-nemo-model.h @@ -0,0 +1,44 @@ +// sherpa-mnn/csrc/speaker-embedding-extractor-nemo-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT +#include "sherpa-mnn/csrc/speaker-embedding-extractor-nemo-model-meta-data.h" +#include "sherpa-mnn/csrc/speaker-embedding-extractor.h" + +namespace sherpa_mnn { + +class SpeakerEmbeddingExtractorNeMoModel { + public: + explicit SpeakerEmbeddingExtractorNeMoModel( + const SpeakerEmbeddingExtractorConfig &config); + + template + SpeakerEmbeddingExtractorNeMoModel( + Manager *mgr, const SpeakerEmbeddingExtractorConfig &config); + + ~SpeakerEmbeddingExtractorNeMoModel(); + + const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const; + + /** + * @param x A float32 tensor of shape (N, C, T) + * @param x_len A int64 tensor of shape (N,) + * @return A float32 tensor of shape (N, C) + */ + MNN::Express::VARP Compute(MNN::Express::VARP x, MNN::Express::VARP x_len) const; + + MNNAllocator *Allocator() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor.cc new file mode 100644 index 00000000..820eefc9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor.cc @@ -0,0 +1,98 @@ +// sherpa-mnn/csrc/speaker-embedding-extractor.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/speaker-embedding-extractor.h" + +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/speaker-embedding-extractor-impl.h" + +namespace sherpa_mnn { + +void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) { + po->Register("model", &model, "Path to the speaker embedding model."); + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool SpeakerEmbeddingExtractorConfig::Validate() const { + if (model.empty()) { + SHERPA_ONNX_LOGE("Please provide a speaker embedding extractor model"); + return false; + } + + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("speaker embedding extractor model: '%s' does not exist", + model.c_str()); + return false; + } + + return true; +} + +std::string SpeakerEmbeddingExtractorConfig::ToString() const { + std::ostringstream os; + + os << "SpeakerEmbeddingExtractorConfig("; + os << "model=\"" << model << "\", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +SpeakerEmbeddingExtractor::SpeakerEmbeddingExtractor( + const SpeakerEmbeddingExtractorConfig &config) + : impl_(SpeakerEmbeddingExtractorImpl::Create(config)) {} + +template +SpeakerEmbeddingExtractor::SpeakerEmbeddingExtractor( + Manager *mgr, const SpeakerEmbeddingExtractorConfig &config) + : impl_(SpeakerEmbeddingExtractorImpl::Create(mgr, config)) {} + +SpeakerEmbeddingExtractor::~SpeakerEmbeddingExtractor() = default; + +int32_t SpeakerEmbeddingExtractor::Dim() const { return impl_->Dim(); } + +std::unique_ptr SpeakerEmbeddingExtractor::CreateStream() const { + return impl_->CreateStream(); +} + +bool SpeakerEmbeddingExtractor::IsReady(OnlineStream *s) const { + return impl_->IsReady(s); +} + +std::vector SpeakerEmbeddingExtractor::Compute(OnlineStream *s) const { + return impl_->Compute(s); +} + +#if __ANDROID_API__ >= 9 +template SpeakerEmbeddingExtractor::SpeakerEmbeddingExtractor( + AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config); +#endif + +#if __OHOS__ +template SpeakerEmbeddingExtractor::SpeakerEmbeddingExtractor( + NativeResourceManager *mgr, const SpeakerEmbeddingExtractorConfig &config); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor.h new file mode 100644 index 00000000..60aba324 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-extractor.h @@ -0,0 +1,71 @@ +// sherpa-mnn/csrc/speaker-embedding-extractor.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ + +#include +#include +#include + +#include "sherpa-mnn/csrc/online-stream.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct SpeakerEmbeddingExtractorConfig { + std::string model; + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + SpeakerEmbeddingExtractorConfig() = default; + SpeakerEmbeddingExtractorConfig(const std::string &model, int32_t num_threads, + bool debug, const std::string &provider) + : model(model), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + bool Validate() const; + std::string ToString() const; +}; + +class SpeakerEmbeddingExtractorImpl; + +class SpeakerEmbeddingExtractor { + public: + explicit SpeakerEmbeddingExtractor( + const SpeakerEmbeddingExtractorConfig &config); + + template + SpeakerEmbeddingExtractor(Manager *mgr, + const SpeakerEmbeddingExtractorConfig &config); + + ~SpeakerEmbeddingExtractor(); + + // Return the dimension of the embedding + int32_t Dim() const; + + // Create a stream to accept audio samples and compute features + std::unique_ptr CreateStream() const; + + // Return true if there are feature frames in OnlineStream that + // can be used to compute embeddings. + bool IsReady(OnlineStream *s) const; + + // Compute the speaker embedding from the available unprocessed features + // of the given stream + // + // You have to ensure IsReady(s) returns true before you call this method. + std::vector Compute(OnlineStream *s) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-manager-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-manager-test.cc new file mode 100644 index 00000000..0dbe218a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-manager-test.cc @@ -0,0 +1,147 @@ +// sherpa-mnn/csrc/speaker-embedding-manager-test.cc +// +// Copyright (c) 2024 Jingzhao Ou (jingzhao.ou@gmail.com) + +#include "sherpa-mnn/csrc/speaker-embedding-manager.h" + +#include "gtest/gtest.h" + +namespace sherpa_mnn { + +TEST(SpeakerEmbeddingManager, AddAndRemove) { + int32_t dim = 2; + SpeakerEmbeddingManager manager(dim); + std::vector v = {0.1, 0.1}; + bool status = manager.Add("first", v.data()); + ASSERT_TRUE(status); + ASSERT_EQ(manager.NumSpeakers(), 1); + + // duplicate + status = manager.Add("first", v.data()); + ASSERT_FALSE(status); + ASSERT_EQ(manager.NumSpeakers(), 1); + + // non-duplicate + v = {0.1, 0.9}; + status = manager.Add("second", v.data()); + ASSERT_TRUE(status); + ASSERT_EQ(manager.NumSpeakers(), 2); + + // do not exist + status = manager.Remove("third"); + ASSERT_FALSE(status); + + status = manager.Remove("first"); + ASSERT_TRUE(status); + ASSERT_EQ(manager.NumSpeakers(), 1); + + v = {0.1, 0.1}; + status = manager.Add("first", v.data()); + ASSERT_TRUE(status); + ASSERT_EQ(manager.NumSpeakers(), 2); + + status = manager.Remove("first"); + ASSERT_TRUE(status); + ASSERT_EQ(manager.NumSpeakers(), 1); + + status = manager.Remove("second"); + ASSERT_TRUE(status); + ASSERT_EQ(manager.NumSpeakers(), 0); +} + +TEST(SpeakerEmbeddingManager, Search) { + int32_t dim = 2; + SpeakerEmbeddingManager manager(dim); + std::vector v1 = {0.1, 0.1}; + std::vector v2 = {0.1, 0.9}; + std::vector v3 = {0.9, 0.1}; + bool status = manager.Add("first", v1.data()); + ASSERT_TRUE(status); + + status = manager.Add("second", v2.data()); + ASSERT_TRUE(status); + + status = manager.Add("third", v3.data()); + ASSERT_TRUE(status); + + ASSERT_EQ(manager.NumSpeakers(), 3); + + std::vector v = {15, 16}; + float threshold = 0.9; + + std::string name = manager.Search(v.data(), threshold); + EXPECT_EQ(name, "first"); + + v = {2, 17}; + name = manager.Search(v.data(), threshold); + EXPECT_EQ(name, "second"); + + v = {17, 2}; + name = manager.Search(v.data(), threshold); + EXPECT_EQ(name, "third"); + + threshold = 0.9; + v = {15, 16}; + status = manager.Remove("first"); + ASSERT_TRUE(status); + name = manager.Search(v.data(), threshold); + EXPECT_EQ(name, ""); + + v = {17, 2}; + status = manager.Remove("third"); + ASSERT_TRUE(status); + name = manager.Search(v.data(), threshold); + EXPECT_EQ(name, ""); + + v = {2, 17}; + status = manager.Remove("second"); + ASSERT_TRUE(status); + name = manager.Search(v.data(), threshold); + EXPECT_EQ(name, ""); + + ASSERT_EQ(manager.NumSpeakers(), 0); +} + +TEST(SpeakerEmbeddingManager, Verify) { + int32_t dim = 2; + SpeakerEmbeddingManager manager(dim); + std::vector v1 = {0.1, 0.1}; + std::vector v2 = {0.1, 0.9}; + std::vector v3 = {0.9, 0.1}; + bool status = manager.Add("first", v1.data()); + ASSERT_TRUE(status); + + status = manager.Add("second", v2.data()); + ASSERT_TRUE(status); + + status = manager.Add("third", v3.data()); + ASSERT_TRUE(status); + + std::vector v = {15, 16}; + float threshold = 0.9; + + status = manager.Verify("first", v.data(), threshold); + ASSERT_TRUE(status); + + v = {2, 17}; + status = manager.Verify("first", v.data(), threshold); + ASSERT_FALSE(status); + + status = manager.Verify("second", v.data(), threshold); + ASSERT_TRUE(status); + + v = {17, 2}; + status = manager.Verify("first", v.data(), threshold); + ASSERT_FALSE(status); + + status = manager.Verify("second", v.data(), threshold); + ASSERT_FALSE(status); + + status = manager.Verify("third", v.data(), threshold); + ASSERT_TRUE(status); + + status = manager.Verify("fourth", v.data(), threshold); + ASSERT_FALSE(status); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-manager.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-manager.cc new file mode 100644 index 00000000..b6277e4a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-manager.cc @@ -0,0 +1,286 @@ +// sherpa-mnn/csrc/speaker-embedding-manager.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/speaker-embedding-manager.h" + +#include +#include +#include + +#include "Eigen/Dense" +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { + +using FloatMatrix = + Eigen::Matrix; + +class SpeakerEmbeddingManager::Impl { + public: + explicit Impl(int32_t dim) : dim_(dim) {} + + bool Add(const std::string &name, const float *p) { + if (name2row_.count(name)) { + // a speaker with the same name already exists + return false; + } + + embedding_matrix_.conservativeResize(embedding_matrix_.rows() + 1, dim_); + + std::copy(p, p + dim_, &embedding_matrix_.bottomRows(1)(0, 0)); + + embedding_matrix_.bottomRows(1).normalize(); // inplace + + name2row_[name] = embedding_matrix_.rows() - 1; + row2name_[embedding_matrix_.rows() - 1] = name; + + return true; + } + + bool Add(const std::string &name, + const std::vector> &embedding_list) { + if (name2row_.count(name)) { + // a speaker with the same name already exists + return false; + } + + if (embedding_list.empty()) { + SHERPA_ONNX_LOGE("Empty list of embeddings"); + return false; + } + + for (const auto &x : embedding_list) { + if (static_cast(x.size()) != dim_) { + SHERPA_ONNX_LOGE("Given dim: %d, expected dim: %d", + static_cast(x.size()), dim_); + return false; + } + } + + // compute the average + Eigen::RowVectorXf v = Eigen::Map( + const_cast(embedding_list[0].data()), dim_); + int32_t i = -1; + for (const auto &x : embedding_list) { + ++i; + if (i == 0) { + continue; + } + v += Eigen::Map(const_cast(x.data()), dim_); + } + + // no need to compute the mean since we are going to normalize it anyway + // v /= embedding_list.size(); + + v.normalize(); + + embedding_matrix_.conservativeResize(embedding_matrix_.rows() + 1, dim_); + embedding_matrix_.bottomRows(1) = v; + + name2row_[name] = embedding_matrix_.rows() - 1; + row2name_[embedding_matrix_.rows() - 1] = name; + + return true; + } + + bool Remove(const std::string &name) { + if (!name2row_.count(name)) { + return false; + } + + int32_t row_idx = name2row_.at(name); + + int32_t num_rows = embedding_matrix_.rows(); + + if (row_idx < num_rows - 1) { + embedding_matrix_.block(row_idx, 0, num_rows - 1 - row_idx, dim_) = + embedding_matrix_.bottomRows(num_rows - 1 - row_idx); + } + + embedding_matrix_.conservativeResize(num_rows - 1, dim_); + for (auto &p : name2row_) { + if (p.second > row_idx) { + p.second -= 1; + row2name_[p.second] = p.first; + } + } + + name2row_.erase(name); + row2name_.erase(num_rows - 1); + + return true; + } + + std::string Search(const float *p, float threshold) { + if (embedding_matrix_.rows() == 0) { + return {}; + } + + Eigen::VectorXf v = + Eigen::Map(const_cast(p), dim_); + v.normalize(); + + Eigen::VectorXf scores = embedding_matrix_ * v; + + Eigen::VectorXf::Index max_index = 0; + float max_score = scores.maxCoeff(&max_index); + if (max_score < threshold) { + return {}; + } + + return row2name_.at(max_index); + } + + std::vector GetBestMatches(const float *p, float threshold, + int32_t n) { + std::vector matches; + + if (embedding_matrix_.rows() == 0) { + return matches; + } + + Eigen::VectorXf v = + Eigen::Map(const_cast(p), dim_); + v.normalize(); + + Eigen::VectorXf scores = embedding_matrix_ * v; + + std::vector> score_indices; + for (int i = 0; i < scores.size(); ++i) { + if (scores[i] >= threshold) { + score_indices.emplace_back(scores[i], i); + } + } + + std::sort(score_indices.rbegin(), score_indices.rend(), + [](const auto &a, const auto &b) { return a.first < b.first; }); + + matches.reserve(score_indices.size()); + for (int i = 0; i < std::min(n, static_cast(score_indices.size())); + ++i) { + const auto &pair = score_indices[i]; + matches.push_back({row2name_.at(pair.second), pair.first}); + } + + return matches; + } + + bool Verify(const std::string &name, const float *p, float threshold) { + if (!name2row_.count(name)) { + return false; + } + + int32_t row_idx = name2row_.at(name); + + Eigen::VectorXf v = + Eigen::Map(const_cast(p), dim_); + v.normalize(); + + float score = embedding_matrix_.row(row_idx) * v; + + if (score < threshold) { + return false; + } + + return true; + } + + float Score(const std::string &name, const float *p) { + if (!name2row_.count(name)) { + // Setting a default value if the name is not found + return -2.0; + } + + int32_t row_idx = name2row_.at(name); + + Eigen::VectorXf v = + Eigen::Map(const_cast(p), dim_); + v.normalize(); + + float score = embedding_matrix_.row(row_idx) * v; + + return score; + } + + bool Contains(const std::string &name) const { + return name2row_.count(name) > 0; + } + + int32_t NumSpeakers() const { return embedding_matrix_.rows(); } + + int32_t Dim() const { return dim_; } + + std::vector GetAllSpeakers() const { + std::vector all_speakers; + all_speakers.reserve(name2row_.size()); + for (const auto &p : name2row_) { + all_speakers.push_back(p.first); + } + + std::sort(all_speakers.begin(), all_speakers.end()); + return all_speakers; + } + + private: + int32_t dim_; + FloatMatrix embedding_matrix_; + std::unordered_map name2row_; + std::unordered_map row2name_; +}; + +SpeakerEmbeddingManager::SpeakerEmbeddingManager(int32_t dim) + : impl_(std::make_unique(dim)) {} + +SpeakerEmbeddingManager::~SpeakerEmbeddingManager() = default; + +bool SpeakerEmbeddingManager::Add(const std::string &name, + const float *p) const { + return impl_->Add(name, p); +} + +bool SpeakerEmbeddingManager::Add( + const std::string &name, + const std::vector> &embedding_list) const { + return impl_->Add(name, embedding_list); +} + +bool SpeakerEmbeddingManager::Remove(const std::string &name) const { + return impl_->Remove(name); +} + +std::string SpeakerEmbeddingManager::Search(const float *p, + float threshold) const { + return impl_->Search(p, threshold); +} + +std::vector SpeakerEmbeddingManager::GetBestMatches( + const float *p, float threshold, int32_t n) const { + return impl_->GetBestMatches(p, threshold, n); +} + +bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p, + float threshold) const { + return impl_->Verify(name, p, threshold); +} + +float SpeakerEmbeddingManager::Score(const std::string &name, + const float *p) const { + return impl_->Score(name, p); +} + +int32_t SpeakerEmbeddingManager::NumSpeakers() const { + return impl_->NumSpeakers(); +} + +int32_t SpeakerEmbeddingManager::Dim() const { return impl_->Dim(); } + +bool SpeakerEmbeddingManager::Contains(const std::string &name) const { + return impl_->Contains(name); +} + +std::vector SpeakerEmbeddingManager::GetAllSpeakers() const { + return impl_->GetAllSpeakers(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-manager.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-manager.h new file mode 100644 index 00000000..92728807 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/speaker-embedding-manager.h @@ -0,0 +1,120 @@ +// sherpa-mnn/csrc/speaker-embedding-manager.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ + +#include +#include +#include + +struct SpeakerMatch { + const std::string name; + float score; +}; + +namespace sherpa_mnn { + +class SpeakerEmbeddingManager { + public: + // @param dim Embedding dimension. + explicit SpeakerEmbeddingManager(int32_t dim); + ~SpeakerEmbeddingManager(); + + /* Add the embedding and name of a speaker to the manager. + * + * @param name Name of the speaker + * @param p Pointer to the embedding. Its length is `dim`. + * @return Return true if added successfully. Return false if it failed. + * At present, the only reason for a failure is that there is already + * a speaker with the same `name`. + */ + bool Add(const std::string &name, const float *p) const; + + /** Add a list of embeddings of a speaker. + * + * @param name Name of the speaker + * @param embedding_list A list of embeddings. Each entry should be of size + * `dim`. The average of the list is the final + * embedding. + * @return Return true if added successfully. Return false if it failed. + * At present, the only reason for a failure is that there is already + * a speaker with the same `name`. + */ + bool Add(const std::string &name, + const std::vector> &embedding_list) const; + + /* Remove a speaker by its name. + * + * @param name Name of the speaker to remove. + * @return Return true if it is removed successfully. Return false + * if there is no such a speaker. + */ + bool Remove(const std::string &name) const; + + /** It is for speaker identification. + * + * It computes the cosine similarity between and given embedding and all + * other embeddings and find the embedding that has the largest score + * and the score is above or equal to threshold. Return the speaker + * name for the embedding if found; otherwise, it returns an empty string. + * + * @param p The input embedding. + * @param threshold A value between 0 and 1. + * @param If found, return the name of the speaker. Otherwise, return an + * empty string. + */ + std::string Search(const float *p, float threshold) const; + + /** + * It is for speaker identification. + * + * It computes the cosine similarity between a given embedding and all + * other embeddings and finds the embeddings that have the largest scores + * and the scores are above or equal to the threshold. Returns a vector of + * SpeakerMatch structures containing the speaker names and scores for the + * embeddings if found; otherwise, returns an empty vector. + * + * @param p A pointer to the input embedding. + * @param threshold A value between 0 and 1. + * @param n The number of top matches to return. + * @return A vector of SpeakerMatch structures. If matches are found, the + * vector contains the names and scores of the speakers. Otherwise, + * it returns an empty vector. + */ + std::vector GetBestMatches(const float *p, float threshold, + int32_t n) const; + + /* Check whether the input embedding matches the embedding of the input + * speaker. + * + * It is for speaker verification. + * + * @param name The target speaker name. + * @param p The input embedding to check. + * @param threshold A value between 0 and 1. + * @return Return true if it matches. Otherwise, it returns false. + */ + bool Verify(const std::string &name, const float *p, float threshold) const; + + float Score(const std::string &name, const float *p) const; + + // Return true if the given speaker already exists; return false otherwise. + bool Contains(const std::string &name) const; + + int32_t NumSpeakers() const; + + int32_t Dim() const; + + // Return a list of speaker names + std::vector GetAllSpeakers() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/spoken-language-identification-impl.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/spoken-language-identification-impl.cc new file mode 100644 index 00000000..07c5b159 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/spoken-language-identification-impl.cc @@ -0,0 +1,123 @@ +// sherpa-mnn/csrc/spoken-language-identification-impl.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-mnn/csrc/spoken-language-identification-impl.h" + +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/spoken-language-identification-whisper-impl.h" + +namespace sherpa_mnn { + +namespace { + +enum class ModelType : std::uint8_t { + kWhisper, + kUnknown, +}; + +} + +static ModelType GetModelType(char *model_data, size_t model_data_length, + bool debug) { + MNNEnv env; + std::shared_ptr sess_opts; + + auto sess = std::unique_ptr(MNN::Express::Module::load({}, {}, (const uint8_t*)model_data, model_data_length, + sess_opts)); + + MNNMeta meta_data = sess->getInfo()->metaData; + if (debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + MNNAllocator* allocator; + auto model_type = + LookupCustomModelMetaData(meta_data, "model_type", allocator); + if (model_type.empty()) { + SHERPA_ONNX_LOGE( + "No model_type in the metadata!\n" + "Please make sure you have added metadata to the model.\n\n" + "For instance, you can use\n" + "https://github.com/k2-fsa/sherpa-mnn/blob/master/scripts/whisper/" + "export-onnx.py " + "to add metadata to models from whisper\n"); + return ModelType::kUnknown; + } + + if (model_type.find("whisper") == 0) { + return ModelType::kWhisper; + } else { + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.c_str()); + return ModelType::kUnknown; + } +} + +std::unique_ptr +SpokenLanguageIdentificationImpl::Create( + const SpokenLanguageIdentificationConfig &config) { + ModelType model_type = ModelType::kUnknown; + { + if (config.whisper.encoder.empty()) { + SHERPA_ONNX_LOGE("Only whisper models are supported at present"); + exit(-1); + } + auto buffer = ReadFile(config.whisper.encoder); + + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); + } + + switch (model_type) { + case ModelType::kWhisper: + return std::make_unique(config); + case ModelType::kUnknown: + SHERPA_ONNX_LOGE( + "Unknown model type for spoken language identification!"); + return nullptr; + } + + // unreachable code + return nullptr; +} + +#if __ANDROID_API__ >= 9 +std::unique_ptr +SpokenLanguageIdentificationImpl::Create( + AAssetManager *mgr, const SpokenLanguageIdentificationConfig &config) { + ModelType model_type = ModelType::kUnknown; + { + if (config.whisper.encoder.empty()) { + SHERPA_ONNX_LOGE("Only whisper models are supported at present"); + exit(-1); + } + auto buffer = ReadFile(mgr, config.whisper.encoder); + + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); + } + + switch (model_type) { + case ModelType::kWhisper: + return std::make_unique(mgr, + config); + case ModelType::kUnknown: + SHERPA_ONNX_LOGE( + "Unknown model type for spoken language identification!"); + return nullptr; + } + + // unreachable code + return nullptr; +} +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/spoken-language-identification-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/spoken-language-identification-impl.h new file mode 100644 index 00000000..cf6b1619 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/spoken-language-identification-impl.h @@ -0,0 +1,38 @@ +// sherpa-mnn/csrc/spoken-language-identification-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_ +#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_ + +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/spoken-language-identification.h" + +namespace sherpa_mnn { + +class SpokenLanguageIdentificationImpl { + public: + virtual ~SpokenLanguageIdentificationImpl() = default; + + static std::unique_ptr Create( + const SpokenLanguageIdentificationConfig &config); + +#if __ANDROID_API__ >= 9 + static std::unique_ptr Create( + AAssetManager *mgr, const SpokenLanguageIdentificationConfig &config); +#endif + + virtual std::unique_ptr CreateStream() const = 0; + + virtual std::string Compute(OfflineStream *s) const = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/spoken-language-identification-whisper-impl.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/spoken-language-identification-whisper-impl.h new file mode 100644 index 00000000..7458cdd8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/spoken-language-identification-whisper-impl.h @@ -0,0 +1,123 @@ +// sherpa-mnn/csrc/spoken-language-identification-whisper-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_ +#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_ + +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/offline-whisper-model.h" +#include "sherpa-mnn/csrc/spoken-language-identification-impl.h" +#include "sherpa-mnn/csrc/transpose.h" + +namespace sherpa_mnn { + +class SpokenLanguageIdentificationWhisperImpl + : public SpokenLanguageIdentificationImpl { + public: + explicit SpokenLanguageIdentificationWhisperImpl( + const SpokenLanguageIdentificationConfig &config) + : config_(config), model_(std::make_unique(config)) { + Check(); + } + +#if __ANDROID_API__ >= 9 + SpokenLanguageIdentificationWhisperImpl( + AAssetManager *mgr, const SpokenLanguageIdentificationConfig &config) + : config_(config), + model_(std::make_unique(mgr, config)) { + Check(); + } +#endif + + std::unique_ptr CreateStream() const override { + return std::make_unique(WhisperTag{}); + } + + std::string Compute(OfflineStream *s) const override { + int32_t max_num_frames = 3000; + auto memory_info = + (MNNAllocator*)(nullptr); + + int32_t feat_dim = s->FeatureDim(); + std::vector f = s->GetFrames(); + int32_t num_frames = f.size() / feat_dim; + + // we use 50 here so that there will be some zero tail paddings + if (num_frames >= max_num_frames - 50) { + SHERPA_ONNX_LOGE( + "Only waves less than 30 seconds are supported. We process only the " + "first 30 seconds and discard the remaining data"); + num_frames = max_num_frames - 50; + } + + model_->NormalizeFeatures(f.data(), num_frames, feat_dim); + + // note that 1000 is an experience-value. + // You can replace 1000 by other values, say, 100. + // + // Since we have removed the 30 seconds constraint, we need + // tail_padding_frames so that whisper is able to detect the eot token. + int32_t tail_padding_frames = 1000; + + if (config_.whisper.tail_paddings > 0) { + tail_padding_frames = config_.whisper.tail_paddings; + } + + int32_t actual_frames = + std::min(num_frames + tail_padding_frames, max_num_frames); + + std::array shape{1, actual_frames, feat_dim}; + + MNN::Express::VARP mel = MNNUtilsCreateTensor( + model_->Allocator(), shape.data(), shape.size()); + + float *p_mel = mel->writeMap(); + std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel); + + std::fill_n(p_mel + num_frames * feat_dim, + (actual_frames - num_frames) * feat_dim, 0); + + mel = Transpose12(model_->Allocator(), mel); + + auto cross_kv = model_->ForwardEncoder(std::move(mel)); + int32_t lang_id = model_->DetectLanguage(cross_kv.first, cross_kv.second); + const auto &id2lang = model_->GetID2Lang(); + if (id2lang.count(lang_id)) { + return id2lang.at(lang_id); + } else { + SHERPA_ONNX_LOGE("Unknown language ID: %d. Return an empty string.", + lang_id); + return ""; + } + } + + private: + void Check() const { + if (!model_->IsMultiLingual()) { + SHERPA_ONNX_LOGE( + "Only whisper multilingual models can be used for spoken language " + "identification. Given: %s,%s", + config_.whisper.encoder.c_str(), config_.whisper.decoder.c_str()); + exit(-1); + } + } + + private: + SpokenLanguageIdentificationConfig config_; + std::unique_ptr model_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/spoken-language-identification.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/spoken-language-identification.cc new file mode 100644 index 00000000..ba9f72ab --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/spoken-language-identification.cc @@ -0,0 +1,130 @@ +// sherpa-mnn/csrc/spoken-language-identification.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/spoken-language-identification.h" + +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/spoken-language-identification-impl.h" + +namespace sherpa_mnn { + +void SpokenLanguageIdentificationWhisperConfig::Register(ParseOptions *po) { + po->Register( + "whisper-encoder", &encoder, + "Path to then encoder of a whisper multilingual model. Support only " + "tiny, base, small, medium, large."); + + po->Register( + "whisper-decoder", &decoder, + "Path to the decoder of a whisper multilingual model. Support only " + "tiny, base, small, medium, large."); + + po->Register( + "whisper-tail-paddings", &tail_paddings, + "Suggested value: 300 for multilingual models. " + "Since we have removed the 30-second constraint, we need to add some " + "tail padding frames " + "so that whisper can detect the eot token. Leave it to -1 to use 1000"); +} + +bool SpokenLanguageIdentificationWhisperConfig::Validate() const { + if (encoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --whisper-encoder"); + return false; + } + + if (!FileExists(encoder)) { + SHERPA_ONNX_LOGE("whisper encoder file '%s' does not exist", + encoder.c_str()); + return false; + } + + if (decoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --whisper-decoder"); + return false; + } + + if (!FileExists(decoder)) { + SHERPA_ONNX_LOGE("whisper decoder file '%s' does not exist", + decoder.c_str()); + return false; + } + + return true; +} + +std::string SpokenLanguageIdentificationWhisperConfig::ToString() const { + std::ostringstream os; + + os << "SpokenLanguageIdentificationWhisperConfig("; + os << "encoder=\"" << encoder << "\", "; + os << "decoder=\"" << decoder << "\", "; + os << "tail_paddings=" << tail_paddings << ")"; + + return os.str(); +} + +void SpokenLanguageIdentificationConfig::Register(ParseOptions *po) { + whisper.Register(po); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool SpokenLanguageIdentificationConfig::Validate() const { + if (!whisper.Validate()) { + return false; + } + + return true; +} + +std::string SpokenLanguageIdentificationConfig::ToString() const { + std::ostringstream os; + + os << "SpokenLanguageIdentificationConfig("; + os << "whisper=" << whisper.ToString() << ", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +SpokenLanguageIdentification::SpokenLanguageIdentification( + const SpokenLanguageIdentificationConfig &config) + : impl_(SpokenLanguageIdentificationImpl::Create(config)) {} + +#if __ANDROID_API__ >= 9 +SpokenLanguageIdentification::SpokenLanguageIdentification( + AAssetManager *mgr, const SpokenLanguageIdentificationConfig &config) + : impl_(SpokenLanguageIdentificationImpl::Create(mgr, config)) {} +#endif + +SpokenLanguageIdentification::~SpokenLanguageIdentification() = default; + +std::unique_ptr SpokenLanguageIdentification::CreateStream() + const { + return impl_->CreateStream(); +} + +std::string SpokenLanguageIdentification::Compute(OfflineStream *s) const { + return impl_->Compute(s); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/spoken-language-identification.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/spoken-language-identification.h new file mode 100644 index 00000000..ff3d9378 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/spoken-language-identification.h @@ -0,0 +1,99 @@ +// sherpa-mnn/csrc/spoken-language-identification.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ +#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ + +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-mnn/csrc/offline-stream.h" +#include "sherpa-mnn/csrc/parse-options.h" + +namespace sherpa_mnn { + +struct SpokenLanguageIdentificationWhisperConfig { + // Requires a multi-lingual whisper model. + // That is, it supports only tiny, base, small, medium, large. + // Note: It does NOT support tiny.en, base.en, small.en, medium.en + std::string encoder; + std::string decoder; + + // Number of tail padding frames. + // + // Since we remove the 30-second constraint, we need to add some paddings + // at the end. + // + // Recommended values: + // - 50 for English models + // - 300 for multilingual models + int32_t tail_paddings = -1; + + SpokenLanguageIdentificationWhisperConfig() = default; + + SpokenLanguageIdentificationWhisperConfig(const std::string &encoder, + const std::string &decoder, + int32_t tail_paddings) + : encoder(encoder), decoder(decoder), tail_paddings(tail_paddings) {} + + void Register(ParseOptions *po); + bool Validate() const; + std::string ToString() const; +}; + +struct SpokenLanguageIdentificationConfig { + SpokenLanguageIdentificationWhisperConfig whisper; + + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + SpokenLanguageIdentificationConfig() = default; + + SpokenLanguageIdentificationConfig( + const SpokenLanguageIdentificationWhisperConfig &whisper, + int32_t num_threads, bool debug, const std::string &provider) + : whisper(whisper), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + bool Validate() const; + std::string ToString() const; +}; + +class SpokenLanguageIdentificationImpl; + +class SpokenLanguageIdentification { + public: + explicit SpokenLanguageIdentification( + const SpokenLanguageIdentificationConfig &config); + +#if __ANDROID_API__ >= 9 + SpokenLanguageIdentification( + AAssetManager *mgr, const SpokenLanguageIdentificationConfig &config); +#endif + + ~SpokenLanguageIdentification(); + + // Create a stream to accept audio samples and compute features + std::unique_ptr CreateStream() const; + + // Return a string containing the language, e.g., en, zh, de, + // etc. + // Note: en is for English, zh is for Chinese, de is for German, etc. + std::string Compute(OfflineStream *s) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/stack-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/stack-test.cc new file mode 100644 index 00000000..118231ac --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/stack-test.cc @@ -0,0 +1,254 @@ +// sherpa-mnn/csrc/stack-test.cc +// +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) + +#include "sherpa-mnn/csrc/stack.h" + +#include "gtest/gtest.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +TEST(Stack, Test1DTensors) { + MNNAllocator* allocator; + + std::array a_shape{3}; + std::array b_shape{3}; + + MNN::Express::VARP a = MNNUtilsCreateTensor(allocator, a_shape.data(), + a_shape.size()); + + MNN::Express::VARP b = MNNUtilsCreateTensor(allocator, b_shape.data(), + b_shape.size()); + float *pa = a->writeMap(); + float *pb = b->writeMap(); + for (int32_t i = 0; i != static_cast(a_shape[0]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; i != static_cast(b_shape[0]); ++i) { + pb[i] = i + 10; + } + + MNN::Express::VARP ans = Stack(allocator, {&a, &b}, 0); + + Print1D(&a); + Print1D(&b); + Print2D(&ans); + + const float *pans = ans->readMap(); + for (int32_t i = 0; i != static_cast(a_shape[0]); ++i) { + EXPECT_EQ(pa[i], pans[i]); + } + + for (int32_t i = 0; i != static_cast(b_shape[0]); ++i) { + EXPECT_EQ(pb[i], pans[i + a_shape[0]]); + } +} + +TEST(Stack, Test2DTensorsDim0) { + MNNAllocator* allocator; + + std::array a_shape{2, 3}; + std::array b_shape{2, 3}; + + MNN::Express::VARP a = MNNUtilsCreateTensor(allocator, a_shape.data(), + a_shape.size()); + + MNN::Express::VARP b = MNNUtilsCreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a->writeMap(); + float *pb = b->writeMap(); + for (int32_t i = 0; i != static_cast(a_shape[0] * a_shape[1]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; i != static_cast(b_shape[0] * b_shape[1]); ++i) { + pb[i] = i + 10; + } + + MNN::Express::VARP ans = Stack(allocator, {&a, &b}, 0); + + Print2D(&a); + Print2D(&b); + Print3D(&ans); + + const float *pans = ans->readMap(); + for (int32_t i = 0; i != static_cast(a_shape[0] * a_shape[1]); ++i) { + EXPECT_EQ(pa[i], pans[i]); + } + for (int32_t i = 0; i != static_cast(b_shape[0] * b_shape[1]); ++i) { + EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1]]); + } +} + +TEST(Stack, Test2DTensorsDim1) { + MNNAllocator* allocator; + + std::array a_shape{4, 3}; + std::array b_shape{4, 3}; + + MNN::Express::VARP a = MNNUtilsCreateTensor(allocator, a_shape.data(), + a_shape.size()); + + MNN::Express::VARP b = MNNUtilsCreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a->writeMap(); + float *pb = b->writeMap(); + for (int32_t i = 0; i != static_cast(a_shape[0] * a_shape[1]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; i != static_cast(b_shape[0] * b_shape[1]); ++i) { + pb[i] = i + 10; + } + + MNN::Express::VARP ans = Stack(allocator, {&a, &b}, 1); + + Print2D(&a); + Print2D(&b); + Print3D(&ans); + + const float *pans = ans->readMap(); + + for (int32_t r = 0; r != static_cast(a_shape[0]); ++r) { + for (int32_t i = 0; i != static_cast(a_shape[1]); + ++i, ++pa, ++pans) { + EXPECT_EQ(*pa, *pans); + } + + for (int32_t i = 0; i != static_cast(b_shape[1]); + ++i, ++pb, ++pans) { + EXPECT_EQ(*pb, *pans); + } + } +} + +TEST(Stack, Test3DTensorsDim0) { + MNNAllocator* allocator; + + std::array a_shape{2, 3, 2}; + std::array b_shape{2, 3, 2}; + + MNN::Express::VARP a = MNNUtilsCreateTensor(allocator, a_shape.data(), + a_shape.size()); + + MNN::Express::VARP b = MNNUtilsCreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a->writeMap(); + float *pb = b->writeMap(); + for (int32_t i = 0; + i != static_cast(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; + i != static_cast(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { + pb[i] = i + 10; + } + + MNN::Express::VARP ans = Stack(allocator, {&a, &b}, 0); + + const float *pans = ans->readMap(); + for (int32_t i = 0; + i != static_cast(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { + EXPECT_EQ(pa[i], pans[i]); + } + for (int32_t i = 0; + i != static_cast(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { + EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1] * a_shape[2]]); + } + + Print3D(&a); + Print3D(&b); + Print4D(&ans); +} + +TEST(Stack, Test3DTensorsDim1) { + MNNAllocator* allocator; + + std::array a_shape{2, 2, 3}; + std::array b_shape{2, 2, 3}; + + MNN::Express::VARP a = MNNUtilsCreateTensor(allocator, a_shape.data(), + a_shape.size()); + + MNN::Express::VARP b = MNNUtilsCreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a->writeMap(); + float *pb = b->writeMap(); + for (int32_t i = 0; + i != static_cast(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; + i != static_cast(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { + pb[i] = i + 10; + } + + MNN::Express::VARP ans = Stack(allocator, {&a, &b}, 1); + + const float *pans = ans->readMap(); + + for (int32_t i = 0; i != static_cast(a_shape[0]); ++i) { + for (int32_t k = 0; k != static_cast(a_shape[1] * a_shape[2]); + ++k, ++pa, ++pans) { + EXPECT_EQ(*pa, *pans); + } + + for (int32_t k = 0; k != static_cast(b_shape[1] * b_shape[2]); + ++k, ++pb, ++pans) { + EXPECT_EQ(*pb, *pans); + } + } + + Print3D(&a); + Print3D(&b); + Print4D(&ans); +} + +TEST(Stack, Test3DTensorsDim2) { + MNNAllocator* allocator; + + std::array a_shape{2, 3, 4}; + std::array b_shape{2, 3, 4}; + + MNN::Express::VARP a = MNNUtilsCreateTensor(allocator, a_shape.data(), + a_shape.size()); + + MNN::Express::VARP b = MNNUtilsCreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a->writeMap(); + float *pb = b->writeMap(); + for (int32_t i = 0; + i != static_cast(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; + i != static_cast(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { + pb[i] = i + 10; + } + + MNN::Express::VARP ans = Stack(allocator, {&a, &b}, 2); + + const float *pans = ans->readMap(); + + for (int32_t i = 0; i != static_cast(a_shape[0] * a_shape[1]); ++i) { + for (int32_t k = 0; k != static_cast(a_shape[2]); + ++k, ++pa, ++pans) { + EXPECT_EQ(*pa, *pans); + } + + for (int32_t k = 0; k != static_cast(b_shape[2]); + ++k, ++pb, ++pans) { + EXPECT_EQ(*pb, *pans); + } + } + + Print3D(&a); + Print3D(&b); + Print4D(&ans); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/stack.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/stack.cc new file mode 100644 index 00000000..83d87e39 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/stack.cc @@ -0,0 +1,94 @@ +// sherpa-mnn/csrc/stack.cc +// +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) + +#include "sherpa-mnn/csrc/stack.h" + +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +static bool Compare(const std::vector &a, + const std::vector &b) { + if (a.size() != b.size()) return false; + + for (int32_t i = 0; i != static_cast(a.size()); ++i) { + if (a[i] != b[i]) return false; + } + + return true; +} + +static void PrintShape(const std::vector &a) { + for (auto i : a) { + fprintf(stderr, "%d ", static_cast(i)); + } + fprintf(stderr, "\n"); +} + +template +MNN::Express::VARP Stack(MNNAllocator *allocator, + const std::vector &values, int32_t dim) { + std::vector v0_shape = + values[0]->getInfo()->dim; + + for (int32_t i = 1; i != static_cast(values.size()); ++i) { + auto s = values[i]->getInfo()->dim; + bool ret = Compare(v0_shape, s); + if (!ret) { + fprintf(stderr, "Incorrect shape in Stack !\n"); + + fprintf(stderr, "Shape for tensor 0: "); + PrintShape(v0_shape); + + fprintf(stderr, "Shape for tensor %d: ", i); + PrintShape(s); + + exit(-1); + } + } + + std::vector ans_shape; + ans_shape.reserve(v0_shape.size() + 1); + ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim); + ans_shape.push_back(values.size()); + ans_shape.insert(ans_shape.end(), v0_shape.data() + dim, + v0_shape.data() + v0_shape.size()); + + auto leading_size = static_cast(std::accumulate( + v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies())); + + auto trailing_size = static_cast(std::accumulate( + v0_shape.begin() + dim, v0_shape.end(), 1, std::multiplies())); + + MNN::Express::VARP ans = MNNUtilsCreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + T *dst = ans->writeMap(); + + for (int32_t i = 0; i != leading_size; ++i) { + for (auto value : values) { + const T *src = value->readMap(); + src += i * trailing_size; + + std::copy(src, src + trailing_size, dst); + dst += trailing_size; + } + } + + return ans; +} + +template MNN::Express::VARP Stack(MNNAllocator *allocator, + const std::vector &values, + int32_t dim); + +template MNN::Express::VARP Stack( + MNNAllocator *allocator, const std::vector &values, + int32_t dim); + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/stack.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/stack.h new file mode 100644 index 00000000..34189813 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/stack.h @@ -0,0 +1,29 @@ +// sherpa-mnn/csrc/stack.h +// +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) + +#ifndef SHERPA_ONNX_CSRC_STACK_H_ +#define SHERPA_ONNX_CSRC_STACK_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT + +namespace sherpa_mnn { + +/** Stack a list of tensors along the given dim. + * + * @param allocator Allocator to allocate space for the returned tensor + * @param values Pointer to a list of tensors. The shape of the tensor must + * be the same except on the dim to be stacked. + * @param dim The dim along which to concatenate the input tensors + * + * @return Return the stacked tensor + */ +template +MNN::Express::VARP Stack(MNNAllocator *allocator, + const std::vector &values, int32_t dim); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_STACK_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/symbol-table.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/symbol-table.cc new file mode 100644 index 00000000..d47d025a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/symbol-table.cc @@ -0,0 +1,272 @@ +// sherpa-mnn/csrc/symbol-table.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/symbol-table.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 + +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/base64-decode.h" +#include "sherpa-mnn/csrc/bbpe.h" +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/lexicon.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +namespace { +// copied from +// https://stackoverflow.com/questions/216823/how-to-trim-a-stdstring +const char *ws = " \t\n\r\f\v"; + +// trim from end of string (right) +inline void TrimRight(std::string *s, const char *t = ws) { + s->erase(s->find_last_not_of(t) + 1); +} + +// trim from beginning of string (left) +inline void TrimLeft(std::string *s, const char *t = ws) { + s->erase(0, s->find_first_not_of(t)); +} + +// trim from both ends of string (right then left) +inline void Trim(std::string *s, const char *t = ws) { + TrimRight(s, t); + TrimLeft(s, t); +} + +bool IsByteBPE(const char *s, int32_t n) { + const uint8_t *p = reinterpret_cast(s); + if (n >= 3 && p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { + return IsByteBPE(s + 3, n - 3); + } + + for (int32_t i = 0; i != n; ++i) { + if (p[i] > 0xc6) { + return false; + } + } + + return true; +} + +bool IsByteBPE(const std::unordered_map &sym2id) { + uint8_t max_v = 0; + for (const auto &p : sym2id) { + const auto &s = p.first; + if (!IsByteBPE(s.c_str(), s.size())) { + return false; + } + + uint8_t m = 0; + if (s.size() >= 3) { + const uint8_t *p = reinterpret_cast(s.c_str()); + + if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { + if (s.size() > 3) { + m = *std::max_element( + reinterpret_cast(s.data()) + 3, + reinterpret_cast(s.data()) + s.size()); + } else { + m = 0; + } + } else { + m = *std::max_element( + reinterpret_cast(s.data()), + reinterpret_cast(s.data()) + s.size()); + } + } else { + m = *std::max_element( + reinterpret_cast(s.data()), + reinterpret_cast(s.data()) + s.size()); + } + + max_v = (m > max_v) ? m : max_v; + } + + return static_cast(max_v) == 0xc6; +} + +} // namespace + +std::unordered_map ReadTokens( + std::istream &is, + std::unordered_map *id2token /*= nullptr*/) { + std::unordered_map token2id; + + std::string line; + + std::string sym; + int32_t id = -1; + while (std::getline(is, line)) { + Trim(&line); + std::istringstream iss(line); + iss >> sym; + if (iss.eof()) { + id = atoi(sym.c_str()); + sym = " "; + } else { + iss >> id; + } + + // eat the trailing \r\n on windows + iss >> std::ws; + if (!iss.eof()) { + SHERPA_ONNX_LOGE("Error: %s", line.c_str()); + exit(-1); + } + +#if 0 + if (token2id.count(sym)) { + SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d", + sym.c_str(), line.c_str(), token2id.at(sym)); + exit(-1); + } +#endif + if (id2token) { + id2token->insert({id, sym}); + } + + token2id.insert({std::move(sym), id}); + } + + return token2id; +} + +SymbolTable::SymbolTable(const std::string &filename, bool is_file) { + if (is_file) { + std::ifstream is(filename); + Init(is); + } else { + std::istringstream iss(filename); + Init(iss); + } +} + +template +SymbolTable::SymbolTable(Manager *mgr, const std::string &filename) { + auto buf = ReadFile(mgr, filename); + + std::istrstream is(buf.data(), buf.size()); + Init(is); +} + +void SymbolTable::Init(std::istream &is) { + sym2id_ = ReadTokens(is, &id2sym_); + is_bbpe_ = IsByteBPE(sym2id_); +} + +std::string SymbolTable::ToString() const { + std::ostringstream os; + char sep = ' '; + for (const auto &p : sym2id_) { + os << p.first << sep << p.second << "\n"; + } + return os.str(); +} + +const std::string SymbolTable::operator[](int32_t id) const { + std::string sym = id2sym_.at(id); + if (sym.size() >= 3 && !is_bbpe_) { + // For BPE-based models, we replace ▁ with a space + // Unicode 9601, hex 0x2581, utf8 0xe29681 + const uint8_t *p = reinterpret_cast(sym.c_str()); + if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { + sym = sym.replace(0, 3, " "); + } + } + + // for BPE with byte_fallback + // id 0 is blank, id 1 is sos/eos, id 2 is unk + // + // Note: For moonshine models, 0 is , 1, is , 2 is + if (id >= 3 && id <= 258 && sym.size() == 6 && sym[0] == '<' && + sym[1] == '0' && sym[2] == 'x' && sym[5] == '>') { + std::ostringstream os; + os << std::hex << std::uppercase << (id - 3); + + if (std::string(sym.data() + 3, sym.data() + 5) == os.str()) { + uint8_t i = id - 3; + sym = std::string(&i, &i + 1); + } + } + return sym; +} + +int32_t SymbolTable::operator[](const std::string &sym) const { + return sym2id_.at(sym); +} + +bool SymbolTable::Contains(int32_t id) const { return id2sym_.count(id) != 0; } + +bool SymbolTable::Contains(const std::string &sym) const { + return sym2id_.count(sym) != 0; +} + +std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) { + return os << symbol_table.ToString(); +} + +void SymbolTable::ApplyBase64Decode() { + sym2id_.clear(); + for (auto &p : id2sym_) { + p.second = Base64Decode(p.second); + sym2id_[p.second] = p.first; + } +} + +std::string SymbolTable::DecodeByteBpe(const std::string &text) const { + if (!is_bbpe_) { + return text; + } + auto v = SplitUtf8(text); + + const auto &bbpe_table = GetByteBpeTable(); + std::string ans; + for (const auto &s : v) { + if (s == "▁") { + if (!ans.empty() && ans.back() != ' ' && std::isprint(ans.back())) { + ans.push_back(' '); + } + } else if (bbpe_table.count(s)) { + ans.push_back(bbpe_table.at(s)); + } else if (std::isprint(s[0])) { + ans.append(s); + } else { + // Should not happen + SHERPA_ONNX_LOGE("Skip OOV: %s from %s", s.c_str(), text.c_str()); + } + } + + // TODO(fangjun): Filter invalid utf-8 sequences + return ans; +} + +#if __ANDROID_API__ >= 9 +template SymbolTable::SymbolTable(AAssetManager *mgr, + const std::string &filename); +#endif + +#if __OHOS__ +template SymbolTable::SymbolTable(NativeResourceManager *mgr, + const std::string &filename); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/symbol-table.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/symbol-table.h new file mode 100644 index 00000000..77f46f0a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/symbol-table.h @@ -0,0 +1,76 @@ +// sherpa-mnn/csrc/symbol-table.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ +#define SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ + +#include +#include +#include +#include + +namespace sherpa_mnn { + +// The same token can be mapped to different integer IDs, so +// we need an id2token argument here. +std::unordered_map ReadTokens( + std::istream &is, + std::unordered_map *id2token = nullptr); + +std::vector ConvertTokensToIds( + const std::unordered_map &token2id, + const std::vector &tokens); + +/// It manages mapping between symbols and integer IDs. +class SymbolTable { + public: + SymbolTable() = default; + /// Construct a symbol table from a file or from a buffered string. + /// Each line in the file contains two fields: + /// + /// sym ID + /// + /// Fields are separated by space(s). + explicit SymbolTable(const std::string &filename, bool is_file = true); + + template + SymbolTable(Manager *mgr, const std::string &filename); + + /// Return a string representation of this symbol table + std::string ToString() const; + + /// Return the symbol corresponding to the given ID. + const std::string operator[](int32_t id) const; + /// Return the ID corresponding to the given symbol. + int32_t operator[](const std::string &sym) const; + + /// Return true if there is a symbol with the given ID. + bool Contains(int32_t id) const; + + /// Return true if there is a given symbol in the symbol table. + bool Contains(const std::string &sym) const; + + // for tokens.txt from Whisper + void ApplyBase64Decode(); + + int32_t NumSymbols() const { return id2sym_.size(); } + + std::string DecodeByteBpe(const std::string &text) const; + + bool IsByteBpe() const { return is_bbpe_; } + + private: + void Init(std::istream &is); + + private: + std::unordered_map sym2id_; + std::unordered_map id2sym_; + bool is_bbpe_ = false; +}; + +std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_SYMBOL_TABLE_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/tee-stream.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/tee-stream.h new file mode 100644 index 00000000..ffdc5b0a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/tee-stream.h @@ -0,0 +1,61 @@ +// Code in this file is copied and modified from +// https://wordaligned.org/articles/cpp-streambufs + +#ifndef SHERPA_ONNX_CSRC_TEE_STREAM_H_ +#define SHERPA_ONNX_CSRC_TEE_STREAM_H_ +#include +#include +#include + +namespace sherpa_mnn { + +template > +class basic_teebuf : public std::basic_streambuf { + public: + using int_type = typename traits::int_type; + + basic_teebuf(std::basic_streambuf *sb1, + std::basic_streambuf *sb2) + : sb1(sb1), sb2(sb2) {} + + private: + int sync() override { + int const r1 = sb1->pubsync(); + int const r2 = sb2->pubsync(); + return r1 == 0 && r2 == 0 ? 0 : -1; + } + + int_type overflow(int_type c) override { + int_type const eof = traits::eof(); + + if (traits::eq_int_type(c, eof)) { + return traits::not_eof(c); + } else { + char_type const ch = traits::to_char_type(c); + int_type const r1 = sb1->sputc(ch); + int_type const r2 = sb2->sputc(ch); + + return traits::eq_int_type(r1, eof) || traits::eq_int_type(r2, eof) ? eof + : c; + } + } + + private: + std::basic_streambuf *sb1; + std::basic_streambuf *sb2; +}; + +using teebuf = basic_teebuf; + +class TeeStream : public std::ostream { + public: + TeeStream(std::ostream &o1, std::ostream &o2) + : std::ostream(&tbuf), tbuf(o1.rdbuf(), o2.rdbuf()) {} + + private: + teebuf tbuf; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_TEE_STREAM_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/text-utils-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/text-utils-test.cc new file mode 100644 index 00000000..72aa508a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/text-utils-test.cc @@ -0,0 +1,131 @@ +// sherpa-mnn/csrc/text-utils-test.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/text-utils.h" + +#include "gtest/gtest.h" + +namespace sherpa_mnn { + +TEST(ToLowerCase, WideString) { + std::string text = + "Hallo! Übeltäter übergibt Ärzten öfters äußerst ätzende Öle 3€"; + auto t = ToLowerCase(text); + std::cout << text << "\n"; + std::cout << t << "\n"; +} + +TEST(RemoveInvalidUtf8Sequences, Case1) { + std::vector v = { + 0xe4, 0xbb, 0x8a, // 今 + 0xe5, 0xa4, 0xa9, // 天 + 'i', 's', ' ', 'M', 'o', 'd', 'a', 'y', ',', // is Monday, + ' ', 'w', 'i', 'e', ' ', 'h', 'e', 'i', 0xc3, // wie heißen Size + 0x9f, 'e', 'n', ' ', 'S', 'i', 'e', 0xf0, 0x9d, 0x84, 0x81}; + + std::vector v0 = v; + v0[1] = 0xc0; // make the first 3 bytes an invalid utf8 character + std::string s0{v0.begin(), v0.end()}; + EXPECT_EQ(s0.size(), v0.size()); + + auto s = RemoveInvalidUtf8Sequences(s0); // should remove 今 + + v0 = v; + // v0[23] == 0xc3 + // v0[24] == 0x9f + + v0[23] = 0xc1; + + s0 = {v0.begin(), v0.end()}; + s = RemoveInvalidUtf8Sequences(s0); // should remove ß + + EXPECT_EQ(s.size() + 2, v.size()); + + v0 = v; + // v0[31] = 0xf0; + // v0[32] = 0x9d; + // v0[33] = 0x84; + // v0[34] = 0x81; + v0[31] = 0xf5; + + s0 = {v0.begin(), v0.end()}; + s = RemoveInvalidUtf8Sequences(s0); + + EXPECT_EQ(s.size() + 4, v.size()); +} + + +// Tests for sanitizeUtf8 +TEST(RemoveInvalidUtf8Sequences, ValidUtf8StringPassesUnchanged) { + std::string input = "Valid UTF-8 🌍"; + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), input); +} + +TEST(RemoveInvalidUtf8Sequences, SingleInvalidByteReplaced) { + std::string input = "Invalid \xFF UTF-8"; + std::string expected = "Invalid UTF-8"; + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); +} + +TEST(RemoveInvalidUtf8Sequences, TruncatedUtf8SequenceReplaced) { + std::string input = "Broken \xE2\x82"; // Incomplete UTF-8 sequence + std::string expected = "Broken "; + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); +} + +TEST(RemoveInvalidUtf8Sequences, MultipleInvalidBytes) { + std::string input = "Test \xC0\xC0\xF8\xA0"; // Multiple invalid sequences + std::string expected = "Test "; + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); +} + +TEST(RemoveInvalidUtf8Sequences, BreakingCase_SpaceFollowedByInvalidByte) { + std::string input = "\x20\xC4"; // Space followed by an invalid byte + std::string expected = " "; // 0xC4 removed + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); +} + +TEST(RemoveInvalidUtf8Sequences, ValidUtf8WithEdgeCaseCharacters) { + std::string input = "Edge 🏆💯"; + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), input); +} + +TEST(RemoveInvalidUtf8Sequences, MixedValidAndInvalidBytes) { + std::string input = "Mix \xE2\x82\xAC \xF0\x9F\x98\x81 \xFF"; + std::string expected = "Mix € 😁 "; // Invalid bytes removed + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); +} + +TEST(RemoveInvalidUtf8Sequences, SpaceFollowedByInvalidByte) { + std::string input = "\x20\xC4"; // Space (0x20) followed by invalid (0xC4) + std::string expected = " "; // Space remains, 0xC4 is removed + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); +} + +TEST(RemoveInvalidUtf8Sequences, RemoveTruncatedC4) { + std::string input = "Hello \xc4 world"; // Invalid `0xC4` + std::string expected = "Hello world"; // `0xC4` should be removed + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); +} + +TEST(RemoveInvalidUtf8Sequences, SpaceFollowedByInvalidByte_Breaking) { + std::string input = "\x20\xc4"; // Space followed by invalid `0xc4` + std::string expected = " "; // `0xc4` should be removed, space remains + EXPECT_EQ(RemoveInvalidUtf8Sequences(input), expected); +} + +TEST(RemoveInvalidUtf8Sequences, DebugSpaceFollowedByInvalidByte) { + std::string input = "\x20\xc4"; // Space followed by invalid `0xc4` + std::string output = RemoveInvalidUtf8Sequences(input); + + std::cout << "Processed string: "; + for (unsigned char c : output) { + printf("\\x%02x ", c); + } + std::cout << std::endl; + + EXPECT_EQ(output, " "); // Expect `0xc4` to be removed, leaving only space +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/text-utils.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/text-utils.cc new file mode 100644 index 00000000..41206e37 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/text-utils.cc @@ -0,0 +1,710 @@ +// sherpa-mnn/csrc/text-utils.cc +// +// Copyright 2009-2011 Saarland University; Microsoft Corporation +// Copyright 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/text-utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_WIN32) +#include +#endif + +#include "sherpa-mnn/csrc/macros.h" + +// This file is copied/modified from +// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.cc + +namespace sherpa_mnn { + +// copied from kaldi/src/util/text-util.cc +template +class NumberIstream { + public: + explicit NumberIstream(std::istream &i) : in_(i) {} + + NumberIstream &operator>>(T &x) { + if (!in_.good()) return *this; + in_ >> x; + if (!in_.fail() && RemainderIsOnlySpaces()) return *this; + return ParseOnFail(&x); + } + + private: + std::istream &in_; + + bool RemainderIsOnlySpaces() { + if (in_.tellg() != std::istream::pos_type(-1)) { + std::string rem; + in_ >> rem; + + if (rem.find_first_not_of(' ') != std::string::npos) { + // there is not only spaces + return false; + } + } + + in_.clear(); + return true; + } + + NumberIstream &ParseOnFail(T *x) { + std::string str; + in_.clear(); + in_.seekg(0); + // If the stream is broken even before trying + // to read from it or if there are many tokens, + // it's pointless to try. + if (!(in_ >> str) || !RemainderIsOnlySpaces()) { + in_.setstate(std::ios_base::failbit); + return *this; + } + + std::unordered_map inf_nan_map; + // we'll keep just uppercase values. + inf_nan_map["INF"] = std::numeric_limits::infinity(); + inf_nan_map["+INF"] = std::numeric_limits::infinity(); + inf_nan_map["-INF"] = -std::numeric_limits::infinity(); + inf_nan_map["INFINITY"] = std::numeric_limits::infinity(); + inf_nan_map["+INFINITY"] = std::numeric_limits::infinity(); + inf_nan_map["-INFINITY"] = -std::numeric_limits::infinity(); + inf_nan_map["NAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["+NAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["-NAN"] = -std::numeric_limits::quiet_NaN(); + // MSVC + inf_nan_map["1.#INF"] = std::numeric_limits::infinity(); + inf_nan_map["-1.#INF"] = -std::numeric_limits::infinity(); + inf_nan_map["1.#QNAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["-1.#QNAN"] = -std::numeric_limits::quiet_NaN(); + + std::transform(str.begin(), str.end(), str.begin(), ::toupper); + + if (inf_nan_map.find(str) != inf_nan_map.end()) { + *x = inf_nan_map[str]; + } else { + in_.setstate(std::ios_base::failbit); + } + + return *this; + } +}; + +/// ConvertStringToReal converts a string into either float or double +/// and returns false if there was any kind of problem (i.e. the string +/// was not a floating point number or contained extra non-whitespace junk). +/// Be careful- this function will successfully read inf's or nan's. +template +bool ConvertStringToReal(const std::string &str, T *out) { + std::istringstream iss(str); + + NumberIstream i(iss); + + i >> *out; + + if (iss.fail()) { + // Number conversion failed. + return false; + } + + return true; +} + +template bool ConvertStringToReal(const std::string &str, float *out); + +template bool ConvertStringToReal(const std::string &str, double *out); + +void SplitStringToVector(const std::string &full, const char *delim, + bool omit_empty_strings, + std::vector *out) { + size_t start = 0, found = 0, end = full.size(); + out->clear(); + while (found != std::string::npos) { + found = full.find_first_of(delim, start); + // start != end condition is for when the delimiter is at the end + if (!omit_empty_strings || (found != start && start != end)) + out->push_back(full.substr(start, found - start)); + start = found + 1; + } +} + +template +bool SplitStringToFloats(const std::string &full, const char *delim, + bool omit_empty_strings, // typically false + std::vector *out) { + assert(out != nullptr); + if (*(full.c_str()) == '\0') { + out->clear(); + return true; + } + std::vector split; + SplitStringToVector(full, delim, omit_empty_strings, &split); + out->resize(split.size()); + for (size_t i = 0; i < split.size(); ++i) { + // assume atof never fails + F f = 0; + if (!ConvertStringToReal(split[i], &f)) return false; + (*out)[i] = f; + } + return true; +} + +// Instantiate the template above for float and double. +template bool SplitStringToFloats(const std::string &full, const char *delim, + bool omit_empty_strings, + std::vector *out); +template bool SplitStringToFloats(const std::string &full, const char *delim, + bool omit_empty_strings, + std::vector *out); + +static bool IsPunct(char c) { return c != '\'' && std::ispunct(c); } +static bool IsGermanUmlaut(const std::string &word) { + // ä 0xC3 0xA4 + // ö 0xC3 0xB6 + // ü 0xC3 0xBC + // Ä 0xC3 0x84 + // Ö 0xC3 0x96 + // Ü 0xC3 0x9C + // ß 0xC3 0x9F + + if (word.size() != 2 || static_cast(word[0]) != 0xc3) { + return false; + } + + auto c = static_cast(word[1]); + if (c == 0xa4 || c == 0xb6 || c == 0xbc || c == 0x84 || c == 0x96 || + c == 0x9c || c == 0x9f) { + return true; + } + + return false; +} + +// see https://www.tandem.net/blog/spanish-accents +// https://www.compart.com/en/unicode/U+00DC +static bool IsSpanishDiacritic(const std::string &word) { + // á 0xC3 0xA1 + // é 0xC3 0xA9 + // í 0xC3 0xAD + // ó 0xC3 0xB3 + // ú 0xC3 0xBA + // ü 0xC3 0xBC + // ñ 0xC3 0xB1 + // + // uppercase + // + // Á 0xC3 0x81 + // É 0xC3 0x89 + // Í 0xC3 0x8D + // Ó 0xC3 0x93 + // Ú 0xC3 0x9A + // Ü 0xC3 0x9C + // Ñ 0xC3 0x91 + + if (word.size() != 2 || static_cast(word[0]) != 0xc3) { + return false; + } + + auto c = static_cast(word[1]); + if (c == 0xa1 || c == 0xa9 || c == 0xad || c == 0xb3 || c == 0xba || + c == 0xbc || c == 0xb1 || c == 0x81 || c == 0x89 || c == 0x8d || + c == 0x93 || c == 0x9a || c == 0x9c || c == 0x91) { + return true; + } + + return false; +} + +// see https://www.busuu.com/en/french/accent-marks +static bool IsFrenchDiacritic(const std::string &word) { + // acute accent + // é 0xC3 0xA9 + // + // grave accent + // à 0xC3 0xA0 + // è 0xC3 0xA8 + // ù 0xC3 0xB9 + // + // cedilla + // ç 0xC3 0xA7 + // + // circumflex + // â 0xC3 0xA2 + // ê 0xC3 0xAA + // î 0xC3 0xAE + // ô 0xC3 0xB4 + // û 0xC3 0xBB + // + // trema + // ë 0xC3 0xAB + // ï 0xC3 0xAF + // ü 0xC3 0xBC + // + // É 0xC3 0x89 + // + // À 0xC3 0x80 + // È 0xC3 0x88 + // Ù 0xC3 0x99 + // Ç 0xC3 0x87 + // Â 0xC3 0x82 + // Ê 0xC3 0x8A + // Î 0xC3 0x8E + // Ô 0xC3 0x94 + // Û 0xC3 0x9B + // Ë 0xC3 0x8B + // Ï 0xC3 0x8F + // Ü 0xC3 0x9C + + if (word.size() != 2 || static_cast(word[0]) != 0xc3) { + return false; + } + + auto c = static_cast(word[1]); + if (c == 0xa9 || c == 0xa0 || c == 0xa8 || c == 0xb9 || c == 0xa7 || + c == 0xa2 || c == 0xaa || c == 0xae || c == 0xb4 || c == 0xbb || + c == 0xab || c == 0xaf || c == 0xbc || c == 0x89 || c == 0x80 || + c == 0x88 || c == 0x99 || c == 0x87 || c == 0x82 || c == 0x8a || + c == 0x8e || c == 0x94 || c == 0x9b || c == 0x8b || c == 0x8f || + c == 0x9c) { + return true; + } + return false; +} + +static bool IsSpecial(const std::string &w) { + bool ans = IsGermanUmlaut(w) || IsSpanishDiacritic(w) || IsFrenchDiacritic(w); + + // for french d’impossible + // ’ 0xE2 0x80 0x99 + bool ans2 = false; + if (w.size() == 3) { + auto c0 = static_cast(w[0]); + auto c1 = static_cast(w[1]); + auto c2 = static_cast(w[2]); + if (c0 == 0xe2 && c1 == 0x80 && c2 == 0x99) { + ans2 = true; + } + } + + return ans || ans2; +} + +static std::vector MergeCharactersIntoWords( + const std::vector &words) { + std::vector ans; + + int32_t n = static_cast(words.size()); + int32_t i = 0; + int32_t prev = -1; + + while (i < n) { + const auto &w = words[i]; + if (w.size() >= 3 || (w.size() == 2 && !IsSpecial(w)) || + (w.size() == 1 && (IsPunct(w[0]) || std::isspace(w[0])))) { + if (prev != -1) { + std::string t; + for (; prev < i; ++prev) { + t.append(words[prev]); + } + prev = -1; + ans.push_back(std::move(t)); + } + + if (!std::isspace(w[0])) { + ans.push_back(w); + } + ++i; + continue; + } + + // e.g., öffnen + if (w.size() == 1 || (w.size() == 2 && IsSpecial(w))) { + if (prev == -1) { + prev = i; + } + ++i; + continue; + } + + SHERPA_ONNX_LOGE("Ignore %s", w.c_str()); + ++i; + } + + if (prev != -1) { + std::string t; + for (; prev < i; ++prev) { + t.append(words[prev]); + } + ans.push_back(std::move(t)); + } + + return ans; +} + +std::vector SplitUtf8(const std::string &text) { + const uint8_t *begin = reinterpret_cast(text.c_str()); + const uint8_t *end = begin + text.size(); + + // Note that English words are split into single characters. + // We need to invoke MergeCharactersIntoWords() to merge them + std::vector ans; + + auto start = begin; + while (start < end) { + uint8_t c = *start; + uint8_t i = 0x80; + int32_t num_bytes = 0; + + // see + // https://en.wikipedia.org/wiki/UTF-8 + for (; c & i; i >>= 1) { + ++num_bytes; + } + + if (num_bytes == 0) { + // this is an ascii + ans.emplace_back(reinterpret_cast(start), 1); + ++start; + } else if (2 <= num_bytes && num_bytes <= 4) { + ans.emplace_back(reinterpret_cast(start), num_bytes); + start += num_bytes; + } else { + SHERPA_ONNX_LOGE("Invalid byte at position: %d", + static_cast(start - begin)); + // skip this byte + ++start; + } + } + + return MergeCharactersIntoWords(ans); +} + +std::string ToLowerCase(const std::string &s) { + return ToString(ToLowerCase(ToWideString(s))); +} + +void ToLowerCase(std::string *in_out) { + std::transform(in_out->begin(), in_out->end(), in_out->begin(), + [](unsigned char c) { return std::tolower(c); }); +} + +std::wstring ToLowerCase(const std::wstring &s) { + std::wstring ans(s.size(), 0); + std::transform(s.begin(), s.end(), ans.begin(), [](wchar_t c) -> wchar_t { + switch (c) { + // French + case L'À': + return L'à'; + case L'Â': + return L'â'; + case L'Æ': + return L'æ'; + case L'Ç': + return L'ç'; + case L'È': + return L'è'; + case L'É': + return L'é'; + case L'Ë': + return L'ë'; + case L'Î': + return L'î'; + case L'Ï': + return L'ï'; + case L'Ô': + return L'ô'; + case L'Ù': + return L'ù'; + case L'Û': + return L'û'; + case L'Ü': + return L'ü'; + + // others + case L'Á': + return L'á'; + case L'Í': + return L'í'; + case L'Ó': + return L'ó'; + case L'Ú': + return L'ú'; + case L'Ñ': + return L'ñ'; + case L'Ì': + return L'ì'; + case L'Ò': + return L'ò'; + case L'Ä': + return L'ä'; + case L'Ö': + return L'ö'; + // TODO(fangjun): Add more + + default: + return std::towlower(c); + } + }); + return ans; +} + +static inline bool InRange(uint8_t x, uint8_t low, uint8_t high) { + return low <= x && x <= high; +} + +/* +Please see +https://stackoverflow.com/questions/6555015/check-for-invalid-utf8 + + +Table 3-7. Well-Formed UTF-8 Byte Sequences + +Code Points First Byte Second Byte Third Byte Fourth Byte +U+0000..U+007F 00..7F +U+0080..U+07FF C2..DF 80..BF +U+0800..U+0FFF E0 A0..BF 80..BF +U+1000..U+CFFF E1..EC 80..BF 80..BF +U+D000..U+D7FF ED 80..9F 80..BF +U+E000..U+FFFF EE..EF 80..BF 80..BF +U+10000..U+3FFFF F0 90..BF 80..BF 80..BF +U+40000..U+FFFFF F1..F3 80..BF 80..BF 80..BF +U+100000..U+10FFFF F4 80..8F 80..BF 80..BF + */ +std::string RemoveInvalidUtf8Sequences(const std::string &text, + bool show_debug_msg /*= false*/) { + int32_t n = static_cast(text.size()); + + std::string ans; + ans.reserve(n); + + int32_t i = 0; + const uint8_t *p = reinterpret_cast(text.data()); + while (i < n) { + if (p[i] <= 0x7f) { + ans.append(text, i, 1); + i += 1; + continue; + } + + if (InRange(p[i], 0xc2, 0xdf) && i + 1 < n && + InRange(p[i + 1], 0x80, 0xbf)) { + ans.append(text, i, 2); + i += 2; + continue; + } + + if (p[i] == 0xe0 && i + 2 < n && InRange(p[i + 1], 0xa0, 0xbf) && + InRange(p[i + 2], 0x80, 0xbf)) { + ans.append(text, i, 3); + i += 3; + continue; + } + + if (InRange(p[i], 0xe1, 0xec) && i + 2 < n && + InRange(p[i + 1], 0x80, 0xbf) && InRange(p[i + 2], 0x80, 0xbf)) { + ans.append(text, i, 3); + i += 3; + continue; + } + + if (p[i] == 0xed && i + 2 < n && InRange(p[i + 1], 0x80, 0x9f) && + InRange(p[i + 2], 0x80, 0xbf)) { + ans.append(text, i, 3); + i += 3; + continue; + } + + if (InRange(p[i], 0xee, 0xef) && i + 2 < n && + InRange(p[i + 1], 0x80, 0xbf) && InRange(p[i + 2], 0x80, 0xbf)) { + ans.append(text, i, 3); + i += 3; + continue; + } + + if (p[i] == 0xf0 && i + 3 < n && InRange(p[i + 1], 0x90, 0xbf) && + InRange(p[i + 2], 0x80, 0xbf) && InRange(p[i + 3], 0x80, 0xbf)) { + ans.append(text, i, 4); + i += 4; + continue; + } + + if (InRange(p[i], 0xf1, 0xf3) && i + 3 < n && + InRange(p[i + 1], 0x80, 0xbf) && InRange(p[i + 2], 0x80, 0xbf) && + InRange(p[i + 3], 0x80, 0xbf)) { + ans.append(text, i, 4); + i += 4; + continue; + } + + if (p[i] == 0xf4 && i + 3 < n && InRange(p[i + 1], 0x80, 0x8f) && + InRange(p[i + 2], 0x80, 0xbf) && InRange(p[i + 3], 0x80, 0xbf)) { + ans.append(text, i, 4); + i += 4; + continue; + } + + if (show_debug_msg) { + SHERPA_ONNX_LOGE("Ignore invalid utf8 sequence at pos: %d, value: %02x", + i, p[i]); + } + + i += 1; + } + + return ans; +} + +bool IsUtf8(const std::string &text) { + int32_t n = static_cast(text.size()); + int32_t i = 0; + const uint8_t *p = reinterpret_cast(text.data()); + while (i < n) { + if (p[i] <= 0x7f) { + i += 1; + continue; + } + + if (InRange(p[i], 0xc2, 0xdf) && i + 1 < n && + InRange(p[i + 1], 0x80, 0xbf)) { + i += 2; + continue; + } + + if (p[i] == 0xe0 && i + 2 < n && InRange(p[i + 1], 0xa0, 0xbf) && + InRange(p[i + 2], 0x80, 0xbf)) { + i += 3; + continue; + } + + if (InRange(p[i], 0xe1, 0xec) && i + 2 < n && + InRange(p[i + 1], 0x80, 0xbf) && InRange(p[i + 2], 0x80, 0xbf)) { + i += 3; + continue; + } + + if (p[i] == 0xed && i + 2 < n && InRange(p[i + 1], 0x80, 0x9f) && + InRange(p[i + 2], 0x80, 0xbf)) { + i += 3; + continue; + } + + if (InRange(p[i], 0xee, 0xef) && i + 2 < n && + InRange(p[i + 1], 0x80, 0xbf) && InRange(p[i + 2], 0x80, 0xbf)) { + i += 3; + continue; + } + + if (p[i] == 0xf0 && i + 3 < n && InRange(p[i + 1], 0x90, 0xbf) && + InRange(p[i + 2], 0x80, 0xbf) && InRange(p[i + 3], 0x80, 0xbf)) { + i += 4; + continue; + } + + if (InRange(p[i], 0xf1, 0xf3) && i + 3 < n && + InRange(p[i + 1], 0x80, 0xbf) && InRange(p[i + 2], 0x80, 0xbf) && + InRange(p[i + 3], 0x80, 0xbf)) { + i += 4; + continue; + } + + if (p[i] == 0xf4 && i + 3 < n && InRange(p[i + 1], 0x80, 0x8f) && + InRange(p[i + 2], 0x80, 0xbf) && InRange(p[i + 3], 0x80, 0xbf)) { + i += 4; + continue; + } + + return false; + } + + return true; +} + +bool IsGB2312(const std::string &text) { + int32_t n = static_cast(text.size()); + int32_t i = 0; + const uint8_t *p = reinterpret_cast(text.data()); + while (i < n) { + if (p[i] <= 0x7f) { + i += 1; + continue; + } + + if (InRange(p[i], 0xa1, 0xf7) && i + 1 < n && + InRange(p[i + 1], 0xa1, 0xfe)) { + i += 2; + continue; + } + + return false; + } + + return true; +} + +#if defined(_WIN32) +std::string Gb2312ToUtf8(const std::string &text) { + // https://learn.microsoft.com/en-us/windows/win32/api/stringapiset/nf-stringapiset-multibytetowidechar + // 936 is from + // https://learn.microsoft.com/en-us/windows/win32/intl/code-page-identifiers + // GB2312 -> 936 + int32_t num_wchars = + MultiByteToWideChar(936, 0, text.c_str(), text.size(), nullptr, 0); + SHERPA_ONNX_LOGE("num of wchars: %d", num_wchars); + if (num_wchars == 0) { + return {}; + } + + std::wstring wstr; + wstr.resize(num_wchars); + MultiByteToWideChar(936, 0, text.c_str(), text.size(), wstr.data(), + num_wchars); + // https://learn.microsoft.com/en-us/windows/win32/api/stringapiset/nf-stringapiset-widechartomultibyte + int32_t num_chars = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, nullptr, + 0, nullptr, nullptr); + if (num_chars == 0) { + return {}; + } + + std::string ans(num_chars, 0); + WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, ans.data(), num_chars, + nullptr, nullptr); + + return ans; +} +#endif + +std::wstring ToWideString(const std::string &s) { + // see + // https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t + std::wstring_convert> converter; + return converter.from_bytes(s); +} + +std::string ToString(const std::wstring &s) { + // see + // https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t + std::wstring_convert> converter; + return converter.to_bytes(s); +} + +bool EndsWith(const std::string &haystack, const std::string &needle) { + if (needle.size() > haystack.size()) { + return false; + } + + return std::equal(needle.rbegin(), needle.rend(), haystack.rbegin()); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/text-utils.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/text-utils.h new file mode 100644 index 00000000..b254e131 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/text-utils.h @@ -0,0 +1,152 @@ +// sherpa-mnn/csrc/text-utils.h +// +// Copyright 2009-2011 Saarland University; Microsoft Corporation +// Copyright 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_TEXT_UTILS_H_ +#define SHERPA_ONNX_CSRC_TEXT_UTILS_H_ +#include +#include + +#include +#include +#include +#include + +#ifdef _MSC_VER +#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) \ + _strtoi64(cur_cstr, end_cstr, 10); +#else +#define SHERPA_ONNX_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10); +#endif + +// This file is copied/modified from +// https://github.com/kaldi-asr/kaldi/blob/master/src/util/text-utils.h + +namespace sherpa_mnn { + +/// Converts a string into an integer via strtoll and returns false if there was +/// any kind of problem (i.e. the string was not an integer or contained extra +/// non-whitespace junk, or the integer was too large to fit into the type it is +/// being converted into). Only sets *out if everything was OK and it returns +/// true. +template +bool ConvertStringToInteger(const std::string &str, Int *out) { + // copied from kaldi/src/util/text-util.h + static_assert(std::is_integral::value, ""); + const char *this_str = str.c_str(); + char *end = nullptr; + errno = 0; + int i = SHERPA_ONNX_STRTOLL(this_str, &end); + if (end != this_str) { + while (isspace(*end)) ++end; + } + if (end == this_str || *end != '\0' || errno != 0) return false; + Int iInt = static_cast(i); + if (static_cast(iInt) != i || + (i < 0 && !std::numeric_limits::is_signed)) { + return false; + } + *out = iInt; + return true; +} + +/// Split a string using any of the single character delimiters. +/// If omit_empty_strings == true, the output will contain any +/// nonempty strings after splitting on any of the +/// characters in the delimiter. If omit_empty_strings == false, +/// the output will contain n+1 strings if there are n characters +/// in the set "delim" within the input string. In this case +/// the empty string is split to a single empty string. +void SplitStringToVector(const std::string &full, const char *delim, + bool omit_empty_strings, + std::vector *out); + +/** + \brief Split a string (e.g. 1:2:3) into a vector of integers. + + \param [in] delim String containing a list of characters, any of which + is allowed as a delimiter. + \param [in] omit_empty_strings If true, empty strings between delimiters are + allowed and will not produce an output integer; if false, + instances of characters in 'delim' that are consecutive or + at the start or end of the string would be an error. + You'll normally want this to be true if 'delim' consists + of spaces, and false otherwise. + \param [out] out The output list of integers. +*/ +template +bool SplitStringToIntegers(const std::string &full, const char *delim, + bool omit_empty_strings, // typically false [but + // should probably be true + // if "delim" is spaces]. + std::vector *out) { + static_assert(std::is_integral::value, ""); + if (*(full.c_str()) == '\0') { + out->clear(); + return true; + } + std::vector split; + SplitStringToVector(full, delim, omit_empty_strings, &split); + out->resize(split.size()); + for (size_t i = 0; i < split.size(); i++) { + const char *this_str = split[i].c_str(); + char *end = NULL; + int j = 0; + j = SHERPA_ONNX_STRTOLL(this_str, &end); + if (end == this_str || *end != '\0') { + out->clear(); + return false; + } else { + I jI = static_cast(j); + if (static_cast(jI) != j) { + // output type cannot fit this integer. + out->clear(); + return false; + } + (*out)[i] = jI; + } + } + return true; +} + +// This is defined for F = float and double. +template +bool SplitStringToFloats(const std::string &full, const char *delim, + bool omit_empty_strings, // typically false + std::vector *out); + +// This is defined for F = float and double. +template +bool ConvertStringToReal(const std::string &str, T *out); + +std::vector SplitUtf8(const std::string &text); + +std::string ToLowerCase(const std::string &s); +void ToLowerCase(std::string *in_out); + +std::wstring ToLowerCase(const std::wstring &s); + +std::string RemoveInvalidUtf8Sequences(const std::string &text, + bool show_debug_msg = false); + +// Return true if text contains valid utf8 sequence. +// Return false otherwise +bool IsUtf8(const std::string &text); + +// Return true if text contains valid gb2312 encoded sequence +// Return false otherwise +bool IsGB2312(const std::string &text); + +#if defined(_WIN32) +std::string Gb2312ToUtf8(const std::string &text); +#endif + +std::wstring ToWideString(const std::string &s); + +std::string ToString(const std::wstring &s); + +bool EndsWith(const std::string &haystack, const std::string &needle); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/text2token-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/text2token-test.cc new file mode 100644 index 00000000..63f6d53a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/text2token-test.cc @@ -0,0 +1,170 @@ +// sherpa-mnn/csrc/text2token-test.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include +#include +#include + +#include "gtest/gtest.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/utils.h" +#include "ssentencepiece/csrc/ssentencepiece.h" + +namespace sherpa_mnn { + +// Please refer to +// https://github.com/pkufool/sherpa-test-data +// to download test data for testing +static const char dir[] = "/tmp/sherpa-test-data"; + +TEST(TEXT2TOKEN, TEST_cjkchar) { + std::ostringstream oss; + oss << dir << "/text2token/tokens_cn.txt"; + + std::string tokens = oss.str(); + + if (!std::ifstream(tokens).good()) { + SHERPA_ONNX_LOGE( + "No test data found, skipping TEST_cjkchar()." + "You can download the test data by: " + "git clone https://github.com/pkufool/sherpa-test-data.git " + "/tmp/sherpa-test-data"); + return; + } + + auto sym_table = SymbolTable(tokens); + + std::string text = + "世界人民大团结\n中国 V S 美国\n\n"; // Test blank lines also + + std::istringstream iss(text); + + std::vector> ids; + std::vector scores; + + auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids, &scores); + + std::vector> expected_ids( + {{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}}); + EXPECT_EQ(ids, expected_ids); + + EXPECT_EQ(scores.size(), 0); +} + +TEST(TEXT2TOKEN, TEST_bpe) { + std::ostringstream oss; + oss << dir << "/text2token/tokens_en.txt"; + std::string tokens = oss.str(); + oss.clear(); + oss.str(""); + oss << dir << "/text2token/bpe_en.vocab"; + std::string bpe = oss.str(); + if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) { + SHERPA_ONNX_LOGE( + "No test data found, skipping TEST_bpe()." + "You can download the test data by: " + "git clone https://github.com/pkufool/sherpa-test-data.git " + "/tmp/sherpa-test-data"); + return; + } + + auto sym_table = SymbolTable(tokens); + auto bpe_processor = std::make_unique(bpe); + + std::string text = "HELLO WORLD\nI LOVE YOU :2.0"; + + std::istringstream iss(text); + + std::vector> ids; + std::vector scores; + + auto r = + EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores); + + std::vector> expected_ids( + {{22, 58, 24, 425}, {19, 370, 47}}); + EXPECT_EQ(ids, expected_ids); + + std::vector expected_scores({0, 2.0}); + EXPECT_EQ(scores, expected_scores); +} + +TEST(TEXT2TOKEN, TEST_cjkchar_bpe) { + std::ostringstream oss; + oss << dir << "/text2token/tokens_mix.txt"; + std::string tokens = oss.str(); + oss.clear(); + oss.str(""); + oss << dir << "/text2token/bpe_mix.vocab"; + std::string bpe = oss.str(); + if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) { + SHERPA_ONNX_LOGE( + "No test data found, skipping TEST_cjkchar_bpe()." + "You can download the test data by: " + "git clone https://github.com/pkufool/sherpa-test-data.git " + "/tmp/sherpa-test-data"); + return; + } + + auto sym_table = SymbolTable(tokens); + auto bpe_processor = std::make_unique(bpe); + + std::string text = "世界人民 GOES TOGETHER :1.5\n中国 GOES WITH 美国 :0.5"; + + std::istringstream iss(text); + + std::vector> ids; + std::vector scores; + + auto r = EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), + &ids, &scores); + + std::vector> expected_ids( + {{1368, 1392, 557, 680, 275, 178, 475}, + {685, 736, 275, 178, 179, 921, 736}}); + EXPECT_EQ(ids, expected_ids); + + std::vector expected_scores({1.5, 0.5}); + EXPECT_EQ(scores, expected_scores); +} + +TEST(TEXT2TOKEN, TEST_bbpe) { + std::ostringstream oss; + oss << dir << "/text2token/tokens_bbpe.txt"; + std::string tokens = oss.str(); + oss.clear(); + oss.str(""); + oss << dir << "/text2token/bbpe.vocab"; + std::string bpe = oss.str(); + if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) { + SHERPA_ONNX_LOGE( + "No test data found, skipping TEST_bbpe()." + "You can download the test data by: " + "git clone https://github.com/pkufool/sherpa-test-data.git " + "/tmp/sherpa-test-data"); + return; + } + + auto sym_table = SymbolTable(tokens); + auto bpe_processor = std::make_unique(bpe); + + std::string text = "频繁 :1.0\n李鞑靼"; + + std::istringstream iss(text); + + std::vector> ids; + std::vector scores; + + auto r = + EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores); + + std::vector> expected_ids( + {{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}}); + EXPECT_EQ(ids, expected_ids); + + std::vector expected_scores({1.0, 0}); + EXPECT_EQ(scores, expected_scores); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/transducer-keyword-decoder.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/transducer-keyword-decoder.cc new file mode 100644 index 00000000..3693e42b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/transducer-keyword-decoder.cc @@ -0,0 +1,185 @@ +// sherpa-mnn/csrc/transducer-keywords-decoder.cc +// +// Copyright (c) 2023-2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/transducer-keyword-decoder.h" + +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/log.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +TransducerKeywordResult TransducerKeywordDecoder::GetEmptyResult() const { + int32_t context_size = model_->ContextSize(); + int32_t blank_id = 0; // always 0 + TransducerKeywordResult r; + std::vector blanks(context_size, -1); + blanks.back() = blank_id; + + Hypotheses blank_hyp({{blanks, 0}}); + r.hyps = std::move(blank_hyp); + return r; +} + +void TransducerKeywordDecoder::Decode( + MNN::Express::VARP encoder_out, OnlineStream **ss, + std::vector *result) { + std::vector encoder_out_shape = + encoder_out->getInfo()->dim; + + if (encoder_out_shape[0] != result->size()) { + SHERPA_ONNX_LOGE( + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n", + static_cast(encoder_out_shape[0]), + static_cast(result->size())); + exit(-1); + } + + int32_t batch_size = static_cast(encoder_out_shape[0]); + + int32_t num_frames = static_cast(encoder_out_shape[1]); + int32_t vocab_size = model_->VocabSize(); + int32_t context_size = model_->ContextSize(); + std::vector blanks(context_size, -1); + blanks.back() = 0; // blank_id is hardcoded to 0 + + std::vector cur; + for (auto &r : *result) { + cur.push_back(std::move(r.hyps)); + } + std::vector prev; + + for (int32_t t = 0; t != num_frames; ++t) { + // Due to merging paths with identical token sequences, + // not all utterances have "num_active_paths" paths. + auto hyps_row_splits = GetHypsRowSplits(cur); + int32_t num_hyps = + hyps_row_splits.back(); // total num hyps for all utterance + prev.clear(); + for (auto &hyps : cur) { + for (auto &h : hyps) { + prev.push_back(std::move(h.second)); + } + } + cur.clear(); + cur.reserve(batch_size); + + MNN::Express::VARP decoder_input = model_->BuildDecoderInput(prev); + MNN::Express::VARP decoder_out = model_->RunDecoder(std::move(decoder_input)); + + MNN::Express::VARP cur_encoder_out = + GetEncoderOutFrame(model_->Allocator(), encoder_out, t); + cur_encoder_out = + Repeat(model_->Allocator(), cur_encoder_out, hyps_row_splits); + MNN::Express::VARP logit = + model_->RunJoiner(std::move(cur_encoder_out), View(decoder_out)); + + float *p_logit = logit->writeMap(); + LogSoftmax(p_logit, vocab_size, num_hyps); + + // The acoustic logprobs for current frame + std::vector logprobs(vocab_size * num_hyps); + std::memcpy(logprobs.data(), p_logit, + sizeof(float) * vocab_size * num_hyps); + + // now p_logit contains log_softmax output, we rename it to p_logprob + // to match what it actually contains + float *p_logprob = p_logit; + + // add log_prob of each hypothesis to p_logprob before taking top_k + for (int32_t i = 0; i != num_hyps; ++i) { + float log_prob = prev[i].log_prob; + for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) { + *p_logprob += log_prob; + } + } + p_logprob = p_logit; // we changed p_logprob in the above for loop + + for (int32_t b = 0; b != batch_size; ++b) { + int32_t frame_offset = (*result)[b].frame_offset; + int32_t start = hyps_row_splits[b]; + int32_t end = hyps_row_splits[b + 1]; + auto topk = + TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_); + + Hypotheses hyps; + for (auto k : topk) { + int32_t hyp_index = k / vocab_size + start; + int32_t new_token = k % vocab_size; + + Hypothesis new_hyp = prev[hyp_index]; + float context_score = 0; + auto context_state = new_hyp.context_state; + + // blank is hardcoded to 0 + // also, it treats unk as blank + if (new_token != 0 && new_token != unk_id_) { + new_hyp.ys.push_back(new_token); + new_hyp.timestamps.push_back(t + frame_offset); + new_hyp.ys_probs.push_back( + exp(logprobs[hyp_index * vocab_size + new_token])); + + new_hyp.num_trailing_blanks = 0; + auto context_res = ss[b]->GetContextGraph()->ForwardOneStep( + context_state, new_token); + context_score = std::get<0>(context_res); + new_hyp.context_state = std::get<1>(context_res); + // Start matching from the start state, forget the decoder history. + if (new_hyp.context_state->token == -1) { + new_hyp.ys = blanks; + new_hyp.timestamps.clear(); + new_hyp.ys_probs.clear(); + } + } else { + ++new_hyp.num_trailing_blanks; + } + new_hyp.log_prob = p_logprob[k] + context_score; + hyps.Add(std::move(new_hyp)); + } // for (auto k : topk) + + auto best_hyp = hyps.GetMostProbable(false); + + auto status = ss[b]->GetContextGraph()->IsMatched(best_hyp.context_state); + bool matched = std::get<0>(status); + const ContextState *matched_state = std::get<1>(status); + + if (matched) { + float ys_prob = 0.0; + for (int32_t i = 0; i < matched_state->level; ++i) { + ys_prob += best_hyp.ys_probs[i]; + } + ys_prob /= matched_state->level; + if (best_hyp.num_trailing_blanks > num_trailing_blanks_ && + ys_prob >= matched_state->ac_threshold) { + auto &r = (*result)[b]; + r.tokens = {best_hyp.ys.end() - matched_state->level, + best_hyp.ys.end()}; + r.timestamps = {best_hyp.timestamps.end() - matched_state->level, + best_hyp.timestamps.end()}; + r.keyword = matched_state->phrase; + + hyps = Hypotheses({{blanks, 0, ss[b]->GetContextGraph()->Root()}}); + } + } + cur.push_back(std::move(hyps)); + p_logprob += (end - start) * vocab_size; + } // for (int32_t b = 0; b != batch_size; ++b) + } + + for (int32_t b = 0; b != batch_size; ++b) { + auto &hyps = cur[b]; + auto best_hyp = hyps.GetMostProbable(false); + auto &r = (*result)[b]; + r.hyps = std::move(hyps); + r.num_trailing_blanks = best_hyp.num_trailing_blanks; + r.frame_offset += num_frames; + } +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/transducer-keyword-decoder.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/transducer-keyword-decoder.h new file mode 100644 index 00000000..714e1770 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/transducer-keyword-decoder.h @@ -0,0 +1,62 @@ +// sherpa-mnn/csrc/transducer-keywords-decoder.h +// +// Copyright (c) 2023-2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_ +#define SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_ + +#include +#include +#include + +#include "sherpa-mnn/csrc/online-stream.h" +#include "sherpa-mnn/csrc/online-transducer-model.h" + +namespace sherpa_mnn { + +struct TransducerKeywordResult { + /// Number of frames after subsampling we have decoded so far + int32_t frame_offset = 0; + + /// The decoded token IDs for keywords + std::vector tokens; + + /// The triggered keyword + std::string keyword; + + /// number of trailing blank frames decoded so far + int32_t num_trailing_blanks = 0; + + /// timestamps[i] contains the output frame index where tokens[i] is decoded. + std::vector timestamps; + + // used only in modified beam_search + Hypotheses hyps; +}; + +class TransducerKeywordDecoder { + public: + TransducerKeywordDecoder(OnlineTransducerModel *model, + int32_t max_active_paths, + int32_t num_trailing_blanks, int32_t unk_id) + : model_(model), + max_active_paths_(max_active_paths), + num_trailing_blanks_(num_trailing_blanks), + unk_id_(unk_id) {} + + TransducerKeywordResult GetEmptyResult() const; + + void Decode(MNN::Express::VARP encoder_out, OnlineStream **ss, + std::vector *result); + + private: + OnlineTransducerModel *model_; // Not owned + + int32_t max_active_paths_; + int32_t num_trailing_blanks_; + int32_t unk_id_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/transpose-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/transpose-test.cc new file mode 100644 index 00000000..f2a1bbb3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/transpose-test.cc @@ -0,0 +1,62 @@ +// sherpa-mnn/csrc/transpose-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/transpose.h" + +#include + +#include "gtest/gtest.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +TEST(Tranpose, Tranpose01) { + MNNAllocator* allocator; + std::array shape{3, 2, 5}; + MNN::Express::VARP v = + MNNUtilsCreateTensor(allocator, shape.data(), shape.size()); + float *p = v->writeMap(); + + std::iota(p, p + shape[0] * shape[1] * shape[2], 0); + + auto ans = Transpose01(allocator, &v); + auto v2 = Transpose01(allocator, &ans); + + Print3D(&v); + Print3D(&ans); + Print3D(&v2); + + const float *q = v2->readMap(); + + for (int32_t i = 0; i != static_cast(shape[0] * shape[1] * shape[2]); + ++i) { + EXPECT_EQ(p[i], q[i]); + } +} + +TEST(Tranpose, Tranpose12) { + MNNAllocator* allocator; + std::array shape{3, 2, 5}; + MNN::Express::VARP v = + MNNUtilsCreateTensor(allocator, shape.data(), shape.size()); + float *p = v->writeMap(); + + std::iota(p, p + shape[0] * shape[1] * shape[2], 0); + + auto ans = Transpose12(allocator, &v); + auto v2 = Transpose12(allocator, &ans); + + Print3D(&v); + Print3D(&ans); + Print3D(&v2); + + const float *q = v2->readMap(); + + for (int32_t i = 0; i != static_cast(shape[0] * shape[1] * shape[2]); + ++i) { + EXPECT_EQ(p[i], q[i]); + } +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/transpose.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/transpose.cc new file mode 100644 index 00000000..f7acee49 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/transpose.cc @@ -0,0 +1,65 @@ +// sherpa-mnn/csrc/transpose.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/transpose.h" + +#include +#include +#include + +namespace sherpa_mnn { + +template +MNN::Express::VARP Transpose01(MNNAllocator *allocator, MNN::Express::VARP v) { + std::vector shape = v->getInfo()->dim; + assert(shape.size() == 3); + + std::array ans_shape{shape[1], shape[0], shape[2]}; + MNN::Express::VARP ans = MNNUtilsCreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + + T *dst = ans->writeMap(); + auto plane_offset = shape[1] * shape[2]; + + for (int i = 0; i != ans_shape[0]; ++i) { + const T *src = v->readMap() + i * shape[2]; + for (int k = 0; k != ans_shape[1]; ++k) { + std::copy(src, src + shape[2], dst); + src += plane_offset; + dst += shape[2]; + } + } + + return ans; +} + +template +MNN::Express::VARP Transpose12(MNNAllocator *allocator, MNN::Express::VARP v) { + std::vector shape = v->getInfo()->dim; + assert(shape.size() == 3); + + std::array ans_shape{shape[0], shape[2], shape[1]}; + MNN::Express::VARP ans = MNNUtilsCreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + T *dst = ans->writeMap(); + auto row_stride = shape[2]; + for (int b = 0; b != ans_shape[0]; ++b) { + const T *src = v->readMap() + b * shape[1] * shape[2]; + for (int i = 0; i != ans_shape[1]; ++i) { + for (int k = 0; k != ans_shape[2]; ++k, ++dst) { + *dst = (src + k * row_stride)[i]; + } + } + } + + return ans; +} + +template MNN::Express::VARP Transpose01(MNNAllocator *allocator, + MNN::Express::VARP v); + +template MNN::Express::VARP Transpose12(MNNAllocator *allocator, + MNN::Express::VARP v); + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/transpose.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/transpose.h new file mode 100644 index 00000000..3c4b3491 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/transpose.h @@ -0,0 +1,32 @@ +// sherpa-mnn/csrc/transpose.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_TRANSPOSE_H_ +#define SHERPA_ONNX_CSRC_TRANSPOSE_H_ + +#include "MNNUtils.hpp" // NOLINT + +namespace sherpa_mnn { +/** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C). + * + * @param allocator + * @param v A 3-D tensor of shape (B, T, C). Its dataype is type. + * + * @return Return a 3-D tensor of shape (T, B, C). Its datatype is type. + */ +template +MNN::Express::VARP Transpose01(MNNAllocator *allocator, MNN::Express::VARP v); + +/** Transpose a 3-D tensor from shape (B, T, C) to (B, C, T). + * + * @param allocator + * @param v A 3-D tensor of shape (B, T, C). Its dataype is type. + * + * @return Return a 3-D tensor of shape (B, C, T). Its datatype is type. + */ +template +MNN::Express::VARP Transpose12(MNNAllocator *allocator, MNN::Express::VARP v); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_TRANSPOSE_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/unbind-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/unbind-test.cc new file mode 100644 index 00000000..f31c9db1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/unbind-test.cc @@ -0,0 +1,223 @@ +// sherpa-mnn/csrc/unbind-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/unbind.h" + +#include "gtest/gtest.h" +#include "sherpa-mnn/csrc/cat.h" +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +TEST(Ubind, Test1DTensors) { + MNNAllocator* allocator; + std::array shape{3}; + MNN::Express::VARP v = + MNNUtilsCreateTensor(allocator, shape.data(), shape.size()); + float *p = v->writeMap(); + + for (int32_t i = 0; i != static_cast(shape[0]); ++i) { + p[i] = i; + } + auto ans = Unbind(allocator, &v, 0); + EXPECT_EQ(ans.size(), shape[0]); + for (int32_t i = 0; i != static_cast(shape[0]); ++i) { + EXPECT_EQ(ans[i]->readMap()[0], p[i]); + } + Print1D(&v); + for (int32_t i = 0; i != static_cast(shape[0]); ++i) { + Print1D(&ans[i]); + } + + // For Cat + std::vector vec(ans.size()); + for (int32_t i = 0; i != static_cast(vec.size()); ++i) { + vec[i] = &ans[i]; + } + MNN::Express::VARP v2 = Cat(allocator, vec, 0); + const float *p2 = v2->readMap(); + for (int32_t i = 0; i != shape[0]; ++i) { + EXPECT_EQ(p[i], p2[i]); + } +} + +TEST(Ubind, Test2DTensorsDim0) { + MNNAllocator* allocator; + std::array shape{3, 2}; + MNN::Express::VARP v = + MNNUtilsCreateTensor(allocator, shape.data(), shape.size()); + float *p = v->writeMap(); + + for (int32_t i = 0; i != static_cast(shape[0] * shape[1]); ++i) { + p[i] = i; + } + auto ans = Unbind(allocator, &v, 0); + + Print2D(&v); + for (int32_t i = 0; i != static_cast(shape[0]); ++i) { + Print2D(&ans[i]); + } + + for (int32_t i = 0; i != static_cast(shape[0]); ++i) { + const float *pans = ans[i]->readMap(); + for (int32_t k = 0; k != static_cast(shape[1]); ++k, ++p) { + EXPECT_EQ(*p, pans[k]); + } + } + + // For Cat + std::vector vec(ans.size()); + for (int32_t i = 0; i != static_cast(vec.size()); ++i) { + vec[i] = &ans[i]; + } + MNN::Express::VARP v2 = Cat(allocator, vec, 0); + Print2D(&v2); + + p = v->writeMap(); + const float *p2 = v2->readMap(); + for (int32_t i = 0; i != shape[0] * shape[1]; ++i) { + EXPECT_EQ(p[i], p2[i]); + } +} + +TEST(Ubind, Test2DTensorsDim1) { + MNNAllocator* allocator; + std::array shape{3, 2}; + MNN::Express::VARP v = + MNNUtilsCreateTensor(allocator, shape.data(), shape.size()); + float *p = v->writeMap(); + + for (int32_t i = 0; i != static_cast(shape[0] * shape[1]); ++i) { + p[i] = i; + } + auto ans = Unbind(allocator, &v, 1); + + Print2D(&v); + for (int32_t i = 0; i != static_cast(shape[1]); ++i) { + Print2D(&ans[i]); + } + + // For Cat + std::vector vec(ans.size()); + for (int32_t i = 0; i != static_cast(vec.size()); ++i) { + vec[i] = &ans[i]; + } + MNN::Express::VARP v2 = Cat(allocator, vec, 1); + Print2D(&v2); + + p = v->writeMap(); + const float *p2 = v2->readMap(); + for (int32_t i = 0; i != shape[0] * shape[1]; ++i) { + EXPECT_EQ(p[i], p2[i]); + } +} + +TEST(Ubind, Test3DTensorsDim0) { + MNNAllocator* allocator; + std::array shape{3, 2, 5}; + MNN::Express::VARP v = + MNNUtilsCreateTensor(allocator, shape.data(), shape.size()); + float *p = v->writeMap(); + + for (int32_t i = 0; i != static_cast(shape[0] * shape[1] * shape[2]); + ++i) { + p[i] = i; + } + auto ans = Unbind(allocator, &v, 0); + + Print3D(&v); + for (int32_t i = 0; i != static_cast(shape[0]); ++i) { + Print3D(&ans[i]); + } + + for (int32_t i = 0; i != static_cast(shape[0]); ++i) { + const float *pans = ans[i]->readMap(); + for (int32_t k = 0; k != static_cast(shape[1] * shape[2]); + ++k, ++p) { + EXPECT_EQ(*p, pans[k]); + } + } + + // For Cat + std::vector vec(ans.size()); + for (int32_t i = 0; i != static_cast(vec.size()); ++i) { + vec[i] = &ans[i]; + } + MNN::Express::VARP v2 = Cat(allocator, vec, 0); + Print3D(&v2); + + p = v->writeMap(); + const float *p2 = v2->readMap(); + for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) { + EXPECT_EQ(p[i], p2[i]); + } +} + +TEST(Ubind, Test3DTensorsDim1) { + MNNAllocator* allocator; + std::array shape{3, 2, 5}; + MNN::Express::VARP v = + MNNUtilsCreateTensor(allocator, shape.data(), shape.size()); + float *p = v->writeMap(); + + for (int32_t i = 0; i != static_cast(shape[0] * shape[1] * shape[2]); + ++i) { + p[i] = i; + } + auto ans = Unbind(allocator, &v, 1); + + Print3D(&v); + for (int32_t i = 0; i != static_cast(shape[1]); ++i) { + Print3D(&ans[i]); + } + + // For Cat + std::vector vec(ans.size()); + for (int32_t i = 0; i != static_cast(vec.size()); ++i) { + vec[i] = &ans[i]; + } + MNN::Express::VARP v2 = Cat(allocator, vec, 1); + Print3D(&v2); + + p = v->writeMap(); + const float *p2 = v2->readMap(); + for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) { + EXPECT_EQ(p[i], p2[i]); + } +} + +TEST(Ubind, Test3DTensorsDim2) { + MNNAllocator* allocator; + std::array shape{3, 2, 5}; + MNN::Express::VARP v = + MNNUtilsCreateTensor(allocator, shape.data(), shape.size()); + float *p = v->writeMap(); + + for (int32_t i = 0; i != static_cast(shape[0] * shape[1] * shape[2]); + ++i) { + p[i] = i; + } + auto ans = Unbind(allocator, &v, 2); + + Print3D(&v); + for (int32_t i = 0; i != static_cast(shape[2]); ++i) { + Print3D(&ans[i]); + } + + // For Cat + std::vector vec(ans.size()); + for (int32_t i = 0; i != static_cast(vec.size()); ++i) { + vec[i] = &ans[i]; + } + MNN::Express::VARP v2 = Cat(allocator, vec, 2); + Print3D(&v2); + + p = v->writeMap(); + const float *p2 = v2->readMap(); + for (int32_t i = 0; i != shape[0] * shape[1] * shape[2]; ++i) { + EXPECT_EQ(p[i], p2[i]); + } +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/unbind.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/unbind.cc new file mode 100644 index 00000000..5505099c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/unbind.cc @@ -0,0 +1,70 @@ +// sherpa-mnn/csrc/unbind.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/unbind.h" + +#include +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/onnx-utils.h" + +namespace sherpa_mnn { + +template +std::vector Unbind(MNNAllocator *allocator, MNN::Express::VARP value, + int32_t dim) { + std::vector shape = value->getInfo()->dim; + assert(dim >= 0); + assert(dim < static_cast(shape.size())); + int32_t n = static_cast(shape[dim]); + if (n == 1) { + std::vector ans; + ans.push_back(Clone(allocator, value)); + return ans; + } + + std::vector ans_shape = shape; + ans_shape[dim] = 1; // // Unlike torch, we keep the dim to 1 + + // allocator tensors + std::vector ans; + ans.reserve(n); + for (int32_t i = 0; i != n; ++i) { + MNN::Express::VARP t = MNNUtilsCreateTensor(allocator, ans_shape.data(), + ans_shape.size()); + ans.push_back(std::move(t)); + } + + auto leading_size = static_cast(std::accumulate( + shape.begin(), shape.begin() + dim, 1, std::multiplies())); + + auto trailing_size = static_cast(std::accumulate( + shape.begin() + dim + 1, shape.end(), 1, std::multiplies())); + + const T *src = value->readMap(); + + for (int32_t i = 0; i != leading_size; ++i) { + for (int32_t k = 0; k != n; ++k) { + T *dst = ans[k]->writeMap() + i * trailing_size; + std::copy(src, src + trailing_size, dst); + src += trailing_size; + } + } + + return ans; +} + +template std::vector Unbind(MNNAllocator *allocator, + MNN::Express::VARP value, + int32_t dim); + +template std::vector Unbind(MNNAllocator *allocator, + MNN::Express::VARP value, + int32_t dim); + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/unbind.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/unbind.h new file mode 100644 index 00000000..0641111f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/unbind.h @@ -0,0 +1,28 @@ +// sherpa-mnn/csrc/unbind.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_UNBIND_H_ +#define SHERPA_ONNX_CSRC_UNBIND_H_ + +#include + +#include "MNNUtils.hpp" // NOLINT + +namespace sherpa_mnn { + +/** It is similar to torch.unbind() but we keep the unbind dim to 1 in + * the output + * + * @param allocator Allocator to allocate space for the returned tensor + * @param value The tensor to unbind + * @param dim The dim along which to unbind the tensor + * + * @return Return a list of tensors + */ +template +std::vector Unbind(MNNAllocator *allocator, MNN::Express::VARP value, + int32_t dim); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_UNBIND_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/utfcpp-test.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/utfcpp-test.cc new file mode 100644 index 00000000..04273dc5 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/utfcpp-test.cc @@ -0,0 +1,21 @@ +// sherpa-mnn/csrc/utfcpp-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include +#include + +#include "gtest/gtest.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +TEST(UTF8, Case1) { + std::string hello = "你好, 早上好!世界. hello!。Hallo! how are you?"; + std::vector ss = SplitUtf8(hello); + for (const auto &s : ss) { + std::cout << s << "\n"; + } +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/utils.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/utils.cc new file mode 100644 index 00000000..87892b02 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/utils.cc @@ -0,0 +1,204 @@ +// sherpa-mnn/csrc/utils.cc +// +// Copyright 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/utils.h" + +#include +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/log.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/text-utils.h" + +namespace sherpa_mnn { + +static bool EncodeBase(const std::vector &lines, + const SymbolTable &symbol_table, + std::vector> *ids, + std::vector *phrases, + std::vector *scores, + std::vector *thresholds) { + ids->clear(); + + std::vector tmp_ids; + std::vector tmp_scores; + std::vector tmp_thresholds; + std::vector tmp_phrases; + + std::string word; + bool has_scores = false; + bool has_thresholds = false; + bool has_phrases = false; + bool has_oov = false; + + for (const auto &line : lines) { + float score = 0; + float threshold = 0; + std::string phrase = ""; + + std::istringstream iss(line); + while (iss >> word) { + if (symbol_table.Contains(word)) { + int32_t id = symbol_table[word]; + tmp_ids.push_back(id); + } else { + switch (word[0]) { + case ':': // boosting score for current keyword + score = std::stof(word.substr(1)); + has_scores = true; + break; + case '#': // triggering threshold (probability) for current keyword + threshold = std::stof(word.substr(1)); + has_thresholds = true; + break; + case '@': // the original keyword string + phrase = word.substr(1); + has_phrases = true; + break; + default: + SHERPA_ONNX_LOGE( + "Cannot find ID for token %s at line: %s. (Hint: Check the " + "tokens.txt see if %s in it)", + word.c_str(), line.c_str(), word.c_str()); + has_oov = true; + break; + } + } + } + ids->push_back(std::move(tmp_ids)); + tmp_ids = {}; + tmp_scores.push_back(score); + tmp_phrases.push_back(phrase); + tmp_thresholds.push_back(threshold); + } + if (scores != nullptr) { + if (has_scores) { + scores->swap(tmp_scores); + } else { + scores->clear(); + } + } + if (phrases != nullptr) { + if (has_phrases) { + *phrases = std::move(tmp_phrases); + } else { + phrases->clear(); + } + } + if (thresholds != nullptr) { + if (has_thresholds) { + thresholds->swap(tmp_thresholds); + } else { + thresholds->clear(); + } + } + return !has_oov; +} + +bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, + const SymbolTable &symbol_table, + const ssentencepiece::Ssentencepiece *bpe_encoder, + std::vector> *hotwords, + std::vector *boost_scores) { + std::vector lines; + std::string line; + std::string word; + + while (std::getline(is, line)) { + std::string score; + std::string phrase; + + std::ostringstream oss; + std::istringstream iss(line); + while (iss >> word) { + switch (word[0]) { + case ':': // boosting score for current keyword + score = word; + break; + default: + if (!score.empty()) { + SHERPA_ONNX_LOGE( + "Boosting score should be put after the words/phrase, given " + "%s.", + line.c_str()); + return false; + } + oss << " " << word; + break; + } + } + phrase = oss.str(); + if (phrase.empty()) { + continue; + } else { + phrase = phrase.substr(1); + } + std::istringstream piss(phrase); + oss.clear(); + oss.str(""); + while (piss >> word) { + if (modeling_unit == "cjkchar") { + for (const auto &w : SplitUtf8(word)) { + oss << " " << w; + } + } else if (modeling_unit == "bpe") { + std::vector bpes; + bpe_encoder->Encode(word, &bpes); + for (const auto &bpe : bpes) { + oss << " " << bpe; + } + } else { + if (modeling_unit != "cjkchar+bpe") { + SHERPA_ONNX_LOGE( + "modeling_unit should be one of bpe, cjkchar or cjkchar+bpe, " + "given " + "%s", + modeling_unit.c_str()); + exit(-1); + } + for (const auto &w : SplitUtf8(word)) { + if (isalpha(w[0])) { + std::vector bpes; + bpe_encoder->Encode(w, &bpes); + for (const auto &bpe : bpes) { + oss << " " << bpe; + } + } else { + oss << " " << w; + } + } + } + } + std::string encoded_phrase = oss.str().substr(1); + oss.clear(); + oss.str(""); + oss << encoded_phrase; + if (!score.empty()) { + oss << " " << score; + } + lines.push_back(oss.str()); + } + return EncodeBase(lines, symbol_table, hotwords, nullptr, boost_scores, + nullptr); +} + +bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, + std::vector> *keywords_id, + std::vector *keywords, + std::vector *boost_scores, + std::vector *threshold) { + std::vector lines; + std::string line; + while (std::getline(is, line)) { + lines.push_back(line); + } + return EncodeBase(lines, symbol_table, keywords_id, keywords, boost_scores, + threshold); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/utils.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/utils.h new file mode 100644 index 00000000..85929ace --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/utils.h @@ -0,0 +1,62 @@ +// sherpa-mnn/csrc/utils.h +// +// Copyright 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_UTILS_H_ +#define SHERPA_ONNX_CSRC_UTILS_H_ + +#include +#include + +#include "sherpa-mnn/csrc/symbol-table.h" +#include "ssentencepiece/csrc/ssentencepiece.h" + +namespace sherpa_mnn { + +/* Encode the hotwords in an input stream to be tokens ids. + * + * @param is The input stream, it contains several lines, one hotword for each + * line. For each hotword, the tokens (cjkchar or bpe) are separated + * by spaces. + * @param symbol_table The tokens table mapping symbols to ids. All the symbols + * in the stream should be in the symbol_table, if not this + * function returns fasle. + * + * @@param hotwords The encoded ids to be written to. + * + * @return If all the symbols from ``is`` are in the symbol_table, returns true + * otherwise returns false. + */ +bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, + const SymbolTable &symbol_table, + const ssentencepiece::Ssentencepiece *bpe_encoder_, + std::vector> *hotwords_id, + std::vector *boost_scores); + +/* Encode the keywords in an input stream to be tokens ids. + * + * @param is The input stream, it contains several lines, one hotword for each + * line. For each hotword, the tokens (cjkchar or bpe) are separated + * by spaces, it might contain boosting score (starting with :), + * triggering threshold (starting with #) and keyword string (starting + * with @) too. + * @param symbol_table The tokens table mapping symbols to ids. All the symbols + * in the stream should be in the symbol_table, if not this + * function returns fasle. + * + * @param keywords_id The encoded ids to be written to. + * @param keywords The original keyword string to be written to. + * @param boost_scores The boosting score for each keyword to be written to. + * @param threshold The triggering threshold for each keyword to be written to. + * + * @return If all the symbols from ``is`` are in the symbol_table, returns true + * otherwise returns false. + */ +bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, + std::vector> *keywords_id, + std::vector *keywords, + std::vector *boost_scores, + std::vector *threshold); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_UTILS_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/vad-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/vad-model-config.cc new file mode 100644 index 00000000..5d378c4d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/vad-model-config.cc @@ -0,0 +1,44 @@ +// sherpa-mnn/csrc/vad-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/vad-model-config.h" + +#include +#include + +namespace sherpa_mnn { + +void VadModelConfig::Register(ParseOptions *po) { + silero_vad.Register(po); + + po->Register("vad-sample-rate", &sample_rate, + "Sample rate expected by the VAD model"); + + po->Register("vad-num-threads", &num_threads, + "Number of threads to run the VAD model"); + + po->Register("vad-provider", &provider, + "Specify a provider to run the VAD model. Supported values: " + "cpu, cuda, coreml"); + + po->Register("vad-debug", &debug, + "true to display debug information when loading vad models"); +} + +bool VadModelConfig::Validate() const { return silero_vad.Validate(); } + +std::string VadModelConfig::ToString() const { + std::ostringstream os; + + os << "VadModelConfig("; + os << "silero_vad=" << silero_vad.ToString() << ", "; + os << "sample_rate=" << sample_rate << ", "; + os << "num_threads=" << num_threads << ", "; + os << "provider=\"" << provider << "\", "; + os << "debug=" << (debug ? "True" : "False") << ")"; + + return os.str(); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/vad-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/vad-model-config.h new file mode 100644 index 00000000..a40b4037 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/vad-model-config.h @@ -0,0 +1,42 @@ +// sherpa-mnn/csrc/vad-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_VAD_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_VAD_MODEL_CONFIG_H_ + +#include + +#include "sherpa-mnn/csrc/parse-options.h" +#include "sherpa-mnn/csrc/silero-vad-model-config.h" + +namespace sherpa_mnn { + +struct VadModelConfig { + SileroVadModelConfig silero_vad; + + int32_t sample_rate = 16000; + int32_t num_threads = 1; + std::string provider = "cpu"; + + // true to show debug information when loading models + bool debug = false; + + VadModelConfig() = default; + + VadModelConfig(const SileroVadModelConfig &silero_vad, int32_t sample_rate, + int32_t num_threads, const std::string &provider, bool debug) + : silero_vad(silero_vad), + sample_rate(sample_rate), + num_threads(num_threads), + provider(provider), + debug(debug) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_VAD_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/vad-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/vad-model.cc new file mode 100644 index 00000000..ded9f117 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/vad-model.cc @@ -0,0 +1,41 @@ +// sherpa-mnn/csrc/vad-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/vad-model.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/silero-vad-model.h" + +namespace sherpa_mnn { + +std::unique_ptr VadModel::Create(const VadModelConfig &config) { + // TODO(fangjun): Support other VAD models. + return std::make_unique(config); +} + +template +std::unique_ptr VadModel::Create(Manager *mgr, + const VadModelConfig &config) { + // TODO(fangjun): Support other VAD models. + return std::make_unique(mgr, config); +} + +#if __ANDROID_API__ >= 9 +template std::unique_ptr VadModel::Create( + AAssetManager *mgr, const VadModelConfig &config); +#endif + +#if __OHOS__ +template std::unique_ptr VadModel::Create( + NativeResourceManager *mgr, const VadModelConfig &config); +#endif +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/vad-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/vad-model.h new file mode 100644 index 00000000..05861956 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/vad-model.h @@ -0,0 +1,47 @@ +// sherpa-mnn/csrc/vad-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_VAD_MODEL_H_ +#define SHERPA_ONNX_CSRC_VAD_MODEL_H_ + +#include + +#include "sherpa-mnn/csrc/vad-model-config.h" + +namespace sherpa_mnn { + +class VadModel { + public: + virtual ~VadModel() = default; + + static std::unique_ptr Create(const VadModelConfig &config); + + template + static std::unique_ptr Create(Manager *mgr, + const VadModelConfig &config); + + // reset the internal model states + virtual void Reset() = 0; + + /** + * @param samples Pointer to a 1-d array containing audio samples. + * Each sample should be normalized to the range [-1, 1]. + * @param n Number of samples. Should be equal to WindowSize() + * + * @return Return true if speech is detected. Return false otherwise. + */ + virtual bool IsSpeech(const float *samples, int32_t n) = 0; + + virtual int32_t WindowSize() const = 0; + + virtual int32_t WindowShift() const = 0; + + virtual int32_t MinSilenceDurationSamples() const = 0; + virtual int32_t MinSpeechDurationSamples() const = 0; + virtual void SetMinSilenceDuration(float s) = 0; + virtual void SetThreshold(float threshold) = 0; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_VAD_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/voice-activity-detector.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/voice-activity-detector.cc new file mode 100644 index 00000000..71f53856 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/voice-activity-detector.cc @@ -0,0 +1,234 @@ +// sherpa-mnn/csrc/voice-activity-detector.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/voice-activity-detector.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-mnn/csrc/circular-buffer.h" +#include "sherpa-mnn/csrc/vad-model.h" + +namespace sherpa_mnn { + +class VoiceActivityDetector::Impl { + public: + explicit Impl(const VadModelConfig &config, float buffer_size_in_seconds = 60) + : model_(VadModel::Create(config)), + config_(config), + buffer_(buffer_size_in_seconds * config.sample_rate) { + Init(); + } + + template + Impl(Manager *mgr, const VadModelConfig &config, + float buffer_size_in_seconds = 60) + : model_(VadModel::Create(mgr, config)), + config_(config), + buffer_(buffer_size_in_seconds * config.sample_rate) { + Init(); + } + + void AcceptWaveform(const float *samples, int32_t n) { + if (buffer_.Size() > max_utterance_length_) { + model_->SetMinSilenceDuration(new_min_silence_duration_s_); + model_->SetThreshold(new_threshold_); + } else { + model_->SetMinSilenceDuration(config_.silero_vad.min_silence_duration); + model_->SetThreshold(config_.silero_vad.threshold); + } + + int32_t window_size = model_->WindowSize(); + int32_t window_shift = model_->WindowShift(); + + // note n is usually window_size and there is no need to use + // an extra buffer here + last_.insert(last_.end(), samples, samples + n); + + if (last_.size() < window_size) { + return; + } + + // Note: For v4, window_shift == window_size + int32_t k = + (static_cast(last_.size()) - window_size) / window_shift + 1; + const float *p = last_.data(); + bool is_speech = false; + + for (int32_t i = 0; i < k; ++i, p += window_shift) { + buffer_.Push(p, window_shift); + // NOTE(fangjun): Please don't use a very large n. + bool this_window_is_speech = model_->IsSpeech(p, window_size); + is_speech = is_speech || this_window_is_speech; + } + + last_ = std::vector( + p, static_cast(last_.data()) + last_.size()); + + if (is_speech) { + if (start_ == -1) { + // beginning of speech + start_ = std::max(buffer_.Tail() - 2 * model_->WindowSize() - + model_->MinSpeechDurationSamples(), + buffer_.Head()); + } + } else { + // non-speech + if (start_ != -1 && buffer_.Size()) { + // end of speech, save the speech segment + int32_t end = buffer_.Tail() - model_->MinSilenceDurationSamples(); + + std::vector s = buffer_.Get(start_, end - start_); + SpeechSegment segment; + + segment.start = start_; + segment.samples = std::move(s); + + segments_.push(std::move(segment)); + + buffer_.Pop(end - buffer_.Head()); + } + + if (start_ == -1) { + int32_t end = buffer_.Tail() - 2 * model_->WindowSize() - + model_->MinSpeechDurationSamples(); + int32_t n = std::max(0, end - buffer_.Head()); + if (n > 0) { + buffer_.Pop(n); + } + } + + start_ = -1; + } + } + + bool Empty() const { return segments_.empty(); } + + void Pop() { segments_.pop(); } + + void Clear() { std::queue().swap(segments_); } + + const SpeechSegment &Front() const { return segments_.front(); } + + void Reset() { + std::queue().swap(segments_); + + model_->Reset(); + buffer_.Reset(); + + start_ = -1; + } + + void Flush() { + if (start_ == -1 || buffer_.Size() == 0) { + return; + } + + int32_t end = buffer_.Tail(); + if (end <= start_) { + return; + } + + std::vector s = buffer_.Get(start_, end - start_); + + SpeechSegment segment; + + segment.start = start_; + segment.samples = std::move(s); + + segments_.push(std::move(segment)); + + buffer_.Pop(end - buffer_.Head()); + start_ = -1; + } + + bool IsSpeechDetected() const { return start_ != -1; } + + const VadModelConfig &GetConfig() const { return config_; } + + private: + void Init() { + // TODO(fangjun): Currently, we support only one vad model. + // If a new vad model is added, we need to change the place + // where max_speech_duration is placed. + max_utterance_length_ = + config_.sample_rate * config_.silero_vad.max_speech_duration; + } + + private: + std::queue segments_; + + std::unique_ptr model_; + VadModelConfig config_; + CircularBuffer buffer_; + std::vector last_; + + int max_utterance_length_ = -1; // in samples + float new_min_silence_duration_s_ = 0.1; + float new_threshold_ = 0.90; + + int32_t start_ = -1; +}; + +VoiceActivityDetector::VoiceActivityDetector( + const VadModelConfig &config, float buffer_size_in_seconds /*= 60*/) + : impl_(std::make_unique(config, buffer_size_in_seconds)) {} + +template +VoiceActivityDetector::VoiceActivityDetector( + Manager *mgr, const VadModelConfig &config, + float buffer_size_in_seconds /*= 60*/) + : impl_(std::make_unique(mgr, config, buffer_size_in_seconds)) {} + +VoiceActivityDetector::~VoiceActivityDetector() = default; + +void VoiceActivityDetector::AcceptWaveform(const float *samples, int32_t n) { + impl_->AcceptWaveform(samples, n); +} + +bool VoiceActivityDetector::Empty() const { return impl_->Empty(); } + +void VoiceActivityDetector::Pop() { impl_->Pop(); } + +void VoiceActivityDetector::Clear() { impl_->Clear(); } + +const SpeechSegment &VoiceActivityDetector::Front() const { + return impl_->Front(); +} + +void VoiceActivityDetector::Reset() const { impl_->Reset(); } + +void VoiceActivityDetector::Flush() const { impl_->Flush(); } + +bool VoiceActivityDetector::IsSpeechDetected() const { + return impl_->IsSpeechDetected(); +} + +const VadModelConfig &VoiceActivityDetector::GetConfig() const { + return impl_->GetConfig(); +} + +#if __ANDROID_API__ >= 9 +template VoiceActivityDetector::VoiceActivityDetector( + AAssetManager *mgr, const VadModelConfig &config, + float buffer_size_in_seconds = 60); +#endif + +#if __OHOS__ +template VoiceActivityDetector::VoiceActivityDetector( + NativeResourceManager *mgr, const VadModelConfig &config, + float buffer_size_in_seconds = 60); +#endif + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/voice-activity-detector.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/voice-activity-detector.h new file mode 100644 index 00000000..189d559e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/voice-activity-detector.h @@ -0,0 +1,53 @@ +// sherpa-mnn/csrc/voice-activity-detector.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_VOICE_ACTIVITY_DETECTOR_H_ +#define SHERPA_ONNX_CSRC_VOICE_ACTIVITY_DETECTOR_H_ + +#include +#include + +#include "sherpa-mnn/csrc/vad-model-config.h" + +namespace sherpa_mnn { + +struct SpeechSegment { + int32_t start; // in samples + std::vector samples; +}; + +class VoiceActivityDetector { + public: + explicit VoiceActivityDetector(const VadModelConfig &config, + float buffer_size_in_seconds = 60); + + template + VoiceActivityDetector(Manager *mgr, const VadModelConfig &config, + float buffer_size_in_seconds = 60); + + ~VoiceActivityDetector(); + + void AcceptWaveform(const float *samples, int32_t n); + bool Empty() const; + void Pop(); + void Clear(); + const SpeechSegment &Front() const; + + bool IsSpeechDetected() const; + + void Reset() const; + + // At the end of the utterance, you can invoke this method so that + // the last speech segment can be detected. + void Flush() const; + + const VadModelConfig &GetConfig() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_VOICE_ACTIVITY_DETECTOR_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/wave-reader.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/wave-reader.cc new file mode 100644 index 00000000..a8dd250d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/wave-reader.cc @@ -0,0 +1,327 @@ +// sherpa-mnn/csrc/wave-reader.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/wave-reader.h" + +#include +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { +namespace { +// see http://soundfile.sapp.org/doc/WaveFormat/ +// +// Note: We assume little endian here +// TODO(fangjun): Support big endian +struct WaveHeader { + // See + // https://en.wikipedia.org/wiki/WAV#Metadata + // and + // https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf + void SeekToDataChunk(std::istream &is) { + // a t a d + while (is && subchunk2_id != 0x61746164) { + // const char *p = reinterpret_cast(&subchunk2_id); + // printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0], + // p[1], p[2], p[3], subchunk2_size); + is.seekg(subchunk2_size, std::istream::cur); + is.read(reinterpret_cast(&subchunk2_id), sizeof(int32_t)); + is.read(reinterpret_cast(&subchunk2_size), sizeof(int32_t)); + } + } + + int32_t chunk_id; + int32_t chunk_size; + int32_t format; + int32_t subchunk1_id; + int32_t subchunk1_size; + int16_t audio_format; + int16_t num_channels; + int32_t sample_rate; + int32_t byte_rate; + int16_t block_align; + int16_t bits_per_sample; + int32_t subchunk2_id; // a tag of this chunk + int32_t subchunk2_size; // size of subchunk2 +}; +static_assert(sizeof(WaveHeader) == 44); + +/* +sox int16-1-channel-zh.wav -b 8 int8-1-channel-zh.wav + +sox int16-1-channel-zh.wav -c 2 int16-2-channel-zh.wav + +we use audacity to generate int32-1-channel-zh.wav and float32-1-channel-zh.wav +because sox uses WAVE_FORMAT_EXTENSIBLE, which is not easy to support +in sherpa-mnn. + */ + +// Read a wave file of mono-channel. +// Return its samples normalized to the range [-1, 1). +std::vector ReadWaveImpl(std::istream &is, int32_t *sampling_rate, + bool *is_ok) { + WaveHeader header{}; + is.read(reinterpret_cast(&header.chunk_id), sizeof(header.chunk_id)); + + // F F I R + if (header.chunk_id != 0x46464952) { + SHERPA_ONNX_LOGE("Expected chunk_id RIFF. Given: 0x%08x\n", + header.chunk_id); + *is_ok = false; + return {}; + } + + is.read(reinterpret_cast(&header.chunk_size), + sizeof(header.chunk_size)); + + is.read(reinterpret_cast(&header.format), sizeof(header.format)); + + // E V A W + if (header.format != 0x45564157) { + SHERPA_ONNX_LOGE("Expected format WAVE. Given: 0x%08x\n", header.format); + *is_ok = false; + return {}; + } + + is.read(reinterpret_cast(&header.subchunk1_id), + sizeof(header.subchunk1_id)); + + is.read(reinterpret_cast(&header.subchunk1_size), + sizeof(header.subchunk1_size)); + + if (header.subchunk1_id == 0x4b4e554a) { + // skip junk padding + is.seekg(header.subchunk1_size, std::istream::cur); + + is.read(reinterpret_cast(&header.subchunk1_id), + sizeof(header.subchunk1_id)); + + is.read(reinterpret_cast(&header.subchunk1_size), + sizeof(header.subchunk1_size)); + } + + if (header.subchunk1_id != 0x20746d66) { + SHERPA_ONNX_LOGE("Expected subchunk1_id 0x20746d66. Given: 0x%08x\n", + header.subchunk1_id); + *is_ok = false; + return {}; + } + + // NAudio uses 18 + // See https://github.com/naudio/NAudio/issues/1132 + if (header.subchunk1_size != 16 && + header.subchunk1_size != 18) { // 16 for PCM + SHERPA_ONNX_LOGE("Expected subchunk1_size 16. Given: %d\n", + header.subchunk1_size); + *is_ok = false; + return {}; + } + + is.read(reinterpret_cast(&header.audio_format), + sizeof(header.audio_format)); + + if (header.audio_format != 1 && header.audio_format != 3) { + // 1 for integer PCM + // 3 for floating point PCM + // see https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html + // and https://github.com/microsoft/DirectXTK/wiki/Wave-Formats + SHERPA_ONNX_LOGE("Expected audio_format 1. Given: %d\n", + header.audio_format); + + if (header.audio_format == static_cast(0xfffe)) { + SHERPA_ONNX_LOGE("We don't support WAVE_FORMAT_EXTENSIBLE files."); + } + + *is_ok = false; + return {}; + } + + is.read(reinterpret_cast(&header.num_channels), + sizeof(header.num_channels)); + + if (header.num_channels != 1) { // we support only single channel for now + SHERPA_ONNX_LOGE( + "Warning: %d channels are found. We only use the first channel.\n", + header.num_channels); + } + + is.read(reinterpret_cast(&header.sample_rate), + sizeof(header.sample_rate)); + + is.read(reinterpret_cast(&header.byte_rate), + sizeof(header.byte_rate)); + + is.read(reinterpret_cast(&header.block_align), + sizeof(header.block_align)); + + is.read(reinterpret_cast(&header.bits_per_sample), + sizeof(header.bits_per_sample)); + + if (header.byte_rate != + (header.sample_rate * header.num_channels * header.bits_per_sample / 8)) { + SHERPA_ONNX_LOGE("Incorrect byte rate: %d. Expected: %d", header.byte_rate, + (header.sample_rate * header.num_channels * + header.bits_per_sample / 8)); + *is_ok = false; + return {}; + } + + if (header.block_align != + (header.num_channels * header.bits_per_sample / 8)) { + SHERPA_ONNX_LOGE("Incorrect block align: %d. Expected: %d\n", + header.block_align, + (header.num_channels * header.bits_per_sample / 8)); + *is_ok = false; + return {}; + } + + if (header.bits_per_sample != 8 && header.bits_per_sample != 16 && + header.bits_per_sample != 32) { + SHERPA_ONNX_LOGE("Expected bits_per_sample 8, 16 or 32. Given: %d\n", + header.bits_per_sample); + *is_ok = false; + return {}; + } + + if (header.subchunk1_size == 18) { + // this is for NAudio. It puts extra bytes after bits_per_sample + // See + // https://github.com/naudio/NAudio/blob/master/NAudio.Core/Wave/WaveFormats/WaveFormat.cs#L223 + + int16_t extra_size = -1; + is.read(reinterpret_cast(&extra_size), sizeof(int16_t)); + if (extra_size != 0) { + SHERPA_ONNX_LOGE( + "Extra size should be 0 for wave from NAudio. Current extra size " + "%d\n", + extra_size); + *is_ok = false; + return {}; + } + } + + is.read(reinterpret_cast(&header.subchunk2_id), + sizeof(header.subchunk2_id)); + + is.read(reinterpret_cast(&header.subchunk2_size), + sizeof(header.subchunk2_size)); + + header.SeekToDataChunk(is); + if (!is) { + *is_ok = false; + return {}; + } + + *sampling_rate = header.sample_rate; + + std::vector ans; + + if (header.bits_per_sample == 16 && header.audio_format == 1) { + // header.subchunk2_size contains the number of bytes in the data. + // As we assume each sample contains two bytes, so it is divided by 2 here + std::vector samples(header.subchunk2_size / 2); + + is.read(reinterpret_cast(samples.data()), header.subchunk2_size); + if (!is) { + SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size); + *is_ok = false; + return {}; + } + + ans.resize(samples.size() / header.num_channels); + + // samples are interleaved + for (int32_t i = 0; i != static_cast(ans.size()); ++i) { + ans[i] = samples[i * header.num_channels] / 32768.; + } + } else if (header.bits_per_sample == 8 && header.audio_format == 1) { + // number of samples == number of bytes for 8-bit encoded samples + // + // For 8-bit encoded samples, they are unsigned! + std::vector samples(header.subchunk2_size); + + is.read(reinterpret_cast(samples.data()), header.subchunk2_size); + if (!is) { + SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size); + *is_ok = false; + return {}; + } + + ans.resize(samples.size() / header.num_channels); + for (int32_t i = 0; i != static_cast(ans.size()); ++i) { + // Note(fangjun): We want to normalize each sample into the range [-1, 1] + // Since each original sample is in the range [0, 256], dividing + // them by 128 converts them to the range [0, 2]; + // so after subtracting 1, we get the range [-1, 1] + // + ans[i] = samples[i * header.num_channels] / 128. - 1; + } + } else if (header.bits_per_sample == 32 && header.audio_format == 1) { + // 32 here is for int32 + // + // header.subchunk2_size contains the number of bytes in the data. + // As we assume each sample contains 4 bytes, so it is divided by 4 here + std::vector samples(header.subchunk2_size / 4); + + is.read(reinterpret_cast(samples.data()), header.subchunk2_size); + if (!is) { + SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size); + *is_ok = false; + return {}; + } + + ans.resize(samples.size() / header.num_channels); + for (int32_t i = 0; i != static_cast(ans.size()); ++i) { + ans[i] = static_cast(samples[i * header.num_channels]) / (1 << 31); + } + } else if (header.bits_per_sample == 32 && header.audio_format == 3) { + // 32 here is for float32 + // + // header.subchunk2_size contains the number of bytes in the data. + // As we assume each sample contains 4 bytes, so it is divided by 4 here + std::vector samples(header.subchunk2_size / 4); + + is.read(reinterpret_cast(samples.data()), header.subchunk2_size); + if (!is) { + SHERPA_ONNX_LOGE("Failed to read %d bytes", header.subchunk2_size); + *is_ok = false; + return {}; + } + + ans.resize(samples.size() / header.num_channels); + for (int32_t i = 0; i != static_cast(ans.size()); ++i) { + ans[i] = samples[i * header.num_channels]; + } + } else { + SHERPA_ONNX_LOGE( + "Unsupported %d bits per sample and audio format: %d. Supported values " + "are: 8, 16, 32.", + header.bits_per_sample, header.audio_format); + *is_ok = false; + return {}; + } + + *is_ok = true; + return ans; +} + +} // namespace + +std::vector ReadWave(const std::string &filename, int32_t *sampling_rate, + bool *is_ok) { + std::ifstream is(filename, std::ifstream::binary); + return ReadWave(is, sampling_rate, is_ok); +} + +std::vector ReadWave(std::istream &is, int32_t *sampling_rate, + bool *is_ok) { + auto samples = ReadWaveImpl(is, sampling_rate, is_ok); + return samples; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/wave-reader.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/wave-reader.h new file mode 100644 index 00000000..2305a659 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/wave-reader.h @@ -0,0 +1,31 @@ +// sherpa-mnn/csrc/wave-reader.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_WAVE_READER_H_ +#define SHERPA_ONNX_CSRC_WAVE_READER_H_ + +#include +#include +#include + +namespace sherpa_mnn { + +/** Read a wave file with expected sample rate. + + @param filename Path to a wave file. It MUST be single channel, 16-bit + PCM encoded. + @param sampling_rate On return, it contains the sampling rate of the file. + @param is_ok On return it is true if the reading succeeded; false otherwise. + + @return Return wave samples normalized to the range [-1, 1). + */ +std::vector ReadWave(const std::string &filename, int32_t *sampling_rate, + bool *is_ok); + +std::vector ReadWave(std::istream &is, int32_t *sampling_rate, + bool *is_ok); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_WAVE_READER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/wave-writer.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/wave-writer.cc new file mode 100644 index 00000000..adcbd639 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/wave-writer.cc @@ -0,0 +1,92 @@ +// sherpa-mnn/csrc/wave-writer.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/wave-writer.h" + +#include +#include +#include +#include + +#include "sherpa-mnn/csrc/macros.h" + +namespace sherpa_mnn { +namespace { + +// see http://soundfile.sapp.org/doc/WaveFormat/ +// +// Note: We assume little endian here +// TODO(fangjun): Support big endian +struct WaveHeader { + int32_t chunk_id; + int32_t chunk_size; + int32_t format; + int32_t subchunk1_id; + int32_t subchunk1_size; + int16_t audio_format; + int16_t num_channels; + int32_t sample_rate; + int32_t byte_rate; + int16_t block_align; + int16_t bits_per_sample; + int32_t subchunk2_id; // a tag of this chunk + int32_t subchunk2_size; // size of subchunk2 +}; + +} // namespace + +int WaveFileSize(int32_t n_samples) { + return sizeof(WaveHeader) + n_samples * sizeof(int16_t); +} + +void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, + int32_t n) { + WaveHeader header{}; + header.chunk_id = 0x46464952; // FFIR + header.format = 0x45564157; // EVAW + header.subchunk1_id = 0x20746d66; // "fmt " + header.subchunk1_size = 16; // 16 for PCM + header.audio_format = 1; // PCM =1 + + int32_t num_channels = 1; + int32_t bits_per_sample = 16; // int16_t + header.num_channels = num_channels; + header.sample_rate = sampling_rate; + header.byte_rate = sampling_rate * num_channels * bits_per_sample / 8; + header.block_align = num_channels * bits_per_sample / 8; + header.bits_per_sample = bits_per_sample; + header.subchunk2_id = 0x61746164; // atad + header.subchunk2_size = n * num_channels * bits_per_sample / 8; + + header.chunk_size = 36 + header.subchunk2_size; + + std::vector samples_int16(n); + for (int32_t i = 0; i != n; ++i) { + samples_int16[i] = samples[i] * 32676; + } + + memcpy(buffer, &header, sizeof(WaveHeader)); + memcpy(buffer + sizeof(WaveHeader), samples_int16.data(), + n * sizeof(int16_t)); +} + +bool WriteWave(const std::string &filename, int32_t sampling_rate, + const float *samples, int32_t n) { + std::string buffer; + buffer.resize(WaveFileSize(n)); + WriteWave(buffer.data(), sampling_rate, samples, n); + std::ofstream os(filename, std::ios::binary); + if (!os) { + SHERPA_ONNX_LOGE("Failed to create %s", filename.c_str()); + return false; + } + os << buffer; + if (!os) { + SHERPA_ONNX_LOGE("Write %s failed", filename.c_str()); + return false; + } + return true; +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/wave-writer.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/wave-writer.h new file mode 100644 index 00000000..b2a5c5fb --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/csrc/wave-writer.h @@ -0,0 +1,32 @@ +// sherpa-mnn/csrc/wave-writer.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_WAVE_WRITER_H_ +#define SHERPA_ONNX_CSRC_WAVE_WRITER_H_ + +#include +#include + +namespace sherpa_mnn { + +// Write a single channel wave file. +// Note that the input samples are in the range [-1, 1]. It will be multiplied +// by 32767 and saved in int16_t format in the wave file. +// +// @param filename Path to save the samples. +// @param sampling_rate Sample rate of the samples. +// @param samples Pointer to the samples +// @param n Number of samples +// @return Return true if the write succeeds; return false otherwise. +bool WriteWave(const std::string &filename, int32_t sampling_rate, + const float *samples, int32_t n); + +void WriteWave(char *buffer, int32_t sampling_rate, const float *samples, + int32_t n); + +int WaveFileSize(int32_t n_samples); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_CSRC_WAVE_WRITER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/CMakeLists.txt b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/CMakeLists.txt new file mode 100644 index 00000000..06c074ad --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/CMakeLists.txt @@ -0,0 +1,49 @@ +include_directories(${CMAKE_SOURCE_DIR}) + +if(NOT DEFINED ANDROID_ABI) + if(NOT DEFINED ENV{JAVA_HOME}) + message(FATAL_ERROR "Please set the environment variable JAVA_HOME") + endif() + include_directories($ENV{JAVA_HOME}/include) + include_directories($ENV{JAVA_HOME}/include/linux) + include_directories($ENV{JAVA_HOME}/include/darwin) + include_directories($ENV{JAVA_HOME}/include/win32) +endif() + +set(sources + audio-tagging.cc + jni.cc + keyword-spotter.cc + offline-punctuation.cc + offline-recognizer.cc + offline-stream.cc + online-punctuation.cc + online-recognizer.cc + online-stream.cc + speaker-embedding-extractor.cc + speaker-embedding-manager.cc + spoken-language-identification.cc + voice-activity-detector.cc + wave-reader.cc + wave-writer.cc +) + +if(SHERPA_MNN_ENABLE_TTS) + list(APPEND sources + offline-tts.cc + ) +endif() + +if(SHERPA_MNN_ENABLE_SPEAKER_DIARIZATION) + list(APPEND sources + offline-speaker-diarization.cc + ) +endif() + +add_library(sherpa-mnn-jni SHARED ${sources}) + +target_compile_definitions(sherpa-mnn-jni PRIVATE SHERPA_MNN_BUILD_SHARED_LIBS=1) +target_compile_definitions(sherpa-mnn-jni PRIVATE SHERPA_MNN_BUILD_MAIN_LIB=1) + +target_link_libraries(sherpa-mnn-jni sherpa-mnn-core) +install(TARGETS sherpa-mnn-jni DESTINATION lib) diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/audio-tagging.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/audio-tagging.cc new file mode 100644 index 00000000..b5aa9201 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/audio-tagging.cc @@ -0,0 +1,156 @@ +// sherpa-mnn/jni/audio-tagging.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/audio-tagging.h" + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/jni/common.h" + +namespace sherpa_mnn { + +static AudioTaggingConfig GetAudioTaggingConfig(JNIEnv *env, jobject config) { + AudioTaggingConfig ans; + + jclass cls = env->GetObjectClass(config); + + jfieldID fid = env->GetFieldID( + cls, "model", "Lcom/k2fsa/sherpa/onnx/AudioTaggingModelConfig;"); + jobject model = env->GetObjectField(config, fid); + jclass model_cls = env->GetObjectClass(model); + + fid = env->GetFieldID( + model_cls, "zipformer", + "Lcom/k2fsa/sherpa/onnx/OfflineZipformerAudioTaggingModelConfig;"); + jobject zipformer = env->GetObjectField(model, fid); + jclass zipformer_cls = env->GetObjectClass(zipformer); + + fid = env->GetFieldID(zipformer_cls, "model", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(zipformer, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.model.zipformer.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_cls, "ced", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.ced = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_cls, "numThreads", "I"); + ans.model.num_threads = env->GetIntField(model, fid); + + fid = env->GetFieldID(model_cls, "debug", "Z"); + ans.model.debug = env->GetBooleanField(model, fid); + + fid = env->GetFieldID(model_cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.provider = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "labels", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.labels = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "topK", "I"); + ans.top_k = env->GetIntField(config, fid); + + return ans; +} + +} // namespace sherpa_mnn + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_mnn_AudioTagging_newFromAsset( + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; + } +#endif + + auto config = sherpa_mnn::GetAudioTaggingConfig(env, _config); + SHERPA_ONNX_LOGE("audio tagging newFromAsset config:\n%s", + config.ToString().c_str()); + + auto tagger = new sherpa_mnn::AudioTagging( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)tagger; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_mnn_AudioTagging_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + auto config = sherpa_mnn::GetAudioTaggingConfig(env, _config); + SHERPA_ONNX_LOGE("audio tagging newFromFile config:\n%s", + config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + return 0; + } + + auto tagger = new sherpa_mnn::AudioTagging(config); + + return (jlong)tagger; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_AudioTagging_delete( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_mnn_AudioTagging_createStream( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + auto tagger = reinterpret_cast(ptr); + std::unique_ptr s = tagger->CreateStream(); + + // The user is responsible to free the returned pointer. + // + // See Java_com_k2fsa_sherpa_mnn_OfflineStream_delete() from + // ./offline-stream.cc + sherpa_mnn::OfflineStream *p = s.release(); + return (jlong)p; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_mnn_AudioTagging_compute( + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong streamPtr, jint top_k) { + auto tagger = reinterpret_cast(ptr); + auto stream = reinterpret_cast(streamPtr); + std::vector events = tagger->Compute(stream, top_k); + + // TODO(fangjun): Return an array of AudioEvent directly + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + events.size(), env->FindClass("java/lang/Object"), nullptr); + + int32_t i = 0; + for (const auto &e : events) { + jobjectArray a = (jobjectArray)env->NewObjectArray( + 3, env->FindClass("java/lang/Object"), nullptr); + + // 0 name + // 1 index + // 2 prob + jstring js = env->NewStringUTF(e.name.c_str()); + env->SetObjectArrayElement(a, 0, js); + env->SetObjectArrayElement(a, 1, NewInteger(env, e.index)); + env->SetObjectArrayElement(a, 2, NewFloat(env, e.prob)); + + env->SetObjectArrayElement(obj_arr, i, a); + i += 1; + } + + return obj_arr; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/common.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/common.h new file mode 100644 index 00000000..34c343d8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/common.h @@ -0,0 +1,105 @@ +// sherpa-mnn/jni/common.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_JNI_COMMON_H_ +#define SHERPA_ONNX_JNI_COMMON_H_ + +#include + +#if __ANDROID_API__ >= 9 +#include + +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if defined(_WIN32) +#if defined(SHERPA_ONNX_BUILD_SHARED_LIBS) +#define SHERPA_ONNX_EXPORT __declspec(dllexport) +#define SHERPA_ONNX_IMPORT __declspec(dllimport) +#else +#define SHERPA_ONNX_EXPORT +#define SHERPA_ONNX_IMPORT +#endif +#else // WIN32 +#define SHERPA_ONNX_EXPORT __attribute__((visibility("default"))) + +#define SHERPA_ONNX_IMPORT SHERPA_ONNX_EXPORT +#endif // WIN32 + +#if defined(SHERPA_ONNX_BUILD_MAIN_LIB) +#define SHERPA_ONNX_API SHERPA_ONNX_EXPORT +#else +#define SHERPA_ONNX_API SHERPA_ONNX_IMPORT +#endif + +// If you use ndk, you can find "jni.h" inside +// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include +#include "jni.h" // NOLINT + +#define SHERPA_ONNX_EXTERN_C extern "C" SHERPA_ONNX_API + +// defined in jni.cc +jobject NewInteger(JNIEnv *env, int32_t value); +jobject NewFloat(JNIEnv *env, float value); + +// Template function for non-void return types +template +ReturnType SafeJNI(JNIEnv *env, const char *functionName, Func func, + ReturnType defaultValue) { + try { + return func(); + } catch (const std::exception &e) { + jclass exClass = env->FindClass("java/lang/RuntimeException"); + if (exClass != nullptr) { + std::string errorMessage = std::string(functionName) + ": " + e.what(); + env->ThrowNew(exClass, errorMessage.c_str()); + } + } catch (...) { + jclass exClass = env->FindClass("java/lang/RuntimeException"); + if (exClass != nullptr) { + std::string errorMessage = std::string(functionName) + + ": Native exception: caught unknown exception"; + env->ThrowNew(exClass, errorMessage.c_str()); + } + } + return defaultValue; +} + +// Specialization for void return type +template +void SafeJNI(JNIEnv *env, const char *functionName, Func func) { + try { + func(); + } catch (const std::exception &e) { + jclass exClass = env->FindClass("java/lang/RuntimeException"); + if (exClass != nullptr) { + std::string errorMessage = std::string(functionName) + ": " + e.what(); + env->ThrowNew(exClass, errorMessage.c_str()); + } + } catch (...) { + jclass exClass = env->FindClass("java/lang/RuntimeException"); + if (exClass != nullptr) { + std::string errorMessage = std::string(functionName) + + ": Native exception: caught unknown exception"; + env->ThrowNew(exClass, errorMessage.c_str()); + } + } +} + +// Helper function to validate JNI pointers +inline bool ValidatePointer(JNIEnv *env, jlong ptr, + const char *functionName, const char *message) { + if (ptr == 0) { + jclass exClass = env->FindClass("java/lang/NullPointerException"); + if (exClass != nullptr) { + std::string errorMessage = std::string(functionName) + ": " + message; + env->ThrowNew(exClass, errorMessage.c_str()); + } + return false; + } + return true; +} + +#endif // SHERPA_ONNX_JNI_COMMON_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/jni.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/jni.cc new file mode 100644 index 00000000..8244742b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/jni.cc @@ -0,0 +1,65 @@ +// sherpa-mnn/jni/jni.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation +// 2022 Pingfeng Luo +// 2023 Zhaoming + +#include + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/csrc/onnx-utils.h" +#include "sherpa-mnn/csrc/wave-writer.h" +#include "sherpa-mnn/jni/common.h" + +// see +// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables +jobject NewInteger(JNIEnv *env, int32_t value) { + jclass cls = env->FindClass("java/lang/Integer"); + jmethodID constructor = env->GetMethodID(cls, "", "(I)V"); + return env->NewObject(cls, constructor, value); +} + +jobject NewFloat(JNIEnv *env, float value) { + jclass cls = env->FindClass("java/lang/Float"); + jmethodID constructor = env->GetMethodID(cls, "", "(F)V"); + return env->NewObject(cls, constructor, value); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jboolean JNICALL Java_com_k2fsa_sherpa_mnn_GeneratedAudio_saveImpl( + JNIEnv *env, jobject /*obj*/, jstring filename, jfloatArray samples, + jint sample_rate) { + const char *p_filename = env->GetStringUTFChars(filename, nullptr); + + jfloat *p = env->GetFloatArrayElements(samples, nullptr); + jsize n = env->GetArrayLength(samples); + + bool ok = sherpa_mnn::WriteWave(p_filename, sample_rate, p, n); + + env->ReleaseStringUTFChars(filename, p_filename); + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); + + return ok; +} + +#if 0 +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_mnn_OnlineRecognizer_decodeStreams(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jlongArray ss_ptr, + jint stream_size) { + sherpa_mnn::OnlineRecognizer *model = + reinterpret_cast(ptr); + jlong *p = env->GetLongArrayElements(ss_ptr, nullptr); + jsize n = env->GetArrayLength(ss_ptr); + std::vector p_ss(n); + for (int32_t i = 0; i != n; ++i) { + p_ss[i] = reinterpret_cast(p[i]); + } + + model->DecodeStreams(p_ss.data(), n); + env->ReleaseLongArrayElements(ss_ptr, p, JNI_ABORT); +} +#endif diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/keyword-spotter.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/keyword-spotter.cc new file mode 100644 index 00000000..b12077e5 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/keyword-spotter.cc @@ -0,0 +1,244 @@ +// sherpa-mnn/jni/keyword-spotter.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/keyword-spotter.h" + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/jni/common.h" + +namespace sherpa_mnn { + +static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) { + KeywordSpotterConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid; + + // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html + // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html + + //---------- decoding ---------- + fid = env->GetFieldID(cls, "maxActivePaths", "I"); + ans.max_active_paths = env->GetIntField(config, fid); + + fid = env->GetFieldID(cls, "keywordsFile", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(config, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.keywords_file = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "keywordsScore", "F"); + ans.keywords_score = env->GetFloatField(config, fid); + + fid = env->GetFieldID(cls, "keywordsThreshold", "F"); + ans.keywords_threshold = env->GetFloatField(config, fid); + + fid = env->GetFieldID(cls, "numTrailingBlanks", "I"); + ans.num_trailing_blanks = env->GetIntField(config, fid); + + //---------- feat config ---------- + fid = env->GetFieldID(cls, "featConfig", + "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); + jobject feat_config = env->GetObjectField(config, fid); + jclass feat_config_cls = env->GetObjectClass(feat_config); + + fid = env->GetFieldID(feat_config_cls, "sampleRate", "I"); + ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid); + + fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); + ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); + + //---------- model config ---------- + fid = env->GetFieldID(cls, "modelConfig", + "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;"); + jobject model_config = env->GetObjectField(config, fid); + jclass model_config_cls = env->GetObjectClass(model_config); + + // transducer + fid = env->GetFieldID(model_config_cls, "transducer", + "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;"); + jobject transducer_config = env->GetObjectField(model_config, fid); + jclass transducer_config_cls = env->GetObjectClass(transducer_config); + + fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.encoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.decoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.joiner = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.tokens = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "numThreads", "I"); + ans.model_config.num_threads = env->GetIntField(model_config, fid); + + fid = env->GetFieldID(model_config_cls, "debug", "Z"); + ans.model_config.debug = env->GetBooleanField(model_config, fid); + + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.provider_config.provider = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.model_type = p; + env->ReleaseStringUTFChars(s, p); + + return ans; +} + +} // namespace sherpa_mnn + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_mnn_KeywordSpotter_newFromAsset( + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; + } +#endif + auto config = sherpa_mnn::GetKwsConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + auto kws = new sherpa_mnn::KeywordSpotter( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)kws; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_mnn_KeywordSpotter_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + auto config = sherpa_mnn::GetKwsConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + return 0; + } + + auto kws = new sherpa_mnn::KeywordSpotter(config); + + return (jlong)kws; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_KeywordSpotter_delete( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_KeywordSpotter_decode( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr, jlong stream_ptr) { + auto kws = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + kws->DecodeStream(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_KeywordSpotter_reset( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr, jlong stream_ptr) { + auto kws = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + kws->Reset(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_mnn_KeywordSpotter_createStream( + JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) { + auto kws = reinterpret_cast(ptr); + + const char *p = env->GetStringUTFChars(keywords, nullptr); + std::unique_ptr stream; + + if (strlen(p) == 0) { + stream = kws->CreateStream(); + } else { + stream = kws->CreateStream(p); + } + + env->ReleaseStringUTFChars(keywords, p); + + // The user is responsible to free the returned pointer. + // + // See Java_com_k2fsa_sherpa_mnn_OfflineStream_delete() from + // ./offline-stream.cc + sherpa_mnn::OnlineStream *ans = stream.release(); + return (jlong)ans; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_mnn_KeywordSpotter_isReady( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr, jlong stream_ptr) { + auto kws = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + return kws->IsReady(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_mnn_KeywordSpotter_getResult(JNIEnv *env, + jobject /*obj*/, jlong ptr, + jlong stream_ptr) { + auto kws = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + sherpa_mnn::KeywordResult result = kws->GetResult(stream); + + // [0]: keyword, jstring + // [1]: tokens, array of jstring + // [2]: timestamps, array of float + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + 3, env->FindClass("java/lang/Object"), nullptr); + + jstring keyword = env->NewStringUTF(result.keyword.c_str()); + env->SetObjectArrayElement(obj_arr, 0, keyword); + + jobjectArray tokens_arr = (jobjectArray)env->NewObjectArray( + result.tokens.size(), env->FindClass("java/lang/String"), nullptr); + + int32_t i = 0; + for (const auto &t : result.tokens) { + jstring jtext = env->NewStringUTF(t.c_str()); + env->SetObjectArrayElement(tokens_arr, i, jtext); + i += 1; + } + + env->SetObjectArrayElement(obj_arr, 1, tokens_arr); + + jfloatArray timestamps_arr = env->NewFloatArray(result.timestamps.size()); + env->SetFloatArrayRegion(timestamps_arr, 0, result.timestamps.size(), + result.timestamps.data()); + + env->SetObjectArrayElement(obj_arr, 2, timestamps_arr); + + return obj_arr; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/offline-punctuation.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/offline-punctuation.cc new file mode 100644 index 00000000..48d48b58 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/offline-punctuation.cc @@ -0,0 +1,110 @@ +// sherpa-mnn/jni/offline-punctuation.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-punctuation.h" + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/jni/common.h" + +namespace sherpa_mnn { + +static OfflinePunctuationConfig GetOfflinePunctuationConfig(JNIEnv *env, + jobject config) { + OfflinePunctuationConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid; + + fid = env->GetFieldID( + cls, "model", "Lcom/k2fsa/sherpa/onnx/OfflinePunctuationModelConfig;"); + jobject model_config = env->GetObjectField(config, fid); + jclass model_config_cls = env->GetObjectClass(model_config); + + fid = + env->GetFieldID(model_config_cls, "ctTransformer", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(model_config, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.model.ct_transformer = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "numThreads", "I"); + ans.model.num_threads = env->GetIntField(model_config, fid); + + fid = env->GetFieldID(model_config_cls, "debug", "Z"); + ans.model.debug = env->GetBooleanField(model_config, fid); + + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.provider = p; + env->ReleaseStringUTFChars(s, p); + + return ans; +} + +} // namespace sherpa_mnn + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_OfflinePunctuation_newFromAsset( + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; + } +#endif + auto config = sherpa_mnn::GetOfflinePunctuationConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + auto model = new sherpa_mnn::OfflinePunctuation( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)model; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_OfflinePunctuation_newFromFile(JNIEnv *env, + jobject /*obj*/, + jobject _config) { + auto config = sherpa_mnn::GetOfflinePunctuationConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + return 0; + } + + auto model = new sherpa_mnn::OfflinePunctuation(config); + + return (jlong)model; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_OfflinePunctuation_delete( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jstring JNICALL +Java_com_k2fsa_sherpa_mnn_OfflinePunctuation_addPunctuation(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jstring text) { + auto punct = reinterpret_cast(ptr); + + const char *ptext = env->GetStringUTFChars(text, nullptr); + + std::string result = punct->AddPunctuation(ptext); + + env->ReleaseStringUTFChars(text, ptext); + + return env->NewStringUTF(result.c_str()); +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/offline-recognizer.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/offline-recognizer.cc new file mode 100644 index 00000000..d6d1cbac --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/offline-recognizer.cc @@ -0,0 +1,420 @@ +// sherpa-mnn/jni/offline-recognizer.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-recognizer.h" + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/jni/common.h" + +namespace sherpa_mnn { + +static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { + OfflineRecognizerConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid; + + //---------- decoding ---------- + fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(config, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.decoding_method = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "maxActivePaths", "I"); + ans.max_active_paths = env->GetIntField(config, fid); + + fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.hotwords_file = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "hotwordsScore", "F"); + ans.hotwords_score = env->GetFloatField(config, fid); + + fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.rule_fsts = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "ruleFars", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.rule_fars = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "blankPenalty", "F"); + ans.blank_penalty = env->GetFloatField(config, fid); + + //---------- feat config ---------- + fid = env->GetFieldID(cls, "featConfig", + "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); + jobject feat_config = env->GetObjectField(config, fid); + jclass feat_config_cls = env->GetObjectClass(feat_config); + + fid = env->GetFieldID(feat_config_cls, "sampleRate", "I"); + ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid); + + fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); + ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); + + //---------- model config ---------- + fid = env->GetFieldID(cls, "modelConfig", + "Lcom/k2fsa/sherpa/onnx/OfflineModelConfig;"); + jobject model_config = env->GetObjectField(config, fid); + jclass model_config_cls = env->GetObjectClass(model_config); + + fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.tokens = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "numThreads", "I"); + ans.model_config.num_threads = env->GetIntField(model_config, fid); + + fid = env->GetFieldID(model_config_cls, "debug", "Z"); + ans.model_config.debug = env->GetBooleanField(model_config, fid); + + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.provider = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.model_type = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.modeling_unit = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.bpe_vocab = p; + env->ReleaseStringUTFChars(s, p); + + // transducer + fid = env->GetFieldID(model_config_cls, "transducer", + "Lcom/k2fsa/sherpa/onnx/OfflineTransducerModelConfig;"); + jobject transducer_config = env->GetObjectField(model_config, fid); + jclass transducer_config_cls = env->GetObjectClass(transducer_config); + + fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.encoder_filename = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.decoder_filename = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.joiner_filename = p; + env->ReleaseStringUTFChars(s, p); + + // paraformer + fid = env->GetFieldID(model_config_cls, "paraformer", + "Lcom/k2fsa/sherpa/onnx/OfflineParaformerModelConfig;"); + jobject paraformer_config = env->GetObjectField(model_config, fid); + jclass paraformer_config_cls = env->GetObjectClass(paraformer_config); + + fid = env->GetFieldID(paraformer_config_cls, "model", "Ljava/lang/String;"); + + s = (jstring)env->GetObjectField(paraformer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.paraformer.model = p; + env->ReleaseStringUTFChars(s, p); + + // whisper + fid = env->GetFieldID(model_config_cls, "whisper", + "Lcom/k2fsa/sherpa/onnx/OfflineWhisperModelConfig;"); + jobject whisper_config = env->GetObjectField(model_config, fid); + jclass whisper_config_cls = env->GetObjectClass(whisper_config); + + fid = env->GetFieldID(whisper_config_cls, "encoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(whisper_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.whisper.encoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(whisper_config_cls, "decoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(whisper_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.whisper.decoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(whisper_config_cls, "language", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(whisper_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.whisper.language = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(whisper_config_cls, "task", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(whisper_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.whisper.task = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(whisper_config_cls, "tailPaddings", "I"); + ans.model_config.whisper.tail_paddings = + env->GetIntField(whisper_config, fid); + + // FireRedAsr + fid = env->GetFieldID(model_config_cls, "fireRedAsr", + "Lcom/k2fsa/sherpa/onnx/OfflineFireRedAsrModelConfig;"); + jobject fire_red_asr_config = env->GetObjectField(model_config, fid); + jclass fire_red_asr_config_cls = env->GetObjectClass(fire_red_asr_config); + + fid = + env->GetFieldID(fire_red_asr_config_cls, "encoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(fire_red_asr_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.fire_red_asr.encoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = + env->GetFieldID(fire_red_asr_config_cls, "decoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(fire_red_asr_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.fire_red_asr.decoder = p; + env->ReleaseStringUTFChars(s, p); + + // moonshine + fid = env->GetFieldID(model_config_cls, "moonshine", + "Lcom/k2fsa/sherpa/onnx/OfflineMoonshineModelConfig;"); + jobject moonshine_config = env->GetObjectField(model_config, fid); + jclass moonshine_config_cls = env->GetObjectClass(moonshine_config); + + fid = env->GetFieldID(moonshine_config_cls, "preprocessor", + "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(moonshine_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.moonshine.preprocessor = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(moonshine_config_cls, "encoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(moonshine_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.moonshine.encoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(moonshine_config_cls, "uncachedDecoder", + "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(moonshine_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.moonshine.uncached_decoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(moonshine_config_cls, "cachedDecoder", + "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(moonshine_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.moonshine.cached_decoder = p; + env->ReleaseStringUTFChars(s, p); + + // sense voice + fid = env->GetFieldID(model_config_cls, "senseVoice", + "Lcom/k2fsa/sherpa/onnx/OfflineSenseVoiceModelConfig;"); + jobject sense_voice_config = env->GetObjectField(model_config, fid); + jclass sense_voice_config_cls = env->GetObjectClass(sense_voice_config); + + fid = env->GetFieldID(sense_voice_config_cls, "model", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(sense_voice_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.sense_voice.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = + env->GetFieldID(sense_voice_config_cls, "language", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(sense_voice_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.sense_voice.language = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(sense_voice_config_cls, "useInverseTextNormalization", + "Z"); + ans.model_config.sense_voice.use_itn = + env->GetBooleanField(sense_voice_config, fid); + + // nemo + fid = env->GetFieldID( + model_config_cls, "nemo", + "Lcom/k2fsa/sherpa/onnx/OfflineNemoEncDecCtcModelConfig;"); + jobject nemo_config = env->GetObjectField(model_config, fid); + jclass nemo_config_cls = env->GetObjectClass(nemo_config); + + fid = env->GetFieldID(nemo_config_cls, "model", "Ljava/lang/String;"); + + s = (jstring)env->GetObjectField(nemo_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.nemo_ctc.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "teleSpeech", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.telespeech_ctc = p; + env->ReleaseStringUTFChars(s, p); + + return ans; +} + +} // namespace sherpa_mnn + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_OfflineRecognizer_newFromAsset(JNIEnv *env, + jobject /*obj*/, + jobject asset_manager, + jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; + } +#endif + auto config = sherpa_mnn::GetOfflineConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + auto model = new sherpa_mnn::OfflineRecognizer( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)model; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_OfflineRecognizer_newFromFile(JNIEnv *env, + jobject /*obj*/, + jobject _config) { + auto config = sherpa_mnn::GetOfflineConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + return 0; + } + + auto model = new sherpa_mnn::OfflineRecognizer(config); + + return (jlong)model; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_OfflineRecognizer_setConfig( + JNIEnv *env, jobject /*obj*/, jlong ptr, jobject _config) { + auto config = sherpa_mnn::GetOfflineConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + auto recognizer = reinterpret_cast(ptr); + recognizer->SetConfig(config); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_OfflineRecognizer_delete( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_OfflineRecognizer_createStream(JNIEnv * /*env*/, + jobject /*obj*/, + jlong ptr) { + auto recognizer = reinterpret_cast(ptr); + std::unique_ptr s = recognizer->CreateStream(); + + // The user is responsible to free the returned pointer. + // + // See Java_com_k2fsa_sherpa_mnn_OfflineStream_delete() from + // ./offline-stream.cc + sherpa_mnn::OfflineStream *p = s.release(); + return (jlong)p; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_OfflineRecognizer_decode( + JNIEnv *env, jobject /*obj*/, jlong ptr, jlong streamPtr) { + SafeJNI(env, "OfflineRecognizer_decode", [&] { + if (!ValidatePointer(env, ptr, "OfflineRecognizer_decode", + "OfflineRecognizer pointer is null.") || + !ValidatePointer(env, streamPtr, "OfflineRecognizer_decode", + "OfflineStream pointer is null.")) { + return; + } + + auto recognizer = reinterpret_cast(ptr); + auto stream = reinterpret_cast(streamPtr); + recognizer->DecodeStream(stream); + }); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_mnn_OfflineRecognizer_getResult(JNIEnv *env, + jobject /*obj*/, + jlong streamPtr) { + auto stream = reinterpret_cast(streamPtr); + sherpa_mnn::OfflineRecognitionResult result = stream->GetResult(); + + // [0]: text, jstring + // [1]: tokens, array of jstring + // [2]: timestamps, array of float + // [3]: lang, jstring + // [4]: emotion, jstring + // [5]: event, jstring + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + 6, env->FindClass("java/lang/Object"), nullptr); + + jstring text = env->NewStringUTF(result.text.c_str()); + env->SetObjectArrayElement(obj_arr, 0, text); + + jobjectArray tokens_arr = (jobjectArray)env->NewObjectArray( + result.tokens.size(), env->FindClass("java/lang/String"), nullptr); + + int32_t i = 0; + for (const auto &t : result.tokens) { + jstring jtext = env->NewStringUTF(t.c_str()); + env->SetObjectArrayElement(tokens_arr, i, jtext); + i += 1; + } + + env->SetObjectArrayElement(obj_arr, 1, tokens_arr); + + jfloatArray timestamps_arr = env->NewFloatArray(result.timestamps.size()); + env->SetFloatArrayRegion(timestamps_arr, 0, result.timestamps.size(), + result.timestamps.data()); + + env->SetObjectArrayElement(obj_arr, 2, timestamps_arr); + + // [3]: lang, jstring + // [4]: emotion, jstring + // [5]: event, jstring + env->SetObjectArrayElement(obj_arr, 3, + env->NewStringUTF(result.lang.c_str())); + env->SetObjectArrayElement(obj_arr, 4, + env->NewStringUTF(result.emotion.c_str())); + env->SetObjectArrayElement(obj_arr, 5, + env->NewStringUTF(result.event.c_str())); + + return obj_arr; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/offline-speaker-diarization.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/offline-speaker-diarization.cc new file mode 100644 index 00000000..1f7b1e75 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/offline-speaker-diarization.cc @@ -0,0 +1,237 @@ +// sherpa-mnn/jni/offline-speaker-diarization.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-speaker-diarization.h" + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/jni/common.h" + +namespace sherpa_mnn { + +static OfflineSpeakerDiarizationConfig GetOfflineSpeakerDiarizationConfig( + JNIEnv *env, jobject config) { + OfflineSpeakerDiarizationConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid; + + //---------- segmentation ---------- + fid = env->GetFieldID( + cls, "segmentation", + "Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationModelConfig;"); + jobject segmentation_config = env->GetObjectField(config, fid); + jclass segmentation_config_cls = env->GetObjectClass(segmentation_config); + + fid = env->GetFieldID( + segmentation_config_cls, "pyannote", + "Lcom/k2fsa/sherpa/onnx/OfflineSpeakerSegmentationPyannoteModelConfig;"); + jobject pyannote_config = env->GetObjectField(segmentation_config, fid); + jclass pyannote_config_cls = env->GetObjectClass(pyannote_config); + + fid = env->GetFieldID(pyannote_config_cls, "model", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(pyannote_config, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.segmentation.pyannote.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(segmentation_config_cls, "numThreads", "I"); + ans.segmentation.num_threads = env->GetIntField(segmentation_config, fid); + + fid = env->GetFieldID(segmentation_config_cls, "debug", "Z"); + ans.segmentation.debug = env->GetBooleanField(segmentation_config, fid); + + fid = env->GetFieldID(segmentation_config_cls, "provider", + "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(segmentation_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.segmentation.provider = p; + env->ReleaseStringUTFChars(s, p); + + //---------- embedding ---------- + fid = env->GetFieldID( + cls, "embedding", + "Lcom/k2fsa/sherpa/onnx/SpeakerEmbeddingExtractorConfig;"); + jobject embedding_config = env->GetObjectField(config, fid); + jclass embedding_config_cls = env->GetObjectClass(embedding_config); + + fid = env->GetFieldID(embedding_config_cls, "model", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(embedding_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.embedding.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(embedding_config_cls, "numThreads", "I"); + ans.embedding.num_threads = env->GetIntField(embedding_config, fid); + + fid = env->GetFieldID(embedding_config_cls, "debug", "Z"); + ans.embedding.debug = env->GetBooleanField(embedding_config, fid); + + fid = env->GetFieldID(embedding_config_cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(embedding_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.embedding.provider = p; + env->ReleaseStringUTFChars(s, p); + + //---------- clustering ---------- + fid = env->GetFieldID(cls, "clustering", + "Lcom/k2fsa/sherpa/onnx/FastClusteringConfig;"); + jobject clustering_config = env->GetObjectField(config, fid); + jclass clustering_config_cls = env->GetObjectClass(clustering_config); + + fid = env->GetFieldID(clustering_config_cls, "numClusters", "I"); + ans.clustering.num_clusters = env->GetIntField(clustering_config, fid); + + fid = env->GetFieldID(clustering_config_cls, "threshold", "F"); + ans.clustering.threshold = env->GetFloatField(clustering_config, fid); + + // its own fields + fid = env->GetFieldID(cls, "minDurationOn", "F"); + ans.min_duration_on = env->GetFloatField(config, fid); + + fid = env->GetFieldID(cls, "minDurationOff", "F"); + ans.min_duration_off = env->GetFloatField(config, fid); + + return ans; +} + +} // namespace sherpa_mnn + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_OfflineSpeakerDiarization_newFromAsset( + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; + } +#endif + + auto config = sherpa_mnn::GetOfflineSpeakerDiarizationConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + auto sd = new sherpa_mnn::OfflineSpeakerDiarization( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)sd; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_OfflineSpeakerDiarization_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + auto config = sherpa_mnn::GetOfflineSpeakerDiarizationConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + return 0; + } + + auto sd = new sherpa_mnn::OfflineSpeakerDiarization(config); + + return (jlong)sd; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_mnn_OfflineSpeakerDiarization_setConfig( + JNIEnv *env, jobject /*obj*/, jlong ptr, jobject _config) { + auto config = sherpa_mnn::GetOfflineSpeakerDiarizationConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + auto sd = reinterpret_cast(ptr); + sd->SetConfig(config); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_mnn_OfflineSpeakerDiarization_delete(JNIEnv * /*env*/, + jobject /*obj*/, + jlong ptr) { + delete reinterpret_cast(ptr); +} + +static jobjectArray ProcessImpl( + JNIEnv *env, + const std::vector + &segments) { + jclass cls = + env->FindClass("com/k2fsa/sherpa/onnx/OfflineSpeakerDiarizationSegment"); + + jobjectArray obj_arr = + (jobjectArray)env->NewObjectArray(segments.size(), cls, nullptr); + + jmethodID constructor = env->GetMethodID(cls, "", "(FFI)V"); + + for (int32_t i = 0; i != segments.size(); ++i) { + const auto &s = segments[i]; + jobject segment = + env->NewObject(cls, constructor, s.Start(), s.End(), s.Speaker()); + env->SetObjectArrayElement(obj_arr, i, segment); + } + + return obj_arr; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_mnn_OfflineSpeakerDiarization_process( + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) { + auto sd = reinterpret_cast(ptr); + + jfloat *p = env->GetFloatArrayElements(samples, nullptr); + jsize n = env->GetArrayLength(samples); + auto segments = sd->Process(p, n).SortByStartTime(); + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); + + return ProcessImpl(env, segments); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_mnn_OfflineSpeakerDiarization_processWithCallback( + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, + jobject callback, jlong arg) { + std::function callback_wrapper = + [env, callback](int32_t num_processed_chunks, int32_t num_total_chunks, + void *data) -> int { + jclass cls = env->GetObjectClass(callback); + + jmethodID mid = env->GetMethodID(cls, "invoke", "(IIJ)Ljava/lang/Integer;"); + if (mid == nullptr) { + SHERPA_ONNX_LOGE("Failed to get the callback. Ignore it."); + return 0; + } + + jobject ret = env->CallObjectMethod(callback, mid, num_processed_chunks, + num_total_chunks, (jlong)data); + jclass jklass = env->GetObjectClass(ret); + jmethodID int_value_mid = env->GetMethodID(jklass, "intValue", "()I"); + return env->CallIntMethod(ret, int_value_mid); + }; + + auto sd = reinterpret_cast(ptr); + + jfloat *p = env->GetFloatArrayElements(samples, nullptr); + jsize n = env->GetArrayLength(samples); + auto segments = + sd->Process(p, n, callback_wrapper, reinterpret_cast(arg)) + .SortByStartTime(); + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); + + return ProcessImpl(env, segments); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL +Java_com_k2fsa_sherpa_mnn_OfflineSpeakerDiarization_getSampleRate( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + return reinterpret_cast(ptr) + ->SampleRate(); +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/offline-stream.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/offline-stream.cc new file mode 100644 index 00000000..d157cd05 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/offline-stream.cc @@ -0,0 +1,25 @@ +// sherpa-mnn/jni/offline-stream.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-stream.h" + +#include "sherpa-mnn/jni/common.h" + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_OfflineStream_delete( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_OfflineStream_acceptWaveform( + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, + jint sample_rate) { + auto stream = reinterpret_cast(ptr); + + jfloat *p = env->GetFloatArrayElements(samples, nullptr); + jsize n = env->GetArrayLength(samples); + stream->AcceptWaveform(sample_rate, p, n); + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/offline-tts.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/offline-tts.cc new file mode 100644 index 00000000..611a0006 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/offline-tts.cc @@ -0,0 +1,342 @@ +// sherpa-mnn/jni/offline-tts.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-tts.h" + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/jni/common.h" + +namespace sherpa_mnn { + +static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { + OfflineTtsConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid; + + fid = env->GetFieldID(cls, "model", + "Lcom/k2fsa/sherpa/onnx/OfflineTtsModelConfig;"); + jobject model = env->GetObjectField(config, fid); + jclass model_config_cls = env->GetObjectClass(model); + + // vits + fid = env->GetFieldID(model_config_cls, "vits", + "Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;"); + jobject vits = env->GetObjectField(model, fid); + jclass vits_cls = env->GetObjectClass(vits); + + fid = env->GetFieldID(vits_cls, "model", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(vits, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.model.vits.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(vits_cls, "lexicon", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(vits, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.vits.lexicon = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(vits_cls, "tokens", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(vits, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.vits.tokens = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(vits_cls, "dataDir", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(vits, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.vits.data_dir = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(vits_cls, "dictDir", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(vits, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.vits.dict_dir = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(vits_cls, "noiseScale", "F"); + ans.model.vits.noise_scale = env->GetFloatField(vits, fid); + + fid = env->GetFieldID(vits_cls, "noiseScaleW", "F"); + ans.model.vits.noise_scale_w = env->GetFloatField(vits, fid); + + fid = env->GetFieldID(vits_cls, "lengthScale", "F"); + ans.model.vits.length_scale = env->GetFloatField(vits, fid); + + // matcha + fid = env->GetFieldID(model_config_cls, "matcha", + "Lcom/k2fsa/sherpa/onnx/OfflineTtsMatchaModelConfig;"); + jobject matcha = env->GetObjectField(model, fid); + jclass matcha_cls = env->GetObjectClass(matcha); + + fid = env->GetFieldID(matcha_cls, "acousticModel", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(matcha, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.matcha.acoustic_model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(matcha_cls, "vocoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(matcha, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.matcha.vocoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(matcha_cls, "lexicon", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(matcha, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.matcha.lexicon = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(matcha_cls, "tokens", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(matcha, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.matcha.tokens = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(matcha_cls, "dataDir", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(matcha, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.matcha.data_dir = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(matcha_cls, "dictDir", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(matcha, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.matcha.dict_dir = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(matcha_cls, "noiseScale", "F"); + ans.model.matcha.noise_scale = env->GetFloatField(matcha, fid); + + fid = env->GetFieldID(matcha_cls, "lengthScale", "F"); + ans.model.matcha.length_scale = env->GetFloatField(matcha, fid); + + // kokoro + fid = env->GetFieldID(model_config_cls, "kokoro", + "Lcom/k2fsa/sherpa/onnx/OfflineTtsKokoroModelConfig;"); + jobject kokoro = env->GetObjectField(model, fid); + jclass kokoro_cls = env->GetObjectClass(kokoro); + + fid = env->GetFieldID(kokoro_cls, "model", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(kokoro, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.kokoro.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(kokoro_cls, "voices", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(kokoro, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.kokoro.voices = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(kokoro_cls, "tokens", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(kokoro, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.kokoro.tokens = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(kokoro_cls, "lexicon", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(kokoro, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.kokoro.lexicon = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(kokoro_cls, "dataDir", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(kokoro, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.kokoro.data_dir = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(kokoro_cls, "dictDir", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(kokoro, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.kokoro.dict_dir = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(kokoro_cls, "lengthScale", "F"); + ans.model.kokoro.length_scale = env->GetFloatField(kokoro, fid); + + fid = env->GetFieldID(model_config_cls, "numThreads", "I"); + ans.model.num_threads = env->GetIntField(model, fid); + + fid = env->GetFieldID(model_config_cls, "debug", "Z"); + ans.model.debug = env->GetBooleanField(model, fid); + + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.provider = p; + env->ReleaseStringUTFChars(s, p); + + // for ruleFsts + fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.rule_fsts = p; + env->ReleaseStringUTFChars(s, p); + + // for ruleFars + fid = env->GetFieldID(cls, "ruleFars", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.rule_fars = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "maxNumSentences", "I"); + ans.max_num_sentences = env->GetIntField(config, fid); + + fid = env->GetFieldID(cls, "silenceScale", "F"); + ans.silence_scale = env->GetFloatField(config, fid); + + return ans; +} + +} // namespace sherpa_mnn + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_mnn_OfflineTts_newFromAsset( + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; + } +#endif + auto config = sherpa_mnn::GetOfflineTtsConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + auto tts = new sherpa_mnn::OfflineTts( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)tts; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_mnn_OfflineTts_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + return SafeJNI( + env, "OfflineTts_newFromFile", + [&]() -> jlong { + auto config = sherpa_mnn::GetOfflineTtsConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + } + + auto tts = new sherpa_mnn::OfflineTts(config); + return reinterpret_cast(tts); + }, + 0L); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_OfflineTts_delete( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_mnn_OfflineTts_getSampleRate( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + return reinterpret_cast(ptr)->SampleRate(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_mnn_OfflineTts_getNumSpeakers( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + return reinterpret_cast(ptr)->NumSpeakers(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_mnn_OfflineTts_generateImpl(JNIEnv *env, jobject /*obj*/, + jlong ptr, jstring text, + jint sid, jfloat speed) { + const char *p_text = env->GetStringUTFChars(text, nullptr); + + auto audio = reinterpret_cast(ptr)->Generate( + p_text, sid, speed); + + jfloatArray samples_arr = env->NewFloatArray(audio.samples.size()); + env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(), + audio.samples.data()); + + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + 2, env->FindClass("java/lang/Object"), nullptr); + + env->SetObjectArrayElement(obj_arr, 0, samples_arr); + env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate)); + + env->ReleaseStringUTFChars(text, p_text); + + return obj_arr; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_mnn_OfflineTts_generateWithCallbackImpl( + JNIEnv *env, jobject /*obj*/, jlong ptr, jstring text, jint sid, + jfloat speed, jobject callback) { + const char *p_text = env->GetStringUTFChars(text, nullptr); + + std::function callback_wrapper = + [env, callback](const float *samples, int32_t n, + float /*progress*/) -> int { + jclass cls = env->GetObjectClass(callback); + +#if 0 + // this block is for debugging only + // see also + // https://jnjosh.com/posts/kotlinfromcpp/ + jmethodID classMethodId = + env->GetMethodID(cls, "getClass", "()Ljava/lang/Class;"); + jobject klassObj = env->CallObjectMethod(callback, classMethodId); + auto klassObject = env->GetObjectClass(klassObj); + auto nameMethodId = + env->GetMethodID(klassObject, "getName", "()Ljava/lang/String;"); + jstring classString = + (jstring)env->CallObjectMethod(klassObj, nameMethodId); + auto className = env->GetStringUTFChars(classString, NULL); + SHERPA_ONNX_LOGE("name is: %s", className); + env->ReleaseStringUTFChars(classString, className); +#endif + + jmethodID mid = env->GetMethodID(cls, "invoke", "([F)Ljava/lang/Integer;"); + if (mid == nullptr) { + SHERPA_ONNX_LOGE("Failed to get the callback. Ignore it."); + return 1; + } + + jfloatArray samples_arr = env->NewFloatArray(n); + env->SetFloatArrayRegion(samples_arr, 0, n, samples); + + jobject should_continue = env->CallObjectMethod(callback, mid, samples_arr); + jclass jklass = env->GetObjectClass(should_continue); + jmethodID int_value_mid = env->GetMethodID(jklass, "intValue", "()I"); + return env->CallIntMethod(should_continue, int_value_mid); + }; + + auto tts = reinterpret_cast(ptr); + auto audio = tts->Generate(p_text, sid, speed, callback_wrapper); + + jfloatArray samples_arr = env->NewFloatArray(audio.samples.size()); + env->SetFloatArrayRegion(samples_arr, 0, audio.samples.size(), + audio.samples.data()); + + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + 2, env->FindClass("java/lang/Object"), nullptr); + + env->SetObjectArrayElement(obj_arr, 0, samples_arr); + env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, audio.sample_rate)); + + env->ReleaseStringUTFChars(text, p_text); + + return obj_arr; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/online-punctuation.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/online-punctuation.cc new file mode 100644 index 00000000..36f59f74 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/online-punctuation.cc @@ -0,0 +1,117 @@ +// sherpa-mnn/jni/online-punctuation.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-punctuation.h" + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/jni/common.h" + +namespace sherpa_mnn { + +static OnlinePunctuationConfig GetOnlinePunctuationConfig(JNIEnv *env, + jobject config) { + OnlinePunctuationConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid; + + fid = env->GetFieldID(cls, "model", + "Lcom/k2fsa/sherpa/onnx/OnlinePunctuationModelConfig;"); + jobject model_config = env->GetObjectField(config, fid); + jclass model_config_cls = env->GetObjectClass(model_config); + + fid = env->GetFieldID(model_config_cls, "cnnBilstm", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(model_config, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.model.cnn_bilstm = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.bpe_vocab = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "numThreads", "I"); + ans.model.num_threads = env->GetIntField(model_config, fid); + + fid = env->GetFieldID(model_config_cls, "debug", "Z"); + ans.model.debug = env->GetBooleanField(model_config, fid); + + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model.provider = p; + env->ReleaseStringUTFChars(s, p); + + return ans; +} + +} // namespace sherpa_mnn + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_OnlinePunctuation_newFromAsset(JNIEnv *env, + jobject /*obj*/, + jobject asset_manager, + jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; + } +#endif + auto config = sherpa_mnn::GetOnlinePunctuationConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + auto model = new sherpa_mnn::OnlinePunctuation( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)model; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_OnlinePunctuation_newFromFile(JNIEnv *env, + jobject /*obj*/, + jobject _config) { + auto config = sherpa_mnn::GetOnlinePunctuationConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + return 0; + } + + auto model = new sherpa_mnn::OnlinePunctuation(config); + + return (jlong)model; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_OnlinePunctuation_delete( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jstring JNICALL +Java_com_k2fsa_sherpa_mnn_OnlinePunctuation_addPunctuation(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jstring text) { + auto punct = reinterpret_cast(ptr); + + const char *ptext = env->GetStringUTFChars(text, nullptr); + + std::string result = punct->AddPunctuationWithCase(ptext); + + env->ReleaseStringUTFChars(text, ptext); + + return env->NewStringUTF(result.c_str()); +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/online-recognizer.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/online-recognizer.cc new file mode 100644 index 00000000..13cc767e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/online-recognizer.cc @@ -0,0 +1,408 @@ +// sherpa-mnn/jni/online-recognizer.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-recognizer.h" + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/jni/common.h" + +namespace sherpa_mnn { + +static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { + OnlineRecognizerConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid; + + // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html + // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html + + //---------- decoding ---------- + fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(config, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.decoding_method = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "maxActivePaths", "I"); + ans.max_active_paths = env->GetIntField(config, fid); + + fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.hotwords_file = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "hotwordsScore", "F"); + ans.hotwords_score = env->GetFloatField(config, fid); + + fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.rule_fsts = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "ruleFars", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.rule_fars = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "blankPenalty", "F"); + ans.blank_penalty = env->GetFloatField(config, fid); + + //---------- feat config ---------- + fid = env->GetFieldID(cls, "featConfig", + "Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); + jobject feat_config = env->GetObjectField(config, fid); + jclass feat_config_cls = env->GetObjectClass(feat_config); + + fid = env->GetFieldID(feat_config_cls, "sampleRate", "I"); + ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid); + + fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); + ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); + + //---------- enable endpoint ---------- + fid = env->GetFieldID(cls, "enableEndpoint", "Z"); + ans.enable_endpoint = env->GetBooleanField(config, fid); + + //---------- endpoint_config ---------- + + fid = env->GetFieldID(cls, "endpointConfig", + "Lcom/k2fsa/sherpa/onnx/EndpointConfig;"); + jobject endpoint_config = env->GetObjectField(config, fid); + jclass endpoint_config_cls = env->GetObjectClass(endpoint_config); + + fid = env->GetFieldID(endpoint_config_cls, "rule1", + "Lcom/k2fsa/sherpa/onnx/EndpointRule;"); + jobject rule1 = env->GetObjectField(endpoint_config, fid); + jclass rule_class = env->GetObjectClass(rule1); + + fid = env->GetFieldID(endpoint_config_cls, "rule2", + "Lcom/k2fsa/sherpa/onnx/EndpointRule;"); + jobject rule2 = env->GetObjectField(endpoint_config, fid); + + fid = env->GetFieldID(endpoint_config_cls, "rule3", + "Lcom/k2fsa/sherpa/onnx/EndpointRule;"); + jobject rule3 = env->GetObjectField(endpoint_config, fid); + + fid = env->GetFieldID(rule_class, "mustContainNonSilence", "Z"); + ans.endpoint_config.rule1.must_contain_nonsilence = + env->GetBooleanField(rule1, fid); + ans.endpoint_config.rule2.must_contain_nonsilence = + env->GetBooleanField(rule2, fid); + ans.endpoint_config.rule3.must_contain_nonsilence = + env->GetBooleanField(rule3, fid); + + fid = env->GetFieldID(rule_class, "minTrailingSilence", "F"); + ans.endpoint_config.rule1.min_trailing_silence = + env->GetFloatField(rule1, fid); + ans.endpoint_config.rule2.min_trailing_silence = + env->GetFloatField(rule2, fid); + ans.endpoint_config.rule3.min_trailing_silence = + env->GetFloatField(rule3, fid); + + fid = env->GetFieldID(rule_class, "minUtteranceLength", "F"); + ans.endpoint_config.rule1.min_utterance_length = + env->GetFloatField(rule1, fid); + ans.endpoint_config.rule2.min_utterance_length = + env->GetFloatField(rule2, fid); + ans.endpoint_config.rule3.min_utterance_length = + env->GetFloatField(rule3, fid); + + //---------- model config ---------- + fid = env->GetFieldID(cls, "modelConfig", + "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;"); + jobject model_config = env->GetObjectField(config, fid); + jclass model_config_cls = env->GetObjectClass(model_config); + + // transducer + fid = env->GetFieldID(model_config_cls, "transducer", + "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;"); + jobject transducer_config = env->GetObjectField(model_config, fid); + jclass transducer_config_cls = env->GetObjectClass(transducer_config); + + fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.encoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.decoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(transducer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.transducer.joiner = p; + env->ReleaseStringUTFChars(s, p); + + // paraformer + fid = env->GetFieldID(model_config_cls, "paraformer", + "Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;"); + jobject paraformer_config = env->GetObjectField(model_config, fid); + jclass paraformer_config_cls = env->GetObjectClass(paraformer_config); + + fid = env->GetFieldID(paraformer_config_cls, "encoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(paraformer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.paraformer.encoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(paraformer_config_cls, "decoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(paraformer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.paraformer.decoder = p; + env->ReleaseStringUTFChars(s, p); + + // streaming zipformer2 CTC + fid = + env->GetFieldID(model_config_cls, "zipformer2Ctc", + "Lcom/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig;"); + jobject zipformer2_ctc_config = env->GetObjectField(model_config, fid); + jclass zipformer2_ctc_config_cls = env->GetObjectClass(zipformer2_ctc_config); + + fid = + env->GetFieldID(zipformer2_ctc_config_cls, "model", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(zipformer2_ctc_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.zipformer2_ctc.model = p; + env->ReleaseStringUTFChars(s, p); + + // streaming NeMo CTC + fid = env->GetFieldID(model_config_cls, "neMoCtc", + "Lcom/k2fsa/sherpa/onnx/OnlineNeMoCtcModelConfig;"); + jobject nemo_ctc_config = env->GetObjectField(model_config, fid); + jclass nemo_ctc_config_cls = env->GetObjectClass(nemo_ctc_config); + + fid = env->GetFieldID(nemo_ctc_config_cls, "model", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(nemo_ctc_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.nemo_ctc.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.tokens = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "numThreads", "I"); + ans.model_config.num_threads = env->GetIntField(model_config, fid); + + fid = env->GetFieldID(model_config_cls, "debug", "Z"); + ans.model_config.debug = env->GetBooleanField(model_config, fid); + + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.provider_config.provider = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.model_type = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.modeling_unit = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.bpe_vocab = p; + env->ReleaseStringUTFChars(s, p); + + //---------- rnn lm model config ---------- + fid = env->GetFieldID(cls, "lmConfig", + "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;"); + jobject lm_model_config = env->GetObjectField(config, fid); + jclass lm_model_config_cls = env->GetObjectClass(lm_model_config); + + fid = env->GetFieldID(lm_model_config_cls, "model", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(lm_model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.lm_config.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(lm_model_config_cls, "scale", "F"); + ans.lm_config.scale = env->GetFloatField(lm_model_config, fid); + + fid = env->GetFieldID(cls, "ctcFstDecoderConfig", + "Lcom/k2fsa/sherpa/onnx/OnlineCtcFstDecoderConfig;"); + + jobject fst_decoder_config = env->GetObjectField(config, fid); + jclass fst_decoder_config_cls = env->GetObjectClass(fst_decoder_config); + + fid = env->GetFieldID(fst_decoder_config_cls, "graph", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(fst_decoder_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.ctc_fst_decoder_config.graph = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(fst_decoder_config_cls, "maxActive", "I"); + ans.ctc_fst_decoder_config.max_active = + env->GetIntField(fst_decoder_config, fid); + + return ans; +} +} // namespace sherpa_mnn + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_OnlineRecognizer_newFromAsset(JNIEnv *env, + jobject /*obj*/, + jobject asset_manager, + jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; + } +#endif + auto config = sherpa_mnn::GetConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + auto recognizer = new sherpa_mnn::OnlineRecognizer( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)recognizer; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_mnn_OnlineRecognizer_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + auto config = sherpa_mnn::GetConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + return 0; + } + + auto recognizer = new sherpa_mnn::OnlineRecognizer(config); + + return (jlong)recognizer; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_OnlineRecognizer_delete( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_OnlineRecognizer_reset( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr, jlong stream_ptr) { + auto recognizer = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + recognizer->Reset(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_mnn_OnlineRecognizer_isReady( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr, jlong stream_ptr) { + auto recognizer = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + return recognizer->IsReady(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_mnn_OnlineRecognizer_isEndpoint( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr, jlong stream_ptr) { + auto recognizer = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + return recognizer->IsEndpoint(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_OnlineRecognizer_decode( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr, jlong stream_ptr) { + auto recognizer = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + recognizer->DecodeStream(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_OnlineRecognizer_createStream(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jstring hotwords) { + auto recognizer = reinterpret_cast(ptr); + + const char *p = env->GetStringUTFChars(hotwords, nullptr); + std::unique_ptr stream; + + if (strlen(p) == 0) { + stream = recognizer->CreateStream(); + } else { + stream = recognizer->CreateStream(p); + } + + env->ReleaseStringUTFChars(hotwords, p); + + // The user is responsible to free the returned pointer. + // + // See Java_com_k2fsa_sherpa_mnn_OfflineStream_delete() from + // ./offline-stream.cc + sherpa_mnn::OnlineStream *ans = stream.release(); + return (jlong)ans; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_mnn_OnlineRecognizer_getResult(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jlong stream_ptr) { + auto recognizer = reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + sherpa_mnn::OnlineRecognizerResult result = recognizer->GetResult(stream); + + // [0]: text, jstring + // [1]: tokens, array of jstring + // [2]: timestamps, array of float + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + 3, env->FindClass("java/lang/Object"), nullptr); + + jstring text = env->NewStringUTF(result.text.c_str()); + env->SetObjectArrayElement(obj_arr, 0, text); + + jobjectArray tokens_arr = (jobjectArray)env->NewObjectArray( + result.tokens.size(), env->FindClass("java/lang/String"), nullptr); + + int32_t i = 0; + for (const auto &t : result.tokens) { + jstring jtext = env->NewStringUTF(t.c_str()); + env->SetObjectArrayElement(tokens_arr, i, jtext); + i += 1; + } + + env->SetObjectArrayElement(obj_arr, 1, tokens_arr); + + jfloatArray timestamps_arr = env->NewFloatArray(result.timestamps.size()); + env->SetFloatArrayRegion(timestamps_arr, 0, result.timestamps.size(), + result.timestamps.data()); + + env->SetObjectArrayElement(obj_arr, 2, timestamps_arr); + + return obj_arr; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/online-stream.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/online-stream.cc new file mode 100644 index 00000000..881a3b30 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/online-stream.cc @@ -0,0 +1,32 @@ +// sherpa-mnn/jni/online-stream.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-stream.h" + +#include "sherpa-mnn/jni/common.h" + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_OnlineStream_delete( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_OnlineStream_acceptWaveform( + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, + jint sample_rate) { + auto stream = reinterpret_cast(ptr); + + jfloat *p = env->GetFloatArrayElements(samples, nullptr); + jsize n = env->GetArrayLength(samples); + stream->AcceptWaveform(sample_rate, p, n); + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_OnlineStream_inputFinished( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + auto stream = reinterpret_cast(ptr); + stream->InputFinished(); +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/speaker-embedding-extractor.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/speaker-embedding-extractor.cc new file mode 100644 index 00000000..220c4f13 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/speaker-embedding-extractor.cc @@ -0,0 +1,138 @@ +// sherpa-mnn/jni/speaker-embedding-extractor.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-mnn/csrc/speaker-embedding-extractor.h" + +#include "sherpa-mnn/jni/common.h" + +namespace sherpa_mnn { + +static SpeakerEmbeddingExtractorConfig GetSpeakerEmbeddingExtractorConfig( + JNIEnv *env, jobject config) { + SpeakerEmbeddingExtractorConfig ans; + + jclass cls = env->GetObjectClass(config); + + jfieldID fid = env->GetFieldID(cls, "model", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(config, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + + ans.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "numThreads", "I"); + ans.num_threads = env->GetIntField(config, fid); + + fid = env->GetFieldID(cls, "debug", "Z"); + ans.debug = env->GetBooleanField(config, fid); + + fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.provider = p; + env->ReleaseStringUTFChars(s, p); + + return ans; +} + +} // namespace sherpa_mnn + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingExtractor_newFromAsset( + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; + } +#endif + auto config = sherpa_mnn::GetSpeakerEmbeddingExtractorConfig(env, _config); + SHERPA_ONNX_LOGE("new config:\n%s", config.ToString().c_str()); + + auto extractor = new sherpa_mnn::SpeakerEmbeddingExtractor( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)extractor; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingExtractor_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + auto config = sherpa_mnn::GetSpeakerEmbeddingExtractorConfig(env, _config); + SHERPA_ONNX_LOGE("newFromFile config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + } + + auto extractor = new sherpa_mnn::SpeakerEmbeddingExtractor(config); + + return (jlong)extractor; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingExtractor_delete(JNIEnv * /*env*/, + jobject /*obj*/, + jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingExtractor_createStream( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + std::unique_ptr s = + reinterpret_cast(ptr) + ->CreateStream(); + + // The user is responsible to free the returned pointer. + // + // See Java_com_k2fsa_sherpa_mnn_OnlineStream_delete() from + // ./online-stream.cc + sherpa_mnn::OnlineStream *p = s.release(); + return (jlong)p; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jboolean JNICALL +Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingExtractor_isReady(JNIEnv * /*env*/, + jobject /*obj*/, + jlong ptr, + jlong stream_ptr) { + auto extractor = + reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + return extractor->IsReady(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jfloatArray JNICALL +Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingExtractor_compute(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jlong stream_ptr) { + auto extractor = + reinterpret_cast(ptr); + auto stream = reinterpret_cast(stream_ptr); + + std::vector embedding = extractor->Compute(stream); + jfloatArray embedding_arr = env->NewFloatArray(embedding.size()); + env->SetFloatArrayRegion(embedding_arr, 0, embedding.size(), + embedding.data()); + return embedding_arr; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingExtractor_dim( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + auto extractor = + reinterpret_cast(ptr); + return extractor->Dim(); +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/speaker-embedding-manager.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/speaker-embedding-manager.cc new file mode 100644 index 00000000..5c94e215 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/speaker-embedding-manager.cc @@ -0,0 +1,207 @@ +// sherpa-mnn/jni/speaker-embedding-manager.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-mnn/csrc/speaker-embedding-manager.h" + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/jni/common.h" + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingManager_create(JNIEnv *env, + jobject /*obj*/, + jint dim) { + auto p = new sherpa_mnn::SpeakerEmbeddingManager(dim); + return (jlong)p; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingManager_delete(JNIEnv * /*env*/, + jobject /*obj*/, + jlong ptr) { + auto manager = reinterpret_cast(ptr); + delete manager; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jboolean JNICALL +Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingManager_add(JNIEnv *env, + jobject /*obj*/, + jlong ptr, jstring name, + jfloatArray embedding) { + auto manager = reinterpret_cast(ptr); + + jfloat *p = env->GetFloatArrayElements(embedding, nullptr); + jsize n = env->GetArrayLength(embedding); + + if (n != manager->Dim()) { + SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(), + static_cast(n)); + exit(-1); + } + + const char *p_name = env->GetStringUTFChars(name, nullptr); + + jboolean ok = manager->Add(p_name, p); + env->ReleaseStringUTFChars(name, p_name); + env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT); + + return ok; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jboolean JNICALL +Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingManager_addList( + JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name, + jobjectArray embedding_arr) { + auto manager = reinterpret_cast(ptr); + + int num_embeddings = env->GetArrayLength(embedding_arr); + if (num_embeddings == 0) { + return false; + } + + std::vector> embedding_list; + embedding_list.reserve(num_embeddings); + for (int32_t i = 0; i != num_embeddings; ++i) { + jfloatArray embedding = + (jfloatArray)env->GetObjectArrayElement(embedding_arr, i); + + jfloat *p = env->GetFloatArrayElements(embedding, nullptr); + jsize n = env->GetArrayLength(embedding); + + if (n != manager->Dim()) { + SHERPA_ONNX_LOGE("i: %d. Expected dim %d, given %d", i, manager->Dim(), + static_cast(n)); + exit(-1); + } + + embedding_list.push_back({p, p + n}); + env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT); + } + + const char *p_name = env->GetStringUTFChars(name, nullptr); + + jboolean ok = manager->Add(p_name, embedding_list); + + env->ReleaseStringUTFChars(name, p_name); + + return ok; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jboolean JNICALL +Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingManager_remove(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jstring name) { + auto manager = reinterpret_cast(ptr); + + const char *p_name = env->GetStringUTFChars(name, nullptr); + + jboolean ok = manager->Remove(p_name); + + env->ReleaseStringUTFChars(name, p_name); + + return ok; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jstring JNICALL +Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingManager_search(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jfloatArray embedding, + jfloat threshold) { + auto manager = reinterpret_cast(ptr); + + jfloat *p = env->GetFloatArrayElements(embedding, nullptr); + jsize n = env->GetArrayLength(embedding); + + if (n != manager->Dim()) { + SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(), + static_cast(n)); + exit(-1); + } + + std::string name = manager->Search(p, threshold); + + env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT); + + return env->NewStringUTF(name.c_str()); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jboolean JNICALL +Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingManager_verify( + JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name, + jfloatArray embedding, jfloat threshold) { + auto manager = reinterpret_cast(ptr); + + jfloat *p = env->GetFloatArrayElements(embedding, nullptr); + jsize n = env->GetArrayLength(embedding); + + if (n != manager->Dim()) { + SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(), + static_cast(n)); + exit(-1); + } + + const char *p_name = env->GetStringUTFChars(name, nullptr); + + jboolean ok = manager->Verify(p_name, p, threshold); + + env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT); + + env->ReleaseStringUTFChars(name, p_name); + + return ok; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jboolean JNICALL +Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingManager_contains(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jstring name) { + auto manager = reinterpret_cast(ptr); + + const char *p_name = env->GetStringUTFChars(name, nullptr); + + jboolean ok = manager->Contains(p_name); + + env->ReleaseStringUTFChars(name, p_name); + + return ok; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jint JNICALL +Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingManager_numSpeakers(JNIEnv * /*env*/, + jobject /*obj*/, + jlong ptr) { + auto manager = reinterpret_cast(ptr); + return manager->NumSpeakers(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_mnn_SpeakerEmbeddingManager_allSpeakerNames( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + auto manager = reinterpret_cast(ptr); + std::vector all_speakers = manager->GetAllSpeakers(); + + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + all_speakers.size(), env->FindClass("java/lang/String"), nullptr); + + int32_t i = 0; + for (auto &s : all_speakers) { + jstring js = env->NewStringUTF(s.c_str()); + env->SetObjectArrayElement(obj_arr, i, js); + + ++i; + } + + return obj_arr; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/spoken-language-identification.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/spoken-language-identification.cc new file mode 100644 index 00000000..74cd2998 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/spoken-language-identification.cc @@ -0,0 +1,139 @@ +// sherpa-mnn/jni/spoken-language-identification.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/spoken-language-identification.h" + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/jni/common.h" + +namespace sherpa_mnn { + +static SpokenLanguageIdentificationConfig GetSpokenLanguageIdentificationConfig( + JNIEnv *env, jobject config) { + SpokenLanguageIdentificationConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid = env->GetFieldID( + cls, "whisper", + "Lcom/k2fsa/sherpa/onnx/SpokenLanguageIdentificationWhisperConfig;"); + + jobject whisper = env->GetObjectField(config, fid); + jclass whisper_cls = env->GetObjectClass(whisper); + + fid = env->GetFieldID(whisper_cls, "encoder", "Ljava/lang/String;"); + + jstring s = (jstring)env->GetObjectField(whisper, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + ans.whisper.encoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(whisper_cls, "decoder", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(whisper, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.whisper.decoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(whisper_cls, "tailPaddings", "I"); + ans.whisper.tail_paddings = env->GetIntField(whisper, fid); + + fid = env->GetFieldID(cls, "numThreads", "I"); + ans.num_threads = env->GetIntField(config, fid); + + fid = env->GetFieldID(cls, "debug", "Z"); + ans.debug = env->GetBooleanField(config, fid); + + fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.provider = p; + env->ReleaseStringUTFChars(s, p); + + return ans; +} + +} // namespace sherpa_mnn + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_SpokenLanguageIdentification_newFromAsset( + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; + } +#endif + + auto config = + sherpa_mnn::GetSpokenLanguageIdentificationConfig(env, _config); + SHERPA_ONNX_LOGE("spoken language identification newFromAsset config:\n%s", + config.ToString().c_str()); + + auto slid = new sherpa_mnn::SpokenLanguageIdentification( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + SHERPA_ONNX_LOGE("slid %p", slid); + + return (jlong)slid; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_SpokenLanguageIdentification_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + auto config = + sherpa_mnn::GetSpokenLanguageIdentificationConfig(env, _config); + SHERPA_ONNX_LOGE("SpokenLanguageIdentification newFromFile config:\n%s", + config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + return 0; + } + + auto tagger = new sherpa_mnn::SpokenLanguageIdentification(config); + + return (jlong)tagger; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_mnn_SpokenLanguageIdentification_delete(JNIEnv * /*env*/, + jobject /*obj*/, + jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_mnn_SpokenLanguageIdentification_createStream( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + auto slid = + reinterpret_cast(ptr); + std::unique_ptr s = slid->CreateStream(); + + // The user is responsible to free the returned pointer. + // + // See Java_com_k2fsa_sherpa_mnn_OfflineStream_delete() from + // ./offline-stream.cc + sherpa_mnn::OfflineStream *p = s.release(); + return (jlong)p; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jstring JNICALL +Java_com_k2fsa_sherpa_mnn_SpokenLanguageIdentification_compute(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jlong s_ptr) { + sherpa_mnn::SpokenLanguageIdentification *slid = + reinterpret_cast(ptr); + sherpa_mnn::OfflineStream *s = + reinterpret_cast(s_ptr); + std::string lang = slid->Compute(s); + return env->NewStringUTF(lang.c_str()); +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/voice-activity-detector.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/voice-activity-detector.cc new file mode 100644 index 00000000..6eb58bcb --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/voice-activity-detector.cc @@ -0,0 +1,201 @@ +// sherpa-mnn/csrc/voice-activity-detector.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-mnn/csrc/voice-activity-detector.h" + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/jni/common.h" + +namespace sherpa_mnn { + +static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) { + VadModelConfig ans; + + jclass cls = env->GetObjectClass(config); + jfieldID fid; + + // silero_vad + fid = env->GetFieldID(cls, "sileroVadModelConfig", + "Lcom/k2fsa/sherpa/onnx/SileroVadModelConfig;"); + jobject silero_vad_config = env->GetObjectField(config, fid); + jclass silero_vad_config_cls = env->GetObjectClass(silero_vad_config); + + fid = env->GetFieldID(silero_vad_config_cls, "model", "Ljava/lang/String;"); + auto s = (jstring)env->GetObjectField(silero_vad_config, fid); + auto p = env->GetStringUTFChars(s, nullptr); + ans.silero_vad.model = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(silero_vad_config_cls, "threshold", "F"); + ans.silero_vad.threshold = env->GetFloatField(silero_vad_config, fid); + + fid = env->GetFieldID(silero_vad_config_cls, "minSilenceDuration", "F"); + ans.silero_vad.min_silence_duration = + env->GetFloatField(silero_vad_config, fid); + + fid = env->GetFieldID(silero_vad_config_cls, "minSpeechDuration", "F"); + ans.silero_vad.min_speech_duration = + env->GetFloatField(silero_vad_config, fid); + + fid = env->GetFieldID(silero_vad_config_cls, "windowSize", "I"); + ans.silero_vad.window_size = env->GetIntField(silero_vad_config, fid); + + fid = env->GetFieldID(silero_vad_config_cls, "maxSpeechDuration", "F"); + ans.silero_vad.max_speech_duration = + env->GetFloatField(silero_vad_config, fid); + + fid = env->GetFieldID(cls, "sampleRate", "I"); + ans.sample_rate = env->GetIntField(config, fid); + + fid = env->GetFieldID(cls, "numThreads", "I"); + ans.num_threads = env->GetIntField(config, fid); + + fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.provider = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "debug", "Z"); + ans.debug = env->GetBooleanField(config, fid); + + return ans; +} + +} // namespace sherpa_mnn + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_mnn_Vad_newFromAsset( + JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; + } +#endif + auto config = sherpa_mnn::GetVadModelConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + auto model = new sherpa_mnn::VoiceActivityDetector( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)model; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_mnn_Vad_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + auto config = sherpa_mnn::GetVadModelConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + return 0; + } + + auto model = new sherpa_mnn::VoiceActivityDetector(config); + + return (jlong)model; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_Vad_delete(JNIEnv * /*env*/, + jobject /*obj*/, + jlong ptr) { + delete reinterpret_cast(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_Vad_acceptWaveform( + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) { + SafeJNI(env, "Vad_acceptWaveform", [&] { + if (!ValidatePointer(env, ptr, "Vad_acceptWaveform", + "VoiceActivityDetector pointer is null.")) { + return; + } + + auto model = reinterpret_cast(ptr); + jfloat *p = env->GetFloatArrayElements(samples, nullptr); + jsize n = env->GetArrayLength(samples); + + model->AcceptWaveform(p, n); + + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); + }); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_mnn_Vad_empty(JNIEnv * /*env*/, + jobject /*obj*/, + jlong ptr) { + auto model = reinterpret_cast(ptr); + return model->Empty(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_Vad_pop(JNIEnv * /*env*/, + jobject /*obj*/, + jlong ptr) { + auto model = reinterpret_cast(ptr); + model->Pop(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_Vad_clear(JNIEnv * /*env*/, + jobject /*obj*/, + jlong ptr) { + auto model = reinterpret_cast(ptr); + model->Clear(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_mnn_Vad_front(JNIEnv *env, jobject /*obj*/, jlong ptr) { + const auto &front = + reinterpret_cast(ptr)->Front(); + + jfloatArray samples_arr = env->NewFloatArray(front.samples.size()); + env->SetFloatArrayRegion(samples_arr, 0, front.samples.size(), + front.samples.data()); + + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + 2, env->FindClass("java/lang/Object"), nullptr); + + env->SetObjectArrayElement(obj_arr, 0, NewInteger(env, front.start)); + env->SetObjectArrayElement(obj_arr, 1, samples_arr); + + return obj_arr; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_mnn_Vad_isSpeechDetected( + JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) { + auto model = reinterpret_cast(ptr); + return model->IsSpeechDetected(); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_Vad_reset( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + SafeJNI(env, "Vad_reset", [&] { + if (!ValidatePointer(env, ptr, "Vad_reset", + "VoiceActivityDetector pointer is null.")) { + return; + } + + auto model = reinterpret_cast(ptr); + model->Reset(); + }); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_mnn_Vad_flush(JNIEnv * /*env*/, + jobject /*obj*/, + jlong ptr) { + auto model = reinterpret_cast(ptr); + model->Flush(); +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/wave-reader.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/wave-reader.cc new file mode 100644 index 00000000..f8e8a133 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/wave-reader.cc @@ -0,0 +1,84 @@ +// sherpa-mnn/jni/wave-reader.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-mnn/csrc/wave-reader.h" + +#include + +#include "sherpa-mnn/csrc/file-utils.h" +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/jni/common.h" + +static jobjectArray ReadWaveImpl(JNIEnv *env, std::istream &is, + const char *p_filename) { + bool is_ok = false; + int32_t sampling_rate = -1; + std::vector samples = + sherpa_mnn::ReadWave(is, &sampling_rate, &is_ok); + + if (!is_ok) { + SHERPA_ONNX_LOGE("Failed to read '%s'", p_filename); + jclass exception_class = env->FindClass("java/lang/Exception"); + env->ThrowNew(exception_class, "Failed to read wave file."); + return nullptr; + } + + jfloatArray samples_arr = env->NewFloatArray(samples.size()); + env->SetFloatArrayRegion(samples_arr, 0, samples.size(), samples.data()); + + jobjectArray obj_arr = (jobjectArray)env->NewObjectArray( + 2, env->FindClass("java/lang/Object"), nullptr); + + env->SetObjectArrayElement(obj_arr, 0, samples_arr); + env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, sampling_rate)); + + return obj_arr; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_mnn_WaveReader_00024Companion_readWaveFromFile( + JNIEnv *env, jclass /*cls*/, jstring filename) { + const char *p_filename = env->GetStringUTFChars(filename, nullptr); + std::ifstream is(p_filename, std::ios::binary); + + auto obj_arr = ReadWaveImpl(env, is, p_filename); + + env->ReleaseStringUTFChars(filename, p_filename); + + return obj_arr; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_mnn_WaveReader_readWaveFromFile(JNIEnv *env, + jclass /*obj*/, + jstring filename) { + return Java_com_k2fsa_sherpa_mnn_WaveReader_00024Companion_readWaveFromFile( + env, nullptr, filename); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jobjectArray JNICALL +Java_com_k2fsa_sherpa_mnn_WaveReader_00024Companion_readWaveFromAsset( + JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename) { + const char *p_filename = env->GetStringUTFChars(filename, nullptr); +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + exit(-1); + } + std::vector buffer = sherpa_mnn::ReadFile(mgr, p_filename); + + std::istrstream is(buffer.data(), buffer.size()); +#else + std::ifstream is(p_filename, std::ios::binary); +#endif + + auto obj_arr = ReadWaveImpl(env, is, p_filename); + + env->ReleaseStringUTFChars(filename, p_filename); + + return obj_arr; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/wave-writer.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/wave-writer.cc new file mode 100644 index 00000000..8d2acfc8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/jni/wave-writer.cc @@ -0,0 +1,23 @@ +// sherpa-mnn/jni/wave-writer.cc +// +// Copyright (c) 2024 Xiaomi Corporation +#include "sherpa-mnn/csrc/wave-writer.h" + +#include "sherpa-mnn/jni/common.h" + +SHERPA_ONNX_EXTERN_C +JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_mnn_WaveWriter_writeWaveToFile( + JNIEnv *env, jclass /*obj*/, jstring filename, jfloatArray samples, + jint sample_rate) { + jfloat *p = env->GetFloatArrayElements(samples, nullptr); + jsize n = env->GetArrayLength(samples); + + const char *p_filename = env->GetStringUTFChars(filename, nullptr); + + bool ok = sherpa_mnn::WriteWave(p_filename, sample_rate, p, n); + + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); + env->ReleaseStringUTFChars(filename, p_filename); + + return ok; +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/AudioTagging.kt b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/AudioTagging.kt new file mode 100644 index 00000000..5f420486 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/AudioTagging.kt @@ -0,0 +1,186 @@ +package com.k2fsa.sherpa.mnn + +import android.content.res.AssetManager + +data class OfflineZipformerAudioTaggingModelConfig( + var model: String = "", +) + +data class AudioTaggingModelConfig( + var zipformer: OfflineZipformerAudioTaggingModelConfig = OfflineZipformerAudioTaggingModelConfig(), + var ced: String = "", + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", +) + +data class AudioTaggingConfig( + var model: AudioTaggingModelConfig = AudioTaggingModelConfig(), + var labels: String = "", + var topK: Int = 5, +) + +data class AudioEvent( + val name: String, + val index: Int, + val prob: Float, +) + +class AudioTagging( + assetManager: AssetManager? = null, + config: AudioTaggingConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun createStream(): OfflineStream { + val p = createStream(ptr) + return OfflineStream(p) + } + + @Suppress("UNCHECKED_CAST") + fun compute(stream: OfflineStream, topK: Int = -1): ArrayList { + val events: Array = compute(ptr, stream.ptr, topK) + val ans = ArrayList() + + for (e in events) { + val p: Array = e as Array + ans.add( + AudioEvent( + name = p[0] as String, + index = p[1] as Int, + prob = p[2] as Float, + ) + ) + } + + return ans + } + + private external fun newFromAsset( + assetManager: AssetManager, + config: AudioTaggingConfig, + ): Long + + private external fun newFromFile( + config: AudioTaggingConfig, + ): Long + + private external fun delete(ptr: Long) + + private external fun createStream(ptr: Long): Long + + private external fun compute(ptr: Long, streamPtr: Long, topK: Int): Array + + companion object { + init { + System.loadLibrary("sherpa-mnn-jni") + } + } +} + +// please refer to +// https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models +// to download more models +// +// See also +// https://k2-fsa.github.io/sherpa/onnx/audio-tagging/ +fun getAudioTaggingConfig(type: Int, numThreads: Int = 1): AudioTaggingConfig? { + when (type) { + 0 -> { + val modelDir = "sherpa-onnx-zipformer-small-audio-tagging-2024-04-15" + return AudioTaggingConfig( + model = AudioTaggingModelConfig( + zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"), + numThreads = numThreads, + debug = true, + ), + labels = "$modelDir/class_labels_indices.csv", + topK = 3, + ) + } + + 1 -> { + val modelDir = "sherpa-onnx-zipformer-audio-tagging-2024-04-09" + return AudioTaggingConfig( + model = AudioTaggingModelConfig( + zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"), + numThreads = numThreads, + debug = true, + ), + labels = "$modelDir/class_labels_indices.csv", + topK = 3, + ) + } + + 2 -> { + val modelDir = "sherpa-onnx-ced-tiny-audio-tagging-2024-04-19" + return AudioTaggingConfig( + model = AudioTaggingModelConfig( + ced = "$modelDir/model.int8.onnx", + numThreads = numThreads, + debug = true, + ), + labels = "$modelDir/class_labels_indices.csv", + topK = 3, + ) + } + + 3 -> { + val modelDir = "sherpa-onnx-ced-mini-audio-tagging-2024-04-19" + return AudioTaggingConfig( + model = AudioTaggingModelConfig( + ced = "$modelDir/model.int8.onnx", + numThreads = numThreads, + debug = true, + ), + labels = "$modelDir/class_labels_indices.csv", + topK = 3, + ) + } + + 4 -> { + val modelDir = "sherpa-onnx-ced-small-audio-tagging-2024-04-19" + return AudioTaggingConfig( + model = AudioTaggingModelConfig( + ced = "$modelDir/model.int8.onnx", + numThreads = numThreads, + debug = true, + ), + labels = "$modelDir/class_labels_indices.csv", + topK = 3, + ) + } + + 5 -> { + val modelDir = "sherpa-onnx-ced-base-audio-tagging-2024-04-19" + return AudioTaggingConfig( + model = AudioTaggingModelConfig( + ced = "$modelDir/model.int8.onnx", + numThreads = numThreads, + debug = true, + ), + labels = "$modelDir/class_labels_indices.csv", + topK = 3, + ) + } + } + + return null +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/FeatureConfig.kt b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/FeatureConfig.kt new file mode 100644 index 00000000..76d928ed --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/FeatureConfig.kt @@ -0,0 +1,10 @@ +package com.k2fsa.sherpa.mnn + +data class FeatureConfig( + var sampleRate: Int = 16000, + var featureDim: Int = 80, +) + +fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig { + return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim) +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/KeywordSpotter.kt b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/KeywordSpotter.kt new file mode 100644 index 00000000..824b201a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/KeywordSpotter.kt @@ -0,0 +1,156 @@ +// Copyright (c) 2024 Xiaomi Corporation +package com.k2fsa.sherpa.mnn + +import android.content.res.AssetManager + +data class KeywordSpotterConfig( + var featConfig: FeatureConfig = FeatureConfig(), + var modelConfig: OnlineModelConfig = OnlineModelConfig(), + var maxActivePaths: Int = 4, + var keywordsFile: String = "keywords.txt", + var keywordsScore: Float = 1.5f, + var keywordsThreshold: Float = 0.25f, + var numTrailingBlanks: Int = 2, +) + +data class KeywordSpotterResult( + val keyword: String, + val tokens: Array, + val timestamps: FloatArray, + // TODO(fangjun): Add more fields +) + +class KeywordSpotter( + assetManager: AssetManager? = null, + val config: KeywordSpotterConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun createStream(keywords: String = ""): OnlineStream { + val p = createStream(ptr, keywords) + return OnlineStream(p) + } + + fun decode(stream: OnlineStream) = decode(ptr, stream.ptr) + fun reset(stream: OnlineStream) = reset(ptr, stream.ptr) + fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr) + fun getResult(stream: OnlineStream): KeywordSpotterResult { + val objArray = getResult(ptr, stream.ptr) + + val keyword = objArray[0] as String + val tokens = objArray[1] as Array + val timestamps = objArray[2] as FloatArray + + return KeywordSpotterResult(keyword = keyword, tokens = tokens, timestamps = timestamps) + } + + private external fun delete(ptr: Long) + + private external fun newFromAsset( + assetManager: AssetManager, + config: KeywordSpotterConfig, + ): Long + + private external fun newFromFile( + config: KeywordSpotterConfig, + ): Long + + private external fun createStream(ptr: Long, keywords: String): Long + private external fun isReady(ptr: Long, streamPtr: Long): Boolean + private external fun decode(ptr: Long, streamPtr: Long) + private external fun reset(ptr: Long, streamPtr: Long) + private external fun getResult(ptr: Long, streamPtr: Long): Array + + companion object { + init { + System.loadLibrary("sherpa-mnn-jni") + } + } +} + +/* +Please see +https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html +for a list of pre-trained models. + +We only add a few here. Please change the following code +to add your own. (It should be straightforward to add a new model +by following the code) + +@param type +0 - sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01 (Chinese) + https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/summary + +1 - sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01 (English) + https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/summary + + */ +fun getKwsModelConfig(type: Int): OnlineModelConfig? { + when (type) { + 0 -> { + val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-12-avg-2-chunk-16-left-64.onnx", + decoder = "$modelDir/decoder-epoch-12-avg-2-chunk-16-left-64.onnx", + joiner = "$modelDir/joiner-epoch-12-avg-2-chunk-16-left-64.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer2", + ) + } + + 1 -> { + val modelDir = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-12-avg-2-chunk-16-left-64.onnx", + decoder = "$modelDir/decoder-epoch-12-avg-2-chunk-16-left-64.onnx", + joiner = "$modelDir/joiner-epoch-12-avg-2-chunk-16-left-64.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer2", + ) + } + + } + return null +} + +/* + * Get the default keywords for each model. + * Caution: The types and modelDir should be the same as those in getModelConfig + * function above. + */ +fun getKeywordsFile(type: Int): String { + when (type) { + 0 -> { + val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01" + return "$modelDir/keywords.txt" + } + + 1 -> { + val modelDir = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01" + return "$modelDir/keywords.txt" + } + + } + return "" +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OfflinePunctuation.kt b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OfflinePunctuation.kt new file mode 100644 index 00000000..b2e5cd86 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OfflinePunctuation.kt @@ -0,0 +1,60 @@ +package com.k2fsa.sherpa.mnn + +import android.content.res.AssetManager + +data class OfflinePunctuationModelConfig( + var ctTransformer: String = "", + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", +) + + +data class OfflinePunctuationConfig( + var model: OfflinePunctuationModelConfig, +) + +class OfflinePunctuation( + assetManager: AssetManager? = null, + config: OfflinePunctuationConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun addPunctuation(text: String) = addPunctuation(ptr, text) + + private external fun delete(ptr: Long) + + private external fun addPunctuation(ptr: Long, text: String): String + + private external fun newFromAsset( + assetManager: AssetManager, + config: OfflinePunctuationConfig, + ): Long + + private external fun newFromFile( + config: OfflinePunctuationConfig, + ): Long + + companion object { + init { + System.loadLibrary("sherpa-mnn-jni") + } + } +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OfflineRecognizer.kt b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OfflineRecognizer.kt new file mode 100644 index 00000000..f07b3fab --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OfflineRecognizer.kt @@ -0,0 +1,486 @@ +package com.k2fsa.sherpa.mnn + +import android.content.res.AssetManager + +data class OfflineRecognizerResult( + val text: String, + val tokens: Array, + val timestamps: FloatArray, + val lang: String, + val emotion: String, + val event: String, +) + +data class OfflineTransducerModelConfig( + var encoder: String = "", + var decoder: String = "", + var joiner: String = "", +) + +data class OfflineParaformerModelConfig( + var model: String = "", +) + +data class OfflineNemoEncDecCtcModelConfig( + var model: String = "", +) + +data class OfflineWhisperModelConfig( + var encoder: String = "", + var decoder: String = "", + var language: String = "en", // Used with multilingual model + var task: String = "transcribe", // transcribe or translate + var tailPaddings: Int = 1000, // Padding added at the end of the samples +) + +data class OfflineFireRedAsrModelConfig( + var encoder: String = "", + var decoder: String = "", +) + +data class OfflineMoonshineModelConfig( + var preprocessor: String = "", + var encoder: String = "", + var uncachedDecoder: String = "", + var cachedDecoder: String = "", +) + +data class OfflineSenseVoiceModelConfig( + var model: String = "", + var language: String = "", + var useInverseTextNormalization: Boolean = true, +) + +data class OfflineModelConfig( + var transducer: OfflineTransducerModelConfig = OfflineTransducerModelConfig(), + var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(), + var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(), + var fireRedAsr: OfflineFireRedAsrModelConfig = OfflineFireRedAsrModelConfig(), + var moonshine: OfflineMoonshineModelConfig = OfflineMoonshineModelConfig(), + var nemo: OfflineNemoEncDecCtcModelConfig = OfflineNemoEncDecCtcModelConfig(), + var senseVoice: OfflineSenseVoiceModelConfig = OfflineSenseVoiceModelConfig(), + var teleSpeech: String = "", + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", + var modelType: String = "", + var tokens: String = "", + var modelingUnit: String = "", + var bpeVocab: String = "", +) + +data class OfflineRecognizerConfig( + var featConfig: FeatureConfig = FeatureConfig(), + var modelConfig: OfflineModelConfig = OfflineModelConfig(), + // var lmConfig: OfflineLMConfig(), // TODO(fangjun): enable it + var decodingMethod: String = "greedy_search", + var maxActivePaths: Int = 4, + var hotwordsFile: String = "", + var hotwordsScore: Float = 1.5f, + var ruleFsts: String = "", + var ruleFars: String = "", + var blankPenalty: Float = 0.0f, +) + +class OfflineRecognizer( + assetManager: AssetManager? = null, + config: OfflineRecognizerConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun createStream(): OfflineStream { + val p = createStream(ptr) + return OfflineStream(p) + } + + fun getResult(stream: OfflineStream): OfflineRecognizerResult { + val objArray = getResult(stream.ptr) + + val text = objArray[0] as String + val tokens = objArray[1] as Array + val timestamps = objArray[2] as FloatArray + val lang = objArray[3] as String + val emotion = objArray[4] as String + val event = objArray[5] as String + return OfflineRecognizerResult( + text = text, + tokens = tokens, + timestamps = timestamps, + lang = lang, + emotion = emotion, + event = event + ) + } + + fun decode(stream: OfflineStream) = decode(ptr, stream.ptr) + + private external fun delete(ptr: Long) + + private external fun createStream(ptr: Long): Long + + private external fun newFromAsset( + assetManager: AssetManager, + config: OfflineRecognizerConfig, + ): Long + + private external fun newFromFile( + config: OfflineRecognizerConfig, + ): Long + + private external fun decode(ptr: Long, streamPtr: Long) + + private external fun getResult(streamPtr: Long): Array + + companion object { + init { + System.loadLibrary("sherpa-mnn-jni") + } + } +} + +/* +Please see +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models. + +We only add a few here. Please change the following code +to add your own. (It should be straightforward to add a new model +by following the code) + +@param type + +0 - csukuangfj/sherpa-onnx-paraformer-zh-2023-09-14 (Chinese) + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-09-14-chinese + int8 + +1 - icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04 (English) + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#icefall-asr-multidataset-pruned-transducer-stateless7-2023-05-04-english + encoder int8, decoder/joiner float32 + +2 - sherpa-onnx-whisper-tiny.en + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en + encoder int8, decoder int8 + +3 - sherpa-onnx-whisper-base.en + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en + encoder int8, decoder int8 + +4 - pkufool/icefall-asr-zipformer-wenetspeech-20230615 (Chinese) + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#pkufool-icefall-asr-zipformer-wenetspeech-20230615-chinese + encoder/joiner int8, decoder fp32 + + */ +fun getOfflineModelConfig(type: Int): OfflineModelConfig? { + when (type) { + 0 -> { + val modelDir = "sherpa-onnx-paraformer-zh-2023-09-14" + return OfflineModelConfig( + paraformer = OfflineParaformerModelConfig( + model = "$modelDir/model.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "paraformer", + ) + } + + 1 -> { + val modelDir = "icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04" + return OfflineModelConfig( + transducer = OfflineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-30-avg-4.int8.onnx", + decoder = "$modelDir/decoder-epoch-30-avg-4.onnx", + joiner = "$modelDir/joiner-epoch-30-avg-4.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "transducer", + ) + } + + 2 -> { + val modelDir = "sherpa-onnx-whisper-tiny.en" + return OfflineModelConfig( + whisper = OfflineWhisperModelConfig( + encoder = "$modelDir/tiny.en-encoder.int8.onnx", + decoder = "$modelDir/tiny.en-decoder.int8.onnx", + ), + tokens = "$modelDir/tiny.en-tokens.txt", + modelType = "whisper", + ) + } + + 3 -> { + val modelDir = "sherpa-onnx-whisper-base.en" + return OfflineModelConfig( + whisper = OfflineWhisperModelConfig( + encoder = "$modelDir/base.en-encoder.int8.onnx", + decoder = "$modelDir/base.en-decoder.int8.onnx", + ), + tokens = "$modelDir/base.en-tokens.txt", + modelType = "whisper", + ) + } + + + 4 -> { + val modelDir = "icefall-asr-zipformer-wenetspeech-20230615" + return OfflineModelConfig( + transducer = OfflineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-12-avg-4.int8.onnx", + decoder = "$modelDir/decoder-epoch-12-avg-4.onnx", + joiner = "$modelDir/joiner-epoch-12-avg-4.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "transducer", + ) + } + + 5 -> { + val modelDir = "sherpa-onnx-zipformer-multi-zh-hans-2023-9-2" + return OfflineModelConfig( + transducer = OfflineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-20-avg-1.int8.onnx", + decoder = "$modelDir/decoder-epoch-20-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-20-avg-1.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "transducer", + ) + } + + 6 -> { + val modelDir = "sherpa-onnx-nemo-ctc-en-citrinet-512" + return OfflineModelConfig( + nemo = OfflineNemoEncDecCtcModelConfig( + model = "$modelDir/model.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + ) + } + + 7 -> { + val modelDir = "sherpa-onnx-nemo-fast-conformer-ctc-be-de-en-es-fr-hr-it-pl-ru-uk-20k" + return OfflineModelConfig( + nemo = OfflineNemoEncDecCtcModelConfig( + model = "$modelDir/model.onnx", + ), + tokens = "$modelDir/tokens.txt", + ) + } + + 8 -> { + val modelDir = "sherpa-onnx-nemo-fast-conformer-ctc-en-24500" + return OfflineModelConfig( + nemo = OfflineNemoEncDecCtcModelConfig( + model = "$modelDir/model.onnx", + ), + tokens = "$modelDir/tokens.txt", + ) + } + + 9 -> { + val modelDir = "sherpa-onnx-nemo-fast-conformer-ctc-en-de-es-fr-14288" + return OfflineModelConfig( + nemo = OfflineNemoEncDecCtcModelConfig( + model = "$modelDir/model.onnx", + ), + tokens = "$modelDir/tokens.txt", + ) + } + + 10 -> { + val modelDir = "sherpa-onnx-nemo-fast-conformer-ctc-es-1424" + return OfflineModelConfig( + nemo = OfflineNemoEncDecCtcModelConfig( + model = "$modelDir/model.onnx", + ), + tokens = "$modelDir/tokens.txt", + ) + } + + 11 -> { + val modelDir = "sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04" + return OfflineModelConfig( + teleSpeech = "$modelDir/model.int8.onnx", + tokens = "$modelDir/tokens.txt", + modelType = "telespeech_ctc", + ) + } + + 12 -> { + val modelDir = "sherpa-onnx-zipformer-thai-2024-06-20" + return OfflineModelConfig( + transducer = OfflineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-12-avg-5.int8.onnx", + decoder = "$modelDir/decoder-epoch-12-avg-5.onnx", + joiner = "$modelDir/joiner-epoch-12-avg-5.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "transducer", + ) + } + + 13 -> { + val modelDir = "sherpa-onnx-zipformer-korean-2024-06-24" + return OfflineModelConfig( + transducer = OfflineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx", + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "transducer", + ) + } + + 14 -> { + val modelDir = "sherpa-onnx-paraformer-zh-small-2024-03-09" + return OfflineModelConfig( + paraformer = OfflineParaformerModelConfig( + model = "$modelDir/model.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "paraformer", + ) + } + + 15 -> { + val modelDir = "sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17" + return OfflineModelConfig( + senseVoice = OfflineSenseVoiceModelConfig( + model = "$modelDir/model.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + ) + } + + 16 -> { + val modelDir = "sherpa-onnx-zipformer-ja-reazonspeech-2024-08-01" + return OfflineModelConfig( + transducer = OfflineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx", + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "transducer", + ) + } + + 17 -> { + val modelDir = "sherpa-onnx-zipformer-ru-2024-09-18" + return OfflineModelConfig( + transducer = OfflineTransducerModelConfig( + encoder = "$modelDir/encoder.int8.onnx", + decoder = "$modelDir/decoder.onnx", + joiner = "$modelDir/joiner.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "transducer", + ) + } + + 18 -> { + val modelDir = "sherpa-onnx-small-zipformer-ru-2024-09-18" + return OfflineModelConfig( + transducer = OfflineTransducerModelConfig( + encoder = "$modelDir/encoder.int8.onnx", + decoder = "$modelDir/decoder.onnx", + joiner = "$modelDir/joiner.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "transducer", + ) + } + + 19 -> { + val modelDir = "sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24" + return OfflineModelConfig( + nemo = OfflineNemoEncDecCtcModelConfig( + model = "$modelDir/model.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + ) + } + + 20 -> { + val modelDir = "sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24" + return OfflineModelConfig( + transducer = OfflineTransducerModelConfig( + encoder = "$modelDir/encoder.int8.onnx", + decoder = "$modelDir/decoder.onnx", + joiner = "$modelDir/joiner.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "nemo_transducer", + ) + } + + 21 -> { + val modelDir = "sherpa-onnx-moonshine-tiny-en-int8" + return OfflineModelConfig( + moonshine = OfflineMoonshineModelConfig( + preprocessor = "$modelDir/preprocess.onnx", + encoder = "$modelDir/encode.int8.onnx", + uncachedDecoder = "$modelDir/uncached_decode.int8.onnx", + cachedDecoder = "$modelDir/cached_decode.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + ) + } + + 22 -> { + val modelDir = "sherpa-onnx-moonshine-base-en-int8" + return OfflineModelConfig( + moonshine = OfflineMoonshineModelConfig( + preprocessor = "$modelDir/preprocess.onnx", + encoder = "$modelDir/encode.int8.onnx", + uncachedDecoder = "$modelDir/uncached_decode.int8.onnx", + cachedDecoder = "$modelDir/cached_decode.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + ) + } + + 23 -> { + val modelDir = "sherpa-onnx-zipformer-zh-en-2023-11-22" + return OfflineModelConfig( + transducer = OfflineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-34-avg-19.int8.onnx", + decoder = "$modelDir/decoder-epoch-34-avg-19.onnx", + joiner = "$modelDir/joiner-epoch-34-avg-19.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "transducer", + ) + } + + 24 -> { + val modelDir = "sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16" + return OfflineModelConfig( + fireRedAsr = OfflineFireRedAsrModelConfig( + encoder = "$modelDir/encoder.int8.onnx", + decoder = "$modelDir/decoder.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + ) + } + } + return null +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OfflineSpeakerDiarization.kt b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OfflineSpeakerDiarization.kt new file mode 100644 index 00000000..fcad3b5f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OfflineSpeakerDiarization.kt @@ -0,0 +1,104 @@ +package com.k2fsa.sherpa.mnn + +import android.content.res.AssetManager + +data class OfflineSpeakerSegmentationPyannoteModelConfig( + var model: String = "", +) + +data class OfflineSpeakerSegmentationModelConfig( + var pyannote: OfflineSpeakerSegmentationPyannoteModelConfig = OfflineSpeakerSegmentationPyannoteModelConfig(), + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", +) + +data class FastClusteringConfig( + var numClusters: Int = -1, + var threshold: Float = 0.5f, +) + +data class OfflineSpeakerDiarizationConfig( + var segmentation: OfflineSpeakerSegmentationModelConfig = OfflineSpeakerSegmentationModelConfig(), + var embedding: SpeakerEmbeddingExtractorConfig = SpeakerEmbeddingExtractorConfig(), + var clustering: FastClusteringConfig = FastClusteringConfig(), + var minDurationOn: Float = 0.2f, + var minDurationOff: Float = 0.5f, +) + +data class OfflineSpeakerDiarizationSegment( + val start: Float, // in seconds + val end: Float, // in seconds + val speaker: Int, // ID of the speaker; count from 0 +) + +class OfflineSpeakerDiarization( + assetManager: AssetManager? = null, + val config: OfflineSpeakerDiarizationConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + // Only config.clustering is used. All other fields in config + // are ignored + fun setConfig(config: OfflineSpeakerDiarizationConfig) = setConfig(ptr, config) + + fun sampleRate() = getSampleRate(ptr) + + fun process(samples: FloatArray) = process(ptr, samples) + + fun processWithCallback( + samples: FloatArray, + callback: (numProcessedChunks: Int, numTotalChunks: Int, arg: Long) -> Int, + arg: Long = 0, + ) = processWithCallback(ptr, samples, callback, arg) + + private external fun delete(ptr: Long) + + private external fun newFromAsset( + assetManager: AssetManager, + config: OfflineSpeakerDiarizationConfig, + ): Long + + private external fun newFromFile( + config: OfflineSpeakerDiarizationConfig, + ): Long + + private external fun setConfig(ptr: Long, config: OfflineSpeakerDiarizationConfig) + + private external fun getSampleRate(ptr: Long): Int + + private external fun process( + ptr: Long, + samples: FloatArray + ): Array + + private external fun processWithCallback( + ptr: Long, + samples: FloatArray, + callback: (numProcessedChunks: Int, numTotalChunks: Int, arg: Long) -> Int, + arg: Long, + ): Array + + companion object { + init { + System.loadLibrary("sherpa-mnn-jni") + } + } +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OfflineStream.kt b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OfflineStream.kt new file mode 100644 index 00000000..3e1a14b1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OfflineStream.kt @@ -0,0 +1,32 @@ +package com.k2fsa.sherpa.mnn + +class OfflineStream(var ptr: Long) { + fun acceptWaveform(samples: FloatArray, sampleRate: Int) = + acceptWaveform(ptr, samples, sampleRate) + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun use(block: (OfflineStream) -> Unit) { + try { + block(this) + } finally { + release() + } + } + + private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) + private external fun delete(ptr: Long) + + companion object { + init { + System.loadLibrary("sherpa-mnn-jni") + } + } +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OnlinePunctuation.kt b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OnlinePunctuation.kt new file mode 100644 index 00000000..f7e233ab --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OnlinePunctuation.kt @@ -0,0 +1,61 @@ +package com.k2fsa.sherpa.mnn + +import android.content.res.AssetManager + +data class OnlinePunctuationModelConfig( + var cnnBilstm: String = "", + var bpeVocab: String = "", + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", +) + + +data class OnlinePunctuationConfig( + var model: OnlinePunctuationModelConfig, +) + +class OnlinePunctuation( + assetManager: AssetManager? = null, + config: OnlinePunctuationConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun addPunctuation(text: String) = addPunctuation(ptr, text) + + private external fun delete(ptr: Long) + + private external fun addPunctuation(ptr: Long, text: String): String + + private external fun newFromAsset( + assetManager: AssetManager, + config: OnlinePunctuationConfig, + ): Long + + private external fun newFromFile( + config: OnlinePunctuationConfig, + ): Long + + companion object { + init { + System.loadLibrary("sherpa-mnn-jni") + } + } +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OnlineRecognizer.kt b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OnlineRecognizer.kt new file mode 100644 index 00000000..9b8f82e6 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OnlineRecognizer.kt @@ -0,0 +1,414 @@ +package com.k2fsa.sherpa.mnn + +import android.content.res.AssetManager + +data class EndpointRule( + var mustContainNonSilence: Boolean, + var minTrailingSilence: Float, + var minUtteranceLength: Float, +) + +data class EndpointConfig( + var rule1: EndpointRule = EndpointRule(false, 2.4f, 0.0f), + var rule2: EndpointRule = EndpointRule(true, 1.4f, 0.0f), + var rule3: EndpointRule = EndpointRule(false, 0.0f, 20.0f) +) + +data class OnlineTransducerModelConfig( + var encoder: String = "", + var decoder: String = "", + var joiner: String = "", +) + +data class OnlineParaformerModelConfig( + var encoder: String = "", + var decoder: String = "", +) + +data class OnlineZipformer2CtcModelConfig( + var model: String = "", +) + +data class OnlineNeMoCtcModelConfig( + var model: String = "", +) + +data class OnlineModelConfig( + var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(), + var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(), + var zipformer2Ctc: OnlineZipformer2CtcModelConfig = OnlineZipformer2CtcModelConfig(), + var neMoCtc: OnlineNeMoCtcModelConfig = OnlineNeMoCtcModelConfig(), + var tokens: String = "", + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", + var modelType: String = "", + var modelingUnit: String = "", + var bpeVocab: String = "", +) + +data class OnlineLMConfig( + var model: String = "", + var scale: Float = 0.5f, +) + +data class OnlineCtcFstDecoderConfig( + var graph: String = "", + var maxActive: Int = 3000, +) + + +data class OnlineRecognizerConfig( + var featConfig: FeatureConfig = FeatureConfig(), + var modelConfig: OnlineModelConfig = OnlineModelConfig(), + var lmConfig: OnlineLMConfig = OnlineLMConfig(), + var ctcFstDecoderConfig: OnlineCtcFstDecoderConfig = OnlineCtcFstDecoderConfig(), + var endpointConfig: EndpointConfig = EndpointConfig(), + var enableEndpoint: Boolean = true, + var decodingMethod: String = "greedy_search", + var maxActivePaths: Int = 4, + var hotwordsFile: String = "", + var hotwordsScore: Float = 1.5f, + var ruleFsts: String = "", + var ruleFars: String = "", + var blankPenalty: Float = 0.0f, +) + +data class OnlineRecognizerResult( + val text: String, + val tokens: Array, + val timestamps: FloatArray, + // TODO(fangjun): Add more fields +) + +class OnlineRecognizer( + assetManager: AssetManager? = null, + val config: OnlineRecognizerConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun createStream(hotwords: String = ""): OnlineStream { + val p = createStream(ptr, hotwords) + return OnlineStream(p) + } + + fun reset(stream: OnlineStream) = reset(ptr, stream.ptr) + fun decode(stream: OnlineStream) = decode(ptr, stream.ptr) + fun isEndpoint(stream: OnlineStream) = isEndpoint(ptr, stream.ptr) + fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr) + fun getResult(stream: OnlineStream): OnlineRecognizerResult { + val objArray = getResult(ptr, stream.ptr) + + val text = objArray[0] as String + val tokens = objArray[1] as Array + val timestamps = objArray[2] as FloatArray + + return OnlineRecognizerResult(text = text, tokens = tokens, timestamps = timestamps) + } + + private external fun delete(ptr: Long) + + private external fun newFromAsset( + assetManager: AssetManager, + config: OnlineRecognizerConfig, + ): Long + + private external fun newFromFile( + config: OnlineRecognizerConfig, + ): Long + + private external fun createStream(ptr: Long, hotwords: String): Long + private external fun reset(ptr: Long, streamPtr: Long) + private external fun decode(ptr: Long, streamPtr: Long) + private external fun isEndpoint(ptr: Long, streamPtr: Long): Boolean + private external fun isReady(ptr: Long, streamPtr: Long): Boolean + private external fun getResult(ptr: Long, streamPtr: Long): Array + + companion object { + init { + System.loadLibrary("sherpa-mnn-jni") + } + } +} + + +/* +Please see +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models. + +We only add a few here. Please change the following code +to add your own. (It should be straightforward to add a new model +by following the code) + +@param type +0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English) + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english + +1 - csukuangfj/sherpa-onnx-lstm-zh-2023-02-20 (Chinese) + + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-zh-2023-02-20-chinese + +2 - csukuangfj/sherpa-onnx-lstm-en-2023-02-17 (English) + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-en-2023-02-17-english + +3,4 - pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615 + https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615 + 3 - int8 encoder + 4 - float32 encoder + +5 - csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en + https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en + +6 - sherpa-onnx-streaming-zipformer-en-2023-06-26 + https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26 + +7 - shaojieli/sherpa-onnx-streaming-zipformer-fr-2023-04-14 (French) + https://huggingface.co/shaojieli/sherpa-onnx-streaming-zipformer-fr-2023-04-14 + +8 - csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English) + https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 + encoder int8, decoder/joiner float32 + + */ +fun getModelConfig(type: Int): OnlineModelConfig? { + when (type) { + 0 -> { + val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-99-avg-1.onnx", + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer", + ) + } + + 1 -> { + val modelDir = "sherpa-onnx-lstm-zh-2023-02-20" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-11-avg-1.onnx", + decoder = "$modelDir/decoder-epoch-11-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-11-avg-1.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "lstm", + ) + } + + 2 -> { + val modelDir = "sherpa-onnx-lstm-en-2023-02-17" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-99-avg-1.onnx", + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "lstm", + ) + } + + 3 -> { + val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx", + decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", + joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", + ), + tokens = "$modelDir/data/lang_char/tokens.txt", + modelType = "zipformer2", + ) + } + + 4 -> { + val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.onnx", + decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", + joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", + ), + tokens = "$modelDir/data/lang_char/tokens.txt", + modelType = "zipformer2", + ) + } + + 5 -> { + val modelDir = "sherpa-onnx-streaming-paraformer-bilingual-zh-en" + return OnlineModelConfig( + paraformer = OnlineParaformerModelConfig( + encoder = "$modelDir/encoder.int8.onnx", + decoder = "$modelDir/decoder.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "paraformer", + ) + } + + 6 -> { + val modelDir = "sherpa-onnx-streaming-zipformer-en-2023-06-26" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-99-avg-1-chunk-16-left-128.int8.onnx", + decoder = "$modelDir/decoder-epoch-99-avg-1-chunk-16-left-128.onnx", + joiner = "$modelDir/joiner-epoch-99-avg-1-chunk-16-left-128.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer2", + ) + } + + 7 -> { + val modelDir = "sherpa-onnx-streaming-zipformer-fr-2023-04-14" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-29-avg-9-with-averaged-model.int8.onnx", + decoder = "$modelDir/decoder-epoch-29-avg-9-with-averaged-model.onnx", + joiner = "$modelDir/joiner-epoch-29-avg-9-with-averaged-model.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer", + ) + } + + 8 -> { + val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx", + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer", + ) + } + + 9 -> { + val modelDir = "sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx", + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer", + ) + } + + 10 -> { + val modelDir = "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx", + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer", + ) + } + + 11 -> { + val modelDir = "sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-80ms" + return OnlineModelConfig( + neMoCtc = OnlineNeMoCtcModelConfig( + model = "$modelDir/model.onnx", + ), + tokens = "$modelDir/tokens.txt", + ) + } + + 12 -> { + val modelDir = "sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-480ms" + return OnlineModelConfig( + neMoCtc = OnlineNeMoCtcModelConfig( + model = "$modelDir/model.onnx", + ), + tokens = "$modelDir/tokens.txt", + ) + } + + 13 -> { + val modelDir = "sherpa-onnx-nemo-streaming-fast-conformer-ctc-en-1040ms" + return OnlineModelConfig( + neMoCtc = OnlineNeMoCtcModelConfig( + model = "$modelDir/model.onnx", + ), + tokens = "$modelDir/tokens.txt", + ) + } + + 14 -> { + val modelDir = "sherpa-onnx-streaming-zipformer-korean-2024-06-16" + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx", + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "zipformer", + ) + } + } + return null +} + +/* +Please see +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models. + +We only add a few here. Please change the following code +to add your own LM model. (It should be straightforward to train a new NN LM model +by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py) + +@param type +0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English) + https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english + */ +fun getOnlineLMConfig(type: Int): OnlineLMConfig { + when (type) { + 0 -> { + val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20" + return OnlineLMConfig( + model = "$modelDir/with-state-epoch-99-avg-1.int8.onnx", + scale = 0.5f, + ) + } + } + return OnlineLMConfig() +} + +fun getEndpointConfig(): EndpointConfig { + return EndpointConfig( + rule1 = EndpointRule(false, 2.4f, 0.0f), + rule2 = EndpointRule(true, 1.4f, 0.0f), + rule3 = EndpointRule(false, 0.0f, 20.0f) + ) +} + diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OnlineStream.kt b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OnlineStream.kt new file mode 100644 index 00000000..677d9bf5 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/OnlineStream.kt @@ -0,0 +1,36 @@ +package com.k2fsa.sherpa.mnn + +class OnlineStream(var ptr: Long = 0) { + fun acceptWaveform(samples: FloatArray, sampleRate: Int) = + acceptWaveform(ptr, samples, sampleRate) + + fun inputFinished() = inputFinished(ptr) + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun use(block: (OnlineStream) -> Unit) { + try { + block(this) + } finally { + release() + } + } + + private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) + private external fun inputFinished(ptr: Long) + private external fun delete(ptr: Long) + + + companion object { + init { + System.loadLibrary("sherpa-mnn-jni") + } + } +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/Speaker.kt b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/Speaker.kt new file mode 100644 index 00000000..a92654ef --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/Speaker.kt @@ -0,0 +1,157 @@ +package com.k2fsa.sherpa.mnn + +import android.content.res.AssetManager +import android.util.Log + +class SpeakerEmbeddingExtractor( + assetManager: AssetManager? = null, + config: SpeakerEmbeddingExtractorConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun createStream(): OnlineStream { + val p = createStream(ptr) + return OnlineStream(p) + } + + fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr) + fun compute(stream: OnlineStream) = compute(ptr, stream.ptr) + fun dim() = dim(ptr) + + private external fun newFromAsset( + assetManager: AssetManager, + config: SpeakerEmbeddingExtractorConfig, + ): Long + + private external fun newFromFile( + config: SpeakerEmbeddingExtractorConfig, + ): Long + + private external fun delete(ptr: Long) + + private external fun createStream(ptr: Long): Long + + private external fun isReady(ptr: Long, streamPtr: Long): Boolean + + private external fun compute(ptr: Long, streamPtr: Long): FloatArray + + private external fun dim(ptr: Long): Int + + companion object { + init { + System.loadLibrary("sherpa-mnn-jni") + } + } +} + +class SpeakerEmbeddingManager(val dim: Int) { + private var ptr: Long + + init { + ptr = create(dim) + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + fun add(name: String, embedding: FloatArray) = add(ptr, name, embedding) + fun add(name: String, embedding: Array) = addList(ptr, name, embedding) + fun remove(name: String) = remove(ptr, name) + fun search(embedding: FloatArray, threshold: Float) = search(ptr, embedding, threshold) + fun verify(name: String, embedding: FloatArray, threshold: Float) = + verify(ptr, name, embedding, threshold) + + fun contains(name: String) = contains(ptr, name) + fun numSpeakers() = numSpeakers(ptr) + + fun allSpeakerNames() = allSpeakerNames(ptr) + + private external fun create(dim: Int): Long + private external fun delete(ptr: Long): Unit + private external fun add(ptr: Long, name: String, embedding: FloatArray): Boolean + private external fun addList(ptr: Long, name: String, embedding: Array): Boolean + private external fun remove(ptr: Long, name: String): Boolean + private external fun search(ptr: Long, embedding: FloatArray, threshold: Float): String + private external fun verify( + ptr: Long, + name: String, + embedding: FloatArray, + threshold: Float + ): Boolean + + private external fun contains(ptr: Long, name: String): Boolean + private external fun numSpeakers(ptr: Long): Int + + private external fun allSpeakerNames(ptr: Long): Array + + companion object { + init { + System.loadLibrary("sherpa-mnn-jni") + } + } +} + +// Please download the model file from +// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +// and put it inside the assets directory. +// +// Please don't put it in a subdirectory of assets +private val modelName = "3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx" + +object SpeakerRecognition { + var _extractor: SpeakerEmbeddingExtractor? = null + var _manager: SpeakerEmbeddingManager? = null + + val extractor: SpeakerEmbeddingExtractor + get() { + return _extractor!! + } + + val manager: SpeakerEmbeddingManager + get() { + return _manager!! + } + + fun initExtractor(assetManager: AssetManager? = null) { + synchronized(this) { + if (_extractor != null) { + return + } + Log.i("sherpa-onnx", "Initializing speaker embedding extractor") + + _extractor = SpeakerEmbeddingExtractor( + assetManager = assetManager, + config = SpeakerEmbeddingExtractorConfig( + model = modelName, + numThreads = 2, + debug = false, + provider = "cpu", + ) + ) + + _manager = SpeakerEmbeddingManager(dim = _extractor!!.dim()) + } + } +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/SpeakerEmbeddingExtractorConfig.kt b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/SpeakerEmbeddingExtractorConfig.kt new file mode 100644 index 00000000..11279b01 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/SpeakerEmbeddingExtractorConfig.kt @@ -0,0 +1,8 @@ +package com.k2fsa.sherpa.mnn + +data class SpeakerEmbeddingExtractorConfig( + val model: String = "", + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", +) diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/SpokenLanguageIdentification.kt b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/SpokenLanguageIdentification.kt new file mode 100644 index 00000000..6fb1ec56 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/SpokenLanguageIdentification.kt @@ -0,0 +1,103 @@ +package com.k2fsa.sherpa.mnn + +import android.content.res.AssetManager + +data class SpokenLanguageIdentificationWhisperConfig( + var encoder: String = "", + var decoder: String = "", + var tailPaddings: Int = -1, +) + +data class SpokenLanguageIdentificationConfig( + var whisper: SpokenLanguageIdentificationWhisperConfig = SpokenLanguageIdentificationWhisperConfig(), + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", +) + +class SpokenLanguageIdentification( + assetManager: AssetManager? = null, + config: SpokenLanguageIdentificationConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun createStream(): OfflineStream { + val p = createStream(ptr) + return OfflineStream(p) + } + + fun compute(stream: OfflineStream) = compute(ptr, stream.ptr) + + private external fun newFromAsset( + assetManager: AssetManager, + config: SpokenLanguageIdentificationConfig, + ): Long + + private external fun newFromFile( + config: SpokenLanguageIdentificationConfig, + ): Long + + private external fun delete(ptr: Long) + + private external fun createStream(ptr: Long): Long + + private external fun compute(ptr: Long, streamPtr: Long): String + + companion object { + init { + System.loadLibrary("sherpa-mnn-jni") + } + } +} + +// please refer to +// https://k2-fsa.github.io/sherpa/onnx/spolken-language-identification/pretrained_models.html#whisper +// to download more models +fun getSpokenLanguageIdentificationConfig( + type: Int, + numThreads: Int = 1 +): SpokenLanguageIdentificationConfig? { + when (type) { + 0 -> { + val modelDir = "sherpa-onnx-whisper-tiny" + return SpokenLanguageIdentificationConfig( + whisper = SpokenLanguageIdentificationWhisperConfig( + encoder = "$modelDir/tiny-encoder.int8.onnx", + decoder = "$modelDir/tiny-decoder.int8.onnx", + ), + numThreads = numThreads, + debug = true, + ) + } + + 1 -> { + val modelDir = "sherpa-onnx-whisper-base" + return SpokenLanguageIdentificationConfig( + whisper = SpokenLanguageIdentificationWhisperConfig( + encoder = "$modelDir/tiny-encoder.int8.onnx", + decoder = "$modelDir/tiny-decoder.int8.onnx", + ), + numThreads = 1, + debug = true, + ) + } + } + return null +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/Tts.kt b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/Tts.kt new file mode 100644 index 00000000..9c2ee1a4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/Tts.kt @@ -0,0 +1,283 @@ +// Copyright (c) 2023 Xiaomi Corporation +package com.k2fsa.sherpa.mnn + +import android.content.res.AssetManager + +data class OfflineTtsVitsModelConfig( + var model: String = "", + var lexicon: String = "", + var tokens: String = "", + var dataDir: String = "", + var dictDir: String = "", + var noiseScale: Float = 0.667f, + var noiseScaleW: Float = 0.8f, + var lengthScale: Float = 1.0f, +) + +data class OfflineTtsMatchaModelConfig( + var acousticModel: String = "", + var vocoder: String = "", + var lexicon: String = "", + var tokens: String = "", + var dataDir: String = "", + var dictDir: String = "", + var noiseScale: Float = 1.0f, + var lengthScale: Float = 1.0f, +) + +data class OfflineTtsKokoroModelConfig( + var model: String = "", + var voices: String = "", + var tokens: String = "", + var dataDir: String = "", + var lexicon: String = "", + var dictDir: String = "", + var lengthScale: Float = 1.0f, +) + +data class OfflineTtsModelConfig( + var vits: OfflineTtsVitsModelConfig = OfflineTtsVitsModelConfig(), + var matcha: OfflineTtsMatchaModelConfig = OfflineTtsMatchaModelConfig(), + var kokoro: OfflineTtsKokoroModelConfig = OfflineTtsKokoroModelConfig(), + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", +) + +data class OfflineTtsConfig( + var model: OfflineTtsModelConfig = OfflineTtsModelConfig(), + var ruleFsts: String = "", + var ruleFars: String = "", + var maxNumSentences: Int = 1, + var silenceScale: Float = 0.2f, +) + +class GeneratedAudio( + val samples: FloatArray, + val sampleRate: Int, +) { + fun save(filename: String) = + saveImpl(filename = filename, samples = samples, sampleRate = sampleRate) + + private external fun saveImpl( + filename: String, + samples: FloatArray, + sampleRate: Int + ): Boolean +} + +class OfflineTts( + assetManager: AssetManager? = null, + var config: OfflineTtsConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + + fun sampleRate() = getSampleRate(ptr) + + fun numSpeakers() = getNumSpeakers(ptr) + + fun generate( + text: String, + sid: Int = 0, + speed: Float = 1.0f + ): GeneratedAudio { + val objArray = generateImpl(ptr, text = text, sid = sid, speed = speed) + return GeneratedAudio( + samples = objArray[0] as FloatArray, + sampleRate = objArray[1] as Int + ) + } + + fun generateWithCallback( + text: String, + sid: Int = 0, + speed: Float = 1.0f, + callback: (samples: FloatArray) -> Int + ): GeneratedAudio { + val objArray = generateWithCallbackImpl( + ptr, + text = text, + sid = sid, + speed = speed, + callback = callback + ) + return GeneratedAudio( + samples = objArray[0] as FloatArray, + sampleRate = objArray[1] as Int + ) + } + + fun allocate(assetManager: AssetManager? = null) { + if (ptr == 0L) { + ptr = if (assetManager != null) { + newFromAsset(assetManager, config) + } else { + newFromFile(config) + } + } + } + + fun free() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + private external fun newFromAsset( + assetManager: AssetManager, + config: OfflineTtsConfig, + ): Long + + private external fun newFromFile( + config: OfflineTtsConfig, + ): Long + + private external fun delete(ptr: Long) + private external fun getSampleRate(ptr: Long): Int + private external fun getNumSpeakers(ptr: Long): Int + + // The returned array has two entries: + // - the first entry is an 1-D float array containing audio samples. + // Each sample is normalized to the range [-1, 1] + // - the second entry is the sample rate + private external fun generateImpl( + ptr: Long, + text: String, + sid: Int = 0, + speed: Float = 1.0f + ): Array + + private external fun generateWithCallbackImpl( + ptr: Long, + text: String, + sid: Int = 0, + speed: Float = 1.0f, + callback: (samples: FloatArray) -> Int + ): Array + + companion object { + init { + System.loadLibrary("sherpa-mnn-jni") + } + } +} + +// please refer to +// https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/index.html +// to download models +fun getOfflineTtsConfig( + modelDir: String, + modelName: String, // for VITS + acousticModelName: String, // for Matcha + vocoder: String, // for Matcha + voices: String, // for Kokoro + lexicon: String, + dataDir: String, + dictDir: String, + ruleFsts: String, + ruleFars: String, + numThreads: Int? = null +): OfflineTtsConfig { + // For Matcha TTS, please set + // acousticModelName, vocoder + + // For Kokoro TTS, please set + // modelName, voices + + // For VITS, please set + // modelName + + val numberOfThreads = if (numThreads != null) { + numThreads + } else if (voices.isNotEmpty()) { + // for Kokoro TTS models, we use more threads + 4 + } else { + 2 + } + + if (modelName.isEmpty() && acousticModelName.isEmpty()) { + throw IllegalArgumentException("Please specify a TTS model") + } + + if (modelName.isNotEmpty() && acousticModelName.isNotEmpty()) { + throw IllegalArgumentException("Please specify either a VITS or a Matcha model, but not both") + } + + if (acousticModelName.isNotEmpty() && vocoder.isEmpty()) { + throw IllegalArgumentException("Please provide vocoder for Matcha TTS") + } + + val vits = if (modelName.isNotEmpty() && voices.isEmpty()) { + OfflineTtsVitsModelConfig( + model = "$modelDir/$modelName", + lexicon = "$modelDir/$lexicon", + tokens = "$modelDir/tokens.txt", + dataDir = dataDir, + dictDir = dictDir, + ) + } else { + OfflineTtsVitsModelConfig() + } + + val matcha = if (acousticModelName.isNotEmpty()) { + OfflineTtsMatchaModelConfig( + acousticModel = "$modelDir/$acousticModelName", + vocoder = vocoder, + lexicon = "$modelDir/$lexicon", + tokens = "$modelDir/tokens.txt", + dictDir = dictDir, + dataDir = dataDir, + ) + } else { + OfflineTtsMatchaModelConfig() + } + + val kokoro = if (voices.isNotEmpty()) { + OfflineTtsKokoroModelConfig( + model = "$modelDir/$modelName", + voices = "$modelDir/$voices", + tokens = "$modelDir/tokens.txt", + dataDir = dataDir, + lexicon = when { + lexicon == "" -> lexicon + "," in lexicon -> lexicon + else -> "$modelDir/$lexicon" + }, + dictDir = dictDir, + ) + } else { + OfflineTtsKokoroModelConfig() + } + + return OfflineTtsConfig( + model = OfflineTtsModelConfig( + vits = vits, + matcha = matcha, + kokoro = kokoro, + numThreads = numberOfThreads, + debug = true, + provider = "cpu", + ), + ruleFsts = ruleFsts, + ruleFars = ruleFars, + ) +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/Vad.kt b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/Vad.kt new file mode 100644 index 00000000..6355c2c8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/Vad.kt @@ -0,0 +1,116 @@ +// Copyright (c) 2023 Xiaomi Corporation +package com.k2fsa.sherpa.mnn + +import android.content.res.AssetManager + +data class SileroVadModelConfig( + var model: String = "", + var threshold: Float = 0.5F, + var minSilenceDuration: Float = 0.25F, + var minSpeechDuration: Float = 0.25F, + var windowSize: Int = 512, + var maxSpeechDuration: Float = 5.0F, +) + +data class VadModelConfig( + var sileroVadModelConfig: SileroVadModelConfig = SileroVadModelConfig(), + var sampleRate: Int = 16000, + var numThreads: Int = 1, + var provider: String = "cpu", + var debug: Boolean = false, +) + +class SpeechSegment(val start: Int, val samples: FloatArray) + +class Vad( + assetManager: AssetManager? = null, + var config: VadModelConfig, +) { + private var ptr: Long + + init { + if (assetManager != null) { + ptr = newFromAsset(assetManager, config) + } else { + ptr = newFromFile(config) + } + } + + protected fun finalize() { + if (ptr != 0L) { + delete(ptr) + ptr = 0 + } + } + + fun release() = finalize() + + fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples) + + fun empty(): Boolean = empty(ptr) + fun pop() = pop(ptr) + + fun front(): SpeechSegment { + val segment = front(ptr) + return SpeechSegment(segment[0] as Int, segment[1] as FloatArray) + } + + fun clear() = clear(ptr) + + fun isSpeechDetected(): Boolean = isSpeechDetected(ptr) + + fun reset() = reset(ptr) + + fun flush() = flush(ptr) + + private external fun delete(ptr: Long) + + private external fun newFromAsset( + assetManager: AssetManager, + config: VadModelConfig, + ): Long + + private external fun newFromFile( + config: VadModelConfig, + ): Long + + private external fun acceptWaveform(ptr: Long, samples: FloatArray) + private external fun empty(ptr: Long): Boolean + private external fun pop(ptr: Long) + private external fun clear(ptr: Long) + private external fun front(ptr: Long): Array + private external fun isSpeechDetected(ptr: Long): Boolean + private external fun reset(ptr: Long) + private external fun flush(ptr: Long) + + companion object { + init { + System.loadLibrary("sherpa-mnn-jni") + } + } +} + +// Please visit +// https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx +// to download silero_vad.onnx +// and put it inside the assets/ +// directory +fun getVadModelConfig(type: Int): VadModelConfig? { + when (type) { + 0 -> { + return VadModelConfig( + sileroVadModelConfig = SileroVadModelConfig( + model = "silero_vad.onnx", + threshold = 0.5F, + minSilenceDuration = 0.25F, + minSpeechDuration = 0.25F, + windowSize = 512, + ), + sampleRate = 16000, + numThreads = 1, + provider = "cpu", + ) + } + } + return null +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/WaveReader.kt b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/WaveReader.kt new file mode 100644 index 00000000..94305542 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/kotlin-api/WaveReader.kt @@ -0,0 +1,70 @@ +// Copyright (c) 2023 Xiaomi Corporation +package com.k2fsa.sherpa.mnn + +import android.content.res.AssetManager + +data class WaveData( + val samples: FloatArray, + val sampleRate: Int, +) { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + + other as WaveData + + if (!samples.contentEquals(other.samples)) return false + if (sampleRate != other.sampleRate) return false + + return true + } + + override fun hashCode(): Int { + var result = samples.contentHashCode() + result = 31 * result + sampleRate + return result + } +} + +class WaveReader { + companion object { + + fun readWave( + assetManager: AssetManager, + filename: String, + ): WaveData { + return readWaveFromAsset(assetManager, filename).let { + WaveData(it[0] as FloatArray, it[1] as Int) + } + } + + fun readWave( + filename: String, + ): WaveData { + return readWaveFromFile(filename).let { + WaveData(it[0] as FloatArray, it[1] as Int) + } + } + + // Read a mono wave file asset + // The returned array has two entries: + // - the first entry contains an 1-D float array + // - the second entry is the sample rate + external fun readWaveFromAsset( + assetManager: AssetManager, + filename: String, + ): Array + + // Read a mono wave file from disk + // The returned array has two entries: + // - the first entry contains an 1-D float array + // - the second entry is the sample rate + external fun readWaveFromFile( + filename: String, + ): Array + + init { + System.loadLibrary("sherpa-mnn-jni") + } + } +} diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/CMakeLists.txt b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/CMakeLists.txt new file mode 100644 index 00000000..ae2bb779 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(csrc) + +if(SHERPA_MNN_ENABLE_TESTS) + add_subdirectory(tests) +endif() diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/CMakeLists.txt b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/CMakeLists.txt new file mode 100644 index 00000000..501c8145 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/CMakeLists.txt @@ -0,0 +1,109 @@ +include_directories(${CMAKE_SOURCE_DIR}) + +set(srcs + audio-tagging.cc + circular-buffer.cc + cuda-config.cc + display.cc + endpoint.cc + features.cc + keyword-spotter.cc + offline-ctc-fst-decoder-config.cc + offline-fire-red-asr-model-config.cc + offline-lm-config.cc + offline-model-config.cc + offline-moonshine-model-config.cc + offline-nemo-enc-dec-ctc-model-config.cc + offline-paraformer-model-config.cc + offline-punctuation.cc + offline-recognizer.cc + offline-sense-voice-model-config.cc + offline-speech-denoiser-gtcrn-model-config.cc + offline-speech-denoiser-model-config.cc + offline-speech-denoiser.cc + offline-stream.cc + offline-tdnn-model-config.cc + offline-transducer-model-config.cc + offline-wenet-ctc-model-config.cc + offline-whisper-model-config.cc + offline-zipformer-ctc-model-config.cc + online-ctc-fst-decoder-config.cc + online-lm-config.cc + online-model-config.cc + online-nemo-ctc-model-config.cc + online-paraformer-model-config.cc + online-punctuation.cc + online-recognizer.cc + online-stream.cc + online-transducer-model-config.cc + online-wenet-ctc-model-config.cc + online-zipformer2-ctc-model-config.cc + provider-config.cc + sherpa-mnn.cc + silero-vad-model-config.cc + speaker-embedding-extractor.cc + speaker-embedding-manager.cc + spoken-language-identification.cc + tensorrt-config.cc + vad-model-config.cc + vad-model.cc + voice-activity-detector.cc + wave-writer.cc +) +if(SHERPA_MNN_HAS_ALSA) + list(APPEND srcs ${CMAKE_SOURCE_DIR}/sherpa-mnn/csrc/alsa.cc alsa.cc) +else() + list(APPEND srcs faked-alsa.cc) +endif() + +if(SHERPA_MNN_ENABLE_TTS) + list(APPEND srcs + offline-tts-kokoro-model-config.cc + offline-tts-matcha-model-config.cc + offline-tts-model-config.cc + offline-tts-vits-model-config.cc + offline-tts.cc + ) +endif() + +if(SHERPA_MNN_ENABLE_SPEAKER_DIARIZATION) + list(APPEND srcs + fast-clustering.cc + offline-speaker-diarization-result.cc + offline-speaker-diarization.cc + ) +endif() + +pybind11_add_module(_sherpa_mnn ${srcs}) + +if(APPLE) + execute_process( + COMMAND "${PYTHON_EXECUTABLE}" -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE PYTHON_SITE_PACKAGE_DIR + ) + message(STATUS "PYTHON_SITE_PACKAGE_DIR: ${PYTHON_SITE_PACKAGE_DIR}") + if(PYTHON_SITE_PACKAGE_DIR STREQUAL "") + message(WARNING "PYTHON_SITE_PACKAGE_DIR is empty!") + else() + target_link_libraries(_sherpa_mnn PRIVATE "-Wl,-rpath,${PYTHON_SITE_PACKAGE_DIR}") + endif() +endif() + +if(NOT WIN32) + target_link_libraries(_sherpa_mnn PRIVATE "-Wl,-rpath,${SHERPA_MNN_RPATH_ORIGIN}/sherpa_mnn/lib") +endif() + +target_link_libraries(_sherpa_mnn PRIVATE sherpa-mnn-core) + +if(SHERPA_MNN_HAS_ALSA) + if(DEFINED ENV{SHERPA_MNN_ALSA_LIB_DIR}) + target_link_libraries(_sherpa_mnn PRIVATE -L$ENV{SHERPA_MNN_ALSA_LIB_DIR} -lasound) + else() + target_link_libraries(_sherpa_mnn PRIVATE asound) + endif() +endif() + +install(TARGETS _sherpa_mnn + DESTINATION ../ +) diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/alsa.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/alsa.cc new file mode 100644 index 00000000..68679881 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/alsa.cc @@ -0,0 +1,30 @@ +// sherpa-mnn/python/csrc/alsa.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/alsa.h" + +#include + +#include "sherpa-mnn/csrc/alsa.h" + +namespace sherpa_mnn { + +void PybindAlsa(py::module *m) { + using PyClass = Alsa; + py::class_(*m, "Alsa") + .def(py::init(), py::arg("device_name"), + py::call_guard()) + .def( + "read", + [](PyClass &self, int32_t num_samples) -> std::vector { + return self.Read(num_samples); + }, + py::arg("num_samples"), py::call_guard()) + .def_property_readonly("expected_sample_rate", + &PyClass::GetExpectedSampleRate) + .def_property_readonly("actual_sample_rate", + &PyClass::GetActualSampleRate); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/alsa.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/alsa.h new file mode 100644 index 00000000..57b97f49 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/alsa.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/alsa.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ALSA_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ALSA_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindAlsa(py::module *m); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_PYTHON_CSRC_ALSA_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/audio-tagging.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/audio-tagging.cc new file mode 100644 index 00000000..7519e55a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/audio-tagging.cc @@ -0,0 +1,88 @@ +// sherpa-mnn/python/csrc/audio-tagging.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/audio-tagging.h" + +#include + +#include "sherpa-mnn/csrc/audio-tagging.h" + +namespace sherpa_mnn { + +static void PybindOfflineZipformerAudioTaggingModelConfig(py::module *m) { + using PyClass = OfflineZipformerAudioTaggingModelConfig; + py::class_(*m, "OfflineZipformerAudioTaggingModelConfig") + .def(py::init<>()) + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +static void PybindAudioTaggingModelConfig(py::module *m) { + PybindOfflineZipformerAudioTaggingModelConfig(m); + + using PyClass = AudioTaggingModelConfig; + + py::class_(*m, "AudioTaggingModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("zipformer") = OfflineZipformerAudioTaggingModelConfig{}, + py::arg("ced") = "", py::arg("num_threads") = 1, + py::arg("debug") = false, py::arg("provider") = "cpu") + .def_readwrite("zipformer", &PyClass::zipformer) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +static void PybindAudioTaggingConfig(py::module *m) { + PybindAudioTaggingModelConfig(m); + + using PyClass = AudioTaggingConfig; + + py::class_(*m, "AudioTaggingConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("model"), py::arg("labels"), py::arg("top_k") = 5) + .def_readwrite("model", &PyClass::model) + .def_readwrite("labels", &PyClass::labels) + .def_readwrite("top_k", &PyClass::top_k) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +static void PybindAudioEvent(py::module *m) { + using PyClass = AudioEvent; + + py::class_(*m, "AudioEvent") + .def_property_readonly( + "name", [](const PyClass &self) -> std::string { return self.name; }) + .def_property_readonly( + "index", [](const PyClass &self) -> int32_t { return self.index; }) + .def_property_readonly( + "prob", [](const PyClass &self) -> float { return self.prob; }) + .def("__str__", &PyClass::ToString); +} + +void PybindAudioTagging(py::module *m) { + PybindAudioTaggingConfig(m); + PybindAudioEvent(m); + + using PyClass = AudioTagging; + + py::class_(*m, "AudioTagging") + .def(py::init(), py::arg("config"), + py::call_guard()) + .def("create_stream", &PyClass::CreateStream, + py::call_guard()) + .def("compute", &PyClass::Compute, py::arg("s"), py::arg("top_k") = -1, + py::call_guard()); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/audio-tagging.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/audio-tagging.h new file mode 100644 index 00000000..9ffc6ca3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/audio-tagging.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/audio-tagging.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_ +#define SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindAudioTagging(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_AUDIO_TAGGING_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/circular-buffer.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/circular-buffer.cc new file mode 100644 index 00000000..f09e69f2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/circular-buffer.cc @@ -0,0 +1,33 @@ +// sherpa-mnn/python/csrc/circular-buffer.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/circular-buffer.h" + +#include + +#include "sherpa-mnn/csrc/circular-buffer.h" + +namespace sherpa_mnn { + +void PybindCircularBuffer(py::module *m) { + using PyClass = CircularBuffer; + py::class_(*m, "CircularBuffer") + .def(py::init(), py::arg("capacity")) + .def( + "push", + [](PyClass &self, const std::vector &samples) { + self.Push(samples.data(), samples.size()); + }, + py::arg("samples"), py::call_guard()) + .def("get", &PyClass::Get, py::arg("start_index"), py::arg("n"), + py::call_guard()) + .def("pop", &PyClass::Pop, py::arg("n"), + py::call_guard()) + .def("reset", &PyClass::Reset, py::call_guard()) + .def_property_readonly("size", &PyClass::Size) + .def_property_readonly("head", &PyClass::Head) + .def_property_readonly("tail", &PyClass::Tail); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/circular-buffer.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/circular-buffer.h new file mode 100644 index 00000000..2ad5a970 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/circular-buffer.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/circular-buffer.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_CIRCULAR_BUFFER_H_ +#define SHERPA_ONNX_PYTHON_CSRC_CIRCULAR_BUFFER_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindCircularBuffer(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_CIRCULAR_BUFFER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/cuda-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/cuda-config.cc new file mode 100644 index 00000000..a1574abd --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/cuda-config.cc @@ -0,0 +1,24 @@ +// sherpa-mnn/python/csrc/cuda-config.cc +// +// Copyright (c) 2024 Uniphore (Author: Manickavela A) + +#include "sherpa-mnn/python/csrc/cuda-config.h" + +#include +#include + +#include "sherpa-mnn/csrc/provider-config.h" + +namespace sherpa_mnn { + +void PybindCudaConfig(py::module *m) { + using PyClass = CudaConfig; + py::class_(*m, "CudaConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("cudnn_conv_algo_search") = 1) + .def_readwrite("cudnn_conv_algo_search", &PyClass::cudnn_conv_algo_search) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/cuda-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/cuda-config.h new file mode 100644 index 00000000..da9be6ab --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/cuda-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/cuda-config.h +// +// Copyright (c) 2024 Uniphore (Author: Manickavela A) + +#ifndef SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindCudaConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/display.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/display.cc new file mode 100644 index 00000000..e638f1b1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/display.cc @@ -0,0 +1,18 @@ +// sherpa-mnn/python/csrc/display.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/display.h" + +#include "sherpa-mnn/csrc/display.h" + +namespace sherpa_mnn { + +void PybindDisplay(py::module *m) { + using PyClass = Display; + py::class_(*m, "Display") + .def(py::init(), py::arg("max_word_per_line") = 60) + .def("print", &PyClass::Print, py::arg("idx"), py::arg("s")); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/display.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/display.h new file mode 100644 index 00000000..30836e19 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/display.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/display.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_ +#define SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindDisplay(py::module *m); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/endpoint.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/endpoint.cc new file mode 100644 index 00000000..5505f2b2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/endpoint.cc @@ -0,0 +1,100 @@ +// sherpa-mnn/csrc/endpoint.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/endpoint.h" + +#include +#include + +#include "sherpa-mnn/csrc/endpoint.h" + +namespace sherpa_mnn { + +static constexpr const char *kEndpointRuleInitDoc = R"doc( +Constructor for EndpointRule. + +Args: + must_contain_nonsilence: + If True, for this endpointing rule to apply there must be nonsilence in the + best-path traceback. For decoding, a non-blank token is considered as + non-silence. + min_trailing_silence: + This endpointing rule requires duration of trailing silence (in seconds) + to be ``>=`` this value. + min_utterance_length: + This endpointing rule requires utterance-length (in seconds) to + be ``>=`` this value. +)doc"; + +static constexpr const char *kEndpointConfigInitDoc = R"doc( +If any rule in EndpointConfig is activated, it is said that an endpointing +is detected. + +Args: + rule1: + By default, it times out after 2.4 seconds of silence, even if + we decoded nothing. + rule2: + By default, it times out after 1.2 seconds of silence after decoding + something. + rule3: + By default, it times out after the utterance is 20 seconds long, regardless of + anything else. +)doc"; + +static void PybindEndpointRule(py::module *m) { + using PyClass = EndpointRule; + py::class_(*m, "EndpointRule") + .def(py::init(), py::arg("must_contain_nonsilence"), + py::arg("min_trailing_silence"), py::arg("min_utterance_length"), + kEndpointRuleInitDoc) + .def("__str__", &PyClass::ToString) + .def_readwrite("must_contain_nonsilence", + &PyClass::must_contain_nonsilence) + .def_readwrite("min_trailing_silence", &PyClass::min_trailing_silence) + .def_readwrite("min_utterance_length", &PyClass::min_utterance_length); +} + +static void PybindEndpointConfig(py::module *m) { + using PyClass = EndpointConfig; + py::class_(*m, "EndpointConfig") + .def( + py::init( + [](float rule1_min_trailing_silence, + float rule2_min_trailing_silence, + float rule3_min_utterance_length) -> std::unique_ptr { + EndpointRule rule1(false, rule1_min_trailing_silence, 0); + EndpointRule rule2(true, rule2_min_trailing_silence, 0); + EndpointRule rule3(false, 0, rule3_min_utterance_length); + + return std::make_unique(rule1, rule2, rule3); + }), + py::arg("rule1_min_trailing_silence"), + py::arg("rule2_min_trailing_silence"), + py::arg("rule3_min_utterance_length")) + .def(py::init([](const EndpointRule &rule1, const EndpointRule &rule2, + const EndpointRule &rule3) -> std::unique_ptr { + auto ans = std::make_unique(); + ans->rule1 = rule1; + ans->rule2 = rule2; + ans->rule3 = rule3; + return ans; + }), + py::arg("rule1") = EndpointRule(false, 2.4, 0), + py::arg("rule2") = EndpointRule(true, 1.2, 0), + py::arg("rule3") = EndpointRule(false, 0, 20), + kEndpointConfigInitDoc) + .def("__str__", + [](const PyClass &self) -> std::string { return self.ToString(); }) + .def_readwrite("rule1", &PyClass::rule1) + .def_readwrite("rule2", &PyClass::rule2) + .def_readwrite("rule3", &PyClass::rule3); +} + +void PybindEndpoint(py::module *m) { + PybindEndpointRule(m); + PybindEndpointConfig(m); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/endpoint.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/endpoint.h new file mode 100644 index 00000000..302d32ec --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/endpoint.h @@ -0,0 +1,16 @@ +// sherpa-mnn/csrc/endpoint.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindEndpoint(py::module *m); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_PYTHON_CSRC_ENDPOINT_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/faked-alsa.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/faked-alsa.cc new file mode 100644 index 00000000..7c7c0b44 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/faked-alsa.cc @@ -0,0 +1,45 @@ +// sherpa-mnn/python/csrc/faked-alsa.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/macros.h" +#include "sherpa-mnn/python/csrc/alsa.h" + +namespace sherpa_mnn { + +class FakedAlsa { + public: + explicit FakedAlsa(const char *) { + SHERPA_ONNX_LOGE("This function is for Linux only."); +#if (SHERPA_ONNX_ENABLE_ALSA == 0) && (defined(__unix__) || defined(__unix)) + SHERPA_ONNX_LOGE(R"doc( +sherpa-mnn is compiled without alsa support. To enable that, please run + (1) sudo apt-get install alsa-utils libasound2-dev + (2) rebuild sherpa-mnn +)doc"); +#endif + exit(-1); + } + + std::vector Read(int32_t) const { return {}; } + int32_t GetExpectedSampleRate() const { return -1; } + int32_t GetActualSampleRate() const { return -1; } +}; + +void PybindAlsa(py::module *m) { + using PyClass = FakedAlsa; + py::class_(*m, "Alsa") + .def(py::init(), py::arg("device_name")) + .def( + "read", + [](PyClass &self, int32_t num_samples) -> std::vector { + return self.Read(num_samples); + }, + py::arg("num_samples"), py::call_guard()) + .def_property_readonly("expected_sample_rate", + &PyClass::GetExpectedSampleRate) + .def_property_readonly("actual_sample_rate", + &PyClass::GetActualSampleRate); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/fast-clustering.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/fast-clustering.cc new file mode 100644 index 00000000..7b30c4e6 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/fast-clustering.cc @@ -0,0 +1,52 @@ +// sherpa-mnn/python/csrc/fast-clustering.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/fast-clustering.h" + +#include +#include + +#include "sherpa-mnn/csrc/fast-clustering.h" + +namespace sherpa_mnn { + +static void PybindFastClusteringConfig(py::module *m) { + using PyClass = FastClusteringConfig; + py::class_(*m, "FastClusteringConfig") + .def(py::init(), py::arg("num_clusters") = -1, + py::arg("threshold") = 0.5) + .def_readwrite("num_clusters", &PyClass::num_clusters) + .def_readwrite("threshold", &PyClass::threshold) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +void PybindFastClustering(py::module *m) { + PybindFastClusteringConfig(m); + + using PyClass = FastClustering; + py::class_(*m, "FastClustering") + .def(py::init(), py::arg("config")) + .def( + "__call__", + [](const PyClass &self, + py::array_t features) -> std::vector { + int num_dim = features.ndim(); + if (num_dim != 2) { + std::ostringstream os; + os << "Expect an array of 2 dimensions. Given dim: " << num_dim + << "\n"; + throw py::value_error(os.str()); + } + + int32_t num_rows = features.shape(0); + int32_t num_cols = features.shape(1); + float *p = features.mutable_data(); + py::gil_scoped_release release; + return self.Cluster(p, num_rows, num_cols); + }, + py::arg("features")); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/fast-clustering.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/fast-clustering.h new file mode 100644 index 00000000..e9db2a4c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/fast-clustering.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/fast-clustering.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_FAST_CLUSTERING_H_ +#define SHERPA_ONNX_PYTHON_CSRC_FAST_CLUSTERING_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindFastClustering(py::module *m); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_PYTHON_CSRC_FAST_CLUSTERING_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/features.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/features.cc new file mode 100644 index 00000000..4fc2a1b4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/features.cc @@ -0,0 +1,34 @@ +// sherpa-mnn/python/csrc/features.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/features.h" + +#include "sherpa-mnn/csrc/features.h" + +namespace sherpa_mnn { + +static void PybindFeatureExtractorConfig(py::module *m) { + using PyClass = FeatureExtractorConfig; + py::class_(*m, "FeatureExtractorConfig") + .def(py::init(), + py::arg("sampling_rate") = 16000, + py::arg("feature_dim") = 80, + py::arg("low_freq") = 20.0f, + py::arg("high_freq") = -400.0f, + py::arg("dither") = 0.0f, + py::arg("normalize_samples") = true, + py::arg("snip_edges") = false) + .def_readwrite("sampling_rate", &PyClass::sampling_rate) + .def_readwrite("feature_dim", &PyClass::feature_dim) + .def_readwrite("low_freq", &PyClass::low_freq) + .def_readwrite("high_freq", &PyClass::high_freq) + .def_readwrite("dither", &PyClass::dither) + .def_readwrite("normalize_samples", &PyClass::normalize_samples) + .def_readwrite("snip_edges", &PyClass::snip_edges) + .def("__str__", &PyClass::ToString); +} + +void PybindFeatures(py::module *m) { PybindFeatureExtractorConfig(m); } + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/features.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/features.h new file mode 100644 index 00000000..04fd965d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/features.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/features.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_FEATURES_H_ +#define SHERPA_ONNX_PYTHON_CSRC_FEATURES_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindFeatures(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_FEATURES_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/keyword-spotter.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/keyword-spotter.cc new file mode 100644 index 00000000..2c643672 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/keyword-spotter.cc @@ -0,0 +1,83 @@ +// sherpa-mnn/python/csrc/keyword-spotter.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/keyword-spotter.h" + +#include +#include + +#include "sherpa-mnn/csrc/keyword-spotter.h" + +namespace sherpa_mnn { + +static void PybindKeywordResult(py::module *m) { + using PyClass = KeywordResult; + py::class_(*m, "KeywordResult") + .def_property_readonly( + "keyword", + [](PyClass &self) -> py::str { + return py::str(PyUnicode_DecodeUTF8(self.keyword.c_str(), + self.keyword.size(), "ignore")); + }) + .def_property_readonly( + "tokens", + [](PyClass &self) -> std::vector { return self.tokens; }) + .def_property_readonly( + "timestamps", + [](PyClass &self) -> std::vector { return self.timestamps; }); +} + +static void PybindKeywordSpotterConfig(py::module *m) { + using PyClass = KeywordSpotterConfig; + py::class_(*m, "KeywordSpotterConfig") + .def(py::init(), + py::arg("feat_config"), py::arg("model_config"), + py::arg("max_active_paths") = 4, py::arg("num_trailing_blanks") = 1, + py::arg("keywords_score") = 1.0, + py::arg("keywords_threshold") = 0.25, py::arg("keywords_file") = "") + .def_readwrite("feat_config", &PyClass::feat_config) + .def_readwrite("model_config", &PyClass::model_config) + .def_readwrite("max_active_paths", &PyClass::max_active_paths) + .def_readwrite("num_trailing_blanks", &PyClass::num_trailing_blanks) + .def_readwrite("keywords_score", &PyClass::keywords_score) + .def_readwrite("keywords_threshold", &PyClass::keywords_threshold) + .def_readwrite("keywords_file", &PyClass::keywords_file) + .def("__str__", &PyClass::ToString); +} + +void PybindKeywordSpotter(py::module *m) { + PybindKeywordResult(m); + PybindKeywordSpotterConfig(m); + + using PyClass = KeywordSpotter; + py::class_(*m, "KeywordSpotter") + .def(py::init(), py::arg("config"), + py::call_guard()) + .def( + "create_stream", + [](const PyClass &self) { return self.CreateStream(); }, + py::call_guard()) + .def( + "create_stream", + [](PyClass &self, const std::string &keywords) { + return self.CreateStream(keywords); + }, + py::arg("keywords"), py::call_guard()) + .def("is_ready", &PyClass::IsReady, + py::call_guard()) + .def("reset", &PyClass::Reset, py::call_guard()) + .def("decode_stream", &PyClass::DecodeStream, + py::call_guard()) + .def( + "decode_streams", + [](PyClass &self, std::vector ss) { + self.DecodeStreams(ss.data(), ss.size()); + }, + py::call_guard()) + .def("get_result", &PyClass::GetResult, + py::call_guard()); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/keyword-spotter.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/keyword-spotter.h new file mode 100644 index 00000000..80a0dae4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/keyword-spotter.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/keyword-spotter.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_ +#define SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindKeywordSpotter(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_KEYWORD_SPOTTER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-ctc-fst-decoder-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-ctc-fst-decoder-config.cc new file mode 100644 index 00000000..f3fcc415 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-ctc-fst-decoder-config.cc @@ -0,0 +1,23 @@ +// sherpa-mnn/python/csrc/offline-ctc-fst-decoder-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/offline-ctc-fst-decoder-config.h" + +#include + +#include "sherpa-mnn/csrc/offline-ctc-fst-decoder-config.h" + +namespace sherpa_mnn { + +void PybindOfflineCtcFstDecoderConfig(py::module *m) { + using PyClass = OfflineCtcFstDecoderConfig; + py::class_(*m, "OfflineCtcFstDecoderConfig") + .def(py::init(), py::arg("graph") = "", + py::arg("max_active") = 3000) + .def_readwrite("graph", &PyClass::graph) + .def_readwrite("max_active", &PyClass::max_active) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-ctc-fst-decoder-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-ctc-fst-decoder-config.h new file mode 100644 index 00000000..2b1fb2f9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-ctc-fst-decoder-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-ctc-fst-decoder-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineCtcFstDecoderConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-fire-red-asr-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-fire-red-asr-model-config.cc new file mode 100644 index 00000000..4608b61b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-fire-red-asr-model-config.cc @@ -0,0 +1,24 @@ +// sherpa-mnn/python/csrc/offline-fire-red-asr-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-fire-red-asr-model-config.h" + +#include +#include + +#include "sherpa-mnn/python/csrc/offline-fire-red-asr-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineFireRedAsrModelConfig(py::module *m) { + using PyClass = OfflineFireRedAsrModelConfig; + py::class_(*m, "OfflineFireRedAsrModelConfig") + .def(py::init(), + py::arg("encoder"), py::arg("decoder")) + .def_readwrite("encoder", &PyClass::encoder) + .def_readwrite("decoder", &PyClass::decoder) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-fire-red-asr-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-fire-red-asr-model-config.h new file mode 100644 index 00000000..8a7a8b1e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-fire-red-asr-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-fire-red-asr-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineFireRedAsrModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_FIRE_RED_ASR_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-lm-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-lm-config.cc new file mode 100644 index 00000000..7d9c8688 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-lm-config.cc @@ -0,0 +1,26 @@ +// sherpa-mnn/python/csrc/offline-lm-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/offline-lm-config.h" + +#include + +#include "sherpa-mnn//csrc/offline-lm-config.h" + +namespace sherpa_mnn { + +void PybindOfflineLMConfig(py::module *m) { + using PyClass = OfflineLMConfig; + py::class_(*m, "OfflineLMConfig") + .def(py::init(), + py::arg("model"), py::arg("scale") = 0.5f, + py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu") + .def_readwrite("model", &PyClass::model) + .def_readwrite("scale", &PyClass::scale) + .def_readwrite("lm_provider", &PyClass::lm_provider) + .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-lm-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-lm-config.h new file mode 100644 index 00000000..ad38082d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-lm-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-lm-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineLMConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-model-config.cc new file mode 100644 index 00000000..fa4a3461 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-model-config.cc @@ -0,0 +1,87 @@ +// sherpa-mnn/python/csrc/offline-model-config.cc +// +// Copyright (c) 2023 by manyeyes + +#include "sherpa-mnn/python/csrc/offline-model-config.h" + +#include +#include + +#include "sherpa-mnn/csrc/offline-model-config.h" +#include "sherpa-mnn/python/csrc/offline-fire-red-asr-model-config.h" +#include "sherpa-mnn/python/csrc/offline-moonshine-model-config.h" +#include "sherpa-mnn/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" +#include "sherpa-mnn/python/csrc/offline-paraformer-model-config.h" +#include "sherpa-mnn/python/csrc/offline-sense-voice-model-config.h" +#include "sherpa-mnn/python/csrc/offline-tdnn-model-config.h" +#include "sherpa-mnn/python/csrc/offline-transducer-model-config.h" +#include "sherpa-mnn/python/csrc/offline-wenet-ctc-model-config.h" +#include "sherpa-mnn/python/csrc/offline-whisper-model-config.h" +#include "sherpa-mnn/python/csrc/offline-zipformer-ctc-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineModelConfig(py::module *m) { + PybindOfflineTransducerModelConfig(m); + PybindOfflineParaformerModelConfig(m); + PybindOfflineNemoEncDecCtcModelConfig(m); + PybindOfflineWhisperModelConfig(m); + PybindOfflineFireRedAsrModelConfig(m); + PybindOfflineTdnnModelConfig(m); + PybindOfflineZipformerCtcModelConfig(m); + PybindOfflineWenetCtcModelConfig(m); + PybindOfflineSenseVoiceModelConfig(m); + PybindOfflineMoonshineModelConfig(m); + + using PyClass = OfflineModelConfig; + py::class_(*m, "OfflineModelConfig") + .def(py::init(), + py::arg("transducer") = OfflineTransducerModelConfig(), + py::arg("paraformer") = OfflineParaformerModelConfig(), + py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), + py::arg("whisper") = OfflineWhisperModelConfig(), + py::arg("fire_red_asr") = OfflineFireRedAsrModelConfig(), + py::arg("tdnn") = OfflineTdnnModelConfig(), + py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), + py::arg("wenet_ctc") = OfflineWenetCtcModelConfig(), + py::arg("sense_voice") = OfflineSenseVoiceModelConfig(), + py::arg("moonshine") = OfflineMoonshineModelConfig(), + py::arg("telespeech_ctc") = "", py::arg("tokens"), + py::arg("num_threads"), py::arg("debug") = false, + py::arg("provider") = "cpu", py::arg("model_type") = "", + py::arg("modeling_unit") = "cjkchar", py::arg("bpe_vocab") = "") + .def_readwrite("transducer", &PyClass::transducer) + .def_readwrite("paraformer", &PyClass::paraformer) + .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) + .def_readwrite("whisper", &PyClass::whisper) + .def_readwrite("fire_red_asr", &PyClass::fire_red_asr) + .def_readwrite("tdnn", &PyClass::tdnn) + .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) + .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) + .def_readwrite("sense_voice", &PyClass::sense_voice) + .def_readwrite("moonshine", &PyClass::moonshine) + .def_readwrite("telespeech_ctc", &PyClass::telespeech_ctc) + .def_readwrite("tokens", &PyClass::tokens) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def_readwrite("model_type", &PyClass::model_type) + .def_readwrite("modeling_unit", &PyClass::modeling_unit) + .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-model-config.h new file mode 100644 index 00000000..b0d7f92a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-model-config.h +// +// Copyright (c) 2023 by manyeyes + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-moonshine-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-moonshine-model-config.cc new file mode 100644 index 00000000..d2371623 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-moonshine-model-config.cc @@ -0,0 +1,28 @@ +// sherpa-mnn/python/csrc/offline-moonshine-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-moonshine-model-config.h" + +#include +#include + +#include "sherpa-mnn/python/csrc/offline-moonshine-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineMoonshineModelConfig(py::module *m) { + using PyClass = OfflineMoonshineModelConfig; + py::class_(*m, "OfflineMoonshineModelConfig") + .def(py::init(), + py::arg("preprocessor"), py::arg("encoder"), + py::arg("uncached_decoder"), py::arg("cached_decoder")) + .def_readwrite("preprocessor", &PyClass::preprocessor) + .def_readwrite("encoder", &PyClass::encoder) + .def_readwrite("uncached_decoder", &PyClass::uncached_decoder) + .def_readwrite("cached_decoder", &PyClass::cached_decoder) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-moonshine-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-moonshine-model-config.h new file mode 100644 index 00000000..6f2823b2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-moonshine-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-moonshine-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineMoonshineModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_MOONSHINE_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-nemo-enc-dec-ctc-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-nemo-enc-dec-ctc-model-config.cc new file mode 100644 index 00000000..1883a8d9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-nemo-enc-dec-ctc-model-config.cc @@ -0,0 +1,22 @@ +// sherpa-mnn/python/csrc/offline-nemo-enc-dec-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" + +#include +#include + +#include "sherpa-mnn/csrc/offline-nemo-enc-dec-ctc-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineNemoEncDecCtcModelConfig(py::module *m) { + using PyClass = OfflineNemoEncDecCtcModelConfig; + py::class_(*m, "OfflineNemoEncDecCtcModelConfig") + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-nemo-enc-dec-ctc-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-nemo-enc-dec-ctc-model-config.h new file mode 100644 index 00000000..912f9b1b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-nemo-enc-dec-ctc-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-nemo-enc-dec-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineNemoEncDecCtcModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-paraformer-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-paraformer-model-config.cc new file mode 100644 index 00000000..5b54038a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-paraformer-model-config.cc @@ -0,0 +1,23 @@ +// sherpa-mnn/python/csrc/offline-paraformer-model-config.cc +// +// Copyright (c) 2023 by manyeyes + +#include "sherpa-mnn/python/csrc/offline-paraformer-model-config.h" + +#include +#include + +#include "sherpa-mnn/csrc/offline-paraformer-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineParaformerModelConfig(py::module *m) { + using PyClass = OfflineParaformerModelConfig; + py::class_(*m, "OfflineParaformerModelConfig") + .def(py::init<>()) + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-paraformer-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-paraformer-model-config.h new file mode 100644 index 00000000..1093190e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-paraformer-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-paraformer-model-config.h +// +// Copyright (c) 2023 by manyeyes + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineParaformerModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-punctuation.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-punctuation.cc new file mode 100644 index 00000000..47c9cb77 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-punctuation.cc @@ -0,0 +1,51 @@ +// sherpa-mnn/python/csrc/offline-punctuation.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/offline-punctuation.h" + +#include + +#include "sherpa-mnn/csrc/offline-punctuation.h" + +namespace sherpa_mnn { + +static void PybindOfflinePunctuationModelConfig(py::module *m) { + using PyClass = OfflinePunctuationModelConfig; + py::class_(*m, "OfflinePunctuationModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("ct_transformer"), py::arg("num_threads") = 1, + py::arg("debug") = false, py::arg("provider") = "cpu") + .def_readwrite("ct_transformer", &PyClass::ct_transformer) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +static void PybindOfflinePunctuationConfig(py::module *m) { + PybindOfflinePunctuationModelConfig(m); + using PyClass = OfflinePunctuationConfig; + + py::class_(*m, "OfflinePunctuationConfig") + .def(py::init<>()) + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +void PybindOfflinePunctuation(py::module *m) { + PybindOfflinePunctuationConfig(m); + using PyClass = OfflinePunctuation; + + py::class_(*m, "OfflinePunctuation") + .def(py::init(), py::arg("config"), + py::call_guard()) + .def("add_punctuation", &PyClass::AddPunctuation, py::arg("text"), + py::call_guard()); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-punctuation.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-punctuation.h new file mode 100644 index 00000000..fc266122 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-punctuation.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-punctuation.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PUNCTUATION_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PUNCTUATION_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflinePunctuation(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PUNCTUATION_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-recognizer.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-recognizer.cc new file mode 100644 index 00000000..7e36e65a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-recognizer.cc @@ -0,0 +1,69 @@ +// sherpa-mnn/python/csrc/offline-recognizer.cc +// +// Copyright (c) 2023 by manyeyes + +#include "sherpa-mnn/python/csrc/offline-recognizer.h" + +#include +#include + +#include "sherpa-mnn/csrc/offline-recognizer.h" + +namespace sherpa_mnn { + +static void PybindOfflineRecognizerConfig(py::module *m) { + using PyClass = OfflineRecognizerConfig; + py::class_(*m, "OfflineRecognizerConfig") + .def(py::init(), + py::arg("feat_config"), py::arg("model_config"), + py::arg("lm_config") = OfflineLMConfig(), + py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), + py::arg("decoding_method") = "greedy_search", + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", + py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0, + py::arg("rule_fsts") = "", py::arg("rule_fars") = "") + .def_readwrite("feat_config", &PyClass::feat_config) + .def_readwrite("model_config", &PyClass::model_config) + .def_readwrite("lm_config", &PyClass::lm_config) + .def_readwrite("ctc_fst_decoder_config", &PyClass::ctc_fst_decoder_config) + .def_readwrite("decoding_method", &PyClass::decoding_method) + .def_readwrite("max_active_paths", &PyClass::max_active_paths) + .def_readwrite("hotwords_file", &PyClass::hotwords_file) + .def_readwrite("hotwords_score", &PyClass::hotwords_score) + .def_readwrite("blank_penalty", &PyClass::blank_penalty) + .def_readwrite("rule_fsts", &PyClass::rule_fsts) + .def_readwrite("rule_fars", &PyClass::rule_fars) + .def("__str__", &PyClass::ToString); +} + +void PybindOfflineRecognizer(py::module *m) { + PybindOfflineRecognizerConfig(m); + + using PyClass = OfflineRecognizer; + py::class_(*m, "OfflineRecognizer") + .def(py::init(), py::arg("config"), + py::call_guard()) + .def( + "create_stream", + [](const PyClass &self) { return self.CreateStream(); }, + py::call_guard()) + .def( + "create_stream", + [](PyClass &self, const std::string &hotwords) { + return self.CreateStream(hotwords); + }, + py::arg("hotwords"), py::call_guard()) + .def("decode_stream", &PyClass::DecodeStream, + py::call_guard()) + .def( + "decode_streams", + [](const PyClass &self, std::vector ss) { + self.DecodeStreams(ss.data(), ss.size()); + }, + py::call_guard()); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-recognizer.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-recognizer.h new file mode 100644 index 00000000..52702208 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-recognizer.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-recognizer.h +// +// Copyright (c) 2023 by manyeyes + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_RECOGNIZER_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_RECOGNIZER_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineRecognizer(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_RECOGNIZER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-sense-voice-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-sense-voice-model-config.cc new file mode 100644 index 00000000..2da9f39b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-sense-voice-model-config.cc @@ -0,0 +1,26 @@ +// sherpa-mnn/python/csrc/offline-sense-voice-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-sense-voice-model-config.h" + +#include +#include + +#include "sherpa-mnn/python/csrc/offline-sense-voice-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineSenseVoiceModelConfig(py::module *m) { + using PyClass = OfflineSenseVoiceModelConfig; + py::class_(*m, "OfflineSenseVoiceModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("model"), py::arg("language"), py::arg("use_itn")) + .def_readwrite("model", &PyClass::model) + .def_readwrite("language", &PyClass::language) + .def_readwrite("use_itn", &PyClass::use_itn) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-sense-voice-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-sense-voice-model-config.h new file mode 100644 index 00000000..82a00433 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-sense-voice-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-sense-voice-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineSenseVoiceModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SENSE_VOICE_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speaker-diarization-result.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speaker-diarization-result.cc new file mode 100644 index 00000000..4e3f23fa --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speaker-diarization-result.cc @@ -0,0 +1,32 @@ +// sherpa-mnn/python/csrc/offline-speaker-diarization-result.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/offline-speaker-diarization-result.h" + +#include "sherpa-mnn/csrc/offline-speaker-diarization-result.h" + +namespace sherpa_mnn { + +static void PybindOfflineSpeakerDiarizationSegment(py::module *m) { + using PyClass = OfflineSpeakerDiarizationSegment; + py::class_(*m, "OfflineSpeakerDiarizationSegment") + .def_property_readonly("start", &PyClass::Start) + .def_property_readonly("end", &PyClass::End) + .def_property_readonly("duration", &PyClass::Duration) + .def_property_readonly("speaker", &PyClass::Speaker) + .def_property("text", &PyClass::Text, &PyClass::SetText) + .def("__str__", &PyClass::ToString); +} + +void PybindOfflineSpeakerDiarizationResult(py::module *m) { + PybindOfflineSpeakerDiarizationSegment(m); + using PyClass = OfflineSpeakerDiarizationResult; + py::class_(*m, "OfflineSpeakerDiarizationResult") + .def_property_readonly("num_speakers", &PyClass::NumSpeakers) + .def_property_readonly("num_segments", &PyClass::NumSegments) + .def("sort_by_start_time", &PyClass::SortByStartTime) + .def("sort_by_speaker", &PyClass::SortBySpeaker); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speaker-diarization-result.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speaker-diarization-result.h new file mode 100644 index 00000000..05c5c251 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speaker-diarization-result.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-speaker-diarization-result.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineSpeakerDiarizationResult(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speaker-diarization.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speaker-diarization.cc new file mode 100644 index 00000000..2bc9efd8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speaker-diarization.cc @@ -0,0 +1,93 @@ +// sherpa-mnn/python/csrc/offline-speaker-diarization.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/offline-speaker-diarization.h" + +#include +#include + +#include "sherpa-mnn/csrc/offline-speaker-diarization.h" +#include "sherpa-mnn/csrc/offline-speaker-segmentation-model-config.h" +#include "sherpa-mnn/csrc/offline-speaker-segmentation-pyannote-model-config.h" + +namespace sherpa_mnn { + +static void PybindOfflineSpeakerSegmentationPyannoteModelConfig(py::module *m) { + using PyClass = OfflineSpeakerSegmentationPyannoteModelConfig; + py::class_(*m, "OfflineSpeakerSegmentationPyannoteModelConfig") + .def(py::init<>()) + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +static void PybindOfflineSpeakerSegmentationModelConfig(py::module *m) { + PybindOfflineSpeakerSegmentationPyannoteModelConfig(m); + + using PyClass = OfflineSpeakerSegmentationModelConfig; + py::class_(*m, "OfflineSpeakerSegmentationModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("pyannote"), py::arg("num_threads") = 1, + py::arg("debug") = false, py::arg("provider") = "cpu") + .def_readwrite("pyannote", &PyClass::pyannote) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +static void PybindOfflineSpeakerDiarizationConfig(py::module *m) { + PybindOfflineSpeakerSegmentationModelConfig(m); + + using PyClass = OfflineSpeakerDiarizationConfig; + py::class_(*m, "OfflineSpeakerDiarizationConfig") + .def(py::init(), + py::arg("segmentation"), py::arg("embedding"), py::arg("clustering"), + py::arg("min_duration_on") = 0.3, py::arg("min_duration_off") = 0.5) + .def_readwrite("segmentation", &PyClass::segmentation) + .def_readwrite("embedding", &PyClass::embedding) + .def_readwrite("clustering", &PyClass::clustering) + .def_readwrite("min_duration_on", &PyClass::min_duration_on) + .def_readwrite("min_duration_off", &PyClass::min_duration_off) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +void PybindOfflineSpeakerDiarization(py::module *m) { + PybindOfflineSpeakerDiarizationConfig(m); + + using PyClass = OfflineSpeakerDiarization; + py::class_(*m, "OfflineSpeakerDiarization") + .def(py::init(), + py::arg("config")) + .def_property_readonly("sample_rate", &PyClass::SampleRate) + .def("set_config", &PyClass::SetConfig, py::arg("config")) + .def( + "process", + [](const PyClass &self, const std::vector samples, + std::function callback) { + if (!callback) { + return self.Process(samples.data(), samples.size()); + } + + std::function callback_wrapper = + [callback](int32_t processed_chunks, int32_t num_chunks, + void *) -> int32_t { + callback(processed_chunks, num_chunks); + return 0; + }; + + return self.Process(samples.data(), samples.size(), + callback_wrapper); + }, + py::arg("samples"), py::arg("callback") = py::none()); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speaker-diarization.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speaker-diarization.h new file mode 100644 index 00000000..cc7ada54 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speaker-diarization.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-speaker-diarization.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineSpeakerDiarization(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser-gtcrn-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser-gtcrn-model-config.cc new file mode 100644 index 00000000..d4c5c719 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser-gtcrn-model-config.cc @@ -0,0 +1,22 @@ +// sherpa-mnn/python/csrc/offline-speech-denoiser-gtcrn-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/offline-speech-denoiser-gtcrn-model-config.h" + +#include + +#include "sherpa-mnn/csrc/offline-speech-denoiser-gtcrn-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineSpeechDenoiserGtcrnModelConfig(py::module *m) { + using PyClass = OfflineSpeechDenoiserGtcrnModelConfig; + py::class_(*m, "OfflineSpeechDenoiserGtcrnModelConfig") + .def(py::init(), py::arg("model") = "") + .def_readwrite("model", &PyClass::model) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser-gtcrn-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser-gtcrn-model-config.h new file mode 100644 index 00000000..6aad6f43 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser-gtcrn-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-speech-denoiser-gtcrn-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineSpeechDenoiserGtcrnModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_GTCRN_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser-model-config.cc new file mode 100644 index 00000000..f016bd16 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser-model-config.cc @@ -0,0 +1,33 @@ +// sherpa-mnn/python/csrc/offline-speech-denoiser-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/offline-speech-denoiser-model-config.h" + +#include + +#include "sherpa-mnn/csrc/offline-speech-denoiser-model-config.h" +#include "sherpa-mnn/python/csrc/offline-speech-denoiser-gtcrn-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineSpeechDenoiserModelConfig(py::module *m) { + PybindOfflineSpeechDenoiserGtcrnModelConfig(m); + + using PyClass = OfflineSpeechDenoiserModelConfig; + py::class_(*m, "OfflineSpeechDenoiserModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("gtcrn") = OfflineSpeechDenoiserGtcrnModelConfig{}, + py::arg("num_threads") = 1, py::arg("debug") = false, + py::arg("provider") = "cpu") + .def_readwrite("gtcrn", &PyClass::gtcrn) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser-model-config.h new file mode 100644 index 00000000..d20f8cfb --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-speech-denoiser-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineSpeechDenoiserModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser.cc new file mode 100644 index 00000000..f0027cad --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser.cc @@ -0,0 +1,61 @@ +// sherpa-mnn/python/csrc/offline-speech-denoiser.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/offline-speech-denoiser.h" + +#include + +#include "sherpa-mnn/csrc/offline-speech-denoiser.h" +#include "sherpa-mnn/python/csrc/offline-speech-denoiser-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineSpeechDenoiserConfig(py::module *m) { + PybindOfflineSpeechDenoiserModelConfig(m); + + using PyClass = OfflineSpeechDenoiserConfig; + + py::class_(*m, "OfflineSpeechDenoiserConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("model") = OfflineSpeechDenoiserModelConfig{}) + .def_readwrite("model", &PyClass::model) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +void PybindDenoisedAudio(py::module *m) { + using PyClass = DenoisedAudio; + py::class_(*m, "DenoisedAudio") + .def_property_readonly( + "sample_rate", [](const PyClass &self) { return self.sample_rate; }) + .def_property_readonly("samples", + [](const PyClass &self) { return self.samples; }); +} + +void PybindOfflineSpeechDenoiser(py::module *m) { + PybindOfflineSpeechDenoiserConfig(m); + PybindDenoisedAudio(m); + using PyClass = OfflineSpeechDenoiser; + py::class_(*m, "OfflineSpeechDenoiser") + .def(py::init(), py::arg("config"), + py::call_guard()) + .def( + "__call__", + [](const PyClass &self, const std::vector &samples, + int32_t sample_rate) { + return self.Run(samples.data(), samples.size(), sample_rate); + }, + py::call_guard()) + .def( + "run", + [](const PyClass &self, const std::vector &samples, + int32_t sample_rate) { + return self.Run(samples.data(), samples.size(), sample_rate); + }, + py::call_guard()) + .def_property_readonly("sample_rate", &PyClass::GetSampleRate); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser.h new file mode 100644 index 00000000..1978d412 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-speech-denoiser.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-speech-denoiser.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineSpeechDenoiser(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEECH_DENOISER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-stream.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-stream.cc new file mode 100644 index 00000000..1d6929bf --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-stream.cc @@ -0,0 +1,65 @@ +// sherpa-mnn/python/csrc/offline-stream.cc +// +// Copyright (c) 2023 by manyeyes + +#include "sherpa-mnn/python/csrc/offline-stream.h" + +#include + +#include "sherpa-mnn/csrc/offline-stream.h" + +namespace sherpa_mnn { + +constexpr const char *kAcceptWaveformUsage = R"( +Process audio samples. + +Args: + sample_rate: + Sample rate of the input samples. If it is different from the one + expected by the model, we will do resampling inside. + waveform: + A 1-D float32 tensor containing audio samples. It must be normalized + to the range [-1, 1]. +)"; + +static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT + using PyClass = OfflineRecognitionResult; + py::class_(*m, "OfflineRecognitionResult") + .def("__str__", &PyClass::AsJsonString) + .def_property_readonly( + "text", + [](const PyClass &self) -> py::str { + return py::str(PyUnicode_DecodeUTF8(self.text.c_str(), + self.text.size(), "ignore")); + }) + .def_property_readonly("lang", + [](const PyClass &self) { return self.lang; }) + .def_property_readonly("emotion", + [](const PyClass &self) { return self.emotion; }) + .def_property_readonly("event", + [](const PyClass &self) { return self.event; }) + .def_property_readonly("tokens", + [](const PyClass &self) { return self.tokens; }) + .def_property_readonly("words", + [](const PyClass &self) { return self.words; }) + .def_property_readonly( + "timestamps", [](const PyClass &self) { return self.timestamps; }); +} + +void PybindOfflineStream(py::module *m) { + PybindOfflineRecognitionResult(m); + + using PyClass = OfflineStream; + py::class_(*m, "OfflineStream") + .def( + "accept_waveform", + [](PyClass &self, float sample_rate, + const std::vector &waveform) { + self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); + }, + py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage, + py::call_guard()) + .def_property_readonly("result", &PyClass::GetResult); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-stream.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-stream.h new file mode 100644 index 00000000..80c3db95 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-stream.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-stream.h +// +// Copyright (c) 2023 by manyeyes + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_STREAM_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_STREAM_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineStream(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_STREAM_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tdnn-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tdnn-model-config.cc new file mode 100644 index 00000000..ce3a6d21 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tdnn-model-config.cc @@ -0,0 +1,22 @@ +// sherpa-mnn/python/csrc/offline-tdnn-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-tdnn-model-config.h" + +#include +#include + +#include "sherpa-mnn/python/csrc/offline-tdnn-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineTdnnModelConfig(py::module *m) { + using PyClass = OfflineTdnnModelConfig; + py::class_(*m, "OfflineTdnnModelConfig") + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tdnn-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tdnn-model-config.h new file mode 100644 index 00000000..9e3642af --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tdnn-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-tdnn-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineTdnnModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TDNN_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-transducer-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-transducer-model-config.cc new file mode 100644 index 00000000..190aaa3c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-transducer-model-config.cc @@ -0,0 +1,27 @@ +// sherpa-mnn/python/csrc/offline-transducer-model-config.cc +// +// Copyright (c) 2023 by manyeyes + +#include "sherpa-mnn/python/csrc/offline-transducer-model-config.h" + +#include +#include + +#include "sherpa-mnn/csrc/offline-transducer-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineTransducerModelConfig(py::module *m) { + using PyClass = OfflineTransducerModelConfig; + py::class_(*m, "OfflineTransducerModelConfig") + .def(py::init(), + py::arg("encoder_filename"), py::arg("decoder_filename"), + py::arg("joiner_filename")) + .def_readwrite("encoder_filename", &PyClass::encoder_filename) + .def_readwrite("decoder_filename", &PyClass::decoder_filename) + .def_readwrite("joiner_filename", &PyClass::joiner_filename) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-transducer-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-transducer-model-config.h new file mode 100644 index 00000000..6a31d330 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-transducer-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-transducer-model-config.h +// +// Copyright (c) 2023 by manyeyes + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineTransducerModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-kokoro-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-kokoro-model-config.cc new file mode 100644 index 00000000..b75f7d01 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-kokoro-model-config.cc @@ -0,0 +1,35 @@ +// sherpa-mnn/python/csrc/offline-tts-kokoro-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/offline-tts-kokoro-model-config.h" + +#include + +#include "sherpa-mnn/csrc/offline-tts-kokoro-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineTtsKokoroModelConfig(py::module *m) { + using PyClass = OfflineTtsKokoroModelConfig; + + py::class_(*m, "OfflineTtsKokoroModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("model"), py::arg("voices"), py::arg("tokens"), + py::arg("lexicon") = "", py::arg("data_dir"), + py::arg("dict_dir") = "", py::arg("length_scale") = 1.0) + .def_readwrite("model", &PyClass::model) + .def_readwrite("voices", &PyClass::voices) + .def_readwrite("tokens", &PyClass::tokens) + .def_readwrite("lexicon", &PyClass::lexicon) + .def_readwrite("data_dir", &PyClass::data_dir) + .def_readwrite("dict_dir", &PyClass::dict_dir) + .def_readwrite("length_scale", &PyClass::length_scale) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-kokoro-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-kokoro-model-config.h new file mode 100644 index 00000000..4db20b1d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-kokoro-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-tts-kokoro-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_KOKORO_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_KOKORO_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineTtsKokoroModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_KOKORO_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-matcha-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-matcha-model-config.cc new file mode 100644 index 00000000..5be79fc4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-matcha-model-config.cc @@ -0,0 +1,37 @@ +// sherpa-mnn/python/csrc/offline-tts-matcha-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/offline-tts-matcha-model-config.h" + +#include + +#include "sherpa-mnn/csrc/offline-tts-matcha-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineTtsMatchaModelConfig(py::module *m) { + using PyClass = OfflineTtsMatchaModelConfig; + + py::class_(*m, "OfflineTtsMatchaModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("acoustic_model"), py::arg("vocoder"), py::arg("lexicon"), + py::arg("tokens"), py::arg("data_dir") = "", + py::arg("dict_dir") = "", py::arg("noise_scale") = 1.0, + py::arg("length_scale") = 1.0) + .def_readwrite("acoustic_model", &PyClass::acoustic_model) + .def_readwrite("vocoder", &PyClass::vocoder) + .def_readwrite("lexicon", &PyClass::lexicon) + .def_readwrite("tokens", &PyClass::tokens) + .def_readwrite("data_dir", &PyClass::data_dir) + .def_readwrite("dict_dir", &PyClass::dict_dir) + .def_readwrite("noise_scale", &PyClass::noise_scale) + .def_readwrite("length_scale", &PyClass::length_scale) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-matcha-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-matcha-model-config.h new file mode 100644 index 00000000..870536a0 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-matcha-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-tts-matcha-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_MATCHA_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_MATCHA_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineTtsMatchaModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_MATCHA_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-model-config.cc new file mode 100644 index 00000000..606839a1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-model-config.cc @@ -0,0 +1,43 @@ +// sherpa-mnn/python/csrc/offline-tts-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/offline-tts-model-config.h" + +#include + +#include "sherpa-mnn/csrc/offline-tts-model-config.h" +#include "sherpa-mnn/python/csrc/offline-tts-kokoro-model-config.h" +#include "sherpa-mnn/python/csrc/offline-tts-matcha-model-config.h" +#include "sherpa-mnn/python/csrc/offline-tts-vits-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineTtsModelConfig(py::module *m) { + PybindOfflineTtsVitsModelConfig(m); + PybindOfflineTtsMatchaModelConfig(m); + PybindOfflineTtsKokoroModelConfig(m); + + using PyClass = OfflineTtsModelConfig; + + py::class_(*m, "OfflineTtsModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("vits") = OfflineTtsVitsModelConfig{}, + py::arg("matcha") = OfflineTtsMatchaModelConfig{}, + py::arg("kokoro") = OfflineTtsKokoroModelConfig{}, + py::arg("num_threads") = 1, py::arg("debug") = false, + py::arg("provider") = "cpu") + .def_readwrite("vits", &PyClass::vits) + .def_readwrite("matcha", &PyClass::matcha) + .def_readwrite("kokoro", &PyClass::kokoro) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-model-config.h new file mode 100644 index 00000000..b3298f71 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-tts-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineTtsModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-vits-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-vits-model-config.cc new file mode 100644 index 00000000..45ca95bd --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-vits-model-config.cc @@ -0,0 +1,37 @@ +// sherpa-mnn/python/csrc/offline-tts-vits-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/offline-tts-vits-model-config.h" + +#include + +#include "sherpa-mnn/csrc/offline-tts-vits-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineTtsVitsModelConfig(py::module *m) { + using PyClass = OfflineTtsVitsModelConfig; + + py::class_(*m, "OfflineTtsVitsModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("model"), py::arg("lexicon"), py::arg("tokens"), + py::arg("data_dir") = "", py::arg("dict_dir") = "", + py::arg("noise_scale") = 0.667, py::arg("noise_scale_w") = 0.8, + py::arg("length_scale") = 1.0) + .def_readwrite("model", &PyClass::model) + .def_readwrite("lexicon", &PyClass::lexicon) + .def_readwrite("tokens", &PyClass::tokens) + .def_readwrite("data_dir", &PyClass::data_dir) + .def_readwrite("dict_dir", &PyClass::dict_dir) + .def_readwrite("noise_scale", &PyClass::noise_scale) + .def_readwrite("noise_scale_w", &PyClass::noise_scale_w) + .def_readwrite("length_scale", &PyClass::length_scale) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-vits-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-vits-model-config.h new file mode 100644 index 00000000..df33954d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts-vits-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-tts-vits-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineTtsVitsModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_VITS_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts.cc new file mode 100644 index 00000000..55ff1192 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts.cc @@ -0,0 +1,90 @@ +// sherpa-mnn/python/csrc/offline-tts.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-mnn/python/csrc/offline-tts.h" + +#include +#include + +#include "sherpa-mnn/csrc/offline-tts.h" +#include "sherpa-mnn/python/csrc/offline-tts-model-config.h" + +namespace sherpa_mnn { + +static void PybindGeneratedAudio(py::module *m) { + using PyClass = GeneratedAudio; + py::class_(*m, "GeneratedAudio") + .def(py::init<>()) + .def_readwrite("samples", &PyClass::samples) + .def_readwrite("sample_rate", &PyClass::sample_rate) + .def("__str__", [](PyClass &self) { + std::ostringstream os; + os << "GeneratedAudio(sample_rate=" << self.sample_rate << ", "; + os << "num_samples=" << self.samples.size() << ")"; + return os.str(); + }); +} + +static void PybindOfflineTtsConfig(py::module *m) { + PybindOfflineTtsModelConfig(m); + + using PyClass = OfflineTtsConfig; + py::class_(*m, "OfflineTtsConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("model"), py::arg("rule_fsts") = "", + py::arg("rule_fars") = "", py::arg("max_num_sentences") = 2, + py::arg("silence_scale") = 0.2) + .def_readwrite("model", &PyClass::model) + .def_readwrite("rule_fsts", &PyClass::rule_fsts) + .def_readwrite("rule_fars", &PyClass::rule_fars) + .def_readwrite("max_num_sentences", &PyClass::max_num_sentences) + .def_readwrite("silence_scale", &PyClass::silence_scale) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +void PybindOfflineTts(py::module *m) { + PybindOfflineTtsConfig(m); + PybindGeneratedAudio(m); + + using PyClass = OfflineTts; + py::class_(*m, "OfflineTts") + .def(py::init(), py::arg("config"), + py::call_guard()) + .def_property_readonly("sample_rate", &PyClass::SampleRate) + .def_property_readonly("num_speakers", &PyClass::NumSpeakers) + .def( + "generate", + [](const PyClass &self, const std::string &text, int64_t sid, + float speed, + std::function, float)> callback) + -> GeneratedAudio { + if (!callback) { + return self.Generate(text, sid, speed); + } + + std::function + callback_wrapper = [callback](const float *samples, int32_t n, + float progress) { + // CAUTION(fangjun): we have to copy samples since it is + // freed once the call back returns. + + pybind11::gil_scoped_acquire acquire; + + pybind11::array_t array(n); + py::buffer_info buf = array.request(); + auto p = static_cast(buf.ptr); + std::copy(samples, samples + n, p); + return callback(array, progress); + }; + + return self.Generate(text, sid, speed, callback_wrapper); + }, + py::arg("text"), py::arg("sid") = 0, py::arg("speed") = 1.0, + py::arg("callback") = py::none(), + py::call_guard()); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts.h new file mode 100644 index 00000000..7e17016a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-tts.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-tts.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineTts(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-wenet-ctc-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-wenet-ctc-model-config.cc new file mode 100644 index 00000000..e9253757 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-wenet-ctc-model-config.cc @@ -0,0 +1,22 @@ +// sherpa-mnn/python/csrc/offline-wenet-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-wenet-ctc-model-config.h" + +#include +#include + +#include "sherpa-mnn/python/csrc/offline-wenet-ctc-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineWenetCtcModelConfig(py::module *m) { + using PyClass = OfflineWenetCtcModelConfig; + py::class_(*m, "OfflineWenetCtcModelConfig") + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-wenet-ctc-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-wenet-ctc-model-config.h new file mode 100644 index 00000000..7e1331ea --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-wenet-ctc-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-wenet-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineWenetCtcModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WENET_CTC_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-whisper-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-whisper-model-config.cc new file mode 100644 index 00000000..9dea3689 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-whisper-model-config.cc @@ -0,0 +1,29 @@ +// sherpa-mnn/python/csrc/offline-whisper-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/offline-whisper-model-config.h" + +#include +#include + +#include "sherpa-mnn/python/csrc/offline-whisper-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineWhisperModelConfig(py::module *m) { + using PyClass = OfflineWhisperModelConfig; + py::class_(*m, "OfflineWhisperModelConfig") + .def(py::init(), + py::arg("encoder"), py::arg("decoder"), py::arg("language"), + py::arg("task"), py::arg("tail_paddings") = -1) + .def_readwrite("encoder", &PyClass::encoder) + .def_readwrite("decoder", &PyClass::decoder) + .def_readwrite("language", &PyClass::language) + .def_readwrite("task", &PyClass::task) + .def_readwrite("tail_paddings", &PyClass::tail_paddings) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-whisper-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-whisper-model-config.h new file mode 100644 index 00000000..9873ca5c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-whisper-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-whisper-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineWhisperModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-zipformer-ctc-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-zipformer-ctc-model-config.cc new file mode 100644 index 00000000..357888c9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-zipformer-ctc-model-config.cc @@ -0,0 +1,22 @@ +// sherpa-mnn/python/csrc/offline-zipformer-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/offline-zipformer-ctc-model-config.h" + +#include + +#include "sherpa-mnn/csrc/offline-zipformer-ctc-model-config.h" + +namespace sherpa_mnn { + +void PybindOfflineZipformerCtcModelConfig(py::module *m) { + using PyClass = OfflineZipformerCtcModelConfig; + py::class_(*m, "OfflineZipformerCtcModelConfig") + .def(py::init<>()) + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-zipformer-ctc-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-zipformer-ctc-model-config.h new file mode 100644 index 00000000..4d02bcbd --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/offline-zipformer-ctc-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/offline-zipformer-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOfflineZipformerCtcModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-ctc-fst-decoder-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-ctc-fst-decoder-config.cc new file mode 100644 index 00000000..eae80fbb --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-ctc-fst-decoder-config.cc @@ -0,0 +1,23 @@ +// sherpa-mnn/python/csrc/online-ctc-fst-decoder-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/online-ctc-fst-decoder-config.h" + +#include + +#include "sherpa-mnn/csrc/online-ctc-fst-decoder-config.h" + +namespace sherpa_mnn { + +void PybindOnlineCtcFstDecoderConfig(py::module *m) { + using PyClass = OnlineCtcFstDecoderConfig; + py::class_(*m, "OnlineCtcFstDecoderConfig") + .def(py::init(), py::arg("graph") = "", + py::arg("max_active") = 3000) + .def_readwrite("graph", &PyClass::graph) + .def_readwrite("max_active", &PyClass::max_active) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-ctc-fst-decoder-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-ctc-fst-decoder-config.h new file mode 100644 index 00000000..aadd9368 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-ctc-fst-decoder-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/online-ctc-fst-decoder-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOnlineCtcFstDecoderConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-lm-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-lm-config.cc new file mode 100644 index 00000000..6641da10 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-lm-config.cc @@ -0,0 +1,29 @@ +// sherpa-mnn/python/csrc/online-lm-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/online-lm-config.h" + +#include + +#include "sherpa-mnn//csrc/online-lm-config.h" + +namespace sherpa_mnn { + +void PybindOnlineLMConfig(py::module *m) { + using PyClass = OnlineLMConfig; + py::class_(*m, "OnlineLMConfig") + .def(py::init(), + py::arg("model") = "", py::arg("scale") = 0.5f, + py::arg("lm_num_threads") = 1, py::arg("lm_provider") = "cpu", + py::arg("shallow_fusion") = true) + .def_readwrite("model", &PyClass::model) + .def_readwrite("scale", &PyClass::scale) + .def_readwrite("lm_provider", &PyClass::lm_provider) + .def_readwrite("lm_num_threads", &PyClass::lm_num_threads) + .def_readwrite("shallow_fusion", &PyClass::shallow_fusion) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-lm-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-lm-config.h new file mode 100644 index 00000000..8eb84786 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-lm-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/online-lm-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOnlineLMConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-model-config.cc new file mode 100644 index 00000000..24db4dc7 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-model-config.cc @@ -0,0 +1,66 @@ +// sherpa-mnn/python/csrc/online-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/online-model-config.h" + +#include +#include + +#include "sherpa-mnn/csrc/online-model-config.h" +#include "sherpa-mnn/csrc/online-transducer-model-config.h" +#include "sherpa-mnn/csrc/provider-config.h" +#include "sherpa-mnn/python/csrc/online-nemo-ctc-model-config.h" +#include "sherpa-mnn/python/csrc/online-paraformer-model-config.h" +#include "sherpa-mnn/python/csrc/online-transducer-model-config.h" +#include "sherpa-mnn/python/csrc/online-wenet-ctc-model-config.h" +#include "sherpa-mnn/python/csrc/online-zipformer2-ctc-model-config.h" +#include "sherpa-mnn/python/csrc/provider-config.h" + +namespace sherpa_mnn { + +void PybindOnlineModelConfig(py::module *m) { + PybindOnlineTransducerModelConfig(m); + PybindOnlineParaformerModelConfig(m); + PybindOnlineWenetCtcModelConfig(m); + PybindOnlineZipformer2CtcModelConfig(m); + PybindOnlineNeMoCtcModelConfig(m); + PybindProviderConfig(m); + + using PyClass = OnlineModelConfig; + py::class_(*m, "OnlineModelConfig") + .def(py::init(), + py::arg("transducer") = OnlineTransducerModelConfig(), + py::arg("paraformer") = OnlineParaformerModelConfig(), + py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), + py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), + py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), + py::arg("provider_config") = ProviderConfig(), + py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0, + py::arg("debug") = false, py::arg("model_type") = "", + py::arg("modeling_unit") = "", py::arg("bpe_vocab") = "") + .def_readwrite("transducer", &PyClass::transducer) + .def_readwrite("paraformer", &PyClass::paraformer) + .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) + .def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc) + .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) + .def_readwrite("provider_config", &PyClass::provider_config) + .def_readwrite("tokens", &PyClass::tokens) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("warm_up", &PyClass::warm_up) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("model_type", &PyClass::model_type) + .def_readwrite("modeling_unit", &PyClass::modeling_unit) + .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-model-config.h new file mode 100644 index 00000000..018f9998 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/online-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOnlineModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-nemo-ctc-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-nemo-ctc-model-config.cc new file mode 100644 index 00000000..19704ec3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-nemo-ctc-model-config.cc @@ -0,0 +1,22 @@ +// sherpa-mnn/python/csrc/online-nemo-ctc-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/online-nemo-ctc-model-config.h" + +#include +#include + +#include "sherpa-mnn/csrc/online-nemo-ctc-model-config.h" + +namespace sherpa_mnn { + +void PybindOnlineNeMoCtcModelConfig(py::module *m) { + using PyClass = OnlineNeMoCtcModelConfig; + py::class_(*m, "OnlineNeMoCtcModelConfig") + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-nemo-ctc-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-nemo-ctc-model-config.h new file mode 100644 index 00000000..166a151d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-nemo-ctc-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/online-nemo-ctc-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOnlineNeMoCtcModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-paraformer-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-paraformer-model-config.cc new file mode 100644 index 00000000..c9aefad3 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-paraformer-model-config.cc @@ -0,0 +1,24 @@ +// sherpa-mnn/python/csrc/online-paraformer-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/online-paraformer-model-config.h" + +#include +#include + +#include "sherpa-mnn/csrc/online-paraformer-model-config.h" + +namespace sherpa_mnn { + +void PybindOnlineParaformerModelConfig(py::module *m) { + using PyClass = OnlineParaformerModelConfig; + py::class_(*m, "OnlineParaformerModelConfig") + .def(py::init(), + py::arg("encoder"), py::arg("decoder")) + .def_readwrite("encoder", &PyClass::encoder) + .def_readwrite("decoder", &PyClass::decoder) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-paraformer-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-paraformer-model-config.h new file mode 100644 index 00000000..3a1536af --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-paraformer-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/online-paraformer-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOnlineParaformerModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-punctuation.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-punctuation.cc new file mode 100644 index 00000000..a3ae5c5a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-punctuation.cc @@ -0,0 +1,55 @@ +// sherpa-mnn/python/csrc/online-punctuation.cc +// +// Copyright (c) 2024 + +#include "sherpa-mnn/python/csrc/online-punctuation.h" + +#include + +#include "sherpa-mnn/csrc/online-punctuation.h" + +namespace sherpa_mnn { + +static void PybindOnlinePunctuationModelConfig(py::module *m) { + using PyClass = OnlinePunctuationModelConfig; + py::class_(*m, "OnlinePunctuationModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("cnn_bilstm"), py::arg("bpe_vocab"), + py::arg("num_threads") = 1, py::arg("debug") = false, + py::arg("provider") = "cpu") + .def_readwrite("cnn_bilstm", &PyClass::cnn_bilstm) + .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +static void PybindOnlinePunctuationConfig(py::module *m) { + PybindOnlinePunctuationModelConfig(m); + using PyClass = OnlinePunctuationConfig; + + py::class_(*m, "OnlinePunctuationConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("model_config")) + .def_readwrite("model_config", &PyClass::model) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +void PybindOnlinePunctuation(py::module *m) { + PybindOnlinePunctuationConfig(m); + using PyClass = OnlinePunctuation; + + py::class_(*m, "OnlinePunctuation") + .def(py::init(), py::arg("config"), + py::call_guard()) + .def("add_punctuation_with_case", &PyClass::AddPunctuationWithCase, + py::arg("text"), py::call_guard()); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-punctuation.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-punctuation.h new file mode 100644 index 00000000..c1eac399 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-punctuation.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/online-punctuation.h +// +// Copyright (c) 2024 + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_PUNCTUATION_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_PUNCTUATION_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOnlinePunctuation(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_PUNCTUATION_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-recognizer.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-recognizer.cc new file mode 100644 index 00000000..a93c49c7 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-recognizer.cc @@ -0,0 +1,123 @@ +// sherpa-mnn/python/csrc/online-recongizer.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/online-recognizer.h" + +#include +#include + +#include "sherpa-mnn/csrc/online-recognizer.h" + +namespace sherpa_mnn { + +static void PybindOnlineRecognizerResult(py::module *m) { + using PyClass = OnlineRecognizerResult; + py::class_(*m, "OnlineRecognizerResult") + .def_property_readonly( + "text", + [](PyClass &self) -> py::str { + return py::str(PyUnicode_DecodeUTF8(self.text.c_str(), + self.text.size(), "ignore")); + }) + .def_property_readonly( + "tokens", + [](PyClass &self) -> std::vector { return self.tokens; }) + .def_property_readonly( + "start_time", [](PyClass &self) -> float { return self.start_time; }) + .def_property_readonly( + "timestamps", + [](PyClass &self) -> std::vector { return self.timestamps; }) + .def_property_readonly( + "ys_probs", + [](PyClass &self) -> std::vector { return self.ys_probs; }) + .def_property_readonly( + "lm_probs", + [](PyClass &self) -> std::vector { return self.lm_probs; }) + .def_property_readonly("context_scores", + [](PyClass &self) -> std::vector { + return self.context_scores; + }) + .def_property_readonly( + "segment", [](PyClass &self) -> int32_t { return self.segment; }) + .def_property_readonly( + "words", + [](PyClass &self) -> std::vector { return self.words; }) + .def_property_readonly( + "is_final", [](PyClass &self) -> bool { return self.is_final; }) + .def("__str__", &PyClass::AsJsonString, + py::call_guard()) + .def("as_json_string", &PyClass::AsJsonString, + py::call_guard()); +} + +static void PybindOnlineRecognizerConfig(py::module *m) { + using PyClass = OnlineRecognizerConfig; + py::class_(*m, "OnlineRecognizerConfig") + .def(py::init(), + py::arg("feat_config"), py::arg("model_config"), + py::arg("lm_config") = OnlineLMConfig(), + py::arg("endpoint_config") = EndpointConfig(), + py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), + py::arg("enable_endpoint"), py::arg("decoding_method"), + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", + py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, + py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "", + py::arg("rule_fars") = "") + .def_readwrite("feat_config", &PyClass::feat_config) + .def_readwrite("model_config", &PyClass::model_config) + .def_readwrite("lm_config", &PyClass::lm_config) + .def_readwrite("endpoint_config", &PyClass::endpoint_config) + .def_readwrite("ctc_fst_decoder_config", &PyClass::ctc_fst_decoder_config) + .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) + .def_readwrite("decoding_method", &PyClass::decoding_method) + .def_readwrite("max_active_paths", &PyClass::max_active_paths) + .def_readwrite("hotwords_file", &PyClass::hotwords_file) + .def_readwrite("hotwords_score", &PyClass::hotwords_score) + .def_readwrite("blank_penalty", &PyClass::blank_penalty) + .def_readwrite("temperature_scale", &PyClass::temperature_scale) + .def_readwrite("rule_fsts", &PyClass::rule_fsts) + .def_readwrite("rule_fars", &PyClass::rule_fars) + .def("__str__", &PyClass::ToString); +} + +void PybindOnlineRecognizer(py::module *m) { + PybindOnlineRecognizerResult(m); + PybindOnlineRecognizerConfig(m); + + using PyClass = OnlineRecognizer; + py::class_(*m, "OnlineRecognizer") + .def(py::init(), py::arg("config"), + py::call_guard()) + .def( + "create_stream", + [](const PyClass &self) { return self.CreateStream(); }, + py::call_guard()) + .def( + "create_stream", + [](PyClass &self, const std::string &hotwords) { + return self.CreateStream(hotwords); + }, + py::arg("hotwords"), py::call_guard()) + .def("is_ready", &PyClass::IsReady, + py::call_guard()) + .def("decode_stream", &PyClass::DecodeStream, + py::call_guard()) + .def( + "decode_streams", + [](PyClass &self, std::vector ss) { + self.DecodeStreams(ss.data(), ss.size()); + }, + py::call_guard()) + .def("get_result", &PyClass::GetResult, + py::call_guard()) + .def("is_endpoint", &PyClass::IsEndpoint, + py::call_guard()) + .def("reset", &PyClass::Reset, py::call_guard()); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-recognizer.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-recognizer.h new file mode 100644 index 00000000..f4c4bf8e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-recognizer.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/online-recognizer.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_RECOGNIZER_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_RECOGNIZER_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOnlineRecognizer(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_RECOGNIZER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-stream.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-stream.cc new file mode 100644 index 00000000..2fef4ec4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-stream.cc @@ -0,0 +1,60 @@ +// sherpa-mnn/python/csrc/online-stream.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/online-stream.h" + +#include + +#include "sherpa-mnn/csrc/online-stream.h" + +namespace sherpa_mnn { + +constexpr const char *kAcceptWaveformUsage = R"( +Process audio samples. + +Args: + sample_rate: + Sample rate of the input samples. If it is different from the one + expected by the model, we will do resampling inside. + waveform: + A 1-D float32 tensor containing audio samples. It must be normalized + to the range [-1, 1]. +)"; + + +constexpr const char *kGetFramesUsage = R"( +Get n frames starting from the given frame index. +(hint: intended for debugging, for comparing FBANK features across pipelines) + +Args: + frame_index: + The starting frame index + n: + Number of frames to get. +Return: + Return a 2-D tensor of shape (n, feature_dim). + which is flattened into a 1-D vector (flattened in row major). + Unflatten in python with: + `features = np.reshape(arr, (n, feature_dim))` +)"; + +void PybindOnlineStream(py::module *m) { + using PyClass = OnlineStream; + py::class_(*m, "OnlineStream") + .def( + "accept_waveform", + [](PyClass &self, float sample_rate, + const std::vector &waveform) { + self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); + }, + py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage, + py::call_guard()) + .def("input_finished", &PyClass::InputFinished, + py::call_guard()) + .def("get_frames", &PyClass::GetFrames, + py::arg("frame_index"), py::arg("n"), kGetFramesUsage, + py::call_guard()); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-stream.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-stream.h new file mode 100644 index 00000000..baa1bc45 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-stream.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/online-stream.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_STREAM_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_STREAM_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOnlineStream(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_STREAM_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-transducer-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-transducer-model-config.cc new file mode 100644 index 00000000..8c2eaf99 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-transducer-model-config.cc @@ -0,0 +1,25 @@ +// sherpa-mnn/python/csrc/online-transducer-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/csrc/online-transducer-model-config.h" + +#include + +#include "sherpa-mnn/python/csrc/online-transducer-model-config.h" + +namespace sherpa_mnn { + +void PybindOnlineTransducerModelConfig(py::module *m) { + using PyClass = OnlineTransducerModelConfig; + py::class_(*m, "OnlineTransducerModelConfig") + .def(py::init(), + py::arg("encoder"), py::arg("decoder"), py::arg("joiner")) + .def_readwrite("encoder", &PyClass::encoder) + .def_readwrite("decoder", &PyClass::decoder) + .def_readwrite("joiner", &PyClass::joiner) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-transducer-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-transducer-model-config.h new file mode 100644 index 00000000..0b02eedb --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-transducer-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/online-transducer-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOnlineTransducerModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-wenet-ctc-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-wenet-ctc-model-config.cc new file mode 100644 index 00000000..c286c7b2 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-wenet-ctc-model-config.cc @@ -0,0 +1,25 @@ +// sherpa-mnn/python/csrc/online-wenet-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/online-wenet-ctc-model-config.h" + +#include +#include + +#include "sherpa-mnn/csrc/online-wenet-ctc-model-config.h" + +namespace sherpa_mnn { + +void PybindOnlineWenetCtcModelConfig(py::module *m) { + using PyClass = OnlineWenetCtcModelConfig; + py::class_(*m, "OnlineWenetCtcModelConfig") + .def(py::init(), py::arg("model"), + py::arg("chunk_size") = 16, py::arg("num_left_chunks") = 4) + .def_readwrite("model", &PyClass::model) + .def_readwrite("chunk_size", &PyClass::chunk_size) + .def_readwrite("num_left_chunks", &PyClass::num_left_chunks) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-wenet-ctc-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-wenet-ctc-model-config.h new file mode 100644 index 00000000..86f094bc --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-wenet-ctc-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/online-wenet-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOnlineWenetCtcModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_WENET_CTC_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-zipformer2-ctc-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-zipformer2-ctc-model-config.cc new file mode 100644 index 00000000..ff86b619 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-zipformer2-ctc-model-config.cc @@ -0,0 +1,22 @@ +// sherpa-mnn/python/csrc/online-zipformer2-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/online-zipformer2-ctc-model-config.h" + +#include +#include + +#include "sherpa-mnn/csrc/online-zipformer2-ctc-model-config.h" + +namespace sherpa_mnn { + +void PybindOnlineZipformer2CtcModelConfig(py::module *m) { + using PyClass = OnlineZipformer2CtcModelConfig; + py::class_(*m, "OnlineZipformer2CtcModelConfig") + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-zipformer2-ctc-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-zipformer2-ctc-model-config.h new file mode 100644 index 00000000..3151f56a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/online-zipformer2-ctc-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/online-zipformer2-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindOnlineZipformer2CtcModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_ZIPFORMER2_CTC_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/provider-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/provider-config.cc new file mode 100644 index 00000000..8b8c9d06 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/provider-config.cc @@ -0,0 +1,39 @@ +// sherpa-mnn/python/csrc/provider-config.cc +// +// Copyright (c) 2024 Uniphore (Author: Manickavela A) + + +#include "sherpa-mnn/python/csrc/provider-config.h" + +#include + +#include "sherpa-mnn/csrc/provider-config.h" +#include "sherpa-mnn/python/csrc/cuda-config.h" +#include "sherpa-mnn/python/csrc/tensorrt-config.h" + +namespace sherpa_mnn { + +void PybindProviderConfig(py::module *m) { + PybindCudaConfig(m); + PybindTensorrtConfig(m); + + using PyClass = ProviderConfig; + py::class_(*m, "ProviderConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("provider") = "cpu", + py::arg("device") = 0) + .def(py::init(), + py::arg("trt_config") = TensorrtConfig{}, + py::arg("cuda_config") = CudaConfig{}, + py::arg("provider") = "cpu", + py::arg("device") = 0) + .def_readwrite("trt_config", &PyClass::trt_config) + .def_readwrite("cuda_config", &PyClass::cuda_config) + .def_readwrite("provider", &PyClass::provider) + .def_readwrite("device", &PyClass::device) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/provider-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/provider-config.h new file mode 100644 index 00000000..57dbfec1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/provider-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/provider-config.h +// +// Copyright (c) 2024 Uniphore (Author: Manickavela A) + +#ifndef SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindProviderConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/sherpa-mnn.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/sherpa-mnn.cc new file mode 100644 index 00000000..0770f93c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/sherpa-mnn.cc @@ -0,0 +1,94 @@ +// sherpa-mnn/python/csrc/sherpa-mnn.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +#include "sherpa-mnn/python/csrc/alsa.h" +#include "sherpa-mnn/python/csrc/audio-tagging.h" +#include "sherpa-mnn/python/csrc/circular-buffer.h" +#include "sherpa-mnn/python/csrc/display.h" +#include "sherpa-mnn/python/csrc/endpoint.h" +#include "sherpa-mnn/python/csrc/features.h" +#include "sherpa-mnn/python/csrc/keyword-spotter.h" +#include "sherpa-mnn/python/csrc/offline-ctc-fst-decoder-config.h" +#include "sherpa-mnn/python/csrc/offline-lm-config.h" +#include "sherpa-mnn/python/csrc/offline-model-config.h" +#include "sherpa-mnn/python/csrc/offline-punctuation.h" +#include "sherpa-mnn/python/csrc/offline-recognizer.h" +#include "sherpa-mnn/python/csrc/offline-speech-denoiser.h" +#include "sherpa-mnn/python/csrc/offline-stream.h" +#include "sherpa-mnn/python/csrc/online-ctc-fst-decoder-config.h" +#include "sherpa-mnn/python/csrc/online-lm-config.h" +#include "sherpa-mnn/python/csrc/online-model-config.h" +#include "sherpa-mnn/python/csrc/online-punctuation.h" +#include "sherpa-mnn/python/csrc/online-recognizer.h" +#include "sherpa-mnn/python/csrc/online-stream.h" +#include "sherpa-mnn/python/csrc/speaker-embedding-extractor.h" +#include "sherpa-mnn/python/csrc/speaker-embedding-manager.h" +#include "sherpa-mnn/python/csrc/spoken-language-identification.h" +#include "sherpa-mnn/python/csrc/vad-model-config.h" +#include "sherpa-mnn/python/csrc/vad-model.h" +#include "sherpa-mnn/python/csrc/voice-activity-detector.h" +#include "sherpa-mnn/python/csrc/wave-writer.h" + +#if SHERPA_MNN_ENABLE_TTS == 1 +#include "sherpa-mnn/python/csrc/offline-tts.h" +#endif + +#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1 +#include "sherpa-mnn/python/csrc/fast-clustering.h" +#include "sherpa-mnn/python/csrc/offline-speaker-diarization-result.h" +#include "sherpa-mnn/python/csrc/offline-speaker-diarization.h" +#endif + +namespace sherpa_mnn { + +PYBIND11_MODULE(_sherpa_mnn, m) { + m.doc() = "pybind11 binding of sherpa-mnn"; + + PybindWaveWriter(&m); + PybindAudioTagging(&m); + PybindOfflinePunctuation(&m); + PybindOnlinePunctuation(&m); + + PybindFeatures(&m); + PybindOnlineCtcFstDecoderConfig(&m); + PybindOnlineModelConfig(&m); + PybindOnlineLMConfig(&m); + PybindOnlineStream(&m); + PybindEndpoint(&m); + PybindOnlineRecognizer(&m); + PybindKeywordSpotter(&m); + PybindDisplay(&m); + + PybindOfflineStream(&m); + PybindOfflineLMConfig(&m); + PybindOfflineModelConfig(&m); + PybindOfflineCtcFstDecoderConfig(&m); + PybindOfflineRecognizer(&m); + + PybindVadModelConfig(&m); + PybindVadModel(&m); + PybindCircularBuffer(&m); + PybindVoiceActivityDetector(&m); + +#if SHERPA_MNN_ENABLE_TTS == 1 + PybindOfflineTts(&m); +#endif + + PybindSpeakerEmbeddingExtractor(&m); + PybindSpeakerEmbeddingManager(&m); + PybindSpokenLanguageIdentification(&m); + +#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1 + PybindFastClustering(&m); + PybindOfflineSpeakerDiarizationResult(&m); + PybindOfflineSpeakerDiarization(&m); +#endif + + PybindAlsa(&m); + PybindOfflineSpeechDenoiser(&m); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/sherpa-mnn.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/sherpa-mnn.h new file mode 100644 index 00000000..99193253 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/sherpa-mnn.h @@ -0,0 +1,15 @@ +// sherpa-mnn/python/csrc/sherpa-mnn.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_SHERPA_ONNX_H_ +#define SHERPA_ONNX_PYTHON_CSRC_SHERPA_ONNX_H_ + +#include "pybind11/functional.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +#endif // SHERPA_ONNX_PYTHON_CSRC_SHERPA_ONNX_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/silero-vad-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/silero-vad-model-config.cc new file mode 100644 index 00000000..137defe8 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/silero-vad-model-config.cc @@ -0,0 +1,47 @@ +// sherpa-mnn/python/csrc/silero-vad-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/silero-vad-model-config.h" + +#include +#include + +#include "sherpa-mnn/csrc/silero-vad-model-config.h" + +namespace sherpa_mnn { + +void PybindSileroVadModelConfig(py::module *m) { + using PyClass = SileroVadModelConfig; + py::class_(*m, "SileroVadModelConfig") + .def(py::init<>()) + .def(py::init([](const std::string &model, float threshold, + float min_silence_duration, float min_speech_duration, + int32_t window_size, + float max_speech_duration) -> std::unique_ptr { + auto ans = std::make_unique(); + + ans->model = model; + ans->threshold = threshold; + ans->min_silence_duration = min_silence_duration; + ans->min_speech_duration = min_speech_duration; + ans->window_size = window_size; + ans->max_speech_duration = max_speech_duration; + + return ans; + }), + py::arg("model"), py::arg("threshold") = 0.5, + py::arg("min_silence_duration") = 0.5, + py::arg("min_speech_duration") = 0.25, py::arg("window_size") = 512, + py::arg("max_speech_duration") = 20) + .def_readwrite("model", &PyClass::model) + .def_readwrite("threshold", &PyClass::threshold) + .def_readwrite("min_silence_duration", &PyClass::min_silence_duration) + .def_readwrite("min_speech_duration", &PyClass::min_speech_duration) + .def_readwrite("window_size", &PyClass::window_size) + .def_readwrite("max_speech_duration", &PyClass::max_speech_duration) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/silero-vad-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/silero-vad-model-config.h new file mode 100644 index 00000000..ebd95eb9 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/silero-vad-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/silero-vad-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_SILERO_VAD_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_SILERO_VAD_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindSileroVadModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_SILERO_VAD_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/speaker-embedding-extractor.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/speaker-embedding-extractor.cc new file mode 100644 index 00000000..91f347ef --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/speaker-embedding-extractor.cc @@ -0,0 +1,44 @@ +// sherpa-mnn/python/csrc/speaker-embedding-extractor.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/speaker-embedding-extractor.h" + +#include + +#include "sherpa-mnn/csrc/speaker-embedding-extractor.h" + +namespace sherpa_mnn { + +static void PybindSpeakerEmbeddingExtractorConfig(py::module *m) { + using PyClass = SpeakerEmbeddingExtractorConfig; + py::class_(*m, "SpeakerEmbeddingExtractorConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("model"), py::arg("num_threads") = 1, + py::arg("debug") = false, py::arg("provider") = "cpu") + .def_readwrite("model", &PyClass::model) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +void PybindSpeakerEmbeddingExtractor(py::module *m) { + PybindSpeakerEmbeddingExtractorConfig(m); + + using PyClass = SpeakerEmbeddingExtractor; + py::class_(*m, "SpeakerEmbeddingExtractor") + .def(py::init(), + py::arg("config"), py::call_guard()) + .def_property_readonly("dim", &PyClass::Dim) + .def("create_stream", &PyClass::CreateStream, + py::call_guard()) + .def("compute", &PyClass::Compute, + py::call_guard()) + .def("is_ready", &PyClass::IsReady, + py::call_guard()); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/speaker-embedding-extractor.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/speaker-embedding-extractor.h new file mode 100644 index 00000000..7e44581e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/speaker-embedding-extractor.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/speaker-embedding-extractor.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ +#define SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindSpeakerEmbeddingExtractor(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/speaker-embedding-manager.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/speaker-embedding-manager.cc new file mode 100644 index 00000000..bf96481f --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/speaker-embedding-manager.cc @@ -0,0 +1,74 @@ +// sherpa-mnn/python/csrc/speaker-embedding-manager.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/speaker-embedding-manager.h" + +#include +#include + +#include "sherpa-mnn/csrc/speaker-embedding-manager.h" + +namespace sherpa_mnn { + +void PybindSpeakerEmbeddingManager(py::module *m) { + using PyClass = SpeakerEmbeddingManager; + py::class_(*m, "SpeakerEmbeddingManager") + .def(py::init(), py::arg("dim"), + py::call_guard()) + .def_property_readonly("num_speakers", &PyClass::NumSpeakers) + .def_property_readonly("dim", &PyClass::Dim) + .def_property_readonly("all_speakers", &PyClass::GetAllSpeakers) + .def( + "__contains__", + [](const PyClass &self, const std::string &name) -> bool { + return self.Contains(name); + }, + py::arg("name"), py::call_guard()) + .def( + "add", + [](const PyClass &self, const std::string &name, + const std::vector &v) -> bool { + return self.Add(name, v.data()); + }, + py::arg("name"), py::arg("v"), + py::call_guard()) + .def( + "add", + [](const PyClass &self, const std::string &name, + const std::vector> &embedding_list) -> bool { + return self.Add(name, embedding_list); + }, + py::arg("name"), py::arg("embedding_list"), + py::call_guard()) + .def( + "remove", + [](const PyClass &self, const std::string &name) -> bool { + return self.Remove(name); + }, + py::arg("name"), py::call_guard()) + .def( + "search", + [](const PyClass &self, const std::vector &v, float threshold) + -> std::string { return self.Search(v.data(), threshold); }, + py::arg("v"), py::arg("threshold"), + py::call_guard()) + .def( + "verify", + [](const PyClass &self, const std::string &name, + const std::vector &v, float threshold) -> bool { + return self.Verify(name, v.data(), threshold); + }, + py::arg("name"), py::arg("v"), py::arg("threshold"), + py::call_guard()) + .def( + "score", + [](const PyClass &self, const std::string &name, + const std::vector &v) -> float { + return self.Score(name, v.data()); + }, + py::arg("name"), py::arg("v"), + py::call_guard()); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/speaker-embedding-manager.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/speaker-embedding-manager.h new file mode 100644 index 00000000..c866adfb --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/speaker-embedding-manager.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/speaker-embedding-manager.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ +#define SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindSpeakerEmbeddingManager(py::module *m); + +} // namespace sherpa_mnn + +#endif // SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/spoken-language-identification.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/spoken-language-identification.cc new file mode 100644 index 00000000..3b34b00c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/spoken-language-identification.cc @@ -0,0 +1,60 @@ +// sherpa-mnn/python/csrc/spoken-language-identification.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/spoken-language-identification.h" + +#include + +#include "sherpa-mnn/csrc/spoken-language-identification.h" + +namespace sherpa_mnn { + +static void PybindSpokenLanguageIdentificationWhisperConfig(py::module *m) { + using PyClass = SpokenLanguageIdentificationWhisperConfig; + + py::class_(*m, "SpokenLanguageIdentificationWhisperConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("encoder"), py::arg("decoder"), + py::arg("tail_paddings") = -1) + .def_readwrite("encoder", &PyClass::encoder) + .def_readwrite("decoder", &PyClass::decoder) + .def_readwrite("tail_paddings", &PyClass::tail_paddings) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +static void PybindSpokenLanguageIdentificationConfig(py::module *m) { + PybindSpokenLanguageIdentificationWhisperConfig(m); + + using PyClass = SpokenLanguageIdentificationConfig; + + py::class_(*m, "SpokenLanguageIdentificationConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("whisper"), py::arg("num_threads") = 1, + py::arg("debug") = false, py::arg("provider") = "cpu") + .def_readwrite("whisper", &PyClass::whisper) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +void PybindSpokenLanguageIdentification(py::module *m) { + PybindSpokenLanguageIdentificationConfig(m); + + using PyClass = SpokenLanguageIdentification; + py::class_(*m, "SpokenLanguageIdentification") + .def(py::init(), + py::arg("config"), py::call_guard()) + .def("create_stream", &PyClass::CreateStream, + py::call_guard()) + .def("compute", &PyClass::Compute, py::arg("s"), + py::call_guard()); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/spoken-language-identification.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/spoken-language-identification.h new file mode 100644 index 00000000..0346867c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/spoken-language-identification.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/spoken-language-identification.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ +#define SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindSpokenLanguageIdentification(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/tensorrt-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/tensorrt-config.cc new file mode 100644 index 00000000..2cf29e5d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/tensorrt-config.cc @@ -0,0 +1,72 @@ +// sherpa-mnn/python/csrc/tensorrt-config.cc +// +// Copyright (c) 2024 Uniphore (Author: Manickavela A) + +#include "sherpa-mnn/python/csrc/tensorrt-config.h" + +#include +#include +#include "sherpa-mnn/csrc/provider-config.h" + +namespace sherpa_mnn { + +void PybindTensorrtConfig(py::module *m) { + using PyClass = TensorrtConfig; + py::class_(*m, "TensorrtConfig") + .def(py::init<>()) + .def(py::init([](int64_t trt_max_workspace_size, + int32_t trt_max_partition_iterations, + int32_t trt_min_subgraph_size, + bool trt_fp16_enable, + bool trt_detailed_build_log, + bool trt_engine_cache_enable, + bool trt_timing_cache_enable, + const std::string &trt_engine_cache_path, + const std::string &trt_timing_cache_path, + bool trt_dump_subgraphs) -> std::unique_ptr { + auto ans = std::make_unique(); + + ans->trt_max_workspace_size = trt_max_workspace_size; + ans->trt_max_partition_iterations = trt_max_partition_iterations; + ans->trt_min_subgraph_size = trt_min_subgraph_size; + ans->trt_fp16_enable = trt_fp16_enable; + ans->trt_detailed_build_log = trt_detailed_build_log; + ans->trt_engine_cache_enable = trt_engine_cache_enable; + ans->trt_timing_cache_enable = trt_timing_cache_enable; + ans->trt_engine_cache_path = trt_engine_cache_path; + ans->trt_timing_cache_path = trt_timing_cache_path; + ans->trt_dump_subgraphs = trt_dump_subgraphs; + + return ans; + }), + py::arg("trt_max_workspace_size") = 2147483647, + py::arg("trt_max_partition_iterations") = 10, + py::arg("trt_min_subgraph_size") = 5, + py::arg("trt_fp16_enable") = true, + py::arg("trt_detailed_build_log") = false, + py::arg("trt_engine_cache_enable") = true, + py::arg("trt_timing_cache_enable") = true, + py::arg("trt_engine_cache_path") = ".", + py::arg("trt_timing_cache_path") = ".", + py::arg("trt_dump_subgraphs") = false) + + .def_readwrite("trt_max_workspace_size", + &PyClass::trt_max_workspace_size) + .def_readwrite("trt_max_partition_iterations", + &PyClass::trt_max_partition_iterations) + .def_readwrite("trt_min_subgraph_size", &PyClass::trt_min_subgraph_size) + .def_readwrite("trt_fp16_enable", &PyClass::trt_fp16_enable) + .def_readwrite("trt_detailed_build_log", + &PyClass::trt_detailed_build_log) + .def_readwrite("trt_engine_cache_enable", + &PyClass::trt_engine_cache_enable) + .def_readwrite("trt_timing_cache_enable", + &PyClass::trt_timing_cache_enable) + .def_readwrite("trt_engine_cache_path", &PyClass::trt_engine_cache_path) + .def_readwrite("trt_timing_cache_path", &PyClass::trt_timing_cache_path) + .def_readwrite("trt_dump_subgraphs", &PyClass::trt_dump_subgraphs) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/tensorrt-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/tensorrt-config.h new file mode 100644 index 00000000..a22e778d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/tensorrt-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/tensorrt-config.h +// +// Copyright (c) 2024 Uniphore (Author: Manickavela A) + +#ifndef SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindTensorrtConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/vad-model-config.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/vad-model-config.cc new file mode 100644 index 00000000..86218573 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/vad-model-config.cc @@ -0,0 +1,34 @@ +// sherpa-mnn/python/csrc/vad-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/vad-model-config.h" + +#include + +#include "sherpa-mnn/csrc/vad-model-config.h" +#include "sherpa-mnn/python/csrc/silero-vad-model-config.h" + +namespace sherpa_mnn { + +void PybindVadModelConfig(py::module *m) { + PybindSileroVadModelConfig(m); + + using PyClass = VadModelConfig; + py::class_(*m, "VadModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("silero_vad"), py::arg("sample_rate") = 16000, + py::arg("num_threads") = 1, py::arg("provider") = "cpu", + py::arg("debug") = false) + .def_readwrite("silero_vad", &PyClass::silero_vad) + .def_readwrite("sample_rate", &PyClass::sample_rate) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("provider", &PyClass::provider) + .def_readwrite("debug", &PyClass::debug) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/vad-model-config.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/vad-model-config.h new file mode 100644 index 00000000..e03b739a --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/vad-model-config.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/vad-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_VAD_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_VAD_MODEL_CONFIG_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindVadModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_VAD_MODEL_CONFIG_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/vad-model.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/vad-model.cc new file mode 100644 index 00000000..d0916768 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/vad-model.cc @@ -0,0 +1,36 @@ +// sherpa-mnn/python/csrc/vad-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/vad-model.h" + +#include +#include + +#include "sherpa-mnn/csrc/vad-model.h" + +namespace sherpa_mnn { + +void PybindVadModel(py::module *m) { + using PyClass = VadModel; + py::class_(*m, "VadModel") + .def_static("create", + (std::unique_ptr(*)(const VadModelConfig &))( + &PyClass::Create), + py::arg("config"), py::call_guard()) + .def("reset", &PyClass::Reset, py::call_guard()) + .def( + "is_speech", + [](PyClass &self, const std::vector &samples) -> bool { + return self.IsSpeech(samples.data(), samples.size()); + }, + py::arg("samples"), py::call_guard()) + .def("window_size", &PyClass::WindowSize, + py::call_guard()) + .def("min_silence_duration_samples", &PyClass::MinSilenceDurationSamples, + py::call_guard()) + .def("min_speech_duration_samples", &PyClass::MinSpeechDurationSamples, + py::call_guard()); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/vad-model.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/vad-model.h new file mode 100644 index 00000000..50a46374 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/vad-model.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/vad-model.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_VAD_MODEL_H_ +#define SHERPA_ONNX_PYTHON_CSRC_VAD_MODEL_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindVadModel(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_VAD_MODEL_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/voice-activity-detector.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/voice-activity-detector.cc new file mode 100644 index 00000000..51e6b65c --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/voice-activity-detector.cc @@ -0,0 +1,45 @@ +// sherpa-mnn/python/csrc/voice-activity-detector.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/voice-activity-detector.h" + +#include + +#include "sherpa-mnn/csrc/voice-activity-detector.h" + +namespace sherpa_mnn { + +void PybindSpeechSegment(py::module *m) { + using PyClass = SpeechSegment; + py::class_(*m, "SpeechSegment") + .def_property_readonly("start", + [](const PyClass &self) { return self.start; }) + .def_property_readonly("samples", + [](const PyClass &self) { return self.samples; }); +} + +void PybindVoiceActivityDetector(py::module *m) { + PybindSpeechSegment(m); + using PyClass = VoiceActivityDetector; + py::class_(*m, "VoiceActivityDetector") + .def(py::init(), py::arg("config"), + py::arg("buffer_size_in_seconds") = 60, + py::call_guard()) + .def( + "accept_waveform", + [](PyClass &self, const std::vector &samples) { + self.AcceptWaveform(samples.data(), samples.size()); + }, + py::arg("samples"), py::call_guard()) + .def_property_readonly("config", &PyClass::GetConfig) + .def("empty", &PyClass::Empty, py::call_guard()) + .def("pop", &PyClass::Pop, py::call_guard()) + .def("is_speech_detected", &PyClass::IsSpeechDetected, + py::call_guard()) + .def("reset", &PyClass::Reset, py::call_guard()) + .def("flush", &PyClass::Flush, py::call_guard()) + .def_property_readonly("front", &PyClass::Front); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/voice-activity-detector.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/voice-activity-detector.h new file mode 100644 index 00000000..1191ed83 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/voice-activity-detector.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/voice-activity-detector.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_VOICE_ACTIVITY_DETECTOR_H_ +#define SHERPA_ONNX_PYTHON_CSRC_VOICE_ACTIVITY_DETECTOR_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindVoiceActivityDetector(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_VOICE_ACTIVITY_DETECTOR_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/wave-writer.cc b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/wave-writer.cc new file mode 100644 index 00000000..7077083e --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/wave-writer.cc @@ -0,0 +1,27 @@ +// sherpa-mnn/python/csrc/wave-writer.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-mnn/python/csrc/wave-writer.h" + +#include +#include + +#include "sherpa-mnn/csrc/wave-writer.h" + +namespace sherpa_mnn { + +void PybindWaveWriter(py::module *m) { + m->def( + "write_wave", + [](const std::string &filename, const std::vector &samples, + int32_t sample_rate) -> bool { + bool ok = + WriteWave(filename, sample_rate, samples.data(), samples.size()); + + return ok; + }, + py::arg("filename"), py::arg("samples"), py::arg("sample_rate")); +} + +} // namespace sherpa_mnn diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/wave-writer.h b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/wave-writer.h new file mode 100644 index 00000000..fd7214ef --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/csrc/wave-writer.h @@ -0,0 +1,16 @@ +// sherpa-mnn/python/csrc/wave-writer.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_WAVE_WRITER_H_ +#define SHERPA_ONNX_PYTHON_CSRC_WAVE_WRITER_H_ + +#include "sherpa-mnn/python/csrc/sherpa-mnn.h" + +namespace sherpa_mnn { + +void PybindWaveWriter(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_WAVE_WRITER_H_ diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/__init__.py b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/__init__.py new file mode 100644 index 00000000..c7502109 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/__init__.py @@ -0,0 +1,54 @@ +from _sherpa_mnn import ( + Alsa, + AudioEvent, + AudioTagging, + AudioTaggingConfig, + AudioTaggingModelConfig, + CircularBuffer, + DenoisedAudio, + Display, + FastClustering, + FastClusteringConfig, + OfflinePunctuation, + OfflinePunctuationConfig, + OfflinePunctuationModelConfig, + OfflineSpeakerDiarization, + OfflineSpeakerDiarizationConfig, + OfflineSpeakerDiarizationResult, + OfflineSpeakerDiarizationSegment, + OfflineSpeakerSegmentationModelConfig, + OfflineSpeakerSegmentationPyannoteModelConfig, + OfflineSpeechDenoiser, + OfflineSpeechDenoiserConfig, + OfflineSpeechDenoiserGtcrnModelConfig, + OfflineSpeechDenoiserModelConfig, + OfflineStream, + OfflineTts, + OfflineTtsConfig, + OfflineTtsKokoroModelConfig, + OfflineTtsMatchaModelConfig, + OfflineTtsModelConfig, + OfflineTtsVitsModelConfig, + OfflineZipformerAudioTaggingModelConfig, + OnlinePunctuation, + OnlinePunctuationConfig, + OnlinePunctuationModelConfig, + OnlineStream, + SileroVadModelConfig, + SpeakerEmbeddingExtractor, + SpeakerEmbeddingExtractorConfig, + SpeakerEmbeddingManager, + SpeechSegment, + SpokenLanguageIdentification, + SpokenLanguageIdentificationConfig, + SpokenLanguageIdentificationWhisperConfig, + VadModel, + VadModelConfig, + VoiceActivityDetector, + write_wave, +) + +from .keyword_spotter import KeywordSpotter +from .offline_recognizer import OfflineRecognizer +from .online_recognizer import OnlineRecognizer +from .utils import text2token diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/cli.py b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/cli.py new file mode 100644 index 00000000..3348d305 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/cli.py @@ -0,0 +1,109 @@ +# Copyright (c) 2023 Xiaomi Corporation + +import logging +try: + import click +except ImportError: + print('Please run') + print(' pip install click') + print('before you continue') + raise + +from pathlib import Path +from sherpa_mnn import text2token + + +@click.group() +def cli(): + """ + The shell entry point to sherpa-mnn. + """ + logging.basicConfig( + format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s", + level=logging.INFO, + ) + + +@cli.command(name="text2token") +@click.argument("input", type=click.Path(exists=True, dir_okay=False)) +@click.argument("output", type=click.Path()) +@click.option( + "--tokens", + type=str, + required=True, + help="The path to tokens.txt.", +) +@click.option( + "--tokens-type", + type=click.Choice( + ["cjkchar", "bpe", "cjkchar+bpe", "fpinyin", "ppinyin"], case_sensitive=True + ), + required=True, + help="""The type of modeling units, should be cjkchar, bpe, cjkchar+bpe, fpinyin or ppinyin. + fpinyin means full pinyin, each cjkchar has a pinyin(with tone). + ppinyin means partial pinyin, it splits pinyin into initial and final, + """, +) +@click.option( + "--bpe-model", + type=str, + help="The path to bpe.model. Only required when tokens-type is bpe or cjkchar+bpe.", +) +def encode_text( + input: Path, output: Path, tokens: Path, tokens_type: str, bpe_model: Path +): + """ + Encode the texts given by the INPUT to tokens and write the results to the OUTPUT. + Each line in the texts contains the original phrase, it might also contain some + extra items, for example, the boosting score (startting with :), the triggering + threshold (startting with #, only used in keyword spotting task) and the original + phrase (startting with @). Note: the extra items will be kept same in the output. + + example input 1 (tokens_type = ppinyin): + + 小爱同学 :2.0 #0.6 @小爱同学 + 你好问问 :3.5 @你好问问 + 小艺小艺 #0.6 @小艺小艺 + + example output 1: + + x iǎo ài t óng x ué :2.0 #0.6 @小爱同学 + n ǐ h ǎo w èn w èn :3.5 @你好问问 + x iǎo y ì x iǎo y ì #0.6 @小艺小艺 + + example input 2 (tokens_type = bpe): + + HELLO WORLD :1.5 #0.4 + HI GOOGLE :2.0 #0.8 + HEY SIRI #0.35 + + example output 2: + + ▁HE LL O ▁WORLD :1.5 #0.4 + ▁HI ▁GO O G LE :2.0 #0.8 + ▁HE Y ▁S I RI #0.35 + """ + texts = [] + # extra information like boosting score (start with :), triggering threshold (start with #) + # original keyword (start with @) + extra_info = [] + with open(input, "r", encoding="utf8") as f: + for line in f: + extra = [] + text = [] + toks = line.strip().split() + for tok in toks: + if tok[0] == ":" or tok[0] == "#" or tok[0] == "@": + extra.append(tok) + else: + text.append(tok) + texts.append(" ".join(text)) + extra_info.append(extra) + + encoded_texts = text2token( + texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model + ) + with open(output, "w", encoding="utf8") as f: + for i, txt in enumerate(encoded_texts): + txt += extra_info[i] + f.write(" ".join(txt) + "\n") diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/keyword_spotter.py b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/keyword_spotter.py new file mode 100644 index 00000000..44f61fe6 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/keyword_spotter.py @@ -0,0 +1,159 @@ +# Copyright (c) 2023 Xiaomi Corporation + +from pathlib import Path +from typing import List, Optional + +from _sherpa_mnn import ( + FeatureExtractorConfig, + KeywordSpotterConfig, + OnlineModelConfig, + OnlineTransducerModelConfig, + OnlineStream, + ProviderConfig, +) + +from _sherpa_mnn import KeywordSpotter as _KeywordSpotter + + +def _assert_file_exists(f: str): + assert Path(f).is_file(), f"{f} does not exist" + + +class KeywordSpotter(object): + """A class for keyword spotting. + + Please refer to the following files for usages + - https://github.com/k2-fsa/sherpa-mnn/blob/master/python-api-examples/keyword-spotter.py + - https://github.com/k2-fsa/sherpa-mnn/blob/master/python-api-examples/keyword-spotter-from-microphone.py + """ + + def __init__( + self, + tokens: str, + encoder: str, + decoder: str, + joiner: str, + keywords_file: str, + num_threads: int = 2, + sample_rate: float = 16000, + feature_dim: int = 80, + max_active_paths: int = 4, + keywords_score: float = 1.0, + keywords_threshold: float = 0.25, + num_trailing_blanks: int = 1, + provider: str = "cpu", + device: int = 0, + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + encoder: + Path to ``encoder.onnx``. + decoder: + Path to ``decoder.onnx``. + joiner: + Path to ``joiner.onnx``. + keywords_file: + The file containing keywords, one word/phrase per line, and for each + phrase the bpe/cjkchar/pinyin are separated by a space. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + max_active_paths: + Use only when decoding_method is modified_beam_search. It specifies + the maximum number of active paths during beam search. + keywords_score: + The boosting score of each token for keywords. The larger the easier to + survive beam search. + keywords_threshold: + The trigger threshold (i.e. probability) of the keyword. The larger the + harder to trigger. + num_trailing_blanks: + The number of trailing blanks a keyword should be followed. Setting + to a larger value (e.g. 8) when your keywords has overlapping tokens + between each other. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + device: + onnxruntime cuda device index. + """ + _assert_file_exists(tokens) + _assert_file_exists(encoder) + _assert_file_exists(decoder) + _assert_file_exists(joiner) + + assert num_threads > 0, num_threads + + transducer_config = OnlineTransducerModelConfig( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + + provider_config = ProviderConfig( + provider=provider, + device=device, + ) + + model_config = OnlineModelConfig( + transducer=transducer_config, + tokens=tokens, + num_threads=num_threads, + provider_config=provider_config, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + keywords_spotter_config = KeywordSpotterConfig( + feat_config=feat_config, + model_config=model_config, + max_active_paths=max_active_paths, + num_trailing_blanks=num_trailing_blanks, + keywords_score=keywords_score, + keywords_threshold=keywords_threshold, + keywords_file=keywords_file, + ) + self.keyword_spotter = _KeywordSpotter(keywords_spotter_config) + + def reset_stream(self, s: OnlineStream): + self.keyword_spotter.reset(s) + + def create_stream(self, keywords: Optional[str] = None): + if keywords is None: + return self.keyword_spotter.create_stream() + else: + return self.keyword_spotter.create_stream(keywords) + + def decode_stream(self, s: OnlineStream): + self.keyword_spotter.decode_stream(s) + + def decode_streams(self, ss: List[OnlineStream]): + self.keyword_spotter.decode_streams(ss) + + def is_ready(self, s: OnlineStream) -> bool: + return self.keyword_spotter.is_ready(s) + + def get_result(self, s: OnlineStream) -> str: + return self.keyword_spotter.get_result(s).keyword.strip() + + def tokens(self, s: OnlineStream) -> List[str]: + return self.keyword_spotter.get_result(s).tokens + + def timestamps(self, s: OnlineStream) -> List[float]: + return self.keyword_spotter.get_result(s).timestamps diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/offline_recognizer.py b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/offline_recognizer.py new file mode 100644 index 00000000..f5f3804d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/offline_recognizer.py @@ -0,0 +1,885 @@ +# Copyright (c) 2023 by manyeyes +# Copyright (c) 2023 Xiaomi Corporation +from pathlib import Path +from typing import List, Optional + +from _sherpa_mnn import ( + FeatureExtractorConfig, + OfflineCtcFstDecoderConfig, + OfflineFireRedAsrModelConfig, + OfflineLMConfig, + OfflineModelConfig, + OfflineMoonshineModelConfig, + OfflineNemoEncDecCtcModelConfig, + OfflineParaformerModelConfig, +) +from _sherpa_mnn import OfflineRecognizer as _Recognizer +from _sherpa_mnn import ( + OfflineRecognizerConfig, + OfflineSenseVoiceModelConfig, + OfflineStream, + OfflineTdnnModelConfig, + OfflineTransducerModelConfig, + OfflineWenetCtcModelConfig, + OfflineWhisperModelConfig, + OfflineZipformerCtcModelConfig, +) + + +def _assert_file_exists(f: str): + assert Path(f).is_file(), f"{f} does not exist" + + +class OfflineRecognizer(object): + """A class for offline speech recognition. + + Please refer to the following files for usages + - https://github.com/k2-fsa/sherpa-mnn/blob/master/sherpa-mnn/python/tests/test_offline_recognizer.py + - https://github.com/k2-fsa/sherpa-mnn/blob/master/python-api-examples/offline-decode-files.py + """ + + @classmethod + def from_transducer( + cls, + encoder: str, + decoder: str, + joiner: str, + tokens: str, + num_threads: int = 1, + sample_rate: int = 16000, + feature_dim: int = 80, + decoding_method: str = "greedy_search", + max_active_paths: int = 4, + hotwords_file: str = "", + hotwords_score: float = 1.5, + blank_penalty: float = 0.0, + modeling_unit: str = "cjkchar", + bpe_vocab: str = "", + debug: bool = False, + provider: str = "cpu", + model_type: str = "transducer", + rule_fsts: str = "", + rule_fars: str = "", + lm: str = "", + lm_scale: float = 0.1, + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + encoder: + Path to ``encoder.onnx``. + decoder: + Path to ``decoder.onnx``. + joiner: + Path to ``joiner.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + decoding_method: + Valid values: greedy_search, modified_beam_search. + max_active_paths: + Maximum number of active paths to keep. Used only when + decoding_method is modified_beam_search. + hotwords_file: + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. + hotwords_score: + The hotword score of each token for biasing word/phrase. Used only if + hotwords_file is given with modified_beam_search as decoding method. + blank_penalty: + The penalty applied on blank symbol during decoding. + modeling_unit: + The modeling unit of the model, commonly used units are bpe, cjkchar, + cjkchar+bpe, etc. Currently, it is needed only when hotwords are + provided, we need it to encode the hotwords into token sequence. + and the modeling unit is bpe or cjkchar+bpe. + bpe_vocab: + The vocabulary generated by google's sentencepiece program. + It is a file has two columns, one is the token, the other is + the log probability, you can get it from the directory where + your bpe model is generated. Only used when hotwords provided + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + transducer=OfflineTransducerModelConfig( + encoder_filename=encoder, + decoder_filename=decoder, + joiner_filename=joiner, + ), + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + modeling_unit=modeling_unit, + bpe_vocab=bpe_vocab, + model_type=model_type, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + if len(hotwords_file) > 0 and decoding_method != "modified_beam_search": + raise ValueError( + "Please use --decoding-method=modified_beam_search when using " + f"--hotwords-file. Currently given: {decoding_method}" + ) + + if lm and decoding_method != "modified_beam_search": + raise ValueError( + "Please use --decoding-method=modified_beam_search when using " + f"--lm. Currently given: {decoding_method}" + ) + + lm_config = OfflineLMConfig( + model=lm, + scale=lm_scale, + lm_num_threads=num_threads, + lm_provider=provider, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + lm_config=lm_config, + decoding_method=decoding_method, + max_active_paths=max_active_paths, + hotwords_file=hotwords_file, + hotwords_score=hotwords_score, + blank_penalty=blank_penalty, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_sense_voice( + cls, + model: str, + tokens: str, + num_threads: int = 1, + sample_rate: int = 16000, + feature_dim: int = 80, + decoding_method: str = "greedy_search", + debug: bool = False, + provider: str = "cpu", + language: str = "", + use_itn: bool = False, + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + model: + Path to ``model.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + decoding_method: + Valid values are greedy_search. + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + language: + If not empty, then valid values are: auto, zh, en, ja, ko, yue + use_itn: + True to enable inverse text normalization; False to disable it. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + sense_voice=OfflineSenseVoiceModelConfig( + model=model, + language=language, + use_itn=use_itn, + ), + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_paraformer( + cls, + paraformer: str, + tokens: str, + num_threads: int = 1, + sample_rate: int = 16000, + feature_dim: int = 80, + decoding_method: str = "greedy_search", + debug: bool = False, + provider: str = "cpu", + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + paraformer: + Path to ``model.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + decoding_method: + Valid values are greedy_search. + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + paraformer=OfflineParaformerModelConfig(model=paraformer), + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + model_type="paraformer", + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_telespeech_ctc( + cls, + model: str, + tokens: str, + num_threads: int = 1, + sample_rate: int = 16000, + feature_dim: int = 40, + decoding_method: str = "greedy_search", + debug: bool = False, + provider: str = "cpu", + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models. + + Args: + model: + Path to ``model.onnx``. + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. It is + ignored and is hard-coded in C++ to 40. + feature_dim: + Dimension of the feature used to train the model. It is ignored + and is hard-coded in C++ to 40. + decoding_method: + Valid values are greedy_search. + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + telespeech_ctc=model, + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + model_type="nemo_ctc", + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_nemo_ctc( + cls, + model: str, + tokens: str, + num_threads: int = 1, + sample_rate: int = 16000, + feature_dim: int = 80, + decoding_method: str = "greedy_search", + debug: bool = False, + provider: str = "cpu", + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + model: + Path to ``model.onnx``. + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + decoding_method: + Valid values are greedy_search. + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + nemo_ctc=OfflineNemoEncDecCtcModelConfig(model=model), + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + model_type="nemo_ctc", + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_whisper( + cls, + encoder: str, + decoder: str, + tokens: str, + language: str = "en", + task: str = "transcribe", + num_threads: int = 1, + decoding_method: str = "greedy_search", + debug: bool = False, + provider: str = "cpu", + tail_paddings: int = -1, + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models for different kinds of whisper models, + e.g., tiny, tiny.en, base, base.en, etc. + + Args: + encoder: + Path to the encoder model, e.g., tiny-encoder.onnx, + tiny-encoder.int8.onnx, tiny-encoder.ort, etc. + decoder: + Path to the decoder model, e.g., tiny-decoder.onnx, + tiny-decoder.int8.onnx, tiny-decoder.ort, etc. + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + language: + The spoken language in the audio file. Example values: en, de, zh, + jp, fr. See https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10 + for all possible values. Note that for non-multilingual models, the + only valid value is 'en'. + task: + Valid values are: transcribe, translate. Note that for + non-multilingual models, the only valid value is 'transcribe'. + num_threads: + Number of threads for neural network computation. + decoding_method: + Valid values: greedy_search. + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + whisper=OfflineWhisperModelConfig( + encoder=encoder, + decoder=decoder, + language=language, + task=task, + tail_paddings=tail_paddings, + ), + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + model_type="whisper", + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=16000, + feature_dim=80, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_fire_red_asr( + cls, + encoder: str, + decoder: str, + tokens: str, + num_threads: int = 1, + decoding_method: str = "greedy_search", + debug: bool = False, + provider: str = "cpu", + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models for different kinds of FireRedAsr models, + e.g., xs, large, etc. + + Args: + encoder: + Path to the encoder model. + decoder: + Path to the decoder model. + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + num_threads: + Number of threads for neural network computation. + decoding_method: + Valid values: greedy_search. + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + fire_red_asr=OfflineFireRedAsrModelConfig( + encoder=encoder, + decoder=decoder, + ), + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=16000, + feature_dim=80, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_moonshine( + cls, + preprocessor: str, + encoder: str, + uncached_decoder: str, + cached_decoder: str, + tokens: str, + num_threads: int = 1, + decoding_method: str = "greedy_search", + debug: bool = False, + provider: str = "cpu", + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models for different kinds of moonshine models, + e.g., tiny, base, etc. + + Args: + preprocessor: + Path to the preprocessor model, e.g., preprocess.onnx + encoder: + Path to the encoder model, e.g., encode.int8.onnx + uncached_decoder: + Path to the uncached decoder model, e.g., uncached_decode.int8.onnx, + cached_decoder: + Path to the cached decoder model, e.g., cached_decode.int8.onnx, + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + num_threads: + Number of threads for neural network computation. + decoding_method: + Valid values: greedy_search. + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + moonshine=OfflineMoonshineModelConfig( + preprocessor=preprocessor, + encoder=encoder, + uncached_decoder=uncached_decoder, + cached_decoder=cached_decoder, + ), + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + ) + + unused_feat_config = FeatureExtractorConfig( + sampling_rate=16000, + feature_dim=80, + ) + + recognizer_config = OfflineRecognizerConfig( + model_config=model_config, + feat_config=unused_feat_config, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_tdnn_ctc( + cls, + model: str, + tokens: str, + num_threads: int = 1, + sample_rate: int = 8000, + feature_dim: int = 23, + decoding_method: str = "greedy_search", + debug: bool = False, + provider: str = "cpu", + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models. + + Args: + model: + Path to ``model.onnx``. + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + decoding_method: + Valid values are greedy_search. + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + tdnn=OfflineTdnnModelConfig(model=model), + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + model_type="tdnn", + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_wenet_ctc( + cls, + model: str, + tokens: str, + num_threads: int = 1, + sample_rate: int = 16000, + feature_dim: int = 80, + decoding_method: str = "greedy_search", + debug: bool = False, + provider: str = "cpu", + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + model: + Path to ``model.onnx``. + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + decoding_method: + Valid values are greedy_search. + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + wenet_ctc=OfflineWenetCtcModelConfig(model=model), + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + model_type="wenet_ctc", + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + def create_stream(self, hotwords: Optional[str] = None): + if hotwords is None: + return self.recognizer.create_stream() + else: + return self.recognizer.create_stream(hotwords) + + def decode_stream(self, s: OfflineStream): + self.recognizer.decode_stream(s) + + def decode_streams(self, ss: List[OfflineStream]): + self.recognizer.decode_streams(ss) diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/online_recognizer.py b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/online_recognizer.py new file mode 100644 index 00000000..84afcd6d --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/online_recognizer.py @@ -0,0 +1,855 @@ +# Copyright (c) 2023 Xiaomi Corporation +from pathlib import Path +from typing import List, Optional + +from _sherpa_mnn import ( + EndpointConfig, + FeatureExtractorConfig, + OnlineLMConfig, + OnlineModelConfig, + OnlineParaformerModelConfig, +) +from _sherpa_mnn import OnlineRecognizer as _Recognizer +from _sherpa_mnn import ( + CudaConfig, + TensorrtConfig, + ProviderConfig, + OnlineRecognizerConfig, + OnlineRecognizerResult, + OnlineStream, + OnlineTransducerModelConfig, + OnlineWenetCtcModelConfig, + OnlineNeMoCtcModelConfig, + OnlineZipformer2CtcModelConfig, + OnlineCtcFstDecoderConfig, +) + + +def _assert_file_exists(f: str): + assert Path(f).is_file(), f"{f} does not exist" + + +class OnlineRecognizer(object): + """A class for streaming speech recognition. + + Please refer to the following files for usages + - https://github.com/k2-fsa/sherpa-mnn/blob/master/sherpa-mnn/python/tests/test_online_recognizer.py + - https://github.com/k2-fsa/sherpa-mnn/blob/master/python-api-examples/online-decode-files.py + """ + + @classmethod + def from_transducer( + cls, + tokens: str, + encoder: str, + decoder: str, + joiner: str, + num_threads: int = 2, + sample_rate: float = 16000, + feature_dim: int = 80, + low_freq: float = 20.0, + high_freq: float = -400.0, + dither: float = 0.0, + normalize_samples: bool = True, + snip_edges: bool = False, + enable_endpoint_detection: bool = False, + rule1_min_trailing_silence: float = 2.4, + rule2_min_trailing_silence: float = 1.2, + rule3_min_utterance_length: float = 20.0, + decoding_method: str = "greedy_search", + max_active_paths: int = 4, + hotwords_score: float = 1.5, + blank_penalty: float = 0.0, + hotwords_file: str = "", + model_type: str = "", + modeling_unit: str = "cjkchar", + bpe_vocab: str = "", + lm: str = "", + lm_scale: float = 0.1, + lm_shallow_fusion: bool = True, + temperature_scale: float = 2.0, + debug: bool = False, + rule_fsts: str = "", + rule_fars: str = "", + provider: str = "cpu", + device: int = 0, + cudnn_conv_algo_search: int = 1, + trt_max_workspace_size: int = 2147483647, + trt_max_partition_iterations: int = 10, + trt_min_subgraph_size: int = 5, + trt_fp16_enable: bool = True, + trt_detailed_build_log: bool = False, + trt_engine_cache_enable: bool = True, + trt_timing_cache_enable: bool = True, + trt_engine_cache_path: str ="", + trt_timing_cache_path: str ="", + trt_dump_subgraphs: bool = False, + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + encoder: + Path to ``encoder.onnx``. + decoder: + Path to ``decoder.onnx``. + joiner: + Path to ``joiner.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + low_freq: + Low cutoff frequency for mel bins in feature extraction. + high_freq: + High cutoff frequency for mel bins in feature extraction + (if <= 0, offset from Nyquist) + dither: + Dithering constant (0.0 means no dither). + By default the audio samples are in range [-1,+1], + so dithering constant 0.00003 is a good value, + equivalent to the default 1.0 from kaldi + normalize_samples: + True for +/- 1.0 range of audio samples (default, zipformer feats), + False for +/- 32k samples (ebranchformer features). + snip_edges: + handling of end of audio signal in kaldi feature extraction. + If true, end effects will be handled by outputting only frames that + completely fit in the file, and the number of frames depends on the + frame-length. If false, the number of frames depends only on the + frame-shift, and we reflect the data at the ends. + enable_endpoint_detection: + True to enable endpoint detection. False to disable endpoint + detection. + rule1_min_trailing_silence: + Used only when enable_endpoint_detection is True. If the duration + of trailing silence in seconds is larger than this value, we assume + an endpoint is detected. + rule2_min_trailing_silence: + Used only when enable_endpoint_detection is True. If we have decoded + something that is nonsilence and if the duration of trailing silence + in seconds is larger than this value, we assume an endpoint is + detected. + rule3_min_utterance_length: + Used only when enable_endpoint_detection is True. If the utterance + length in seconds is larger than this value, we assume an endpoint + is detected. + decoding_method: + Valid values are greedy_search, modified_beam_search. + max_active_paths: + Use only when decoding_method is modified_beam_search. It specifies + the maximum number of active paths during beam search. + blank_penalty: + The penalty applied on blank symbol during decoding. + hotwords_file: + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. + hotwords_score: + The hotword score of each token for biasing word/phrase. Used only if + hotwords_file is given with modified_beam_search as decoding method. + temperature_scale: + Temperature scaling for output symbol confidence estiamation. + It affects only confidence values, the decoding uses the original + logits without temperature. + model_type: + Online transducer model type. Valid values are: conformer, lstm, + zipformer, zipformer2. All other values lead to loading the model twice. + modeling_unit: + The modeling unit of the model, commonly used units are bpe, cjkchar, + cjkchar+bpe, etc. Currently, it is needed only when hotwords are + provided, we need it to encode the hotwords into token sequence. + bpe_vocab: + The vocabulary generated by google's sentencepiece program. + It is a file has two columns, one is the token, the other is + the log probability, you can get it from the directory where + your bpe model is generated. Only used when hotwords provided + and the modeling unit is bpe or cjkchar+bpe. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + device: + onnxruntime cuda device index. + cudnn_conv_algo_search: + onxrt CuDNN convolution search algorithm selection. CUDA EP + trt_max_workspace_size: + Set TensorRT EP GPU memory usage limit. TensorRT EP + trt_max_partition_iterations: + Limit partitioning iterations for model conversion. TensorRT EP + trt_min_subgraph_size: + Set minimum size for subgraphs in partitioning. TensorRT EP + trt_fp16_enable: bool = True, + Enable FP16 precision for faster performance. TensorRT EP + trt_detailed_build_log: bool = False, + Enable detailed logging of build steps. TensorRT EP + trt_engine_cache_enable: bool = True, + Enable caching of TensorRT engines. TensorRT EP + trt_timing_cache_enable: bool = True, + "Enable use of timing cache to speed up builds." TensorRT EP + trt_engine_cache_path: str ="", + "Set path to store cached TensorRT engines." TensorRT EP + trt_timing_cache_path: str ="", + "Set path for storing timing cache." TensorRT EP + trt_dump_subgraphs: bool = False, + "Dump optimized subgraphs for debugging." TensorRT EP + """ + self = cls.__new__(cls) + _assert_file_exists(tokens) + _assert_file_exists(encoder) + _assert_file_exists(decoder) + _assert_file_exists(joiner) + + assert num_threads > 0, num_threads + + transducer_config = OnlineTransducerModelConfig( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + + cuda_config = CudaConfig( + cudnn_conv_algo_search=cudnn_conv_algo_search, + ) + + trt_config = TensorrtConfig( + trt_max_workspace_size=trt_max_workspace_size, + trt_max_partition_iterations=trt_max_partition_iterations, + trt_min_subgraph_size=trt_min_subgraph_size, + trt_fp16_enable=trt_fp16_enable, + trt_detailed_build_log=trt_detailed_build_log, + trt_engine_cache_enable=trt_engine_cache_enable, + trt_timing_cache_enable=trt_timing_cache_enable, + trt_engine_cache_path=trt_engine_cache_path, + trt_timing_cache_path=trt_timing_cache_path, + trt_dump_subgraphs=trt_dump_subgraphs, + ) + + provider_config = ProviderConfig( + trt_config=trt_config, + cuda_config=cuda_config, + provider=provider, + device=device, + ) + + model_config = OnlineModelConfig( + transducer=transducer_config, + tokens=tokens, + num_threads=num_threads, + provider_config=provider_config, + model_type=model_type, + modeling_unit=modeling_unit, + bpe_vocab=bpe_vocab, + debug=debug, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + normalize_samples=normalize_samples, + snip_edges=snip_edges, + feature_dim=feature_dim, + low_freq=low_freq, + high_freq=high_freq, + dither=dither, + ) + + endpoint_config = EndpointConfig( + rule1_min_trailing_silence=rule1_min_trailing_silence, + rule2_min_trailing_silence=rule2_min_trailing_silence, + rule3_min_utterance_length=rule3_min_utterance_length, + ) + + if len(hotwords_file) > 0 and decoding_method != "modified_beam_search": + raise ValueError( + "Please use --decoding-method=modified_beam_search when using " + f"--hotwords-file. Currently given: {decoding_method}" + ) + + if lm and decoding_method != "modified_beam_search": + raise ValueError( + "Please use --decoding-method=modified_beam_search when using " + f"--lm. Currently given: {decoding_method}" + ) + + lm_config = OnlineLMConfig( + model=lm, + scale=lm_scale, + shallow_fusion=lm_shallow_fusion, + ) + + recognizer_config = OnlineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + lm_config=lm_config, + endpoint_config=endpoint_config, + enable_endpoint=enable_endpoint_detection, + decoding_method=decoding_method, + max_active_paths=max_active_paths, + hotwords_score=hotwords_score, + hotwords_file=hotwords_file, + blank_penalty=blank_penalty, + temperature_scale=temperature_scale, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_paraformer( + cls, + tokens: str, + encoder: str, + decoder: str, + num_threads: int = 2, + sample_rate: float = 16000, + feature_dim: int = 80, + enable_endpoint_detection: bool = False, + rule1_min_trailing_silence: float = 2.4, + rule2_min_trailing_silence: float = 1.2, + rule3_min_utterance_length: float = 20.0, + decoding_method: str = "greedy_search", + provider: str = "cpu", + debug: bool = False, + rule_fsts: str = "", + rule_fars: str = "", + device: int = 0, + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + encoder: + Path to ``encoder.onnx``. + decoder: + Path to ``decoder.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + enable_endpoint_detection: + True to enable endpoint detection. False to disable endpoint + detection. + rule1_min_trailing_silence: + Used only when enable_endpoint_detection is True. If the duration + of trailing silence in seconds is larger than this value, we assume + an endpoint is detected. + rule2_min_trailing_silence: + Used only when enable_endpoint_detection is True. If we have decoded + something that is nonsilence and if the duration of trailing silence + in seconds is larger than this value, we assume an endpoint is + detected. + rule3_min_utterance_length: + Used only when enable_endpoint_detection is True. If the utterance + length in seconds is larger than this value, we assume an endpoint + is detected. + decoding_method: + The only valid value is greedy_search. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + device: + onnxruntime cuda device index. + """ + self = cls.__new__(cls) + _assert_file_exists(tokens) + _assert_file_exists(encoder) + _assert_file_exists(decoder) + + assert num_threads > 0, num_threads + + paraformer_config = OnlineParaformerModelConfig( + encoder=encoder, + decoder=decoder, + ) + + provider_config = ProviderConfig( + provider=provider, + device=device, + ) + + model_config = OnlineModelConfig( + paraformer=paraformer_config, + tokens=tokens, + num_threads=num_threads, + provider_config=provider_config, + model_type="paraformer", + debug=debug, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + endpoint_config = EndpointConfig( + rule1_min_trailing_silence=rule1_min_trailing_silence, + rule2_min_trailing_silence=rule2_min_trailing_silence, + rule3_min_utterance_length=rule3_min_utterance_length, + ) + + recognizer_config = OnlineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + endpoint_config=endpoint_config, + enable_endpoint=enable_endpoint_detection, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_zipformer2_ctc( + cls, + tokens: str, + model: str, + num_threads: int = 2, + sample_rate: float = 16000, + feature_dim: int = 80, + enable_endpoint_detection: bool = False, + rule1_min_trailing_silence: float = 2.4, + rule2_min_trailing_silence: float = 1.2, + rule3_min_utterance_length: float = 20.0, + decoding_method: str = "greedy_search", + ctc_graph: str = "", + ctc_max_active: int = 3000, + provider: str = "cpu", + debug: bool = False, + rule_fsts: str = "", + rule_fars: str = "", + device: int = 0, + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + model: + Path to ``model.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + enable_endpoint_detection: + True to enable endpoint detection. False to disable endpoint + detection. + rule1_min_trailing_silence: + Used only when enable_endpoint_detection is True. If the duration + of trailing silence in seconds is larger than this value, we assume + an endpoint is detected. + rule2_min_trailing_silence: + Used only when enable_endpoint_detection is True. If we have decoded + something that is nonsilence and if the duration of trailing silence + in seconds is larger than this value, we assume an endpoint is + detected. + rule3_min_utterance_length: + Used only when enable_endpoint_detection is True. If the utterance + length in seconds is larger than this value, we assume an endpoint + is detected. + decoding_method: + The only valid value is greedy_search. + ctc_graph: + If not empty, decoding_method is ignored. It contains the path to + H.fst, HL.fst, or HLG.fst + ctc_max_active: + Used only when ctc_graph is not empty. It specifies the maximum + active paths at a time. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + device: + onnxruntime cuda device index. + """ + self = cls.__new__(cls) + _assert_file_exists(tokens) + _assert_file_exists(model) + + assert num_threads > 0, num_threads + + zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model) + + provider_config = ProviderConfig( + provider=provider, + device=device, + ) + + model_config = OnlineModelConfig( + zipformer2_ctc=zipformer2_ctc_config, + tokens=tokens, + num_threads=num_threads, + provider_config=provider_config, + debug=debug, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + endpoint_config = EndpointConfig( + rule1_min_trailing_silence=rule1_min_trailing_silence, + rule2_min_trailing_silence=rule2_min_trailing_silence, + rule3_min_utterance_length=rule3_min_utterance_length, + ) + + ctc_fst_decoder_config = OnlineCtcFstDecoderConfig( + graph=ctc_graph, + max_active=ctc_max_active, + ) + + recognizer_config = OnlineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + endpoint_config=endpoint_config, + ctc_fst_decoder_config=ctc_fst_decoder_config, + enable_endpoint=enable_endpoint_detection, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_nemo_ctc( + cls, + tokens: str, + model: str, + num_threads: int = 2, + sample_rate: float = 16000, + feature_dim: int = 80, + enable_endpoint_detection: bool = False, + rule1_min_trailing_silence: float = 2.4, + rule2_min_trailing_silence: float = 1.2, + rule3_min_utterance_length: float = 20.0, + decoding_method: str = "greedy_search", + provider: str = "cpu", + debug: bool = False, + rule_fsts: str = "", + rule_fars: str = "", + device: int = 0, + ): + """ + Please refer to + ``_ + to download pre-trained models. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + model: + Path to ``model.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + enable_endpoint_detection: + True to enable endpoint detection. False to disable endpoint + detection. + rule1_min_trailing_silence: + Used only when enable_endpoint_detection is True. If the duration + of trailing silence in seconds is larger than this value, we assume + an endpoint is detected. + rule2_min_trailing_silence: + Used only when enable_endpoint_detection is True. If we have decoded + something that is nonsilence and if the duration of trailing silence + in seconds is larger than this value, we assume an endpoint is + detected. + rule3_min_utterance_length: + Used only when enable_endpoint_detection is True. If the utterance + length in seconds is larger than this value, we assume an endpoint + is detected. + decoding_method: + The only valid value is greedy_search. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + debug: + True to show meta data in the model. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + device: + onnxruntime cuda device index. + """ + self = cls.__new__(cls) + _assert_file_exists(tokens) + _assert_file_exists(model) + + assert num_threads > 0, num_threads + + nemo_ctc_config = OnlineNeMoCtcModelConfig( + model=model, + ) + + provider_config = ProviderConfig( + provider=provider, + device=device, + ) + + model_config = OnlineModelConfig( + nemo_ctc=nemo_ctc_config, + tokens=tokens, + num_threads=num_threads, + provider_config=provider_config, + debug=debug, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + endpoint_config = EndpointConfig( + rule1_min_trailing_silence=rule1_min_trailing_silence, + rule2_min_trailing_silence=rule2_min_trailing_silence, + rule3_min_utterance_length=rule3_min_utterance_length, + ) + + recognizer_config = OnlineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + endpoint_config=endpoint_config, + enable_endpoint=enable_endpoint_detection, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_wenet_ctc( + cls, + tokens: str, + model: str, + chunk_size: int = 16, + num_left_chunks: int = 4, + num_threads: int = 2, + sample_rate: float = 16000, + feature_dim: int = 80, + enable_endpoint_detection: bool = False, + rule1_min_trailing_silence: float = 2.4, + rule2_min_trailing_silence: float = 1.2, + rule3_min_utterance_length: float = 20.0, + decoding_method: str = "greedy_search", + provider: str = "cpu", + debug: bool = False, + rule_fsts: str = "", + rule_fars: str = "", + device: int = 0, + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + model: + Path to ``model.onnx``. + chunk_size: + The --chunk-size parameter from WeNet. + num_left_chunks: + The --num-left-chunks parameter from WeNet. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + enable_endpoint_detection: + True to enable endpoint detection. False to disable endpoint + detection. + rule1_min_trailing_silence: + Used only when enable_endpoint_detection is True. If the duration + of trailing silence in seconds is larger than this value, we assume + an endpoint is detected. + rule2_min_trailing_silence: + Used only when enable_endpoint_detection is True. If we have decoded + something that is nonsilence and if the duration of trailing silence + in seconds is larger than this value, we assume an endpoint is + detected. + rule3_min_utterance_length: + Used only when enable_endpoint_detection is True. If the utterance + length in seconds is larger than this value, we assume an endpoint + is detected. + decoding_method: + The only valid value is greedy_search. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + device: + onnxruntime cuda device index. + """ + self = cls.__new__(cls) + _assert_file_exists(tokens) + _assert_file_exists(model) + + assert num_threads > 0, num_threads + + wenet_ctc_config = OnlineWenetCtcModelConfig( + model=model, + chunk_size=chunk_size, + num_left_chunks=num_left_chunks, + ) + + provider_config = ProviderConfig( + provider=provider, + device=device, + ) + + model_config = OnlineModelConfig( + wenet_ctc=wenet_ctc_config, + tokens=tokens, + num_threads=num_threads, + provider_config=provider_config, + debug=debug, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + endpoint_config = EndpointConfig( + rule1_min_trailing_silence=rule1_min_trailing_silence, + rule2_min_trailing_silence=rule2_min_trailing_silence, + rule3_min_utterance_length=rule3_min_utterance_length, + ) + + recognizer_config = OnlineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + endpoint_config=endpoint_config, + enable_endpoint=enable_endpoint_detection, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + def create_stream(self, hotwords: Optional[str] = None): + if hotwords is None: + return self.recognizer.create_stream() + else: + return self.recognizer.create_stream(hotwords) + + def decode_stream(self, s: OnlineStream): + self.recognizer.decode_stream(s) + + def decode_streams(self, ss: List[OnlineStream]): + self.recognizer.decode_streams(ss) + + def is_ready(self, s: OnlineStream) -> bool: + return self.recognizer.is_ready(s) + + def get_result_all(self, s: OnlineStream) -> OnlineRecognizerResult: + return self.recognizer.get_result(s) + + def get_result(self, s: OnlineStream) -> str: + return self.recognizer.get_result(s).text.strip() + + def get_result_as_json_string(self, s: OnlineStream) -> str: + return self.recognizer.get_result(s).as_json_string() + + def tokens(self, s: OnlineStream) -> List[str]: + return self.recognizer.get_result(s).tokens + + def timestamps(self, s: OnlineStream) -> List[float]: + return self.recognizer.get_result(s).timestamps + + def start_time(self, s: OnlineStream) -> float: + return self.recognizer.get_result(s).start_time + + def ys_probs(self, s: OnlineStream) -> List[float]: + return self.recognizer.get_result(s).ys_probs + + def lm_probs(self, s: OnlineStream) -> List[float]: + return self.recognizer.get_result(s).lm_probs + + def context_scores(self, s: OnlineStream) -> List[float]: + return self.recognizer.get_result(s).context_scores + + def is_endpoint(self, s: OnlineStream) -> bool: + return self.recognizer.is_endpoint(s) + + def reset(self, s: OnlineStream) -> bool: + return self.recognizer.reset(s) diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/utils.py b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/utils.py new file mode 100644 index 00000000..fd36f2c0 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/sherpa_mnn/utils.py @@ -0,0 +1,133 @@ +# Copyright (c) 2023 Xiaomi Corporation +import re + +from pathlib import Path +from typing import List, Optional, Union + + +def text2token( + texts: List[str], + tokens: str, + tokens_type: str = "cjkchar", + bpe_model: Optional[str] = None, + output_ids: bool = False, +) -> List[List[Union[str, int]]]: + """ + Encode the given texts (a list of string) to a list of a list of tokens. + + Args: + texts: + The given contexts list (a list of string). + tokens: + The path of the tokens.txt. + tokens_type: + The valid values are cjkchar, bpe, cjkchar+bpe, fpinyin, ppinyin. + fpinyin means full pinyin, each cjkchar has a pinyin(with tone). + ppinyin means partial pinyin, it splits pinyin into initial and final, + bpe_model: + The path of the bpe model. Only required when tokens_type is bpe or + cjkchar+bpe. + output_ids: + True to output token ids otherwise tokens. + Returns: + Return the encoded texts, it is a list of a list of token ids if output_ids + is True, or it is a list of list of tokens. + """ + try: + import sentencepiece as spm + except ImportError: + print("Please run") + print(" pip install sentencepiece") + print("before you continue") + raise + + try: + from pypinyin import pinyin + from pypinyin.contrib.tone_convert import to_initials, to_finals_tone + except ImportError: + print("Please run") + print(" pip install pypinyin") + print("before you continue") + raise + + assert Path(tokens).is_file(), f"File not exists, {tokens}" + tokens_table = {} + with open(tokens, "r", encoding="utf-8") as f: + for line in f: + toks = line.strip().split() + assert len(toks) == 2, len(toks) + assert toks[0] not in tokens_table, f"Duplicate token: {toks} " + tokens_table[toks[0]] = int(toks[1]) + + if "bpe" in tokens_type: + assert Path(bpe_model).is_file(), f"File not exists, {bpe_model}" + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + texts_list: List[List[str]] = [] + + if tokens_type == "cjkchar": + texts_list = [list("".join(text.split())) for text in texts] + elif tokens_type == "bpe": + texts_list = sp.encode(texts, out_type=str) + elif "pinyin" in tokens_type: + for txt in texts: + py = [x[0] for x in pinyin(txt)] + if "ppinyin" == tokens_type: + res = [] + for x in py: + initial = to_initials(x, strict=False) + final = to_finals_tone(x, strict=False) + if initial == "" and final == "": + res.append(x) + else: + if initial != "": + res.append(initial) + if final != "": + res.append(final) + texts_list.append(res) + else: + texts_list.append(py) + else: + assert ( + tokens_type == "cjkchar+bpe" + ), f"Supported tokens_type are cjkchar, bpe, cjkchar+bpe, given {tokens_type}" + # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + pattern = re.compile(r"([\u4e00-\u9fff])") + for text in texts: + # Example: + # txt = "你好 ITS'S OKAY 的" + # chars = ["你", "好", " ITS'S OKAY ", "的"] + chars = pattern.split(text) + mix_chars = [w for w in chars if len(w.strip()) > 0] + text_list = [] + for ch_or_w in mix_chars: + # ch_or_w is a single CJK charater(i.e., "你"), do nothing. + if pattern.fullmatch(ch_or_w) is not None: + text_list.append(ch_or_w) + # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), + # encode ch_or_w using bpe_model. + else: + text_list += sp.encode_as_pieces(ch_or_w) + texts_list.append(text_list) + + result: List[List[Union[int, str]]] = [] + for text in texts_list: + text_list = [] + contain_oov = False + for txt in text: + if txt in tokens_table: + text_list.append(tokens_table[txt] if output_ids else txt) + else: + print( + f"Can't find token {txt} in token table, check your " + f"tokens.txt see if {txt} in it. skipping text : {text}." + ) + contain_oov = True + break + if contain_oov: + continue + else: + result.append(text_list) + return result diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/CMakeLists.txt b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/CMakeLists.txt new file mode 100644 index 00000000..05c3afa1 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/CMakeLists.txt @@ -0,0 +1,35 @@ +function(sherpa_onnx_add_py_test source) + get_filename_component(name ${source} NAME_WE) + set(name "${name}_py") + + add_test(NAME ${name} + COMMAND + "${PYTHON_EXECUTABLE}" + "${CMAKE_CURRENT_SOURCE_DIR}/${source}" + WORKING_DIRECTORY + ${CMAKE_CURRENT_SOURCE_DIR} + ) + + get_filename_component(sherpa_onnx_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY) + + set_property(TEST ${name} + PROPERTY ENVIRONMENT "PYTHONPATH=${sherpa_onnx_path}:$:$ENV{PYTHONPATH}" + ) +endfunction() + +# please sort the files in alphabetic order +set(py_test_files + test_fast_clustering.py + test_feature_extractor_config.py + test_keyword_spotter.py + test_offline_recognizer.py + test_online_recognizer.py + test_online_transducer_model_config.py + test_speaker_recognition.py + test_text2token.py +) + +foreach(source IN LISTS py_test_files) + sherpa_onnx_add_py_test(${source}) +endforeach() + diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_fast_clustering.py b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_fast_clustering.py new file mode 100755 index 00000000..0601c9a6 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_fast_clustering.py @@ -0,0 +1,162 @@ +# sherpa-mnn/python/tests/test_fast_clustering.py +# +# Copyright (c) 2024 Xiaomi Corporation +# +# To run this single test, use +# +# ctest --verbose -R test_fast_clustering_py +import unittest + +import sherpa_mnn +import numpy as np +from pathlib import Path +from typing import Tuple + +import soundfile as sf + + +def load_audio(filename: str) -> np.ndarray: + data, sample_rate = sf.read( + filename, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + samples = np.ascontiguousarray(data) + assert sample_rate == 16000, f"Expect sample_rate 16000. Given: {sample_rate}" + return samples + + +class TestFastClustering(unittest.TestCase): + def test_construct_by_num_clusters(self): + config = sherpa_mnn.FastClusteringConfig(num_clusters=4) + assert config.validate() is True + + print(config) + + clustering = sherpa_mnn.FastClustering(config) + features = np.array( + [ + [0.2, 0.3], # cluster 0 + [0.3, -0.4], # cluster 1 + [-0.1, -0.2], # cluster 2 + [-0.3, -0.5], # cluster 2 + [0.1, -0.2], # cluster 1 + [0.1, 0.2], # cluster 0 + [-0.8, 1.9], # cluster 3 + [-0.4, -0.6], # cluster 2 + [-0.7, 0.9], # cluster 3 + ] + ) + labels = clustering(features) + assert isinstance(labels, list) + assert len(labels) == features.shape[0] + + expected = [0, 1, 2, 2, 1, 0, 3, 2, 3] + assert labels == expected, (labels, expected) + + def test_construct_by_threshold(self): + config = sherpa_mnn.FastClusteringConfig(threshold=0.2) + assert config.validate() is True + + print(config) + + clustering = sherpa_mnn.FastClustering(config) + features = np.array( + [ + [0.2, 0.3], # cluster 0 + [0.3, -0.4], # cluster 1 + [-0.1, -0.2], # cluster 2 + [-0.3, -0.5], # cluster 2 + [0.1, -0.2], # cluster 1 + [0.1, 0.2], # cluster 0 + [-0.8, 1.9], # cluster 3 + [-0.4, -0.6], # cluster 2 + [-0.7, 0.9], # cluster 3 + ] + ) + labels = clustering(features) + assert isinstance(labels, list) + assert len(labels) == features.shape[0] + + expected = [0, 1, 2, 2, 1, 0, 3, 2, 3] + assert labels == expected, (labels, expected) + + def test_cluster_speaker_embeddings(self): + d = Path("/tmp/test-cluster") + + # Please download the onnx file from + # https://github.com/k2-fsa/sherpa-mnn/releases/tag/speaker-recongition-models + model_file = d / "3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx" + + if not model_file.exists(): + print(f"skip test since {model_file} does not exist") + return + + # Please download the test wave files from + # https://github.com/csukuangfj/sr-data + wave_dir = d / "sr-data" + if not wave_dir.is_dir(): + print(f"skip test since {wave_dir} does not exist") + return + + wave_files = [ + "enroll/fangjun-sr-1.wav", # cluster 0 + "enroll/fangjun-sr-2.wav", # cluster 0 + "enroll/fangjun-sr-3.wav", # cluster 0 + "enroll/leijun-sr-1.wav", # cluster 1 + "enroll/leijun-sr-2.wav", # cluster 1 + "enroll/liudehua-sr-1.wav", # cluster 2 + "enroll/liudehua-sr-2.wav", # cluster 2 + "test/fangjun-test-sr-1.wav", # cluster 0 + "test/fangjun-test-sr-2.wav", # cluster 0 + "test/leijun-test-sr-1.wav", # cluster 1 + "test/leijun-test-sr-2.wav", # cluster 1 + "test/leijun-test-sr-3.wav", # cluster 1 + "test/liudehua-test-sr-1.wav", # cluster 2 + "test/liudehua-test-sr-2.wav", # cluster 2 + ] + for w in wave_files: + f = d / "sr-data" / w + if not f.is_file(): + print(f"skip testing since {f} does not exist") + return + + extractor_config = sherpa_mnn.SpeakerEmbeddingExtractorConfig( + model=str(model_file), + num_threads=1, + debug=0, + ) + if not extractor_config.validate(): + raise ValueError(f"Invalid extractor config. {config}") + + extractor = sherpa_mnn.SpeakerEmbeddingExtractor(extractor_config) + + features = [] + + for w in wave_files: + f = d / "sr-data" / w + audio = load_audio(str(f)) + stream = extractor.create_stream() + stream.accept_waveform(sample_rate=16000, waveform=audio) + stream.input_finished() + + assert extractor.is_ready(stream) + embedding = extractor.compute(stream) + embedding = np.array(embedding) + features.append(embedding) + features = np.array(features) + + config = sherpa_mnn.FastClusteringConfig(num_clusters=3) + # config = sherpa_mnn.FastClusteringConfig(threshold=0.5) + clustering = sherpa_mnn.FastClustering(config) + labels = clustering(features) + + expected = [0, 0, 0, 1, 1, 2, 2] + expected += [0, 0, 1, 1, 1, 2, 2] + + assert labels == expected, (labels, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_feature_extractor_config.py b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_feature_extractor_config.py new file mode 100755 index 00000000..761be5c7 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_feature_extractor_config.py @@ -0,0 +1,29 @@ +# sherpa-mnn/python/tests/test_feature_extractor_config.py +# +# Copyright (c) 2023 Xiaomi Corporation +# +# To run this single test, use +# +# ctest --verbose -R test_feature_extractor_config_py + +import unittest + +import _sherpa_mnn + + +class TestFeatureExtractorConfig(unittest.TestCase): + def test_default_constructor(self): + config = _sherpa_mnn.FeatureExtractorConfig() + assert config.sampling_rate == 16000, config.sampling_rate + assert config.feature_dim == 80, config.feature_dim + print(config) + + def test_constructor(self): + config = _sherpa_mnn.FeatureExtractorConfig(sampling_rate=8000, feature_dim=40) + assert config.sampling_rate == 8000, config.sampling_rate + assert config.feature_dim == 40, config.feature_dim + print(config) + + +if __name__ == "__main__": + unittest.main() diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_keyword_spotter.py b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_keyword_spotter.py new file mode 100755 index 00000000..b36a2625 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_keyword_spotter.py @@ -0,0 +1,176 @@ +# sherpa-mnn/python/tests/test_keyword_spotter.py +# +# Copyright (c) 2024 Xiaomi Corporation +# +# To run this single test, use +# +# ctest --verbose -R test_keyword_spotter_py + +import unittest +import wave +from pathlib import Path +from typing import Tuple + +import numpy as np +import sherpa_mnn + +d = "/tmp/onnx-models" +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html +# to download pre-trained models for testing + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +class TestKeywordSpotter(unittest.TestCase): + def test_zipformer_transducer_en(self): + for use_int8 in [True, False]: + if use_int8: + encoder = f"{d}/sherpa-mnn-kws-zipformer-gigaspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + decoder = f"{d}/sherpa-mnn-kws-zipformer-gigaspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx" + joiner = f"{d}/sherpa-mnn-kws-zipformer-gigaspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + else: + encoder = f"{d}/sherpa-mnn-kws-zipformer-gigaspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx" + decoder = f"{d}/sherpa-mnn-kws-zipformer-gigaspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx" + joiner = f"{d}/sherpa-mnn-kws-zipformer-gigaspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx" + + tokens = ( + f"{d}/sherpa-mnn-kws-zipformer-gigaspeech-3.3M-2024-01-01/tokens.txt" + ) + keywords_file = f"{d}/sherpa-mnn-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt" + wave0 = f"{d}/sherpa-mnn-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/0.wav" + wave1 = f"{d}/sherpa-mnn-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/1.wav" + + if not Path(encoder).is_file(): + print("skipping test_zipformer_transducer_en()") + return + keyword_spotter = sherpa_mnn.KeywordSpotter( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + num_threads=1, + keywords_file=keywords_file, + provider="cpu", + ) + streams = [] + waves = [wave0, wave1] + for wave in waves: + s = keyword_spotter.create_stream() + samples, sample_rate = read_wave(wave) + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + streams.append(s) + + results = [""] * len(streams) + while True: + ready_list = [] + for i, s in enumerate(streams): + if keyword_spotter.is_ready(s): + ready_list.append(s) + r = keyword_spotter.get_result(s) + if r: + print(f"{r} is detected.") + results[i] += f"{r}/" + + keyword_spotter.reset_stream(s) + + if len(ready_list) == 0: + break + keyword_spotter.decode_streams(ready_list) + for wave_filename, result in zip(waves, results): + print(f"{wave_filename}\n{result[0:-1]}") + print("-" * 10) + + def test_zipformer_transducer_cn(self): + for use_int8 in [True, False]: + if use_int8: + encoder = f"{d}/sherpa-mnn-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + decoder = f"{d}/sherpa-mnn-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx" + joiner = f"{d}/sherpa-mnn-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.int8.onnx" + else: + encoder = f"{d}/sherpa-mnn-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx" + decoder = f"{d}/sherpa-mnn-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx" + joiner = f"{d}/sherpa-mnn-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx" + + tokens = ( + f"{d}/sherpa-mnn-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt" + ) + keywords_file = f"{d}/sherpa-mnn-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt" + wave0 = f"{d}/sherpa-mnn-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav" + wave1 = f"{d}/sherpa-mnn-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/4.wav" + wave2 = f"{d}/sherpa-mnn-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/5.wav" + + if not Path(encoder).is_file(): + print("skipping test_zipformer_transducer_cn()") + return + keyword_spotter = sherpa_mnn.KeywordSpotter( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + num_threads=1, + keywords_file=keywords_file, + provider="cpu", + ) + streams = [] + waves = [wave0, wave1, wave2] + for wave in waves: + s = keyword_spotter.create_stream() + samples, sample_rate = read_wave(wave) + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + streams.append(s) + + results = [""] * len(streams) + while True: + ready_list = [] + for i, s in enumerate(streams): + if keyword_spotter.is_ready(s): + ready_list.append(s) + r = keyword_spotter.get_result(s) + if r: + print(f"{r} is detected.") + results[i] += f"{r}/" + + keyword_spotter.reset_stream(s) + + if len(ready_list) == 0: + break + keyword_spotter.decode_streams(ready_list) + for wave_filename, result in zip(waves, results): + print(f"{wave_filename}\n{result[0:-1]}") + print("-" * 10) + + +if __name__ == "__main__": + unittest.main() diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_offline_recognizer.py b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_offline_recognizer.py new file mode 100755 index 00000000..8e18e13b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_offline_recognizer.py @@ -0,0 +1,319 @@ +# sherpa-mnn/python/tests/test_offline_recognizer.py +# +# Copyright (c) 2023 Xiaomi Corporation +# +# To run this single test, use +# +# ctest --verbose -R test_offline_recognizer_py + +import unittest +import wave +from pathlib import Path +from typing import Tuple + +import numpy as np +import sherpa_mnn + +d = "/tmp/icefall-models" +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html +# and +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html +# to download pre-trained models for testing + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +class TestOfflineRecognizer(unittest.TestCase): + def test_transducer_single_file(self): + for use_int8 in [True, False]: + if use_int8: + encoder = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.int8.onnx" + decoder = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.onnx" + joiner = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.int8.onnx" + else: + encoder = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.onnx" + decoder = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.onnx" + joiner = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.onnx" + + tokens = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/tokens.txt" + wave0 = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/test_wavs/0.wav" + + if not Path(encoder).is_file(): + print("skipping test_transducer_single_file()") + return + + recognizer = sherpa_mnn.OfflineRecognizer.from_transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + num_threads=1, + provider="cpu", + ) + + s = recognizer.create_stream() + samples, sample_rate = read_wave(wave0) + s.accept_waveform(sample_rate, samples) + recognizer.decode_stream(s) + print(s.result.text) + + def test_transducer_multiple_files(self): + for use_int8 in [True, False]: + if use_int8: + encoder = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.int8.onnx" + decoder = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.onnx" + joiner = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.int8.onnx" + else: + encoder = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/encoder-epoch-99-avg-1.onnx" + decoder = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/decoder-epoch-99-avg-1.onnx" + joiner = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/joiner-epoch-99-avg-1.onnx" + + tokens = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/tokens.txt" + wave0 = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/test_wavs/0.wav" + wave1 = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/test_wavs/1.wav" + wave2 = f"{d}/sherpa-mnn-zipformer-en-2023-04-01/test_wavs/8k.wav" + + if not Path(encoder).is_file(): + print("skipping test_transducer_multiple_files()") + return + + recognizer = sherpa_mnn.OfflineRecognizer.from_transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + num_threads=1, + provider="cpu", + ) + + s0 = recognizer.create_stream() + samples0, sample_rate0 = read_wave(wave0) + s0.accept_waveform(sample_rate0, samples0) + + s1 = recognizer.create_stream() + samples1, sample_rate1 = read_wave(wave1) + s1.accept_waveform(sample_rate1, samples1) + + s2 = recognizer.create_stream() + samples2, sample_rate2 = read_wave(wave2) + s2.accept_waveform(sample_rate2, samples2) + + recognizer.decode_streams([s0, s1, s2]) + print(s0.result.text) + print(s1.result.text) + print(s2.result.text) + + def test_paraformer_single_file(self): + for use_int8 in [True, False]: + if use_int8: + model = f"{d}/sherpa-mnn-paraformer-zh-2023-09-14/model.int8.onnx" + else: + model = f"{d}/sherpa-mnn-paraformer-zh-2023-09-14/model.onnx" + + tokens = f"{d}/sherpa-mnn-paraformer-zh-2023-09-14/tokens.txt" + wave0 = f"{d}/sherpa-mnn-paraformer-zh-2023-09-14/test_wavs/0.wav" + + if not Path(model).is_file(): + print("skipping test_paraformer_single_file()") + return + + recognizer = sherpa_mnn.OfflineRecognizer.from_paraformer( + paraformer=model, + tokens=tokens, + num_threads=1, + provider="cpu", + ) + + s = recognizer.create_stream() + samples, sample_rate = read_wave(wave0) + s.accept_waveform(sample_rate, samples) + recognizer.decode_stream(s) + print(s.result.text) + + def test_paraformer_multiple_files(self): + for use_int8 in [True, False]: + if use_int8: + model = f"{d}/sherpa-mnn-paraformer-zh-2023-09-14/model.int8.onnx" + else: + model = f"{d}/sherpa-mnn-paraformer-zh-2023-09-14/model.onnx" + + tokens = f"{d}/sherpa-mnn-paraformer-zh-2023-09-14/tokens.txt" + wave0 = f"{d}/sherpa-mnn-paraformer-zh-2023-09-14/test_wavs/0.wav" + wave1 = f"{d}/sherpa-mnn-paraformer-zh-2023-09-14/test_wavs/1.wav" + wave2 = f"{d}/sherpa-mnn-paraformer-zh-2023-09-14/test_wavs/2.wav" + wave3 = f"{d}/sherpa-mnn-paraformer-zh-2023-09-14/test_wavs/8k.wav" + + if not Path(model).is_file(): + print("skipping test_paraformer_multiple_files()") + return + + recognizer = sherpa_mnn.OfflineRecognizer.from_paraformer( + paraformer=model, + tokens=tokens, + num_threads=1, + provider="cpu", + ) + + s0 = recognizer.create_stream() + samples0, sample_rate0 = read_wave(wave0) + s0.accept_waveform(sample_rate0, samples0) + + s1 = recognizer.create_stream() + samples1, sample_rate1 = read_wave(wave1) + s1.accept_waveform(sample_rate1, samples1) + + s2 = recognizer.create_stream() + samples2, sample_rate2 = read_wave(wave2) + s2.accept_waveform(sample_rate2, samples2) + + s3 = recognizer.create_stream() + samples3, sample_rate3 = read_wave(wave3) + s3.accept_waveform(sample_rate3, samples3) + + recognizer.decode_streams([s0, s1, s2, s3]) + print(s0.result.text) + print(s1.result.text) + print(s2.result.text) + print(s3.result.text) + + def test_nemo_ctc_single_file(self): + for use_int8 in [True, False]: + if use_int8: + model = f"{d}/sherpa-mnn-nemo-ctc-en-citrinet-512/model.int8.onnx" + else: + model = f"{d}/sherpa-mnn-nemo-ctc-en-citrinet-512/model.onnx" + + tokens = f"{d}/sherpa-mnn-nemo-ctc-en-citrinet-512/tokens.txt" + wave0 = f"{d}/sherpa-mnn-nemo-ctc-en-citrinet-512/test_wavs/0.wav" + + if not Path(model).is_file(): + print("skipping test_nemo_ctc_single_file()") + return + + recognizer = sherpa_mnn.OfflineRecognizer.from_nemo_ctc( + model=model, + tokens=tokens, + num_threads=1, + provider="cpu", + ) + + s = recognizer.create_stream() + samples, sample_rate = read_wave(wave0) + s.accept_waveform(sample_rate, samples) + recognizer.decode_stream(s) + print(s.result.text) + + def test_nemo_ctc_multiple_files(self): + for use_int8 in [True, False]: + if use_int8: + model = f"{d}/sherpa-mnn-nemo-ctc-en-citrinet-512/model.int8.onnx" + else: + model = f"{d}/sherpa-mnn-nemo-ctc-en-citrinet-512/model.onnx" + + tokens = f"{d}/sherpa-mnn-nemo-ctc-en-citrinet-512/tokens.txt" + wave0 = f"{d}/sherpa-mnn-nemo-ctc-en-citrinet-512/test_wavs/0.wav" + wave1 = f"{d}/sherpa-mnn-nemo-ctc-en-citrinet-512/test_wavs/1.wav" + wave2 = f"{d}/sherpa-mnn-nemo-ctc-en-citrinet-512/test_wavs/8k.wav" + + if not Path(model).is_file(): + print("skipping test_nemo_ctc_multiple_files()") + return + + recognizer = sherpa_mnn.OfflineRecognizer.from_nemo_ctc( + model=model, + tokens=tokens, + num_threads=1, + provider="cpu", + ) + + s0 = recognizer.create_stream() + samples0, sample_rate0 = read_wave(wave0) + s0.accept_waveform(sample_rate0, samples0) + + s1 = recognizer.create_stream() + samples1, sample_rate1 = read_wave(wave1) + s1.accept_waveform(sample_rate1, samples1) + + s2 = recognizer.create_stream() + samples2, sample_rate2 = read_wave(wave2) + s2.accept_waveform(sample_rate2, samples2) + + recognizer.decode_streams([s0, s1, s2]) + print(s0.result.text) + print(s1.result.text) + print(s2.result.text) + + def _test_wenet_ctc(self): + models = [ + "sherpa-mnn-zh-wenet-aishell", + "sherpa-mnn-zh-wenet-aishell2", + "sherpa-mnn-zh-wenet-wenetspeech", + "sherpa-mnn-zh-wenet-multi-cn", + "sherpa-mnn-en-wenet-librispeech", + "sherpa-mnn-en-wenet-gigaspeech", + ] + for m in models: + for use_int8 in [True, False]: + name = "model.int8.onnx" if use_int8 else "model.onnx" + model = f"{d}/{m}/{name}" + tokens = f"{d}/{m}/tokens.txt" + + wave0 = f"{d}/{m}/test_wavs/0.wav" + wave1 = f"{d}/{m}/test_wavs/1.wav" + wave2 = f"{d}/{m}/test_wavs/8k.wav" + + if not Path(model).is_file(): + print("skipping test_wenet_ctc()") + return + + recognizer = sherpa_mnn.OfflineRecognizer.from_wenet_ctc( + model=model, + tokens=tokens, + num_threads=1, + provider="cpu", + ) + + s0 = recognizer.create_stream() + samples0, sample_rate0 = read_wave(wave0) + s0.accept_waveform(sample_rate0, samples0) + + s1 = recognizer.create_stream() + samples1, sample_rate1 = read_wave(wave1) + s1.accept_waveform(sample_rate1, samples1) + + s2 = recognizer.create_stream() + samples2, sample_rate2 = read_wave(wave2) + s2.accept_waveform(sample_rate2, samples2) + + recognizer.decode_streams([s0, s1, s2]) + print(s0.result.text) + print(s1.result.text) + print(s2.result.text) + + +if __name__ == "__main__": + unittest.main() diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_online_recognizer.py b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_online_recognizer.py new file mode 100755 index 00000000..e768c974 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_online_recognizer.py @@ -0,0 +1,257 @@ +# sherpa-mnn/python/tests/test_online_recognizer.py +# +# Copyright (c) 2023 Xiaomi Corporation +# +# To run this single test, use +# +# ctest --verbose -R test_online_recognizer_py + +import unittest +import wave +from pathlib import Path +from typing import Tuple + +import numpy as np +import sherpa_mnn + +d = "/tmp/icefall-models" +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html +# to download pre-trained models for testing + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +class TestOnlineRecognizer(unittest.TestCase): + def test_transducer_single_file(self): + for use_int8 in [True, False]: + if use_int8: + encoder = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx" + decoder = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" + joiner = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx" + else: + encoder = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx" + decoder = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" + joiner = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx" + + tokens = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt" + wave0 = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav" + + if not Path(encoder).is_file(): + print("skipping test_transducer_single_file()") + return + + for decoding_method in ["greedy_search", "modified_beam_search"]: + recognizer = sherpa_mnn.OnlineRecognizer.from_transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + num_threads=1, + decoding_method=decoding_method, + provider="cpu", + ) + s = recognizer.create_stream() + samples, sample_rate = read_wave(wave0) + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + + s.input_finished() + while recognizer.is_ready(s): + recognizer.decode_stream(s) + print(recognizer.get_result(s)) + + def test_transducer_multiple_files(self): + for use_int8 in [True, False]: + if use_int8: + encoder = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx" + decoder = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" + joiner = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx" + else: + encoder = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.onnx" + decoder = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx" + joiner = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.onnx" + + tokens = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt" + wave0 = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav" + wave1 = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/1.wav" + wave2 = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/2.wav" + wave3 = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/3.wav" + wave4 = f"{d}/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/8k.wav" + + if not Path(encoder).is_file(): + print("skipping test_transducer_multiple_files()") + return + + for decoding_method in ["greedy_search", "modified_beam_search"]: + recognizer = sherpa_mnn.OnlineRecognizer.from_transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + tokens=tokens, + num_threads=1, + decoding_method=decoding_method, + provider="cpu", + ) + streams = [] + waves = [wave0, wave1, wave2, wave3, wave4] + for wave in waves: + s = recognizer.create_stream() + samples, sample_rate = read_wave(wave) + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + streams.append(s) + + while True: + ready_list = [] + for s in streams: + if recognizer.is_ready(s): + ready_list.append(s) + if len(ready_list) == 0: + break + recognizer.decode_streams(ready_list) + results = [recognizer.get_result(s) for s in streams] + for wave_filename, result in zip(waves, results): + print(f"{wave_filename}\n{result}") + print("-" * 10) + + def test_zipformer2_ctc(self): + m = "sherpa-mnn-streaming-zipformer-ctc-multi-zh-hans-2023-12-13" + for use_int8 in [True, False]: + name = ( + "ctc-epoch-20-avg-1-chunk-16-left-128.int8.onnx" + if use_int8 + else "ctc-epoch-20-avg-1-chunk-16-left-128.onnx" + ) + model = f"{d}/{m}/{name}" + tokens = f"{d}/{m}/tokens.txt" + wave0 = f"{d}/{m}/test_wavs/DEV_T0000000000.wav" + wave1 = f"{d}/{m}/test_wavs/DEV_T0000000001.wav" + wave2 = f"{d}/{m}/test_wavs/DEV_T0000000002.wav" + if not Path(model).is_file(): + print("skipping test_zipformer2_ctc()") + return + print(f"testing {model}") + + recognizer = sherpa_mnn.OnlineRecognizer.from_zipformer2_ctc( + model=model, + tokens=tokens, + num_threads=1, + provider="cpu", + ) + + streams = [] + waves = [wave0, wave1, wave2] + for wave in waves: + s = recognizer.create_stream() + samples, sample_rate = read_wave(wave) + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + streams.append(s) + + while True: + ready_list = [] + for s in streams: + if recognizer.is_ready(s): + ready_list.append(s) + if len(ready_list) == 0: + break + recognizer.decode_streams(ready_list) + + results = [recognizer.get_result(s) for s in streams] + for wave_filename, result in zip(waves, results): + print(f"{wave_filename}\n{result}") + print("-" * 10) + + def test_wenet_ctc(self): + models = [ + "sherpa-mnn-zh-wenet-aishell", + "sherpa-mnn-zh-wenet-aishell2", + "sherpa-mnn-zh-wenet-wenetspeech", + "sherpa-mnn-zh-wenet-multi-cn", + "sherpa-mnn-en-wenet-librispeech", + "sherpa-mnn-en-wenet-gigaspeech", + ] + for m in models: + for use_int8 in [True, False]: + name = ( + "model-streaming.int8.onnx" if use_int8 else "model-streaming.onnx" + ) + model = f"{d}/{m}/{name}" + tokens = f"{d}/{m}/tokens.txt" + + wave0 = f"{d}/{m}/test_wavs/0.wav" + wave1 = f"{d}/{m}/test_wavs/1.wav" + wave2 = f"{d}/{m}/test_wavs/8k.wav" + + if not Path(model).is_file(): + print("skipping test_wenet_ctc()") + return + + recognizer = sherpa_mnn.OnlineRecognizer.from_wenet_ctc( + model=model, + tokens=tokens, + num_threads=1, + provider="cpu", + ) + + streams = [] + waves = [wave0, wave1, wave2] + for wave in waves: + s = recognizer.create_stream() + samples, sample_rate = read_wave(wave) + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + streams.append(s) + + while True: + ready_list = [] + for s in streams: + if recognizer.is_ready(s): + ready_list.append(s) + if len(ready_list) == 0: + break + recognizer.decode_streams(ready_list) + + results = [recognizer.get_result(s) for s in streams] + for wave_filename, result in zip(waves, results): + print(f"{wave_filename}\n{result}") + print("-" * 10) + + +if __name__ == "__main__": + unittest.main() diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_online_transducer_model_config.py b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_online_transducer_model_config.py new file mode 100755 index 00000000..ffe16ead --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_online_transducer_model_config.py @@ -0,0 +1,28 @@ +# sherpa-mnn/python/tests/test_online_transducer_model_config.py +# +# Copyright (c) 2023 Xiaomi Corporation +# +# To run this single test, use +# +# ctest --verbose -R test_online_transducer_model_config_py + +import unittest + +import _sherpa_mnn + + +class TestOnlineTransducerModelConfig(unittest.TestCase): + def test_constructor(self): + config = _sherpa_mnn.OnlineTransducerModelConfig( + encoder="encoder.onnx", + decoder="decoder.onnx", + joiner="joiner.onnx", + ) + assert config.encoder == "encoder.onnx", config.encoder + assert config.decoder == "decoder.onnx", config.decoder + assert config.joiner == "joiner.onnx", config.joiner + print(config) + + +if __name__ == "__main__": + unittest.main() diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_speaker_recognition.py b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_speaker_recognition.py new file mode 100755 index 00000000..3dea5a18 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_speaker_recognition.py @@ -0,0 +1,217 @@ +# sherpa-mnn/python/tests/test_speaker_recognition.py +# +# Copyright (c) 2024 Xiaomi Corporation +# +# To run this single test, use +# +# ctest --verbose -R test_speaker_recognition_py + +import unittest +import wave +from collections import defaultdict +from pathlib import Path +from typing import Tuple + +import numpy as np +import sherpa_mnn + +d = "/tmp/sr-models" + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +def load_speaker_embedding_model(model_filename): + config = sherpa_mnn.SpeakerEmbeddingExtractorConfig( + model=model_filename, + num_threads=1, + debug=True, + provider="cpu", + ) + if not config.validate(): + raise ValueError(f"Invalid config. {config}") + extractor = sherpa_mnn.SpeakerEmbeddingExtractor(config) + return extractor + + +def test_zh_models(model_filename: str): + model_filename = str(model_filename) + if "en" in model_filename: + print(f"skip {model_filename}") + return + extractor = load_speaker_embedding_model(model_filename) + filenames = [ + "leijun-sr-1", + "leijun-sr-2", + "fangjun-sr-1", + "fangjun-sr-2", + "fangjun-sr-3", + ] + tmp = defaultdict(list) + for filename in filenames: + print(filename) + name = filename.split("-", maxsplit=1)[0] + data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/enroll/{filename}.wav") + stream = extractor.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=data) + stream.input_finished() + assert extractor.is_ready(stream) + embedding = extractor.compute(stream) + embedding = np.array(embedding) + tmp[name].append(embedding) + + manager = sherpa_mnn.SpeakerEmbeddingManager(extractor.dim) + for name, embedding_list in tmp.items(): + print(name, len(embedding_list)) + embedding = sum(embedding_list) / len(embedding_list) + status = manager.add(name, embedding) + if not status: + raise RuntimeError(f"Failed to register speaker {name}") + + filenames = [ + "leijun-test-sr-1", + "leijun-test-sr-2", + "leijun-test-sr-3", + "fangjun-test-sr-1", + "fangjun-test-sr-2", + ] + for filename in filenames: + name = filename.split("-", maxsplit=1)[0] + data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/test/{filename}.wav") + stream = extractor.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=data) + stream.input_finished() + assert extractor.is_ready(stream) + embedding = extractor.compute(stream) + embedding = np.array(embedding) + status = manager.verify(name, embedding, threshold=0.5) + if not status: + raise RuntimeError(f"Failed to verify {name} with wave {filename}.wav") + + ans = manager.search(embedding, threshold=0.5) + assert ans == name, (name, ans) + + +def test_en_and_zh_models(model_filename: str): + model_filename = str(model_filename) + extractor = load_speaker_embedding_model(model_filename) + manager = sherpa_mnn.SpeakerEmbeddingManager(extractor.dim) + + filenames = [ + "speaker1_a_cn_16k", + "speaker2_a_cn_16k", + "speaker1_a_en_16k", + "speaker2_a_en_16k", + ] + is_en = "en" in model_filename + for filename in filenames: + if is_en and "cn" in filename: + continue + + if not is_en and "en" in filename: + continue + + name = filename.rsplit("_", maxsplit=1)[0] + data, sample_rate = read_wave( + f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav" + ) + stream = extractor.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=data) + stream.input_finished() + assert extractor.is_ready(stream) + embedding = extractor.compute(stream) + embedding = np.array(embedding) + + status = manager.add(name, embedding) + if not status: + raise RuntimeError(f"Failed to register speaker {name}") + + filenames = [ + "speaker1_b_cn_16k", + "speaker1_b_en_16k", + ] + for filename in filenames: + if is_en and "cn" in filename: + continue + + if not is_en and "en" in filename: + continue + print(filename) + name = filename.rsplit("_", maxsplit=1)[0] + name = name.replace("b_cn", "a_cn") + name = name.replace("b_en", "a_en") + print(name) + + data, sample_rate = read_wave( + f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav" + ) + stream = extractor.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=data) + stream.input_finished() + assert extractor.is_ready(stream) + embedding = extractor.compute(stream) + embedding = np.array(embedding) + status = manager.verify(name, embedding, threshold=0.5) + if not status: + raise RuntimeError( + f"Failed to verify {name} with wave {filename}.wav. model: {model_filename}" + ) + + ans = manager.search(embedding, threshold=0.5) + assert ans == name, (name, ans) + + +class TestSpeakerRecognition(unittest.TestCase): + def test_wespeaker_models(self): + model_dir = Path(d) / "wespeaker" + if not model_dir.is_dir(): + print(f"{model_dir} does not exist - skip it") + return + for filename in model_dir.glob("*.onnx"): + print(filename) + test_zh_models(filename) + test_en_and_zh_models(filename) + + def _test_3dpeaker_models(self): + model_dir = Path(d) / "3dspeaker" + if not model_dir.is_dir(): + print(f"{model_dir} does not exist - skip it") + return + for filename in model_dir.glob("*.onnx"): + print(filename) + test_en_and_zh_models(filename) + + def test_nemo_models(self): + model_dir = Path(d) / "nemo" + if not model_dir.is_dir(): + print(f"{model_dir} does not exist - skip it") + return + for filename in model_dir.glob("*.onnx"): + print(filename) + test_en_and_zh_models(filename) + + +if __name__ == "__main__": + unittest.main() diff --git a/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_text2token.py b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_text2token.py new file mode 100755 index 00000000..532e15f5 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/sherpa-mnn/python/tests/test_text2token.py @@ -0,0 +1,121 @@ +# sherpa-mnn/python/tests/test_text2token.py +# +# Copyright (c) 2023 Xiaomi Corporation +# +# To run this single test, use +# +# ctest --verbose -R test_text2token_py + +import unittest +from pathlib import Path + +import sherpa_mnn + +d = "/tmp/sherpa-test-data" +# Please refer to +# https://github.com/pkufool/sherpa-test-data +# to download test data for testing + + +class TestText2Token(unittest.TestCase): + def test_bpe(self): + tokens = f"{d}/text2token/tokens_en.txt" + bpe_model = f"{d}/text2token/bpe_en.model" + + if not Path(tokens).is_file() or not Path(bpe_model).is_file(): + print( + f"No test data found, skipping test_bpe().\n" + f"You can download the test data by: \n" + f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data" + ) + return + + texts = ["HELLO WORLD", "I LOVE YOU"] + encoded_texts = sherpa_mnn.text2token( + texts, + tokens=tokens, + tokens_type="bpe", + bpe_model=bpe_model, + ) + assert encoded_texts == [ + ["▁HE", "LL", "O", "▁WORLD"], + ["▁I", "▁LOVE", "▁YOU"], + ], encoded_texts + + encoded_ids = sherpa_mnn.text2token( + texts, + tokens=tokens, + tokens_type="bpe", + bpe_model=bpe_model, + output_ids=True, + ) + assert encoded_ids == [[22, 58, 24, 425], [19, 370, 47]], encoded_ids + + def test_cjkchar(self): + tokens = f"{d}/text2token/tokens_cn.txt" + + if not Path(tokens).is_file(): + print( + f"No test data found, skipping test_cjkchar().\n" + f"You can download the test data by: \n" + f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data" + ) + return + + texts = ["世界人民大团结", "中国 VS 美国"] + encoded_texts = sherpa_mnn.text2token( + texts, tokens=tokens, tokens_type="cjkchar" + ) + assert encoded_texts == [ + ["世", "界", "人", "民", "大", "团", "结"], + ["中", "国", "V", "S", "美", "国"], + ], encoded_texts + encoded_ids = sherpa_mnn.text2token( + texts, + tokens=tokens, + tokens_type="cjkchar", + output_ids=True, + ) + assert encoded_ids == [ + [379, 380, 72, 874, 93, 1251, 489], + [262, 147, 3423, 2476, 21, 147], + ], encoded_ids + + def test_cjkchar_bpe(self): + tokens = f"{d}/text2token/tokens_mix.txt" + bpe_model = f"{d}/text2token/bpe_mix.model" + + if not Path(tokens).is_file() or not Path(bpe_model).is_file(): + print( + f"No test data found, skipping test_cjkchar_bpe().\n" + f"You can download the test data by: \n" + f"git clone https://github.com/pkufool/sherpa-test-data.git /tmp/sherpa-test-data" + ) + return + + texts = ["世界人民 GOES TOGETHER", "中国 GOES WITH 美国"] + encoded_texts = sherpa_mnn.text2token( + texts, + tokens=tokens, + tokens_type="cjkchar+bpe", + bpe_model=bpe_model, + ) + assert encoded_texts == [ + ["世", "界", "人", "民", "▁GO", "ES", "▁TOGETHER"], + ["中", "国", "▁GO", "ES", "▁WITH", "美", "国"], + ], encoded_texts + encoded_ids = sherpa_mnn.text2token( + texts, + tokens=tokens, + tokens_type="cjkchar+bpe", + bpe_model=bpe_model, + output_ids=True, + ) + assert encoded_ids == [ + [1368, 1392, 557, 680, 275, 178, 475], + [685, 736, 275, 178, 179, 921, 736], + ], encoded_ids + + +if __name__ == "__main__": + unittest.main() diff --git a/apps/frameworks/sherpa-mnn/toolchains/aarch64-linux-gnu.toolchain.cmake b/apps/frameworks/sherpa-mnn/toolchains/aarch64-linux-gnu.toolchain.cmake new file mode 100644 index 00000000..e72e0cba --- /dev/null +++ b/apps/frameworks/sherpa-mnn/toolchains/aarch64-linux-gnu.toolchain.cmake @@ -0,0 +1,18 @@ +# Copied from https://github.com/Tencent/ncnn/blob/master/toolchains/aarch64-linux-gnu.toolchain.cmake + +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_SYSTEM_PROCESSOR aarch64) + +set(CMAKE_C_COMPILER "aarch64-linux-gnu-gcc") +set(CMAKE_CXX_COMPILER "aarch64-linux-gnu-g++") + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) + +set(CMAKE_C_FLAGS "-march=armv8-a") +set(CMAKE_CXX_FLAGS "-march=armv8-a") + +# cache flags +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "c flags") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "c++ flags") diff --git a/apps/frameworks/sherpa-mnn/toolchains/arm-linux-gnueabihf.toolchain.cmake b/apps/frameworks/sherpa-mnn/toolchains/arm-linux-gnueabihf.toolchain.cmake new file mode 100644 index 00000000..abe1a22b --- /dev/null +++ b/apps/frameworks/sherpa-mnn/toolchains/arm-linux-gnueabihf.toolchain.cmake @@ -0,0 +1,17 @@ +# Copied from https://github.com/Tencent/ncnn/blob/master/toolchains/arm-linux-gnueabihf.toolchain.cmake +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_SYSTEM_PROCESSOR arm) + +set(CMAKE_C_COMPILER "arm-linux-gnueabihf-gcc") +set(CMAKE_CXX_COMPILER "arm-linux-gnueabihf-g++") + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) + +set(CMAKE_C_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon") +set(CMAKE_CXX_FLAGS "-march=armv7-a -mfloat-abi=hard -mfpu=neon") + +# cache flags +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "c flags") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "c++ flags") diff --git a/apps/frameworks/sherpa-mnn/toolchains/ios.toolchain.cmake b/apps/frameworks/sherpa-mnn/toolchains/ios.toolchain.cmake new file mode 100644 index 00000000..a99bdff4 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/toolchains/ios.toolchain.cmake @@ -0,0 +1,948 @@ +# This file is part of the ios-cmake project. It was retrieved from +# https://github.com/leetal/ios-cmake.git, which is a fork of +# https://github.com/gerstrong/ios-cmake.git, which is a fork of +# https://github.com/cristeab/ios-cmake.git, which is a fork of +# https://code.google.com/p/ios-cmake/. Which in turn is based off of +# the Platform/Darwin.cmake and Platform/UnixPaths.cmake files which +# are included with CMake 2.8.4 +# +# The ios-cmake project is licensed under the new BSD license. +# +# Copyright (c) 2014, Bogdan Cristea and LTE Engineering Software, +# Kitware, Inc., Insight Software Consortium. All rights reserved. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# This file is based off of the Platform/Darwin.cmake and +# Platform/UnixPaths.cmake files which are included with CMake 2.8.4 +# It has been altered for iOS development. +# +# Updated by Alex Stewart (alexs.mac@gmail.com) +# +# ***************************************************************************** +# Now maintained by Alexander Widerberg (widerbergaren [at] gmail.com) +# under the BSD-3-Clause license +# https://github.com/leetal/ios-cmake +# ***************************************************************************** +# +# INFORMATION / HELP +# +# The following options control the behaviour of this toolchain: +# +# PLATFORM: (default "OS64") +# OS = Build for iPhoneOS. +# OS64 = Build for arm64 iphoneOS. +# OS64COMBINED = Build for arm64 x86_64 iphoneOS. Combined into FAT STATIC lib (supported on 3.14+ of CMakewith "-G Xcode" argument ONLY) +# SIMULATOR = Build for x86 i386 iphoneOS Simulator. +# SIMULATOR64 = Build for x86_64 iphoneOS Simulator. +# SIMULATORARM64 = Build for arm64 iphoneOS Simulator. +# TVOS = Build for arm64 tvOS. +# TVOSCOMBINED = Build for arm64 x86_64 tvOS. Combined into FAT STATIC lib (supported on 3.14+ of CMake with "-G Xcode" argument ONLY) +# SIMULATOR_TVOS = Build for x86_64 tvOS Simulator. +# WATCHOS = Build for armv7k arm64_32 for watchOS. +# WATCHOSCOMBINED = Build for armv7k arm64_32 x86_64 watchOS. Combined into FAT STATIC lib (supported on 3.14+ of CMake with "-G Xcode" argument ONLY) +# SIMULATOR_WATCHOS = Build for x86_64 for watchOS Simulator. +# MAC = Build for x86_64 macOS. +# MAC_ARM64 = Build for Apple Silicon macOS. +# MAC_CATALYST = Build for x86_64 macOS with Catalyst support (iOS toolchain on macOS). +# Note: The build argument "MACOSX_DEPLOYMENT_TARGET" can be used to control min-version of macOS +# MAC_CATALYST_ARM64 = Build for Apple Silicon macOS with Catalyst support (iOS toolchain on macOS). +# Note: The build argument "MACOSX_DEPLOYMENT_TARGET" can be used to control min-version of macOS +# +# CMAKE_OSX_SYSROOT: Path to the SDK to use. By default this is +# automatically determined from PLATFORM and xcodebuild, but +# can also be manually specified (although this should not be required). +# +# CMAKE_DEVELOPER_ROOT: Path to the Developer directory for the platform +# being compiled for. By default this is automatically determined from +# CMAKE_OSX_SYSROOT, but can also be manually specified (although this should +# not be required). +# +# DEPLOYMENT_TARGET: Minimum SDK version to target. Default 2.0 on watchOS and 9.0 on tvOS+iOS +# +# ENABLE_BITCODE: (1|0) Enables or disables bitcode support. Default 1 (true) +# +# ENABLE_ARC: (1|0) Enables or disables ARC support. Default 1 (true, ARC enabled by default) +# +# ENABLE_VISIBILITY: (1|0) Enables or disables symbol visibility support. Default 0 (false, visibility hidden by default) +# +# ENABLE_STRICT_TRY_COMPILE: (1|0) Enables or disables strict try_compile() on all Check* directives (will run linker +# to actually check if linking is possible). Default 0 (false, will set CMAKE_TRY_COMPILE_TARGET_TYPE to STATIC_LIBRARY) +# +# ARCHS: (armv7 armv7s armv7k arm64 arm64_32 i386 x86_64) If specified, will override the default architectures for the given PLATFORM +# OS = armv7 armv7s arm64 (if applicable) +# OS64 = arm64 (if applicable) +# SIMULATOR = i386 +# SIMULATOR64 = x86_64 +# SIMULATORARM64 = arm64 +# TVOS = arm64 +# SIMULATOR_TVOS = x86_64 (i386 has since long been deprecated) +# WATCHOS = armv7k arm64_32 (if applicable) +# SIMULATOR_WATCHOS = x86_64 (i386 has since long been deprecated) +# MAC = x86_64 +# MAC_ARM64 = arm64 +# MAC_CATALYST = x86_64 +# MAC_CATALYST_ARM64 = arm64 +# +# This toolchain defines the following properties (available via get_property()) for use externally: +# +# PLATFORM: The currently targeted platform. +# XCODE_VERSION: Version number (not including Build version) of Xcode detected. +# SDK_VERSION: Version of SDK being used. +# OSX_ARCHITECTURES: Architectures being compiled for (generated from PLATFORM). +# APPLE_TARGET_TRIPLE: Used by autoconf build systems. NOTE: If "ARCHS" are overridden, this will *NOT* be set! +# +# This toolchain defines the following macros for use externally: +# +# set_xcode_property (TARGET XCODE_PROPERTY XCODE_VALUE XCODE_VARIANT) +# A convenience macro for setting xcode specific properties on targets. +# Available variants are: All, Release, RelWithDebInfo, Debug, MinSizeRel +# example: set_xcode_property (myioslib IPHONEOS_DEPLOYMENT_TARGET "3.1" "all"). +# +# find_host_package (PROGRAM ARGS) +# A macro used to find executable programs on the host system, not within the +# environment. Thanks to the android-cmake project for providing the +# command. +# + +cmake_minimum_required(VERSION 3.8.0) + +# CMake invokes the toolchain file twice during the first build, but only once during subsequent rebuilds. +if(IOS_TOOLCHAIN_HAS_RUN) + return() +endif(IOS_TOOLCHAIN_HAS_RUN) +set(IOS_TOOLCHAIN_HAS_RUN true) + +############################################################################### +# OPTIONS # +############################################################################### + +option(DROP_32_BIT "Drops the 32-bit targets universally." YES) + +############################################################################### +# END OPTIONS # +############################################################################### + +# List of supported platform values +list(APPEND _supported_platforms + "OS" "OS64" "OS64COMBINED" "SIMULATOR" "SIMULATOR64" "SIMULATORARM64" + "TVOS" "TVOSCOMBINED" "SIMULATOR_TVOS" + "WATCHOS" "WATCHOSCOMBINED" "SIMULATOR_WATCHOS" + "MAC" "MAC_ARM64" + "MAC_CATALYST" "MAC_CATALYST_ARM64" + "XROS" "XRSIMULATOR") + +# Cache what generator is used +set(USED_CMAKE_GENERATOR "${CMAKE_GENERATOR}") + +# Check if using a CMake version capable of building combined FAT builds (simulator and target slices combined in one static lib) +if(${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.14") + set(MODERN_CMAKE YES) +endif() + +# Get the Xcode version being used. +# Problem: CMake runs toolchain files multiple times, but can't read cache variables on some runs. +# Workaround: On first run (in which cache variables are always accessible), set an intermediary environment variable. +# +# NOTE: This pattern is used i many places in this toolchain to speed up checks of all sorts +if(DEFINED XCODE_VERSION_INT) + # Environment variables are always preserved. + set(ENV{_XCODE_VERSION_INT} "${XCODE_VERSION_INT}") +elseif(DEFINED ENV{_XCODE_VERSION_INT}) + set(XCODE_VERSION_INT "$ENV{_XCODE_VERSION_INT}") +elseif(NOT DEFINED XCODE_VERSION_INT) + find_program(XCODEBUILD_EXECUTABLE xcodebuild) + if(NOT XCODEBUILD_EXECUTABLE) + message(FATAL_ERROR "xcodebuild not found. Please install either the standalone commandline tools or Xcode.") + endif() + execute_process(COMMAND ${XCODEBUILD_EXECUTABLE} -version + OUTPUT_VARIABLE XCODE_VERSION_INT + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) + string(REGEX MATCH "Xcode [0-9\\.]+" XCODE_VERSION_INT "${XCODE_VERSION_INT}") + string(REGEX REPLACE "Xcode ([0-9\\.]+)" "\\1" XCODE_VERSION_INT "${XCODE_VERSION_INT}") + set(XCODE_VERSION_INT "${XCODE_VERSION_INT}" CACHE INTERNAL "") +endif() + +# Assuming that xcode 12.0 is installed you most probably have ios sdk 14.0 or later installed (tested on Big Sur) +# if you don't set a deployment target it will be set the way you only get 64-bit builds +if(NOT DEFINED DEPLOYMENT_TARGET AND XCODE_VERSION_INT VERSION_GREATER 12.0) + # Temporarily fix the arm64 issues in CMake install-combined by excluding arm64 for simulator builds (needed for Apple Silicon...) + set(CMAKE_XCODE_ATTRIBUTE_EXCLUDED_ARCHS[sdk=iphonesimulator*] "arm64") +endif() + +# Check if the platform variable is set +if(DEFINED PLATFORM) + # Environment variables are always preserved. + set(ENV{_PLATFORM} "${PLATFORM}") +elseif(DEFINED ENV{_PLATFORM}) + set(PLATFORM "$ENV{_PLATFORM}") +elseif(NOT DEFINED PLATFORM) + message(FATAL_ERROR "PLATFORM argument not set. Bailing configure since I don't know what target you want to build for!") +endif () + +# Safeguard that the platform value is set and is one of the supported values +list(FIND _supported_platforms ${PLATFORM} contains_PLATFORM) +if("${contains_PLATFORM}" EQUAL "-1") + string(REPLACE ";" "\n * " _supported_platforms_formatted "${_supported_platforms}") + message(FATAL_ERROR " Invalid PLATFORM specified! Current value: ${PLATFORM}.\n" + " Supported PLATFORM values: \n * ${_supported_platforms_formatted}") +endif() + +# Check if Apple Silicon is supported +if(PLATFORM MATCHES "^(MAC_ARM64)$|^(MAC_CATALYST_ARM64)$" AND ${CMAKE_VERSION} VERSION_LESS "3.19.5") + message(FATAL_ERROR "Apple Silicon builds requires a minimum of CMake 3.19.5") +endif() + +# Touch toolchain variable to suppress "unused variable" warning. +# This happens if CMake is invoked with the same command line the second time. +if(CMAKE_TOOLCHAIN_FILE) +endif() + +# Fix for PThread library not in path +set(CMAKE_THREAD_LIBS_INIT "-lpthread") +set(CMAKE_HAVE_THREADS_LIBRARY 1) +set(CMAKE_USE_WIN32_THREADS_INIT 0) +set(CMAKE_USE_PTHREADS_INIT 1) + +# Specify minimum version of deployment target. +if(NOT DEFINED DEPLOYMENT_TARGET) + if (PLATFORM MATCHES "WATCHOS") + # Unless specified, SDK version 4.0 is used by default as minimum target version (watchOS). + set(DEPLOYMENT_TARGET "4.0") + elseif(PLATFORM STREQUAL "MAC") + # Unless specified, SDK version 10.13 (High sierra) is used by default as minimum target version (macos). + set(DEPLOYMENT_TARGET "10.13") + elseif(PLATFORM STREQUAL "MAC_ARM64") + # Unless specified, SDK version 11.0 (Big Sur) is used by default as minimum target version (macos on arm). + set(DEPLOYMENT_TARGET "11.0") + elseif(PLATFORM STREQUAL "MAC_CATALYST" OR PLATFORM STREQUAL "MAC_CATALYST_ARM64") + # Unless specified, SDK version 13.0 is used by default as minimum target version (mac catalyst minimum requirement). + set(DEPLOYMENT_TARGET "13.0") + elseif(PLATFORM STREQUAL "XROS" OR PLATFORM STREQUAL "XRSIMULATOR") + set(DEPLOYMENT_TARGET "1.0") + else() + # Unless specified, SDK version 11.0 is used by default as minimum target version (iOS, tvOS). + set(DEPLOYMENT_TARGET "11.0") + endif() + message(STATUS "[DEFAULTS] Using the default min-version since DEPLOYMENT_TARGET not provided!") +elseif(DEFINED DEPLOYMENT_TARGET AND PLATFORM STREQUAL "MAC_CATALYST" AND ${DEPLOYMENT_TARGET} VERSION_LESS "13.0") + message(FATAL_ERROR "Mac Catalyst builds requires a minimum deployment target of 13.0!") +endif() + +# Store the DEPLOYMENT_TARGET in the cache +set(DEPLOYMENT_TARGET "${DEPLOYMENT_TARGET}" CACHE INTERNAL "") + +# Handle the case where we are targeting iOS and a version above 10.3.4 (32-bit support dropped officially) +if(PLATFORM STREQUAL "OS" AND DEPLOYMENT_TARGET VERSION_GREATER_EQUAL 10.3.4) + set(PLATFORM "OS64") + message(STATUS "Targeting minimum SDK version ${DEPLOYMENT_TARGET}. Dropping 32-bit support.") +elseif(PLATFORM STREQUAL "SIMULATOR" AND DEPLOYMENT_TARGET VERSION_GREATER_EQUAL 10.3.4) + set(PLATFORM "SIMULATOR64") + message(STATUS "Targeting minimum SDK version ${DEPLOYMENT_TARGET}. Dropping 32-bit support.") +endif() + +set(PLATFORM_INT "${PLATFORM}") + +if(DEFINED ARCHS) + string(REPLACE ";" "-" ARCHS_SPLIT "${ARCHS}") +endif() + +# Determine the platform name and architectures for use in xcodebuild commands +# from the specified PLATFORM_INT name. +if(PLATFORM_INT STREQUAL "OS") + set(SDK_NAME iphoneos) + if(NOT ARCHS) + set(ARCHS armv7 armv7s arm64) + set(APPLE_TARGET_TRIPLE_INT arm-apple-ios) + endif() +elseif(PLATFORM_INT STREQUAL "OS64") + set(SDK_NAME iphoneos) + if(NOT ARCHS) + if (XCODE_VERSION_INT VERSION_GREATER 10.0) + set(ARCHS arm64) # Add arm64e when Apple have fixed the integration issues with it, libarclite_iphoneos.a is currently missung bitcode markers for example + else() + set(ARCHS arm64) + endif() + set(APPLE_TARGET_TRIPLE_INT aarch64-apple-ios) + else() + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-ios) + endif() +elseif(PLATFORM_INT STREQUAL "OS64COMBINED") + set(SDK_NAME iphoneos) + if(MODERN_CMAKE) + if(NOT ARCHS) + if (XCODE_VERSION_INT VERSION_GREATER 10.0) + set(ARCHS arm64 x86_64) # Add arm64e when Apple have fixed the integration issues with it, libarclite_iphoneos.a is currently missung bitcode markers for example + set(CMAKE_XCODE_ATTRIBUTE_ARCHS[sdk=iphoneos*] "arm64") + set(CMAKE_XCODE_ATTRIBUTE_ARCHS[sdk=iphonesimulator*] "x86_64") + set(CMAKE_XCODE_ATTRIBUTE_VALID_ARCHS[sdk=iphoneos*] "arm64") + set(CMAKE_XCODE_ATTRIBUTE_VALID_ARCHS[sdk=iphonesimulator*] "x86_64") + else() + set(ARCHS arm64 x86_64) + set(CMAKE_XCODE_ATTRIBUTE_ARCHS[sdk=iphoneos*] "arm64") + set(CMAKE_XCODE_ATTRIBUTE_ARCHS[sdk=iphonesimulator*] "x86_64") + set(CMAKE_XCODE_ATTRIBUTE_VALID_ARCHS[sdk=iphoneos*] "arm64") + set(CMAKE_XCODE_ATTRIBUTE_VALID_ARCHS[sdk=iphonesimulator*] "x86_64") + endif() + set(APPLE_TARGET_TRIPLE_INT aarch64-x86_64-apple-ios) + else() + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-ios) + endif() + else() + message(FATAL_ERROR "Please make sure that you are running CMake 3.14+ to make the OS64COMBINED setting work") + endif() +elseif(PLATFORM_INT STREQUAL "SIMULATOR") + set(SDK_NAME iphonesimulator) + if(NOT ARCHS) + set(ARCHS i386) + set(APPLE_TARGET_TRIPLE_INT i386-apple-ios) + else() + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-ios) + endif() + message(DEPRECATION "SIMULATOR IS DEPRECATED. Consider using SIMULATOR64 instead.") +elseif(PLATFORM_INT STREQUAL "SIMULATOR64") + set(SDK_NAME iphonesimulator) + if(NOT ARCHS) + set(ARCHS x86_64) + set(APPLE_TARGET_TRIPLE_INT x86_64-apple-ios) + else() + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-ios) + endif() +elseif(PLATFORM_INT STREQUAL "SIMULATORARM64") + set(SDK_NAME iphonesimulator) + if(NOT ARCHS) + set(ARCHS arm64) + set(APPLE_TARGET_TRIPLE_INT aarch64-apple-ios) + else() + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-ios) + endif() +elseif(PLATFORM_INT STREQUAL "TVOS") + set(SDK_NAME appletvos) + if(NOT ARCHS) + set(ARCHS arm64) + set(APPLE_TARGET_TRIPLE_INT aarch64-apple-tvos) + else() + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-tvos) + endif() +elseif (PLATFORM_INT STREQUAL "TVOSCOMBINED") + set(SDK_NAME appletvos) + if(MODERN_CMAKE) + if(NOT ARCHS) + set(ARCHS arm64 x86_64) + set(APPLE_TARGET_TRIPLE_INT aarch64-x86_64-apple-tvos) + set(CMAKE_XCODE_ATTRIBUTE_ARCHS[sdk=appletvos*] "arm64") + set(CMAKE_XCODE_ATTRIBUTE_ARCHS[sdk=appletvsimulator*] "x86_64") + set(CMAKE_XCODE_ATTRIBUTE_VALID_ARCHS[sdk=appletvos*] "arm64") + set(CMAKE_XCODE_ATTRIBUTE_VALID_ARCHS[sdk=appletvsimulator*] "x86_64") + else() + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-tvos) + endif() + else() + message(FATAL_ERROR "Please make sure that you are running CMake 3.14+ to make the TVOSCOMBINED setting work") + endif() +elseif(PLATFORM_INT STREQUAL "SIMULATOR_TVOS") + set(SDK_NAME appletvsimulator) + if(NOT ARCHS) + set(ARCHS x86_64) + set(APPLE_TARGET_TRIPLE_INT x86_64-apple-tvos) + else() + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-tvos) + endif() +elseif(PLATFORM_INT STREQUAL "WATCHOS") + set(SDK_NAME watchos) + if(NOT ARCHS) + if (XCODE_VERSION_INT VERSION_GREATER 10.0) + set(ARCHS armv7k arm64_32) + set(APPLE_TARGET_TRIPLE_INT aarch64_32-apple-watchos) + else() + set(ARCHS armv7k) + set(APPLE_TARGET_TRIPLE_INT arm-apple-watchos) + endif() + else() + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-watchos) + endif() +elseif(PLATFORM_INT STREQUAL "WATCHOSCOMBINED") + set(SDK_NAME watchos) + if(MODERN_CMAKE) + if(NOT ARCHS) + if (XCODE_VERSION_INT VERSION_GREATER 10.0) + set(ARCHS armv7k arm64_32 i386) + set(APPLE_TARGET_TRIPLE_INT aarch64_32-i386-apple-watchos) + set(CMAKE_XCODE_ATTRIBUTE_ARCHS[sdk=watchos*] "armv7k arm64_32") + set(CMAKE_XCODE_ATTRIBUTE_ARCHS[sdk=watchsimulator*] "i386") + set(CMAKE_XCODE_ATTRIBUTE_VALID_ARCHS[sdk=watchos*] "armv7k arm64_32") + set(CMAKE_XCODE_ATTRIBUTE_VALID_ARCHS[sdk=watchsimulator*] "i386") + else() + set(ARCHS armv7k i386) + set(APPLE_TARGET_TRIPLE_INT arm-i386-apple-watchos) + set(CMAKE_XCODE_ATTRIBUTE_ARCHS[sdk=watchos*] "armv7k") + set(CMAKE_XCODE_ATTRIBUTE_ARCHS[sdk=watchsimulator*] "i386") + set(CMAKE_XCODE_ATTRIBUTE_VALID_ARCHS[sdk=watchos*] "armv7k") + set(CMAKE_XCODE_ATTRIBUTE_VALID_ARCHS[sdk=watchsimulator*] "i386") + endif() + else() + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-watchos) + endif() + else() + message(FATAL_ERROR "Please make sure that you are running CMake 3.14+ to make the WATCHOSCOMBINED setting work") + endif() +elseif(PLATFORM_INT STREQUAL "SIMULATOR_WATCHOS") + set(SDK_NAME watchsimulator) + if(NOT ARCHS) + set(ARCHS i386) + set(APPLE_TARGET_TRIPLE_INT i386-apple-watchos) + else() + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-watchos) + endif() +elseif(PLATFORM_INT STREQUAL "MAC" OR PLATFORM_INT STREQUAL "MAC_CATALYST") + set(SDK_NAME macosx) + if(NOT ARCHS) + set(ARCHS x86_64) + endif() + string(REPLACE ";" "-" ARCHS_SPLIT "${ARCHS}") + if(PLATFORM_INT STREQUAL "MAC") + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-macosx) + elseif(PLATFORM_INT STREQUAL "MAC_CATALYST") + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-ios${DEPLOYMENT_TARGET}-macabi) + endif() +elseif(PLATFORM_INT MATCHES "^(MAC_ARM64)$|^(MAC_CATALYST_ARM64)$") + set(SDK_NAME macosx) + if(NOT ARCHS) + set(ARCHS arm64) + endif() + string(REPLACE ";" "-" ARCHS_SPLIT "${ARCHS}") + if(PLATFORM_INT STREQUAL "MAC_ARM64") + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-macosx) + elseif(PLATFORM_INT STREQUAL "MAC_CATALYST_ARM64") + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-ios${DEPLOYMENT_TARGET}-macabi) + endif() +elseif(PLATFORM_INT STREQUAL "XROS") + set(SDK_NAME xros) + if(NOT ARCHS) + set(ARCHS arm64) + set(APPLE_TARGET_TRIPLE_INT arm64-apple-xros) + else() + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-xros) + endif() +elseif(PLATFORM_INT STREQUAL "XRSIMULATOR") + set(SDK_NAME xrsimulator) + if(NOT ARCHS) + set(ARCHS arm64) + set(APPLE_TARGET_TRIPLE_INT arm64-apple-xros-simulator) + else() + set(APPLE_TARGET_TRIPLE_INT ${ARCHS_SPLIT}-apple-xros-simulator) + endif() +else() + message(FATAL_ERROR "Invalid PLATFORM: ${PLATFORM_INT}") +endif() + +if(MODERN_CMAKE AND PLATFORM_INT MATCHES ".*COMBINED" AND NOT CMAKE_GENERATOR MATCHES "Xcode") + message(FATAL_ERROR "The COMBINED options only work with Xcode generator, -G Xcode") +endif() + +if(CMAKE_GENERATOR MATCHES "Xcode" AND PLATFORM_INT MATCHES "MAC_CATALYST_.*") + set(CMAKE_XCODE_ATTRIBUTE_CLANG_CXX_LIBRARY "libc++") + set(CMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS "macosx") + set(CMAKE_XCODE_EFFECTIVE_PLATFORMS "-maccatalyst") + if(NOT DEFINED MACOSX_DEPLOYMENT_TARGET) + set(CMAKE_XCODE_ATTRIBUTE_MACOSX_DEPLOYMENT_TARGET "10.15") + else() + set(CMAKE_XCODE_ATTRIBUTE_MACOSX_DEPLOYMENT_TARGET "${MACOSX_DEPLOYMENT_TARGET}") + endif() +elseif(CMAKE_GENERATOR MATCHES "Xcode") + set(CMAKE_XCODE_ATTRIBUTE_IPHONEOS_DEPLOYMENT_TARGET "${DEPLOYMENT_TARGET}") + if(NOT PLATFORM_INT MATCHES ".*COMBINED") + set(CMAKE_XCODE_ATTRIBUTE_ARCHS[sdk=${SDK_NAME}*] "${ARCHS}") + set(CMAKE_XCODE_ATTRIBUTE_VALID_ARCHS[sdk=${SDK_NAME}*] "${ARCHS}") + endif() +endif() + +# If user did not specify the SDK root to use, then query xcodebuild for it. +if(DEFINED CMAKE_OSX_SYSROOT_INT) + # Environment variables are always preserved. + set(ENV{_CMAKE_OSX_SYSROOT_INT} "${CMAKE_OSX_SYSROOT_INT}") +elseif(DEFINED ENV{_CMAKE_OSX_SYSROOT_INT}) + set(CMAKE_OSX_SYSROOT_INT "$ENV{_CMAKE_OSX_SYSROOT_INT}") +elseif(NOT DEFINED CMAKE_OSX_SYSROOT_INT) + execute_process(COMMAND ${XCODEBUILD_EXECUTABLE} -version -sdk ${SDK_NAME} Path + OUTPUT_VARIABLE CMAKE_OSX_SYSROOT_INT + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) +endif() + +if (NOT DEFINED CMAKE_OSX_SYSROOT_INT AND NOT DEFINED CMAKE_OSX_SYSROOT) + message(SEND_ERROR "Please make sure that Xcode is installed and that the toolchain" + "is pointing to the correct path. Please run:" + "sudo xcode-select -s /Applications/Xcode.app/Contents/Developer" + "and see if that fixes the problem for you.") + message(FATAL_ERROR "Invalid CMAKE_OSX_SYSROOT: ${CMAKE_OSX_SYSROOT} " + "does not exist.") +elseif(DEFINED CMAKE_OSX_SYSROOT_INT) + set(CMAKE_OSX_SYSROOT_INT "${CMAKE_OSX_SYSROOT_INT}" CACHE INTERNAL "") + # Specify the location or name of the platform SDK to be used in CMAKE_OSX_SYSROOT. + set(CMAKE_OSX_SYSROOT "${CMAKE_OSX_SYSROOT_INT}" CACHE INTERNAL "") +endif() + +# Use bitcode or not +if(NOT DEFINED ENABLE_BITCODE AND NOT ARCHS MATCHES "((^|;|, )(i386|x86_64))+") + # Unless specified, enable bitcode support by default + message(STATUS "[DEFAULTS] Enabling bitcode support by default. ENABLE_BITCODE not provided!") + set(ENABLE_BITCODE TRUE) +elseif(NOT DEFINED ENABLE_BITCODE) + message(STATUS "[DEFAULTS] Disabling bitcode support by default on simulators. ENABLE_BITCODE not provided for override!") + set(ENABLE_BITCODE FALSE) +endif() +set(ENABLE_BITCODE_INT ${ENABLE_BITCODE} CACHE BOOL + "Whether or not to enable bitcode" FORCE) +# Use ARC or not +if(NOT DEFINED ENABLE_ARC) + # Unless specified, enable ARC support by default + set(ENABLE_ARC TRUE) + message(STATUS "[DEFAULTS] Enabling ARC support by default. ENABLE_ARC not provided!") +endif() +set(ENABLE_ARC_INT ${ENABLE_ARC} CACHE BOOL "Whether or not to enable ARC" FORCE) +# Use hidden visibility or not +if(NOT DEFINED ENABLE_VISIBILITY) + # Unless specified, disable symbols visibility by default + set(ENABLE_VISIBILITY FALSE) + message(STATUS "[DEFAULTS] Hiding symbols visibility by default. ENABLE_VISIBILITY not provided!") +endif() +set(ENABLE_VISIBILITY_INT ${ENABLE_VISIBILITY} CACHE BOOL "Whether or not to hide symbols from the dynamic linker (-fvisibility=hidden)" FORCE) +# Set strict compiler checks or not +if(NOT DEFINED ENABLE_STRICT_TRY_COMPILE) + # Unless specified, disable strict try_compile() + set(ENABLE_STRICT_TRY_COMPILE FALSE) + message(STATUS "[DEFAULTS] Using NON-strict compiler checks by default. ENABLE_STRICT_TRY_COMPILE not provided!") +endif() +set(ENABLE_STRICT_TRY_COMPILE_INT ${ENABLE_STRICT_TRY_COMPILE} CACHE BOOL + "Whether or not to use strict compiler checks" FORCE) + +# Get the SDK version information. +if(DEFINED SDK_VERSION) + # Environment variables are always preserved. + set(ENV{_SDK_VERSION} "${SDK_VERSION}") +elseif(DEFINED ENV{_SDK_VERSION}) + set(SDK_VERSION "$ENV{_SDK_VERSION}") +elseif(NOT DEFINED SDK_VERSION) + execute_process(COMMAND ${XCODEBUILD_EXECUTABLE} -sdk ${CMAKE_OSX_SYSROOT_INT} -version SDKVersion + OUTPUT_VARIABLE SDK_VERSION + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) +endif() + +# Find the Developer root for the specific iOS platform being compiled for +# from CMAKE_OSX_SYSROOT. Should be ../../ from SDK specified in +# CMAKE_OSX_SYSROOT. There does not appear to be a direct way to obtain +# this information from xcrun or xcodebuild. +if (NOT DEFINED CMAKE_DEVELOPER_ROOT AND NOT CMAKE_GENERATOR MATCHES "Xcode") + get_filename_component(PLATFORM_SDK_DIR ${CMAKE_OSX_SYSROOT_INT} PATH) + get_filename_component(CMAKE_DEVELOPER_ROOT ${PLATFORM_SDK_DIR} PATH) + if (NOT EXISTS "${CMAKE_DEVELOPER_ROOT}") + message(FATAL_ERROR "Invalid CMAKE_DEVELOPER_ROOT: ${CMAKE_DEVELOPER_ROOT} does not exist.") + endif() +endif() + +# Find the C & C++ compilers for the specified SDK. +if(DEFINED CMAKE_C_COMPILER) + # Environment variables are always preserved. + set(ENV{_CMAKE_C_COMPILER} "${CMAKE_C_COMPILER}") +elseif(DEFINED ENV{_CMAKE_C_COMPILER}) + set(CMAKE_C_COMPILER "$ENV{_CMAKE_C_COMPILER}") +elseif(NOT DEFINED CMAKE_C_COMPILER) + execute_process(COMMAND xcrun -sdk ${CMAKE_OSX_SYSROOT_INT} -find clang + OUTPUT_VARIABLE CMAKE_C_COMPILER + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) +endif() +if(DEFINED CMAKE_CXX_COMPILER) + # Environment variables are always preserved. + set(ENV{_CMAKE_CXX_COMPILER} "${CMAKE_CXX_COMPILER}") +elseif(DEFINED ENV{_CMAKE_CXX_COMPILER}) + set(CMAKE_CXX_COMPILER "$ENV{_CMAKE_CXX_COMPILER}") +elseif(NOT DEFINED CMAKE_CXX_COMPILER) + execute_process(COMMAND xcrun -sdk ${CMAKE_OSX_SYSROOT_INT} -find clang++ + OUTPUT_VARIABLE CMAKE_CXX_COMPILER + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) +endif() +# Find (Apple's) libtool. +if(DEFINED BUILD_LIBTOOL) + # Environment variables are always preserved. + set(ENV{_BUILD_LIBTOOL} "${BUILD_LIBTOOL}") +elseif(DEFINED ENV{_BUILD_LIBTOOL}) + set(BUILD_LIBTOOL "$ENV{_BUILD_LIBTOOL}") +elseif(NOT DEFINED BUILD_LIBTOOL) + execute_process(COMMAND xcrun -sdk ${CMAKE_OSX_SYSROOT_INT} -find libtool + OUTPUT_VARIABLE BUILD_LIBTOOL + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) +endif() +# Find the toolchain's provided install_name_tool if none is found on the host +if(DEFINED CMAKE_INSTALL_NAME_TOOL) + # Environment variables are always preserved. + set(ENV{_CMAKE_INSTALL_NAME_TOOL} "${CMAKE_INSTALL_NAME_TOOL}") +elseif(DEFINED ENV{_CMAKE_INSTALL_NAME_TOOL}) + set(CMAKE_INSTALL_NAME_TOOL "$ENV{_CMAKE_INSTALL_NAME_TOOL}") +elseif(NOT DEFINED CMAKE_INSTALL_NAME_TOOL) + execute_process(COMMAND xcrun -sdk ${CMAKE_OSX_SYSROOT_INT} -find install_name_tool + OUTPUT_VARIABLE CMAKE_INSTALL_NAME_TOOL_INT + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE) + set(CMAKE_INSTALL_NAME_TOOL ${CMAKE_INSTALL_NAME_TOOL_INT} CACHE INTERNAL "") +endif() + +# Configure libtool to be used instead of ar + ranlib to build static libraries. +# This is required on Xcode 7+, but should also work on previous versions of +# Xcode. +get_property(languages GLOBAL PROPERTY ENABLED_LANGUAGES) +foreach(lang ${languages}) + set(CMAKE_${lang}_CREATE_STATIC_LIBRARY "${BUILD_LIBTOOL} -static -o " CACHE INTERNAL "") +endforeach() + +# CMake 3.14+ support building for iOS, watchOS and tvOS out of the box. +if(MODERN_CMAKE) + if(SDK_NAME MATCHES "iphone") + set(CMAKE_SYSTEM_NAME iOS) + elseif(SDK_NAME MATCHES "macosx") + set(CMAKE_SYSTEM_NAME Darwin) + elseif(SDK_NAME MATCHES "appletv") + set(CMAKE_SYSTEM_NAME tvOS) + elseif(SDK_NAME MATCHES "watch") + set(CMAKE_SYSTEM_NAME watchOS) + elseif(SDK_NAME MATCHES "xros" OR SDK_NAME MATCHES "xrsimulator") + set(CMAKE_SYSTEM_NAME visionOS) + endif() + # Provide flags for a combined FAT library build on newer CMake versions + if(PLATFORM_INT MATCHES ".*COMBINED") + set(CMAKE_XCODE_ATTRIBUTE_ONLY_ACTIVE_ARCH "NO") + set(CMAKE_IOS_INSTALL_COMBINED YES) + message(STATUS "Will combine built (static) artifacts into FAT lib...") + endif() +elseif(NOT DEFINED CMAKE_SYSTEM_NAME AND ${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.10") + # Legacy code path prior to CMake 3.14 or fallback if no CMAKE_SYSTEM_NAME specified + set(CMAKE_SYSTEM_NAME iOS) +elseif(NOT DEFINED CMAKE_SYSTEM_NAME) + # Legacy code path prior to CMake 3.14 or fallback if no CMAKE_SYSTEM_NAME specified + set(CMAKE_SYSTEM_NAME Darwin) +endif() +# Standard settings. +set(CMAKE_SYSTEM_VERSION ${SDK_VERSION} CACHE INTERNAL "") +set(UNIX TRUE CACHE BOOL "") +set(APPLE TRUE CACHE BOOL "") +if(PLATFORM STREQUAL "MAC" OR PLATFORM STREQUAL "MAC_ARM64") + set(IOS FALSE CACHE BOOL "") + set(MACOS TRUE CACHE BOOL "") +elseif(PLATFORM STREQUAL "MAC_CATALYST" OR PLATFORM STREQUAL "MAC_CATALYST_ARM64") + set(IOS TRUE CACHE BOOL "") + set(MACOS TRUE CACHE BOOL "") +else() + set(IOS TRUE CACHE BOOL "") +endif() +set(CMAKE_AR ar CACHE FILEPATH "" FORCE) +set(CMAKE_RANLIB ranlib CACHE FILEPATH "" FORCE) +set(CMAKE_STRIP strip CACHE FILEPATH "" FORCE) +# Set the architectures for which to build. +set(CMAKE_OSX_ARCHITECTURES ${ARCHS} CACHE INTERNAL "") +# Change the type of target generated for try_compile() so it'll work when cross-compiling, weak compiler checks +if(NOT ENABLE_STRICT_TRY_COMPILE_INT) + set(CMAKE_TRY_COMPILE_TARGET_TYPE STATIC_LIBRARY) +endif() +# All iOS/Darwin specific settings - some may be redundant. +set(CMAKE_MACOSX_BUNDLE YES) +set(CMAKE_XCODE_ATTRIBUTE_CODE_SIGNING_REQUIRED "NO") +set(CMAKE_SHARED_LIBRARY_PREFIX "lib") +set(CMAKE_SHARED_LIBRARY_SUFFIX ".dylib") +set(CMAKE_SHARED_MODULE_PREFIX "lib") +set(CMAKE_SHARED_MODULE_SUFFIX ".so") +set(CMAKE_C_COMPILER_ABI ELF) +set(CMAKE_CXX_COMPILER_ABI ELF) +set(CMAKE_C_HAS_ISYSROOT 1) +set(CMAKE_CXX_HAS_ISYSROOT 1) +set(CMAKE_MODULE_EXISTS 1) +set(CMAKE_DL_LIBS "") +set(CMAKE_C_OSX_COMPATIBILITY_VERSION_FLAG "-compatibility_version ") +set(CMAKE_C_OSX_CURRENT_VERSION_FLAG "-current_version ") +set(CMAKE_CXX_OSX_COMPATIBILITY_VERSION_FLAG "${CMAKE_C_OSX_COMPATIBILITY_VERSION_FLAG}") +set(CMAKE_CXX_OSX_CURRENT_VERSION_FLAG "${CMAKE_C_OSX_CURRENT_VERSION_FLAG}") + +if(ARCHS MATCHES "((^|;|, )(arm64|arm64e|x86_64))+") + set(CMAKE_C_SIZEOF_DATA_PTR 8) + set(CMAKE_CXX_SIZEOF_DATA_PTR 8) + if(ARCHS MATCHES "((^|;|, )(arm64|arm64e))+") + set(CMAKE_SYSTEM_PROCESSOR "aarch64") + else() + set(CMAKE_SYSTEM_PROCESSOR "x86_64") + endif() +else() + set(CMAKE_C_SIZEOF_DATA_PTR 4) + set(CMAKE_CXX_SIZEOF_DATA_PTR 4) + set(CMAKE_SYSTEM_PROCESSOR "arm") +endif() + +# Note that only Xcode 7+ supports the newer more specific: +# -m${SDK_NAME}-version-min flags, older versions of Xcode use: +# -m(ios/ios-simulator)-version-min instead. +if(${CMAKE_VERSION} VERSION_LESS "3.11") + if(PLATFORM_INT STREQUAL "OS" OR PLATFORM_INT STREQUAL "OS64") + if(XCODE_VERSION_INT VERSION_LESS 7.0) + set(SDK_NAME_VERSION_FLAGS + "-mios-version-min=${DEPLOYMENT_TARGET}") + else() + # Xcode 7.0+ uses flags we can build directly from SDK_NAME. + set(SDK_NAME_VERSION_FLAGS + "-m${SDK_NAME}-version-min=${DEPLOYMENT_TARGET}") + endif() + elseif(PLATFORM_INT STREQUAL "TVOS") + set(SDK_NAME_VERSION_FLAGS + "-mtvos-version-min=${DEPLOYMENT_TARGET}") + elseif(PLATFORM_INT STREQUAL "SIMULATOR_TVOS") + set(SDK_NAME_VERSION_FLAGS + "-mtvos-simulator-version-min=${DEPLOYMENT_TARGET}") + elseif(PLATFORM_INT STREQUAL "WATCHOS") + set(SDK_NAME_VERSION_FLAGS + "-mwatchos-version-min=${DEPLOYMENT_TARGET}") + elseif(PLATFORM_INT STREQUAL "SIMULATOR_WATCHOS") + set(SDK_NAME_VERSION_FLAGS + "-mwatchos-simulator-version-min=${DEPLOYMENT_TARGET}") + elseif(PLATFORM_INT STREQUAL "MAC") + set(SDK_NAME_VERSION_FLAGS + "-mmacosx-version-min=${DEPLOYMENT_TARGET}") + else() + # SIMULATOR or SIMULATOR64 both use -mios-simulator-version-min. + set(SDK_NAME_VERSION_FLAGS + "-mios-simulator-version-min=${DEPLOYMENT_TARGET}") + endif() +elseif(NOT PLATFORM_INT STREQUAL "MAC_CATALYST") + # Newer versions of CMake sets the version min flags correctly, skip this for Mac Catalyst targets + set(CMAKE_OSX_DEPLOYMENT_TARGET ${DEPLOYMENT_TARGET}) +endif() + +if(DEFINED APPLE_TARGET_TRIPLE_INT) + set(APPLE_TARGET_TRIPLE ${APPLE_TARGET_TRIPLE_INT} CACHE INTERNAL "") +endif() + +if(PLATFORM_INT STREQUAL "MAC_CATALYST") + set(C_TARGET_FLAGS "-target ${APPLE_TARGET_TRIPLE_INT} -isystem ${CMAKE_OSX_SYSROOT_INT}/System/iOSSupport/usr/include") +endif() + +if(ENABLE_BITCODE_INT) + set(BITCODE "-fembed-bitcode") + set(CMAKE_XCODE_ATTRIBUTE_BITCODE_GENERATION_MODE "bitcode") + set(CMAKE_XCODE_ATTRIBUTE_ENABLE_BITCODE "YES") +else() + set(BITCODE "") + set(CMAKE_XCODE_ATTRIBUTE_ENABLE_BITCODE "NO") +endif() + +if(ENABLE_ARC_INT) + set(FOBJC_ARC "-fobjc-arc") + set(CMAKE_XCODE_ATTRIBUTE_CLANG_ENABLE_OBJC_ARC "YES") +else() + set(FOBJC_ARC "-fno-objc-arc") + set(CMAKE_XCODE_ATTRIBUTE_CLANG_ENABLE_OBJC_ARC "NO") +endif() + +if(NOT ENABLE_VISIBILITY_INT) + foreach(lang ${languages}) + set(CMAKE_${lang}_VISIBILITY_PRESET "hidden" CACHE INTERNAL "") + endforeach() + set(CMAKE_XCODE_ATTRIBUTE_GCC_SYMBOLS_PRIVATE_EXTERN "YES") + set(VISIBILITY "-fvisibility=hidden -fvisibility-inlines-hidden") +else() + foreach(lang ${languages}) + set(CMAKE_${lang}_VISIBILITY_PRESET "default" CACHE INTERNAL "") + endforeach() + set(CMAKE_XCODE_ATTRIBUTE_GCC_SYMBOLS_PRIVATE_EXTERN "NO") + set(VISIBILITY "-fvisibility=default") +endif() + +#Check if Xcode generator is used, since that will handle these flags automagically +if(CMAKE_GENERATOR MATCHES "Xcode") + message(STATUS "Not setting any manual command-line buildflags, since Xcode is selected as generator.") +else() + # Hidden visibility is required for C++ on iOS. + set(CMAKE_C_FLAGS "${C_TARGET_FLAGS} ${SDK_NAME_VERSION_FLAGS} ${BITCODE} -fobjc-abi-version=2 ${FOBJC_ARC} ${CMAKE_C_FLAGS}") + set(CMAKE_CXX_FLAGS "${C_TARGET_FLAGS} ${SDK_NAME_VERSION_FLAGS} ${BITCODE} ${VISIBILITY} -fobjc-abi-version=2 ${FOBJC_ARC} ${CMAKE_CXX_FLAGS}") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS} -O0 -g ${CMAKE_CXX_FLAGS_DEBUG}") + set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS} -DNDEBUG -Os -ffast-math ${CMAKE_CXX_FLAGS_MINSIZEREL}") + set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS} -DNDEBUG -O2 -g -ffast-math ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS} -DNDEBUG -O3 -ffast-math ${CMAKE_CXX_FLAGS_RELEASE}") + set(CMAKE_C_LINK_FLAGS "${C_TARGET_FLAGS} ${SDK_NAME_VERSION_FLAGS} -Wl,-search_paths_first ${CMAKE_C_LINK_FLAGS}") + set(CMAKE_CXX_LINK_FLAGS "${C_TARGET_FLAGS} ${SDK_NAME_VERSION_FLAGS} -Wl,-search_paths_first ${CMAKE_CXX_LINK_FLAGS}") + set(CMAKE_ASM_FLAGS "${CMAKE_C_FLAGS} -x assembler-with-cpp -arch ${CMAKE_OSX_ARCHITECTURES}") +endif() + +## Print status messages to inform of the current state +message(STATUS "Configuring ${SDK_NAME} build for platform: ${PLATFORM_INT}, architecture(s): ${ARCHS}") +message(STATUS "Using SDK: ${CMAKE_OSX_SYSROOT_INT}") +message(STATUS "Using C compiler: ${CMAKE_C_COMPILER}") +message(STATUS "Using CXX compiler: ${CMAKE_CXX_COMPILER}") +message(STATUS "Using libtool: ${BUILD_LIBTOOL}") +message(STATUS "Using install name tool: ${CMAKE_INSTALL_NAME_TOOL}") +if(DEFINED APPLE_TARGET_TRIPLE) + message(STATUS "Autoconf target triple: ${APPLE_TARGET_TRIPLE}") +endif() +message(STATUS "Using minimum deployment version: ${DEPLOYMENT_TARGET}" + " (SDK version: ${SDK_VERSION})") +if(MODERN_CMAKE) + message(STATUS "Merging integrated CMake 3.14+ iOS,tvOS,watchOS,macOS toolchain(s) with this toolchain!") +endif() +if(CMAKE_GENERATOR MATCHES "Xcode") + message(STATUS "Using Xcode version: ${XCODE_VERSION_INT}") +endif() +message(STATUS "CMake version: ${CMAKE_VERSION}") +if(DEFINED SDK_NAME_VERSION_FLAGS) + message(STATUS "Using version flags: ${SDK_NAME_VERSION_FLAGS}") +endif() +message(STATUS "Using a data_ptr size of: ${CMAKE_CXX_SIZEOF_DATA_PTR}") +if(ENABLE_BITCODE_INT) + message(STATUS "Bitcode: Enabled") +else() + message(STATUS "Bitcode: Disabled") +endif() + +if(ENABLE_ARC_INT) + message(STATUS "ARC: Enabled") +else() + message(STATUS "ARC: Disabled") +endif() + +if(ENABLE_VISIBILITY_INT) + message(STATUS "Hiding symbols: Disabled") +else() + message(STATUS "Hiding symbols: Enabled") +endif() + +# Set global properties +set_property(GLOBAL PROPERTY PLATFORM "${PLATFORM}") +set_property(GLOBAL PROPERTY APPLE_TARGET_TRIPLE "${APPLE_TARGET_TRIPLE_INT}") +set_property(GLOBAL PROPERTY SDK_VERSION "${SDK_VERSION}") +set_property(GLOBAL PROPERTY XCODE_VERSION "${XCODE_VERSION_INT}") +set_property(GLOBAL PROPERTY OSX_ARCHITECTURES "${CMAKE_OSX_ARCHITECTURES}") + +# Export configurable variables for the try_compile() command. +set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES + PLATFORM + XCODE_VERSION_INT + SDK_VERSION + DEPLOYMENT_TARGET + CMAKE_DEVELOPER_ROOT + CMAKE_OSX_SYSROOT_INT + ENABLE_BITCODE + ENABLE_ARC + CMAKE_C_COMPILER + CMAKE_CXX_COMPILER + BUILD_LIBTOOL + CMAKE_INSTALL_NAME_TOOL + CMAKE_C_FLAGS + CMAKE_CXX_FLAGS + CMAKE_CXX_FLAGS_DEBUG + CMAKE_CXX_FLAGS_MINSIZEREL + CMAKE_CXX_FLAGS_RELWITHDEBINFO + CMAKE_CXX_FLAGS_RELEASE + CMAKE_C_LINK_FLAGS + CMAKE_CXX_LINK_FLAGS + CMAKE_ASM_FLAGS + ) + +set(CMAKE_PLATFORM_HAS_INSTALLNAME 1) +set(CMAKE_SHARED_LINKER_FLAGS "-rpath @executable_path/Frameworks -rpath @loader_path/Frameworks") +set(CMAKE_SHARED_LIBRARY_CREATE_C_FLAGS "-dynamiclib -Wl,-headerpad_max_install_names") +set(CMAKE_SHARED_MODULE_CREATE_C_FLAGS "-bundle -Wl,-headerpad_max_install_names") +set(CMAKE_SHARED_MODULE_LOADER_C_FLAG "-Wl,-bundle_loader,") +set(CMAKE_SHARED_MODULE_LOADER_CXX_FLAG "-Wl,-bundle_loader,") +set(CMAKE_FIND_LIBRARY_SUFFIXES ".tbd" ".dylib" ".so" ".a") +set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-install_name") + +# Set the find root to the SDK developer roots. +# Note: CMAKE_FIND_ROOT_PATH is only useful when cross-compiling. Thus, do not set on macOS builds. +if(NOT PLATFORM_INT STREQUAL "MAC" AND NOT PLATFORM_INT STREQUAL "MAC_ARM64") + list(APPEND CMAKE_FIND_ROOT_PATH "${CMAKE_OSX_SYSROOT_INT}" CACHE INTERNAL "") + set(CMAKE_IGNORE_PATH "/System/Library/Frameworks;/usr/local/lib" CACHE INTERNAL "") +endif() + +# Default to searching for frameworks first. +set(CMAKE_FIND_FRAMEWORK FIRST) + +# Set up the default search directories for frameworks. +if(PLATFORM_INT MATCHES "MAC_CATALYST.*") + set(CMAKE_FRAMEWORK_PATH + ${CMAKE_DEVELOPER_ROOT}/Library/PrivateFrameworks + ${CMAKE_OSX_SYSROOT_INT}/System/Library/Frameworks + ${CMAKE_OSX_SYSROOT_INT}/System/iOSSupport/System/Library/Frameworks + ${CMAKE_FRAMEWORK_PATH} CACHE INTERNAL "") +else() + set(CMAKE_FRAMEWORK_PATH + ${CMAKE_DEVELOPER_ROOT}/Library/PrivateFrameworks + ${CMAKE_OSX_SYSROOT_INT}/System/Library/Frameworks + ${CMAKE_FRAMEWORK_PATH} CACHE INTERNAL "") +endif() + +# By default, search both the specified iOS SDK and the remainder of the host filesystem. +if(NOT CMAKE_FIND_ROOT_PATH_MODE_PROGRAM) + set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM BOTH CACHE INTERNAL "") +endif() +if(NOT CMAKE_FIND_ROOT_PATH_MODE_LIBRARY) + set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH CACHE INTERNAL "") +endif() +if(NOT CMAKE_FIND_ROOT_PATH_MODE_INCLUDE) + set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH CACHE INTERNAL "") +endif() +if(NOT CMAKE_FIND_ROOT_PATH_MODE_PACKAGE) + set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH CACHE INTERNAL "") +endif() + +# +# Some helper-macros below to simplify and beautify the CMakeFile +# + +# This little macro lets you set any Xcode specific property. +macro(set_xcode_property TARGET XCODE_PROPERTY XCODE_VALUE XCODE_RELVERSION) + set(XCODE_RELVERSION_I "${XCODE_RELVERSION}") + if(XCODE_RELVERSION_I STREQUAL "All") + set_property(TARGET ${TARGET} PROPERTY XCODE_ATTRIBUTE_${XCODE_PROPERTY} "${XCODE_VALUE}") + else() + set_property(TARGET ${TARGET} PROPERTY XCODE_ATTRIBUTE_${XCODE_PROPERTY}[variant=${XCODE_RELVERSION_I}] "${XCODE_VALUE}") + endif() +endmacro(set_xcode_property) + +# This macro lets you find executable programs on the host system. +macro(find_host_package) + set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) + set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY NEVER) + set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE NEVER) + set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE NEVER) + set(_TOOLCHAIN_IOS ${IOS}) + set(IOS FALSE) + find_package(${ARGN}) + set(IOS ${_TOOLCHAIN_IOS}) + set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM BOTH) + set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH) + set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH) + set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH) +endmacro(find_host_package) diff --git a/apps/frameworks/sherpa-mnn/toolchains/riscv64-linux-gnu.toolchain.cmake b/apps/frameworks/sherpa-mnn/toolchains/riscv64-linux-gnu.toolchain.cmake new file mode 100644 index 00000000..8f07b964 --- /dev/null +++ b/apps/frameworks/sherpa-mnn/toolchains/riscv64-linux-gnu.toolchain.cmake @@ -0,0 +1,17 @@ +# Copied from https://github.com/Tencent/ncnn/blob/master/toolchains/riscv64-linux-gnu.toolchain.cmake +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_SYSTEM_PROCESSOR riscv64) + +set(CMAKE_C_COMPILER "riscv64-unknown-linux-gnu-gcc") +set(CMAKE_CXX_COMPILER "riscv64-unknown-linux-gnu-g++") + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) + +set(CMAKE_C_FLAGS "-march=rv64gc") +set(CMAKE_CXX_FLAGS "-march=rv64gc") + +# cache flags +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "c flags") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "c++ flags")