Compare commits

...

219 Commits

Author SHA1 Message Date
Xiang Lyu 55e3e370a0
Merge pull request #1758 from orbisai0security/fix/V-005-pickle-deserialization
[Security] Fix CRITICAL vulnerability: V-005
2025-12-31 11:45:44 +08:00
lyuxiang.lx 46788c7379 update readme 2025-12-31 11:28:47 +08:00
Xiang Lyu 881177287c
Merge pull request #1640 from Jzz1943/main
support vLLM >=0.11.0 (V1 engine) for better performance
2025-12-31 10:40:02 +08:00
Xiang Lyu f88a14e41d
Merge branch 'main' into main 2025-12-31 10:37:18 +08:00
lyuxiang.lx dac6566fc3 update tokens 2025-12-30 16:14:36 +00:00
lyuxiang.lx cc91e40db8 more silent_token 2025-12-30 15:53:33 +00:00
lyuxiang.lx ab7f1f4a86 update silent token 2025-12-30 15:12:18 +00:00
lyuxiang.lx e15222b17c refine code 2025-12-30 09:24:52 +00:00
lyuxiang.lx cfa1c115b2 add silent_token 2025-12-30 09:18:17 +00:00
lyuxiang.lx dd5cdb6ebf Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main 2025-12-30 14:03:52 +08:00
lyuxiang.lx 2d7ef0b719 remove instruct warning 2025-12-30 12:36:18 +08:00
Xiang Lyu ba5db602a9
Merge pull request #1622 from GoyoUijin/Fix/token2wav-cache-thread-unsafe
fix triton token2wav model cache thread unsafety
2025-12-30 11:24:25 +08:00
orbisai0security 5b94675f62 fix: resolve critical vulnerability V-005
Automatically generated security fix
2025-12-29 13:25:05 +00:00
lyuxiang.lx 4c19646b9a update dataset 2025-12-29 12:46:34 +00:00
lyuxiang.lx 63a06227d1 Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main 2025-12-29 18:32:54 +08:00
lyuxiang.lx 3b44913782 fix bug 2025-12-29 10:30:54 +00:00
Xiang Lyu 055f64d002
Merge pull request #1728 from yigitcatak/main
fixed a typo in Dockerfile regarding environment variable
2025-12-29 18:24:39 +08:00
lyuxiang.lx 4d7295a9a7 fix cosyvoice3 training 2025-12-29 10:03:14 +00:00
lyuxiang.lx 8524c81acd fix cv3 train 2025-12-29 10:03:14 +00:00
Xiang Lyu a14e063ead
Merge pull request #1722 from majiayu000/fix/issue-1683-ja-language-tag
docs: fix Japanese language tag in example.py comment
2025-12-29 16:12:54 +08:00
lyuxiang.lx 2db78e7058 fix export_jit.py 2025-12-23 17:23:23 +08:00
lyuxiang.lx 7538c6a73d update export_jit 2025-12-23 15:23:29 +08:00
yigitcatak 823ae2c60d fix a typo in Dockerfile 2025-12-23 16:22:48 +09:00
lyuxiang.lx 59cb2bf16c fix jit_export 2025-12-23 14:06:40 +08:00
majiayu000 80bebb1978 docs: fix Japanese language tag in example.py comment
Changed <|jp|> to <|ja|> to match the actual tokenizer implementation.

The LANGUAGES dict in cosyvoice/tokenizer/tokenizer.py defines 'ja' for Japanese,
not 'jp'. This fixes the misleading comment that could cause issues like #621.

Fixes #1683
2025-12-22 19:11:47 +08:00
Xiang Lyu bc34459bb8
Merge pull request #1693 from whiteshirt0429/main
Fix CosyVoice3 config error
2025-12-22 13:56:19 +08:00
lyuxiang.lx 9f27b42cd9 update 2025-12-17 18:58:50 +08:00
lyuxiang.lx a7d6e2251a update libritts cosyvoice3.yaml 2025-12-17 17:15:10 +08:00
lyuxiang.lx 7baefaf0f2 update libritts cosyvoice3.yaml 2025-12-17 17:14:17 +08:00
di.wu ff0d05c380 Fix CosyVoice3 config error 2025-12-17 14:57:17 +08:00
lyuxiang.lx f5816b4e51 update readme 2025-12-17 03:16:00 +00:00
lyuxiang.lx 8b54619760 update 2025-12-16 15:00:03 +08:00
lyuxiang.lx 2abd42220e add x-transformer version 2025-12-16 14:45:02 +08:00
lyuxiang.lx 2d6bb9bd80 Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main 2025-12-15 16:16:35 +08:00
lyuxiang.lx 0b80c0746a update 2025-12-15 16:13:53 +08:00
Xiangang Li e98b828f33
Update README.md 2025-12-15 15:54:32 +08:00
lyuxiang.lx 4d4c787be0 update 2025-12-15 15:33:52 +08:00
lyuxiang.lx 781a49acb4 update metric 2025-12-15 14:52:32 +08:00
lyuxiang.lx 9476a063b3 update metric 2025-12-15 14:48:17 +08:00
Xiang Lyu 3426ceb70f
Merge pull request #1671 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2025-12-15 14:00:40 +08:00
hengwu.zty a460960ade [mod]readme 2025-12-15 13:53:11 +08:00
hengwu.zty f51f5c5c6a [mod]readme 2025-12-15 13:39:56 +08:00
hengwu.zty f11ba4024c [mod]readme 2025-12-15 13:24:44 +08:00
lyuxiang.lx 089343ab0a update 2025-12-15 12:44:06 +08:00
Xiang Lyu 0c50894d49
Merge pull request #1670 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2025-12-15 12:17:39 +08:00
lyuxiang.lx 95d56cba64 update 2025-12-15 03:56:42 +00:00
lyuxiang.lx 095f7bad55 update readme 2025-12-15 03:53:10 +00:00
lyuxiang.lx a6eb2c56da update dingding 2025-12-14 14:00:40 +00:00
lyuxiang.lx ca3b054a52 fix bistream bug 2025-12-12 14:28:27 +00:00
lyuxiang.lx b02d7e61f7 update prompt wav 2025-12-12 14:28:26 +00:00
hengwu.zty 6b6a5a7bd1 Merge branch 'dev/lyuxiang.lx' of http://gitlab.alibaba-inc.com/NLS/CosyVoice into dev/lyuxiang.lx 2025-12-12 18:39:18 +08:00
hengwu.zty 5640545406 update readme 2025-12-12 18:38:30 +08:00
lyuxiang.lx 5bc4b23f02 use amp in flow 2025-12-12 07:44:36 +00:00
lyuxiang.lx ebef63066f add instruct 2025-12-12 07:44:36 +00:00
quantu 3298d6f3e3 readme 2025-12-11 15:55:26 +08:00
lyuxiang.lx f21c4764ec fix bug 2025-12-11 06:04:36 +00:00
lyuxiang.lx 927addadd8 fix lint 2025-12-10 02:17:00 +00:00
lyuxiang.lx a051a09ba4 remove unncessary code 2025-12-09 15:41:02 +00:00
lyuxiang.lx 0c65d3c7ab use automodel 2025-12-09 15:15:05 +00:00
lyuxiang.lx 56d9876037 fix bug 2025-12-09 07:57:10 +00:00
lyuxiang.lx b35ece675b dit need higher trt version 2025-12-08 11:28:45 +00:00
lyuxiang.lx 59f02cb85d add vllm example 2025-12-08 11:14:12 +00:00
lyuxiang.lx b4dd67a8af add cosyvoice3 vllm example 2025-12-08 10:55:53 +00:00
lyuxiang.lx bfa835a74b add cosyvoice3 inference code 2025-12-08 10:04:11 +00:00
lyuxiang.lx 622a3a19b0 use wav file rather than tensor 2025-12-08 08:43:09 +00:00
lyuxiang.lx d985100326 Merge branch 'main' into dev/lyuxiang.lx 2025-12-04 18:00:17 +08:00
zhongze.jiang 6816fc6a6f support vLLM >=0.11.0 (V1 engine only) 2025-11-10 16:30:42 +08:00
김의진 e8bf717333 Fix: generate token2wav_request_id from cosyvoice2
- Since all token2wav requests within a single cosyvoice2 request must share the same request_id, modify the logic so that a new request_id is generated only if it does not already exist, and ensure that the same request_id is sent consistently.
2025-10-27 18:12:17 +09:00
김의진 fa2781405f Revert "fix triton token2wav model cache thread unsafety"
This reverts commit cd26dd1932.
2025-10-27 18:07:30 +09:00
김의진 cd26dd1932 fix triton token2wav model cache thread unsafety 2025-10-27 17:20:14 +09:00
Xiang Lyu 6e01309e01
Merge pull request #1598 from yuekaizhang/streaming
[Runtime] StepAudio2 Streaming DiT Token2Wav Integration
2025-10-21 15:30:45 +08:00
yuekaiz 1fc8435146 add disaggregated deployment 2025-10-16 15:58:22 +08:00
yuekaiz a224be6117 fix lint 2025-10-09 15:18:09 +08:00
yuekaiz 33aee03ed5 fix lint 2025-10-09 15:13:43 +08:00
root 8811e9f33a fix white space 2025-10-09 14:49:22 +08:00
root 807bb6ee0b add dit results 2025-10-08 22:01:24 +08:00
root aceede59ba fix bug 2025-10-08 18:13:09 +08:00
root 7cbd490253 add docker compose for streaming tts 2025-10-08 17:20:04 +08:00
root a019a2504e clean code 2025-10-08 16:48:00 +08:00
root f186ec3338 clean code 2025-10-08 15:21:52 +08:00
root 988d395162 mark multi client 2025-10-08 14:06:19 +08:00
lyuxiang.lx 4d60ff6abc Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main 2025-09-28 15:27:41 +08:00
lyuxiang.lx be005c825f fix vllm transformer version bug 2025-09-28 15:25:42 +08:00
root 79116ac32e remove cache router 2025-09-26 15:14:31 +08:00
root 31a0adc73d mark stateless token2wav 2025-09-26 14:51:41 +08:00
yuekaiz 482464ea27 add streaming dit 2025-09-24 15:18:01 +08:00
root 444b7ff5df fix cache shallow copy 2025-09-19 13:48:32 +08:00
yuekaiz b207c60885 init step-audio2 token2wav 2025-09-18 19:07:23 +08:00
Xiang Lyu 0b357ba25d
Merge pull request #1583 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2025-09-17 10:57:18 +08:00
Xiang Lyu 0867ebcb8c
Merge pull request #1566 from yuekaizhang/streaming
[runtime: TRT-LLM] support prompt audio cache & offline inference mode
2025-09-12 11:52:28 +08:00
root 52556a6de9 fix lint 2025-09-08 09:59:58 +00:00
root 66ef5a097b fix lint 2025-09-08 09:55:33 +00:00
root cc1991870b add cosyvoice2 offline inference 2025-09-08 17:37:33 +08:00
yuekaiz 8ded65e611 set use_spk2info_cache=False 2025-09-05 14:03:23 +08:00
yuekaiz 6971536358 add prompt audio cache 2025-09-05 13:55:41 +08:00
Xiang Lyu 86e7c2d731
Merge pull request #1561 from yuekaizhang/streaming
[Runtime] Support Streaming TTS for Triton + TensorRT-LLM runtime
2025-09-04 09:48:43 +08:00
Yuekai Zhang 8a4309d89c fix space 2025-09-03 03:51:38 -07:00
Yuekai Zhang ad257b06e3 fix lint 2025-09-03 03:45:17 -07:00
yuekaiz 633b991290 update readme 2025-09-03 17:42:14 +08:00
yuekaiz e04699c6da add spk trt 2025-09-03 11:44:36 +08:00
root 73d261dd48 support streaming tts 2025-09-02 18:32:12 +08:00
Xiang Lyu b7ec6c4678
Merge pull request #1459 from huiwq1990/feat-fixpip
pip install disable cache
2025-09-01 17:34:19 +08:00
lyuxiang.lx f76f5abcc1 update 2025-08-22 14:42:34 +08:00
lyuxiang.lx 6b5eef62cc update 2025-08-22 14:24:27 +08:00
lyuxiang.lx dc96e4c984 update 2025-08-21 21:03:58 +08:00
lyuxiang.lx 70991d7327 update 2025-08-21 20:08:08 +08:00
lyuxiang.lx 8c96081f94 update 2025-08-21 11:45:36 +08:00
lyuxiang.lx dd2d926147 update 2025-08-20 16:55:03 +08:00
lyuxiang.lx da41f6175b update 2025-08-19 18:53:18 +08:00
lyuxiang.lx e3c2400abb update 2025-08-19 16:23:59 +08:00
lyuxiang.lx a976519ada fix transformer version 2025-08-14 16:24:24 +08:00
lyuxiang.lx cf615011ce update readme 2025-08-07 16:55:46 +08:00
Xiang Lyu 9ddb9e4a83
Merge pull request #1491 from yuekaizhang/rl
[recipe] Add GRPO training recipe for cosyvoice2 llm
2025-08-07 16:45:49 +08:00
Xiang Lyu 0a496c18f7
Merge pull request #1507 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2025-08-05 11:18:46 +08:00
lyuxiang.lx 05bdf4c769 add contributor info 2025-08-05 11:15:42 +08:00
Xiang Lyu 1850e2a56e
Merge pull request #1489 from yuekaizhang/triton
[runtime] Support Cosyvoice2 Nvidia TensorRT-LLM Inference Solution
2025-08-05 10:21:40 +08:00
root 47e4137651 update test set 2025-07-30 11:07:50 +00:00
root 0bc48c1180 update readme 2025-07-30 11:05:49 +00:00
root 62d082634e fix lint 2025-07-29 08:40:51 +00:00
root 07cbc51cd1 fix lint 2025-07-29 08:39:41 +00:00
root d1c354eac7 add huggingface to pretrained 2025-07-29 07:54:42 +00:00
yuekaiz 1b8d194b67 fix commit 2025-07-29 12:01:55 +08:00
yuekaiz b44f121102 update readme 2025-07-29 11:58:23 +08:00
yuekaiz dc196df940 fix decoupled mode 2025-07-29 11:13:07 +08:00
Yuekai Zhang 178da09993 clean code 2025-07-27 23:33:10 -07:00
lyuxiang.lx 11515d0d5a use bf16 for amp 2025-07-28 11:55:38 +08:00
Yuekai Zhang 5427c274e3 add triton solution 2025-07-22 06:50:13 -07:00
huiwq1990 3387f07266 pip install disable cache
Signed-off-by: huiwq1990 <huiwq1990@163.com>
2025-07-17 10:35:44 +08:00
lyuxiang.lx b048a2d6db fix dpo bug 2025-07-16 20:09:10 +08:00
lyuxiang.lx 8555549e88 fix lint 2025-07-07 14:52:12 +08:00
lyuxiang.lx 3047591fad fix lint 2025-07-07 14:49:51 +08:00
lyuxiang.lx 5a00aefa20 update 2025-07-07 14:27:03 +08:00
lyuxiang.lx 35abb1f3b5 Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main 2025-07-07 10:56:06 +08:00
lyuxiang.lx 21a5efd8ae fix lint 2025-06-29 13:21:44 +08:00
Xiang Lyu 44316c3475
Merge pull request #1408 from xxnuo/wetext
Replace WeTextProcessing with wetext to enhance compatibility on all platforms
2025-06-29 12:56:57 +08:00
xxnuo 116c99bf39 Fix: remove overwrite_cache 2025-06-27 15:57:17 +08:00
xxnuo 6eaef42126 Fix: remove full_to_half 2025-06-27 15:55:00 +08:00
xxnuo 525531d8a3 Update README.md 2025-06-27 14:52:58 +08:00
xxnuo c788bca1a6 Update requirements.txt 2025-06-27 14:50:35 +08:00
xxnuo ff9694dc2c Fix: Use wetext cache 2025-06-27 14:45:07 +08:00
xxnuo 4505924608 Fix: Use wetext replace WeTextProcessing 2025-06-27 14:43:57 +08:00
lyuxiang.lx 46dfe0439b fix possible bug 2025-06-25 14:31:05 +08:00
lyuxiang.lx 63856565f3 update dpo 2025-06-24 11:49:20 +08:00
Xiang Lyu cc234bd322
Merge pull request #1369 from UncleLLD/Update-README.md
Update README.md CosyVoice2 vllm Usage
2025-06-12 16:21:43 +08:00
UncleLLD 98ef35b1c0
Update README.md CosyVoice2 vllm Usage
First change env and then pip libs
2025-06-12 15:12:58 +08:00
Xiang Lyu fca1df3ea7
Merge pull request #1340 from Soooda/main
Fix README Typos
2025-06-02 09:53:05 +08:00
lyuxiang.lx c939c80480 update 2025-06-01 08:46:47 +00:00
Hilbert Kong 5bf6befd70
Fix README Typos 2025-06-01 17:19:50 +09:00
lyuxiang.lx 1c7976779b update 2025-05-30 13:07:56 +00:00
lyuxiang.lx a8e1774e82 add vllm example 2025-05-30 10:06:14 +00:00
lyuxiang.lx 1f50ae259b update 2025-05-30 09:25:26 +00:00
Xiang Lyu 793a7fe6ad
Merge pull request #1337 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2025-05-30 17:16:30 +08:00
Xiang Lyu 79b2ea818b
Merge branch 'main' into dev/lyuxiang.lx 2025-05-30 17:16:22 +08:00
lyuxiang.lx 7d9d84d32d update 2025-05-30 09:12:03 +00:00
lyuxiang.lx 9b052a94c4 fix lint 2025-05-30 07:51:49 +00:00
lyuxiang.lx 6dd68b9d5e add vllm inference 2025-05-30 07:22:35 +00:00
root 9f55c5af8f add vllm export 2025-05-28 08:35:48 +00:00
Xiang Lyu b6c5f9dfd2
Merge pull request #1331 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2025-05-27 14:07:56 +08:00
lyuxiang.lx cbfed4a9ee send streaming as args 2025-05-27 13:47:12 +08:00
lyuxiang.lx 54d21b40f0 update 2025-05-27 10:40:45 +08:00
lyuxiang.lx c3250c222f add trt_concurrent arg 2025-05-27 10:34:59 +08:00
lyuxiang.lx 82219cdd27 update 2025-05-27 10:23:06 +08:00
Xiang Lyu 4159a18469
Merge pull request #1325 from FunAudioLLM/dev/lyuxiang.lx
fix trt wrapper bug
2025-05-26 18:04:02 +08:00
lyuxiang.lx 3e12bb86bd fix trt wrapper bug 2025-05-26 18:03:15 +08:00
Xiang Lyu cbfbe2bc33
Merge pull request #1317 from Lsnsh/patch-1
docs(README): set Markdown headings for paragraphs to support quick anchor points
2025-05-26 11:28:28 +08:00
Lsnsh Xin 3660da4a19
docs(README): set Markdown headings for paragraphs to support quick anchor points 2025-05-24 04:17:28 +08:00
Xiang Lyu 3c921daede
Merge pull request #1315 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2025-05-23 16:11:40 +08:00
lyuxiang.lx 68100c267a remove flow_cache 2025-05-23 16:07:27 +08:00
burkliu fbab274b6a [feature] modify pad to trim
Conflicts:
	cosyvoice/dataset/processor.py
2025-05-06 10:53:37 +08:00
lyuxiang.lx 97f0bc61cd remove unnecessary file 2025-05-06 10:41:55 +08:00
lyuxiang.lx afb1a70f7a Merge branch 'main' of github.com:FunAudioLLM/CosyVoice into main 2025-05-06 10:37:48 +08:00
Xiang Lyu 5c77e40304
Merge pull request #1256 from hwangsihu/main
Reorder requirements.txt
2025-05-04 16:27:47 +08:00
hwangsihu b4c4d848ca Reorder requirements.txt 2025-05-01 13:28:15 +09:00
Xiang Lyu 88f467a8ac
Merge pull request #1232 from boji123/bj_dev_feat_len_pad
a better solution for mismatch of speech feat len and speech token len when trainning
2025-04-30 09:41:50 +08:00
burkliu 038ff9f353 [feature] modify pad to trim 2025-04-25 10:36:25 +08:00
burkliu 65ad448714 [debug] a better solution for mismatch of speech feat len and speech token len, refer to https://github.com/FunAudioLLM/CosyVoice/issues/1051 2025-04-25 10:36:08 +08:00
lyuxiang.lx a96ae13616 fix instruct2 bug 2025-04-23 15:40:59 +08:00
bearlu 587604b2b4 fix inference_instruct2 speaker ID bug 2025-04-21 09:26:34 -07:00
lyuxiang.lx e97cd1b655 fix cross_lingual bug 2025-04-19 09:08:47 +08:00
lyuxiang.lx 8d67d17f73 update 2025-04-16 20:18:49 +08:00
lyuxiang.lx a442317d17 add flow trt wrapper 2025-04-16 17:57:02 +08:00
lyuxiang.lx 7f8bea2669 Merge branch 'dev/lyuxiang.lx' of github.com:FunAudioLLM/CosyVoice into dev/lyuxiang.lx 2025-04-16 15:00:24 +08:00
ShengqiangLi 6d876f573c feat: Support DPO 2025-04-16 14:57:35 +08:00
lyuxiang.lx 3770c1c8b1 fix bug 2025-04-16 13:48:13 +08:00
lyuxiang.lx 2c193781cc fix export_onnx.py 2025-04-15 17:29:42 +08:00
lyuxiang.lx efe1d15960 fix export_onnx.py 2025-04-15 16:51:25 +08:00
lyuxiang.lx 9ebcf7b1ad fix bug 2025-04-15 16:36:10 +08:00
lyuxiang.lx 37e48dd318 optimize vc code 2025-04-15 16:15:20 +08:00
lyuxiang.lx c07cd3d730 fix lint 2025-04-15 15:00:29 +08:00
lyuxiang.lx 36aec2c0f7 optimize flow cache code 2025-04-15 13:58:52 +08:00
lyuxiang.lx d71d790f55 fix flow cache bug 2025-04-15 13:12:25 +08:00
lyuxiang.lx e1ffb1e978 fix lint 2025-04-15 11:46:53 +08:00
lyuxiang.lx 9fea0f0836 add_zero_shot_spk 2025-04-15 11:45:00 +08:00
lyuxiang.lx 9dc559fc2a force set use_flow_cache 2025-04-08 12:23:26 +08:00
Xiang Lyu 634edfadf0
Merge pull request #983 from Shengqiang-Li/main
feat: Support DPO
2025-04-08 12:14:51 +08:00
Xiang Lyu b56dfa223d
Merge pull request #1140 from FunAudioLLM/dev/lyuxiang.lx
Dev/lyuxiang.lx
2025-04-07 23:04:37 +08:00
lyuxiang.lx f0b8e892f6 fix lint 2025-04-07 23:04:01 +08:00
lyuxiang.lx cfc68f379c only keep online export 2025-04-07 22:54:13 +08:00
lyuxiang.lx 4951d2ad1a update 2025-04-07 22:45:47 +08:00
lyuxiang.lx d9ffd592f6 use static_chunk_size in flow training 2025-04-07 22:34:45 +08:00
lyuxiang.lx 7902d1c17f set use_flow_cache=True when export 2025-04-07 22:27:17 +08:00
lyuxiang.lx 39ffc50dec add flow cache inference code 2025-04-07 21:23:09 +08:00
Xiang Lyu 08312f4c46
Merge pull request #1100 from jingfelix/fix/dockerfile-dependency
fix: missing dependency in `runtime/python/Dockerfile`
2025-03-25 08:30:26 +08:00
jingfelix c6d8737336 fix: missing dependency in runtime/python/Dockerfile
Signed-off-by: jingfelix <jingfelix@outlook.com>
2025-03-22 23:34:59 +08:00
ShengqiangLi a22873e360 feat: Support DPO 2025-03-16 18:33:42 +08:00
Xiang Lyu c97b445df4
Merge pull request #1036 from hwangsihu/main
Fixed an issue where onnxruntime would not install on Windows
2025-03-10 15:48:04 +08:00
hwangsihu 265507f213 Fixed an issue where onnxruntime would not install on Windows 2025-03-10 16:07:01 +09:00
lyuxiang.lx a69b7e275d fix vocoder train 2025-03-07 17:34:13 +08:00
lyuxiang.lx fcc054f64e fix hifigan bug 2025-02-18 14:45:43 +08:00
lyuxiang.lx 890300513c fix bug 2025-02-11 15:56:20 +08:00
lyuxiang.lx f77c6a85aa fix bug 2025-02-11 00:07:12 +08:00
lyuxiang.lx b6d66ce2e3 update 2025-02-08 13:58:41 +08:00
lyuxiang.lx 8e4f252d32 add llm bistream 2025-02-08 12:15:37 +08:00
lyuxiang.lx 79b7dff8d2 add llm train 2025-02-07 17:17:12 +08:00
lyuxiang.lx 2a3e033ee1 fix lint 2025-02-07 16:23:03 +08:00
lyuxiang.lx 24f796a2b1 Merge branch 'main' into dev/lyuxiang.lx 2025-01-26 16:56:18 +08:00
lyuxiang.lx fd1a951a6c add flow unified training 2025-01-26 16:56:06 +08:00
lyuxiang.lx aea75207dd fix cache bug 2025-01-24 11:07:26 +08:00
lyuxiang.lx 1c062ab381 add flow decoder cache 2025-01-23 16:48:13 +08:00
114 changed files with 13295 additions and 1386 deletions

View File

@ -52,5 +52,5 @@ jobs:
set -eux
pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
flake8 --version
flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504,F401,F403,F405,F722,F841 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
if [ $? != 0 ]; then exit 1; fi

283
README.md
View File

@ -1,39 +1,52 @@
[![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=CosyVoice🤠&text2=Text-to-Speech%20💖%20Large%20Language%20Model&width=800&height=210)](https://github.com/Akshay090/svg-banners)
![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=CosyVoice🤠&text2=Text-to-Speech%20💖%20Large%20Language%20Model&width=800&height=210)
## 👉🏻 CosyVoice 👈🏻
**CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/abs/2412.10117); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/spaces/FunAudioLLM/CosyVoice2-0.5B)
**CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/studios/iic/CosyVoice-300M)
**Fun-CosyVoice 3.0**: [Demos](https://funaudiollm.github.io/cosyvoice3/); [Paper](https://arxiv.org/pdf/2505.17589); [Modelscope](https://www.modelscope.cn/models/FunAudioLLM/Fun-CosyVoice3-0.5B-2512); [Huggingface](https://huggingface.co/FunAudioLLM/Fun-CosyVoice3-0.5B-2512); [CV3-Eval](https://github.com/FunAudioLLM/CV3-Eval)
**CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/pdf/2412.10117); [Modelscope](https://www.modelscope.cn/models/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/FunAudioLLM/CosyVoice2-0.5B)
**CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/models/iic/CosyVoice-300M); [HuggingFace](https://huggingface.co/FunAudioLLM/CosyVoice-300M)
## Highlight🔥
**CosyVoice 2.0** has been released! Compared to version 1.0, the new version offers more accurate, more stable, faster, and better speech generation capabilities.
### Multilingual
- **Supported Language**: Chinese, English, Japanese, Korean, Chinese dialects (Cantonese, Sichuanese, Shanghainese, Tianjinese, Wuhanese, etc.)
- **Crosslingual & Mixlingual**Support zero-shot voice cloning for cross-lingual and code-switching scenarios.
### Ultra-Low Latency
- **Bidirectional Streaming Support**: CosyVoice 2.0 integrates offline and streaming modeling technologies.
- **Rapid First Packet Synthesis**: Achieves latency as low as 150ms while maintaining high-quality audio output.
### High Accuracy
- **Improved Pronunciation**: Reduces pronunciation errors by 30% to 50% compared to CosyVoice 1.0.
- **Benchmark Achievements**: Attains the lowest character error rate on the hard test set of the Seed-TTS evaluation set.
### Strong Stability
- **Consistency in Timbre**: Ensures reliable voice consistency for zero-shot and cross-language speech synthesis.
- **Cross-language Synthesis**: Marked improvements compared to version 1.0.
### Natural Experience
- **Enhanced Prosody and Sound Quality**: Improved alignment of synthesized audio, raising MOS evaluation scores from 5.4 to 5.53.
- **Emotional and Dialectal Flexibility**: Now supports more granular emotional controls and accent adjustments.
**Fun-CosyVoice 3.0** is an advanced text-to-speech (TTS) system based on large language models (LLM), surpassing its predecessor (CosyVoice 2.0) in content consistency, speaker similarity, and prosody naturalness. It is designed for zero-shot multilingual speech synthesis in the wild.
### Key Features
- **Language Coverage**: Covers 9 common languages (Chinese, English, Japanese, Korean, German, Spanish, French, Italian, Russian), 18+ Chinese dialects/accents (Guangdong, Minnan, Sichuan, Dongbei, Shan3xi, Shan1xi, Shanghai, Tianjin, Shandong, Ningxia, Gansu, etc.) and meanwhile supports both multi-lingual/cross-lingual zero-shot voice cloning.
- **Content Consistency & Naturalness**: Achieves state-of-the-art performance in content consistency, speaker similarity, and prosody naturalness.
- **Pronunciation Inpainting**: Supports pronunciation inpainting of Chinese Pinyin and English CMU phonemes, providing more controllability and thus suitable for production use.
- **Text Normalization**: Supports reading of numbers, special symbols and various text formats without a traditional frontend module.
- **Bi-Streaming**: Support both text-in streaming and audio-out streaming, and achieves latency as low as 150ms while maintaining high-quality audio output.
- **Instruct Support**: Supports various instructions such as languages, dialects, emotions, speed, volume, etc.
## Roadmap
- [x] 2025/12
- [x] release Fun-CosyVoice3-0.5B-2512 base model, rl model and its training/inference script
- [x] release Fun-CosyVoice3-0.5B modelscope gradio space
- [x] 2025/08
- [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support and cosyvoice2 grpo training support
- [x] 2025/07
- [x] release Fun-CosyVoice 3.0 eval set
- [x] 2025/05
- [x] add CosyVoice2-0.5B vllm support
- [x] 2024/12
- [x] 25hz cosyvoice 2.0 released
- [x] 25hz CosyVoice2-0.5B released
- [x] 2024/09
- [x] 25hz cosyvoice base model
- [x] 25hz cosyvoice voice conversion model
- [x] 25hz CosyVoice-300M base model
- [x] 25hz CosyVoice-300M voice conversion function
- [x] 2024/08
@ -46,65 +59,82 @@
- [x] WeTextProcessing support when ttsfrd is not available
- [x] Fastapi server and client
## Evaluation
| Model | Open-Source | Model Size | test-zh<br>CER (%) ↓ | test-zh<br>SS (%) ↑ | test-en<br>WER (%) ↓ | test-en<br>SS (%) ↑ | test-hard<br>CER (%) ↓ | test-hard<br>SS (%) ↑ |
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Human | - | - | 1.26 | 75.5 | 2.14 | 73.4 | - | - |
| Seed-TTS | ❌ | - | 1.12 | 79.6 | 2.25 | 76.2 | 7.59 | 77.6 |
| MiniMax-Speech | ❌ | - | 0.83 | 78.3 | 1.65 | 69.2 | - | - |
| F5-TTS | ✅ | 0.3B | 1.52 | 74.1 | 2.00 | 64.7 | 8.67 | 71.3 |
| Spark TTS | ✅ | 0.5B | 1.2 | 66.0 | 1.98 | 57.3 | - | - |
| CosyVoice2 | ✅ | 0.5B | 1.45 | 75.7 | 2.57 | 65.9 | 6.83 | 72.4 |
| FireRedTTS2 | ✅ | 1.5B | 1.14 | 73.2 | 1.95 | 66.5 | - | - |
| Index-TTS2 | ✅ | 1.5B | 1.03 | 76.5 | 2.23 | 70.6 | 7.12 | 75.5 |
| VibeVoice-1.5B | ✅ | 1.5B | 1.16 | 74.4 | 3.04 | 68.9 | - | - |
| VibeVoice-Realtime | ✅ | 0.5B | - | - | 2.05 | 63.3 | - | - |
| HiggsAudio-v2 | ✅ | 3B | 1.50 | 74.0 | 2.44 | 67.7 | - | - |
| VoxCPM | ✅ | 0.5B | 0.93 | 77.2 | 1.85 | 72.9 | 8.87 | 73.0 |
| GLM-TTS | ✅ | 1.5B | 1.03 | 76.1 | - | - | - | - |
| GLM-TTS RL | ✅ | 1.5B | 0.89 | 76.4 | - | - | - | - |
| Fun-CosyVoice3-0.5B-2512 | ✅ | 0.5B | 1.21 | 78.0 | 2.24 | 71.8 | 6.71 | 75.8 |
| Fun-CosyVoice3-0.5B-2512_RL | ✅ | 0.5B | 0.81 | 77.4 | 1.68 | 69.5 | 5.44 | 75.0 |
## Install
**Clone and install**
### Clone and install
- Clone the repo
``` sh
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
# If you failed to clone submodule due to network failures, please run following command until success
cd CosyVoice
git submodule update --init --recursive
```
``` sh
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
# If you failed to clone the submodule due to network failures, please run the following command until success
cd CosyVoice
git submodule update --init --recursive
```
- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
- Create Conda env:
``` sh
conda create -n cosyvoice -y python=3.10
conda activate cosyvoice
# pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platform.
conda install -y -c conda-forge pynini==2.1.5
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
``` sh
conda create -n cosyvoice -y python=3.10
conda activate cosyvoice
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
# If you encounter sox compatibility issues
# ubuntu
sudo apt-get install sox libsox-dev
# centos
sudo yum install sox sox-devel
```
# If you encounter sox compatibility issues
# ubuntu
sudo apt-get install sox libsox-dev
# centos
sudo yum install sox sox-devel
```
**Model download**
### Model download
We strongly recommend that you download our pretrained `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
We strongly recommend that you download our pretrained `Fun-CosyVoice3-0.5B` `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
``` python
# SDK模型下载
# modelscope SDK model download
from modelscope import snapshot_download
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
snapshot_download('iic/CosyVoice-300M-25Hz', local_dir='pretrained_models/CosyVoice-300M-25Hz')
snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
# for oversea users, huggingface SDK model download
from huggingface_hub import snapshot_download
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
snapshot_download('FunAudioLLM/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
snapshot_download('FunAudioLLM/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
snapshot_download('FunAudioLLM/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
snapshot_download('FunAudioLLM/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
snapshot_download('FunAudioLLM/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
```
``` sh
# git模型下载请确保已安装git lfs
mkdir -p pretrained_models
git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git pretrained_models/CosyVoice2-0.5B
git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
git clone https://www.modelscope.cn/iic/CosyVoice-300M-25Hz.git pretrained_models/CosyVoice-300M-25Hz
git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
git clone https://www.modelscope.cn/iic/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct
git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
```
Optionally, you can unzip `ttsfrd` resource and install `ttsfrd` package for better text normalization performance.
Optionally, you can unzip `ttsfrd` resouce and install `ttsfrd` package for better text normalization performance.
Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use WeTextProcessing by default.
Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use wetext by default.
``` sh
cd pretrained_models/CosyVoice-ttsfrd/
@ -113,79 +143,31 @@ pip install ttsfrd_dependency-0.1-py3-none-any.whl
pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
```
**Basic Usage**
### Basic Usage
We strongly recommend using `CosyVoice2-0.5B` for better performance.
Follow code below for detailed usage of each model.
``` python
import sys
sys.path.append('third_party/Matcha-TTS')
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.utils.file_utils import load_wav
import torchaudio
We strongly recommend using `Fun-CosyVoice3-0.5B` for better performance.
Follow the code in `example.py` for detailed usage of each model.
```sh
python example.py
```
**CosyVoice2 Usage**
```python
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False)
#### vLLM Usage
CosyVoice2/3 now supports **vLLM 0.11.x+ (V1 engine)** and **vLLM 0.9.0 (legacy)**.
Older vllm version(<0.9.0) do not support CosyVoice inference, and versions in between (e.g., 0.10.x) are not tested.
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
# zero_shot usage
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
Notice that `vllm` has a lot of specific requirements. You can create a new env to in case your hardward do not support vllm and old env is corrupted.
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)):
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# instruct usage
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# bistream usage, you can use generator as input, this is useful when using text llm model as input
# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
def text_generator():
yield '收到好友从远方寄来的生日礼物,'
yield '那份意外的惊喜与深深的祝福'
yield '让我心中充满了甜蜜的快乐,'
yield '笑容如花儿般绽放。'
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
``` sh
conda create -n cosyvoice_vllm --clone cosyvoice
conda activate cosyvoice_vllm
# for vllm==0.9.0
pip install vllm==v0.9.0 transformers==4.51.3 numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
# for vllm>=0.11.0
pip install vllm==v0.11.0 transformers==4.57.1 numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
python vllm_example.py
```
**CosyVoice Usage**
```python
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=False, load_trt=False, fp16=False)
# sft usage
print(cosyvoice.list_available_spks())
# change stream=True for chunk stream inference
for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M') # or change to pretrained_models/CosyVoice-300M-25Hz for 25Hz inference
# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# cross_lingual usage
prompt_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000)
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# vc usage
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
source_speech_16k = load_wav('./asset/cross_lingual_prompt.wav', 16000)
for i, j in enumerate(cosyvoice.inference_vc(source_speech_16k, prompt_speech_16k, stream=False)):
torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong><strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
```
**Start web demo**
#### Start web demo
You can use our web demo page to get familiar with CosyVoice quickly.
@ -196,14 +178,14 @@ Please see the demo website for details.
python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
```
**Advanced Usage**
#### Advanced Usage
For advanced user, we have provided train and inference scripts in `examples/libritts/cosyvoice/run.sh`.
For advanced users, we have provided training and inference scripts in `examples/libritts`.
**Build for deployment**
#### Build for deployment
Optionally, if you want service deployment,
you can run following steps.
You can run the following steps.
``` sh
cd runtime/python
@ -217,6 +199,17 @@ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /o
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
```
#### Using Nvidia TensorRT-LLM for deployment
Using TensorRT-LLM to accelerate cosyvoice2 llm could give 4x acceleration comparing with huggingface transformers implementation.
To quick start:
``` sh
cd runtime/triton_trtllm
docker compose up -d
```
For more details, you could check [here](https://github.com/FunAudioLLM/CosyVoice/tree/main/runtime/triton_trtllm)
## Discussion & Communication
You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
@ -233,5 +226,39 @@ You can also scan the QR code to join our official Dingding chat group.
4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
## Citations
``` bibtex
@article{du2024cosyvoice,
title={Cosyvoice: A scalable multilingual zero-shot text-to-speech synthesizer based on supervised semantic tokens},
author={Du, Zhihao and Chen, Qian and Zhang, Shiliang and Hu, Kai and Lu, Heng and Yang, Yexin and Hu, Hangrui and Zheng, Siqi and Gu, Yue and Ma, Ziyang and others},
journal={arXiv preprint arXiv:2407.05407},
year={2024}
}
@article{du2024cosyvoice,
title={Cosyvoice 2: Scalable streaming speech synthesis with large language models},
author={Du, Zhihao and Wang, Yuxuan and Chen, Qian and Shi, Xian and Lv, Xiang and Zhao, Tianyu and Gao, Zhifu and Yang, Yexin and Gao, Changfeng and Wang, Hui and others},
journal={arXiv preprint arXiv:2412.10117},
year={2024}
}
@article{du2025cosyvoice,
title={CosyVoice 3: Towards In-the-wild Speech Generation via Scaling-up and Post-training},
author={Du, Zhihao and Gao, Changfeng and Wang, Yuxuan and Yu, Fan and Zhao, Tianyu and Wang, Hao and Lv, Xiang and Wang, Hui and Shi, Xian and An, Keyu and others},
journal={arXiv preprint arXiv:2505.17589},
year={2025}
}
@inproceedings{lyu2025build,
title={Build LLM-Based Zero-Shot Streaming TTS System with Cosyvoice},
author={Lyu, Xiang and Wang, Yuxuan and Zhao, Tianyu and Wang, Hao and Liu, Huadai and Du, Zhihao},
booktitle={ICASSP 2025-2025 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={1--2},
year={2025},
organization={IEEE}
}
```
## Disclaimer
The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 94 KiB

After

Width:  |  Height:  |  Size: 120 KiB

Binary file not shown.

View File

@ -75,10 +75,11 @@ def main():
print('Processing {}'.format(path))
states = torch.load(path, map_location=torch.device('cpu'))
for k in states.keys():
if k not in avg.keys():
avg[k] = states[k].clone()
else:
avg[k] += states[k]
if k not in ['step', 'epoch']:
if k not in avg.keys():
avg[k] = states[k].clone()
else:
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:

View File

@ -23,7 +23,8 @@ import torch
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.utils.file_utils import logging
def get_args():
@ -56,21 +57,16 @@ def main():
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
try:
model = CosyVoice(args.model_dir)
except Exception:
try:
model = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
model = AutoModel(model_dir=args.model_dir)
if not isinstance(model, CosyVoice2):
if model.__class__.__name__ == 'CosyVoice':
# 1. export llm text_encoder
llm_text_encoder = model.model.llm.text_encoder
script = get_optimized_script(llm_text_encoder)
script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(llm_text_encoder.half())
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export llm_text_encoder')
# 2. export llm llm
llm_llm = model.model.llm.llm
@ -78,13 +74,25 @@ def main():
script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
logging.info('successfully export llm_llm')
# 3. export flow encoder
flow_encoder = model.model.flow.encoder
script = get_optimized_script(flow_encoder)
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
# 3. export flow encoder
flow_encoder = model.model.flow.encoder
script = get_optimized_script(flow_encoder)
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder')
elif model.__class__.__name__ == 'CosyVoice2':
# 1. export flow encoder
flow_encoder = model.model.flow.encoder
script = get_optimized_script(flow_encoder)
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder')
else:
raise ValueError('unsupported model type')
if __name__ == '__main__':

View File

@ -27,7 +27,8 @@ from tqdm import tqdm
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.utils.file_utils import logging
def get_dummy_input(batch_size, seq_len, out_channels, device):
@ -51,21 +52,17 @@ def get_args():
return args
@torch.no_grad()
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
try:
model = CosyVoice(args.model_dir)
except Exception:
try:
model = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
model = AutoModel(model_dir=args.model_dir)
# 1. export flow decoder estimator
estimator = model.model.flow.decoder.estimator
estimator.eval()
device = model.model.device
batch_size, seq_len = 2, 256
@ -110,6 +107,7 @@ def main():
}
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
logging.info('successfully export estimator')
if __name__ == "__main__":

View File

@ -1,10 +0,0 @@
#!/bin/bash
# Copyright 2024 Alibaba Inc. All Rights Reserved.
# download tensorrt from https://developer.nvidia.com/tensorrt/download/10x, check your system and cuda for compatibability
# for example for linux + cuda12.4, you can download https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz
TRT_DIR=<YOUR_TRT_DIR>
MODEL_DIR=<COSYVOICE2_MODEL_DIR>
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64
$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp32.mygpu.plan --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw,fp32:chw --outputIOFormats=fp32:chw
$TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw

View File

@ -1,115 +0,0 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import torch
from torch.utils.data import DataLoader
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm
from cosyvoice.cli.model import CosyVoiceModel
from cosyvoice.dataset.dataset import Dataset
def get_args():
parser = argparse.ArgumentParser(description='inference with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--prompt_data', required=True, help='prompt data file')
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
parser.add_argument('--tts_text', required=True, help='tts input file')
parser.add_argument('--llm_model', required=True, help='llm model file')
parser.add_argument('--flow_model', required=True, help='flow model file')
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--mode',
default='sft',
choices=['sft', 'zero_shot'],
help='inference mode')
parser.add_argument('--result_dir', required=True, help='asr result file')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
# Init cosyvoice models from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f)
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
model.load(args.llm_model, args.flow_model, args.hifigan_model)
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
del configs
os.makedirs(args.result_dir, exist_ok=True)
fn = os.path.join(args.result_dir, 'wav.scp')
f = open(fn, 'w')
with torch.no_grad():
for _, batch in tqdm(enumerate(test_data_loader)):
utts = batch["utts"]
assert len(utts) == 1, "inference mode only support batchsize 1"
text_token = batch["text_token"].to(device)
text_token_len = batch["text_token_len"].to(device)
tts_index = batch["tts_index"]
tts_text_token = batch["tts_text_token"].to(device)
tts_text_token_len = batch["tts_text_token_len"].to(device)
speech_token = batch["speech_token"].to(device)
speech_token_len = batch["speech_token_len"].to(device)
speech_feat = batch["speech_feat"].to(device)
speech_feat_len = batch["speech_feat_len"].to(device)
utt_embedding = batch["utt_embedding"].to(device)
spk_embedding = batch["spk_embedding"].to(device)
if args.mode == 'sft':
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
else:
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'prompt_text': text_token, 'prompt_text_len': text_token_len,
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
tts_speeches = []
for model_output in model.tts(**model_input):
tts_speeches.append(model_output['tts_speech'])
tts_speeches = torch.concat(tts_speeches, dim=1)
tts_key = '{}_{}'.format(utts[0], tts_index[0])
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
f.write('{} {}\n'.format(tts_key, tts_fn))
f.flush()
f.close()
logging.info('Result wav.scp saved in {}'.format(fn))
if __name__ == '__main__':
main()

View File

@ -27,6 +27,7 @@ from hyperpyyaml import load_hyperpyyaml
from torch.distributed.elastic.multiprocessing.errors import record
from cosyvoice.utils.losses import DPOLoss
from cosyvoice.utils.executor import Executor
from cosyvoice.utils.train_utils import (
init_distributed,
@ -43,9 +44,11 @@ def get_args():
choices=['torch_ddp', 'deepspeed'],
help='Engine for paralleled training')
parser.add_argument('--model', required=True, help='model which will be trained')
parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--tensorboard_dir',
@ -72,6 +75,10 @@ def get_args():
action='store_true',
default=False,
help='Use automatic mixed precision training')
parser.add_argument('--dpo',
action='store_true',
default=False,
help='Use Direct Preference Optimization')
parser.add_argument('--deepspeed.save_states',
dest='save_states',
default='model_only',
@ -97,8 +104,12 @@ def main():
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
if gan is True:
override_dict.pop('hift')
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides=override_dict)
try:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
except Exception:
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides=override_dict)
if gan is True:
configs['train_conf'] = configs['train_conf_gan']
configs['train_conf'].update(vars(args))
@ -108,7 +119,7 @@ def main():
# Get dataset & dataloader
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
init_dataset_and_dataloader(args, configs, gan)
init_dataset_and_dataloader(args, configs, gan, args.dpo)
# Do some sanity checks and save config to arsg.model_dir
configs = check_modify_and_save_config(args, configs)
@ -117,6 +128,8 @@ def main():
writer = init_summarywriter(args)
# load checkpoint
if args.dpo is True:
configs[args.model].forward = configs[args.model].forward_dpo
model = configs[args.model]
start_step, start_epoch = 0, -1
if args.checkpoint is not None:
@ -145,13 +158,25 @@ def main():
info_dict['epoch'] = start_epoch
save_model(model, 'init', info_dict)
# DPO related
if args.dpo is True:
ref_model = deepcopy(configs[args.model])
state_dict = torch.load(args.ref_model, map_location='cpu')
ref_model.load_state_dict(state_dict, strict=False)
dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
# NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
ref_model = wrap_cuda_model(args, ref_model)
else:
ref_model, dpo_loss = None, None
# Get executor
executor = Executor(gan=gan)
executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
executor.step = start_step
# Init scaler, used for pytorch amp mixed precision training
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
print('start step {} start epoch {}'.format(start_step, start_epoch))
# Start training loop
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
executor.epoch = epoch
@ -162,7 +187,7 @@ def main():
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
writer, info_dict, scaler, group_join)
else:
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=ref_model)
dist.destroy_process_group(group_join)

View File

@ -19,22 +19,24 @@ from hyperpyyaml import load_hyperpyyaml
from modelscope import snapshot_download
import torch
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type
class CosyVoice:
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
self.instruct = True if '-Instruct' in model_dir else False
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
model_dir = snapshot_download(model_dir)
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir)
if not os.path.exists(hyper_yaml_path):
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f)
assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
configs['feat_extractor'],
'{}/campplus.onnx'.format(model_dir),
@ -56,6 +58,7 @@ class CosyVoice:
if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
trt_concurrent,
self.fp16)
del configs
@ -63,6 +66,17 @@ class CosyVoice:
spks = list(self.frontend.spk2info.keys())
return spks
def add_zero_shot_spk(self, prompt_text, prompt_wav, zero_shot_spk_id):
assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_wav, self.sample_rate, '')
del model_input['text']
del model_input['text_len']
self.frontend.spk2info[zero_shot_spk_id] = model_input
return True
def save_spkinfo(self):
torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_sft(i, spk_id)
@ -74,12 +88,14 @@ class CosyVoice:
yield model_output
start_time = time.time()
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
if self.__class__.__name__ == 'CosyVoice3' and '<|endofprompt|>' not in prompt_text + tts_text:
logging.warning('<|endofprompt|> not found in CosyVoice3 inference, check your input text')
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate)
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@ -88,9 +104,9 @@ class CosyVoice:
yield model_output
start_time = time.time()
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
def inference_cross_lingual(self, tts_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate)
model_input = self.frontend.frontend_cross_lingual(i, prompt_wav, self.sample_rate, zero_shot_spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@ -100,9 +116,7 @@ class CosyVoice:
start_time = time.time()
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
if self.instruct is False:
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
assert self.__class__.__name__ == 'CosyVoice', 'inference_instruct is only implemented for CosyVoice!'
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
@ -114,10 +128,10 @@ class CosyVoice:
yield model_output
start_time = time.time()
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
def inference_vc(self, source_wav, prompt_wav, stream=False, speed=1.0):
model_input = self.frontend.frontend_vc(source_wav, prompt_wav, self.sample_rate)
start_time = time.time()
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
@ -126,13 +140,15 @@ class CosyVoice:
class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
self.instruct = True if '-Instruct' in model_dir else False
def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
model_dir = snapshot_download(model_dir)
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
if not os.path.exists(hyper_yaml_path):
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
@ -142,28 +158,27 @@ class CosyVoice2(CosyVoice):
'{}/spk2info.pt'.format(model_dir),
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
load_jit, load_trt, fp16 = False, False, False
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or load_vllm is True or fp16 is True):
load_jit, load_trt, load_vllm, fp16 = False, False, False, False
logging.warning('no cuda device, set load_jit/load_trt/load_vllm/fp16 to False')
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir))
if load_vllm:
self.model.load_vllm('{}/vllm'.format(model_dir))
if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
trt_concurrent,
self.fp16)
del configs
def inference_instruct(self, *args, **kwargs):
raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, stream=False, speed=1.0, text_frontend=True):
assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate)
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
@ -171,3 +186,55 @@ class CosyVoice2(CosyVoice):
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
class CosyVoice3(CosyVoice2):
def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
model_dir = snapshot_download(model_dir)
hyper_yaml_path = '{}/cosyvoice3.yaml'.format(model_dir)
if not os.path.exists(hyper_yaml_path):
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
assert get_model_type(configs) == CosyVoice3Model, 'do not use {} for CosyVoice3 initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
configs['feat_extractor'],
'{}/campplus.onnx'.format(model_dir),
'{}/speech_tokenizer_v3.onnx'.format(model_dir),
'{}/spk2info.pt'.format(model_dir),
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (load_trt is True or fp16 is True):
load_trt, fp16 = False, False
logging.warning('no cuda device, set load_trt/fp16 to False')
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir))
if load_vllm:
self.model.load_vllm('{}/vllm'.format(model_dir))
if load_trt:
if self.fp16 is True:
logging.warning('DiT tensorRT fp16 engine have some performance issue, use at caution!')
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
trt_concurrent,
self.fp16)
del configs
def AutoModel(**kwargs):
if not os.path.exists(kwargs['model_dir']):
kwargs['model_dir'] = snapshot_download(kwargs['model_dir'])
if os.path.exists('{}/cosyvoice.yaml'.format(kwargs['model_dir'])):
return CosyVoice(**kwargs)
elif os.path.exists('{}/cosyvoice2.yaml'.format(kwargs['model_dir'])):
return CosyVoice2(**kwargs)
elif os.path.exists('{}/cosyvoice3.yaml'.format(kwargs['model_dir'])):
return CosyVoice3(**kwargs)
else:
raise TypeError('No valid model type found!')

View File

@ -20,19 +20,10 @@ import numpy as np
import whisper
from typing import Callable
import torchaudio.compliance.kaldi as kaldi
import torchaudio
import os
import re
import inflect
try:
import ttsfrd
use_ttsfrd = True
except ImportError:
print("failed to import ttsfrd, use WeTextProcessing instead")
from tn.chinese.normalizer import Normalizer as ZhNormalizer
from tn.english.normalizer import Normalizer as EnNormalizer
use_ttsfrd = False
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.file_utils import logging, load_wav
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
@ -56,21 +47,33 @@ class CosyVoiceFrontEnd:
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
"CPUExecutionProvider"])
if os.path.exists(spk2info):
self.spk2info = torch.load(spk2info, map_location=self.device)
self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True)
else:
self.spk2info = {}
self.allowed_special = allowed_special
self.use_ttsfrd = use_ttsfrd
if self.use_ttsfrd:
self.inflect_parser = inflect.engine()
# NOTE compatible when no text frontend tool is avaliable
try:
import ttsfrd
self.frd = ttsfrd.TtsFrontendEngine()
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
'failed to initialize ttsfrd resource'
self.frd.set_lang_type('pinyinvg')
else:
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
self.en_tn_model = EnNormalizer()
self.inflect_parser = inflect.engine()
self.text_frontend = 'ttsfrd'
logging.info('use ttsfrd frontend')
except:
try:
from wetext import Normalizer as ZhNormalizer
from wetext import Normalizer as EnNormalizer
self.zh_tn_model = ZhNormalizer(remove_erhua=False)
self.en_tn_model = EnNormalizer()
self.text_frontend = 'wetext'
logging.info('use wetext frontend')
except:
self.text_frontend = ''
logging.info('no frontend is avaliable')
def _extract_text_token(self, text):
if isinstance(text, Generator):
@ -89,7 +92,8 @@ class CosyVoiceFrontEnd:
for i in range(text_token.shape[1]):
yield text_token[:, i: i + 1]
def _extract_speech_token(self, speech):
def _extract_speech_token(self, prompt_wav):
speech = load_wav(prompt_wav, 16000)
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
speech_token = self.speech_tokenizer_session.run(None,
@ -101,7 +105,8 @@ class CosyVoiceFrontEnd:
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
return speech_token, speech_token_len
def _extract_spk_embedding(self, speech):
def _extract_spk_embedding(self, prompt_wav):
speech = load_wav(prompt_wav, 16000)
feat = kaldi.fbank(speech,
num_mel_bins=80,
dither=0,
@ -112,7 +117,8 @@ class CosyVoiceFrontEnd:
embedding = torch.tensor([embedding]).to(self.device)
return embedding
def _extract_speech_feat(self, speech):
def _extract_speech_feat(self, prompt_wav):
speech = load_wav(prompt_wav, 24000)
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
@ -122,15 +128,19 @@ class CosyVoiceFrontEnd:
if isinstance(text, Generator):
logging.info('get tts_text generator, will skip text_normalize!')
return [text]
if text_frontend is False:
# NOTE skip text_frontend when ssml symbol in text
if '<|' in text and '|>' in text:
text_frontend = False
if text_frontend is False or text == '':
return [text] if split is True else text
text = text.strip()
if self.use_ttsfrd:
if self.text_frontend == 'ttsfrd':
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
text = ''.join(texts)
else:
if contains_chinese(text):
text = self.zh_tn_model.normalize(text)
if self.text_frontend == 'wetext':
text = self.zh_tn_model.normalize(text)
text = text.replace("\n", "")
text = replace_blank(text)
text = replace_corner_mark(text)
@ -141,7 +151,8 @@ class CosyVoiceFrontEnd:
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
else:
text = self.en_tn_model.normalize(text)
if self.text_frontend == 'wetext':
text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
@ -154,28 +165,31 @@ class CosyVoiceFrontEnd:
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
return model_input
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
def frontend_zero_shot(self, tts_text, prompt_text, prompt_wav, resample_rate, zero_shot_spk_id):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
if resample_rate == 24000:
# cosyvoice2, force speech_feat % speech_token = 2
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
embedding = self._extract_spk_embedding(prompt_speech_16k)
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': embedding, 'flow_embedding': embedding}
if zero_shot_spk_id == '':
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav)
speech_token, speech_token_len = self._extract_speech_token(prompt_wav)
if resample_rate == 24000:
# cosyvoice2, force speech_feat % speech_token = 2
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
embedding = self._extract_spk_embedding(prompt_wav)
model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': embedding, 'flow_embedding': embedding}
else:
model_input = {**self.spk2info[zero_shot_spk_id]}
model_input['text'] = tts_text_token
model_input['text_len'] = tts_text_token_len
return model_input
def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate):
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate)
def frontend_cross_lingual(self, tts_text, prompt_wav, resample_rate, zero_shot_spk_id):
model_input = self.frontend_zero_shot(tts_text, '', prompt_wav, resample_rate, zero_shot_spk_id)
# in cross lingual mode, we remove prompt in llm
del model_input['prompt_text']
del model_input['prompt_text_len']
@ -187,22 +201,21 @@ class CosyVoiceFrontEnd:
model_input = self.frontend_sft(tts_text, spk_id)
# in instruct mode, we remove spk_embedding in llm due to information leakage
del model_input['llm_embedding']
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text)
model_input['prompt_text'] = instruct_text_token
model_input['prompt_text_len'] = instruct_text_token_len
return model_input
def frontend_instruct2(self, tts_text, instruct_text, prompt_speech_16k, resample_rate):
model_input = self.frontend_zero_shot(tts_text, instruct_text + '<|endofprompt|>', prompt_speech_16k, resample_rate)
def frontend_instruct2(self, tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id):
model_input = self.frontend_zero_shot(tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id)
del model_input['llm_prompt_speech_token']
del model_input['llm_prompt_speech_token_len']
return model_input
def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate):
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
embedding = self._extract_spk_embedding(prompt_speech_16k)
def frontend_vc(self, source_speech_16k, prompt_wav, resample_rate):
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_wav)
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_wav)
embedding = self._extract_spk_embedding(prompt_wav)
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,

View File

@ -1,4 +1,5 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -21,7 +22,8 @@ from torch.nn import functional as F
from contextlib import nullcontext
import uuid
from cosyvoice.utils.common import fade_in_out
from cosyvoice.utils.file_utils import convert_onnx_to_trt
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
from cosyvoice.utils.common import TrtContextWrapper
class CosyVoiceModel:
@ -30,22 +32,15 @@ class CosyVoiceModel:
llm: torch.nn.Module,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool):
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm
self.flow = flow
self.hift = hift
self.fp16 = fp16
self.llm.fp16 = fp16
self.flow.fp16 = fp16
if self.fp16 is True:
self.llm.half()
self.flow.half()
self.token_min_hop_len = 2 * self.flow.input_frame_rate
self.token_max_hop_len = 4 * self.flow.input_frame_rate
self.token_overlap_len = 20
# here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
self.flow.decoder.estimator.static_chunk_size = 0
# mel fade in out
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
self.mel_window = np.hamming(2 * self.mel_overlap_len)
@ -65,14 +60,15 @@ class CosyVoiceModel:
self.mel_overlap_dict = {}
self.flow_cache_dict = {}
self.hift_cache_dict = {}
self.silent_tokens = []
def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device, weights_only=True), strict=True)
self.llm.to(self.device).eval()
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device, weights_only=True), strict=True)
self.flow.to(self.device).eval()
# in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device, weights_only=True).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()
@ -84,52 +80,68 @@ class CosyVoiceModel:
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, fp16):
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model):
convert_onnx_to_trt(flow_decoder_estimator_model, flow_decoder_onnx_model, fp16)
if os.path.getsize(flow_decoder_estimator_model) == 0:
raise ValueError('{} is empty file, delete it and export again!'.format(flow_decoder_estimator_model))
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
del self.flow.decoder.estimator
import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
self.flow.decoder.estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
if self.flow.decoder.estimator_engine is None:
raise ValueError('failed to load trt {}'.format(flow_decoder_estimator_model))
self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context()
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
input_names = ["x", "mask", "mu", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
with self.llm_context:
cur_silent_token_num, max_silent_token_num = 0, 5
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
if isinstance(text, Generator):
assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!'
for i in self.llm.inference_bistream(text=text,
assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
token_generator = self.llm.inference_bistream(text=text,
prompt_text=prompt_text.to(self.device),
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device))
else:
token_generator = self.llm.inference(text=text.to(self.device),
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
prompt_text=prompt_text.to(self.device),
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device)):
self.tts_speech_token_dict[uuid].append(i)
else:
for i in self.llm.inference(text=text.to(self.device),
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
prompt_text=prompt_text.to(self.device),
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device)):
self.tts_speech_token_dict[uuid].append(i)
embedding=llm_embedding.to(self.device),
uuid=uuid)
for i in token_generator:
if i in self.silent_tokens:
cur_silent_token_num += 1
if cur_silent_token_num > max_silent_token_num:
continue
else:
cur_silent_token_num = 0
self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True
def vc_job(self, source_speech_token, uuid):
self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
self.llm_end_dict[uuid] = True
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
flow_cache=self.flow_cache_dict[uuid])
self.flow_cache_dict[uuid] = flow_cache
with torch.cuda.amp.autocast(self.fp16):
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
flow_cache=self.flow_cache_dict[uuid])
# mel overlap fade in out
if self.mel_overlap_dict[uuid].shape[2] != 0:
@ -160,11 +172,11 @@ class CosyVoiceModel:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
return tts_speech
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
with self.lock:
@ -172,7 +184,10 @@ class CosyVoiceModel:
self.hift_cache_dict[this_uuid] = None
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
if source_speech_token.shape[1] == 0:
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
else:
p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
p.start()
if stream is True:
token_hop_len = self.token_min_hop_len
@ -222,61 +237,9 @@ class CosyVoiceModel:
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
torch.cuda.empty_cache()
def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
self.hift_cache_dict[this_uuid] = None
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
if stream is True:
token_hop_len = self.token_min_hop_len
while True:
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
.unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=False)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
# increase token_hop_len for better speech quality
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
break
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True)
yield {'tts_speech': this_tts_speech.cpu()}
else:
# deal with all tokens
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True,
speed=speed)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize()
class CosyVoice2Model(CosyVoiceModel):
@ -285,48 +248,54 @@ class CosyVoice2Model(CosyVoiceModel):
llm: torch.nn.Module,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool):
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm
self.flow = flow
self.hift = hift
self.fp16 = fp16
self.llm.fp16 = fp16
self.flow.fp16 = fp16
if self.fp16 is True:
self.llm.half()
self.flow.half()
self.token_hop_len = 2 * self.flow.input_frame_rate
# here we fix flow encoder/decoder decoding_chunk_size, in the future we will send it as arguments, or use cache
self.flow.encoder.static_chunk_size = 2 * self.flow.input_frame_rate
self.flow.decoder.estimator.static_chunk_size = 2 * self.flow.input_frame_rate * self.flow.token_mel_ratio
# NOTE must matching training static_chunk_size
self.token_hop_len = 25
# hift cache
self.mel_cache_len = 8
self.source_cache_len = int(self.mel_cache_len * 480)
# speech fade in out
self.speech_window = np.hamming(2 * self.source_cache_len)
# rtf and decoding related
self.stream_scale_factor = 1
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.hift_cache_dict = {}
self.silent_tokens = []
def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
tts_mel, _ = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
finalize=finalize)
def load_vllm(self, model_dir):
export_cosyvoice2_vllm(self.llm, model_dir, self.device)
from vllm import EngineArgs, LLMEngine
engine_args = EngineArgs(model=model_dir,
skip_tokenizer_init=True,
enable_prompt_embeds=True,
gpu_memory_utilization=0.2)
self.llm.vllm = LLMEngine.from_engine_args(engine_args)
self.llm.lock = threading.Lock()
del self.llm.llm.model.model.layers
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
streaming=stream,
finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append hift cache
if self.hift_cache_dict[uuid] is not None:
@ -352,34 +321,40 @@ class CosyVoice2Model(CosyVoiceModel):
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
return tts_speech
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
if source_speech_token.shape[1] == 0:
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
else:
p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
p.start()
if stream is True:
token_offset = 0
prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
while True:
time.sleep(0.1)
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
token_offset=token_offset,
uuid=this_uuid,
stream=stream,
finalize=False)
token_offset += self.token_hop_len
token_offset += this_token_hop_len
yield {'tts_speech': this_tts_speech.cpu()}
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
break
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
@ -388,8 +363,8 @@ class CosyVoice2Model(CosyVoiceModel):
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
token_offset=token_offset,
uuid=this_uuid,
finalize=True)
yield {'tts_speech': this_tts_speech.cpu()}
else:
@ -400,12 +375,67 @@ class CosyVoice2Model(CosyVoiceModel):
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
token_offset=0,
uuid=this_uuid,
finalize=True,
speed=speed)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
torch.cuda.empty_cache()
self.hift_cache_dict.pop(this_uuid)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize()
class CosyVoice3Model(CosyVoice2Model):
def __init__(self,
llm: torch.nn.Module,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm
self.flow = flow
self.hift = hift
self.fp16 = fp16
# NOTE must matching training static_chunk_size
self.token_hop_len = 25
# rtf and decoding related
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.hift_cache_dict = {}
# FSQ silent and breath token
self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323]
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
streaming=stream,
finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append mel cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel = self.hift_cache_dict[uuid]['mel']
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
self.hift_cache_dict[uuid]['mel'] = tts_mel
else:
self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0}
if speed != 1.0:
assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:]
self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1]
return tts_speech

View File

@ -14,14 +14,13 @@
# limitations under the License.
import random
import json
import math
from functools import partial
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset
from cosyvoice.utils.file_utils import read_lists, read_json_lists
from cosyvoice.utils.file_utils import read_lists
class Processor(IterableDataset):
@ -127,10 +126,9 @@ def Dataset(data_list_file,
data_pipeline,
mode='train',
gan=False,
dpo=False,
shuffle=True,
partition=True,
tts_file='',
prompt_utt2data=''):
partition=True):
""" Construct dataset from arguments
We have two shuffle stage in the Dataset. The first is global
@ -142,23 +140,16 @@ def Dataset(data_list_file,
tokenizer (BaseTokenizer): tokenizer to tokenize
partition(bool): whether to do data partition in terms of rank
"""
assert mode in ['train', 'inference']
lists = read_lists(data_list_file)
if mode == 'inference':
with open(tts_file) as f:
tts_data = json.load(f)
utt2lists = read_json_lists(prompt_utt2data)
# filter unnecessary file in inference mode
lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
dataset = DataList(lists,
shuffle=shuffle,
partition=partition)
if mode == 'inference':
# map partial arg to parquet_opener func in inference mode
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
if gan is True:
# map partial arg to padding func in gan mode
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
# map partial arg to padding func
for i in range(1, len(data_pipeline)):
if data_pipeline[i].func.__name__ == 'compute_fbank':
data_pipeline[i] = partial(data_pipeline[i], token_mel_ratio=0)
if data_pipeline[i].func.__name__ == 'padding':
data_pipeline[i] = partial(data_pipeline[i], gan=gan, dpo=dpo)
for func in data_pipeline:
dataset = Processor(dataset, func, mode=mode)
return dataset

View File

@ -26,7 +26,7 @@ import pyworld as pw
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
def parquet_opener(data, mode='train', tts_data={}):
def parquet_opener(data, mode='train'):
""" Give url or local file, return file descriptor
Inplace operation.
@ -43,15 +43,9 @@ def parquet_opener(data, mode='train', tts_data={}):
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
df = df.to_pandas()
for i in range(len(df)):
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
continue
sample.update(dict(df.loc[i]))
if mode == 'train':
# NOTE do not return sample directly, must initialize a new dict
yield {**sample}
else:
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
yield {**sample, 'tts_index': index, 'tts_text': text}
# NOTE do not return sample directly, must initialize a new dict
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
@ -100,6 +94,8 @@ def filter(data,
continue
if len(sample['speech_token']) == 0:
continue
if 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
continue
if num_frames != 0:
if len(sample['text_token']) / num_frames < min_output_input_ratio:
continue
@ -159,6 +155,7 @@ def truncate(data, truncate_length=24576, mode='train'):
def compute_fbank(data,
feat_extractor,
token_mel_ratio=0,
mode='train'):
""" Extract fbank
@ -174,8 +171,13 @@ def compute_fbank(data,
assert 'utt' in sample
assert 'text_token' in sample
waveform = sample['speech']
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
sample['speech_feat'] = mat
feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
if token_mel_ratio != 0:
# trim to align speech_token and speech_feat
token_len = int(min(feat.shape[0] / token_mel_ratio, sample["speech_token"].shape[0]))
feat = feat[:token_mel_ratio * token_len]
sample["speech_token"] = sample["speech_token"][:token_len]
sample['speech_feat'] = feat
yield sample
@ -196,8 +198,8 @@ def compute_f0(data, sample_rate, hop_size, mode='train'):
assert 'text_token' in sample
waveform = sample['speech']
_f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
if sum(_f0 != 0) < 5: # this happens when the algorithm fails
_f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
if sum(_f0 != 0) < 5: # this happens when the algorithm fails
_f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
sample['pitch_feat'] = f0
@ -236,8 +238,10 @@ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
for sample in data:
assert 'text' in sample
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
if mode == 'inference':
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
if 'instruct' in sample:
sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
else:
sample['instruct_token'] = tokenizer.encode('', allowed_special=allowed_special)
yield sample
@ -345,18 +349,15 @@ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
""" Wrapper for static/dynamic batch
"""
if mode == 'inference':
return static_batch(data, 1)
if batch_type == 'static':
return static_batch(data, batch_size)
elif batch_type == 'dynamic':
return dynamic_batch(data, max_frames_in_batch)
else:
if batch_type == 'static':
return static_batch(data, batch_size)
elif batch_type == 'dynamic':
return dynamic_batch(data, max_frames_in_batch)
else:
logging.fatal('Unsupported batch type {}'.format(batch_type))
logging.fatal('Unsupported batch type {}'.format(batch_type))
def padding(data, use_spk_embedding, mode='train', gan=False):
def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
""" Padding the data into training data
Args:
@ -389,6 +390,9 @@ def padding(data, use_spk_embedding, mode='train', gan=False):
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
instruct_token_len = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
instruct_token = pad_sequence(instruct_token, batch_first=True, padding_value=0)
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
batch = {
@ -402,6 +406,8 @@ def padding(data, use_spk_embedding, mode='train', gan=False):
"text": text,
"text_token": text_token,
"text_token_len": text_token_len,
"instruct_token": instruct_token,
"instruct_token_len": instruct_token_len,
"utt_embedding": utt_embedding,
"spk_embedding": spk_embedding,
}
@ -418,16 +424,14 @@ def padding(data, use_spk_embedding, mode='train', gan=False):
# only gan train needs speech, delete it to save memory
del batch["speech"]
del batch["speech_len"]
if mode == 'inference':
tts_text = [sample[i]['tts_text'] for i in order]
tts_index = [sample[i]['tts_index'] for i in order]
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
batch.update({'tts_text': tts_text,
'tts_index': tts_index,
'tts_text_token': tts_text_token,
'tts_text_token_len': tts_text_token_len})
if dpo is True:
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
reject_speech_token_len = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
reject_speech_token = pad_sequence(reject_speech_token,
batch_first=True,
padding_value=0)
batch['reject_speech_token'] = reject_speech_token
batch['reject_speech_token_len'] = reject_speech_token_len
if use_spk_embedding is True:
batch["embedding"] = batch["spk_embedding"]
else:

176
cosyvoice/flow/DiT/dit.py Normal file
View File

@ -0,0 +1,176 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from x_transformers.x_transformers import RotaryEmbedding
from cosyvoice.utils.mask import add_optional_chunk_mask
from cosyvoice.flow.DiT.modules import (
TimestepEmbedding,
ConvNeXtV2Block,
CausalConvPositionEmbedding,
DiTBlock,
AdaLayerNormZero_Final,
precompute_freqs_cis,
get_pos_embed_indices,
)
# Text embedding
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
)
else:
self.extra_modeling = False
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
batch, text_len = text.shape[0], text.shape[1]
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text_len), value=0)
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
# possible extra modeling
if self.extra_modeling:
# sinus pos emb
batch_start = torch.zeros((batch,), dtype=torch.long)
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
# convnextv2 blocks
text = self.text_blocks(text)
return text
# noised input audio and context mixing embedding
class InputEmbedding(nn.Module):
def __init__(self, mel_dim, text_dim, out_dim, spk_dim=None):
super().__init__()
spk_dim = 0 if spk_dim is None else spk_dim
self.spk_dim = spk_dim
self.proj = nn.Linear(mel_dim * 2 + text_dim + spk_dim, out_dim)
self.conv_pos_embed = CausalConvPositionEmbedding(dim=out_dim)
def forward(
self,
x: float["b n d"],
cond: float["b n d"],
text_embed: float["b n d"],
spks: float["b d"],
):
to_cat = [x, cond, text_embed]
if self.spk_dim > 0:
spks = repeat(spks, "b c -> b t c", t=x.shape[1])
to_cat.append(spks)
x = self.proj(torch.cat(to_cat, dim=-1))
x = self.conv_pos_embed(x) + x
return x
# Transformer backbone using DiT blocks
class DiT(nn.Module):
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
mel_dim=80,
mu_dim=None,
long_skip_connection=False,
spk_dim=None,
out_channels=None,
static_chunk_size=50,
num_decoding_left_chunks=2
):
super().__init__()
self.time_embed = TimestepEmbedding(dim)
if mu_dim is None:
mu_dim = mel_dim
self.input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim)
self.rotary_embed = RotaryEmbedding(dim_head)
self.dim = dim
self.depth = depth
self.transformer_blocks = nn.ModuleList(
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
)
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.out_channels = out_channels
self.static_chunk_size = static_chunk_size
self.num_decoding_left_chunks = num_decoding_left_chunks
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
x = x.transpose(1, 2)
mu = mu.transpose(1, 2)
cond = cond.transpose(1, 2)
spks = spks.unsqueeze(dim=1)
batch, seq_len = x.shape[0], x.shape[1]
if t.ndim == 0:
t = t.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(t)
x = self.input_embed(x, cond, mu, spks.squeeze(1))
rope = self.rotary_embed.forward_from_seq_len(seq_len)
if self.long_skip_connection is not None:
residual = x
if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1)
else:
attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1).unsqueeze(dim=1)
for block in self.transformer_blocks:
x = block(x, t, mask=attn_mask.bool(), rope=rope)
if self.long_skip_connection is not None:
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
x = self.norm_out(x, t)
output = self.proj_out(x).transpose(1, 2)
return output

View File

@ -0,0 +1,616 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
from typing import Optional
import math
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
from x_transformers.x_transformers import apply_rotary_pos_emb
# raw wav to mel spec
class MelSpec(nn.Module):
def __init__(
self,
filter_length=1024,
hop_length=256,
win_length=1024,
n_mel_channels=100,
target_sample_rate=24_000,
normalize=False,
power=1,
norm=None,
center=True,
):
super().__init__()
self.n_mel_channels = n_mel_channels
self.mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=target_sample_rate,
n_fft=filter_length,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mel_channels,
power=power,
center=center,
normalized=normalize,
norm=norm,
)
self.register_buffer("dummy", torch.tensor(0), persistent=False)
def forward(self, inp):
if len(inp.shape) == 3:
inp = inp.squeeze(1) # 'b 1 nw -> b nw'
assert len(inp.shape) == 2
if self.dummy.device != inp.device:
self.to(inp.device)
mel = self.mel_stft(inp)
mel = mel.clamp(min=1e-5).log()
return mel
# sinusoidal position embedding
class SinusPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# convolutional position embedding
class ConvPositionEmbedding(nn.Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.conv1d = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
x = x.permute(0, 2, 1)
x = self.conv1d(x)
out = x.permute(0, 2, 1)
if mask is not None:
out = out.masked_fill(~mask, 0.0)
return out
class CausalConvPositionEmbedding(nn.Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.kernel_size = kernel_size
self.conv1 = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
nn.Mish(),
)
self.conv2 = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
nn.Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
x = x.permute(0, 2, 1)
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
x = self.conv1(x)
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
x = self.conv2(x)
out = x.permute(0, 2, 1)
if mask is not None:
out = out.masked_fill(~mask, 0.0)
return out
# rotary positional embedding related
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
# length = length if isinstance(length, int) else length.max()
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
pos = (
start.unsqueeze(1)
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
)
# avoid extra long error.
pos = torch.where(pos < max_pos, pos, max_pos - 1)
return pos
# Global Response Normalization layer (Instance Normalization ?)
class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
class ConvNeXtV2Block(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
dilation: int = 1,
):
super().__init__()
padding = (dilation * (7 - 1)) // 2
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.transpose(1, 2) # b n d -> b d n
x = self.dwconv(x)
x = x.transpose(1, 2) # b d n -> b n d
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
return residual + x
# AdaLayerNormZero
# return with modulated x for attn input, and params for later mlp modulation
class AdaLayerNormZero(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 6)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb=None):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
# AdaLayerNormZero for final layer
# return only with modulated x for attn input, cuz no more mlp modulation
class AdaLayerNormZero_Final(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 2)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(self.silu(emb))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
# FeedForward
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
activation = nn.GELU(approximate=approximate)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.ff(x)
# Attention with possible joint part
# modified from diffusers/src/diffusers/models/attention_processor.py
class Attention(nn.Module):
def __init__(
self,
processor: JointAttnProcessor | AttnProcessor,
dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.processor = processor
self.dim = dim
self.heads = heads
self.inner_dim = dim_head * heads
self.dropout = dropout
self.context_dim = context_dim
self.context_pre_only = context_pre_only
self.to_q = nn.Linear(dim, self.inner_dim)
self.to_k = nn.Linear(dim, self.inner_dim)
self.to_v = nn.Linear(dim, self.inner_dim)
if self.context_dim is not None:
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
if self.context_pre_only is not None:
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, dim))
self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only:
self.to_out_c = nn.Linear(self.inner_dim, dim)
def forward(
self,
x: float["b n d"], # noised input x # noqa: F722
c: float["b n d"] = None, # context c # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
else:
return self.processor(self, x, mask=mask, rope=rope)
# Attention processor
class AttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding
) -> torch.FloatTensor:
batch_size = x.shape[0]
# `sample` projections.
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# apply rotary position embedding
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = mask
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(-1)
else:
mask = mask[:, 0, -1].unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
return x
# Joint Attention processor for MM-DiT
# modified from diffusers/src/diffusers/models/attention_processor.py
class JointAttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
c: float["b nt d"] = None, # context c, here text # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.FloatTensor:
residual = x
batch_size = c.shape[0]
# `sample` projections.
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# `context` projections.
c_query = attn.to_q_c(c)
c_key = attn.to_k_c(c)
c_value = attn.to_v_c(c)
# apply rope for context and noised input independently
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
if c_rope is not None:
freqs, xpos_scale = c_rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
# attention
query = torch.cat([query, c_query], dim=1)
key = torch.cat([key, c_key], dim=1)
value = torch.cat([value, c_value], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# Split the attention outputs.
x, c = (
x[:, : residual.shape[1]],
x[:, residual.shape[1]:],
)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if not attn.context_pre_only:
c = attn.to_out_c(c)
if mask is not None:
mask = mask.unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
return x, c
# DiT Block
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor=AttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
)
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
# attention
attn_output = self.attn(x=norm, mask=mask, rope=rope)
# process attention output for input x
x = x + gate_msa.unsqueeze(1) * attn_output
ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(ff_norm)
x = x + gate_mlp.unsqueeze(1) * ff_output
return x
# MMDiT Block https://arxiv.org/abs/2403.03206
class MMDiTBlock(nn.Module):
r"""
modified from diffusers/src/diffusers/models/attention.py
notes.
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
_x: noised input related. (right part)
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
"""
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
super().__init__()
self.context_pre_only = context_pre_only
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
self.attn_norm_x = AdaLayerNormZero(dim)
self.attn = Attention(
processor=JointAttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
context_dim=dim,
context_pre_only=context_pre_only,
)
if not context_pre_only:
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
else:
self.ff_norm_c = None
self.ff_c = None
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
# pre-norm & modulation for attention input
if self.context_pre_only:
norm_c = self.attn_norm_c(c, t)
else:
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
# attention
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
# process attention output for context c
if self.context_pre_only:
c = None
else: # if not last layer
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
c_ff_output = self.ff_c(norm_c)
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
# process attention output for input x
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
x_ff_output = self.ff_x(norm_x)
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
return c, x
# time step conditioning embedding
class TimestepEmbedding(nn.Module):
def __init__(self, dim, freq_embed_dim=256):
super().__init__()
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, timestep: float["b"]): # noqa: F821
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d
return time

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -27,34 +28,11 @@ class Transpose(torch.nn.Module):
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.transpose(x, self.dim0, self.dim1)
return x
class CausalBlock1D(Block1D):
def __init__(self, dim: int, dim_out: int):
super(CausalBlock1D, self).__init__(dim, dim_out)
self.block = torch.nn.Sequential(
CausalConv1d(dim, dim_out, 3),
Transpose(1, 2),
nn.LayerNorm(dim_out),
Transpose(1, 2),
nn.Mish(),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor):
output = self.block(x * mask)
return output * mask
class CausalResnetBlock1D(ResnetBlock1D):
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
self.block1 = CausalBlock1D(dim, dim_out)
self.block2 = CausalBlock1D(dim_out, dim_out)
class CausalConv1d(torch.nn.Conv1d):
def __init__(
self,
@ -76,20 +54,42 @@ class CausalConv1d(torch.nn.Conv1d):
padding_mode=padding_mode,
device=device, dtype=dtype)
assert stride == 1
self.causal_padding = (kernel_size - 1, 0)
self.causal_padding = kernel_size - 1
def forward(self, x: torch.Tensor):
x = F.pad(x, self.causal_padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, (self.causal_padding, 0), value=0.0)
x = super(CausalConv1d, self).forward(x)
return x
class CausalBlock1D(Block1D):
def __init__(self, dim: int, dim_out: int):
super(CausalBlock1D, self).__init__(dim, dim_out)
self.block = torch.nn.Sequential(
CausalConv1d(dim, dim_out, 3),
Transpose(1, 2),
nn.LayerNorm(dim_out),
Transpose(1, 2),
nn.Mish(),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
output = self.block(x * mask)
return output * mask
class CausalResnetBlock1D(ResnetBlock1D):
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
self.block1 = CausalBlock1D(dim, dim_out)
self.block2 = CausalBlock1D(dim_out, dim_out)
class ConditionalDecoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
causal=False,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
@ -106,7 +106,7 @@ class ConditionalDecoder(nn.Module):
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.causal = causal
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
@ -123,8 +123,7 @@ class ConditionalDecoder(nn.Module):
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
@ -138,16 +137,14 @@ class ConditionalDecoder(nn.Module):
]
)
downsample = (
Downsample1D(output_channel) if not is_last else
CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for _ in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
@ -169,11 +166,7 @@ class ConditionalDecoder(nn.Module):
input_channel = channels[i] * 2
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = CausalResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
) if self.causal else ResnetBlock1D(
resnet = ResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
@ -193,10 +186,10 @@ class ConditionalDecoder(nn.Module):
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
self.final_block = Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
@ -214,7 +207,7 @@ class ConditionalDecoder(nn.Module):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None):
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
"""Forward pass of the UNet1DConditional model.
Args:
@ -249,9 +242,8 @@ class ConditionalDecoder(nn.Module):
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
@ -268,9 +260,8 @@ class ConditionalDecoder(nn.Module):
for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
@ -285,9 +276,211 @@ class ConditionalDecoder(nn.Module):
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous()
# attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask
class CausalConditionalDecoder(ConditionalDecoder):
def __init__(
self,
in_channels,
out_channels,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
n_blocks=1,
num_mid_blocks=2,
num_heads=4,
act_fn="snake",
static_chunk_size=50,
num_decoding_left_chunks=2,
):
"""
This decoder requires an input with the same shape of the target. So, if your text content
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
"""
torch.nn.Module.__init__(self)
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.static_chunk_size = static_chunk_size
self.num_decoding_left_chunks = num_decoding_left_chunks
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
downsample = (
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for _ in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
channels = channels[::-1] + (channels[0],)
for i in range(len(channels) - 1):
input_channel = channels[i] * 2
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = CausalResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else CausalConv1d(output_channel, output_channel, 3)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = CausalBlock1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t).to(t.dtype)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
if cond is not None:
x = pack([x, cond], "b * t")[0]
hiddens = []
masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous()
if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
else:
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
hiddens.append(x) # Save hidden states for skip connections
x = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous()
if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
else:
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
for resnet, transformer_blocks, upsample in self.up_blocks:
mask_up = masks.pop()
skip = hiddens.pop()
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous()
if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
else:
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,

View File

@ -37,14 +37,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.mel_feat_conf = mel_feat_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
@ -91,7 +88,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h)
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
# NOTE this is unnecessary, feat/h already same shape
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
@ -111,16 +108,12 @@ class MaskedDiffWithXvec(torch.nn.Module):
prompt_feat_len,
embedding,
flow_cache):
if self.fp16 is True:
prompt_feat = prompt_feat.half()
embedding = embedding.half()
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
# concat speech token and prompt speech token
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
@ -145,7 +138,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
cond=conds,
n_timesteps=10,
prompt_len=mel_len1,
flow_cache=flow_cache
cache=flow_cache
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
@ -169,14 +162,11 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.mel_feat_conf = mel_feat_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
@ -190,6 +180,52 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
self.token_mel_ratio = token_mel_ratio
self.pre_lookahead_len = pre_lookahead_len
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)
# NOTE unified training, static_chunk_size > 0 or = 0
streaming = True if random.random() < 0.5 else False
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
h = self.encoder_proj(h)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds,
streaming=streaming,
)
return {'loss': loss}
@torch.inference_mode()
def inference(self,
token,
@ -199,11 +235,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
prompt_feat,
prompt_feat_len,
embedding,
streaming,
finalize):
if self.fp16 is True:
prompt_feat = prompt_feat.half()
embedding = embedding.half()
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
@ -215,9 +248,11 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
if finalize is False:
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
if finalize is True:
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
else:
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
h = self.encoder_proj(h)
@ -232,8 +267,166 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10
n_timesteps=10,
streaming=streaming
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat.float(), None
class CausalMaskedDiffWithDiT(torch.nn.Module):
def __init__(self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 4096,
input_frame_rate: int = 50,
only_mask_loss: bool = True,
token_mel_ratio: int = 2,
pre_lookahead_len: int = 3,
pre_lookahead_layer: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
logging.info(f"input frame rate={self.input_frame_rate}")
self.input_embedding = nn.Embedding(vocab_size, input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
self.pre_lookahead_len = pre_lookahead_len
self.pre_lookahead_layer = pre_lookahead_layer
self.decoder = decoder
self.only_mask_loss = only_mask_loss
self.token_mel_ratio = token_mel_ratio
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)
# NOTE unified training, static_chunk_size > 0 or = 0
streaming = True if random.random() < 0.5 else False
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h = self.pre_lookahead_layer(token)
h = h.repeat_interleave(self.token_mel_ratio, dim=1)
mask = mask.repeat_interleave(self.token_mel_ratio, dim=1).squeeze(dim=-1)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds,
streaming=streaming,
)
return {'loss': loss}
@torch.inference_mode()
def inference(self,
token,
token_len,
prompt_token,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding,
streaming,
finalize):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
if finalize is True:
h = self.pre_lookahead_layer(token)
else:
h = self.pre_lookahead_layer(token[:, :-self.pre_lookahead_len], context=token[:, -self.pre_lookahead_len:])
h = h.repeat_interleave(self.token_mel_ratio, dim=1)
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
# get conditions
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat, _ = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10,
streaming=streaming
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat.float(), None
if __name__ == '__main__':
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from hyperpyyaml import load_hyperpyyaml
with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
configs = load_hyperpyyaml(f, overrides={'llm': None, 'hift': None})
model = configs['flow']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
max_len = 10 * model.decoder.estimator.static_chunk_size
chunk_size = model.decoder.estimator.static_chunk_size
context_size = model.pre_lookahead_layer.pre_lookahead_len
token = torch.randint(0, 6561, size=(1, max_len)).to(device)
token_len = torch.tensor([max_len]).to(device)
prompt_token = torch.randint(0, 6561, size=(1, chunk_size)).to(device)
prompt_token_len = torch.tensor([chunk_size]).to(device)
prompt_feat = torch.rand(1, chunk_size * 2, 80).to(device)
prompt_feat_len = torch.tensor([chunk_size * 2]).to(device)
prompt_embedding = torch.rand(1, 192).to(device)
pred_gt, _ = model.inference(token, token_len, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=True)
for i in range(0, max_len, chunk_size):
finalize = True if i + chunk_size + context_size >= max_len else False
pred_chunk, _ = model.inference(token[:, :i + chunk_size + context_size], torch.tensor([token[:, :i + chunk_size + context_size].shape[1]]).to(device),
prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=finalize)
pred_chunk = pred_chunk[:, :, i * model.token_mel_ratio:]
print((pred_gt[:, :, i * model.token_mel_ratio: i * model.token_mel_ratio + pred_chunk.shape[2]] - pred_chunk).abs().max().item())

View File

@ -1,4 +1,5 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,10 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import torch
import torch.nn.functional as F
from matcha.models.components.flow_matching import BASECFM
from cosyvoice.utils.common import set_all_random_seed
class ConditionalCFM(BASECFM):
@ -31,10 +32,9 @@ class ConditionalCFM(BASECFM):
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
# Just change the architecture of the estimator here
self.estimator = estimator
self.lock = threading.Lock()
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
"""Forward diffusion
Args:
@ -54,21 +54,21 @@ class ConditionalCFM(BASECFM):
"""
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
cache_size = flow_cache.shape[2]
cache_size = cache.shape[2]
# fix prompt and overlap part mu and z
if cache_size != 0:
z[:, :, :cache_size] = flow_cache[:, :, :, 0]
mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
z[:, :, :cache_size] = cache[:, :, :, 0]
mu[:, :, :cache_size] = cache[:, :, :, 1]
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
cache = torch.stack([z_cache, mu_cache], dim=-1)
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
def solve_euler(self, x, t_span, mu, mask, spks, cond):
def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
"""
Fixed euler solver for ODEs.
Args:
@ -91,12 +91,13 @@ class ConditionalCFM(BASECFM):
sol = []
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
# NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
t_in = torch.zeros([2], device=x.device, dtype=spks.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
for step in range(1, len(t_span)):
# Classifier-Free Guidance inference introduced in VoiceBox
x_in[:] = x
@ -109,7 +110,8 @@ class ConditionalCFM(BASECFM):
x_in, mask_in,
mu_in, t_in,
spks_in,
cond_in
cond_in,
streaming
)
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
@ -121,28 +123,36 @@ class ConditionalCFM(BASECFM):
return sol[-1].float()
def forward_estimator(self, x, mask, mu, t, spks, cond):
def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
if isinstance(self.estimator, torch.nn.Module):
return self.estimator.forward(x, mask, mu, t, spks, cond)
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
else:
with self.lock:
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
self.estimator.set_input_shape('t', (2,))
self.estimator.set_input_shape('spks', (2, 80))
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
# NOTE need to synchronize when switching stream
torch.cuda.current_stream().synchronize()
with stream:
estimator.set_input_shape('x', (2, 80, x.size(2)))
estimator.set_input_shape('mask', (2, 1, x.size(2)))
estimator.set_input_shape('mu', (2, 80, x.size(2)))
estimator.set_input_shape('t', (2,))
estimator.set_input_shape('spks', (2, 80))
estimator.set_input_shape('cond', (2, 80, x.size(2)))
data_ptrs = [x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()]
for i, j in enumerate(data_ptrs):
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()])
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.estimator.release_estimator(estimator, stream)
return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
"""Computes diffusion loss
Args:
@ -179,7 +189,7 @@ class ConditionalCFM(BASECFM):
spks = spks * cfg_mask.view(-1, 1)
cond = cond * cfg_mask.view(-1, 1, 1)
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
return loss, y
@ -187,10 +197,11 @@ class ConditionalCFM(BASECFM):
class CausalConditionalCFM(ConditionalCFM):
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
set_all_random_seed(0)
self.rand_noise = torch.randn([1, 80, 50 * 300])
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
"""Forward diffusion
Args:
@ -214,4 +225,4 @@ class CausalConditionalCFM(ConditionalCFM):
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None

View File

@ -51,6 +51,7 @@ class InterpolateRegulator(nn.Module):
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
# NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
# x in (B, T, D)
if x2.shape[1] > 40:
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')

View File

@ -1,10 +1,16 @@
import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import weight_norm
import torch.nn.functional as F
try:
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
except ImportError:
from torch.nn.utils import weight_norm, spectral_norm
from typing import List, Optional, Tuple
from einops import rearrange
from torchaudio.transforms import Spectrogram
LRELU_SLOPE = 0.1
class MultipleDiscriminator(nn.Module):
def __init__(
@ -138,3 +144,87 @@ class DiscriminatorR(nn.Module):
x += h
return x, fmap
class MultiResSpecDiscriminator(torch.nn.Module):
def __init__(self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
window="hann_window"):
super(MultiResSpecDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for _, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
return torch.abs(x_stft).transpose(2, 1)
class SpecDiscriminator(nn.Module):
"""docstring for Discriminator."""
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
super(SpecDiscriminator, self).__init__()
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.window = getattr(torch, window)(win_length)
self.discriminators = nn.ModuleList([
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
])
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
def forward(self, y):
fmap = []
y = y.squeeze(1)
y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.device))
y = y.unsqueeze(1)
for _, d in enumerate(self.discriminators):
y = d(y)
y = F.leaky_relu(y, LRELU_SLOPE)
fmap.append(y)
y = self.out(y)
fmap.append(y)
return torch.flatten(y, 1, -1), fmap

View File

@ -13,7 +13,11 @@
# limitations under the License.
import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import weight_norm
try:
from torch.nn.utils.parametrizations import weight_norm
except ImportError:
from torch.nn.utils import weight_norm
from cosyvoice.transformer.convolution import CausalConv1d
class ConvRNNF0Predictor(nn.Module):
@ -53,3 +57,47 @@ class ConvRNNF0Predictor(nn.Module):
x = self.condnet(x)
x = x.transpose(1, 2)
return torch.abs(self.classifier(x).squeeze(-1))
class CausalConvRNNF0Predictor(nn.Module):
def __init__(self,
num_class: int = 1,
in_channels: int = 80,
cond_channels: int = 512
):
super().__init__()
self.num_class = num_class
self.condnet = nn.Sequential(
weight_norm(
CausalConv1d(in_channels, cond_channels, kernel_size=4, causal_type='right')
),
nn.ELU(),
weight_norm(
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
),
nn.ELU(),
weight_norm(
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
),
nn.ELU(),
weight_norm(
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
),
nn.ELU(),
weight_norm(
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
),
nn.ELU(),
)
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
def forward(self, x: torch.Tensor, finalize: bool = True) -> torch.Tensor:
if finalize is True:
x = self.condnet[0](x)
else:
x = self.condnet[0](x[:, :, :-self.condnet[0].causal_padding], x[:, :, -self.condnet[0].causal_padding:])
for i in range(1, len(self.condnet)):
x = self.condnet[i](x)
x = x.transpose(1, 2)
return torch.abs(self.classifier(x).squeeze(-1))

View File

@ -23,9 +23,12 @@ import torch.nn.functional as F
from torch.nn import Conv1d
from torch.nn import ConvTranspose1d
from torch.nn.utils import remove_weight_norm
from torch.nn.utils.parametrizations import weight_norm
try:
from torch.nn.utils.parametrizations import weight_norm
except ImportError:
from torch.nn.utils import weight_norm
from torch.distributions.uniform import Uniform
from cosyvoice.transformer.convolution import CausalConv1d, CausalConv1dDownSample, CausalConv1dUpsample
from cosyvoice.transformer.activation import Snake
from cosyvoice.utils.common import get_padding
from cosyvoice.utils.common import init_weights
@ -47,8 +50,10 @@ class ResBlock(torch.nn.Module):
channels: int = 512,
kernel_size: int = 3,
dilations: List[int] = [1, 3, 5],
causal: bool = False,
):
super(ResBlock, self).__init__()
self.causal = causal
self.convs1 = nn.ModuleList()
self.convs2 = nn.ModuleList()
@ -61,7 +66,14 @@ class ResBlock(torch.nn.Module):
kernel_size,
1,
dilation=dilation,
padding=get_padding(kernel_size, dilation)
padding=get_padding(kernel_size, dilation)) if causal is False else
CausalConv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation,
causal_type='left'
)
)
)
@ -73,7 +85,14 @@ class ResBlock(torch.nn.Module):
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)
padding=get_padding(kernel_size, 1)) if causal is False else
CausalConv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
causal_type='left'
)
)
)
@ -136,11 +155,13 @@ class SineGen(torch.nn.Module):
@torch.no_grad()
def forward(self, f0):
""" sine_tensor, uv = forward(f0)
input F0: tensor(batchsize=1, dim=1, length)
f0 for unvoiced steps should be 0
output sine_tensor: tensor(batchsize=1, length, dim)
output uv: tensor(batchsize=1, length, 1)
"""
:param f0: [B, 1, sample_len], Hz
:return: [B, 1, sample_len]
"""
f0 = f0.transpose(1, 2)
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
for i in range(self.harmonic_num + 1):
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
@ -162,6 +183,134 @@ class SineGen(torch.nn.Module):
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
# first: set the unvoiced part to 0 by uv
# then: additive noise
sine_waves = sine_waves * uv + noise
return sine_waves.transpose(1, 2), uv.transpose(1, 2), noise
class SineGen2(torch.nn.Module):
""" Definition of sine generator
SineGen(samp_rate, harmonic_num = 0,
sine_amp = 0.1, noise_std = 0.003,
voiced_threshold = 0,
flag_for_pulse=False)
samp_rate: sampling rate in Hz
harmonic_num: number of harmonic overtones (default 0)
sine_amp: amplitude of sine-wavefrom (default 0.1)
noise_std: std of Gaussian noise (default 0.003)
voiced_thoreshold: F0 threshold for U/V classification (default 0)
flag_for_pulse: this SinGen is used inside PulseGen (default False)
Note: when flag_for_pulse is True, the first time step of a voiced
segment is always sin(np.pi) or cos(0)
"""
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
sine_amp=0.1, noise_std=0.003,
voiced_threshold=0,
flag_for_pulse=False,
causal=False):
super(SineGen2, self).__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
self.harmonic_num = harmonic_num
self.dim = self.harmonic_num + 1
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold
self.flag_for_pulse = flag_for_pulse
self.upsample_scale = upsample_scale
self.causal = causal
if causal is True:
self.rand_ini = torch.rand(1, 9)
self.rand_ini[:, 0] = 0
self.sine_waves = torch.rand(1, 300 * 24000, 9)
def _f02uv(self, f0):
# generate uv signal
uv = (f0 > self.voiced_threshold).type(torch.float32)
return uv
def _f02sine(self, f0_values):
""" f0_values: (batchsize, length, dim)
where dim indicates fundamental tone and overtones
"""
# convert to F0 in rad. The interger part n can be ignored
# because 2 * np.pi * n doesn't affect phase
rad_values = (f0_values / self.sampling_rate) % 1
# initial phase noise (no noise for fundamental component)
if self.training is False and self.causal is True:
rad_values[:, 0, :] = rad_values[:, 0, :] + self.rand_ini.to(rad_values.device)
else:
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
if not self.flag_for_pulse:
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
scale_factor=1 / self.upsample_scale,
mode="linear").transpose(1, 2)
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
scale_factor=self.upsample_scale, mode="nearest" if self.causal is True else 'linear').transpose(1, 2)
sines = torch.sin(phase)
else:
# If necessary, make sure that the first time step of every
# voiced segments is sin(pi) or cos(0)
# This is used for pulse-train generation
# identify the last time step in unvoiced segments
uv = self._f02uv(f0_values)
uv_1 = torch.roll(uv, shifts=-1, dims=1)
uv_1[:, -1, :] = 1
u_loc = (uv < 1) * (uv_1 > 0)
# get the instantanouse phase
tmp_cumsum = torch.cumsum(rad_values, dim=1)
# different batch needs to be processed differently
for idx in range(f0_values.shape[0]):
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
# stores the accumulation of i.phase within
# each voiced segments
tmp_cumsum[idx, :, :] = 0
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
# rad_values - tmp_cumsum: remove the accumulation of i.phase
# within the previous voiced segment.
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
# get the sines
sines = torch.cos(i_phase * 2 * np.pi)
return sines
def forward(self, f0):
""" sine_tensor, uv = forward(f0)
input F0: tensor(batchsize=1, length, dim=1)
f0 for unvoiced steps should be 0
output sine_tensor: tensor(batchsize=1, length, dim)
output uv: tensor(batchsize=1, length, 1)
"""
# fundamental component
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
# generate sine waveforms
sine_waves = self._f02sine(fn) * self.sine_amp
# generate uv signal
uv = self._f02uv(f0)
# noise: for unvoiced should be similar to sine_amp
# std = self.sine_amp/3 -> max value ~ self.sine_amp
# . for voiced regions is self.noise_std
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
if self.training is False and self.causal is True:
noise = noise_amp * self.sine_waves[:, :sine_waves.shape[1]].to(sine_waves.device)
else:
noise = noise_amp * torch.randn_like(sine_waves)
# first: set the unvoiced part to 0 by uv
# then: additive noise
sine_waves = sine_waves * uv + noise
@ -187,19 +336,24 @@ class SourceModuleHnNSF(torch.nn.Module):
"""
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0):
add_noise_std=0.003, voiced_threshod=0, sinegen_type='1', causal=False):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
# to produce sine waveforms
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
sine_amp, add_noise_std, voiced_threshod)
if sinegen_type == '1':
self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
else:
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, sine_amp, add_noise_std, voiced_threshod, causal=causal)
# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
self.causal = causal
if causal is True:
self.uv = torch.rand(1, 300 * 24000, 1)
def forward(self, x):
"""
@ -210,13 +364,14 @@ class SourceModuleHnNSF(torch.nn.Module):
"""
# source for harmonic branch
with torch.no_grad():
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
sine_wavs = sine_wavs.transpose(1, 2)
uv = uv.transpose(1, 2)
sine_wavs, uv, _ = self.l_sin_gen(x)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
# source for noise branch, in the same shape as uv
noise = torch.randn_like(uv) * self.sine_amp / 3
if self.training is False and self.causal is True:
noise = self.uv[:, :uv.shape[1]] * self.sine_amp / 3
else:
noise = torch.randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv
@ -256,13 +411,16 @@ class HiFTGenerator(nn.Module):
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
# NOTE in CosyVoice2, we use the original SineGen implementation
self.m_source = SourceModuleHnNSF(
sampling_rate=sampling_rate,
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
harmonic_num=nb_harmonics,
sine_amp=nsf_alpha,
add_noise_std=nsf_sigma,
voiced_threshod=nsf_voiced_threshold)
voiced_threshod=nsf_voiced_threshold,
sinegen_type='1' if self.sampling_rate == 22050 else '2',
causal=False)
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
self.conv_pre = weight_norm(
@ -409,3 +567,180 @@ class HiFTGenerator(nn.Module):
s[:, :, :cache_source.shape[2]] = cache_source
generated_speech = self.decode(x=speech_feat, s=s)
return generated_speech, s
class CausalHiFTGenerator(HiFTGenerator):
"""
HiFTNet Generator: Neural Source Filter + ISTFTNet
https://arxiv.org/abs/2309.09493
"""
def __init__(
self,
in_channels: int = 80,
base_channels: int = 512,
nb_harmonics: int = 8,
sampling_rate: int = 22050,
nsf_alpha: float = 0.1,
nsf_sigma: float = 0.003,
nsf_voiced_threshold: float = 10,
upsample_rates: List[int] = [8, 8],
upsample_kernel_sizes: List[int] = [16, 16],
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
resblock_kernel_sizes: List[int] = [3, 7, 11],
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
source_resblock_kernel_sizes: List[int] = [7, 11],
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
lrelu_slope: float = 0.1,
audio_limit: float = 0.99,
conv_pre_look_right: int = 4,
f0_predictor: torch.nn.Module = None,
):
torch.nn.Module.__init__(self)
self.out_channels = 1
self.nb_harmonics = nb_harmonics
self.sampling_rate = sampling_rate
self.istft_params = istft_params
self.lrelu_slope = lrelu_slope
self.audio_limit = audio_limit
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.m_source = SourceModuleHnNSF(
sampling_rate=sampling_rate,
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
harmonic_num=nb_harmonics,
sine_amp=nsf_alpha,
add_noise_std=nsf_sigma,
voiced_threshod=nsf_voiced_threshold,
sinegen_type='1' if self.sampling_rate == 22050 else '2',
causal=True)
self.upsample_rates = upsample_rates
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
self.conv_pre = weight_norm(
CausalConv1d(in_channels, base_channels, conv_pre_look_right + 1, 1, causal_type='right')
)
# Up
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
CausalConv1dUpsample(
base_channels // (2**i),
base_channels // (2**(i + 1)),
k,
u,
)
)
)
# Down
self.source_downs = nn.ModuleList()
self.source_resblocks = nn.ModuleList()
downsample_rates = [1] + upsample_rates[::-1][:-1]
downsample_cum_rates = np.cumprod(downsample_rates)
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
if u == 1:
self.source_downs.append(
CausalConv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1, causal_type='left')
)
else:
self.source_downs.append(
CausalConv1dDownSample(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u)
)
self.source_resblocks.append(
ResBlock(base_channels // (2 ** (i + 1)), k, d, causal=True)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = base_channels // (2**(i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(ResBlock(ch, k, d, causal=True))
self.conv_post = weight_norm(CausalConv1d(ch, istft_params["n_fft"] + 2, 7, 1, causal_type='left'))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
self.reflection_pad = nn.ReflectionPad1d((1, 0))
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
self.conv_pre_look_right = conv_pre_look_right
self.f0_predictor = f0_predictor
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0), finalize: bool = True) -> torch.Tensor:
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
if finalize is True:
x = self.conv_pre(x)
else:
x = self.conv_pre(x[:, :, :-self.conv_pre_look_right], x[:, :, -self.conv_pre_look_right:])
s_stft_real = s_stft_real[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
s_stft_imag = s_stft_imag[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, self.lrelu_slope)
x = self.ups[i](x)
if i == self.num_upsamples - 1:
x = self.reflection_pad(x)
# fusion
si = self.source_downs[i](s_stft)
si = self.source_resblocks[i](si)
x = x + si
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
x = self._istft(magnitude, phase)
if finalize is False:
x = x[:, :-int(np.prod(self.upsample_rates) * self.istft_params['hop_len'])]
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
return x
@torch.inference_mode()
def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor:
# mel->f0 NOTE f0_predictor precision is crucial for causal inference, move self.f0_predictor to cpu if necessary
self.f0_predictor.to('cpu')
f0 = self.f0_predictor(speech_feat.cpu(), finalize=finalize).to(speech_feat)
# f0->source
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
s, _, _ = self.m_source(s)
s = s.transpose(1, 2)
if finalize is True:
generated_speech = self.decode(x=speech_feat, s=s, finalize=finalize)
else:
generated_speech = self.decode(x=speech_feat[:, :, :-self.f0_predictor.condnet[0].causal_padding], s=s, finalize=finalize)
return generated_speech, s
if __name__ == '__main__':
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from hyperpyyaml import load_hyperpyyaml
with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
configs = load_hyperpyyaml(f, overrides={'llm': None, 'flow': None})
model = configs['hift']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
max_len, chunk_size, context_size = 300, 30, 8
mel = torch.rand(1, 80, max_len).to(device)
pred_gt, _ = model.inference(mel)
for i in range(0, max_len, chunk_size):
finalize = True if i + chunk_size + context_size >= max_len else False
pred_chunk, _ = model.inference(mel[:, :, : i + chunk_size + context_size], finalize=finalize)
pred_chunk = pred_chunk[:, i * 480:]
print((pred_gt[:, i * 480:i * 480 + pred_chunk.shape[1]] - pred_chunk).abs().max().item())

View File

@ -41,7 +41,7 @@ class HiFiGan(nn.Module):
loss_fm = feature_loss(fmap_rs, fmap_gs)
loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
if self.tpr_loss_weight != 0:
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
loss_tpr = tpr_loss(y_d_gs, y_d_rs, self.tpr_loss_tau)
else:
loss_tpr = torch.zeros(1).to(device)
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
@ -56,7 +56,7 @@ class HiFiGan(nn.Module):
with torch.no_grad():
generated_speech, generated_f0 = self.generator(batch, device)
# 2. calculate discriminator outputs
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech.detach())
# 3. calculate discriminator losses, tpr losses [Optional]
loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
if self.tpr_loss_weight != 0:

View File

@ -1,4 +1,5 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,7 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import queue
import random
import time
import threading
from typing import Dict, Optional, Callable, List, Generator
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
@ -21,6 +27,7 @@ from cosyvoice.utils.common import IGNORE_ID
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
from cosyvoice.utils.common import th_accuracy
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.mask import make_pad_mask
class TransformerLM(torch.nn.Module):
@ -50,8 +57,9 @@ class TransformerLM(torch.nn.Module):
)
# 2. build speech token language model related modules
self.sos_eos = 0
self.sos = 0
self.task_id = 1
self.eos_token = self.speech_token_size
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
self.llm = llm
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
@ -79,10 +87,10 @@ class TransformerLM(torch.nn.Module):
encoder_out = self.text_encoder_affine_layer(encoder_out)
return encoder_out, encoder_out_lens
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
for i in range(len(text_token))]
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
@ -120,15 +128,15 @@ class TransformerLM(torch.nn.Module):
embedding = self.spk_embed_affine_layer(embedding)
embedding = embedding.unsqueeze(1)
# 3. eos and task_id
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
# 3. sos and task_id
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
# 4. encode speech_token
speech_token = self.speech_embedding(speech_token)
# 5. unpad and pad
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
task_id_emb, speech_token, speech_token_len)
# 6. run lm forward
@ -148,7 +156,7 @@ class TransformerLM(torch.nn.Module):
num_trials, max_trials = 0, 100
while True:
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
if (not ignore_eos) or (self.speech_token_size not in top_ids):
if (not ignore_eos) or (top_ids < self.speech_token_size):
break
num_trials += 1
if num_trials > max_trials:
@ -168,10 +176,8 @@ class TransformerLM(torch.nn.Module):
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
uuid: str = '',
) -> Generator[torch.Tensor, None, None]:
if self.fp16 is True:
embedding = embedding.half()
device = text.device
text = torch.concat([prompt_text, text], dim=1)
text_len += prompt_text_len
@ -189,13 +195,13 @@ class TransformerLM(torch.nn.Module):
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
# 3. concat llm_input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
@ -211,11 +217,8 @@ class TransformerLM(torch.nn.Module):
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
device=lm_input.device)).to(torch.bool))
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
# force continue decode first token
if i == 0:
logp[:, self.speech_token_size] = -float('inf')
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
if top_ids == self.eos_token:
break
# in stream mode, yield token one by one
yield top_ids
@ -229,6 +232,17 @@ class Qwen2Encoder(torch.nn.Module):
super().__init__()
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T)
outs = self.model(
inputs_embeds=xs,
attention_mask=masks,
output_hidden_states=True,
return_dict=True,
)
return outs.hidden_states[-1], masks.unsqueeze(1)
def forward_one_step(self, xs, masks, cache=None):
input_masks = masks[:, -1, :]
outs = self.model(
@ -260,11 +274,11 @@ class Qwen2LM(TransformerLM):
self.llm_input_size = llm_input_size
self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size
# 2. build speech token language model related modules
self.sos_eos = 0
self.sos = 0
self.task_id = 1
self.fill_token = 2
self.eos_token = speech_token_size
self.fill_token = speech_token_size + 2
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
self.llm = llm
@ -283,6 +297,146 @@ class Qwen2LM(TransformerLM):
self.sampling = sampling
self.mix_ratio = mix_ratio
# 5. vllm related
self.stop_token_ids = [speech_token_size + i for i in range(3)]
self.vllm_output_queue = {}
def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
lm_target, lm_input = [], []
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
# NOTE add instruct_token in CosyVoice3
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
for i in range(len(text_token)):
# bistream sequence
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
this_lm_target += [IGNORE_ID] * instruct_token_len[i]
this_lm_input.append(instruct_token_emb[i])
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
if len(this_text_token) == self.mix_ratio[0]:
assert len(this_speech_token) == self.mix_ratio[1]
this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
this_lm_target += this_speech_token
this_lm_target.append(self.fill_token)
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
else:
this_lm_target += [-1] * len(this_text_token)
this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
this_lm_target.append(self.eos_token)
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
this_lm_input.append(task_id_emb.squeeze(dim=0))
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
# unistream sequence
else:
this_lm_target = torch.tensor([IGNORE_ID] * (1 + instruct_token_len[i] + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
this_lm_input = torch.concat([sos_emb.squeeze(dim=0), instruct_token_emb[i], text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
lm_target.append(this_lm_target)
lm_input.append(this_lm_input)
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
return lm_target, lm_input, lm_input_len
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Args:
text: (B, L, D)
text_lengths: (B,)
audio: (B, T, N) or (B, T)
audio_lengths: (B,)
"""
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token)
# 3. sos and task_id
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
# 2. encode speech_token
speech_token_emb = self.speech_embedding(speech_token)
# 3. prepare llm_input/target
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
speech_token, speech_token_emb, speech_token_len)
lm_target = lm_target.to(device)
# 4. run lm forward
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
logits = self.llm_decoder(lm_output)
loss = self.criterion_ce(logits, lm_target.to(device))
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
return {'loss': loss, 'acc': acc}
def forward_dpo(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
reject_speech_token = batch['reject_speech_token'].to(device)
reject_speech_token_len = batch['reject_speech_token_len'].to(device)
# 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token)
# 3. sos and task_id
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
# 2. encode speech_token
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
speech_token_combined = speech_token + reject_speech_token
speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
speech_token_combined_emb = self.speech_embedding(speech_token_combined)
# 3. prepare llm_input/target
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
lm_target = lm_target.to(device)
# 4. run lm forward
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
logits = self.llm_decoder(lm_output)
chosen_logits = logits[:text_token.shape[0]]
rejected_logits = logits[text_token.shape[0]:]
chosen_lm_target = lm_target[:text_token.shape[0]]
rejected_lm_target = lm_target[text_token.shape[0]:]
loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
# 5. calculate dpo logits
chosen_lm_mask = chosen_lm_target == IGNORE_ID
rejected_lm_mask = rejected_lm_target == IGNORE_ID
chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
@torch.inference_mode()
def inference(
self,
@ -296,6 +450,7 @@ class Qwen2LM(TransformerLM):
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
uuid: str = '',
) -> Generator[torch.Tensor, None, None]:
device = text.device
text = torch.concat([prompt_text, text], dim=1)
@ -303,35 +458,68 @@ class Qwen2LM(TransformerLM):
text = self.llm.model.model.embed_tokens(text)
# 3. concat llm_input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
# 5. step by step decode
out_tokens = []
cache = None
for i in range(max_len):
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
break
if top_ids > self.speech_token_size:
continue
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
yield token
@torch.inference_mode()
def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
if hasattr(self, 'vllm'):
from vllm import SamplingParams, RequestOutput
sampling_params = SamplingParams(top_k=sampling,
stop_token_ids=self.stop_token_ids,
min_tokens=min_len,
max_tokens=max_len)
with self.lock:
self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
self.vllm_output_queue[uuid] = queue.Queue()
out_tokens = []
while True:
with self.lock:
if self.vllm_output_queue[uuid].empty() is True:
request_outputs: List[RequestOutput] = self.vllm.step()
for request_output in request_outputs:
top_ids = list(request_output.outputs[0].token_ids)[-1]
self.vllm_output_queue[request_output.request_id].put(top_ids)
if self.vllm_output_queue[uuid].empty() is False:
top_ids = self.vllm_output_queue[uuid].get()
if top_ids in self.stop_token_ids:
break
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
if len(out_tokens) == max_len:
break
time.sleep(0.001)
with self.lock:
self.vllm_output_queue.pop(uuid)
else:
out_tokens = []
cache = None
for i in range(max_len):
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
if top_ids in self.stop_token_ids:
break
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
@torch.inference_mode()
def inference_bistream(
@ -349,20 +537,20 @@ class Qwen2LM(TransformerLM):
device = prompt_text.device
# 1. prepare input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb], dim=1)
lm_input = torch.concat([sos_emb], dim=1)
# 2. iterate text
out_tokens = []
cache = None
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
text_cache = self.llm.model.model.embed_tokens(prompt_text)
next_fill_index = -1
next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
for this_text in text:
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
# prompt_speech_token_emb not empty, try append to lm_input
@ -377,12 +565,12 @@ class Qwen2LM(TransformerLM):
break
# no prompt_speech_token_emb remain, can decode some speech token
if prompt_speech_token_emb.size(1) == 0:
if (len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
logging.info('get fill token, need to append more text token')
if text_cache.size(1) >= self.mix_ratio[0]:
lm_input_text = text_cache[:, :self.mix_ratio[0]]
logging.info('append {} text token'.format(lm_input_text.size(1)))
if len(out_tokens) != 0 and out_tokens[-1] == self.speech_token_size + 2:
if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
lm_input = lm_input_text
else:
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
@ -393,20 +581,20 @@ class Qwen2LM(TransformerLM):
while True:
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
y_pred, cache = self.llm.forward_one_step(lm_input,
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
cache=cache)
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
top_ids = self.speech_token_size + 2
top_ids = self.fill_token
next_fill_index += (self.mix_ratio[1] + 1)
else:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True).item()
if top_ids == self.speech_token_size + 2:
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True)
if top_ids == self.fill_token:
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
out_tokens.append(top_ids)
if top_ids >= self.speech_token_size:
if top_ids == self.speech_token_size + 2:
if top_ids == self.fill_token:
break
else:
raise ValueError('should not get token {}'.format(top_ids))
@ -422,13 +610,136 @@ class Qwen2LM(TransformerLM):
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
cache=cache)
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item()
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False)
out_tokens.append(top_ids)
if top_ids >= self.speech_token_size:
if top_ids == self.speech_token_size:
if top_ids == self.eos_token:
break
else:
raise ValueError('should not get token {}'.format(top_ids))
# in stream mode, yield token one by one
yield top_ids
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
class CosyVoice3LM(Qwen2LM):
def __init__(
self,
llm_input_size: int,
llm_output_size: int,
speech_token_size: int,
llm: torch.nn.Module,
sampling: Callable,
length_normalized_loss: bool = True,
lsm_weight: float = 0.0,
mix_ratio: List[int] = [5, 15],
):
torch.nn.Module.__init__(self)
self.llm_input_size = llm_input_size
self.llm_output_size = llm_output_size
self.speech_token_size = speech_token_size
# 2. build speech token language model related modules
self.sos = speech_token_size + 0
self.eos_token = speech_token_size + 1
self.task_id = speech_token_size + 2
self.fill_token = speech_token_size + 3
self.llm = llm
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
self.criterion_ce = LabelSmoothingLoss(
size=speech_token_size + 200,
padding_idx=IGNORE_ID,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
# 3. [Optional] build speech token related modules
self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size)
# 4. sampling method
self.sampling = sampling
self.mix_ratio = mix_ratio
# 5. vllm related
self.stop_token_ids = [speech_token_size + i for i in range(200)]
self.vllm_output_queue = {}
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Args:
text: (B, L, D)
text_lengths: (B,)
audio: (B, T, N) or (B, T)
audio_lengths: (B,)
"""
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
# NOTE should append instruct_token to sequence, not implemented yet
instruct_token = batch['instruct_token'].to(device)
instruct_token_len = batch['instruct_token_len'].to(device)
# 1. encode text_token
text_token_emb = self.llm.model.model.embed_tokens(text_token)
instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
# 3. sos and task_id
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
# 2. encode speech_token
speech_token_emb = self.speech_embedding(speech_token)
# 3. prepare llm_input/target
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
lm_target = lm_target.to(device)
# 4. run lm forward
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
logits = self.llm_decoder(lm_output)
loss = self.criterion_ce(logits, lm_target.to(device))
acc = th_accuracy(logits.view(-1, self.speech_token_size + 200), lm_target, ignore_label=IGNORE_ID)
return {'loss': loss, 'acc': acc}
@torch.inference_mode()
def inference(
self,
text: torch.Tensor,
text_len: torch.Tensor,
prompt_text: torch.Tensor,
prompt_text_len: torch.Tensor,
prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor,
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
uuid: str = '',
) -> Generator[torch.Tensor, None, None]:
device = text.device
text = torch.concat([prompt_text, text], dim=1)
text_len += prompt_text_len
text = self.llm.model.model.embed_tokens(text)
# 3. concat llm_input
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 4. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
# 5. step by step decode
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
yield token

View File

@ -238,7 +238,7 @@ def get_tokenizer(
)
class QwenTokenizer():
class CosyVoice2Tokenizer():
def __init__(self, token_path, skip_special_tokens=True):
super().__init__()
# NOTE: non-chat model, all these special tokens keep randomly initialized.
@ -271,9 +271,57 @@ class QwenTokenizer():
return text
class CosyVoice3Tokenizer(CosyVoice2Tokenizer):
def __init__(self, token_path, skip_special_tokens=True):
# NOTE: non-chat model, all these special tokens keep randomly initialized.
special_tokens = {
'eos_token': '<|endoftext|>',
'pad_token': '<|endoftext|>',
'additional_special_tokens': [
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
'[breath]', '<strong>', '</strong>', '[noise]',
'[laughter]', '[cough]', '[clucking]', '[accent]',
'[quick_breath]',
"<laughter>", "</laughter>",
"[hissing]", "[sigh]", "[vocalized-noise]",
"[lipsmack]", "[mn]", "<|endofsystem|>",
"[AA]", "[AA0]", "[AA1]", "[AA2]", "[AE]", "[AE0]", "[AE1]", "[AE2]", "[AH]", "[AH0]", "[AH1]", "[AH2]",
"[AO]", "[AO0]", "[AO1]", "[AO2]", "[AW]", "[AW0]", "[AW1]", "[AW2]", "[AY]", "[AY0]", "[AY1]", "[AY2]",
"[B]", "[CH]", "[D]", "[DH]", "[EH]", "[EH0]", "[EH1]", "[EH2]", "[ER]", "[ER0]", "[ER1]", "[ER2]", "[EY]",
"[EY0]", "[EY1]", "[EY2]", "[F]", "[G]", "[HH]", "[IH]", "[IH0]", "[IH1]", "[IH2]", "[IY]", "[IY0]", "[IY1]",
"[IY2]", "[JH]", "[K]", "[L]", "[M]", "[N]", "[NG]", "[OW]", "[OW0]", "[OW1]", "[OW2]", "[OY]", "[OY0]",
"[OY1]", "[OY2]", "[P]", "[R]", "[S]", "[SH]", "[T]", "[TH]", "[UH]", "[UH0]", "[UH1]", "[UH2]", "[UW]",
"[UW0]", "[UW1]", "[UW2]", "[V]", "[W]", "[Y]", "[Z]", "[ZH]",
"[a]", "[ai]", "[an]", "[ang]", "[ao]", "[b]", "[c]", "[ch]", "[d]", "[e]", "[ei]", "[en]", "[eng]", "[f]",
"[g]", "[h]", "[i]", "[ian]", "[in]", "[ing]", "[iu]", "[ià]", "[iàn]", "[iàng]", "[iào]", "[iá]", "[ián]",
"[iáng]", "[iáo]", "[iè]", "[ié]", "[iòng]", "[ióng]", "[iù]", "[iú]", "[iā]", "[iān]", "[iāng]", "[iāo]",
"[iē]", "[iě]", "[iōng]", "[iū]", "[iǎ]", "[iǎn]", "[iǎng]", "[iǎo]", "[iǒng]", "[iǔ]", "[j]", "[k]", "[l]",
"[m]", "[n]", "[o]", "[ong]", "[ou]", "[p]", "[q]", "[r]", "[s]", "[sh]", "[t]", "[u]", "[uang]", "[ue]",
"[un]", "[uo]", "[uà]", "[uài]", "[uàn]", "[uàng]", "[uá]", "[uái]", "[uán]", "[uáng]", "[uè]", "[ué]", "[uì]",
"[uí]", "[uò]", "[uó]", "[uā]", "[uāi]", "[uān]", "[uāng]", "[uē]", "[uě]", "[uī]", "[uō]", "[uǎ]", "[uǎi]",
"[uǎn]", "[uǎng]", "[uǐ]", "[uǒ]", "[vè]", "[w]", "[x]", "[y]", "[z]", "[zh]", "[à]", "[ài]", "[àn]", "[àng]",
"[ào]", "[á]", "[ái]", "[án]", "[áng]", "[áo]", "[è]", "[èi]", "[èn]", "[èng]", "[èr]", "[é]", "[éi]", "[én]",
"[éng]", "[ér]", "[ì]", "[ìn]", "[ìng]", "[í]", "[ín]", "[íng]", "[ò]", "[òng]", "[òu]", "[ó]", "[óng]", "[óu]",
"[ù]", "[ùn]", "[ú]", "[ún]", "[ā]", "[āi]", "[ān]", "[āng]", "[āo]", "[ē]", "[ēi]", "[ēn]", "[ēng]", "[ě]",
"[ěi]", "[ěn]", "[ěng]", "[ěr]", "[ī]", "[īn]", "[īng]", "[ō]", "[ōng]", "[ōu]", "[ū]", "[ūn]", "[ǎ]", "[ǎi]",
"[ǎn]", "[ǎng]", "[ǎo]", "[ǐ]", "[ǐn]", "[ǐng]", "[ǒ]", "[ǒng]", "[ǒu]", "[ǔ]", "[ǔn]", "[ǘ]", "[ǚ]", "[ǜ]"
]
}
self.special_tokens = special_tokens
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
self.tokenizer.add_special_tokens(special_tokens)
self.skip_special_tokens = skip_special_tokens
@lru_cache(maxsize=None)
def get_qwen_tokenizer(
token_path: str,
skip_special_tokens: bool
) -> QwenTokenizer:
return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
skip_special_tokens: bool,
version: str = 'cosyvoice2'
):
if version == 'cosyvoice2':
return CosyVoice2Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
elif version == 'cosyvoice3':
return CosyVoice3Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
else:
raise ValueError

View File

@ -19,6 +19,7 @@ from typing import Tuple
import torch
from torch import nn
import torch.nn.functional as F
class ConvolutionModule(nn.Module):
@ -143,3 +144,115 @@ class ConvolutionModule(nn.Module):
x.masked_fill_(~mask_pad, 0.0)
return x.transpose(1, 2), new_cache
# NOTE(Xiang Lyu) causal conv module used in convolution-based vocoder
class CausalConv1d(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
causal_type: str = 'left',
device=None,
dtype=None
) -> None:
super(CausalConv1d, self).__init__(in_channels, out_channels,
kernel_size, stride=1,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype)
assert stride == 1
self.causal_padding = int((kernel_size * dilation - dilation) / 2) * 2 + (kernel_size + 1) % 2
assert causal_type in ['left', 'right']
self.causal_type = causal_type
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor]:
input_timestep = x.shape[2]
if cache.size(2) == 0:
cache = torch.zeros(x.shape[0], x.shape[1], self.causal_padding).to(x)
assert cache.size(2) == self.causal_padding
if self.causal_type == 'left':
x = torch.concat([cache, x], dim=2)
else:
x = torch.concat([x, cache], dim=2)
x = super(CausalConv1d, self).forward(x)
assert x.shape[2] == input_timestep
return x
class CausalConv1dDownSample(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
super(CausalConv1dDownSample, self).__init__(in_channels, out_channels,
kernel_size, stride,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype)
assert stride != 1 and dilation == 1
assert kernel_size % stride == 0
self.causal_padding = stride - 1
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
if cache.size(2) == 0:
x = F.pad(x, (self.causal_padding, 0), value=0.0)
else:
assert cache.size(2) == self.causal_padding
x = torch.concat([cache, x], dim=2)
x = super(CausalConv1dDownSample, self).forward(x)
return x
class CausalConv1dUpsample(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
super(CausalConv1dUpsample, self).__init__(in_channels, out_channels,
kernel_size, 1,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype)
assert dilation == 1
self.causal_padding = kernel_size - 1
self.upsample = torch.nn.Upsample(scale_factor=stride, mode='nearest')
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.upsample(x)
input_timestep = x.shape[2]
if cache.size(2) == 0:
x = F.pad(x, (self.causal_padding, 0), value=0.0)
else:
assert cache.size(2) == self.causal_padding
x = torch.concat([cache, x], dim=2)
x = super(CausalConv1dUpsample, self).forward(x)
assert input_timestep == x.shape[2]
return x

View File

@ -287,8 +287,16 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
Returns:
torch.Tensor: Corresponding encoding
"""
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
]
# How to subscript a Union type:
# https://github.com/pytorch/pytorch/issues/69434
if isinstance(offset, int):
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
]
elif isinstance(offset, torch.Tensor):
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
]
return pos_emb

View File

@ -56,7 +56,7 @@ class Upsample1D(nn.Module):
# In this mode, first repeat interpolate, than conv with stride=1
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
outputs = self.conv(outputs)
@ -64,30 +64,37 @@ class Upsample1D(nn.Module):
class PreLookaheadLayer(nn.Module):
def __init__(self, channels: int, pre_lookahead_len: int = 1):
def __init__(self, in_channels: int, channels: int, pre_lookahead_len: int = 1):
super().__init__()
self.in_channels = in_channels
self.channels = channels
self.pre_lookahead_len = pre_lookahead_len
self.conv1 = nn.Conv1d(
channels, channels,
in_channels, channels,
kernel_size=pre_lookahead_len + 1,
stride=1, padding=0,
)
self.conv2 = nn.Conv1d(
channels, channels,
channels, in_channels,
kernel_size=3, stride=1, padding=0,
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor:
"""
inputs: (batch_size, seq_len, channels)
"""
outputs = inputs.transpose(1, 2).contiguous()
context = context.transpose(1, 2).contiguous()
# look ahead
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
if context.size(2) == 0:
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
else:
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
assert context.size(2) == self.pre_lookahead_len
outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
outputs = F.leaky_relu(self.conv1(outputs))
# outputs
outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
outputs = self.conv2(outputs)
outputs = outputs.transpose(1, 2).contiguous()
@ -193,7 +200,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
# convolution module definition
convolution_layer_args = (output_size, cnn_module_kernel, activation,
cnn_module_norm, causal)
self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
self.pre_lookahead_layer = PreLookaheadLayer(in_channels=512, channels=512, pre_lookahead_len=3)
self.encoders = torch.nn.ModuleList([
ConformerEncoderLayer(
output_size,
@ -238,8 +245,10 @@ class UpsampleConformerEncoder(torch.nn.Module):
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
context: torch.Tensor = torch.zeros(0, 0, 0),
decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1,
streaming: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed positions in tensor.
@ -269,15 +278,14 @@ class UpsampleConformerEncoder(torch.nn.Module):
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
if context.size(1) != 0:
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
context_masks = torch.ones(1, 1, context.size(1)).to(masks)
context, _, _ = self.embed(context, context_masks, offset=xs.size(1))
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks)
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
# lookahead + conformer encoder
xs = self.pre_lookahead_layer(xs)
xs = self.pre_lookahead_layer(xs, context=context)
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
# upsample + conformer encoder
@ -288,12 +296,7 @@ class UpsampleConformerEncoder(torch.nn.Module):
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
xs, pos_emb, masks = self.up_embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size * self.up_layer.stride,
num_decoding_left_chunks)
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
if self.normalize_before:

View File

@ -32,10 +32,10 @@ from cosyvoice.transformer.attention import (MultiHeadedAttention,
RelPositionMultiHeadedAttention)
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
from cosyvoice.llm.llm import TransformerLM, Qwen2LM
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
from cosyvoice.hifigan.generator import HiFTGenerator
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
from cosyvoice.llm.llm import TransformerLM, Qwen2LM, CosyVoice3LM
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
COSYVOICE_ACTIVATION_CLASSES = {
@ -80,4 +80,6 @@ def get_model_type(configs):
return CosyVoiceModel
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
return CosyVoice2Model
if isinstance(configs['llm'], CosyVoice3LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
return CosyVoice3Model
raise TypeError('No valid model type found!')

View File

@ -1,5 +1,6 @@
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,6 +16,7 @@
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Unility functions for Transformer."""
import queue
import random
from typing import List
@ -23,6 +25,33 @@ import torch
IGNORE_ID = -1
instruct_list = ["You are a helpful assistant. 请用广东话表达。<|endofprompt|>",
"You are a helpful assistant. 请用东北话表达。<|endofprompt|>",
"You are a helpful assistant. 请用甘肃话表达。<|endofprompt|>",
"You are a helpful assistant. 请用贵州话表达。<|endofprompt|>",
"You are a helpful assistant. 请用河南话表达。<|endofprompt|>",
"You are a helpful assistant. 请用湖北话表达。<|endofprompt|>",
"You are a helpful assistant. 请用湖南话表达。<|endofprompt|>",
"You are a helpful assistant. 请用江西话表达。<|endofprompt|>",
"You are a helpful assistant. 请用闽南话表达。<|endofprompt|>",
"You are a helpful assistant. 请用宁夏话表达。<|endofprompt|>",
"You are a helpful assistant. 请用山西话表达。<|endofprompt|>",
"You are a helpful assistant. 请用陕西话表达。<|endofprompt|>",
"You are a helpful assistant. 请用山东话表达。<|endofprompt|>",
"You are a helpful assistant. 请用上海话表达。<|endofprompt|>",
"You are a helpful assistant. 请用四川话表达。<|endofprompt|>",
"You are a helpful assistant. 请用天津话表达。<|endofprompt|>",
"You are a helpful assistant. 请用云南话表达。<|endofprompt|>",
"You are a helpful assistant. Please say a sentence as loudly as possible.<|endofprompt|>",
"You are a helpful assistant. Please say a sentence in a very soft voice.<|endofprompt|>",
"You are a helpful assistant. 请用尽可能慢地语速说一句话。<|endofprompt|>",
"You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>",
"You are a helpful assistant. 请非常开心地说一句话。<|endofprompt|>",
"You are a helpful assistant. 请非常伤心地说一句话。<|endofprompt|>",
"You are a helpful assistant. 请非常生气地说一句话。<|endofprompt|>",
"You are a helpful assistant. 我想体验一下小猪佩奇风格,可以吗?<|endofprompt|>",
"You are a helpful assistant. 你可以尝试用机器人的方式解答吗?<|endofprompt|>"]
def pad_list(xs: List[torch.Tensor], pad_value: int):
"""Perform padding for the list of tensors.
@ -128,12 +157,12 @@ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
break
prob = torch.tensor(prob).to(weighted_scores)
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
top_ids = indices[prob.multinomial(1, replacement=True)]
top_ids = indices[prob.multinomial(1, replacement=True)].item()
return top_ids
def random_sampling(weighted_scores, decoded_tokens, sampling):
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True).item()
return top_ids
@ -164,3 +193,21 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
mask = (1.0 - mask) * -1.0e+10
return mask
class TrtContextWrapper:
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
self.trt_engine = trt_engine
for _ in range(trt_concurrent):
trt_context = trt_engine.create_execution_context()
trt_stream = torch.cuda.stream(torch.cuda.Stream(device))
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
self.trt_context_pool.put([trt_context, trt_stream])
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
def acquire_estimator(self):
return self.trt_context_pool.get(), self.trt_engine
def release_estimator(self, context, stream):
self.trt_context_pool.put([context, stream])

View File

@ -25,14 +25,16 @@ from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, l
class Executor:
def __init__(self, gan: bool = False):
def __init__(self, gan: bool = False, ref_model: torch.nn.Module = None, dpo_loss: torch.nn.Module = None):
self.gan = gan
self.ref_model = ref_model
self.dpo_loss = dpo_loss
self.step = 0
self.epoch = 0
self.rank = int(os.environ.get('RANK', 0))
self.device = torch.device('cuda:{}'.format(self.rank))
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join):
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None):
''' Train one epoch
'''
@ -44,6 +46,8 @@ class Executor:
# torch.nn.parallel.DistributedDataParallel to be able to train
# with uneven inputs across participating processes.
model.train()
if self.ref_model is not None:
self.ref_model.eval()
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
with model_context():
for batch_idx, batch_dict in enumerate(train_data_loader):
@ -65,7 +69,7 @@ class Executor:
context = nullcontext
with context():
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model=self.ref_model, dpo_loss=self.dpo_loss)
info_dict = batch_backward(model, scaler, info_dict)
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
@ -162,7 +166,7 @@ class Executor:
for k, v in info_dict['loss_dict'].items():
if k not in total_loss_dict:
total_loss_dict[k] = []
total_loss_dict[k].append(v.item() * num_utts)
total_loss_dict[k].append(v.mean().item() * num_utts)
log_per_step(None, info_dict)
for k, v in total_loss_dict.items():
total_loss_dict[k] = sum(v) / total_num_utts

View File

@ -1,5 +1,6 @@
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,7 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import json
import torch
import torchaudio
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
@ -38,22 +41,17 @@ def read_json_lists(list_file):
return results
def load_wav(wav, target_sr):
def load_wav(wav, target_sr, min_sr=16000):
speech, sample_rate = torchaudio.load(wav, backend='soundfile')
speech = speech.mean(dim=0, keepdim=True)
if sample_rate != target_sr:
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
return speech
def convert_onnx_to_trt(trt_model, onnx_model, fp16):
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
import tensorrt as trt
_min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)]
_opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)]
_max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)]
input_names = ["x", "mask", "mu", "t", "spks", "cond"]
logging.info("Converting onnx to trt...")
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
logger = trt.Logger(trt.Logger.INFO)
@ -61,7 +59,7 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
network = builder.create_network(network_flags)
parser = trt.OnnxParser(network, logger)
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
@ -72,8 +70,8 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
print(parser.get_error(error))
raise ValueError('failed to parse {}'.format(onnx_model))
# set input shapes
for i in range(len(input_names)):
profile.set_shape(input_names[i], _min_shape[i], _opt_shape[i], _max_shape[i])
for i in range(len(trt_kwargs['input_names'])):
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
# set input and output data type
for i in range(network.num_inputs):
@ -87,3 +85,34 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
# save trt engine
with open(trt_model, "wb") as f:
f.write(engine_bytes)
logging.info("Succesfully convert onnx to trt...")
# NOTE do not support bistream inference as only speech token embedding/head is kept
def export_cosyvoice2_vllm(model, model_path, device):
if os.path.exists(model_path):
return
dtype = torch.bfloat16
# lm_head
use_bias = True if model.llm_decoder.bias is not None else False
model.llm.model.lm_head = model.llm_decoder
# embed_tokens
embed_tokens = model.llm.model.model.embed_tokens
model.llm.model.set_input_embeddings(model.speech_embedding)
model.llm.model.to(device)
model.llm.model.to(dtype)
tmp_vocab_size = model.llm.model.config.vocab_size
tmp_tie_embedding = model.llm.model.config.tie_word_embeddings
del model.llm.model.generation_config.eos_token_id
del model.llm.model.config.bos_token_id
del model.llm.model.config.eos_token_id
model.llm.model.config.vocab_size = model.speech_embedding.num_embeddings
model.llm.model.config.tie_word_embeddings = False
model.llm.model.config.use_bias = use_bias
model.llm.model.save_pretrained(model_path)
if use_bias is True:
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
model.llm.model.config.vocab_size = tmp_vocab_size
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
model.llm.model.set_input_embeddings(embed_tokens)

View File

@ -1,5 +1,6 @@
import torch
import torch.nn.functional as F
from typing import Tuple
def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
@ -18,3 +19,39 @@ def mel_loss(real_speech, generated_speech, mel_transforms):
mel_g = transform(generated_speech)
loss += F.l1_loss(mel_g, mel_r)
return loss
class DPOLoss(torch.nn.Module):
"""
DPO Loss
"""
def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
super().__init__()
self.beta = beta
self.label_smoothing = label_smoothing
self.ipo = ipo
def forward(
self,
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
reference_chosen_logps: torch.Tensor,
reference_rejected_logps: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
logits = pi_logratios - ref_logratios
if self.ipo:
losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
else:
# Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
loss = losses.mean()
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
return loss, chosen_rewards, rejected_rewards

View File

@ -15,7 +15,6 @@
# limitations under the License.
import torch
from cosyvoice.utils.file_utils import logging
'''
def subsequent_mask(
size: int,
@ -153,7 +152,6 @@ def subsequent_chunk_mask(
[1, 1, 1, 1]]
"""
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
# actually this is not needed after we have inference cache implemented, will remove it later
pos_idx = torch.arange(size, device=device)
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
@ -233,8 +231,8 @@ def add_optional_chunk_mask(xs: torch.Tensor,
chunk_masks = masks
assert chunk_masks.dtype == torch.bool
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
chunk_masks[chunk_masks.sum(dim=-1)==0] = True
print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
return chunk_masks

View File

@ -50,10 +50,10 @@ def init_distributed(args):
return world_size, local_rank, rank
def init_dataset_and_dataloader(args, configs, gan):
def init_dataset_and_dataloader(args, configs, gan, dpo):
data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=True, partition=True)
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=False, partition=False)
train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=True, partition=True)
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='dev', gan=gan, dpo=dpo, shuffle=False, partition=False)
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
train_data_loader = DataLoader(train_dataset,
@ -71,7 +71,7 @@ def init_dataset_and_dataloader(args, configs, gan):
def check_modify_and_save_config(args, configs):
if args.train_engine == "torch_ddp":
configs['train_conf']["dtype"] = 'fp32'
configs['train_conf']["dtype"] = 'bf16' if args.use_amp is True else 'fp32'
else:
with open(args.deepspeed_config, 'r') as fin:
ds_configs = json.load(fin)
@ -164,18 +164,18 @@ def init_optimizer_and_scheduler(args, configs, model, gan):
raise ValueError("unknown scheduler: " + configs['train_conf'])
if configs['train_conf']['optim_d'] == 'adam':
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
elif configs['train_conf']['optim_d'] == 'adamw':
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
else:
raise ValueError("unknown optimizer: " + configs['train_conf'])
if configs['train_conf']['scheduler_d'] == 'warmuplr':
scheduler_type = WarmupLR
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf'])
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_d'])
elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
scheduler_type = NoamHoldAnnealing
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf'])
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_d'])
elif configs['train_conf']['scheduler'] == 'constantlr':
scheduler_type = ConstantLR
scheduler_d = ConstantLR(optimizer_d)
@ -235,7 +235,7 @@ def cosyvoice_join(group_join, info_dict):
return False
def batch_forward(model, batch, scaler, info_dict):
def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None):
device = int(os.environ.get('LOCAL_RANK', 0))
dtype = info_dict["dtype"]
@ -247,12 +247,30 @@ def batch_forward(model, batch, scaler, info_dict):
dtype = torch.float32
if info_dict['train_engine'] == 'torch_ddp':
autocast = torch.cuda.amp.autocast(enabled=scaler is not None)
autocast = torch.cuda.amp.autocast(enabled=scaler is not None, dtype=dtype)
else:
autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
with autocast:
info_dict['loss_dict'] = model(batch, device)
if ref_model is not None and dpo_loss is not None:
chosen_logps = info_dict['loss_dict']["chosen_logps"]
rejected_logps = info_dict['loss_dict']["rejected_logps"]
sft_loss = info_dict['loss_dict']['loss']
with torch.no_grad():
ref_loss_dict = ref_model(batch, device)
reference_chosen_logps = ref_loss_dict["chosen_logps"]
reference_rejected_logps = ref_loss_dict["rejected_logps"]
preference_loss, chosen_reward, reject_reward = dpo_loss(
chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps
)
dpo_acc = (chosen_reward > reject_reward).float().mean()
info_dict['loss_dict']["loss"] = preference_loss + sft_loss
info_dict['loss_dict']["sft_loss"] = sft_loss
info_dict['loss_dict']["dpo_loss"] = preference_loss
info_dict['loss_dict']["dpo_acc"] = dpo_acc
info_dict['loss_dict']["chosen_reward"] = chosen_reward.mean()
info_dict['loss_dict']["reject_reward"] = reject_reward.mean()
return info_dict
@ -286,11 +304,15 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
# optimizer.step().
if torch.isfinite(grad_norm):
scaler.step(optimizer)
else:
logging.warning('get infinite grad_norm, check your code/data if it appears frequently')
scaler.update()
else:
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
if torch.isfinite(grad_norm):
optimizer.step()
else:
logging.warning('get infinite grad_norm, check your code/data if it appears frequently')
optimizer.zero_grad()
scheduler.step()
info_dict["lr"] = optimizer.param_groups[0]['lr']
@ -336,7 +358,7 @@ def log_per_save(writer, info_dict):
rank = int(os.environ.get('RANK', 0))
logging.info(
'Epoch {} Step {} CV info lr {} {} rank {}'.format(
epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()])))
epoch, step + 1, lr, rank, ' '.join(['{} {}'.format(k, v) for k, v in loss_dict.items()])))
if writer is not None:
for k in ['epoch', 'lr']:

View File

@ -0,0 +1,116 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Optional
from packaging.version import parse as vparse
import vllm
# vLLM-0.11.0+ only support V1 engine
VLLM_V1_ENGINE_ONLY: bool = vparse(vllm.__version__) >= vparse("0.11.0")
if VLLM_V1_ENGINE_ONLY:
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.model_executor.models.qwen2 import *
class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
True,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: Optional[SamplingMetadata] = None,
) -> Optional[torch.Tensor]:
if VLLM_V1_ENGINE_ONLY:
logits = self.logits_processor(self.lm_head, hidden_states,
self.lm_head.bias)
else:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)

View File

@ -4,7 +4,7 @@ ARG VENV_NAME="cosyvoice"
ENV VENV=$VENV_NAME
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
ENV DEBIAN_FRONTEN=noninteractive
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
SHELL ["/bin/bash", "--login", "-c"]
@ -46,6 +46,6 @@ RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
RUN conda activate ${VENV} && conda install -y -c conda-forge pynini==2.1.5
RUN conda activate ${VENV} && cd CosyVoice && \
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir
WORKDIR /workspace/CosyVoice

106
example.py Normal file
View File

@ -0,0 +1,106 @@
import sys
sys.path.append('third_party/Matcha-TTS')
from cosyvoice.cli.cosyvoice import AutoModel
import torchaudio
def cosyvoice_example():
""" CosyVoice Usage, check https://fun-audio-llm.github.io/ for more details
"""
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M-SFT')
# sft usage
print(cosyvoice.list_available_spks())
# change stream=True for chunk stream inference
for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M')
# zero_shot usage
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# cross_lingual usage, <|zh|><|en|><|ja|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.',
'./asset/cross_lingual_prompt.wav')):
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# vc usage
for i, j in enumerate(cosyvoice.inference_vc('./asset/cross_lingual_prompt.wav', './asset/zero_shot_prompt.wav')):
torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M-Instruct')
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男',
'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.<|endofprompt|>')):
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
def cosyvoice2_example():
""" CosyVoice2 Usage, check https://funaudiollm.github.io/cosyvoice2/ for more details
"""
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice2-0.5B')
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
# zero_shot usage
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# save zero_shot spk for future usage
assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', 'my_zero_shot_spk') is True
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk')):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
cosyvoice.save_spkinfo()
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', './asset/zero_shot_prompt.wav')):
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# instruct usage
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话<|endofprompt|>', './asset/zero_shot_prompt.wav')):
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# bistream usage, you can use generator as input, this is useful when using text llm model as input
# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
def text_generator():
yield '收到好友从远方寄来的生日礼物,'
yield '那份意外的惊喜与深深的祝福'
yield '让我心中充满了甜蜜的快乐,'
yield '笑容如花儿般绽放。'
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', stream=False)):
torchaudio.save('zero_shot_bistream_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
def cosyvoice3_example():
""" CosyVoice3 Usage, check https://funaudiollm.github.io/cosyvoice3/ for more details
"""
cosyvoice = AutoModel(model_dir='pretrained_models/Fun-CosyVoice3-0.5B')
# zero_shot usage
for i, j in enumerate(cosyvoice.inference_zero_shot('八百标兵奔北坡,北坡炮兵并排跑,炮兵怕把标兵碰,标兵怕碰炮兵炮。', 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。',
'./asset/zero_shot_prompt.wav', stream=False)):
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L280
for i, j in enumerate(cosyvoice.inference_cross_lingual('You are a helpful assistant.<|endofprompt|>[breath]因为他们那一辈人[breath]在乡里面住的要习惯一点,[breath]邻居都很活络,[breath]嗯,都很熟悉。[breath]',
'./asset/zero_shot_prompt.wav', stream=False)):
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# instruct usage, for supported control, check cosyvoice/utils/common.py#L28
for i, j in enumerate(cosyvoice.inference_instruct2('好少咯,一般系放嗰啲国庆啊,中秋嗰啲可能会咯。', 'You are a helpful assistant. 请用广东话表达。<|endofprompt|>',
'./asset/zero_shot_prompt.wav', stream=False)):
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', 'You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>',
'./asset/zero_shot_prompt.wav', stream=False)):
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
# hotfix usage
for i, j in enumerate(cosyvoice.inference_zero_shot('高管也通过电话、短信、微信等方式对报道[j][ǐ]予好评。', 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。',
'./asset/zero_shot_prompt.wav', stream=False)):
torchaudio.save('hotfix_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
def main():
# cosyvoice_example()
# cosyvoice2_example()
cosyvoice3_example()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,6 @@
FROM verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
COPY requirements.txt /myworkspace/requirements.txt
RUN pip install -r /myworkspace/requirements.txt
RUN pip install -U nvidia-pytriton
RUN git clone https://github.com/yuekaizhang/verl.git /myworkspace/verl -b thread && cd /myworkspace/verl && pip install --no-deps -e .
RUN git clone https://github.com/yuekaizhang/PytritonSenseVoice.git /myworkspace/PytritonSenseVoice && cd /myworkspace/PytritonSenseVoice && pip install -e .

View File

@ -0,0 +1,125 @@
# CosyVoice2 LLM Reinforcement Learning Recipe
This recipe demonstrates how to fine-tune the **CosyVoice2** large language model with reinforcement learning algorithms—specifically **GRPO**—using the [veRL](https://github.com/volcengine/verl) framework. Our experiments show that applying GRPO reduces the character error rate (CER) on the CosyVoice3 `zero_shot_zh` set from 4.08% to 3.36%.
## Table of Contents
- [Environment Setup](#environment-setup)
- [Data Preparation](#data-preparation)
- [Reward Function & ASR Server](#reward-function--asr-server)
- [Training](#training)
- [Evaluation](#evaluation)
- [Export Model](#export-model)
- [Results](#results)
- [Acknowledgement](#acknowledgement)
## Environment Setup
We recommend using the pre-built Docker image below. Alternatively, you can manually install the dependencies following the Dockerfile.
```bash
docker pull soar97/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
```
If Docker is not available, you can refer to `run.sh` `stage -2` to install the dependencies locally.
## Data Preparation
`prepare_data.py` expects a JSON/JSONL file with at least the following schema:
```jsonc
{
"text": "An example sentence to be synthesized."
}
```
You can download the JSONL files from the metadata directory of the [SparkAudio/voxbox](https://huggingface.co/datasets/SparkAudio/voxbox/tree/main/metadata) dataset on Hugging Face.
Stage `0` converts raw JSONL files into the parquet format expected by veRL:
```bash
bash run.sh 0 0
```
Create two JSONL files—`train.jsonl` and `test.jsonl`.
The script will then generate two Parquet files:
```
data/parquet_tiny/train.parquet
data/parquet_tiny/test.parquet
```
Each sample is automatically wrapped into a CosyVoice2-style prompt so that the LLM learns to output CosyVoice2 speech tokens.
## Reward Function & ASR Server
To compute rewards, we run a lightweight server that:
1. Converts generated speech tokens back to a 16 kHz waveform with the **CosyVoice2** pretrained U-Net model.
2. Transcribes the waveform with **SenseVoice** ASR.
3. Calculates the pinyin-level error rate relative to the ground-truth text and maps it to a score between 0 and 1.
Start the server (stage `1`) in a dedicated terminal or on a separate GPU:
```bash
bash run.sh 1 1
# Triton server listens on ports 8000/8001/8002
```
The custom reward implementation is located in [`reward_tts.py`](./reward_tts.py) and calls the server to obtain the reward score.
## Training
Run stage `2` to start GRPO training:
```bash
bash run.sh 2 2
```
Key CLI arguments passed to `verl.trainer.main_ppo`:
* `algorithm.adv_estimator=grpo` use GRPO instead of PPO.
* `data.train_files=data/parquet_aishell3/train.parquet` and `data.val_files=data/parquet_aishell3/test.parquet`
* `custom_reward_function.path=reward_tts.py` custom reward function described above.
Adjust `CUDA_VISIBLE_DEVICES`, batch sizes, and other hyperparameters to match your hardware.
> [!TIP]
> Note: the lm_head bias is disabled during training to make the model compatible with VLLM and Transformers' Qwen model.
## Evaluation
After training is complete, collect the sharded FSDP weights and export a Hugging Face-style checkpoint (stage `3`):
```bash
bash run.sh 3 3 # merges weights into $llm_path/merged_hf_model
```
You can then evaluate the model on the CosyVoice3 zero-shot Chinese test set (stage `4`):
```bash
bash run.sh 4 4
```
This command launches distributed inference via `infer_dataset.py` and computes WER with `scripts/compute_wer.sh`.
> [!TIP]
> The script also supports the Seed-TTS test set by setting `dataset=test_zh`.
## Export Model
To use the RL-trained model with the official CosyVoice repository:
```bash
bash run.sh 5 5
```
The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
> [!TIP]
> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
## Results
| Model | Seed-TTS `test_zh` CER | CosyVoice3 `zero_shot_zh` CER | Comment |
|-------|------------------------|------------------------------|---------|
| CosyVoice2 LLM (official) | 1.45% | 4.08% | See the [paper](https://arxiv.org/abs/2412.10117) |
| CosyVoice2 LLM + GRPO | 1.37% | **3.36%** | See the [decoding results](yuekai/official-cosyvoice-llm-grpo-aishell3), Hugging Face-format model |
## Acknowledgement
This work was inspired by the implementation in [ch-tts-llasa-rl-grpo](https://github.com/channel-io/ch-tts-llasa-rl-grpo).

View File

@ -0,0 +1,71 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
python3 hf2pretrained.py --hf-cosyvoice2-llm-path /workspace/rl-exp/checkpoint-400 --output-path /workspace/CosyVoice2-0.5B/llm-new.pt
"""
from argparse import ArgumentParser
import torch
from safetensors import safe_open
from transformers import AutoTokenizer
def get_args():
parser = ArgumentParser()
parser.add_argument(
"--hf-cosyvoice2-llm-path",
type=str,
default=None,
help="The RL trained CosyVoice2 model path in HuggingFace format",
)
parser.add_argument(
"--output-path",
type=str,
default="./llm.pt",
help="The path to save the llm.pt",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
tokenizer = AutoTokenizer.from_pretrained(args.hf_cosyvoice2_llm_path)
speech_start_idx = tokenizer.convert_tokens_to_ids("<|s_0|>")
cosyvoice2_token_size = 6561 + 3
llm_embedding_vocab_size = 2
hf_tensors = {}
with safe_open(f"{args.hf_cosyvoice2_llm_path}/model.safetensors", framework="pt", device="cpu") as f:
for k in f.keys():
if k.startswith("lm_head.bias"):
# RL trained model disable bias for lm_head
continue
new_k = "llm.model." + k
hf_tensors[new_k] = f.get_tensor(k)
if k.startswith("lm_head"):
hf_tensors["llm_decoder.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
hf_tensors["llm_decoder.bias"] = torch.zeros_like(hf_tensors["llm_decoder.weight"][:, 0])
if k.startswith("model.embed_tokens"):
hf_tensors["speech_embedding.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
hf_tensors["llm_embedding.weight"] = f.get_tensor(k)[speech_start_idx + cosyvoice2_token_size:speech_start_idx + cosyvoice2_token_size + llm_embedding_vocab_size]
# use tie_word_embeddings=True
hf_tensors["llm.model.model.embed_tokens.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"][:151936]
hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"]
torch.save(hf_tensors, args.output_path)

View File

@ -0,0 +1,397 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Example Usage
dataset=zero_shot_zh
output_dir=./outputs_rl_aishell3_step${step}_${dataset}_jit_trt_fp16_reward_tts
token2wav_path=/workspace/CosyVoice2-0.5B
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --nproc_per_node=8 \
infer_dataset.py \
--output-dir $output_dir \
--llm-model-name-or-path $llm_path/merged_hf_model \
--token2wav-path $token2wav_path \
--split-name ${dataset} || exit 1
"""
import argparse
import json
import os
import sys
from pathlib import Path
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio
from cosyvoice.cli.cosyvoice import CosyVoice2
from cosyvoice.utils.file_utils import load_wav
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
import soundfile as sf
import s3tokenizer
from functools import partial
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
try:
torch.multiprocessing.set_start_method("spawn")
except RuntimeError:
pass
TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501
def audio_decode_cosyvoice2(
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
):
"""
Generate audio from tokens with optional tone and prompt embedding.
"""
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
"empty", prompt_text, prompt_speech_16k, 24000
)
tts_mel, _ = codec_decoder.model.flow.inference(
token=audio_tokens.to(codec_decoder.model.device),
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
codec_decoder.model.device
),
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
codec_decoder.model.device
),
prompt_token_len=torch.tensor(
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
).to(codec_decoder.model.device),
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
codec_decoder.model.device
),
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
codec_decoder.model.device
),
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
finalize=True,
)
audio_hat, _ = codec_decoder.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
return audio_hat
def extract_speech_ids(speech_tokens_str):
"""Extract speech IDs from token strings like <|s_23456|>"""
speech_ids = []
for token_str in speech_tokens_str:
if token_str.startswith('<|s_') and token_str.endswith('|>'):
num_str = token_str[4:-2]
num = int(num_str)
speech_ids.append(num)
else:
print(f"Unexpected token: {token_str}")
return speech_ids
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
speech_id_str = ""
for token in cosy2_tokens:
speech_id_str += f"<|s_{token}|>"
return speech_id_str
def get_args():
parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
)
parser.add_argument(
"--output-dir", required=True, type=str, help="dir to save result"
)
parser.add_argument(
"--batch-size",
default=1,
type=int,
help="batch size (per-device) for inference",
)
parser.add_argument(
"--num-workers", type=int, default=1, help="workers for dataloader"
)
parser.add_argument(
"--prefetch", type=int, default=5, help="prefetch for dataloader"
)
parser.add_argument(
"--llm-model-name-or-path",
required=True,
type=str,
help="LLM model path (includes both model and tokenizer)",
)
parser.add_argument(
"--token2wav-path",
required=True,
type=str,
help="CosyVoice2 token2wav model path",
)
parser.add_argument(
"--prompt-text",
type=str,
default=None,
help="The prompt text for CosyVoice2",
)
parser.add_argument(
"--prompt-speech-path",
type=str,
default=None,
help="The path to the prompt speech for CosyVoice2",
)
parser.add_argument(
"--top-p",
type=float,
default=0.95,
help="top p for sampling",
)
parser.add_argument(
"--temperature",
type=float,
default=0.8,
help="temperature for sampling",
)
parser.add_argument(
"--top-k",
type=int,
default=50,
help="top k for sampling",
)
args = parser.parse_args()
return args
def data_collator(batch, tokenizer, s3_tokenizer):
"""Simplified data collator for batch_size=1 processing"""
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
mels, prompt_audio_cosy2tokens_list = [], []
for item in batch:
prompt_text, target_text = (
item["prompt_text"],
item["target_text"],
)
prompt_text_list.append(prompt_text)
# Combine prompt and target text
full_text = prompt_text + target_text
# get prompt audio for CosyVoice2 (convert to 16kHz)
ref_audio_org, ref_sr = (
item["prompt_audio"]["array"],
item["prompt_audio"]["sampling_rate"],
)
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
# ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True)
print(ref_audio_org.shape)
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio_org)
else:
ref_audio = ref_audio_org
prompt_audio_list.append(ref_audio)
if "prompt_audio_cosy2_tokens" in item:
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
else:
# convert to float first
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
if len(mels) > 0:
mels, mels_lens = s3tokenizer.padding(mels)
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
for i in range(len(codes)):
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
# Create chat template for LLM generation
chat = [
{"role": "user", "content": full_text},
{"role": "assistant", "content": prompt_audio_cosy2_id_str}
]
if 'system' in tokenizer.chat_template:
tokenizer.chat_template = TEMPLATE
input_ids = tokenizer.apply_chat_template(
chat,
tokenize=True,
return_tensors='pt',
continue_final_message=True
)
input_ids_list.append(input_ids.squeeze(0))
# For batch_size=1, no need to pad
if len(input_ids_list) == 1:
input_ids = input_ids_list[0].unsqueeze(0)
else:
# Handle batch > 1 if needed
max_len = max([len(input_ids) for input_ids in input_ids_list])
input_ids_list = [
torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids])
for input_ids in input_ids_list
]
input_ids = torch.stack(input_ids_list)
ids = [item["id"] for item in batch]
return {
"input_ids": input_ids,
"ids": ids,
"prompt_text": prompt_text_list,
"prompt_audio_list": prompt_audio_list,
}
def init_distributed():
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
rank = int(os.environ.get("RANK", 0))
print(
"Inference on multiple gpus, this gpu {}".format(local_rank)
+ ", rank {}, world_size {}".format(rank, world_size)
)
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
return world_size, local_rank, rank
def main():
args = get_args()
os.makedirs(args.output_dir, exist_ok=True)
assert torch.cuda.is_available()
world_size, local_rank, rank = init_distributed()
device = torch.device(f"cuda:{local_rank}")
# Load LLM model and tokenizer directly
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
model.eval()
model.to(device)
cosyvoice_codec = CosyVoice2(
args.token2wav_path, load_jit=True, load_trt=True, fp16=True
)
if args.prompt_speech_path:
prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
else:
prompt_speech_16k = None
s3_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").to(device) if 'zero' in args.split_name else None
dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
dataset = load_dataset(
dataset_name,
split=args.split_name,
trust_remote_code=True,
)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
shuffle=False,
num_workers=args.num_workers,
prefetch_factor=args.prefetch,
collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
)
total_steps = len(dataset)
if rank == 0:
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
for batch in dataloader:
with torch.no_grad():
input_ids = batch["input_ids"].to(device)
# Generate speech tokens using LLM
outputs = model.generate(
input_ids,
max_new_tokens=2048, # Max length for generation
do_sample=True,
top_p=args.top_p,
temperature=args.temperature,
top_k=args.top_k,
)
# Process each sample in the batch
for i in range(len(batch["ids"])):
# Extract generated tokens (excluding input)
input_length = input_ids[i].shape[0]
generated_ids = outputs[i][input_length:-1] # Remove last token if needed
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Extract speech IDs from token strings like <|s_23456|>
speech_ids = extract_speech_ids(speech_tokens_str)
if len(speech_ids) == 0:
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
continue
# Convert to tensor for CosyVoice2
audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0)
if args.prompt_text is not None:
current_prompt_text = args.prompt_text
current_prompt_audio = prompt_speech_16k
else:
current_prompt_text = batch["prompt_text"][i]
current_prompt_audio = batch["prompt_audio_list"][i]
if current_prompt_audio is not None:
# Generate audio using CosyVoice2
audio_hat = audio_decode_cosyvoice2(
audio_tokens,
current_prompt_text,
current_prompt_audio,
cosyvoice_codec,
)
# Convert to numpy and save
generated_wave = audio_hat.squeeze(0).cpu().numpy()
target_sample_rate = 24000
utt = batch["ids"][i]
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens")
else:
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
if rank == 0:
progress_bar.update(world_size * len(batch["ids"]))
if rank == 0:
progress_bar.close()
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,86 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the Text to Speech dataset to parquet format
"""
import argparse
import os
import re
import datasets
from verl.utils.hdfs_io import copy, makedirs
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file")
parser.add_argument("--test_file", required=True, help="Path to test JSON/JSONL file")
parser.add_argument("--local_dir", default=None, required=True)
parser.add_argument("--hdfs_dir", default=None)
args = parser.parse_args()
# Load datasets from local JSON files
train_dataset = datasets.load_dataset("json", data_files=args.train_file)['train']
test_dataset = datasets.load_dataset("json", data_files=args.test_file)['train']
# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
text = example.pop("text")
# use cosyvoice2 official huggingface compatible checkpoint template
question = text
answer = ""
data = {
"data_source": f"{args.train_file}_{args.test_file}", # Use file names as data source
"prompt": [
{
"role": "user",
"content": question,
},
{
"role": "assistant",
"content": answer,
},
],
"ability": "text-to-speech",
"reward_model": {"style": "rule", "ground_truth": text},
"extra_info": {
"split": split,
"index": idx,
"text": text,
},
}
return data
return process_fn
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
local_dir = args.local_dir
hdfs_dir = args.hdfs_dir
print(train_dataset)
print(test_dataset)
train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
if hdfs_dir is not None:
makedirs(hdfs_dir)
copy(src=local_dir, dst=hdfs_dir)

View File

@ -0,0 +1,133 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage: Instruct TTS
python3 infer.py \
--token2wav-path /workspace/CosyVoice2-0.5B \
--prompt-text "吃燕窝就选燕之屋本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝营养更均衡本节目由豆本豆豆奶特约播出。" \
--prompt-speech-path ./assets/prompt_audio.wav \
--model-path ./transformers_cosyvoice2_llm \
--input-text "用四川话说<|endofprompt|>扁担长,板凳宽,扁担绑在板凳上。吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮。"
"""
from cosyvoice.cli.cosyvoice import CosyVoice2
import sys
from argparse import ArgumentParser
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
def get_args():
parser = ArgumentParser()
parser.add_argument(
"--pretrained-cosyvoice2-path",
type=str,
default="/workspace/CosyVoice2-0.5B",
help="Token2Wav path, default to %(default)r",
)
parser.add_argument(
"--save-path",
type=str,
default='./transformers_cosyvoice2_llm',
help="The path to save the model",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
cosy2_model = CosyVoice2(
args.pretrained_cosyvoice2_path, load_jit=False, load_trt=False, fp16=False
)
llm = cosy2_model.model.llm.llm.model
speech_embedding = cosy2_model.model.llm.speech_embedding
llm_decoder = cosy2_model.model.llm.llm_decoder
llm_embedding = cosy2_model.model.llm.llm_embedding
tokenizer = AutoTokenizer.from_pretrained(f"{args.pretrained_cosyvoice2_path}/CosyVoice-BlankEN")
special_tokens = {
'eos_token': '<|endoftext|>',
'pad_token': '<|endoftext|>',
'additional_special_tokens': [
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
'[breath]', '<strong>', '</strong>', '[noise]',
'[laughter]', '[cough]', '[clucking]', '[accent]',
'[quick_breath]',
"<laughter>", "</laughter>",
"[hissing]", "[sigh]", "[vocalized-noise]",
"[lipsmack]", "[mn]"
]
}
tokenizer.add_special_tokens(special_tokens)
original_tokenizer_vocab_size = len(tokenizer)
cosyvoice2_token_size = 6561
new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
"<|eos1|>", "<|eos2|>", "<|eos3|>", "<|sos|>", "<|task_id|>"
]
num_added_tokens = tokenizer.add_tokens(new_tokens)
llm.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=128)
vocab_size = llm.get_input_embeddings().weight.shape[0]
feature_size = speech_embedding.embedding_dim
new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=vocab_size, bias=True)
with torch.no_grad():
# set the weight and bias of the new lm_head to 0
new_lm_head.weight.data.zero_()
# make bias value -inf
new_lm_head.bias.data.fill_(-float('inf'))
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias
llm.lm_head = new_lm_head
input_embeddings = llm.get_input_embeddings()
with torch.no_grad():
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight
input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size,
original_tokenizer_vocab_size + cosyvoice2_token_size + 1,
original_tokenizer_vocab_size + cosyvoice2_token_size + 2]
llm.generation_config.eos_token_id = eos_token_ids
llm.generation_config.temperature = 1.0
llm.generation_config.top_p = 0.8
llm.generation_config.top_k = 25
llm.config.eos_token_id = original_tokenizer_vocab_size + cosyvoice2_token_size
llm.config.vocab_size = vocab_size
llm.config.tie_word_embeddings = False
llm.config.use_bias = True
llm.to(torch.bfloat16)
llm.save_pretrained(args.save_path)
TEMPLATE = (
"{%- for message in messages %}"
"{%- if message['role'] == 'user' %}"
"{{- '<|sos|>' + message['content'] + '<|task_id|>' }}"
"{%- elif message['role'] == 'assistant' %}"
"{{- message['content']}}"
"{%- endif %}"
"{%- endfor %}"
)
tokenizer.chat_template = TEMPLATE
tokenizer.save_pretrained(args.save_path)

View File

@ -0,0 +1,31 @@
conformer==0.3.2
diffusers==0.29.0
gdown==5.1.0
gradio
hydra-core==1.3.2
HyperPyYAML==1.2.2
inflect==7.3.1
librosa==0.10.2
lightning==2.2.4
matplotlib==3.7.5
modelscope==1.15.0
networkx==3.1
omegaconf==2.3.0
onnx==1.16.0
onnxruntime-gpu==1.18.0
protobuf==4.25
pydantic==2.7.0
pyworld==0.3.4
rich==13.7.1
soundfile==0.12.1
tensorboard==2.14.0
wget==3.2
WeTextProcessing==1.0.3
s3tokenizer
tensorrt
sherpa_onnx
jiwer
zhon
numpy==1.25.2
pypinyin
openai-whisper

View File

@ -0,0 +1,233 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reward calculation for CosyVoice2-0.5B.
"""
from __future__ import annotations
import re
import json
import time
import argparse
from typing import List
import numpy as np
import requests
REWARD_SERVER_URL = "http://localhost:8000/v2/models/token2wav_asr/infer"
def _parse_ids(token_str: str) -> List[int]:
return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)]
def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float:
"""Send token IDs and ground-truth text to the Triton server and get reward."""
tokens_arr = np.array(tokens, dtype=np.int32).reshape(1, -1)
lens_arr = np.array([[tokens_arr.shape[1]]], dtype=np.int32)
gt_arr = np.array([ground_truth.encode("utf-8")], dtype=object)
payload = {
"inputs": [
{
"name": "TOKENS",
"shape": list(tokens_arr.shape),
"datatype": "INT32",
"data": tokens_arr.tolist(),
},
{
"name": "TOKEN_LENS",
"shape": list(lens_arr.shape),
"datatype": "INT32",
"data": lens_arr.tolist(),
},
{
"name": "GT_TEXT",
"shape": [1, 1],
"datatype": "BYTES",
"data": [ground_truth],
},
]
}
rsp = requests.post(
REWARD_SERVER_URL,
headers={"Content-Type": "application/json"},
json=payload,
timeout=timeout,
verify=False,
params={"request_id": "0"},
)
rsp.raise_for_status()
result = rsp.json()
try:
# Reward is returned as the first output
return float(result["outputs"][0]["data"][0])
except (KeyError, IndexError, TypeError):
return 0.0
def compute_score(
data_source: str,
solution_str: str,
ground_truth: str,
extra_info: dict | None = None,
*,
debug_dump: bool = False,
) -> float:
"""Return reward in [0, 1] using the Triton ASR service.
The reward is based on the pinyin-level WER between the ASR transcript
produced from *solution_str* and the provided *ground_truth* text.
"""
# Decode token IDs
ids = _parse_ids(solution_str)
# Query remote server for reward
try:
reward = _remote_reward(ids, ground_truth)
except Exception as e:
reward = 0.0
if debug_dump:
print(
f"\033[92m[{data_source}] Remote reward: {reward:.4f}\033[0m"
)
return reward
# CLI quick test
if __name__ == "__main__":
import sys
def get_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Test TTS CER scoring with data from JSONL file",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--input", "-i",
type=str,
default="data/emilia_zh-cosy-tiny-test.jsonl",
help="Path to input JSONL file"
)
parser.add_argument(
"--max-samples", "-n",
type=int,
default=None,
help="Maximum number of samples to process (default: all)"
)
parser.add_argument(
"--no-interactive",
action="store_true",
help="Run in non-interactive mode (process all samples without prompts)"
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug mode"
)
return parser.parse_args()
def load_jsonl(file_path: str):
"""Load data from jsonl file."""
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data.append(json.loads(line.strip()))
return data
def code_to_solution_str(code_list: List[int]) -> str:
"""Convert code list to solution string format."""
return ''.join([f"<|s_{code}|>" for code in code_list])
# Parse command line arguments
args = get_args()
try:
# Load data from jsonl file
print(f"Loading data from: {args.input}")
data_list = load_jsonl(args.input)
print(f"Loaded {len(data_list)} samples")
# Limit samples if specified
if args.max_samples is not None:
data_list = data_list[:args.max_samples]
print(f"Processing first {len(data_list)} samples (limited by --max-samples)")
# Process each sample
begin_time = time.time()
for i, sample in enumerate(data_list):
print(f"\n--- Sample {i+1}/{len(data_list)} ---")
print(f"Index: {sample.get('index', 'unknown')}")
print(f"Text: {sample['text']}")
# Extract required fields
code_list = sample['code']
ground_truth = sample['text']
data_source = sample.get('index', f'sample_{i}') # Use index as data_source
# Convert code list to solution string
solution_str = code_to_solution_str(code_list)
print(f"Solution tokens: {len(code_list)} tokens")
if args.debug:
print(f"Solution string: {solution_str}")
else:
print(f"Solution string preview: {solution_str[:100]}..." if len(solution_str) > 100 else f"Solution string: {solution_str}")
# Call compute_score function
try:
score = compute_score(
data_source=data_source,
solution_str=solution_str,
ground_truth=ground_truth,
extra_info=None,
debug_dump=args.debug
)
print(f"Final Score: {score:.4f}")
except Exception as e:
print(f"Error computing score: {e}")
# Ask user if they want to continue (for interactive mode)
if not args.no_interactive and i < len(data_list) - 1:
try:
response = input("\nPress Enter to continue or 'q' to quit: ").strip().lower()
if response == 'q':
break
except KeyboardInterrupt:
print("\nStopped by user")
break
print(f"\nProcessed {min(i+1, len(data_list))} samples")
end_time = time.time()
print(f"Time taken: {end_time - begin_time} seconds")
except FileNotFoundError:
print(f"Error: File not found - {args.input}")
print("Please check the file path or use --input to specify correct path")
print("Run with --help for usage information")
except Exception as e:
print(f"Error: {e}")

View File

@ -0,0 +1,159 @@
#!/usr/bin/env bash
set -eou pipefail
stage=-1
stop_stage=4
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
export PYTHONPATH=/workspace/CosyVoice
model_scope_model_path=./CosyVoice2-0.5B
sft_model_path=./transformers_cosyvoice2_llm
if [ $stage -le -2 ] && [ $stop_stage -ge -2 ]; then
log "stage -2: install dependencies locally if pre-built docker image is not available"
conda create -n cosyvoice2 python=3.10 -y
conda activate cosyvoice2
# install verl
git clone https://github.com/yuekaizhang/verl.git -b thread
cd verl
USE_MEGATRON=0 bash scripts/install_vllm_sglang_mcore.sh
pip install --no-deps -e .
cd -
# install requirements
pip install -r requirements.txt
pip install -U nvidia-pytriton
git clone https://github.com/yuekaizhang/PytritonSenseVoice.git && cd PytritonSenseVoice && pip install -e .
fi
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path
python3 pretrained_to_huggingface.py \
--pretrained-cosyvoice2-path $model_scope_model_path \
--save-path $sft_model_path
# Or, you could use the following command to download the huggingface compatible checkpoint
# huggingface-cli download --local-dir $sft_model_path yuekai/cosyvoice2_llm
# Note: we remove the lm_head's bias to make it compatible with the Qwen2.5-0.5B model in Transformers.
fi
data_dir=data/parquet_aishell3
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "stage 0: prepare data into verl format"
mkdir -p $data_dir
wget -O data/aishell-3.jsonl https://huggingface.co/datasets/SparkAudio/voxbox/resolve/main/metadata/aishell-3.jsonl
# total 88035 samples
head -n 80000 data/aishell-3.jsonl > data/train.jsonl
tail -n 100 data/aishell-3.jsonl > data/test.jsonl
python prepare_data.py \
--train_file data/train.jsonl \
--test_file data/test.jsonl \
--local_dir $data_dir
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: start token2wav asr server for reward function"
python3 token2wav_asr_server.py --number-of-devices 8
fi
exp_name=official_llm_aishell3_grpo
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "stage 2: grpo train"
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
export MKL_SERVICE_FORCE_INTEL=TRUE
n_gpus_per_node=8
micro_batch_size=4
train_batch_size=32
python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$data_dir/train.parquet \
data.val_files=$data_dir/test.parquet \
data.train_batch_size=$train_batch_size \
data.max_prompt_length=1024 \
data.max_response_length=512 \
data.truncation='error' \
actor_rollout_ref.model.use_remove_padding=False \
actor_rollout_ref.model.path=$sft_model_path \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$micro_batch_size \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_batch_size \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.do_sample=true \
actor_rollout_ref.rollout.temperature=0.8 \
actor_rollout_ref.rollout.top_p=0.95 \
actor_rollout_ref.rollout.top_k=25 \
actor_rollout_ref.rollout.n=4 \
actor_rollout_ref.rollout.val_kwargs.do_sample=true \
actor_rollout_ref.rollout.val_kwargs.temperature=0.8 \
actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
actor_rollout_ref.rollout.val_kwargs.top_k=25 \
reward_model.reward_manager=prime \
custom_reward_function.path=reward_tts.py \
custom_reward_function.name=compute_score \
trainer.project_name='cosyvoice2_grpo' \
trainer.experiment_name=$exp_name \
trainer.logger=['console','wandb'] \
trainer.n_gpus_per_node=$n_gpus_per_node \
trainer.nnodes=1 \
trainer.save_freq=100 \
trainer.test_freq=100 \
trainer.resume_mode='auto' \
trainer.total_epochs=1 \
trainer.val_before_train=False
fi
steps=(100 200 300 400 500)
for step in ${steps[@]}; do
llm_path=./checkpoints/cosyvoice2_grpo/$exp_name/global_step_${step}
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "stage 3: merge the model"
python -m verl.model_merger merge \
--backend fsdp \
--local_dir $llm_path/actor \
--target_dir $llm_path/merged_hf_model || exit 1
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "stage 4: Test the model"
dataset=zero_shot_zh # from CosyVoice3 test set
# dataset=test_zh # from seed_tts test set
output_dir=./outputs_${exp_name}_${step}_${dataset}
token2wav_path=/workspace/CosyVoice2-0.5B
model_path=$llm_path/merged_hf_model
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --nproc_per_node=8 \
infer_dataset.py \
--output-dir $output_dir \
--llm-model-name-or-path $model_path \
--token2wav-path $token2wav_path \
--split-name ${dataset} || exit 1
bash scripts/compute_wer.sh $output_dir ${dataset}
fi
done
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "stage 5: Convert the RL trained model to CosyVoice repo format"
python3 huggingface_to_pretrained.py \
--hf-cosyvoice2-llm-path $llm_path/merged_hf_model \
--output-path /workspace/CosyVoice2-0.5B/llm-new.pt
# You need to manually move the llm-new.pt to overwrite /workspace/CosyVoice2-0.5B/llm.pt
# However, we found that the RL trained model accuracy would slightly drop after this conversion.
# Please be careful or use the huggingface format inference code.
fi

View File

@ -0,0 +1,33 @@
wav_dir=$1
wav_files=$(ls $wav_dir/*.wav)
# if wav_files is empty, then exit
if [ -z "$wav_files" ]; then
exit 1
fi
split_name=$2
model_path=models/sherpa-onnx-paraformer-zh-2023-09-14
if [ ! -d $model_path ]; then
pip install sherpa-onnx
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
mkdir models
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C models
fi
python3 scripts/offline-decode-files.py \
--tokens=$model_path/tokens.txt \
--paraformer=$model_path/model.int8.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
--sample-rate=24000 \
--log-dir $wav_dir \
--feature-dim=80 \
--split-name $split_name \
--name sherpa_onnx \
$wav_files
# python3 scripts/paraformer-pytriton-client.py \
# --log-dir $wav_dir \
# --split-name $split_name \
# $wav_files

View File

@ -0,0 +1,754 @@
# Copyright (c) 2023 by manyeyes
# Copyright (c) 2023 Xiaomi Corporation
"""
This file demonstrates how to use sherpa-onnx Python API to transcribe
file(s) with a non-streaming model.
(1) For paraformer
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--paraformer=/path/to/paraformer.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
--sample-rate=16000 \
--feature-dim=80 \
/path/to/0.wav \
/path/to/1.wav
(2) For transducer models from icefall
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
--decoder=/path/to/decoder.onnx \
--joiner=/path/to/joiner.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
--sample-rate=16000 \
--feature-dim=80 \
/path/to/0.wav \
/path/to/1.wav
(3) For CTC models from NeMo
python3 ./python-api-examples/offline-decode-files.py \
--tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
--nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
(4) For Whisper models
python3 ./python-api-examples/offline-decode-files.py \
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
--whisper-task=transcribe \
--num-threads=1 \
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
(5) For CTC models from WeNet
python3 ./python-api-examples/offline-decode-files.py \
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
(6) For tdnn models of the yesno recipe from icefall
python3 ./python-api-examples/offline-decode-files.py \
--sample-rate=8000 \
--feature-dim=23 \
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download non-streaming pre-trained models
used in this file.
"""
import argparse
import time
import wave
from pathlib import Path
from typing import List, Tuple, Dict, Iterable, TextIO, Union
import numpy as np
import sherpa_onnx
import soundfile as sf
from datasets import load_dataset
import logging
from collections import defaultdict
import kaldialign
from zhon.hanzi import punctuation
import string
punctuation_all = punctuation + string.punctuation
Pathlike = Union[str, Path]
def remove_punctuation(text: str) -> str:
for x in punctuation_all:
if x == '\'':
continue
text = text.replace(x, '')
return text
def store_transcripts(
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
) -> None:
"""Save predicted results and reference transcripts to a file.
Args:
filename:
File to save the results to.
texts:
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
If it is a multi-talker ASR system, the ref and hyp may also be lists of
strings.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf8") as f:
for cut_id, ref, hyp in texts:
if char_level:
ref = list("".join(ref))
hyp = list("".join(hyp))
print(f"{cut_id}:\tref={ref}", file=f)
print(f"{cut_id}:\thyp={hyp}", file=f)
def write_error_stats(
f: TextIO,
test_set_name: str,
results: List[Tuple[str, str]],
enable_log: bool = True,
compute_CER: bool = False,
sclite_mode: bool = False,
) -> float:
"""Write statistics based on predicted results and reference transcripts.
It will write the following to the given file:
- WER
- number of insertions, deletions, substitutions, corrects and total
reference words. For example::
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
reference words (2337 correct)
- The difference between the reference transcript and predicted result.
An instance is given below::
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
The above example shows that the reference word is `EDISON`,
but it is predicted to `ADDISON` (a substitution error).
Another example is::
FOR THE FIRST DAY (SIR->*) I THINK
The reference word `SIR` is missing in the predicted
results (a deletion error).
results:
An iterable of tuples. The first element is the cut_id, the second is
the reference transcript and the third element is the predicted result.
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
Returns:
Return None.
"""
subs: Dict[Tuple[str, str], int] = defaultdict(int)
ins: Dict[str, int] = defaultdict(int)
dels: Dict[str, int] = defaultdict(int)
# `words` stores counts per word, as follows:
# corr, ref_sub, hyp_sub, ins, dels
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
if compute_CER:
for i, res in enumerate(results):
cut_id, ref, hyp = res
ref = list("".join(ref))
hyp = list("".join(hyp))
results[i] = (cut_id, ref, hyp)
for _cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
for ref_word, hyp_word in ali:
if ref_word == ERR:
ins[hyp_word] += 1
words[hyp_word][3] += 1
elif hyp_word == ERR:
dels[ref_word] += 1
words[ref_word][4] += 1
elif hyp_word != ref_word:
subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1
words[hyp_word][2] += 1
else:
words[ref_word][0] += 1
num_corr += 1
ref_len = sum([len(r) for _, r, _ in results])
sub_errs = sum(subs.values())
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
if enable_log:
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
print(f"%WER = {tot_err_rate}", file=f)
print(
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
f"{sub_errs} substitutions, over {ref_len} reference "
f"words ({num_corr} correct)",
file=f,
)
print(
"Search below for sections starting with PER-UTT DETAILS:, "
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
file=f,
)
print("", file=f)
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
combine_successive_errors = True
if combine_successive_errors:
ali = [[[x], [y]] for x, y in ali]
for i in range(len(ali) - 1):
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
ali[i] = [[], []]
ali = [
[
list(filter(lambda a: a != ERR, x)),
list(filter(lambda a: a != ERR, y)),
]
for x, y in ali
]
ali = list(filter(lambda x: x != [[], []], ali))
ali = [
[
ERR if x == [] else " ".join(x),
ERR if y == [] else " ".join(y),
]
for x, y in ali
]
print(
f"{cut_id}:\t"
+ " ".join(
(
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali
)
),
file=f,
)
print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f)
print("DELETIONS: count ref", file=f)
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
print(f"{count} {ref}", file=f)
print("", file=f)
print("INSERTIONS: count hyp", file=f)
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
print(f"{count} {hyp}", file=f)
print("", file=f)
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
):
(corr, ref_sub, hyp_sub, ins, dels) = counts
tot_errs = ref_sub + hyp_sub + ins + dels
ref_count = corr + ref_sub + dels
hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate)
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
help="Path to tokens.txt",
)
parser.add_argument(
"--hotwords-file",
type=str,
default="",
help="""
The file containing hotwords, one words/phrases per line, like
HELLO WORLD
你好世界
""",
)
parser.add_argument(
"--hotwords-score",
type=float,
default=1.5,
help="""
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)
parser.add_argument(
"--modeling-unit",
type=str,
default="",
help="""
The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
Used only when hotwords-file is given.
""",
)
parser.add_argument(
"--bpe-vocab",
type=str,
default="",
help="""
The path to the bpe vocabulary, the bpe vocabulary is generated by
sentencepiece, you can also export the bpe vocabulary through a bpe model
by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
and modeling-unit is bpe or cjkchar+bpe.
""",
)
parser.add_argument(
"--encoder",
default="",
type=str,
help="Path to the encoder model",
)
parser.add_argument(
"--decoder",
default="",
type=str,
help="Path to the decoder model",
)
parser.add_argument(
"--joiner",
default="",
type=str,
help="Path to the joiner model",
)
parser.add_argument(
"--paraformer",
default="",
type=str,
help="Path to the model.onnx from Paraformer",
)
parser.add_argument(
"--nemo-ctc",
default="",
type=str,
help="Path to the model.onnx from NeMo CTC",
)
parser.add_argument(
"--wenet-ctc",
default="",
type=str,
help="Path to the model.onnx from WeNet CTC",
)
parser.add_argument(
"--tdnn-model",
default="",
type=str,
help="Path to the model.onnx for the tdnn model of the yesno recipe",
)
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)
parser.add_argument(
"--whisper-encoder",
default="",
type=str,
help="Path to whisper encoder model",
)
parser.add_argument(
"--whisper-decoder",
default="",
type=str,
help="Path to whisper decoder model",
)
parser.add_argument(
"--whisper-language",
default="",
type=str,
help="""It specifies the spoken language in the input audio file.
Example values: en, fr, de, zh, jp.
Available languages for multilingual models can be found at
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
If not specified, we infer the language from the input audio file.
""",
)
parser.add_argument(
"--whisper-task",
default="transcribe",
choices=["transcribe", "translate"],
type=str,
help="""For multilingual models, if you specify translate, the output
will be in English.
""",
)
parser.add_argument(
"--whisper-tail-paddings",
default=-1,
type=int,
help="""Number of tail padding frames.
We have removed the 30-second constraint from whisper, so you need to
choose the amount of tail padding frames by yourself.
Use -1 to use a default value for tail padding.
""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
parser.add_argument(
"--debug",
type=bool,
default=False,
help="True to show debug messages",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="""Sample rate of the feature extractor. Must match the one
expected by the model. Note: The input sound files can have a
different sample rate from this argument.""",
)
parser.add_argument(
"--feature-dim",
type=int,
default=80,
help="Feature dimension. Must match the one expected by the model",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to decode. Each file must be of WAVE"
"format with a single channel, and each sample has 16-bit, "
"i.e., int16_t. "
"The sample rate of the file can be arbitrary and does not need to "
"be 16 kHz",
)
parser.add_argument(
"--name",
type=str,
default="",
help="The directory containing the input sound files to decode",
)
parser.add_argument(
"--log-dir",
type=str,
default="",
help="The directory containing the input sound files to decode",
)
parser.add_argument(
"--label",
type=str,
default=None,
help="wav_base_name label",
)
# Dataset related arguments for loading labels when label file is not provided
parser.add_argument(
"--dataset-name",
type=str,
default="yuekai/seed_tts_cosy2",
help="Huggingface dataset name for loading labels",
)
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
help="Dataset split name for loading labels",
)
return parser.parse_args()
def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
)
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and can be of type
32-bit floating point PCM. Its sample rate does not need to be 24kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples,
which are normalized to the range [-1, 1].
- Sample rate of the wave file.
"""
samples, sample_rate = sf.read(wave_filename, dtype="float32")
assert (
samples.ndim == 1
), f"Expected single channel, but got {samples.ndim} channels."
samples_float32 = samples.astype(np.float32)
return samples_float32, sample_rate
def normalize_text_alimeeting(text: str) -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
import re
text = text.replace('\u00A0', '') # test_hard
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = remove_punctuation(text)
return text
def main():
args = get_args()
assert_file_exists(args.tokens)
assert args.num_threads > 0, args.num_threads
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.paraformer)
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=args.paraformer,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
debug=args.debug,
)
print("Started!")
start_time = time.time()
streams, results = [], []
total_duration = 0
for i, wave_filename in enumerate(args.sound_files):
assert_file_exists(wave_filename)
samples, sample_rate = read_wave(wave_filename)
duration = len(samples) / sample_rate
total_duration += duration
s = recognizer.create_stream()
s.accept_waveform(sample_rate, samples)
streams.append(s)
if i % 10 == 0:
recognizer.decode_streams(streams)
results += [s.result.text for s in streams]
streams = []
print(f"Processed {i} files")
# process the last batch
if streams:
recognizer.decode_streams(streams)
results += [s.result.text for s in streams]
end_time = time.time()
print("Done!")
results_dict = {}
for wave_filename, result in zip(args.sound_files, results):
print(f"{wave_filename}\n{result}")
print("-" * 10)
wave_basename = Path(wave_filename).stem
results_dict[wave_basename] = result
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration
print(f"num_threads: {args.num_threads}")
print(f"decoding_method: {args.decoding_method}")
print(f"Wave duration: {total_duration:.3f} s")
print(f"Elapsed time: {elapsed_seconds:.3f} s")
print(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
# Load labels either from file or from dataset
labels_dict = {}
if args.label:
# Load labels from file (original functionality)
print(f"Loading labels from file: {args.label}")
with open(args.label, "r") as f:
for line in f:
# fields = line.strip().split(" ")
# fields = [item for item in fields if item]
# assert len(fields) == 4
# prompt_text, prompt_audio, text, audio_path = fields
fields = line.strip().split("|")
fields = [item for item in fields if item]
assert len(fields) == 4
audio_path, prompt_text, prompt_audio, text = fields
labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
else:
# Load labels from dataset (new functionality)
print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}")
if 'zero' in args.split_name:
dataset_name = "yuekai/CV3-Eval"
else:
dataset_name = "yuekai/seed_tts_cosy2"
dataset = load_dataset(
dataset_name,
split=args.split_name,
trust_remote_code=True,
)
for item in dataset:
audio_id = item["id"]
labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])
print(f"Loaded {len(labels_dict)} labels from dataset")
# Perform evaluation if labels are available
if labels_dict:
final_results = []
for key, value in results_dict.items():
if key in labels_dict:
final_results.append((key, labels_dict[key], value))
else:
print(f"Warning: No label found for {key}, skipping...")
if final_results:
store_transcripts(
filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
)
with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
write_error_stats(f, "test-set", final_results, enable_log=True)
with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
print(f.readline()) # WER
print(f.readline()) # Detailed errors
else:
print("No matching labels found for evaluation")
else:
print("No labels available for evaluation")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,346 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pytriton server for token2wav conversion and ASR"""
from datasets import load_dataset
from cosyvoice.cli.cosyvoice import CosyVoice2
from omnisense.models import OmniSenseVoiceSmall
from pytriton.proxy.types import Request
from pytriton.triton import Triton, TritonConfig
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
from pytriton.decorators import batch
import argparse
import io
import logging
from typing import Any, List
import numpy as np
import torch
from scipy.signal import resample
import sys
import random
import re
from jiwer import wer
from pypinyin import lazy_pinyin, Style
from tn.chinese.normalizer import Normalizer as ZhNormalizer
# Chinese text normalizer (cached globally)
zh_tn_model = ZhNormalizer(
cache_dir="./cache",
remove_erhua=False,
remove_interjections=False,
remove_puncts=True,
overwrite_cache=True,
)
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
logger = logging.getLogger("token2wav_asr_server")
class _ASR_Server:
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
def __init__(self, device_id: int):
self._model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
@batch
def __call__(self, WAV: np.ndarray, WAV_LENS: np.ndarray, LANGUAGE: np.ndarray, TEXT_NORM: np.ndarray):
"""
WAV: np.ndarray, WAV_LENS: np.ndarray
LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
"""
logger.debug("WAV: %s, WAV_LENS: %s, shapes: %s %s", type(WAV), type(WAV_LENS), WAV.shape, WAV_LENS.shape)
wavs = [WAV[i, :WAV_LENS[i, 0]] for i in range(len(WAV))]
results = self._model.transcribe_single_batch(
wavs,
language="zh",
textnorm="woitn",
)
texts = [result.text for result in results]
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
return {"TRANSCRIPTS": transcripts}
def audio_decode_cosyvoice2(
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
):
"""
Generate audio from tokens with optional tone and prompt embedding.
"""
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
"empty", prompt_text, prompt_speech_16k, 24000
)
tts_mel, _ = codec_decoder.model.flow.inference(
token=audio_tokens.to(codec_decoder.model.device),
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
codec_decoder.model.device
),
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
codec_decoder.model.device
),
prompt_token_len=torch.tensor(
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
).to(codec_decoder.model.device),
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
codec_decoder.model.device
),
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
codec_decoder.model.device
),
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
finalize=True,
)
audio_hat, _ = codec_decoder.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
return audio_hat
def get_random_prompt_from_dataset(dataset):
"""
Get random prompt text and speech from the pre-loaded dataset.
Returns (prompt_text, prompt_speech_16k)
"""
random_idx = random.randint(0, len(dataset) - 1)
sample = dataset[random_idx]
# Extract audio data
audio_data = sample["audio"]
audio_array = audio_data["array"]
sample_rate = audio_data["sampling_rate"]
# Convert audio to 16kHz if needed
if sample_rate != 16000:
num_samples = int(len(audio_array) * (16000 / sample_rate))
audio_array = resample(audio_array, num_samples)
# Convert to torch tensor
prompt_speech_16k = torch.from_numpy(audio_array).float().unsqueeze(0)
prompt_text = sample["text"]
# remove space in prompt_text
prompt_text = prompt_text.replace(" ", "")
return prompt_text, prompt_speech_16k
class _Token2Wav_ASR:
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
def __init__(self, device_id: int):
self.asr_model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
self.dataset = load_dataset("yuekai/aishell", "test", trust_remote_code=True)["test"]
# Make sure the CosyVoice2 decoder lives on the same GPU as the ASR model
# CosyVoice2 internally uses generic "cuda" device, so we first switch the
# current CUDA context to the desired card before the object is created.
# Afterwards, all parameters loaded with the generic "cuda" device will
# reside on this GPU. We keep the selected id in `self.device_id` and
# will set the context again for every forward call to avoid race
# conditions when several instances are used in the same process.
self.device_id = device_id
# Construct the TTS codec decoder under the correct CUDA device context
with torch.cuda.device(self.device_id):
self.codec_decoder = CosyVoice2(
"/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
)
@batch
def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
"""
WAV: np.ndarray, WAV_LENS: np.ndarray
LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
"""
# Ensure the default CUDA device is set correctly for this invocation
torch.cuda.set_device(self.device_id)
if self.device_id == 0:
print(f"device_id: {self.device_id}, TOKENS: {TOKENS.shape}, TOKEN_LENS: {TOKEN_LENS.shape}")
tokens_list = [TOKENS[i, :TOKEN_LENS[i, 0]] for i in range(len(TOKENS))]
# Decode ground-truth text strings (BYTES → str)
if GT_TEXT.ndim == 2:
gt_texts = [GT_TEXT[i, 0].decode("utf-8") for i in range(len(GT_TEXT))]
else:
gt_texts = [GT_TEXT[i].decode("utf-8") for i in range(len(GT_TEXT))]
wavs = []
for tokens in tokens_list:
prompt_text, prompt_speech_16k = get_random_prompt_from_dataset(self.dataset)
audio_tokens = torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0)
audio_hat = audio_decode_cosyvoice2(
audio_tokens,
prompt_text,
prompt_speech_16k,
self.codec_decoder,
)
# resample to 16000 using soundfile
audio_hat = audio_hat.squeeze(0).float().cpu()
audio_hat = audio_hat.numpy()
num_samples = int(len(audio_hat) * (16000 / 24000))
audio_hat = resample(audio_hat, num_samples)
wavs.append(audio_hat)
results = self.asr_model.transcribe_single_batch(
wavs,
language="zh",
textnorm="woitn",
)
texts = [result.text for result in results]
# ---------------- Reward computation ----------------
rewards = []
for gt_text, hyp_text in zip(gt_texts, texts):
gt_norm = zh_tn_model.normalize(gt_text).lower()
hyp_norm = zh_tn_model.normalize(hyp_text).lower()
gt_pinyin = lazy_pinyin(
gt_norm,
style=Style.TONE3,
tone_sandhi=True,
neutral_tone_with_five=True,
)
hyp_pinyin = lazy_pinyin(
hyp_norm,
style=Style.TONE3,
tone_sandhi=True,
neutral_tone_with_five=True,
)
c = float(wer(" ".join(gt_pinyin), " ".join(hyp_pinyin)))
reward_val = 1.0 - np.tanh(3.0 * c)
reward_val = max(0.0, min(1.0, reward_val))
rewards.append(reward_val)
print(f"gt_text: {gt_text}, hyp_text: {hyp_text}, reward_val: {reward_val}")
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}
def _infer_function_factory(device_ids: List[int], model_name: str):
"""Creates a list of inference functions, one for each requested device ID."""
infer_funcs = []
for device_id in device_ids:
if model_name == "sensevoice":
infer_funcs.append(_ASR_Server(device_id=device_id))
else:
infer_funcs.append(_Token2Wav_ASR(device_id=device_id))
return infer_funcs
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--max-batch-size",
type=int,
default=32,
help="Batch size of request.",
required=False,
)
parser.add_argument(
"--verbose",
action="store_true",
default=False,
)
parser.add_argument(
"--number-of-instances-per-device",
type=int,
default=1,
help="Number of model instances to load.",
required=False,
)
parser.add_argument(
"--number-of-devices",
type=int,
default=8,
help="Number of devices to use.",
)
parser.add_argument(
"--model-name",
type=str,
default="token2wav_asr",
choices=["token2wav_asr", "sensevoice"],
help="Model name.",
)
args = parser.parse_args()
log_level = logging.DEBUG if args.verbose else logging.INFO
logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")
triton_config = TritonConfig(
http_port=8000,
grpc_port=8001,
metrics_port=8002,
)
device_ids = list(range(args.number_of_devices))
device_ids = device_ids * args.number_of_instances_per_device
with Triton(config=triton_config) as triton:
logger.info("Loading SenseVoice model on device ids: %s", device_ids)
if args.model_name == "sensevoice":
triton.bind(
model_name="sensevoice",
infer_func=_infer_function_factory(device_ids, args.model_name),
inputs=[
Tensor(name="WAV", dtype=np.float32, shape=(-1,)),
Tensor(name="WAV_LENS", dtype=np.int32, shape=(-1,)),
Tensor(name="LANGUAGE", dtype=np.int32, shape=(-1,)),
Tensor(name="TEXT_NORM", dtype=np.int32, shape=(-1,)),
],
outputs=[
Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
],
config=ModelConfig(
max_batch_size=args.max_batch_size,
batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
),
strict=True,
)
else:
triton.bind(
model_name="token2wav_asr",
infer_func=_infer_function_factory(device_ids, args.model_name),
inputs=[
Tensor(name="TOKENS", dtype=np.int32, shape=(-1,)),
Tensor(name="TOKEN_LENS", dtype=np.int32, shape=(-1,)),
Tensor(name="GT_TEXT", dtype=bytes, shape=(-1,)),
],
outputs=[
Tensor(name="REWARDS", dtype=np.float32, shape=(-1,)),
Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
],
config=ModelConfig(
max_batch_size=args.max_batch_size,
batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
),
strict=True,
)
logger.info("Serving inference")
triton.serve()
if __name__ == "__main__":
main()

View File

@ -147,7 +147,7 @@ hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
generator: !ref <hift>
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
mel_spec_transform: [
!ref <mel_spec_transform1>
]

View File

@ -40,6 +40,10 @@ def main():
with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
for k, v in spk2utt.items():
f.write('{} {}\n'.format(k, ' '.join(v)))
if args.instruct != '':
with open('{}/instruct'.format(args.des_dir), 'w') as f:
for k, v in utt2text.items():
f.write('{} {}\n'.format(k, args.instruct))
return
@ -49,5 +53,7 @@ if __name__ == "__main__":
type=str)
parser.add_argument('--des_dir',
type=str)
parser.add_argument('--instruct',
type=str)
args = parser.parse_args()
main()

View File

@ -0,0 +1,50 @@
import argparse
import logging
import os
from tqdm import tqdm
import torch
import torchaudio
from cosyvoice.cli.cosyvoice import CosyVoice2
from cosyvoice.utils.file_utils import load_wav
logger = logging.getLogger()
def main():
cosyvoice = CosyVoice2(args.ref_model)
utt2wav, utt2text = {}, {}
with open('{}/wav.scp'.format(args.src_dir)) as f:
for l in f:
l = l.split('\n')[0].split()
utt2wav[l[0]] = l[1]
with open('{}/text'.format(args.src_dir)) as f:
for l in f:
l = l.split('\n')[0].split()
utt2text[l[0]] = ' '.join(l[1:])
os.makedirs('{}/wav'.format(args.des_dir), exist_ok=True)
with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
for utt, wav in tqdm(utt2wav.items()):
prompt_speech_16k = load_wav(wav, 16000)
if prompt_speech_16k.shape[1] >= 30 * 16000:
continue
speech_list = []
for _, j in enumerate(cosyvoice.inference_zero_shot(utt2text[utt], utt2text[utt], prompt_speech_16k, stream=False, text_frontend=False)):
speech_list.append(j['tts_speech'])
negative_wav = os.path.abspath('{}/wav/{}'.format(args.des_dir, os.path.basename(wav)))
torchaudio.save(negative_wav, torch.concat(speech_list, dim=1), cosyvoice.sample_rate, backend='soundfile')
f.write('{} {}\n'.format(utt, negative_wav))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--src_dir',
type=str)
parser.add_argument('--des_dir',
type=str)
parser.add_argument('--ref_model',
type=str)
args = parser.parse_args()
main()

View File

@ -51,23 +51,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
done
fi
# inference
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "Run inference. Please make sure utt in tts_text is in prompt_data"
for mode in sft zero_shot; do
python cosyvoice/bin/inference.py --mode $mode \
--gpu 0 \
--config conf/cosyvoice.yaml \
--prompt_data data/test-clean/parquet/data.list \
--prompt_utt2data data/test-clean/parquet/utt2data.list \
--tts_text `pwd`/tts_text.json \
--llm_model $pretrained_model_dir/llm.pt \
--flow_model $pretrained_model_dir/flow.pt \
--hifigan_model $pretrained_model_dir/hift.pt \
--result_dir `pwd`/exp/cosyvoice/test-clean/$mode
done
fi
# train llm
export CUDA_VISIBLE_DEVICES="0,1,2,3"
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
@ -77,7 +60,7 @@ num_workers=2
prefetch=100
train_engine=torch_ddp
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
echo "Run train. We only support llm traning for now"
if [ $train_engine == 'deepspeed' ]; then
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
fi

View File

@ -5,74 +5,51 @@ __set_seed3: !apply:torch.manual_seed [1986]
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
# fixed params
sample_rate: 22050
text_encoder_input_size: 512
llm_input_size: 1024
llm_output_size: 1024
sample_rate: 24000
llm_input_size: 896
llm_output_size: 896
spk_embed_dim: 192
qwen_pretrain_path: ''
token_frame_rate: 25
token_mel_ratio: 2
# stream related params
chunk_size: 25 # streaming inference chunk size, in token
num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
# model params
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
# for system/third_party class/function, we do not require this.
llm: !new:cosyvoice.llm.llm.TransformerLM
text_encoder_input_size: !ref <text_encoder_input_size>
llm: !new:cosyvoice.llm.llm.Qwen2LM
llm_input_size: !ref <llm_input_size>
llm_output_size: !ref <llm_output_size>
text_token_size: 51866 # change to 60515 if you want to train with CosyVoice-300M-25Hz recipe
speech_token_size: 4096
speech_token_size: 6561
length_normalized_loss: True
lsm_weight: 0
spk_embed_dim: !ref <spk_embed_dim>
text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
input_size: !ref <text_encoder_input_size>
output_size: 1024
attention_heads: 8
linear_units: 2048
num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
use_cnn_module: False
macaron_style: False
use_dynamic_chunk: False
use_dynamic_left_chunk: False
static_chunk_size: 1
llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
input_size: !ref <llm_input_size>
output_size: !ref <llm_output_size>
attention_heads: 8
linear_units: 2048
num_blocks: 7
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: 'linear_legacy'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
static_chunk_size: 1
mix_ratio: [5, 15]
llm: !new:cosyvoice.llm.llm.Qwen2Encoder
pretrain_path: !ref <qwen_pretrain_path>
sampling: !name:cosyvoice.utils.common.ras_sampling
top_p: 0.8
top_k: 25
win_size: 10
tau_r: 0.1
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
input_size: 512
output_size: 80
spk_embed_dim: !ref <spk_embed_dim>
output_type: 'mel'
vocab_size: 4096
input_frame_rate: 50 # change to 25 if you want to train with CosyVoice-300M-25Hz recipe
vocab_size: 6561
input_frame_rate: !ref <token_frame_rate>
only_mask_loss: True
encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
token_mel_ratio: !ref <token_mel_ratio>
pre_lookahead_len: 3
encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
output_size: 512
attention_heads: 4
linear_units: 1024
num_blocks: 3
attention_heads: 8
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
@ -83,10 +60,8 @@ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
input_size: 512
use_cnn_module: False
macaron_style: False
length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
channels: 80
sampling_ratios: [1, 1, 1, 1]
decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
static_chunk_size: !ref <chunk_size>
decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
in_channels: 240
n_spks: 1
spk_emb_dim: 80
@ -98,16 +73,18 @@ flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
training_cfg_rate: 0.2
inference_cfg_rate: 0.7
reg_loss_type: 'l1'
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
estimator: !new:cosyvoice.flow.decoder.CausalConditionalDecoder
in_channels: 320
out_channels: 80
channels: [256, 256]
channels: [256]
dropout: 0.0
attention_head_dim: 64
n_blocks: 4
num_mid_blocks: 8
num_mid_blocks: 12
num_heads: 8
act_fn: 'gelu'
static_chunk_size: !ref <chunk_size> * <token_mel_ratio>
num_decoding_left_chunks: !ref <num_decoding_left_chunks>
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
in_channels: 80
@ -117,15 +94,15 @@ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
nsf_alpha: 0.1
nsf_sigma: 0.003
nsf_voiced_threshold: 10
upsample_rates: [8, 8]
upsample_kernel_sizes: [16, 16]
upsample_rates: [8, 5, 3]
upsample_kernel_sizes: [16, 11, 7]
istft_params:
n_fft: 16
hop_len: 4
resblock_kernel_sizes: [3, 7, 11]
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
source_resblock_kernel_sizes: [7, 11]
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
source_resblock_kernel_sizes: [7, 7, 11]
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
lrelu_slope: 0.1
audio_limit: 0.99
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
@ -135,11 +112,11 @@ hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
# gan related module
mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1024
n_fft: 1920
num_mels: 80
sampling_rate: !ref <sample_rate>
hop_size: 256
win_size: 1024
hop_size: 480
win_size: 1920
fmin: 0
fmax: null
center: False
@ -147,45 +124,44 @@ hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
generator: !ref <hift>
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
mel_spec_transform: [
!ref <mel_spec_transform1>
]
# processor functions
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
multilingual: True
num_languages: 100
language: 'en'
task: 'transcribe'
get_tokenizer: !name:cosyvoice.tokenizer.tokenizer.get_qwen_tokenizer
token_path: !ref <qwen_pretrain_path>
skip_special_tokens: True
allowed_special: 'all'
tokenize: !name:cosyvoice.dataset.processor.tokenize
get_tokenizer: !ref <get_tokenizer>
allowed_special: !ref <allowed_special>
filter: !name:cosyvoice.dataset.processor.filter
max_length: 40960
min_length: 0
min_length: 100
token_max_length: 200
token_min_length: 1
resample: !name:cosyvoice.dataset.processor.resample
resample_rate: !ref <sample_rate>
truncate: !name:cosyvoice.dataset.processor.truncate
truncate_length: 24576 # must be a multiplier of hop_size
truncate_length: 24480 # must be a multiplier of hop_size
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1024
n_fft: 1920
num_mels: 80
sampling_rate: !ref <sample_rate>
hop_size: 256
win_size: 1024
hop_size: 480
win_size: 1920
fmin: 0
fmax: 8000
center: False
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor>
token_mel_ratio: 2
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
sample_rate: !ref <sample_rate>
hop_size: 256
hop_size: 480
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
normalize: True
shuffle: !name:cosyvoice.dataset.processor.shuffle
@ -194,10 +170,11 @@ sort: !name:cosyvoice.dataset.processor.sort
sort_size: 500 # sort_size should be less than shuffle_size
batch: !name:cosyvoice.dataset.processor.batch
batch_type: 'dynamic'
max_frames_in_batch: 12000
max_frames_in_batch: 2000
padding: !name:cosyvoice.dataset.processor.padding
use_spk_embedding: False # change to True during sft
# dataset processor pipeline
data_pipeline: [
!ref <parquet_opener>,
@ -230,10 +207,10 @@ data_pipeline_gan: [
train_conf:
optim: adam
optim_conf:
lr: 0.002 # change to 0.001 if you want to train flow from scratch
scheduler: warmuplr
lr: 1e-5 # change to 1e-5 during sft
scheduler: constantlr # change to constantlr during sft
scheduler_conf:
warmup_steps: 25000
warmup_steps: 2500
max_epoch: 200
grad_clip: 5
accum_grad: 2

View File

@ -0,0 +1 @@
../cosyvoice/local

View File

@ -0,0 +1 @@
../cosyvoice/path.sh

View File

@ -0,0 +1,110 @@
#!/bin/bash
# Copyright 2024 Alibaba Inc. All Rights Reserved.
. ./path.sh || exit 1;
stage=-1
stop_stage=3
data_url=www.openslr.org/resources/60
data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
pretrained_model_dir=../../../pretrained_models/CosyVoice2-0.5B
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "Data Download"
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
local/download_and_untar.sh ${data_dir} ${data_url} ${part}
done
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_embedding.py --dir data/$x \
--onnx_path $pretrained_model_dir/campplus.onnx
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_speech_token.py --dir data/$x \
--onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx
done
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x/parquet
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
--num_processes 10 \
--src_dir data/$x \
--des_dir data/$x/parquet
done
fi
# train llm
export CUDA_VISIBLE_DEVICES="0,1,2,3"
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
job_id=1986
dist_backend="nccl"
num_workers=2
prefetch=100
train_engine=torch_ddp
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Run train. We only support llm traning for now"
if [ $train_engine == 'deepspeed' ]; then
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
fi
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
for model in llm flow hifigan; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
cosyvoice/bin/train.py \
--train_engine $train_engine \
--config conf/cosyvoice2.yaml \
--train_data data/train.data.list \
--cv_data data/dev.data.list \
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
--model $model \
--checkpoint $pretrained_model_dir/$model.pt \
--model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \
--tensorboard_dir `pwd`/tensorboard/cosyvoice2/$model/$train_engine \
--ddp.dist_backend $dist_backend \
--num_workers ${num_workers} \
--prefetch ${prefetch} \
--pin_memory \
--use_amp \
--deepspeed_config ./conf/ds_stage2.json \
--deepspeed.save_states model+optimizer
done
fi
# average model
average_num=5
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
for model in llm flow hifigan; do
decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
echo "do model average and final checkpoint is $decode_checkpoint"
python cosyvoice/bin/average_model.py \
--dst_model $decode_checkpoint \
--src_path `pwd`/exp/cosyvoice/$model/$train_engine \
--num ${average_num} \
--val_best
done
fi
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
fi

View File

@ -0,0 +1,123 @@
#!/bin/bash
# Copyright 2024 Alibaba Inc. All Rights Reserved.
. ./path.sh || exit 1;
stage=-1
stop_stage=3
data_url=www.openslr.org/resources/60
data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
pretrained_model_dir=../../../pretrained_models/CosyVoice2-0.5B
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "Data Download"
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
local/download_and_untar.sh ${data_dir} ${data_url} ${part}
done
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x
done
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Prepare negative samples using CosyVoice2-0.5B, this is also our reference model.
Here we use CosyVoice2-0.5B generated audio as reject sample for simplicity, you can use metric like wer/similarity."
for x in train-clean-100 train-clean-360 train-other-500; do
mkdir -p data/${x}_reject
python local/prepare_reject_sample.py --src_dir data/$x --des_dir data/${x}_reject --ref_model $pretrained_model_dir
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_embedding.py --dir data/$x \
--onnx_path $pretrained_model_dir/campplus.onnx
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 train-clean-100_reject train-clean-360_reject dev-clean dev-other test-clean test-other; do
tools/extract_speech_token.py --dir data/$x \
--onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx
done
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x/parquet
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
--num_processes 10 \
--dpo \
--src_dir data/$x \
--des_dir data/$x/parquet
done
fi
# train llm
export CUDA_VISIBLE_DEVICES="0,1,2,3"
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
job_id=1986
dist_backend="nccl"
num_workers=2
prefetch=100
train_engine=torch_ddp
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Run train. We only support llm traning for now"
if [ $train_engine == 'deepspeed' ]; then
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
fi
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
# NOTE only llm supports dpo
for model in llm; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
cosyvoice/bin/train.py \
--train_engine $train_engine \
--config conf/cosyvoice2.yaml \
--train_data data/train.data.list \
--cv_data data/dev.data.list \
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
--model $model \
--checkpoint $pretrained_model_dir/$model.pt \
--ref_model $pretrained_model_dir/llm.pt \
--model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \
--tensorboard_dir `pwd`/tensorboard/cosyvoice2/$model/$train_engine \
--ddp.dist_backend $dist_backend \
--num_workers ${num_workers} \
--prefetch ${prefetch} \
--pin_memory \
--use_amp \
--dpo \
--deepspeed_config ./conf/ds_stage2.json \
--deepspeed.save_states model+optimizer
done
fi
# average model
average_num=5
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
for model in llm flow hifigan; do
decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
echo "do model average and final checkpoint is $decode_checkpoint"
python cosyvoice/bin/average_model.py \
--dst_model $decode_checkpoint \
--src_path `pwd`/exp/cosyvoice/$model/$train_engine \
--num ${average_num} \
--val_best
done
fi
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
fi

View File

@ -0,0 +1 @@
../cosyvoice/tts_text.json

View File

@ -0,0 +1,224 @@
# set random seed, so that you may reproduce your result.
__set_seed1: !apply:random.seed [1986]
__set_seed2: !apply:numpy.random.seed [1986]
__set_seed3: !apply:torch.manual_seed [1986]
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
# fixed params
sample_rate: 24000
llm_input_size: 896
llm_output_size: 896
spk_embed_dim: 192
qwen_pretrain_path: ''
token_frame_rate: 25
token_mel_ratio: 2
# stream related params
chunk_size: 25 # streaming inference chunk size, in token
num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
# model params
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
# for system/third_party class/function, we do not require this.
llm: !new:cosyvoice.llm.llm.CosyVoice3LM
llm_input_size: !ref <llm_input_size>
llm_output_size: !ref <llm_output_size>
speech_token_size: 6561
length_normalized_loss: True
lsm_weight: 0
mix_ratio: [5, 15]
llm: !new:cosyvoice.llm.llm.Qwen2Encoder
pretrain_path: !ref <qwen_pretrain_path>
sampling: !name:cosyvoice.utils.common.ras_sampling
top_p: 0.8
top_k: 25
win_size: 10
tau_r: 0.1
flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithDiT
input_size: 80
output_size: 80
spk_embed_dim: !ref <spk_embed_dim>
output_type: 'mel'
vocab_size: 6561
input_frame_rate: !ref <token_frame_rate>
only_mask_loss: True
token_mel_ratio: !ref <token_mel_ratio>
pre_lookahead_len: 3
pre_lookahead_layer: !new:cosyvoice.transformer.upsample_encoder.PreLookaheadLayer
in_channels: 80
channels: 1024
pre_lookahead_len: 3
decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
in_channels: 240
n_spks: 1
spk_emb_dim: 80
cfm_params: !new:omegaconf.DictConfig
content:
sigma_min: 1e-06
solver: 'euler'
t_scheduler: 'cosine'
training_cfg_rate: 0.2
inference_cfg_rate: 0.7
reg_loss_type: 'l1'
estimator: !new:cosyvoice.flow.DiT.dit.DiT
dim: 1024
depth: 22
heads: 16
dim_head: 64
ff_mult: 2
mel_dim: 80
mu_dim: 80
spk_dim: 80
out_channels: 80
static_chunk_size: !ref <chunk_size> * <token_mel_ratio>
num_decoding_left_chunks: !ref <num_decoding_left_chunks>
hift: !new:cosyvoice.hifigan.generator.CausalHiFTGenerator
in_channels: 80
base_channels: 512
nb_harmonics: 8
sampling_rate: !ref <sample_rate>
nsf_alpha: 0.1
nsf_sigma: 0.003
nsf_voiced_threshold: 10
upsample_rates: [8, 5, 3]
upsample_kernel_sizes: [16, 11, 7]
istft_params:
n_fft: 16
hop_len: 4
resblock_kernel_sizes: [3, 7, 11]
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
source_resblock_kernel_sizes: [7, 7, 11]
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
lrelu_slope: 0.1
audio_limit: 0.99
conv_pre_look_right: 4
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.CausalConvRNNF0Predictor
num_class: 1
in_channels: 80
cond_channels: 512
# gan related module
mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1920
num_mels: 80
sampling_rate: !ref <sample_rate>
hop_size: 480
win_size: 1920
fmin: 0
fmax: null
center: False
hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
generator: !ref <hift>
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
mel_spec_transform: [
!ref <mel_spec_transform1>
]
# processor functions
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
get_tokenizer: !name:cosyvoice.tokenizer.tokenizer.get_qwen_tokenizer
token_path: !ref <qwen_pretrain_path>
skip_special_tokens: True
version: cosyvoice3
allowed_special: 'all'
tokenize: !name:cosyvoice.dataset.processor.tokenize
get_tokenizer: !ref <get_tokenizer>
allowed_special: !ref <allowed_special>
filter: !name:cosyvoice.dataset.processor.filter
max_length: 40960
min_length: 100
token_max_length: 200
token_min_length: 1
resample: !name:cosyvoice.dataset.processor.resample
resample_rate: !ref <sample_rate>
truncate: !name:cosyvoice.dataset.processor.truncate
truncate_length: 24960 # must be a multiplier of hop_size and token_mel_ratio
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1920
num_mels: 80
sampling_rate: !ref <sample_rate>
hop_size: 480
win_size: 1920
fmin: 0
fmax: null
center: False
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor>
token_mel_ratio: 2
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
sample_rate: !ref <sample_rate>
hop_size: 480
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
normalize: True
shuffle: !name:cosyvoice.dataset.processor.shuffle
shuffle_size: 1000
sort: !name:cosyvoice.dataset.processor.sort
sort_size: 500 # sort_size should be less than shuffle_size
batch: !name:cosyvoice.dataset.processor.batch
batch_type: 'dynamic'
max_frames_in_batch: 2000
padding: !name:cosyvoice.dataset.processor.padding
use_spk_embedding: False # change to True during sft
# dataset processor pipeline
data_pipeline: [
!ref <parquet_opener>,
!ref <tokenize>,
!ref <filter>,
!ref <resample>,
!ref <compute_fbank>,
!ref <parse_embedding>,
!ref <shuffle>,
!ref <sort>,
!ref <batch>,
!ref <padding>,
]
data_pipeline_gan: [
!ref <parquet_opener>,
!ref <tokenize>,
!ref <filter>,
!ref <resample>,
!ref <truncate>,
!ref <compute_fbank>,
!ref <compute_f0>,
!ref <parse_embedding>,
!ref <shuffle>,
!ref <sort>,
!ref <batch>,
!ref <padding>,
]
# llm flow train conf
train_conf:
optim: adam
optim_conf:
lr: 1e-5 # change to 1e-5 during sft
scheduler: constantlr # change to constantlr during sft
scheduler_conf:
warmup_steps: 2500
max_epoch: 200
grad_clip: 5
accum_grad: 2
log_interval: 100
save_per_step: -1
# gan train conf
train_conf_gan:
optim: adam
optim_conf:
lr: 0.0002 # use small lr for gan training
scheduler: constantlr
optim_d: adam
optim_conf_d:
lr: 0.0002 # use small lr for gan training
scheduler_d: constantlr
max_epoch: 200
grad_clip: 5
accum_grad: 1 # in gan training, accum_grad must be 1
log_interval: 100
save_per_step: -1

View File

@ -0,0 +1,42 @@
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 100,
"gradient_clipping": 5,
"fp16": {
"enabled": false,
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 256,
"hysteresis": 2,
"consecutive_hysteresis": false,
"min_loss_scale": 1
},
"bf16": {
"enabled": false
},
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": false,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients" : true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 0.001,
"weight_decay": 0.0001,
"torch_adam": true,
"adam_w_mode": true
}
}
}

View File

@ -0,0 +1 @@
../../../cosyvoice

View File

@ -0,0 +1 @@
../cosyvoice/local

View File

@ -0,0 +1 @@
../cosyvoice/path.sh

View File

@ -0,0 +1,112 @@
#!/bin/bash
# Copyright 2024 Alibaba Inc. All Rights Reserved.
. ./path.sh || exit 1;
stage=-1
stop_stage=3
data_url=www.openslr.org/resources/60
data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
pretrained_model_dir=../../../pretrained_models/Fun-CosyVoice3-0.5B
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "Data Download"
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
local/download_and_untar.sh ${data_dir} ${data_url} ${part}
done
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x
# NOTE in CosyVoice3, we add instruct in sequence
python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x --instruct "You are a helpful assistant.<|endofprompt|>"
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_embedding.py --dir data/$x \
--onnx_path $pretrained_model_dir/campplus.onnx
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
tools/extract_speech_token.py --dir data/$x \
--onnx_path $pretrained_model_dir/speech_tokenizer_v3.onnx
done
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mkdir -p data/$x/parquet
tools/make_parquet_list.py --num_utts_per_parquet 1000 \
--num_processes 10 \
--instruct \
--src_dir data/$x \
--des_dir data/$x/parquet
done
fi
# train llm
export CUDA_VISIBLE_DEVICES="0,1,2,3"
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
job_id=1986
dist_backend="nccl"
num_workers=2
prefetch=100
train_engine=torch_ddp
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Run train. We only support llm traning for now"
if [ $train_engine == 'deepspeed' ]; then
echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
fi
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
for model in llm flow hifigan; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
cosyvoice/bin/train.py \
--train_engine $train_engine \
--config conf/cosyvoice3.yaml \
--train_data data/train.data.list \
--cv_data data/dev.data.list \
--qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
--model $model \
--checkpoint $pretrained_model_dir/$model.pt \
--model_dir `pwd`/exp/cosyvoice3/$model/$train_engine \
--tensorboard_dir `pwd`/tensorboard/cosyvoice3/$model/$train_engine \
--ddp.dist_backend $dist_backend \
--num_workers ${num_workers} \
--prefetch ${prefetch} \
--pin_memory \
--use_amp \
--deepspeed_config ./conf/ds_stage2.json \
--deepspeed.save_states model+optimizer
done
fi
# average model
average_num=5
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
for model in llm flow hifigan; do
decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
echo "do model average and final checkpoint is $decode_checkpoint"
python cosyvoice/bin/average_model.py \
--dst_model $decode_checkpoint \
--src_path `pwd`/exp/cosyvoice/$model/$train_engine \
--num ${average_num} \
--val_best
done
fi
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
fi

View File

@ -0,0 +1 @@
../../../tools

View File

@ -0,0 +1 @@
../../libritts/cosyvoice/conf

View File

@ -1,203 +0,0 @@
# set random seed, so that you may reproduce your result.
__set_seed1: !apply:random.seed [1986]
__set_seed2: !apply:numpy.random.seed [1986]
__set_seed3: !apply:torch.manual_seed [1986]
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
# fixed params
sample_rate: 22050
text_encoder_input_size: 512
llm_input_size: 1024
llm_output_size: 1024
spk_embed_dim: 192
# model params
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
# for system/third_party class/function, we do not require this.
llm: !new:cosyvoice.llm.llm.TransformerLM
text_encoder_input_size: !ref <text_encoder_input_size>
llm_input_size: !ref <llm_input_size>
llm_output_size: !ref <llm_output_size>
text_token_size: 51866 # change to 60515 if you want to train with CosyVoice-300M-25Hz recipe
speech_token_size: 4096
length_normalized_loss: True
lsm_weight: 0
spk_embed_dim: !ref <spk_embed_dim>
text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
input_size: !ref <text_encoder_input_size>
output_size: 1024
attention_heads: 8
linear_units: 2048
num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
use_cnn_module: False
macaron_style: False
use_dynamic_chunk: False
use_dynamic_left_chunk: False
static_chunk_size: 1
llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
input_size: !ref <llm_input_size>
output_size: !ref <llm_output_size>
attention_heads: 8
linear_units: 2048
num_blocks: 7
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: 'linear_legacy'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
static_chunk_size: 1
sampling: !name:cosyvoice.utils.common.ras_sampling
top_p: 0.8
top_k: 25
win_size: 10
tau_r: 0.1
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
input_size: 512
output_size: 80
spk_embed_dim: !ref <spk_embed_dim>
output_type: 'mel'
vocab_size: 4096
input_frame_rate: 50 # change to 25 if you want to train with CosyVoice-300M-25Hz recipe
only_mask_loss: True
encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
output_size: 512
attention_heads: 4
linear_units: 1024
num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
input_size: 512
use_cnn_module: False
macaron_style: False
length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
channels: 80
sampling_ratios: [1, 1, 1, 1]
decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
in_channels: 240
n_spks: 1
spk_emb_dim: 80
cfm_params: !new:omegaconf.DictConfig
content:
sigma_min: 1e-06
solver: 'euler'
t_scheduler: 'cosine'
training_cfg_rate: 0.2
inference_cfg_rate: 0.7
reg_loss_type: 'l1'
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
in_channels: 320
out_channels: 80
channels: [256, 256]
dropout: 0.0
attention_head_dim: 64
n_blocks: 4
num_mid_blocks: 8
num_heads: 8
act_fn: 'gelu'
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
in_channels: 80
base_channels: 512
nb_harmonics: 8
sampling_rate: !ref <sample_rate>
nsf_alpha: 0.1
nsf_sigma: 0.003
nsf_voiced_threshold: 10
upsample_rates: [8, 8]
upsample_kernel_sizes: [16, 16]
istft_params:
n_fft: 16
hop_len: 4
resblock_kernel_sizes: [3, 7, 11]
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
source_resblock_kernel_sizes: [7, 11]
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
lrelu_slope: 0.1
audio_limit: 0.99
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
num_class: 1
in_channels: 80
cond_channels: 512
# processor functions
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
multilingual: True
num_languages: 100
language: 'en'
task: 'transcribe'
allowed_special: 'all'
tokenize: !name:cosyvoice.dataset.processor.tokenize
get_tokenizer: !ref <get_tokenizer>
allowed_special: !ref <allowed_special>
filter: !name:cosyvoice.dataset.processor.filter
max_length: 40960
min_length: 0
token_max_length: 200
token_min_length: 1
resample: !name:cosyvoice.dataset.processor.resample
resample_rate: !ref <sample_rate>
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1024
num_mels: 80
sampling_rate: !ref <sample_rate>
hop_size: 256
win_size: 1024
fmin: 0
fmax: 8000
center: False
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor>
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
normalize: True
shuffle: !name:cosyvoice.dataset.processor.shuffle
shuffle_size: 1000
sort: !name:cosyvoice.dataset.processor.sort
sort_size: 500 # sort_size should be less than shuffle_size
batch: !name:cosyvoice.dataset.processor.batch
batch_type: 'dynamic'
max_frames_in_batch: 12000
padding: !name:cosyvoice.dataset.processor.padding
use_spk_embedding: False # change to True during sft
# dataset processor pipeline
data_pipeline: [
!ref <parquet_opener>,
!ref <tokenize>,
!ref <filter>,
!ref <resample>,
!ref <compute_fbank>,
!ref <parse_embedding>,
!ref <shuffle>,
!ref <sort>,
!ref <batch>,
!ref <padding>,
]
# train conf
train_conf:
optim: adam
optim_conf:
lr: 0.002 # change to 0.001 if you want to train flow from scratch
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
max_epoch: 200
grad_clip: 5
accum_grad: 2
log_interval: 100
save_per_step: -1

View File

@ -1,203 +0,0 @@
# set random seed, so that you may reproduce your result.
__set_seed1: !apply:random.seed [1986]
__set_seed2: !apply:numpy.random.seed [1986]
__set_seed3: !apply:torch.manual_seed [1986]
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
# fixed params
sample_rate: 22050
text_encoder_input_size: 512
llm_input_size: 1024
llm_output_size: 1024
spk_embed_dim: 192
# model params
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
# for system/third_party class/function, we do not require this.
llm: !new:cosyvoice.llm.llm.TransformerLM
text_encoder_input_size: !ref <text_encoder_input_size>
llm_input_size: !ref <llm_input_size>
llm_output_size: !ref <llm_output_size>
text_token_size: 51866 # change to 60515 if you want to train with CosyVoice-300M-25Hz recipe
speech_token_size: 4096
length_normalized_loss: True
lsm_weight: 0
spk_embed_dim: !ref <spk_embed_dim>
text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
input_size: !ref <text_encoder_input_size>
output_size: 1024
attention_heads: 16
linear_units: 4096
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
use_cnn_module: False
macaron_style: False
use_dynamic_chunk: False
use_dynamic_left_chunk: False
static_chunk_size: 1
llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
input_size: !ref <llm_input_size>
output_size: !ref <llm_output_size>
attention_heads: 16
linear_units: 4096
num_blocks: 14
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: 'linear_legacy'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
static_chunk_size: 1
sampling: !name:cosyvoice.utils.common.ras_sampling
top_p: 0.8
top_k: 25
win_size: 10
tau_r: 0.1
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
input_size: 512
output_size: 80
spk_embed_dim: !ref <spk_embed_dim>
output_type: 'mel'
vocab_size: 4096
input_frame_rate: 50 # change to 25 if you want to train with CosyVoice-300M-25Hz recipe
only_mask_loss: True
encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
output_size: 512
attention_heads: 8
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
input_size: 512
use_cnn_module: False
macaron_style: False
length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
channels: 80
sampling_ratios: [1, 1, 1, 1]
decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
in_channels: 240
n_spks: 1
spk_emb_dim: 80
cfm_params: !new:omegaconf.DictConfig
content:
sigma_min: 1e-06
solver: 'euler'
t_scheduler: 'cosine'
training_cfg_rate: 0.2
inference_cfg_rate: 0.7
reg_loss_type: 'l1'
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
in_channels: 320
out_channels: 80
channels: [256, 256]
dropout: 0.0
attention_head_dim: 64
n_blocks: 4
num_mid_blocks: 12
num_heads: 8
act_fn: 'gelu'
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
in_channels: 80
base_channels: 512
nb_harmonics: 8
sampling_rate: !ref <sample_rate>
nsf_alpha: 0.1
nsf_sigma: 0.003
nsf_voiced_threshold: 10
upsample_rates: [8, 8]
upsample_kernel_sizes: [16, 16]
istft_params:
n_fft: 16
hop_len: 4
resblock_kernel_sizes: [3, 7, 11]
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
source_resblock_kernel_sizes: [7, 11]
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
lrelu_slope: 0.1
audio_limit: 0.99
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
num_class: 1
in_channels: 80
cond_channels: 512
# processor functions
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
multilingual: True
num_languages: 100
language: 'en'
task: 'transcribe'
allowed_special: 'all'
tokenize: !name:cosyvoice.dataset.processor.tokenize
get_tokenizer: !ref <get_tokenizer>
allowed_special: !ref <allowed_special>
filter: !name:cosyvoice.dataset.processor.filter
max_length: 40960
min_length: 0
token_max_length: 200
token_min_length: 1
resample: !name:cosyvoice.dataset.processor.resample
resample_rate: !ref <sample_rate>
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1024
num_mels: 80
sampling_rate: !ref <sample_rate>
hop_size: 256
win_size: 1024
fmin: 0
fmax: 8000
center: False
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor>
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
normalize: True
shuffle: !name:cosyvoice.dataset.processor.shuffle
shuffle_size: 1000
sort: !name:cosyvoice.dataset.processor.sort
sort_size: 500 # sort_size should be less than shuffle_size
batch: !name:cosyvoice.dataset.processor.batch
batch_type: 'dynamic'
max_frames_in_batch: 2000
padding: !name:cosyvoice.dataset.processor.padding
use_spk_embedding: False # change to True during sft
# dataset processor pipeline
data_pipeline: [
!ref <parquet_opener>,
!ref <tokenize>,
!ref <filter>,
!ref <resample>,
!ref <compute_fbank>,
!ref <parse_embedding>,
!ref <shuffle>,
!ref <sort>,
!ref <batch>,
!ref <padding>,
]
# train conf
train_conf:
optim: adam
optim_conf:
lr: 0.001 # change to 1e-5 during sft
scheduler: warmuplr # change to constantlr during sft
scheduler_conf:
warmup_steps: 2500
max_epoch: 200
grad_clip: 5
accum_grad: 2
log_interval: 100
save_per_step: -1

View File

@ -1,3 +0,0 @@
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=../../../:../../../third_party/Matcha-TTS:$PYTHONPATH

View File

@ -0,0 +1 @@
../../libritts/cosyvoice/path.sh

View File

@ -51,23 +51,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
done
fi
# inference
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "Run inference. Please make sure utt in tts_text is in prompt_data"
for mode in sft zero_shot; do
python cosyvoice/bin/inference.py --mode $mode \
--gpu 0 \
--config conf/cosyvoice.yaml \
--prompt_data data/test/parquet/data.list \
--prompt_utt2data data/test/parquet/utt2data.list \
--tts_text `pwd`/tts_text.json \
--llm_model $pretrained_model_dir/llm.pt \
--flow_model $pretrained_model_dir/flow.pt \
--hifigan_model $pretrained_model_dir/hift.pt \
--result_dir `pwd`/exp/cosyvoice/test/$mode
done
fi
# train llm
export CUDA_VISIBLE_DEVICES="0,1,2,3"
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
@ -83,7 +66,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
fi
cp data/train/parquet/data.list data/train.data.list
cp data/dev/parquet/data.list data/dev.data.list
for model in llm flow; do
for model in llm flow hifigan; do
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
cosyvoice/bin/train.py \
@ -99,11 +82,26 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
--num_workers ${num_workers} \
--prefetch ${prefetch} \
--pin_memory \
--use_amp \
--deepspeed_config ./conf/ds_stage2.json \
--deepspeed.save_states model+optimizer
done
fi
# average model
average_num=5
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
for model in llm flow hifigan; do
decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
echo "do model average and final checkpoint is $decode_checkpoint"
python cosyvoice/bin/average_model.py \
--dst_model $decode_checkpoint \
--src_path `pwd`/exp/cosyvoice/$model/$train_engine \
--num ${average_num} \
--val_best
done
fi
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir

View File

@ -1,8 +1,10 @@
--extra-index-url https://download.pytorch.org/whl/cu121
--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ # https://github.com/microsoft/onnxruntime/issues/21684
conformer==0.3.2
deepspeed==0.14.2; sys_platform == 'linux'
deepspeed==0.15.1; sys_platform == 'linux'
diffusers==0.29.0
fastapi==0.115.6
fastapi-cli==0.0.4
gdown==5.1.0
gradio==5.4.0
grpcio==1.57.0
@ -13,27 +15,28 @@ inflect==7.3.1
librosa==0.10.2
lightning==2.2.4
matplotlib==3.7.5
modelscope==1.15.0
modelscope==1.20.0
networkx==3.1
numpy==1.26.4
omegaconf==2.3.0
onnx==1.16.0
onnxruntime-gpu==1.18.0; sys_platform == 'linux'
onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'windows'
onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'win32'
openai-whisper==20231117
protobuf==4.25
pyarrow==18.1.0
pydantic==2.7.0
pyworld==0.3.4
rich==13.7.1
soundfile==0.12.1
tensorboard==2.14.0
tensorrt-cu12==10.0.1; sys_platform == 'linux'
tensorrt-cu12-bindings==10.0.1; sys_platform == 'linux'
tensorrt-cu12-libs==10.0.1; sys_platform == 'linux'
tensorrt-cu12==10.13.3.9; sys_platform == 'linux'
tensorrt-cu12-bindings==10.13.3.9; sys_platform == 'linux'
tensorrt-cu12-libs==10.13.3.9; sys_platform == 'linux'
torch==2.3.1
torchaudio==2.3.1
transformers==4.40.1
transformers==4.51.3
x-transformers==2.11.24
uvicorn==0.30.0
wetext==0.0.4
wget==3.2
fastapi==0.115.6
fastapi-cli==0.0.4
WeTextProcessing==1.0.3

View File

@ -5,9 +5,9 @@ WORKDIR /opt/CosyVoice
RUN sed -i s@/archive.ubuntu.com/@/mirrors.aliyun.com/@g /etc/apt/sources.list
RUN apt-get update -y
RUN apt-get -y install git unzip git-lfs
RUN apt-get -y install git unzip git-lfs g++
RUN git lfs install
RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
# here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed
RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir
RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto

View File

@ -24,7 +24,7 @@ import numpy as np
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../../..'.format(ROOT_DIR))
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.utils.file_utils import load_wav
app = FastAPI()
@ -72,6 +72,7 @@ async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instr
model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text)
return StreamingResponse(generate_data(model_output))
@app.get("/inference_instruct2")
@app.post("/inference_instruct2")
async def inference_instruct2(tts_text: str = Form(), instruct_text: str = Form(), prompt_wav: UploadFile = File()):
@ -80,7 +81,6 @@ async def inference_instruct2(tts_text: str = Form(), instruct_text: str = Form(
return StreamingResponse(generate_data(model_output))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--port',
@ -88,14 +88,8 @@ if __name__ == '__main__':
default=50000)
parser.add_argument('--model_dir',
type=str,
default='iic/CosyVoice-300M',
default='iic/CosyVoice2-0.5B',
help='local path or modelscope repo id')
args = parser.parse_args()
try:
cosyvoice = CosyVoice(args.model_dir)
except Exception:
try:
cosyvoice = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
uvicorn.run(app, host="0.0.0.0", port=args.port)
cosyvoice = AutoModel(model_dir=args.model_dir)
uvicorn.run(app, host="0.0.0.0", port=args.port)

View File

@ -25,7 +25,7 @@ import numpy as np
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../../..'.format(ROOT_DIR))
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
from cosyvoice.cli.cosyvoice import AutoModel
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
@ -33,13 +33,7 @@ logging.basicConfig(level=logging.DEBUG,
class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
def __init__(self, args):
try:
self.cosyvoice = CosyVoice(args.model_dir)
except Exception:
try:
self.cosyvoice = CosyVoice2(args.model_dir)
except Exception:
raise TypeError('no valid model_type!')
self.cosyvoice = AutoModel(model_dir=args.model_dir)
logging.info('grpc service initialized')
def Inference(self, request, context):
@ -90,7 +84,7 @@ if __name__ == '__main__':
default=4)
parser.add_argument('--model_dir',
type=str,
default='iic/CosyVoice-300M',
default='iic/CosyVoice2-0.5B',
help='local path or modelscope repo id')
args = parser.parse_args()
main()

View File

@ -0,0 +1,8 @@
FROM nvcr.io/nvidia/tritonserver:25.06-trtllm-python-py3
LABEL maintainer="zhangyuekai@foxmail.com"
RUN apt-get update && apt-get install -y cmake
RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop
COPY ./requirements.txt /workspace/requirements.txt
RUN pip install -r /workspace/requirements.txt
WORKDIR /workspace

View File

@ -0,0 +1,141 @@
## Accelerating CosyVoice with DiT-based Token2Wav, NVIDIA Triton Inference Server and TensorRT-LLM
Contributed by Yuekai Zhang (NVIDIA).
This document describes how to accelerate CosyVoice with a DiT-based Token2Wav module from Step-Audio2, using NVIDIA Triton Inference Server and TensorRT-LLM.
### Quick Start
Launch the service directly with Docker Compose:
```sh
docker compose -f docker-compose.dit.yml up
```
### Build the Docker Image
To build the image from scratch:
```sh
docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
```
### Run a Docker Container
```sh
your_mount_dir=/mnt:/mnt
docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06
```
### Understanding `run_stepaudio2_dit_token2wav.sh`
The `run_stepaudio2_dit_token2wav.sh` script orchestrates the entire workflow through numbered stages.
You can run a subset of stages with:
```sh
bash run_stepaudio2_dit_token2wav.sh <start_stage> <stop_stage>
```
- `<start_stage>`: The stage to start from.
- `<stop_stage>`: The stage to stop after.
**Stages:**
- **Stage -1**: Clones the `Step-Audio2` and `CosyVoice` repositories.
- **Stage 0**: Downloads the `cosyvoice2_llm`, `CosyVoice2-0.5B`, and `Step-Audio-2-mini` models.
- **Stage 1**: Converts the HuggingFace checkpoint for the LLM to the TensorRT-LLM format and builds the TensorRT engines.
- **Stage 2**: Creates the Triton model repository, including configurations for `cosyvoice2_dit` and `token2wav_dit`.
- **Stage 3**: Launches the Triton Inference Server for Token2Wav module and uses `trtllm-serve` to deploy Cosyvoice2 LLM.
- **Stage 4**: Runs the gRPC benchmark client for performance testing.
- **Stage 5**: Runs the offline TTS inference benchmark test.
- **Stage 6**: Runs a standalone inference script for the Step-Audio2-mini DiT Token2Wav model.
- **Stage 7**: Launches servers in a disaggregated setup, with the LLM on GPU 0 and Token2Wav servers on GPUs 1-3.
- **Stage 8**: Runs the benchmark client for the disaggregated server configuration.
### Export Models and Launch Server
Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
```sh
# This command runs stages 0, 1, 2, and 3
bash run_stepaudio2_dit_token2wav.sh 0 3
```
### Benchmark with client-server mode
To benchmark the running Triton server, run stage 4:
```sh
bash run_stepaudio2_dit_token2wav.sh 4 4
# You can customize parameters such as the number of tasks inside the script.
```
The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset.
#### Total Request Latency
| Concurrent Tasks | RTF | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) |
| ---------------- | ------ | ------------ | -------------------- | -------------------- | -------------------- | -------------------- |
| 1 | 0.1228 | 833.66 | 779.98 | 1297.05 | 1555.97 | 1653.02 |
| 2 | 0.0901 | 1166.23 | 1124.69 | 1762.76 | 1900.64 | 2204.14 |
| 4 | 0.0741 | 1849.30 | 1759.42 | 2624.50 | 2822.20 | 3128.42 |
| 6 | 0.0774 | 2936.13 | 3054.64 | 3849.60 | 3900.49 | 4245.79 |
| 8 | 0.0691 | 3408.56 | 3434.98 | 4547.13 | 5047.76 | 5346.53 |
| 10 | 0.0707 | 4306.56 | 4343.44 | 5769.64 | 5876.09 | 5939.79 |
#### First Chunk Latency
| Concurrent Tasks | Average (ms) | 50th Percentile (ms) | 90th Percentile (ms) | 95th Percentile (ms) | 99th Percentile (ms) |
| ---------------- | ------------ | -------------------- | -------------------- | -------------------- | -------------------- |
| 1 | 197.50 | 196.13 | 214.65 | 215.96 | 229.21 |
| 2 | 281.15 | 278.20 | 345.18 | 361.79 | 395.97 |
| 4 | 510.65 | 530.50 | 630.13 | 642.44 | 666.65 |
| 6 | 921.54 | 918.86 | 1079.97 | 1265.22 | 1524.41 |
| 8 | 1019.95 | 1085.26 | 1371.05 | 1402.24 | 1410.66 |
| 10 | 1214.98 | 1293.54 | 1575.36 | 1654.51 | 2161.76 |
### Benchmark with offline inference mode
For offline inference mode benchmark, please run stage 5:
```sh
bash run_stepaudio2_dit_token2wav.sh 5 5
```
The following results were obtained by decoding on a single L20 GPU with the `yuekai/seed_tts_cosy2` dataset.
#### Offline TTS (Cosyvoice2 0.5B LLM + StepAudio2 DiT Token2Wav)
| Backend | Batch Size | llm_time_seconds | total_time_seconds | RTF |
|---------|------------|------------------|-----------------------|--|
| TRTLLM | 16 | 2.01 | 5.03 | 0.0292 |
### Disaggregated Server
When the LLM and token2wav components are deployed on the same GPU, they compete for resources. To optimize performance, we use a disaggregated setup where the LLM is deployed on one dedicated L20 GPU, taking advantage of in-flight batching for inference. The token2wav module is deployed on separate, dedicated GPUs.
The table below shows the first chunk latency results for this configuration. In our tests, we deploy two token2wav instances on each dedicated token2wav GPU.
| token2wav_num_gpu | concurrent_task_per_instance | concurrent_tasks_per_gpu | avg (ms) | p50 (ms) | p90 (ms) | p99 (ms) |
|---|---|---|---|---|---|---|
| 1 | 1 | 1.00 | 218.53 | 217.86 | 254.07 | 296.49 |
| 2 | 1 | 1.33 | 218.82 | 219.21 | 256.62 | 303.13 |
| 3 | 1 | 1.50 | 229.08 | 223.27 | 302.13 | 324.41 |
| 4 | 1 | 1.60 | 203.87 | 198.23 | 254.92 | 279.31 |
| 1 | 2 | 2.00 | 293.46 | 280.53 | 370.81 | 407.40 |
| 2 | 2 | 2.67 | 263.38 | 236.84 | 350.82 | 397.39 |
| 3 | 2 | 3.00 | 308.09 | 275.48 | 385.22 | 521.45 |
| 4 | 2 | 3.20 | 271.85 | 253.25 | 359.03 | 387.91 |
| 1 | 3 | 3.00 | 389.15 | 373.01 | 469.22 | 542.89 |
| 2 | 3 | 4.00 | 403.48 | 394.80 | 481.24 | 507.75 |
| 3 | 3 | 4.50 | 406.33 | 391.28 | 495.43 | 571.29 |
| 4 | 3 | 4.80 | 436.72 | 383.81 | 638.44 | 879.23 |
| 1 | 4 | 4.00 | 520.12 | 493.98 | 610.38 | 739.85 |
| 2 | 4 | 5.33 | 494.60 | 490.50 | 605.93 | 708.09 |
| 3 | 4 | 6.00 | 538.23 | 508.33 | 687.62 | 736.96 |
| 4 | 4 | 6.40 | 579.68 | 546.20 | 721.53 | 958.04 |
| 1 | 5 | 5.00 | 635.02 | 623.30 | 786.85 | 819.84 |
| 2 | 5 | 6.67 | 598.23 | 617.09 | 741.00 | 788.96 |
| 3 | 5 | 7.50 | 644.78 | 684.40 | 786.45 | 1009.45 |
| 4 | 5 | 8.00 | 733.92 | 642.26 | 1024.79 | 1281.55 |
| 1 | 6 | 6.00 | 715.38 | 745.68 | 887.04 | 906.68 |
| 2 | 6 | 8.00 | 748.31 | 753.94 | 873.59 | 1007.14 |
| 3 | 6 | 9.00 | 900.27 | 822.28 | 1431.14 | 1800.23 |
| 4 | 6 | 9.60 | 857.54 | 820.33 | 1150.30 | 1298.53 |
The `concurrent_task_per_gpu` is calculated as:
`concurrent_task_per_gpu = concurrent_task_per_instance * num_token2wav_instance_per_gpu (2) * token2wav_gpus / (token2wav_gpus + llm_gpus (1))`
### Acknowledgements
This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub).

View File

@ -0,0 +1,146 @@
## Accelerating CosyVoice with NVIDIA Triton Inference Server and TensorRT-LLM
Contributed by Yuekai Zhang (NVIDIA).
### Quick Start
Launch the service directly with Docker Compose:
```sh
docker compose up
```
### Build the Docker Image
To build the image from scratch:
```sh
docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06
```
### Run a Docker Container
```sh
your_mount_dir=/mnt:/mnt
docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06
```
### Understanding `run.sh`
The `run.sh` script orchestrates the entire workflow through numbered stages.
You can run a subset of stages with:
```sh
bash run.sh <start_stage> <stop_stage> [service_type]
```
- `<start_stage>`: The stage to start from (0-5).
- `<stop_stage>`: The stage to stop after (0-5).
**Stages:**
- **Stage 0**: Downloads the `cosyvoice-2 0.5B` model from HuggingFace.
- **Stage 1**: Converts the HuggingFace checkpoint to the TensorRT-LLM format and builds the TensorRT engines.
- **Stage 2**: Creates the Triton model repository and configures the model files. The configuration is adjusted based on whether `Decoupled=True` (streaming) or `Decoupled=False` (offline) will be used.
- **Stage 3**: Launches the Triton Inference Server.
- **Stage 4**: Runs the single-utterance HTTP client for testing.
- **Stage 5**: Runs the gRPC benchmark client.
- **Stage 6**: Runs the offline inference benchmark test.
### Export Models and Launch Server
Inside the Docker container, prepare the models and start the Triton server by running stages 0-3:
```sh
# This command runs stages 0, 1, 2, and 3
bash run.sh 0 3
```
> [!TIP]
> Both streaming and offline (non-streaming) TTS modes are supported. For streaming TTS, set `Decoupled=True`. For offline TTS, set `Decoupled=False`. You need to rerun stage 2 if you switch between modes.
### Single-Utterance HTTP Client
Sends a single HTTP inference request. This is intended for testing the offline TTS mode (`Decoupled=False`):
```sh
bash run.sh 4 4
```
### Benchmark with client-server mode
To benchmark the running Triton server, pass `streaming` or `offline` as the third argument:
```sh
bash run.sh 5 5 # [streaming|offline]
# You can also customize parameters such as the number of tasks and the dataset split:
# python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts_cosy2 --split-name test_zh --mode [streaming|offline]
```
> [!TIP]
> It is recommended to run the benchmark multiple times to get stable results after the initial server warm-up.
### Benchmark with offline inference mode
For offline inference mode benchmark, please check the below command:
```sh
# install FlashCosyVoice for token2wav batching
# git clone https://github.com/yuekaizhang/FlashCosyVoice.git /workspace/FlashCosyVoice -b trt
# cd /workspace/FlashCosyVoice
# pip install -e .
# cd -
# wget https://huggingface.co/yuekai/cosyvoice2_flow_onnx/resolve/main/flow.decoder.estimator.fp32.dynamic_batch.onnx -O $model_scope_model_local_dir/flow.decoder.estimator.fp32.dynamic_batch.onnx
bash run.sh 6 6
# You can also switch to huggingface backend by setting backend=hf
```
### Benchmark Results
The following results were obtained by decoding on a single L20 GPU with 26 prompt audio/target text pairs from the [yuekai/seed_tts](https://huggingface.co/datasets/yuekai/seed_tts) dataset (approximately 170 seconds of audio):
**Client-Server Mode: Streaming TTS (First Chunk Latency)**
| Mode | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|---|---|---|---|---|
| Streaming, use_spk2info_cache=False | 1 | 220.43 | 218.07 | 0.1237 |
| Streaming, use_spk2info_cache=False | 2 | 476.97 | 369.25 | 0.1022 |
| Streaming, use_spk2info_cache=False | 4 | 1107.34 | 1243.75| 0.0922 |
| Streaming, use_spk2info_cache=True | 1 | 189.88 | 184.81 | 0.1155 |
| Streaming, use_spk2info_cache=True | 2 | 323.04 | 316.83 | 0.0905 |
| Streaming, use_spk2info_cache=True | 4 | 977.68 | 903.68| 0.0733 |
> If your service only needs a fixed speaker, you can set `use_spk2info_cache=True` in `run.sh`. To add more speakers, refer to the instructions [here](https://github.com/qi-hua/async_cosyvoice?tab=readme-ov-file#9-spk2info-%E8%AF%B4%E6%98%8E).
**Client-Server Mode: Offline TTS (Full Sentence Latency)**
| Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF |
|---|---|---|---|---|---|
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 |
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 |
| Offline, Decoupled=False, use_spk2info_cache=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 |
**Offline Inference Mode: Hugginface LLM V.S. TensorRT-LLM**
| Backend | Batch Size | llm_time_seconds | total_time_seconds | RTF |
|---------|------------|------------------|-----------------------|--|
| HF | 1 | 39.26 | 44.31 | 0.2494 |
| HF | 2 | 30.54 | 35.62 | 0.2064 |
| HF | 4 | 18.63 | 23.90 | 0.1421 |
| HF | 8 | 11.22 | 16.45 | 0.0947 |
| HF | 16 | 8.42 | 13.78 | 0.0821 |
| TRTLLM | 1 | 12.46 | 17.31 | 0.0987 |
| TRTLLM | 2 | 7.64 |12.65 | 0.0739 |
| TRTLLM | 4 | 4.89 | 9.38 | 0.0539 |
| TRTLLM | 8 | 2.92 | 7.23 | 0.0418 |
| TRTLLM | 16 | 2.01 | 6.63 | 0.0386 |
### OpenAI-Compatible Server
To launch an OpenAI-compatible API service, run the following commands:
```sh
git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git
cd Triton-OpenAI-Speech
pip install -r requirements.txt
# After the Triton service is running, start the FastAPI bridge:
python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000
# Test the service with curl:
bash test/test_cosyvoice.sh
```
> [!NOTE]
> Currently, only the offline TTS mode is compatible with the OpenAI-compatible server.
### Acknowledgements
This work originates from the NVIDIA CISI project. For more multimodal resources, please see [mair-hub](https://github.com/nvidia-china-sae/mair-hub).

View File

@ -0,0 +1,922 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
# 2023 Nvidia (authors: Yuekai Zhang)
# 2023 Recurrent.ai (authors: Songtao Shi)
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script supports to load dataset from huggingface and sends it to the server
for decoding, in parallel.
Usage:
num_task=2
# For offline F5-TTS
python3 client_grpc.py \
--server-addr localhost \
--model-name f5_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name test_zh \
--log-dir ./log_concurrent_tasks_${num_task}
# For offline Spark-TTS-0.5B
python3 client_grpc.py \
--server-addr localhost \
--model-name spark_tts \
--num-tasks $num_task \
--huggingface-dataset yuekai/seed_tts \
--split-name wenetspeech4tts \
--log-dir ./log_concurrent_tasks_${num_task}
"""
import argparse
import asyncio
import json
import queue
import uuid
import functools
import os
import time
import types
from pathlib import Path
import numpy as np
import soundfile as sf
import tritonclient
import tritonclient.grpc.aio as grpcclient_aio
import tritonclient.grpc as grpcclient_sync
from tritonclient.utils import np_to_triton_dtype, InferenceServerException
class UserData:
def __init__(self):
self._completed_requests = queue.Queue()
self._first_chunk_time = None
self._second_chunk_time = None
self._start_time = None
def record_start_time(self):
self._start_time = time.time()
def get_first_chunk_latency(self):
if self._first_chunk_time and self._start_time:
return self._first_chunk_time - self._start_time
return None
def get_second_chunk_latency(self):
if self._first_chunk_time and self._second_chunk_time:
return self._second_chunk_time - self._first_chunk_time
return None
def callback(user_data, result, error):
if not error:
if user_data._first_chunk_time is None:
user_data._first_chunk_time = time.time()
elif user_data._second_chunk_time is None:
user_data._second_chunk_time = time.time()
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)
def stream_callback(user_data_map, result, error):
request_id = None
if error:
print(f"An error occurred in the stream callback: {error}")
else:
request_id = result.get_response().id
if request_id:
user_data = user_data_map.get(request_id)
if user_data:
callback(user_data, result, error)
else:
print(f"Warning: Could not find user_data for request_id {request_id}")
def write_triton_stats(stats, summary_file):
with open(summary_file, "w") as summary_f:
model_stats = stats["model_stats"]
for model_state in model_stats:
if "last_inference" not in model_state:
continue
summary_f.write(f"model name is {model_state['name']} \n")
model_inference_stats = model_state["inference_stats"]
total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9
total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9
total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9
total_output_time_s = int(model_inference_stats["compute_output"]["ns"]) / 1e9
summary_f.write(
f"queue time {total_queue_time_s:<5.2f} s, "
f"compute infer time {total_infer_time_s:<5.2f} s, "
f"compute input time {total_input_time_s:<5.2f} s, "
f"compute output time {total_output_time_s:<5.2f} s \n"
)
model_batch_stats = model_state["batch_stats"]
for batch in model_batch_stats:
batch_size = int(batch["batch_size"])
compute_input = batch["compute_input"]
compute_output = batch["compute_output"]
compute_infer = batch["compute_infer"]
batch_count = int(compute_infer["count"])
if batch_count == 0:
continue
assert compute_infer["count"] == compute_output["count"] == compute_input["count"]
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
compute_input_time_ms = int(compute_input["ns"]) / 1e6
compute_output_time_ms = int(compute_output["ns"]) / 1e6
summary_f.write(
f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, "
f"total_infer_time {compute_infer_time_ms:<9.2f} ms, "
f"avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}="
f"{compute_infer_time_ms / batch_count:.2f} ms, "
f"avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}="
f"{compute_infer_time_ms / batch_count / batch_size:.2f} ms \n"
)
summary_f.write(
f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms / batch_count:.2f} ms, "
)
summary_f.write(
f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms / batch_count:.2f} ms \n"
)
def subtract_stats(stats_after, stats_before):
"""Subtracts two Triton inference statistics objects."""
stats_diff = json.loads(json.dumps(stats_after))
model_stats_before_map = {
s["name"]: {
"version": s["version"],
"last_inference": s.get("last_inference", 0),
"inference_count": s.get("inference_count", 0),
"execution_count": s.get("execution_count", 0),
"inference_stats": s.get("inference_stats", {}),
"batch_stats": s.get("batch_stats", []),
}
for s in stats_before["model_stats"]
}
for model_stat_after in stats_diff["model_stats"]:
model_name = model_stat_after["name"]
if model_name in model_stats_before_map:
model_stat_before = model_stats_before_map[model_name]
model_stat_after["inference_count"] = str(
int(model_stat_after.get("inference_count", 0)) - int(model_stat_before.get("inference_count", 0))
)
model_stat_after["execution_count"] = str(
int(model_stat_after.get("execution_count", 0)) - int(model_stat_before.get("execution_count", 0))
)
if "inference_stats" in model_stat_after and "inference_stats" in model_stat_before:
for key in ["success", "fail", "queue", "compute_input", "compute_infer", "compute_output", "cache_hit", "cache_miss"]:
if key in model_stat_after["inference_stats"] and key in model_stat_before["inference_stats"]:
if "ns" in model_stat_after["inference_stats"][key]:
ns_after = int(model_stat_after["inference_stats"][key]["ns"])
ns_before = int(model_stat_before["inference_stats"][key]["ns"])
model_stat_after["inference_stats"][key]["ns"] = str(ns_after - ns_before)
if "count" in model_stat_after["inference_stats"][key]:
count_after = int(model_stat_after["inference_stats"][key]["count"])
count_before = int(model_stat_before["inference_stats"][key]["count"])
model_stat_after["inference_stats"][key]["count"] = str(count_after - count_before)
if "batch_stats" in model_stat_after and "batch_stats" in model_stat_before:
batch_stats_before_map = {b["batch_size"]: b for b in model_stat_before["batch_stats"]}
for batch_stat_after in model_stat_after["batch_stats"]:
bs = batch_stat_after["batch_size"]
if bs in batch_stats_before_map:
batch_stat_before = batch_stats_before_map[bs]
for key in ["compute_input", "compute_infer", "compute_output"]:
if key in batch_stat_after and key in batch_stat_before:
count_after = int(batch_stat_after[key]["count"])
count_before = int(batch_stat_before[key]["count"])
batch_stat_after[key]["count"] = str(count_after - count_before)
ns_after = int(batch_stat_after[key]["ns"])
ns_before = int(batch_stat_before[key]["ns"])
batch_stat_after[key]["ns"] = str(ns_after - ns_before)
return stats_diff
def get_args():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--server-addr",
type=str,
default="localhost",
help="Address of the server",
)
parser.add_argument(
"--server-port",
type=int,
default=8001,
help="Grpc port of the triton server, default is 8001",
)
parser.add_argument(
"--reference-audio",
type=str,
default=None,
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
)
parser.add_argument(
"--reference-text",
type=str,
default="",
help="",
)
parser.add_argument(
"--target-text",
type=str,
default="",
help="",
)
parser.add_argument(
"--huggingface-dataset",
type=str,
default="yuekai/seed_tts",
help="dataset name in huggingface dataset hub",
)
parser.add_argument(
"--split-name",
type=str,
default="wenetspeech4tts",
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
help="dataset split name, default is 'test'",
)
parser.add_argument(
"--manifest-path",
type=str,
default=None,
help="Path to the manifest dir which includes wav.scp trans.txt files.",
)
parser.add_argument(
"--model-name",
type=str,
default="f5_tts",
choices=[
"f5_tts",
"spark_tts",
"cosyvoice2",
"cosyvoice2_dit"],
help="triton model_repo module name to request",
)
parser.add_argument(
"--num-tasks",
type=int,
default=1,
help="Number of concurrent tasks for sending",
)
parser.add_argument(
"--log-interval",
type=int,
default=5,
help="Controls how frequently we print the log.",
)
parser.add_argument(
"--compute-wer",
action="store_true",
default=False,
help="""True to compute WER.
""",
)
parser.add_argument(
"--log-dir",
type=str,
required=False,
default="./tmp",
help="log directory",
)
parser.add_argument(
"--mode",
type=str,
default="offline",
choices=["offline", "streaming"],
help="Select offline or streaming benchmark mode."
)
parser.add_argument(
"--chunk-overlap-duration",
type=float,
default=0.1,
help="Chunk overlap duration for streaming reconstruction (in seconds)."
)
parser.add_argument(
"--use-spk2info-cache",
type=str,
default="False",
help="Use spk2info cache for reference audio.",
)
return parser.parse_args()
def load_audio(wav_path, target_sample_rate=16000):
assert target_sample_rate == 16000, "hard coding in server"
if isinstance(wav_path, dict):
waveform = wav_path["array"]
sample_rate = wav_path["sampling_rate"]
else:
waveform, sample_rate = sf.read(wav_path)
if sample_rate != target_sample_rate:
from scipy.signal import resample
num_samples = int(len(waveform) * (target_sample_rate / sample_rate))
waveform = resample(waveform, num_samples)
return waveform, target_sample_rate
def prepare_request_input_output(
protocol_client,
waveform,
reference_text,
target_text,
sample_rate=16000,
padding_duration: int = None,
use_spk2info_cache: bool = False
):
"""Prepares inputs for Triton inference (offline or streaming)."""
assert len(waveform.shape) == 1, "waveform should be 1D"
lengths = np.array([[len(waveform)]], dtype=np.int32)
if padding_duration:
duration = len(waveform) / sample_rate
if reference_text:
estimated_target_duration = duration / len(reference_text) * len(target_text)
else:
estimated_target_duration = duration
required_total_samples = padding_duration * sample_rate * (
(int(estimated_target_duration + duration) // padding_duration) + 1
)
samples = np.zeros((1, required_total_samples), dtype=np.float32)
samples[0, : len(waveform)] = waveform
else:
samples = waveform.reshape(1, -1).astype(np.float32)
inputs = [
protocol_client.InferInput("reference_wav", samples.shape, np_to_triton_dtype(samples.dtype)),
protocol_client.InferInput(
"reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype)
),
protocol_client.InferInput("reference_text", [1, 1], "BYTES"),
protocol_client.InferInput("target_text", [1, 1], "BYTES"),
]
inputs[0].set_data_from_numpy(samples)
inputs[1].set_data_from_numpy(lengths)
input_data_numpy = np.array([reference_text], dtype=object)
input_data_numpy = input_data_numpy.reshape((1, 1))
inputs[2].set_data_from_numpy(input_data_numpy)
input_data_numpy = np.array([target_text], dtype=object)
input_data_numpy = input_data_numpy.reshape((1, 1))
inputs[3].set_data_from_numpy(input_data_numpy)
outputs = [protocol_client.InferRequestedOutput("waveform")]
if use_spk2info_cache:
inputs = inputs[-1:]
return inputs, outputs
def run_sync_streaming_inference(
sync_triton_client: tritonclient.grpc.InferenceServerClient,
model_name: str,
inputs: list,
outputs: list,
request_id: str,
user_data: UserData,
chunk_overlap_duration: float,
save_sample_rate: int,
audio_save_path: str,
):
"""Helper function to run the blocking sync streaming call."""
start_time_total = time.time()
user_data.record_start_time()
sync_triton_client.async_stream_infer(
model_name,
inputs,
request_id=request_id,
outputs=outputs,
enable_empty_final_response=True,
)
audios = []
while True:
try:
result = user_data._completed_requests.get(timeout=200)
if isinstance(result, InferenceServerException):
print(f"Received InferenceServerException: {result}")
return None, None, None, None
response = result.get_response()
final = response.parameters["triton_final_response"].bool_param
if final is True:
break
audio_chunk = result.as_numpy("waveform").reshape(-1)
if audio_chunk.size > 0:
audios.append(audio_chunk)
else:
print("Warning: received empty audio chunk.")
except queue.Empty:
print(f"Timeout waiting for response for request id {request_id}")
return None, None, None, None
end_time_total = time.time()
total_request_latency = end_time_total - start_time_total
first_chunk_latency = user_data.get_first_chunk_latency()
second_chunk_latency = user_data.get_second_chunk_latency()
if audios:
if model_name == "spark_tts":
cross_fade_samples = int(chunk_overlap_duration * save_sample_rate)
fade_out = np.linspace(1, 0, cross_fade_samples)
fade_in = np.linspace(0, 1, cross_fade_samples)
reconstructed_audio = None
if not audios:
print("Warning: No audio chunks received.")
reconstructed_audio = np.array([], dtype=np.float32)
elif len(audios) == 1:
reconstructed_audio = audios[0]
else:
reconstructed_audio = audios[0][:-cross_fade_samples]
for i in range(1, len(audios)):
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
audios[i - 1][-cross_fade_samples:] * fade_out)
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
if reconstructed_audio is not None and reconstructed_audio.size > 0:
actual_duration = len(reconstructed_audio) / save_sample_rate
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
else:
print("Warning: No audio chunks received or reconstructed.")
actual_duration = 0
else:
reconstructed_audio = np.concatenate(audios)
actual_duration = len(reconstructed_audio) / save_sample_rate
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
else:
print("Warning: No audio chunks received.")
actual_duration = 0
return total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration
async def send_streaming(
manifest_item_list: list,
name: str,
server_url: str,
protocol_client: types.ModuleType,
log_interval: int,
model_name: str,
audio_save_dir: str = "./",
save_sample_rate: int = 16000,
chunk_overlap_duration: float = 0.1,
padding_duration: int = None,
use_spk2info_cache: bool = False,
):
total_duration = 0.0
latency_data = []
task_id = int(name[5:])
sync_triton_client = None
user_data_map = {}
try:
print(f"{name}: Initializing sync client for streaming...")
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False)
sync_triton_client.start_stream(callback=functools.partial(stream_callback, user_data_map))
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
for i, item in enumerate(manifest_item_list):
if i % log_interval == 0:
print(f"{name}: Processing item {i}/{len(manifest_item_list)}")
try:
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
reference_text, target_text = item["reference_text"], item["target_text"]
inputs, outputs = prepare_request_input_output(
protocol_client,
waveform,
reference_text,
target_text,
sample_rate,
padding_duration=padding_duration,
use_spk2info_cache=use_spk2info_cache
)
request_id = str(uuid.uuid4())
user_data = UserData()
user_data_map[request_id] = user_data
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration = await asyncio.to_thread(
run_sync_streaming_inference,
sync_triton_client,
model_name,
inputs,
outputs,
request_id,
user_data,
chunk_overlap_duration,
save_sample_rate,
audio_save_path
)
if total_request_latency is not None:
print(
f"{name}: Item {i} - First Chunk Latency: {first_chunk_latency:.4f}s, "
f"Second Chunk Latency: {second_chunk_latency if second_chunk_latency is not None else 'N/A'}, "
f"Total Latency: {total_request_latency:.4f}s, Duration: {actual_duration:.4f}s"
)
latency_data.append((total_request_latency, first_chunk_latency, second_chunk_latency, actual_duration))
total_duration += actual_duration
else:
print(f"{name}: Item {i} failed.")
del user_data_map[request_id]
except FileNotFoundError:
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
except Exception as e:
print(f"Error processing item {i} ({item['target_audio_path']}): {e}")
import traceback
traceback.print_exc()
finally:
if sync_triton_client:
try:
print(f"{name}: Closing stream and sync client...")
sync_triton_client.stop_stream()
sync_triton_client.close()
except Exception as e:
print(f"{name}: Error closing sync client: {e}")
print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
return total_duration, latency_data
async def send(
manifest_item_list: list,
name: str,
triton_client: tritonclient.grpc.aio.InferenceServerClient,
protocol_client: types.ModuleType,
log_interval: int,
model_name: str,
padding_duration: int = None,
audio_save_dir: str = "./",
save_sample_rate: int = 16000,
use_spk2info_cache: bool = False,
):
total_duration = 0.0
latency_data = []
task_id = int(name[5:])
for i, item in enumerate(manifest_item_list):
if i % log_interval == 0:
print(f"{name}: {i}/{len(manifest_item_list)}")
waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000)
reference_text, target_text = item["reference_text"], item["target_text"]
inputs, outputs = prepare_request_input_output(
protocol_client,
waveform,
reference_text,
target_text,
sample_rate,
padding_duration=padding_duration,
use_spk2info_cache=use_spk2info_cache
)
sequence_id = 100000000 + i + task_id * 10
start = time.time()
response = await triton_client.infer(model_name, inputs, request_id=str(sequence_id), outputs=outputs)
audio = response.as_numpy("waveform").reshape(-1)
actual_duration = len(audio) / save_sample_rate
end = time.time() - start
audio_save_path = os.path.join(audio_save_dir, f"{item['target_audio_path']}.wav")
sf.write(audio_save_path, audio, save_sample_rate, "PCM_16")
latency_data.append((end, actual_duration))
total_duration += actual_duration
return total_duration, latency_data
def load_manifests(manifest_path):
with open(manifest_path, "r") as f:
manifest_list = []
for line in f:
assert len(line.strip().split("|")) == 4
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
utt = Path(utt).stem
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav)
manifest_list.append(
{
"audio_filepath": prompt_wav,
"reference_text": prompt_text,
"target_text": gt_text,
"target_audio_path": utt,
}
)
return manifest_list
def split_data(data, k):
n = len(data)
if n < k:
print(f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}.")
k = n
quotient = n // k
remainder = n % k
result = []
start = 0
for i in range(k):
if i < remainder:
end = start + quotient + 1
else:
end = start + quotient
result.append(data[start:end])
start = end
return result
async def main():
args = get_args()
url = f"{args.server_addr}:{args.server_port}"
triton_client = None
protocol_client = None
if args.mode == "offline":
print("Initializing gRPC client for offline mode...")
triton_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
protocol_client = grpcclient_aio
elif args.mode == "streaming":
print("Initializing gRPC client for streaming mode...")
protocol_client = grpcclient_sync
else:
raise ValueError(f"Invalid mode: {args.mode}")
if args.reference_audio:
args.num_tasks = 1
args.log_interval = 1
manifest_item_list = [
{
"reference_text": args.reference_text,
"target_text": args.target_text,
"audio_filepath": args.reference_audio,
"target_audio_path": "test",
}
]
elif args.huggingface_dataset:
import datasets
dataset = datasets.load_dataset(
args.huggingface_dataset,
split=args.split_name,
trust_remote_code=True,
)
manifest_item_list = []
for i in range(len(dataset)):
manifest_item_list.append(
{
"audio_filepath": dataset[i]["prompt_audio"],
"reference_text": dataset[i]["prompt_text"],
"target_audio_path": dataset[i]["id"],
"target_text": dataset[i]["target_text"],
}
)
else:
manifest_item_list = load_manifests(args.manifest_path)
stats_client = None
stats_before = None
try:
print("Initializing temporary async client for fetching stats...")
stats_client = grpcclient_aio.InferenceServerClient(url=url, verbose=False)
print("Fetching inference statistics before running tasks...")
stats_before = await stats_client.get_inference_statistics(model_name="", as_json=True)
except Exception as e:
print(f"Could not retrieve statistics before running tasks: {e}")
num_tasks = min(args.num_tasks, len(manifest_item_list))
manifest_item_list = split_data(manifest_item_list, num_tasks)
os.makedirs(args.log_dir, exist_ok=True)
args.use_spk2info_cache = args.use_spk2info_cache == "True" or args.use_spk2info_cache == "true"
tasks = []
start_time = time.time()
for i in range(num_tasks):
if args.mode == "offline":
task = asyncio.create_task(
send(
manifest_item_list[i],
name=f"task-{i}",
triton_client=triton_client,
protocol_client=protocol_client,
log_interval=args.log_interval,
model_name=args.model_name,
audio_save_dir=args.log_dir,
padding_duration=1,
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
use_spk2info_cache=args.use_spk2info_cache,
)
)
elif args.mode == "streaming":
task = asyncio.create_task(
send_streaming(
manifest_item_list[i],
name=f"task-{i}",
server_url=url,
protocol_client=protocol_client,
log_interval=args.log_interval,
model_name=args.model_name,
audio_save_dir=args.log_dir,
padding_duration=10,
save_sample_rate=16000 if args.model_name == "spark_tts" else 24000,
chunk_overlap_duration=args.chunk_overlap_duration,
use_spk2info_cache=args.use_spk2info_cache,
)
)
tasks.append(task)
ans_list = await asyncio.gather(*tasks)
end_time = time.time()
elapsed = end_time - start_time
total_duration = 0.0
latency_data = []
for ans in ans_list:
if ans:
total_duration += ans[0]
latency_data.extend(ans[1])
else:
print("Warning: A task returned None, possibly due to an error.")
if total_duration == 0:
print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
rtf = float('inf')
else:
rtf = elapsed / total_duration
s = f"Mode: {args.mode}\n"
s += f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration / 3600:.2f} hours)\n"
s += f"processing time: {elapsed:.3f} seconds ({elapsed / 3600:.2f} hours)\n"
if latency_data:
if args.mode == "offline":
latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
if latency_list:
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
s += f"latency_variance: {latency_variance:.2f}\n"
s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n"
s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
s += f"average_latency_ms: {latency_ms:.2f}\n"
else:
s += "No latency data collected for offline mode.\n"
elif args.mode == "streaming":
total_latency_list = [total for (total, first, second, duration) in latency_data if total is not None]
first_chunk_latency_list = [first for (total, first, second, duration) in latency_data if first is not None]
second_chunk_latency_list = [second for (total, first, second, duration) in latency_data if second is not None]
s += "\n--- Total Request Latency ---\n"
if total_latency_list:
avg_total_latency_ms = sum(total_latency_list) / len(total_latency_list) * 1000.0
variance_total_latency = np.var(total_latency_list, dtype=np.float64) * 1000.0
s += f"total_request_latency_variance: {variance_total_latency:.2f}\n"
s += f"total_request_latency_50_percentile_ms: {np.percentile(total_latency_list, 50) * 1000.0:.2f}\n"
s += f"total_request_latency_90_percentile_ms: {np.percentile(total_latency_list, 90) * 1000.0:.2f}\n"
s += f"total_request_latency_95_percentile_ms: {np.percentile(total_latency_list, 95) * 1000.0:.2f}\n"
s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
else:
s += "No total request latency data collected.\n"
s += "\n--- First Chunk Latency ---\n"
if first_chunk_latency_list:
avg_first_chunk_latency_ms = sum(first_chunk_latency_list) / len(first_chunk_latency_list) * 1000.0
variance_first_chunk_latency = np.var(first_chunk_latency_list, dtype=np.float64) * 1000.0
s += f"first_chunk_latency_variance: {variance_first_chunk_latency:.2f}\n"
s += f"first_chunk_latency_50_percentile_ms: {np.percentile(first_chunk_latency_list, 50) * 1000.0:.2f}\n"
s += f"first_chunk_latency_90_percentile_ms: {np.percentile(first_chunk_latency_list, 90) * 1000.0:.2f}\n"
s += f"first_chunk_latency_95_percentile_ms: {np.percentile(first_chunk_latency_list, 95) * 1000.0:.2f}\n"
s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
else:
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
s += "\n--- Second Chunk Latency ---\n"
if second_chunk_latency_list:
avg_second_chunk_latency_ms = sum(second_chunk_latency_list) / len(second_chunk_latency_list) * 1000.0
variance_second_chunk_latency = np.var(second_chunk_latency_list, dtype=np.float64) * 1000.0
s += f"second_chunk_latency_variance: {variance_second_chunk_latency:.2f}\n"
s += f"second_chunk_latency_50_percentile_ms: {np.percentile(second_chunk_latency_list, 50) * 1000.0:.2f}\n"
s += f"second_chunk_latency_90_percentile_ms: {np.percentile(second_chunk_latency_list, 90) * 1000.0:.2f}\n"
s += f"second_chunk_latency_95_percentile_ms: {np.percentile(second_chunk_latency_list, 95) * 1000.0:.2f}\n"
s += f"second_chunk_latency_99_percentile_ms: {np.percentile(second_chunk_latency_list, 99) * 1000.0:.2f}\n"
s += f"average_second_chunk_latency_ms: {avg_second_chunk_latency_ms:.2f}\n"
else:
s += "No second chunk latency data collected (check for errors or if all requests failed before second chunk).\n"
else:
s += "No latency data collected.\n"
print(s)
if args.manifest_path:
name = Path(args.manifest_path).stem
elif args.split_name:
name = args.split_name
elif args.reference_audio:
name = Path(args.reference_audio).stem
else:
name = "results"
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
f.write(s)
try:
if stats_client and stats_before:
print("Fetching inference statistics after running tasks...")
stats_after = await stats_client.get_inference_statistics(model_name="", as_json=True)
print("Calculating statistics difference...")
stats = subtract_stats(stats_after, stats_before)
print("Fetching model config...")
metadata = await stats_client.get_model_config(model_name=args.model_name, as_json=True)
write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt")
with open(f"{args.log_dir}/model_config-{name}.json", "w") as f:
json.dump(metadata, f, indent=4)
else:
print("Stats client not available or initial stats were not fetched. Skipping stats reporting.")
except Exception as e:
print(f"Could not retrieve statistics or config: {e}")
finally:
if stats_client:
try:
print("Closing temporary async stats client...")
await stats_client.close()
except Exception as e:
print(f"Error closing async stats client: {e}")
if __name__ == "__main__":
async def run_main():
try:
await main()
except Exception as e:
print(f"An error occurred in main: {e}")
import traceback
traceback.print_exc()
asyncio.run(run_main())

View File

@ -0,0 +1,172 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import requests
import soundfile as sf
import numpy as np
import argparse
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--server-url",
type=str,
default="localhost:8000",
help="Address of the server",
)
parser.add_argument(
"--reference-audio",
type=str,
default="../../example/prompt_audio.wav",
help="Path to a single audio file. It can't be specified at the same time with --manifest-dir",
)
parser.add_argument(
"--reference-text",
type=str,
default="吃燕窝就选燕之屋本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝营养更均衡本节目由豆本豆豆奶特约播出。",
help="",
)
parser.add_argument(
"--target-text",
type=str,
default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。",
help="",
)
parser.add_argument(
"--model-name",
type=str,
default="spark_tts",
choices=[
"f5_tts",
"spark_tts",
"cosyvoice2"],
help="triton model_repo module name to request",
)
parser.add_argument(
"--output-audio",
type=str,
default="output.wav",
help="Path to save the output audio",
)
return parser.parse_args()
def prepare_request(
waveform,
reference_text,
target_text,
sample_rate=16000,
padding_duration: int = None,
audio_save_dir: str = "./",
):
assert len(waveform.shape) == 1, "waveform should be 1D"
lengths = np.array([[len(waveform)]], dtype=np.int32)
if padding_duration:
# padding to nearset 10 seconds
samples = np.zeros(
(
1,
padding_duration
* sample_rate
* ((int(len(waveform) / sample_rate) // padding_duration) + 1),
),
dtype=np.float32,
)
samples[0, : len(waveform)] = waveform
else:
samples = waveform
samples = samples.reshape(1, -1).astype(np.float32)
data = {
"inputs": [
{
"name": "reference_wav",
"shape": samples.shape,
"datatype": "FP32",
"data": samples.tolist()
},
{
"name": "reference_wav_len",
"shape": lengths.shape,
"datatype": "INT32",
"data": lengths.tolist(),
},
{
"name": "reference_text",
"shape": [1, 1],
"datatype": "BYTES",
"data": [reference_text]
},
{
"name": "target_text",
"shape": [1, 1],
"datatype": "BYTES",
"data": [target_text]
}
]
}
return data
if __name__ == "__main__":
args = get_args()
server_url = args.server_url
if not server_url.startswith(("http://", "https://")):
server_url = f"http://{server_url}"
url = f"{server_url}/v2/models/{args.model_name}/infer"
waveform, sr = sf.read(args.reference_audio)
assert sr == 16000, "sample rate hardcoded in server"
samples = np.array(waveform, dtype=np.float32)
data = prepare_request(samples, args.reference_text, args.target_text)
rsp = requests.post(
url,
headers={"Content-Type": "application/json"},
json=data,
verify=False,
params={"request_id": '0'}
)
result = rsp.json()
audio = result["outputs"][0]["data"]
audio = np.array(audio, dtype=np.float32)
if args.model_name == "spark_tts":
sample_rate = 16000
else:
sample_rate = 24000
sf.write(args.output_audio, audio, sample_rate, "PCM_16")

View File

@ -0,0 +1,20 @@
services:
tts:
image: soar97/triton-cosyvoice:25.06
shm_size: '1gb'
ports:
- "8000:8000"
- "8001:8001"
- "8002:8002"
environment:
- PYTHONIOENCODING=utf-8
- MODEL_ID=${MODEL_ID}
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['0']
capabilities: [gpu]
command: >
/bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/yuekaizhang/Step-Audio2.git -b trt && git clone https://github.com/yuekaizhang/CosyVoice.git -b streaming && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run_stepaudio2_dit_token2wav.sh 0 3"

View File

@ -0,0 +1,20 @@
services:
tts:
image: soar97/triton-cosyvoice:25.06
shm_size: '1gb'
ports:
- "8000:8000"
- "8001:8001"
- "8002:8002"
environment:
- PYTHONIOENCODING=utf-8
- MODEL_ID=${MODEL_ID}
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['0']
capabilities: [gpu]
command: >
/bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/FunAudioLLM/CosyVoice.git && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run.sh 0 3"

View File

@ -0,0 +1,97 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import torch
from torch.utils.dlpack import to_dlpack
import triton_python_backend_utils as pb_utils
import os
import numpy as np
import s3tokenizer
torch.set_num_threads(1)
ORIGINAL_VOCAB_SIZE = 151663
class TritonPythonModel:
"""Triton Python model for audio tokenization.
This model takes reference audio input and extracts semantic tokens
using s3tokenizer.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
# Parse model parameters
parameters = json.loads(args['model_config'])['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.device = torch.device("cuda")
model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx")
self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device)
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing tokenized outputs
"""
mels = []
# Process each request in batch
for request in requests:
# Extract input tensors
wav_array = pb_utils.get_input_tensor_by_name(
request, "reference_wav").as_numpy()
wav_len = pb_utils.get_input_tensor_by_name(
request, "reference_wav_len").as_numpy().item()
wav_array = torch.from_numpy(wav_array).to(self.device)
# Prepare inputs
wav = wav_array[:, :wav_len].squeeze(0)
mels.append(s3tokenizer.log_mel_spectrogram(wav))
mels, mels_lens = s3tokenizer.padding(mels)
codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device))
codes = codes.clone() + ORIGINAL_VOCAB_SIZE
responses = []
for i in range(len(requests)):
prompt_speech_tokens = codes[i, :codes_lens[i].item()]
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack(
"prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
inference_response = pb_utils.InferenceResponse(
output_tensors=[prompt_speech_tokens_tensor])
responses.append(inference_response)
return responses

View File

@ -0,0 +1,53 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "audio_tokenizer"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
parameters [
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
}
]
output [
{
name: "prompt_speech_tokens"
data_type: TYPE_INT32
dims: [-1]
}
]
instance_group [
{
count: 1
kind: KIND_CPU
}
]

View File

@ -0,0 +1,454 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import os
import threading
import time
from uuid import uuid4
import numpy as np
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
import triton_python_backend_utils as pb_utils
from transformers import AutoTokenizer
import torchaudio
from matcha.utils.audio import mel_spectrogram
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
class TritonPythonModel:
"""Triton Python model for Spark TTS.
This model orchestrates the end-to-end TTS pipeline by coordinating
between audio tokenizer, LLM, and vocoder components.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
self.logger = pb_utils.Logger
# Parse model parameters
self.model_config = json.loads(args['model_config'])
parameters = self.model_config['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.logger.log_info(f"model_params:{model_params}")
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
# Initialize tokenizer
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
self.prompt_template = "<|sos|>{input_text}<|task_id|>"
self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
self.device = torch.device("cuda")
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
self.token_frame_rate = 25
self.flow_pre_lookahead_len = 3
self.token_hop_len = 15
spk_info_path = os.path.join(model_params["model_dir"], "spk2info.pt")
if not os.path.exists(spk_info_path):
raise ValueError(f"spk2info.pt not found in {model_params['model_dir']}")
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
self.default_spk_info = spk_info["001"]
def forward_llm(self, input_ids):
"""
Prepares the response from the language model based on the provided
inputs. Creates a `pb_utils.InferenceRequest` object with passed
`llm_request_inputs` to send to a decoupled TensorRTLLM model.
For each response from the language model:
- Checks for errors and raise an exception if any are found.
- Extracts the "output_ids" tensor from the response.
- Determines the finish reason based on the presence of the
end-of-sequence token or reaching the maximum length.
- Appends the generated token IDs to `output_ids`.
- If the finish reason is determined, decodes the output IDs to text
and prepares the final response.
The final response includes the generated text, finish reason,
completion tokens, prompt tokens, and total tokens.
Parameters
----------
- llm_request_inputs (dict): A dictionary containing the inputs for the language model.
Returns
-------
- pb_utils.InferenceResponse: The response object containing the generated text and additional metadata.
"""
# convert input_ids to numpy, with shape [1, sequence_length]
input_ids = input_ids.cpu().numpy()
max_tokens = 750
input_dict = {
"request_output_len": np.array([[max_tokens]], dtype=np.int32),
"end_id": np.array([[self.eos_token_id]], dtype=np.int32),
"pad_id": np.array([[self.eos_token_id]], dtype=np.int32),
"streaming": np.array([[self.decoupled]], dtype=np.bool_),
"runtime_top_p": np.array([[0.95]], dtype=np.float32),
"runtime_top_k": np.array([[50]], dtype=np.int32),
"temperature": np.array([[0.8]], dtype=np.float32),
"repetition_penalty": np.array([[1.1]], dtype=np.float32),
"random_seed": np.array([[42]], dtype=np.uint64),
"input_ids": input_ids,
"input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32),
}
# Convert inputs to Triton tensors
input_tensor_list = [
pb_utils.Tensor(k, v) for k, v in input_dict.items()
]
# Create and execute inference request
llm_request = pb_utils.InferenceRequest(
model_name="tensorrt_llm",
requested_output_names=["output_ids", "sequence_length"],
inputs=input_tensor_list,
)
llm_responses = llm_request.exec(decoupled=self.decoupled)
if self.decoupled:
for llm_response in llm_responses:
if llm_response.has_error():
raise pb_utils.TritonModelException(llm_response.error().message())
# Extract and process output
output_ids = pb_utils.get_output_tensor_by_name(
llm_response, "output_ids").as_numpy()
seq_lens = pb_utils.get_output_tensor_by_name(
llm_response, "sequence_length").as_numpy()
# Get actual output IDs up to the sequence length
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
yield actual_output_ids
else:
llm_response = llm_responses
if llm_response.has_error():
raise pb_utils.TritonModelException(llm_response.error().message())
# Extract and process output
output_ids = pb_utils.get_output_tensor_by_name(
llm_response, "output_ids").as_numpy()
seq_lens = pb_utils.get_output_tensor_by_name(
llm_response, "sequence_length").as_numpy()
# Get actual output IDs up to the sequence length
actual_output_ids = output_ids[0][0][:seq_lens[0][0]]
yield actual_output_ids
def forward_audio_tokenizer(self, wav, wav_len):
"""Forward pass through the audio tokenizer component.
Args:
wav: Input waveform tensor
wav_len: Waveform length tensor
Returns:
Tuple of global and semantic tokens
"""
inference_request = pb_utils.InferenceRequest(
model_name='audio_tokenizer',
requested_output_names=['prompt_speech_tokens'],
inputs=[wav, wav_len]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
return prompt_speech_tokens
def forward_speaker_embedding(self, wav):
"""Forward pass through the speaker embedding component.
Args:
wav: Input waveform tensor
Returns:
Prompt speaker embedding tensor
"""
inference_request = pb_utils.InferenceRequest(
model_name='speaker_embedding',
requested_output_names=['prompt_spk_embedding'],
inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
return prompt_spk_embedding
def forward_token2wav(
self,
target_speech_tokens: torch.Tensor,
request_id: str,
prompt_speech_tokens: torch.Tensor = None,
prompt_speech_feat: torch.Tensor = None,
prompt_spk_embedding: torch.Tensor = None,
token_offset: int = None,
finalize: bool = None) -> torch.Tensor:
"""Forward pass through the vocoder component.
Args:
prompt_speech_tokens: Prompt speech tokens tensor
prompt_speech_feat: Prompt speech feat tensor
prompt_spk_embedding: Prompt spk embedding tensor
target_speech_tokens: Target speech tokens tensor
Returns:
Generated waveform tensor
"""
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
inputs_tensor = [target_speech_tokens_tensor]
if token_offset is not None:
assert finalize is not None
token_offset_tensor = pb_utils.Tensor("token_offset", np.array([[token_offset]], dtype=np.int32))
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor.append(token_offset_tensor)
inputs_tensor.append(finalize_tensor)
if prompt_spk_embedding is not None:
assert prompt_speech_feat is not None
prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_tokens", to_dlpack(prompt_speech_tokens))
prompt_speech_feat_tensor = pb_utils.Tensor.from_dlpack("prompt_speech_feat", to_dlpack(prompt_speech_feat))
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack("prompt_spk_embedding", to_dlpack(prompt_spk_embedding))
inputs_tensor.extend([prompt_speech_tokens_tensor, prompt_speech_feat_tensor, prompt_spk_embedding_tensor])
# Create and execute inference request
inference_request = pb_utils.InferenceRequest(
model_name='token2wav',
requested_output_names=['waveform'],
inputs=inputs_tensor,
request_id=request_id,
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output waveform
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
return waveform
def parse_input(self, text, prompt_text, prompt_speech_tokens):
total_text = f"{prompt_text}{text}"
prompt = self.prompt_template.format(input_text=total_text)
input_ids = self.tokenizer.encode(prompt)
input_ids = torch.tensor([input_ids], dtype=torch.int32)
input_ids = torch.cat([input_ids, prompt_speech_tokens], dim=1)
return input_ids
def _extract_speech_feat(self, speech):
speech_feat = mel_spectrogram(
speech,
n_fft=1920,
num_mels=80,
sampling_rate=24000,
hop_size=480,
win_size=1920,
fmin=0,
fmax=8000).squeeze(
dim=0).transpose(
0,
1).to(
self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
return speech_feat
def _llm_gen_thread(self, generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag):
for generated_ids in generated_ids_iter:
generated_ids = generated_ids.tolist()
if len(generated_ids) == 0:
break
semantic_token_ids_arr.extend(generated_ids)
llm_is_done_flag[0] = True
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated audio
"""
responses = []
for request in requests:
request_id = request.request_id()
# Extract input tensors
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
# Process reference audio through audio tokenizer
if wav is not None:
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
prompt_speech_tokens = self.forward_audio_tokenizer(wav, wav_len)
prompt_speech_tokens = prompt_speech_tokens.unsqueeze(0)
wav_tensor = wav.as_numpy()
wav_tensor = torch.from_numpy(wav_tensor)[:, :wav_len.as_numpy()[0][0]]
prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=24000)(wav_tensor)
speech_feat = self._extract_speech_feat(prompt_speech_resample)
token_len = min(int(speech_feat.shape[1] / 2), prompt_speech_tokens.shape[-1])
prompt_speech_feat = speech_feat[:, :2 * token_len].contiguous().half()
prompt_speech_tokens = prompt_speech_tokens[:, :token_len].contiguous()
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
prompt_spk_embedding = self.forward_speaker_embedding(wav_tensor)
else:
# using pre-cached reference text
reference_text = self.default_spk_info["prompt_text"]
prompt_speech_tokens = self.default_spk_info["speech_token"] + ORIGINAL_VOCAB_SIZE
prompt_speech_feat = None
prompt_spk_embedding = None
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode('utf-8')
# Prepare prompt for LLM
input_ids = self.parse_input(
text=target_text,
prompt_text=reference_text,
prompt_speech_tokens=prompt_speech_tokens,
)
# Generate semantic tokens with LLM
generated_ids_iter = self.forward_llm(input_ids)
token2wav_request_id = request_id or str(uuid4())
if self.decoupled:
response_sender = request.get_response_sender()
semantic_token_ids_arr = []
llm_is_done_flag = [False]
llm_thread = threading.Thread(
target=self._llm_gen_thread,
args=(generated_ids_iter, semantic_token_ids_arr, llm_is_done_flag)
)
llm_thread.start()
token_offset, chunk_index = 0, 0
start_time = time.time()
this_token_hop_len = self.token_hop_len
while True:
pending_num = len(semantic_token_ids_arr) - token_offset
if llm_is_done_flag[0]:
break
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
this_tts_speech_token = semantic_token_ids_arr[:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(
this_tts_speech_token, token2wav_request_id, prompt_speech_tokens,
prompt_speech_feat, prompt_spk_embedding, token_offset, False
)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
token_offset += this_token_hop_len
self.logger.log_info(f"chunk_index: {chunk_index}, current_token_hop_len: {this_token_hop_len}")
if self.dynamic_chunk_strategy == "exponential":
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
elif self.dynamic_chunk_strategy == "time_based":
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
cost_time = time.time() - start_time
duration = token_offset / self.token_frame_rate
if chunk_index > 0 and cost_time > 0:
avg_chunk_processing_time = cost_time / (chunk_index + 1)
if avg_chunk_processing_time > 0:
multiples = (duration - cost_time) / avg_chunk_processing_time
self.logger.log_info(f"multiples: {multiples}")
next_pending_num = len(semantic_token_ids_arr) - token_offset
if multiples > 4:
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
elif multiples > 2:
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
else:
this_token_hop_len = self.token_hop_len
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
chunk_index += 1
else:
time.sleep(0.02)
this_tts_speech_token = torch.tensor(semantic_token_ids_arr).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = self.forward_token2wav(this_tts_speech_token, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding, token_offset, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
llm_thread.join()
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.logger.log_info("send tritonserver_response_complete_final to end")
else:
generated_ids = next(generated_ids_iter)
generated_ids = torch.tensor(generated_ids).unsqueeze(0).to(self.device)
if generated_ids is None or len(generated_ids) == 0:
raise pb_utils.TritonModelException("Generated IDs is None or empty")
audio = self.forward_token2wav(generated_ids, token2wav_request_id, prompt_speech_tokens, prompt_speech_feat, prompt_spk_embedding)
# Prepare response
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
responses.append(inference_response)
if not self.decoupled:
return responses

View File

@ -0,0 +1,73 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "cosyvoice2"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
model_transaction_policy {
decoupled: ${decoupled_mode}
}
parameters [
{
key: "llm_tokenizer_dir",
value: {string_value:"${llm_tokenizer_dir}"}
},
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
optional: true
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
optional: true
},
{
name: "reference_text"
data_type: TYPE_STRING
dims: [1]
optional: true
},
{
name: "target_text"
data_type: TYPE_STRING
dims: [1]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: ${bls_instance_num}
kind: KIND_CPU
}
]

View File

@ -0,0 +1,394 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import math
import os
import re
import time
from typing import Dict, List, Tuple, Optional, Union
import asyncio
import httpx
import numpy as np
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
import triton_python_backend_utils as pb_utils
from transformers import AutoTokenizer
import torchaudio
from matcha.utils.audio import mel_spectrogram
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
def parse_speech_token_string(response_text: str) -> List[int]:
"""
Parses a string of speech tokens (e.g., "<|s_123|><|s_456|>") into a list of integer IDs.
"""
speech_tokens = response_text.strip().split('><')
if len(speech_tokens) > 1:
# Add back the missing '<' and '>' for proper parsing
speech_tokens = ['<' + t if not t.startswith('<') else t for t in speech_tokens]
speech_tokens = [t + '>' if not t.endswith('>') else t for t in speech_tokens]
speech_ids = []
for token_str in speech_tokens:
match = re.match(r'<\|s_(\d+)\|>', token_str)
if match:
speech_ids.append(int(match.group(1)))
return speech_ids
class TritonPythonModel:
"""Triton Python model for Spark TTS.
This model orchestrates the end-to-end TTS pipeline by coordinating
between audio tokenizer, LLM, and vocoder components.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
self.logger = pb_utils.Logger
# Parse model parameters
self.model_config = json.loads(args['model_config'])
parameters = self.model_config['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.dynamic_chunk_strategy = model_params.get("dynamic_chunk_strategy", "exponential") # "exponential" or "time_based"
self.logger.log_info(f"Using dynamic chunk strategy: {self.dynamic_chunk_strategy}")
# Initialize tokenizer
llm_tokenizer_dir = model_params["llm_tokenizer_dir"]
self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir)
self.prompt_template = "<|sos|>{input_text}<|task_id|>"
self.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|eos1|>")
self.device = torch.device("cuda")
self.decoupled = pb_utils.using_decoupled_model_transaction_policy(self.model_config)
self.token_frame_rate = 25
self.flow_pre_lookahead_len = 3
self.token_hop_len = 15
self.http_client = httpx.AsyncClient()
self.api_base = "http://localhost:8000/v1/chat/completions"
self.speaker_cache = {}
def _convert_speech_tokens_to_str(self, speech_tokens: Union[torch.Tensor, List]) -> str:
"""Converts a tensor or list of speech token IDs to a string representation."""
if isinstance(speech_tokens, torch.Tensor):
# Ensure tensor is on CPU and flattened
speech_tokens = speech_tokens.cpu().numpy().flatten().tolist()
speech_id_str = ""
for token_id in speech_tokens:
# Convert token ID back to the speech number N
token_num = token_id - ORIGINAL_VOCAB_SIZE
speech_id_str += f"<|s_{token_num}|>"
return speech_id_str
async def forward_llm_async(self, target_text: str, reference_text: str, prompt_speech_tokens: Union[torch.Tensor, List]):
"""
Asynchronously sends a request to the TRTLLM-serve endpoint and processes the streaming response.
"""
full_text = f"{reference_text}{target_text}"
prompt_speech_tokens_str = self._convert_speech_tokens_to_str(prompt_speech_tokens)
chat = [
{"role": "user", "content": full_text},
{"role": "assistant", "content": prompt_speech_tokens_str}
]
payload = {
"model": "trt_engines_bfloat16",
"messages": chat,
"max_tokens": 750,
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"repetition_penalty": 1.1,
"stop": ["<|eos1|>", "<|eos|>"],
"stream": True,
}
buffer = ""
async with self.http_client.stream("POST", self.api_base, json=payload, timeout=None) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
line_data = line[len("data: "):].strip()
if line_data == "[DONE]":
break
try:
json_data = json.loads(line_data)
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content")
if content:
buffer += content
while True:
match = re.search(r"<\|s_(\d+)\|>", buffer)
if not match:
break
token_num = int(match.group(1))
final_id = token_num + ORIGINAL_VOCAB_SIZE
yield final_id
buffer = buffer[match.end():]
except json.JSONDecodeError:
self.logger.log_info(f"Skipping non-JSON line: {line_data}")
continue
# Process any remaining complete tokens in the buffer after the stream ends
while True:
match = re.search(r"<\|s_(\d+)\|>", buffer)
if not match:
break
token_num = int(match.group(1))
final_id = token_num + ORIGINAL_VOCAB_SIZE
yield final_id
buffer = buffer[match.end():]
def forward_audio_tokenizer(self, wav, wav_len):
"""Forward pass through the audio tokenizer component.
Args:
wav: Input waveform tensor
wav_len: Waveform length tensor
Returns:
Tuple of global and semantic tokens
"""
inference_request = pb_utils.InferenceRequest(
model_name='audio_tokenizer',
requested_output_names=['prompt_speech_tokens'],
inputs=[wav, wav_len]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_speech_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_speech_tokens')
prompt_speech_tokens = torch.utils.dlpack.from_dlpack(prompt_speech_tokens.to_dlpack()).cpu()
return prompt_speech_tokens
def forward_speaker_embedding(self, wav):
"""Forward pass through the speaker embedding component.
Args:
wav: Input waveform tensor
Returns:
Prompt speaker embedding tensor
"""
inference_request = pb_utils.InferenceRequest(
model_name='speaker_embedding',
requested_output_names=['prompt_spk_embedding'],
inputs=[pb_utils.Tensor.from_dlpack("reference_wav", to_dlpack(wav))]
)
inference_response = inference_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output tensors
prompt_spk_embedding = pb_utils.get_output_tensor_by_name(inference_response, 'prompt_spk_embedding')
prompt_spk_embedding = torch.utils.dlpack.from_dlpack(prompt_spk_embedding.to_dlpack())
return prompt_spk_embedding
async def forward_token2wav(
self,
index: int,
target_speech_tokens: torch.Tensor,
request_id: str,
reference_wav: object,
reference_wav_len: object,
finalize: bool = None) -> torch.Tensor:
"""Forward pass through the vocoder component.
Args:
index: Index of the request
target_speech_tokens: Target speech tokens tensor
request_id: Request ID
reference_wav: Reference waveform tensor
reference_wav_len: Reference waveform length tensor
finalize: Whether to finalize the request
Returns:
Generated waveform tensor
"""
target_speech_tokens_tensor = pb_utils.Tensor.from_dlpack("target_speech_tokens", to_dlpack(target_speech_tokens))
finalize_tensor = pb_utils.Tensor("finalize", np.array([[finalize]], dtype=np.bool_))
inputs_tensor = [target_speech_tokens_tensor, reference_wav, reference_wav_len, finalize_tensor]
# Create and execute inference request
inference_request = pb_utils.InferenceRequest(
model_name='token2wav_dit',
requested_output_names=[
"waveform",
],
inputs=inputs_tensor,
request_id=request_id,
parameters={"priority": index + 1},
)
inference_response = await inference_request.async_exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
# Extract and convert output waveform
waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform')
waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu()
return waveform
def _extract_speech_feat(self, speech):
speech_feat = mel_spectrogram(
speech,
n_fft=1920,
num_mels=80,
sampling_rate=24000,
hop_size=480,
win_size=1920,
fmin=0,
fmax=8000).squeeze(
dim=0).transpose(
0,
1).to(
self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
return speech_feat
async def _process_request(self, request):
request_id = request.request_id()
reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy()
reference_text = reference_text[0][0].decode('utf-8')
wav = pb_utils.get_input_tensor_by_name(request, "reference_wav")
wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len")
if reference_text not in self.speaker_cache:
self.speaker_cache[reference_text] = self.forward_audio_tokenizer(wav, wav_len).unsqueeze(0)
prompt_speech_tokens = self.speaker_cache[reference_text]
target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy()
target_text = target_text[0][0].decode('utf-8')
if self.decoupled:
response_sender = request.get_response_sender()
semantic_token_ids_arr = []
token_offset, chunk_index = 0, 0
start_time = time.time()
this_token_hop_len = self.token_hop_len
async for generated_ids in self.forward_llm_async(
target_text=target_text,
reference_text=reference_text,
prompt_speech_tokens=prompt_speech_tokens,
):
if not generated_ids:
break
semantic_token_ids_arr.append(generated_ids)
while True:
pending_num = len(semantic_token_ids_arr) - token_offset
if pending_num >= this_token_hop_len + self.flow_pre_lookahead_len:
this_tts_speech_token = semantic_token_ids_arr[token_offset:token_offset + this_token_hop_len + self.flow_pre_lookahead_len]
this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = await self.forward_token2wav(
chunk_index,
this_tts_speech_token, request_id, wav, wav_len, False
)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
token_offset += this_token_hop_len
if self.dynamic_chunk_strategy == "exponential":
this_token_hop_len = self.token_frame_rate * (2 ** chunk_index)
elif self.dynamic_chunk_strategy == "equal":
this_token_hop_len = self.token_hop_len
elif self.dynamic_chunk_strategy == "time_based":
# see https://github.com/qi-hua/async_cosyvoice/blob/main/model.py#L306
cost_time = time.time() - start_time
duration = token_offset / self.token_frame_rate
if chunk_index > 0 and cost_time > 0:
avg_chunk_processing_time = cost_time / (chunk_index + 1)
if avg_chunk_processing_time > 0:
multiples = (duration - cost_time) / avg_chunk_processing_time
next_pending_num = len(semantic_token_ids_arr) - token_offset
if multiples > 4:
this_token_hop_len = (next_pending_num // self.token_hop_len + 1) * self.token_hop_len
elif multiples > 2:
this_token_hop_len = (next_pending_num // self.token_hop_len) * self.token_hop_len
else:
this_token_hop_len = self.token_hop_len
this_token_hop_len = max(self.token_hop_len, this_token_hop_len)
chunk_index += 1
else:
break
this_tts_speech_token = torch.tensor(semantic_token_ids_arr[token_offset:]).unsqueeze(dim=0).to(torch.int32).to(self.device)
sub_tts_speech = await self.forward_token2wav(chunk_index, this_tts_speech_token, request_id, wav, wav_len, True)
audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(sub_tts_speech))
inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor])
response_sender.send(inference_response)
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
else:
raise NotImplementedError("Offline TTS mode is not supported")
async def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated audio
"""
tasks = [
asyncio.create_task(self._process_request(request))
for request in requests
]
await asyncio.gather(*tasks)
return None
def finalize(self):
self.logger.log_info("Finalizing CosyVoice DIT model")
if hasattr(self, "http_client"):
asyncio.run(self.http_client.aclose())

View File

@ -0,0 +1,73 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "cosyvoice2_dit"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
model_transaction_policy {
decoupled: ${decoupled_mode}
}
parameters [
{
key: "llm_tokenizer_dir",
value: {string_value:"${llm_tokenizer_dir}"}
},
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
optional: true
},
{
name: "reference_wav_len"
data_type: TYPE_INT32
dims: [1]
optional: true
},
{
name: "reference_text"
data_type: TYPE_STRING
dims: [1]
optional: true
},
{
name: "target_text"
data_type: TYPE_STRING
dims: [1]
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: ${bls_instance_num}
kind: KIND_CPU
}
]

View File

@ -0,0 +1,153 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import torch
from torch.utils.dlpack import to_dlpack
import triton_python_backend_utils as pb_utils
import os
import numpy as np
import torchaudio.compliance.kaldi as kaldi
from cosyvoice.utils.file_utils import convert_onnx_to_trt
from cosyvoice.utils.common import TrtContextWrapper
import onnxruntime
class TritonPythonModel:
"""Triton Python model for audio tokenization.
This model takes reference audio input and extracts semantic tokens
using s3tokenizer.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
# Parse model parameters
parameters = json.loads(args['model_config'])['parameters']
model_params = {k: v["string_value"] for k, v in parameters.items()}
self.device = torch.device("cuda")
model_dir = model_params["model_dir"]
gpu = "l20"
enable_trt = True
if enable_trt:
self.load_spk_trt(f'{model_dir}/campplus.{gpu}.fp32.trt',
f'{model_dir}/campplus.onnx',
1,
False)
else:
campplus_model = f'{model_dir}/campplus.onnx'
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.spk_model = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
def load_spk_trt(self, spk_model, spk_onnx_model, trt_concurrent=1, fp16=True):
if not os.path.exists(spk_model) or os.path.getsize(spk_model) == 0:
trt_kwargs = self.get_spk_trt_kwargs()
convert_onnx_to_trt(spk_model, trt_kwargs, spk_onnx_model, fp16)
import tensorrt as trt
with open(spk_model, 'rb') as f:
spk_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert spk_engine is not None, 'failed to load trt {}'.format(spk_model)
self.spk_model = TrtContextWrapper(spk_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_spk_trt_kwargs(self):
min_shape = [(1, 4, 80)]
opt_shape = [(1, 500, 80)]
max_shape = [(1, 3000, 80)]
input_names = ["input"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def _extract_spk_embedding(self, speech):
feat = kaldi.fbank(speech,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
spk_feat = feat - feat.mean(dim=0, keepdim=True)
if isinstance(self.spk_model, onnxruntime.InferenceSession):
embedding = self.spk_model.run(
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
)[0].flatten().tolist()
embedding = torch.tensor([embedding]).to(self.device)
else:
[spk_model, stream], trt_engine = self.spk_model.acquire_estimator()
# NOTE need to synchronize when switching stream
with torch.cuda.device(self.device):
torch.cuda.current_stream().synchronize()
spk_feat = spk_feat.unsqueeze(dim=0).to(self.device)
batch_size = spk_feat.size(0)
with stream:
spk_model.set_input_shape('input', (batch_size, spk_feat.size(1), 80))
embedding = torch.empty((batch_size, 192), device=spk_feat.device)
data_ptrs = [spk_feat.contiguous().data_ptr(),
embedding.contiguous().data_ptr()]
for i, j in enumerate(data_ptrs):
spk_model.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert spk_model.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.spk_model.release_estimator(spk_model, stream)
return embedding.half()
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing tokenized outputs
"""
responses = []
# Process each request in batch
for request in requests:
# Extract input tensors
wav_array = pb_utils.get_input_tensor_by_name(
request, "reference_wav").as_numpy()
wav_array = torch.from_numpy(wav_array).to(self.device)
embedding = self._extract_spk_embedding(wav_array)
prompt_spk_embedding_tensor = pb_utils.Tensor.from_dlpack(
"prompt_spk_embedding", to_dlpack(embedding))
inference_response = pb_utils.InferenceResponse(
output_tensors=[prompt_spk_embedding_tensor])
responses.append(inference_response)
return responses

View File

@ -0,0 +1,48 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "speaker_embedding"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
parameters [
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "reference_wav"
data_type: TYPE_FP32
dims: [-1]
}
]
output [
{
name: "prompt_spk_embedding"
data_type: TYPE_FP16
dims: [-1]
}
]
instance_group [
{
count: 1
kind: KIND_CPU
}
]

View File

@ -0,0 +1,857 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
name: "tensorrt_llm"
backend: "${triton_backend}"
max_batch_size: ${triton_max_batch_size}
model_transaction_policy {
decoupled: ${decoupled_mode}
}
dynamic_batching {
preferred_batch_size: [ ${triton_max_batch_size} ]
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
default_queue_policy: { max_queue_size: ${max_queue_size} }
}
input [
{
name: "input_ids"
data_type: TYPE_INT32
dims: [ -1 ]
allow_ragged_batch: true
optional: true
},
{
name: "encoder_input_features"
data_type: ${encoder_input_features_data_type}
dims: [ -1, -1 ]
allow_ragged_batch: true
optional: true
},
{
name: "encoder_output_lengths"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "input_lengths"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
},
{
name: "request_output_len"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
},
{
name: "num_return_sequences"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "draft_input_ids"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "decoder_input_ids"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "decoder_input_lengths"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
reshape: { shape: [ ] }
},
{
name: "draft_logits"
data_type: ${logits_datatype}
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "draft_acceptance_threshold"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "end_id"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "pad_id"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "stop_words_list"
data_type: TYPE_INT32
dims: [ 2, -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "bad_words_list"
data_type: TYPE_INT32
dims: [ 2, -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "embedding_bias"
data_type: TYPE_FP32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "beam_width"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "temperature"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_k"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_p"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_p_min"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_p_decay"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "runtime_top_p_reset_ids"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "len_penalty"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "early_stopping"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "repetition_penalty"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "min_length"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "beam_search_diversity_rate"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "presence_penalty"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "frequency_penalty"
data_type: TYPE_FP32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "random_seed"
data_type: TYPE_UINT64
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "return_log_probs"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "return_context_logits"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "return_generation_logits"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "return_perf_metrics"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "exclude_input_in_output"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "stop"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "streaming"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "prompt_embedding_table"
data_type: TYPE_FP16
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "prompt_table_extra_ids"
data_type: TYPE_UINT64
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "prompt_vocab_size"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
# cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]`
{
name: "cross_attention_mask"
data_type: TYPE_BOOL
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true
},
# Mrope param when mrope is used
{
name: "mrope_rotary_cos_sin"
data_type: TYPE_FP32
dims: [ -1 ]
optional: true
},
{
name: "mrope_position_deltas"
data_type: TYPE_INT64
dims: [ 1 ]
optional: true
},
# the unique task ID for the given LoRA.
# To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
# The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
# If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached.
{
name: "lora_task_id"
data_type: TYPE_UINT64
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
# weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
# where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
# each of the in / out tensors are first flattened and then concatenated together in the format above.
# D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
{
name: "lora_weights"
data_type: TYPE_FP16
dims: [ -1, -1 ]
optional: true
allow_ragged_batch: true
},
# module identifier (same size a first dimension of lora_weights)
# See LoraModule::ModuleType for model id mapping
#
# "attn_qkv": 0 # compbined qkv adapter
# "attn_q": 1 # q adapter
# "attn_k": 2 # k adapter
# "attn_v": 3 # v adapter
# "attn_dense": 4 # adapter for the dense layer in attention
# "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
# "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
# "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
#
# last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
{
name: "lora_config"
data_type: TYPE_INT32
dims: [ -1, 3 ]
optional: true
allow_ragged_batch: true
},
{
name: "context_phase_params"
data_type: TYPE_UINT8
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
# skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama
{
name: "skip_cross_attn_blocks"
data_type: TYPE_BOOL
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_token_range_starts"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_token_range_ends"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_token_range_priorities"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_token_range_durations_ms"
data_type: TYPE_INT32
dims: [ -1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_decode_priority"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "retention_decode_duration_ms"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "guided_decoding_guide_type"
data_type: TYPE_STRING
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "guided_decoding_guide"
data_type: TYPE_STRING
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "lookahead_window_size"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "lookahead_ngram_size"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
allow_ragged_batch: true
},
{
name: "lookahead_verification_set_size"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
allow_ragged_batch: true
}
]
output [
{
name: "output_ids"
data_type: TYPE_INT32
dims: [ -1, -1 ]
},
{
name: "sequence_length"
data_type: TYPE_INT32
dims: [ -1 ]
},
{
name: "cum_log_probs"
data_type: TYPE_FP32
dims: [ -1 ]
},
{
name: "output_log_probs"
data_type: TYPE_FP32
dims: [ -1, -1 ]
},
{
name: "context_logits"
data_type: ${logits_datatype}
dims: [ -1, -1 ]
},
{
name: "generation_logits"
data_type: ${logits_datatype}
dims: [ -1, -1, -1 ]
},
{
name: "batch_index"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "sequence_index"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "context_phase_params"
data_type: TYPE_UINT8
dims: [ -1 ]
},
{
name: "kv_cache_alloc_new_blocks"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "kv_cache_reused_blocks"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "kv_cache_alloc_total_blocks"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "arrival_time_ns"
data_type: TYPE_INT64
dims: [ 1 ]
},
{
name: "first_scheduled_time_ns"
data_type: TYPE_INT64
dims: [ 1 ]
},
{
name: "first_token_time_ns"
data_type: TYPE_INT64
dims: [ 1 ]
},
{
name: "last_token_time_ns"
data_type: TYPE_INT64
dims: [ 1 ]
},
{
name: "acceptance_rate"
data_type: TYPE_FP32
dims: [ 1 ]
},
{
name: "total_accepted_draft_tokens"
data_type: TYPE_INT32
dims: [ 1 ]
},
{
name: "total_draft_tokens"
data_type: TYPE_INT32
dims: [ 1 ]
}
]
instance_group [
{
count: 1
kind : KIND_CPU
}
]
parameters: {
key: "max_beam_width"
value: {
string_value: "${max_beam_width}"
}
}
parameters: {
key: "FORCE_CPU_ONLY_INPUT_TENSORS"
value: {
string_value: "no"
}
}
parameters: {
key: "gpt_model_type"
value: {
string_value: "${batching_strategy}"
}
}
parameters: {
key: "gpt_model_path"
value: {
string_value: "${engine_dir}"
}
}
parameters: {
key: "encoder_model_path"
value: {
string_value: "${encoder_engine_dir}"
}
}
parameters: {
key: "max_tokens_in_paged_kv_cache"
value: {
string_value: "${max_tokens_in_paged_kv_cache}"
}
}
parameters: {
key: "max_attention_window_size"
value: {
string_value: "${max_attention_window_size}"
}
}
parameters: {
key: "sink_token_length"
value: {
string_value: "${sink_token_length}"
}
}
parameters: {
key: "batch_scheduler_policy"
value: {
string_value: "${batch_scheduler_policy}"
}
}
parameters: {
key: "kv_cache_free_gpu_mem_fraction"
value: {
string_value: "${kv_cache_free_gpu_mem_fraction}"
}
}
parameters: {
key: "cross_kv_cache_fraction"
value: {
string_value: "${cross_kv_cache_fraction}"
}
}
parameters: {
key: "kv_cache_host_memory_bytes"
value: {
string_value: "${kv_cache_host_memory_bytes}"
}
}
# kv_cache_onboard_blocks is for internal implementation.
parameters: {
key: "kv_cache_onboard_blocks"
value: {
string_value: "${kv_cache_onboard_blocks}"
}
}
# enable_trt_overlap is deprecated and doesn't have any effect on the runtime
# parameters: {
# key: "enable_trt_overlap"
# value: {
# string_value: "${enable_trt_overlap}"
# }
# }
parameters: {
key: "exclude_input_in_output"
value: {
string_value: "${exclude_input_in_output}"
}
}
parameters: {
key: "cancellation_check_period_ms"
value: {
string_value: "${cancellation_check_period_ms}"
}
}
parameters: {
key: "stats_check_period_ms"
value: {
string_value: "${stats_check_period_ms}"
}
}
parameters: {
key: "iter_stats_max_iterations"
value: {
string_value: "${iter_stats_max_iterations}"
}
}
parameters: {
key: "request_stats_max_iterations"
value: {
string_value: "${request_stats_max_iterations}"
}
}
parameters: {
key: "enable_kv_cache_reuse"
value: {
string_value: "${enable_kv_cache_reuse}"
}
}
parameters: {
key: "normalize_log_probs"
value: {
string_value: "${normalize_log_probs}"
}
}
parameters: {
key: "enable_chunked_context"
value: {
string_value: "${enable_chunked_context}"
}
}
parameters: {
key: "gpu_device_ids"
value: {
string_value: "${gpu_device_ids}"
}
}
parameters: {
key: "participant_ids"
value: {
string_value: "${participant_ids}"
}
}
parameters: {
key: "lora_cache_optimal_adapter_size"
value: {
string_value: "${lora_cache_optimal_adapter_size}"
}
}
parameters: {
key: "lora_cache_max_adapter_size"
value: {
string_value: "${lora_cache_max_adapter_size}"
}
}
parameters: {
key: "lora_cache_gpu_memory_fraction"
value: {
string_value: "${lora_cache_gpu_memory_fraction}"
}
}
parameters: {
key: "lora_cache_host_memory_bytes"
value: {
string_value: "${lora_cache_host_memory_bytes}"
}
}
parameters: {
key: "lora_prefetch_dir"
value: {
string_value: "${lora_prefetch_dir}"
}
}
parameters: {
key: "decoding_mode"
value: {
string_value: "${decoding_mode}"
}
}
parameters: {
key: "executor_worker_path"
value: {
string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
}
}
parameters: {
key: "lookahead_window_size"
value: {
string_value: "${lookahead_window_size}"
}
}
parameters: {
key: "lookahead_ngram_size"
value: {
string_value: "${lookahead_ngram_size}"
}
}
parameters: {
key: "lookahead_verification_set_size"
value: {
string_value: "${lookahead_verification_set_size}"
}
}
parameters: {
key: "medusa_choices"
value: {
string_value: "${medusa_choices}"
}
}
parameters: {
key: "eagle_choices"
value: {
string_value: "${eagle_choices}"
}
}
parameters: {
key: "gpu_weights_percent"
value: {
string_value: "${gpu_weights_percent}"
}
}
parameters: {
key: "enable_context_fmha_fp32_acc"
value: {
string_value: "${enable_context_fmha_fp32_acc}"
}
}
parameters: {
key: "multi_block_mode"
value: {
string_value: "${multi_block_mode}"
}
}
parameters: {
key: "cuda_graph_mode"
value: {
string_value: "${cuda_graph_mode}"
}
}
parameters: {
key: "cuda_graph_cache_size"
value: {
string_value: "${cuda_graph_cache_size}"
}
}
parameters: {
key: "speculative_decoding_fast_logits"
value: {
string_value: "${speculative_decoding_fast_logits}"
}
}
parameters: {
key: "tokenizer_dir"
value: {
string_value: "${tokenizer_dir}"
}
}
parameters: {
key: "guided_decoding_backend"
value: {
string_value: "${guided_decoding_backend}"
}
}
parameters: {
key: "xgrammar_tokenizer_info_path"
value: {
string_value: "${xgrammar_tokenizer_info_path}"
}
}

View File

@ -0,0 +1,277 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import os
import logging
import torch
from torch.utils.dlpack import to_dlpack
from torch.nn import functional as F
import triton_python_backend_utils as pb_utils
from hyperpyyaml import load_hyperpyyaml
from cosyvoice.utils.common import fade_in_out
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
from cosyvoice.utils.common import TrtContextWrapper
from collections import defaultdict
import numpy as np
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
class CosyVoice2:
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1, device='cuda'):
self.model_dir = model_dir
self.fp16 = fp16
hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
if not os.path.exists(hyper_yaml_path):
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
self.model = CosyVoice2Model(configs['flow'], configs['hift'], fp16, device)
self.model.load('{}/flow.pt'.format(model_dir), '{}/hift.pt'.format(model_dir))
if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
trt_concurrent,
self.fp16)
class CosyVoice2Model:
def __init__(self,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False,
device: str = 'cuda'):
self.device = device
self.flow = flow
self.hift = hift
self.fp16 = fp16
if self.fp16 is True:
self.flow.half()
# streaming tts config
self.token_hop_len = 25
self.mel_cache_len = 8
self.source_cache_len = int(self.mel_cache_len * 480)
self.speech_window = np.hamming(2 * self.source_cache_len)
self.hift_cache_dict = defaultdict(lambda: None)
def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
def load(self, flow_model, hift_model):
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=True)
self.flow.to(self.device).eval()
# in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
del self.flow.decoder.estimator
import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
input_names = ["x", "mask", "mu", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
streaming=stream,
finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append hift cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
else:
hift_cache_source = torch.zeros(1, 1, 0)
# keep overlap mel and hift cache
if finalize is False:
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
'source': tts_source[:, :, -self.source_cache_len:],
'speech': tts_speech[:, -self.source_cache_len:]}
tts_speech = tts_speech[:, :-self.source_cache_len]
else:
if speed != 1.0:
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
return tts_speech
class TritonPythonModel:
"""Triton Python model for vocoder.
This model takes global and semantic tokens as input and generates audio waveforms
using the BiCodec vocoder.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
# Parse model parameters
parameters = json.loads(args['model_config'])['parameters']
model_params = {key: value["string_value"] for key, value in parameters.items()}
model_dir = model_params["model_dir"]
# Initialize device and vocoder
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
self.token2wav_model = CosyVoice2(
model_dir, load_jit=False, load_trt=True, fp16=True, device=self.device
)
spk_info_path = os.path.join(model_dir, "spk2info.pt")
if not os.path.exists(spk_info_path):
raise ValueError(f"spk2info.pt not found in {model_dir}")
spk_info = torch.load(spk_info_path, map_location="cpu", weights_only=False)
self.default_spk_info = spk_info["001"]
logger.info("Token2Wav initialized successfully")
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated waveforms
"""
responses = []
# Process each request in batch
for request in requests:
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor).to(self.device)
prompt_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_tokens")
if prompt_speech_tokens_tensor is not None:
prompt_speech_tokens_tensor = prompt_speech_tokens_tensor.as_numpy()
prompt_speech_feat_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_speech_feat").as_numpy()
prompt_spk_embedding_tensor = pb_utils.get_input_tensor_by_name(request, "prompt_spk_embedding").as_numpy()
prompt_speech_tokens = torch.from_numpy(prompt_speech_tokens_tensor).to(self.device)
prompt_speech_feat = torch.from_numpy(prompt_speech_feat_tensor).to(self.device)
prompt_spk_embedding = torch.from_numpy(prompt_spk_embedding_tensor).to(self.device)
prompt_speech_tokens = prompt_speech_tokens - ORIGINAL_VOCAB_SIZE
else:
prompt_speech_tokens = self.default_spk_info["speech_token"].to(self.device)
prompt_speech_feat = self.default_spk_info["speech_feat"].to(torch.float16).to(self.device)
prompt_spk_embedding = self.default_spk_info["embedding"].to(torch.float16).to(self.device)
# shift the speech tokens according to the original vocab size
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
# We set token_offset as an optional input to support streaming/offline tts. It has to be None when offline tts.
token_offset = pb_utils.get_input_tensor_by_name(request, "token_offset")
if token_offset is not None:
token_offset = token_offset.as_numpy().item()
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
if not finalize:
stream = True
else:
stream = False
request_id = request.request_id()
audio_hat = self.token2wav_model.model.token2wav(token=target_speech_tokens,
prompt_token=prompt_speech_tokens,
prompt_feat=prompt_speech_feat,
embedding=prompt_spk_embedding,
token_offset=token_offset,
uuid=request_id,
stream=stream,
finalize=finalize)
if finalize:
self.token2wav_model.model.hift_cache_dict.pop(request_id)
else:
tts_mel, _ = self.token2wav_model.model.flow.inference(
token=target_speech_tokens,
token_len=torch.tensor([target_speech_tokens.shape[1]], dtype=torch.int32).to(
self.device
),
prompt_token=prompt_speech_tokens,
prompt_token_len=torch.tensor(
[prompt_speech_tokens.shape[1]], dtype=torch.int32
).to(self.device),
prompt_feat=prompt_speech_feat,
prompt_feat_len=torch.tensor([prompt_speech_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=prompt_spk_embedding,
streaming=False,
finalize=True,
)
audio_hat, _ = self.token2wav_model.model.hift.inference(
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
)
generated_wave = audio_hat.squeeze(0).cpu().numpy()
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
responses.append(inference_response)
return responses

View File

@ -0,0 +1,80 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: "token2wav"
backend: "python"
max_batch_size: ${triton_max_batch_size}
dynamic_batching {
max_queue_delay_microseconds: ${max_queue_delay_microseconds}
}
parameters [
{
key: "model_dir",
value: {string_value:"${model_dir}"}
}
]
input [
{
name: "target_speech_tokens"
data_type: TYPE_INT32
dims: [-1]
},
{
name: "prompt_speech_tokens"
data_type: TYPE_INT32
dims: [-1]
optional: true
},
{
name: "prompt_speech_feat"
data_type: TYPE_FP16
dims: [-1, 80]
optional: true
},
{
name: "prompt_spk_embedding"
data_type: TYPE_FP16
dims: [-1]
optional: true
},
{
name: "token_offset"
data_type: TYPE_INT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "finalize"
data_type: TYPE_BOOL
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
}
]
output [
{
name: "waveform"
data_type: TYPE_FP32
dims: [ -1 ]
}
]
instance_group [
{
count: 1
kind: KIND_CPU
}
]

View File

@ -0,0 +1,142 @@
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import os
import logging
from typing import List, Dict
import torch
from torch.utils.dlpack import to_dlpack
from torch.nn import functional as F
import triton_python_backend_utils as pb_utils
from hyperpyyaml import load_hyperpyyaml
from cosyvoice.utils.common import fade_in_out
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
from cosyvoice.utils.common import TrtContextWrapper
from collections import defaultdict
import numpy as np
from .token2wav_dit import CosyVoice2_Token2Wav
import hashlib
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
ORIGINAL_VOCAB_SIZE = 151663
torch.set_num_threads(1)
def get_spk_id_from_prompt_audio(tensor: torch.Tensor) -> str:
"""
Generates a unique ID for a torch.Tensor.
Tensors with the same elements and properties will have the same ID.
"""
# Convert tensor to a byte string
tensor_bytes = tensor.numpy().tobytes()
# Create a SHA-256 hash of the byte string
hasher = hashlib.sha256()
hasher.update(tensor_bytes)
return hasher.hexdigest()
class TritonPythonModel:
"""Triton Python model for vocoder.
This model takes global and semantic tokens as input and generates audio waveforms
using the BiCodec vocoder.
"""
def initialize(self, args):
"""Initialize the model.
Args:
args: Dictionary containing model configuration
"""
# Parse model parameters
parameters = json.loads(args['model_config'])['parameters']
model_params = {key: value["string_value"] for key, value in parameters.items()}
model_dir = model_params["model_dir"]
# Initialize device and vocoder
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
# FIXME: device id settings
self.token2wav_model = CosyVoice2_Token2Wav(
model_dir, enable_trt=True, streaming=True
)
logger.info("Token2Wav initialized successfully")
def execute(self, requests):
"""Execute inference on the batched requests.
Args:
requests: List of inference requests
Returns:
List of inference responses containing generated waveforms
"""
responses = []
# Process each request in batch
for request in requests:
target_speech_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "target_speech_tokens").as_numpy()
target_speech_tokens = torch.from_numpy(target_speech_tokens_tensor)
target_speech_tokens = target_speech_tokens - ORIGINAL_VOCAB_SIZE
target_speech_tokens = target_speech_tokens.squeeze().tolist()
finalize = pb_utils.get_input_tensor_by_name(request, "finalize").as_numpy().item()
request_id = request.request_id()
wav_array = pb_utils.get_input_tensor_by_name(
request, "reference_wav").as_numpy()
wav_len = pb_utils.get_input_tensor_by_name(
request, "reference_wav_len").as_numpy().item()
wav_array = torch.from_numpy(wav_array)
wav = wav_array[:, :wav_len].squeeze(0)
spk_id = get_spk_id_from_prompt_audio(wav)
audio_hat = self.token2wav_model.forward_streaming(
target_speech_tokens, finalize, request_id=request_id,
speaker_id=f"{spk_id}", prompt_audio=wav, prompt_audio_sample_rate=16000
)
outputs = []
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio_hat))
outputs.append(wav_tensor)
inference_response = pb_utils.InferenceResponse(output_tensors=outputs)
responses.append(inference_response)
return responses

Some files were not shown because too many files have changed in this diff Show More