Compare commits

..

No commits in common. "master" and "add-disc" have entirely different histories.

34 changed files with 508 additions and 2275 deletions

View File

@ -1,23 +0,0 @@
name: Test Installation
on:
push:
branches: [ "master" ]
pull_request:
branches: [ "master" ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Test Standard Install
run: |
pip install -e .

Binary file not shown.

Before

Width:  |  Height:  |  Size: 707 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 763 KiB

178
README.md
View File

@ -1,112 +1,33 @@
![Chatterbox Turbo Image](./Chatterbox-Turbo.jpg)
<img width="1200" alt="cb-big2" src="https://github.com/user-attachments/assets/bd8c5f03-e91d-4ee5-b680-57355da204d1" />
# Chatterbox TTS
[![Alt Text](https://img.shields.io/badge/listen-demo_samples-blue)](https://resemble-ai.github.io/chatterbox_turbo_demopage/)
[![Alt Text](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/ResembleAI/chatterbox-turbo-demo)
[![Alt Text](https://img.shields.io/badge/listen-demo_samples-blue)](https://resemble-ai.github.io/chatterbox_demopage/)
[![Alt Text](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/ResembleAI/Chatterbox)
[![Alt Text](https://static-public.podonos.com/badges/insight-on-pdns-sm-dark.svg)](https://podonos.com/resembleai/chatterbox)
[![Discord](https://img.shields.io/discord/1377773249798344776?label=join%20discord&logo=discord&style=flat)](https://discord.gg/rJq9cRJBJ6)
_Made with ♥️ by <a href="https://resemble.ai" target="_blank"><img width="100" alt="resemble-logo-horizontal" src="https://github.com/user-attachments/assets/35cf756b-3506-4943-9c72-c05ddfa4e525" /></a>
_Made with ♥️ by <img width="100" alt="resemble-logo-horizontal" src="https://github.com/user-attachments/assets/35cf756b-3506-4943-9c72-c05ddfa4e525" />_
**Chatterbox** is a family of three state-of-the-art, open-source text-to-speech models by Resemble AI.
We're excited to introduce Chatterbox, [Resemble AI's](https://resemble.ai) first production-grade open source TTS model. Licensed under MIT, Chatterbox has been benchmarked against leading closed-source systems like ElevenLabs, and is consistently preferred in side-by-side evaluations.
We are excited to introduce **Chatterbox-Turbo**, our most efficient model yet. Built on a streamlined 350M parameter architecture, **Turbo** delivers high-quality speech with less compute and VRAM than our previous models. We have also distilled the speech-token-to-mel decoder, previously a bottleneck, reducing generation from 10 steps to just **one**, while retaining high-fidelity audio output.
**Paralinguistic tags** are now native to the Turbo model, allowing you to use `[cough]`, `[laugh]`, `[chuckle]`, and more to add distinct realism. While Turbo was built primarily for low-latency voice agents, it excels at narration and creative workflows.
Whether you're working on memes, videos, games, or AI agents, Chatterbox brings your content to life. It's also the first open source TTS model to support **emotion exaggeration control**, a powerful feature that makes your voices stand out. Try it now on our [Hugging Face Gradio app.](https://huggingface.co/spaces/ResembleAI/Chatterbox)
If you like the model but need to scale or tune it for higher accuracy, check out our competitively priced TTS service (<a href="https://resemble.ai">link</a>). It delivers reliable performance with ultra-low latency of sub 200ms—ideal for production use in agents, applications, or interactive media.
<img width="1200" height="600" alt="Podonos Turbo Eval" src="https://storage.googleapis.com/chatterbox-demo-samples/turbo/podonos_turbo.png" />
# Key Details
- SoTA zeroshot TTS
- 0.5B Llama backbone
- Unique exaggeration/intensity control
- Ultra-stable with alignment-informed inference
- Trained on 0.5M hours of cleaned data
- Watermarked outputs
- Easy voice conversion script
- [Outperforms ElevenLabs](https://podonos.com/resembleai/chatterbox)
### ⚡ Model Zoo
Choose the right model for your application.
| Model | Size | Languages | Key Features | Best For | 🤗 | Examples |
|:----------------------------------------------------------------------------------------------------------------| :--- | :--- |:--------------------------------------------------------|:---------------------------------------------|:--------------------------------------------------------------------------| :--- |
| **Chatterbox-Turbo** | **350M** | **English** | Paralinguistic Tags (`[laugh]`), Lower Compute and VRAM | Zero-shot voice agents, Production | [Demo](https://huggingface.co/spaces/ResembleAI/chatterbox-turbo-demo) | [Listen](https://resemble-ai.github.io/chatterbox_turbo_demopage/) |
| Chatterbox-Multilingual [(Language list)](#supported-languages) | 500M | 23+ | Zero-shot cloning, Multiple Languages | Global applications, Localization | [Demo](https://huggingface.co/spaces/ResembleAI/Chatterbox-Multilingual-TTS) | [Listen](https://resemble-ai.github.io/chatterbox_demopage/) |
| Chatterbox [(Tips and Tricks)](#original-chatterbox-tips) | 500M | English | CFG & Exaggeration tuning | General zero-shot TTS with creative controls | [Demo](https://huggingface.co/spaces/ResembleAI/Chatterbox) | [Listen](https://resemble-ai.github.io/chatterbox_demopage/) |
## Installation
```shell
pip install chatterbox-tts
```
Alternatively, you can install from source:
```shell
# conda create -yn chatterbox python=3.11
# conda activate chatterbox
git clone https://github.com/resemble-ai/chatterbox.git
cd chatterbox
pip install -e .
```
We developed and tested Chatterbox on Python 3.11 on Debian 11 OS; the versions of the dependencies are pinned in `pyproject.toml` to ensure consistency. You can modify the code or dependencies in this installation mode.
## Usage
##### Chatterbox-Turbo
```python
import torchaudio as ta
import torch
from chatterbox.tts_turbo import ChatterboxTurboTTS
# Load the Turbo model
model = ChatterboxTurboTTS.from_pretrained(device="cuda")
# Generate with Paralinguistic Tags
text = "Hi there, Sarah here from MochaFone calling you back [chuckle], have you got one minute to chat about the billing issue?"
# Generate audio (requires a reference clip for voice cloning)
wav = model.generate(text, audio_prompt_path="your_10s_ref_clip.wav")
ta.save("test-turbo.wav", wav, model.sr)
```
##### Chatterbox and Chatterbox-Multilingual
```python
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
# English example
model = ChatterboxTTS.from_pretrained(device="cuda")
text = "Ezreal and Jinx teamed up with Ahri, Yasuo, and Teemo to take down the enemy's Nexus in an epic late-game pentakill."
wav = model.generate(text)
ta.save("test-english.wav", wav, model.sr)
# Multilingual examples
multilingual_model = ChatterboxMultilingualTTS.from_pretrained(device=device)
french_text = "Bonjour, comment ça va? Ceci est le modèle de synthèse vocale multilingue Chatterbox, il prend en charge 23 langues."
wav_french = multilingual_model.generate(spanish_text, language_id="fr")
ta.save("test-french.wav", wav_french, model.sr)
chinese_text = "你好,今天天气真不错,希望你有一个愉快的周末。"
wav_chinese = multilingual_model.generate(chinese_text, language_id="zh")
ta.save("test-chinese.wav", wav_chinese, model.sr)
# If you want to synthesize with a different voice, specify the audio prompt
AUDIO_PROMPT_PATH = "YOUR_FILE.wav"
wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH)
ta.save("test-2.wav", wav, model.sr)
```
See `example_tts.py` and `example_vc.py` for more examples.
## Supported Languages
Arabic (ar) • Danish (da) • German (de) • Greek (el) • English (en) • Spanish (es) • Finnish (fi) • French (fr) • Hebrew (he) • Hindi (hi) • Italian (it) • Japanese (ja) • Korean (ko) • Malay (ms) • Dutch (nl) • Norwegian (no) • Polish (pl) • Portuguese (pt) • Russian (ru) • Swedish (sv) • Swahili (sw) • Turkish (tr) • Chinese (zh)
## Original Chatterbox Tips
# Tips
- **General Use (TTS and Voice Agents):**
- Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clips language. To mitigate this, set `cfg_weight` to `0`.
- The default settings (`exaggeration=0.5`, `cfg_weight=0.5`) work well for most prompts across all languages.
- The default settings (`exaggeration=0.5`, `cfg_weight=0.5`) work well for most prompts.
- If the reference speaker has a fast speaking style, lowering `cfg_weight` to around `0.3` can improve pacing.
- **Expressive or Dramatic Speech:**
@ -114,55 +35,40 @@ Arabic (ar) • Danish (da) • German (de) • Greek (el) • English (en) •
- Higher `exaggeration` tends to speed up speech; reducing `cfg_weight` helps compensate with slower, more deliberate pacing.
## Built-in PerTh Watermarking for Responsible AI
Every audio file generated by Chatterbox includes [Resemble AI's Perth (Perceptual Threshold) Watermarker](https://github.com/resemble-ai/perth) - imperceptible neural watermarks that survive MP3 compression, audio editing, and common manipulations while maintaining nearly 100% detection accuracy.
## Watermark extraction
You can look for the watermark using the following script.
```python
import perth
import librosa
AUDIO_PATH = "YOUR_FILE.wav"
# Load the watermarked audio
watermarked_audio, sr = librosa.load(AUDIO_PATH, sr=None)
# Initialize watermarker (same as used for embedding)
watermarker = perth.PerthImplicitWatermarker()
# Extract watermark
watermark = watermarker.get_watermark(watermarked_audio, sample_rate=sr)
print(f"Extracted watermark: {watermark}")
# Output: 0.0 (no watermark) or 1.0 (watermarked)
# Installation
```
pip install chatterbox-tts
```
## Official Discord
# Usage
```python
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
👋 Join us on [Discord](https://discord.gg/rJq9cRJBJ6) and let's build something awesome together!
model = ChatterboxTTS.from_pretrained(device="cuda")
## Acknowledgements
text = "Ezreal and Jinx teamed up with Ahri, Yasuo, and Teemo to take down the enemy's Nexus in an epic late-game pentakill."
wav = model.generate(text)
ta.save("test-1.wav", wav, model.sr)
# If you want to synthesize with a different voice, specify the audio prompt
AUDIO_PROMPT_PATH="YOUR_FILE.wav"
wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH)
ta.save("test-2.wav", wav, model.sr)
```
See `example_tts.py` for more examples.
# Acknowledgements
- [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice)
- [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning)
- [HiFT-GAN](https://github.com/yl4579/HiFTNet)
- [Llama 3](https://github.com/meta-llama/llama3)
- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer)
## Citation
If you find this model useful, please consider citing.
```
@misc{chatterboxtts2025,
author = {{Resemble AI}},
title = {{Chatterbox-TTS}},
year = {2025},
howpublished = {\url{https://github.com/resemble-ai/chatterbox}},
note = {GitHub repository}
}
```
## Disclaimer
# Built-in PerTh Watermarking for Responsible AI
Every audio file generated by Chatterbox includes [Resemble AI's Perth (Perceptual Threshold) Watermarker](https://github.com/resemble-ai/perth) - imperceptible neural watermarks that survive MP3 compression, audio editing, and common manipulations while maintaining nearly 100% detection accuracy.
# Disclaimer
Don't use this model to do bad things. Prompts are sourced from freely available data on the internet.

View File

@ -1,7 +1,6 @@
import torchaudio as ta
import torch
from chatterbox.tts import ChatterboxTTS
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
# Automatically detect the best available device
if torch.cuda.is_available():
@ -17,15 +16,4 @@ model = ChatterboxTTS.from_pretrained(device=device)
text = "Ezreal and Jinx teamed up with Ahri, Yasuo, and Teemo to take down the enemy's Nexus in an epic late-game pentakill."
wav = model.generate(text)
ta.save("test-1.wav", wav, model.sr)
multilingual_model = ChatterboxMultilingualTTS.from_pretrained(device=device)
text = "Bonjour, comment ça va? Ceci est le modèle de synthèse vocale multilingue Chatterbox, il prend en charge 23 langues."
wav = multilingual_model.generate(text, language_id="fr")
ta.save("test-2.wav", wav, multilingual_model.sr)
# If you want to synthesize with a different voice, specify the audio prompt
AUDIO_PROMPT_PATH = "YOUR_FILE.wav"
wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH)
ta.save("test-3.wav", wav, model.sr)
ta.save("test-1.wav", wav, model.sr)

View File

@ -1,14 +0,0 @@
import torchaudio as ta
import torch
from chatterbox.tts_turbo import ChatterboxTurboTTS
# Load the Turbo model
model = ChatterboxTurboTTS.from_pretrained(device="cuda")
# Generate with Paralinguistic Tags
text = "Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and all that jazz. Would you like me to get some prices for you?"
# Generate audio (requires a reference clip for voice cloning)
# wav = model.generate(text, audio_prompt_path="your_10s_ref_clip.wav")
wav = model.generate(text)
ta.save("test-turbo.wav", wav, model.sr)

View File

@ -1,24 +1,6 @@
import torch
import torchaudio as ta
from chatterbox.vc import ChatterboxVC
# Automatically detect the best available device
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print(f"Using device: {device}")
AUDIO_PATH = "YOUR_FILE.wav"
TARGET_VOICE_PATH = "YOUR_FILE.wav"
model = ChatterboxVC.from_pretrained(device)
wav = model.generate(
audio=AUDIO_PATH,
target_voice_path=TARGET_VOICE_PATH,
)
model = ChatterboxVC.from_pretrained("cuda")
wav = model.generate("tests/trimmed_8b7f38b1.wav")
import torchaudio as ta
ta.save("testvc.wav", wav, model.sr)

View File

@ -21,7 +21,7 @@ def load_model():
return model
def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num, cfgw, min_p, top_p, repetition_penalty):
def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num, cfgw):
if model is None:
model = ChatterboxTTS.from_pretrained(DEVICE)
@ -34,9 +34,6 @@ def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num
exaggeration=exaggeration,
temperature=temperature,
cfg_weight=cfgw,
min_p=min_p,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
return (model.sr, wav.squeeze(0).numpy())
@ -46,21 +43,14 @@ with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
text = gr.Textbox(
value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.",
label="Text to synthesize (max chars 300)",
max_lines=5
)
text = gr.Textbox(value="What does the fox say?", label="Text to synthesize")
ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value=None)
exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5)
cfg_weight = gr.Slider(0.0, 1, step=.05, label="CFG/Pace", value=0.5)
cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG/Pace", value=0.5)
with gr.Accordion("More options", open=False):
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
temp = gr.Slider(0.05, 5, step=.05, label="temperature", value=.8)
min_p = gr.Slider(0.00, 1.00, step=0.01, label="min_p || Newer Sampler. Recommend 0.02 > 0.1. Handles Higher Temperatures better. 0.00 Disables", value=0.05)
top_p = gr.Slider(0.00, 1.00, step=0.01, label="top_p || Original Sampler. 1.0 Disables(recommended). Original 0.8", value=1.00)
repetition_penalty = gr.Slider(1.00, 2.00, step=0.1, label="repetition_penalty", value=1.2)
run_btn = gr.Button("Generate", variant="primary")
@ -79,9 +69,6 @@ with gr.Blocks() as demo:
temp,
seed_num,
cfg_weight,
min_p,
top_p,
repetition_penalty,
],
outputs=audio_output,
)

View File

@ -1,186 +0,0 @@
import random
import numpy as np
import torch
import gradio as gr
from chatterbox.tts_turbo import ChatterboxTurboTTS
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EVENT_TAGS = [
"[clear throat]", "[sigh]", "[shush]", "[cough]", "[groan]",
"[sniff]", "[gasp]", "[chuckle]", "[laugh]"
]
# --- REFINED CSS ---
# 1. tag-container: Forces the row to wrap items instead of scrolling. Removes borders/backgrounds.
# 2. tag-btn: Sets the specific look (indigo theme) and stops them from stretching.
CUSTOM_CSS = """
.tag-container {
display: flex !important;
flex-wrap: wrap !important; /* This fixes the one-per-line issue */
gap: 8px !important;
margin-top: 5px !important;
margin-bottom: 10px !important;
border: none !important;
background: transparent !important;
}
.tag-btn {
min-width: fit-content !important;
width: auto !important;
height: 32px !important;
font-size: 13px !important;
background: #eef2ff !important;
border: 1px solid #c7d2fe !important;
color: #3730a3 !important;
border-radius: 6px !important;
padding: 0 10px !important;
margin: 0 !important;
box-shadow: none !important;
}
.tag-btn:hover {
background: #c7d2fe !important;
transform: translateY(-1px);
}
"""
INSERT_TAG_JS = """
(tag_val, current_text) => {
const textarea = document.querySelector('#main_textbox textarea');
if (!textarea) return current_text + " " + tag_val;
const start = textarea.selectionStart;
const end = textarea.selectionEnd;
let prefix = " ";
let suffix = " ";
if (start === 0) prefix = "";
else if (current_text[start - 1] === ' ') prefix = "";
if (end < current_text.length && current_text[end] === ' ') suffix = "";
return current_text.slice(0, start) + prefix + tag_val + suffix + current_text.slice(end);
}
"""
def set_seed(seed: int):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
def load_model():
print(f"Loading Chatterbox-Turbo on {DEVICE}...")
model = ChatterboxTurboTTS.from_pretrained(DEVICE)
return model
def generate(
model,
text,
audio_prompt_path,
temperature,
seed_num,
min_p,
top_p,
top_k,
repetition_penalty,
norm_loudness
):
if model is None:
model = ChatterboxTurboTTS.from_pretrained(DEVICE)
if seed_num != 0:
set_seed(int(seed_num))
wav = model.generate(
text,
audio_prompt_path=audio_prompt_path,
temperature=temperature,
min_p=min_p,
top_p=top_p,
top_k=int(top_k),
repetition_penalty=repetition_penalty,
norm_loudness=norm_loudness,
)
return (model.sr, wav.squeeze(0).numpy())
with gr.Blocks(title="Chatterbox Turbo", css=CUSTOM_CSS) as demo:
gr.Markdown("# ⚡ Chatterbox Turbo")
model_state = gr.State(None)
with gr.Row():
with gr.Column():
text = gr.Textbox(
value="Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and um all that jazz. Would you like me to get some prices for you?",
label="Text to synthesize (max chars 300)",
max_lines=5,
elem_id="main_textbox"
)
# --- Event Tags ---
# Switched back to Row, but applied specific CSS to force wrapping
with gr.Row(elem_classes=["tag-container"]):
for tag in EVENT_TAGS:
# elem_classes targets the button specifically
btn = gr.Button(tag, elem_classes=["tag-btn"])
btn.click(
fn=None,
inputs=[btn, text],
outputs=text,
js=INSERT_TAG_JS
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio File",
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_random_podcast.wav"
)
run_btn = gr.Button("Generate ⚡", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output Audio")
with gr.Accordion("Advanced Options", open=False):
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
temp = gr.Slider(0.05, 2.0, step=.05, label="Temperature", value=0.8)
top_p = gr.Slider(0.00, 1.00, step=0.01, label="Top P", value=0.95)
top_k = gr.Slider(0, 1000, step=10, label="Top K", value=1000)
repetition_penalty = gr.Slider(1.00, 2.00, step=0.05, label="Repetition Penalty", value=1.2)
min_p = gr.Slider(0.00, 1.00, step=0.01, label="Min P (Set to 0 to disable)", value=0.00)
norm_loudness = gr.Checkbox(value=True, label="Normalize Loudness (-27 LUFS)")
demo.load(fn=load_model, inputs=[], outputs=model_state)
run_btn.click(
fn=generate,
inputs=[
model_state,
text,
ref_wav,
temp,
seed_num,
min_p,
top_p,
top_k,
repetition_penalty,
norm_loudness,
],
outputs=audio_output,
)
if __name__ == "__main__":
demo.queue(
max_size=50,
default_concurrency_limit=1,
).launch(share=True)

View File

@ -1,317 +0,0 @@
import random
import numpy as np
import torch
from chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
import gradio as gr
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {DEVICE}")
# --- Global Model Initialization ---
MODEL = None
LANGUAGE_CONFIG = {
"ar": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_f/ar_prompts2.flac",
"text": "في الشهر الماضي، وصلنا إلى معلم جديد بمليارين من المشاهدات على قناتنا على يوتيوب."
},
"da": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/da_m1.flac",
"text": "Sidste måned nåede vi en ny milepæl med to milliarder visninger på vores YouTube-kanal."
},
"de": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/de_f1.flac",
"text": "Letzten Monat haben wir einen neuen Meilenstein erreicht: zwei Milliarden Aufrufe auf unserem YouTube-Kanal."
},
"el": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/el_m.flac",
"text": "Τον περασμένο μήνα, φτάσαμε σε ένα νέο ορόσημο με δύο δισεκατομμύρια προβολές στο κανάλι μας στο YouTube."
},
"en": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac",
"text": "Last month, we reached a new milestone with two billion views on our YouTube channel."
},
"es": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/es_f1.flac",
"text": "El mes pasado alcanzamos un nuevo hito: dos mil millones de visualizaciones en nuestro canal de YouTube."
},
"fi": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fi_m.flac",
"text": "Viime kuussa saavutimme uuden virstanpylvään kahden miljardin katselukerran kanssa YouTube-kanavallamme."
},
"fr": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fr_f1.flac",
"text": "Le mois dernier, nous avons atteint un nouveau jalon avec deux milliards de vues sur notre chaîne YouTube."
},
"he": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/he_m1.flac",
"text": "בחודש שעבר הגענו לאבן דרך חדשה עם שני מיליארד צפיות בערוץ היוטיוב שלנו."
},
"hi": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/hi_f1.flac",
"text": "पिछले महीने हमने एक नया मील का पत्थर छुआ: हमारे YouTube चैनल पर दो अरब व्यूज़।"
},
"it": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/it_m1.flac",
"text": "Il mese scorso abbiamo raggiunto un nuovo traguardo: due miliardi di visualizzazioni sul nostro canale YouTube."
},
"ja": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ja/ja_prompts1.flac",
"text": "先月、私たちのYouTubeチャンネルで二十億回の再生回数という新たなマイルストーンに到達しました。"
},
"ko": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ko_f.flac",
"text": "지난달 우리는 유튜브 채널에서 이십억 조회수라는 새로운 이정표에 도달했습니다."
},
"ms": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ms_f.flac",
"text": "Bulan lepas, kami mencapai pencapaian baru dengan dua bilion tontonan di saluran YouTube kami."
},
"nl": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/nl_m.flac",
"text": "Vorige maand bereikten we een nieuwe mijlpaal met twee miljard weergaven op ons YouTube-kanaal."
},
"no": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/no_f1.flac",
"text": "Forrige måned nådde vi en ny milepæl med to milliarder visninger på YouTube-kanalen vår."
},
"pl": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/pl_m.flac",
"text": "W zeszłym miesiącu osiągnęliśmy nowy kamień milowy z dwoma miliardami wyświetleń na naszym kanale YouTube."
},
"pt": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/pt_m1.flac",
"text": "No mês passado, alcançámos um novo marco: dois mil milhões de visualizações no nosso canal do YouTube."
},
"ru": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ru_m.flac",
"text": "В прошлом месяце мы достигли нового рубежа: два миллиарда просмотров на нашем YouTube-канале."
},
"sv": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/sv_f.flac",
"text": "Förra månaden nådde vi en ny milstolpe med två miljarder visningar på vår YouTube-kanal."
},
"sw": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/sw_m.flac",
"text": "Mwezi uliopita, tulifika hatua mpya ya maoni ya bilioni mbili kweny kituo chetu cha YouTube."
},
"tr": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/tr_m.flac",
"text": "Geçen ay YouTube kanalımızda iki milyar görüntüleme ile yeni bir dönüm noktasına ulaştık."
},
"zh": {
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/zh_f2.flac",
"text": "上个月,我们达到了一个新的里程碑. 我们的YouTube频道观看次数达到了二十亿次这绝对令人难以置信。"
},
}
# --- UI Helpers ---
def default_audio_for_ui(lang: str) -> str | None:
return LANGUAGE_CONFIG.get(lang, {}).get("audio")
def default_text_for_ui(lang: str) -> str:
return LANGUAGE_CONFIG.get(lang, {}).get("text", "")
def get_supported_languages_display() -> str:
"""Generate a formatted display of all supported languages."""
language_items = []
for code, name in sorted(SUPPORTED_LANGUAGES.items()):
language_items.append(f"**{name}** (`{code}`)")
# Split into 2 lines
mid = len(language_items) // 2
line1 = "".join(language_items[:mid])
line2 = "".join(language_items[mid:])
return f"""
### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total)
{line1}
{line2}
"""
def get_or_load_model():
"""Loads the ChatterboxMultilingualTTS model if it hasn't been loaded already,
and ensures it's on the correct device."""
global MODEL
if MODEL is None:
print("Model not loaded, initializing...")
try:
MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
MODEL.to(DEVICE)
print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
except Exception as e:
print(f"Error loading model: {e}")
raise
return MODEL
# Attempt to load the model at startup.
try:
get_or_load_model()
except Exception as e:
print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
def set_seed(seed: int):
"""Sets the random seed for reproducibility across torch, numpy, and random."""
torch.manual_seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
def resolve_audio_prompt(language_id: str, provided_path: str | None) -> str | None:
"""
Decide which audio prompt to use:
- If user provided a path (upload/mic/url), use it.
- Else, fall back to language-specific default (if any).
"""
if provided_path and str(provided_path).strip():
return provided_path
return LANGUAGE_CONFIG.get(language_id, {}).get("audio")
def generate_tts_audio(
text_input: str,
language_id: str,
audio_prompt_path_input: str = None,
exaggeration_input: float = 0.5,
temperature_input: float = 0.8,
seed_num_input: int = 0,
cfgw_input: float = 0.5
) -> tuple[int, np.ndarray]:
"""
Generate high-quality speech audio from text using Chatterbox Multilingual model with optional reference audio styling.
Supported languages: English, French, German, Spanish, Italian, Portuguese, and Hindi.
This tool synthesizes natural-sounding speech from input text. When a reference audio file
is provided, it captures the speaker's voice characteristics and speaking style. The generated audio
maintains the prosody, tone, and vocal qualities of the reference speaker, or uses default voice if no reference is provided.
Args:
text_input (str): The text to synthesize into speech (maximum 300 characters)
language_id (str): The language code for synthesis (eg. en, fr, de, es, it, pt, hi)
audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None.
exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5.
temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8.
seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0.
cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.5, 0 for language transfer.
Returns:
tuple[int, np.ndarray]: A tuple containing the sample rate (int) and the generated audio waveform (numpy.ndarray)
"""
current_model = get_or_load_model()
if current_model is None:
raise RuntimeError("TTS model is not loaded.")
if seed_num_input != 0:
set_seed(int(seed_num_input))
print(f"Generating audio for text: '{text_input[:50]}...'")
# Handle optional audio prompt
chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id)
generate_kwargs = {
"exaggeration": exaggeration_input,
"temperature": temperature_input,
"cfg_weight": cfgw_input,
}
if chosen_prompt:
generate_kwargs["audio_prompt_path"] = chosen_prompt
print(f"Using audio prompt: {chosen_prompt}")
else:
print("No audio prompt provided; using default voice.")
wav = current_model.generate(
text_input[:300], # Truncate text to max chars
language_id=language_id,
**generate_kwargs
)
print("Audio generation complete.")
return (current_model.sr, wav.squeeze(0).numpy())
with gr.Blocks() as demo:
gr.Markdown(
"""
# Chatterbox Multilingual Demo
Generate high-quality multilingual speech from text with reference audio styling, supporting 23 languages.
"""
)
# Display supported languages
gr.Markdown(get_supported_languages_display())
with gr.Row():
with gr.Column():
initial_lang = "fr"
text = gr.Textbox(
value=default_text_for_ui(initial_lang),
label="Text to synthesize (max chars 300)",
max_lines=5
)
language_id = gr.Dropdown(
choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()),
value=initial_lang,
label="Language",
info="Select the language for text-to-speech synthesis"
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio File (Optional)",
value=default_audio_for_ui(initial_lang)
)
gr.Markdown(
"💡 **Note**: Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip's language. To mitigate this, set the CFG weight to 0.",
elem_classes=["audio-note"]
)
exaggeration = gr.Slider(
0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
)
cfg_weight = gr.Slider(
0.2, 1, step=.05, label="CFG/Pace", value=0.5
)
with gr.Accordion("More options", open=False):
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output Audio")
def on_language_change(lang, current_ref, current_text):
return default_audio_for_ui(lang), default_text_for_ui(lang)
language_id.change(
fn=on_language_change,
inputs=[language_id, ref_wav, text],
outputs=[ref_wav, text],
show_progress=False
)
run_btn.click(
fn=generate_tts_audio,
inputs=[
text,
language_id,
ref_wav,
exaggeration,
temp,
seed_num,
cfg_weight,
],
outputs=[audio_output],
)
demo.launch(mcp_server=True)

View File

@ -1,29 +1,25 @@
[project]
name = "chatterbox-tts"
version = "0.1.6"
version = "0.1.1"
description = "Chatterbox: Open Source TTS and Voice Conversion by Resemble AI"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.8"
license = {file = "LICENSE"}
authors = [
{name = "resemble-ai", email = "engineering@resemble.ai"}
]
dependencies = [
"numpy>=1.24.0,<1.26.0",
"librosa==0.11.0",
"numpy~=1.26.0",
"resampy==0.4.3",
"librosa==0.10.0",
"s3tokenizer",
"torch==2.6.0",
"torchaudio==2.6.0",
"transformers==4.46.3",
"diffusers==0.29.0",
"resemble-perth==1.0.1",
"omegaconf==2.3.0",
"conformer==0.3.2",
"safetensors==0.5.3",
"spacy-pkuseg",
"pykakasi==2.3.0",
"gradio==5.44.1",
"pyloudnorm",
"omegaconf"
]
[project.urls]

View File

@ -1,11 +1,2 @@
try:
from importlib.metadata import version
except ImportError:
from importlib_metadata import version # For Python <3.8
__version__ = version("chatterbox-tts")
from .tts import ChatterboxTTS
from .vc import ChatterboxVC
from .mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES

View File

@ -1,10 +0,0 @@
from ..utils import AttrDict
CFM_PARAMS = AttrDict({
"sigma_min": 1e-06,
"solver": "euler",
"t_scheduler": "cosine",
"training_cfg_rate": 0.2,
"inference_cfg_rate": 0.7,
"reg_loss_type": "l1"
})

View File

@ -1,2 +1 @@
S3GEN_SR = 24000
S3GEN_SIL = 4299

View File

@ -20,7 +20,6 @@ from .utils.mask import add_optional_chunk_mask
from .matcha.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, \
TimestepEmbedding, Upsample1D
from .matcha.transformer import BasicTransformerBlock
from .utils.intmeanflow import get_intmeanflow_time_mixer
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
@ -96,6 +95,8 @@ class CausalConv1d(torch.nn.Conv1d):
x = F.pad(x, self.causal_padding)
x = super(CausalConv1d, self).forward(x)
return x
class ConditionalDecoder(nn.Module):
def __init__(
self,
@ -109,7 +110,6 @@ class ConditionalDecoder(nn.Module):
num_mid_blocks=12,
num_heads=8,
act_fn="gelu",
meanflow=False,
):
"""
This decoder requires an input with the same shape of the target. So, if your text content
@ -117,7 +117,6 @@ class ConditionalDecoder(nn.Module):
"""
super().__init__()
channels = tuple(channels)
self.meanflow = meanflow
self.in_channels = in_channels
self.out_channels = out_channels
self.causal = causal
@ -128,7 +127,6 @@ class ConditionalDecoder(nn.Module):
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
@ -217,14 +215,6 @@ class ConditionalDecoder(nn.Module):
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
self.time_embed_mixer = None
if self.meanflow:
self.time_embed_mixer = get_intmeanflow_time_mixer(time_embed_dim)
@property
def dtype(self):
return self.final_proj.weight.dtype
def initialize_weights(self):
for m in self.modules():
@ -240,16 +230,15 @@ class ConditionalDecoder(nn.Module):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None, r=None):
def forward(self, x, mask, mu, t, spks=None, cond=None):
"""Forward pass of the UNet1DConditional model.
Args:
x: (B, 80, T)
mask (_type_)
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional) Defaults to None.
cond (_type_, optional)
r: end time for meanflow mode (shape (1,) tensor)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
@ -258,15 +247,10 @@ class ConditionalDecoder(nn.Module):
Returns:
_type_: _description_
"""
t = self.time_embeddings(t).to(t.dtype)
t = self.time_mlp(t)
if self.meanflow:
r = self.time_embeddings(r).to(t.dtype)
r = self.time_mlp(r)
concat_embed = torch.cat([t, r], dim=1)
t = self.time_embed_mixer(concat_embed)
x = pack([x, mu], "b * t")[0]
if spks is not None:

View File

@ -14,30 +14,142 @@
import logging
import random
from typing import Dict, Optional
logger = logging.getLogger(__name__)
import torch
import torch.nn as nn
from torch.nn import functional as F
from .utils.mask import make_pad_mask
from .configs import CFM_PARAMS
from omegaconf import DictConfig
from .utils.mask import make_pad_mask
logger = logging.getLogger(__name__)
class MaskedDiffWithXvec(torch.nn.Module):
def __init__(self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 4096,
input_frame_rate: int = 50,
only_mask_loss: bool = True,
encoder: torch.nn.Module = None,
length_regulator: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.mel_feat_conf = mel_feat_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
logging.info(f"input frame rate={self.input_frame_rate}")
self.input_embedding = nn.Embedding(vocab_size, input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
self.encoder = encoder
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
self.decoder = decoder
self.length_regulator = length_regulator
self.only_mask_loss = only_mask_loss
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)
def _repeat_batch_dim(tnsr, B, ndim):
"repeat batch dimension if it's equal to 1"
if tnsr is not None:
# add missing batch dim if needed
while tnsr.ndim < ndim:
tnsr = tnsr[None]
# repeat batch dim as needed
if B > 1 and tnsr.size(0) == 1:
tnsr = tnsr.repeat(B, *([1] * (ndim - 1)))
assert tnsr.ndim == ndim, f"Expected {ndim=}, got {tnsr.ndim=}"
return tnsr
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h)
h, h_lengths = self.length_regulator(h, feat_len)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h)
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds
)
return {'loss': loss}
@torch.inference_mode()
def inference(self,
token,
token_len,
prompt_token,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding,
flow_cache):
if self.fp16 is True:
prompt_feat = prompt_feat.half()
embedding = embedding.half()
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h)
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
# get conditions
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat, flow_cache = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10,
prompt_len=mel_len1,
flow_cache=flow_cache
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat.float(), flow_cache
class CausalMaskedDiffWithXvec(torch.nn.Module):
@ -54,14 +166,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
encoder: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
'cfm_params': DictConfig(
{'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7,
'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0,
'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8,
'act_fn': 'gelu'}},
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
super().__init__()
@ -82,51 +190,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
self.token_mel_ratio = token_mel_ratio
self.pre_lookahead_len = pre_lookahead_len
# NOTE: copied in from cosyvoice repo
def compute_loss(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device) # (B, 80, T)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)
# NOTE unified training, static_chunk_size > 0 or = 0
# streaming = True if random.random() < 0.5 else False
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) # (B, T, 1)
token = self.input_embedding(torch.clamp(token, min=0)) * mask # (B, T, emb)
# text encode
h, h_lengths = self.encoder(token, token_len) # (B, T, C) -> (B, 2T, C)
h = self.encoder_proj(h)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :, :index] = feat[i, :, :index]
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
loss, _ = self.decoder.compute_loss(
feat.contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds,
# streaming=streaming,
)
return {'loss': loss}
# FIXME: this was missing - just putting it in as false
self.fp16 = False
@torch.inference_mode()
def inference(self,
@ -137,62 +202,41 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
prompt_feat,
prompt_feat_len,
embedding,
finalize,
n_timesteps=10,
noised_mels=None,
meanflow=False):
# token: (B, n_toks)
# token_len: (B,)
B = token.size(0)
finalize):
if self.fp16 is True:
prompt_feat = prompt_feat.half()
embedding = embedding.half()
assert token.shape[0] == 1
# xvec projection
embedding = torch.atleast_2d(embedding)
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding) # (1 or B, emb_dim)
# adjust shapes (batching logic)
prompt_token = _repeat_batch_dim(prompt_token, B, ndim=2) # (B, n_prompt)
prompt_token_len = _repeat_batch_dim(prompt_token_len, B, ndim=1) # (B,)
prompt_feat = _repeat_batch_dim(prompt_feat, B, ndim=3) # (B, n_feat, feat_dim=80)
prompt_feat_len = _repeat_batch_dim(prompt_feat_len, B, ndim=1) # (B,) or None
embedding = _repeat_batch_dim(embedding, B, ndim=2) # (B, emb_dim)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
if (token >= self.vocab_size).any():
logger.error(f"{token.max()}>{self.vocab_size}\n out-of-range special tokens found in flow, fix inputs!")
token = self.input_embedding(token.long()) * mask
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_masks = self.encoder(token, token_len)
h, h_lengths = self.encoder(token, token_len)
if finalize is False:
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
h_lengths = h_masks.sum(dim=-1).squeeze(dim=-1)
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
h = self.encoder_proj(h)
# # get conditions
conds = torch.zeros([B, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
# get conditions
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(h_lengths)).unsqueeze(1).to(h)
if mask.shape[0] != B:
mask = mask.repeat(B, 1, 1)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat, _ = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask,
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=n_timesteps,
noised_mels=noised_mels,
meanflow=meanflow,
n_timesteps=10
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat, None # NOTE jrm: why are they returning None here?
return feat.float(), None # NOTE jrm: why are they returning None here?

View File

@ -15,12 +15,17 @@ import threading
import torch
import torch.nn.functional as F
from .matcha.flow_matching import BASECFM
from .configs import CFM_PARAMS
from tqdm import tqdm
from omegaconf import OmegaConf
def cast_all(*args, dtype):
return [a if (not a.dtype.is_floating_point) or a.dtype == dtype else a.to(dtype) for a in args]
CFM_PARAMS = OmegaConf.create({
"sigma_min": 1e-06,
"solver": "euler",
"t_scheduler": "cosine",
"training_cfg_rate": 0.2,
"inference_cfg_rate": 0.7,
"reg_loss_type": "l1"
})
class ConditionalCFM(BASECFM):
@ -37,6 +42,7 @@ class ConditionalCFM(BASECFM):
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
# Just change the architecture of the estimator here
self.estimator = estimator
self.lock = threading.Lock()
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
@ -58,8 +64,6 @@ class ConditionalCFM(BASECFM):
shape: (batch_size, n_feats, mel_timesteps)
"""
raise NotImplementedError("unused, needs updating for meanflow model")
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
cache_size = flow_cache.shape[2]
# fix prompt and overlap part mu and z
@ -75,7 +79,7 @@ class ConditionalCFM(BASECFM):
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
def solve_euler(self, x, t_span, mu, mask, spks, cond, meanflow=False):
def solve_euler(self, x, t_span, mu, mask, spks, cond):
"""
Fixed euler solver for ODEs.
Args:
@ -89,60 +93,65 @@ class ConditionalCFM(BASECFM):
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
meanflow: meanflow mode
"""
in_dtype = x.dtype
x, t_span, mu, mask, spks, cond = cast_all(x, t_span, mu, mask, spks, cond, dtype=self.estimator.dtype)
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
t = t.unsqueeze(dim=0)
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
# Duplicated batch dims are for CFG
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
B, T = mu.size(0), x.size(2)
x_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype)
mask_in = torch.zeros([2 * B, 1, T], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2 * B ], device=x.device, dtype=x.dtype)
spks_in = torch.zeros([2 * B, 80 ], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype)
r_in = torch.zeros([2 * B ], device=x.device, dtype=x.dtype) # (only used for meanflow)
for t, r in zip(t_span[:-1], t_span[1:]):
t = t.unsqueeze(dim=0)
r = r.unsqueeze(dim=0)
# Shapes:
# x_in ( 2B, 80, T )
# mask_in ( 2B, 1, T )
# mu_in ( 2B, 80, T )
# t_in ( 2B, )
# spks_in ( 2B, 80, )
# cond_in ( 2B, 80, T )
# r_in ( 2B, )
# x ( B, 80, T )
# mask ( B, 1, T )
# mu ( B, 80, T )
# t ( B, )
# spks ( B, 80, )
# cond ( B, 80, T )
# r ( B, )
x_in[:B] = x_in[B:] = x
mask_in[:B] = mask_in[B:] = mask
mu_in[:B] = mu
t_in[:B] = t_in[B:] = t
spks_in[:B] = spks
cond_in[:B] = cond
r_in[:B] = r_in[B:] = r # (only used for meanflow)
dxdt = self.estimator.forward(
x=x_in, mask=mask_in, mu=mu_in, t=t_in, spks=spks_in, cond=cond_in,
r=r_in if meanflow else None,
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
for step in range(1, len(t_span)):
# Classifier-Free Guidance inference introduced in VoiceBox
x_in[:] = x
mask_in[:] = mask
mu_in[0] = mu
t_in[:] = t.unsqueeze(0)
spks_in[0] = spks
cond_in[0] = cond
dphi_dt = self.forward_estimator(
x_in, mask_in,
mu_in, t_in,
spks_in,
cond_in
)
dxdt, cfg_dxdt = torch.split(dxdt, [B, B], dim=0)
dxdt = ((1.0 + self.inference_cfg_rate) * dxdt - self.inference_cfg_rate * cfg_dxdt)
dt = r - t
x = x + dt * dxdt
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1].float()
return x.to(in_dtype)
def forward_estimator(self, x, mask, mu, t, spks, cond):
if isinstance(self.estimator, torch.nn.Module):
return self.estimator.forward(x, mask, mu, t, spks, cond)
else:
with self.lock:
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
self.estimator.set_input_shape('t', (2,))
self.estimator.set_input_shape('spks', (2, 80))
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
# run trt engine
self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()])
return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss
@ -189,11 +198,10 @@ class ConditionalCFM(BASECFM):
class CausalConditionalCFM(ConditionalCFM):
def __init__(self, in_channels=240, cfm_params=CFM_PARAMS, n_spks=1, spk_emb_dim=80, estimator=None):
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
# TODO: BAD BAD IDEA - IT'LL MESS UP DISTILLATION - SETTING TO NONE
self.rand_noise = None
self.rand_noise = torch.randn([1, 80, 50 * 300])
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, noised_mels=None, meanflow=False):
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
@ -206,41 +214,15 @@ class CausalConditionalCFM(ConditionalCFM):
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
noised_mels: gt mels noised a time t
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
B = mu.size(0)
z = torch.randn_like(mu)
if noised_mels is not None:
prompt_len = mu.size(2) - noised_mels.size(2)
z[..., prompt_len:] = noised_mels
# time steps for reverse diffusion
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
# fix prompt and overlap part mu and z
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if (not meanflow) and (self.t_scheduler == 'cosine'):
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
# NOTE: right now, the only meanflow models are also distilled models, which don't need CFG
# because they were distilled with CFG outputs. We would need to add another hparam and
# change the conditional logic here if we want to use CFG inference with a meanflow model.
if meanflow:
return self.basic_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, meanflow=meanflow), None
def basic_euler(self, x, t_span, mu, mask, spks, cond):
in_dtype = x.dtype
x, t_span, mu, mask, spks, cond = cast_all(x, t_span, mu, mask, spks, cond, dtype=self.estimator.dtype)
print("S3 Token -> Mel Inference...")
for t, r in tqdm(zip(t_span[..., :-1], t_span[..., 1:]), total=t_span.shape[-1] - 1):
t, r = t[None], r[None]
dxdt = self.estimator.forward(x, mask=mask, mu=mu, t=t, spks=spks, cond=cond, r=r)
dt = r - t
x = x + dt * dxdt
return x.to(in_dtype)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None

View File

@ -19,6 +19,7 @@ import torch
import torchaudio as ta
from functools import lru_cache
from typing import Optional
from omegaconf import DictConfig
from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer
from .const import S3GEN_SR
@ -30,7 +31,6 @@ from .hifigan import HiFTGenerator
from .transformer.upsample_encoder import UpsampleConformerEncoder
from .flow_matching import CausalConditionalCFM
from .decoder import ConditionalDecoder
from .configs import CFM_PARAMS
def drop_invalid_tokens(x):
@ -46,20 +46,15 @@ def get_resampler(src_sr, dst_sr, device):
class S3Token2Mel(torch.nn.Module):
"""
S3Gen's CFM decoder maps S3 speech tokens to mel-spectrograms.
CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms.
TODO: make these modules configurable?
"""
def __init__(self, meanflow=False):
def __init__(self):
super().__init__()
self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz")
self.mel_extractor = mel_spectrogram # TODO: make it a torch module?
self.speaker_encoder = CAMPPlus(
# NOTE: This doesn't affect inference. It turns off activation checkpointing
# (a training optimization), which causes a crazy DDP error with accelerate
memory_efficient=False,
)
self.meanflow = meanflow
self.speaker_encoder = CAMPPlus() # use default args
encoder = UpsampleConformerEncoder(
output_size=512,
@ -89,9 +84,15 @@ class S3Token2Mel(torch.nn.Module):
num_mid_blocks=12,
num_heads=8,
act_fn='gelu',
meanflow=self.meanflow,
)
cfm_params = CFM_PARAMS
cfm_params = DictConfig({
"sigma_min": 1e-06,
"solver": 'euler',
"t_scheduler": 'cosine',
"training_cfg_rate": 0.2,
"inference_cfg_rate": 0.7,
"reg_loss_type": 'l1',
})
decoder = CausalConditionalCFM(
spk_emb_dim=80,
cfm_params=cfm_params,
@ -110,11 +111,6 @@ class S3Token2Mel(torch.nn.Module):
params = self.tokenizer.parameters()
return next(params).device
@property
def dtype(self):
params = self.flow.parameters()
return next(params).dtype
def embed_ref(
self,
ref_wav: torch.Tensor,
@ -133,26 +129,23 @@ class S3Token2Mel(torch.nn.Module):
ref_wav = ref_wav.unsqueeze(0) # (B, L)
if ref_wav.size(1) > 10 * ref_sr:
print("WARNING: s3gen received ref longer than 10s")
print("WARNING: cosydec received ref longer than 10s")
ref_wav_24 = ref_wav
if ref_sr != S3GEN_SR:
ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav)
ref_wav_24 = ref_wav_24.to(device=device, dtype=self.dtype)
ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(dtype=self.dtype)
ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device)
ref_mels_24_len = None
# Resample to 16kHz
ref_wav_16 = ref_wav
if ref_sr != S3_SR:
ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav)
ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device)
# Speaker embedding
ref_x_vector = self.speaker_encoder.inference(ref_wav_16.to(dtype=self.dtype))
ref_x_vector = self.speaker_encoder.inference(ref_wav_16)
# Tokenize 16khz reference
ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16.float())
ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16)
# Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms)
if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]:
@ -178,10 +171,7 @@ class S3Token2Mel(torch.nn.Module):
ref_sr: Optional[int],
# pre-computed ref embedding (prod API)
ref_dict: Optional[dict] = None,
n_cfm_timesteps = None,
finalize: bool = False,
speech_token_lens=None,
noised_mels=None,
):
"""
Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
@ -209,21 +199,18 @@ class S3Token2Mel(torch.nn.Module):
if isinstance(ref_dict[rk], np.ndarray):
ref_dict[rk] = torch.from_numpy(ref_dict[rk])
if torch.is_tensor(ref_dict[rk]):
ref_dict[rk] = ref_dict[rk].to(device=self.device, dtype=self.dtype)
ref_dict[rk] = ref_dict[rk].to(self.device)
speech_tokens = torch.atleast_2d(speech_tokens)
if len(speech_tokens.shape) == 1:
speech_tokens = speech_tokens.unsqueeze(0)
# backcompat
if speech_token_lens is None:
speech_token_lens = torch.LongTensor([st.size(-1) for st in speech_tokens]).to(self.device)
# assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now"
speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device)
output_mels, _ = self.flow.inference(
token=speech_tokens,
token_len=speech_token_lens,
finalize=finalize,
noised_mels=noised_mels,
n_timesteps=n_cfm_timesteps,
meanflow=self.meanflow,
**ref_dict,
)
return output_mels
@ -231,15 +218,13 @@ class S3Token2Mel(torch.nn.Module):
class S3Token2Wav(S3Token2Mel):
"""
The decoder of S3Gen is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules.
The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules.
TODO: make these modules configurable?
"""
ignore_state_dict_missing = ("tokenizer._mel_filters", "tokenizer.window")
def __init__(self, meanflow=False):
super().__init__(meanflow)
def __init__(self):
super().__init__()
f0_predictor = ConvRNNF0Predictor()
self.mel2wav = HiFTGenerator(
@ -256,7 +241,6 @@ class S3Token2Wav(S3Token2Mel):
trim_fade = torch.zeros(2 * n_trim)
trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2
self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting)
self.estimator_dtype = "fp32"
def forward(
self,
@ -266,25 +250,9 @@ class S3Token2Wav(S3Token2Mel):
ref_sr: Optional[int],
# pre-computed ref embedding (prod API)
ref_dict: Optional[dict] = None,
finalize: bool = False,
speech_token_lens=None,
skip_vocoder=False,
n_cfm_timesteps=None,
noised_mels=None,
finalize: bool = False
):
"""
Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
NOTE: used for sync synthesis only. Please use `S3GenStreamer` for streaming synthesis.
"""
output_mels = super().forward(
speech_tokens, speech_token_lens=speech_token_lens, ref_wav=ref_wav,
ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize,
n_cfm_timesteps=n_cfm_timesteps, noised_mels=noised_mels,
)
if skip_vocoder:
return output_mels
output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
# TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now.
hift_cache_source = torch.zeros(1, 1, 0).to(self.device)
@ -306,24 +274,14 @@ class S3Token2Wav(S3Token2Mel):
ref_sr: Optional[int] = None,
# pre-computed ref embedding (prod API)
ref_dict: Optional[dict] = None,
n_cfm_timesteps = None,
finalize: bool = False,
speech_token_lens=None,
):
n_cfm_timesteps = n_cfm_timesteps or (2 if self.meanflow else 10)
noise = None
if self.meanflow:
noise = torch.randn(1, 80, speech_tokens.size(-1) * 2, dtype=self.dtype, device=self.device)
output_mels = super().forward(
speech_tokens, speech_token_lens=speech_token_lens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict,
n_cfm_timesteps=n_cfm_timesteps, finalize=finalize, noised_mels=noise,
)
return output_mels
return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
@torch.inference_mode()
def hift_inference(self, speech_feat, cache_source: torch.Tensor = None):
if cache_source is None:
cache_source = torch.zeros(1, 1, 0).to(device=self.device, dtype=self.dtype)
cache_source = torch.zeros(1, 1, 0).to(self.device)
return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source)
@torch.inference_mode()
@ -335,26 +293,11 @@ class S3Token2Wav(S3Token2Mel):
ref_sr: Optional[int] = None,
# pre-computed ref embedding (prod API)
ref_dict: Optional[dict] = None,
# left as a kwarg because this can change input/output size ratio
drop_invalid_tokens=True,
n_cfm_timesteps=None,
speech_token_lens=None,
cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here
finalize: bool = True,
):
# hallucination prevention, drop special tokens
# if drop_invalid_tokens:
# speech_tokens, speech_token_lens = drop_invalid(speech_tokens, pad=S3_QUIET_PAD)
output_mels = self.flow_inference(
speech_tokens,
speech_token_lens=speech_token_lens,
ref_wav=ref_wav,
ref_sr=ref_sr,
ref_dict=ref_dict,
n_cfm_timesteps=n_cfm_timesteps,
finalize=True,
)
output_mels = output_mels.to(dtype=self.dtype) # FIXME (fp16 mode) is this still needed?
output_wavs, output_sources = self.hift_inference(output_mels, None)
output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
output_wavs, output_sources = self.hift_inference(output_mels, cache_source)
# NOTE: ad-hoc method to reduce "spillover" from the reference clip.
output_wavs[:, :len(self.trim_fade)] *= self.trim_fade

View File

@ -1,36 +0,0 @@
import torch
import torch.nn as nn
def get_intmeanflow_time_mixer(dims):
""""
Diagonal init as described in 3.3 https://arxiv.org/pdf/2510.07979
"""
layer = nn.Linear(dims * 2, dims, bias=False)
with torch.no_grad():
target_weight = torch.zeros(dims, 2 * dims)
target_weight[:, 0:dims] = torch.eye(dims)
layer.weight.data = target_weight
return layer
if __name__ == '__main__':
D_example = 6
W_layer = get_intmeanflow_time_mixer(D_example)
print(f"Layer weight (AFTER init):\n{W_layer.weight.data}\n")
e_t = torch.tensor([0., 1., 2., 3., 4., 5.])
e_r = torch.tensor([6., 7., 8., 9., 10., 11.])
e_concat = torch.cat([e_t, e_r]).unsqueeze(0) # Shape (1, 12)
output = W_layer(e_concat)
print(f"Test Input e_t: \n{e_t}")
print(f"Test Input e_r: \n{e_r}")
print(f"Test Input concat: \n{e_concat}")
print(f"Forward Pass Output: \n{output.squeeze(0)}")

View File

@ -181,7 +181,6 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
lengths = lengths.long()
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0,

View File

@ -1,11 +1,8 @@
"""mel-spectrogram extraction in Matcha-TTS"""
import logging
from librosa.filters import mel as librosa_mel_fn
import torch
import numpy as np
logger = logging.getLogger(__name__)
# NOTE: they decalred these global vars
mel_basis = {}
@ -45,11 +42,10 @@ def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=48
if len(y.shape) == 1:
y = y[None, ]
# Debug: Check for audio clipping (values outside [-1.0, 1.0] range)
min_val = torch.min(y)
max_val = torch.max(y)
if min_val < -1.0 or max_val > 1.0:
logger.warning(f"Audio values outside normalized range: min={min_val.item():.4f}, max={max_val.item():.4f}")
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:

View File

@ -10,9 +10,6 @@ from types import MethodType
logger = logging.getLogger(__name__)
LLAMA_ALIGNED_HEADS = [(12, 15), (13, 11), (9, 2)]
@dataclass
class AlignmentAnalysisResult:
# was this frame detected as being part of a noisy beginning chunk with potential hallucinations?
@ -52,22 +49,21 @@ class AlignmentStreamAnalyzer:
self.complete = False
self.completed_at = None
# Track generated tokens for repetition detection
self.generated_tokens = []
# Using `output_attentions=True` is incompatible with optimized attention kernels, so
# using it for all layers slows things down too much. We can apply it to just one layer
# by intercepting the kwargs and adding a forward hook (credit: jrm)
self.last_aligned_attns = []
for i, (layer_idx, head_idx) in enumerate(LLAMA_ALIGNED_HEADS):
self.last_aligned_attns += [None]
self._add_attention_spy(tfmr, i, layer_idx, head_idx)
self.last_aligned_attn = None
self._add_attention_spy(tfmr, alignment_layer_idx)
def _add_attention_spy(self, tfmr, buffer_idx, layer_idx, head_idx):
def _add_attention_spy(self, tfmr, alignment_layer_idx):
"""
Adds a forward hook to a specific attention layer to collect outputs.
Using `output_attentions=True` is incompatible with optimized attention kernels, so
using it for all layers slows things down too much.
(credit: jrm)
"""
def attention_forward_hook(module, input, output):
"""
See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
@ -75,23 +71,27 @@ class AlignmentStreamAnalyzer:
- When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
- `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
"""
if isinstance(output, tuple) and len(output) > 1 and output[1] is not None:
step_attention = output[1].cpu() # (B, n_heads, T0, Ti)
self.last_aligned_attns[buffer_idx] = step_attention[0, head_idx] # (T0, Ti)
step_attention = output[1].cpu() # (B, 16, N, N)
self.last_aligned_attn = step_attention[0].mean(0) # (N, N)
target_layer = tfmr.layers[layer_idx].self_attn
# Register hook and store the handle
target_layer.register_forward_hook(attention_forward_hook)
if hasattr(tfmr, 'config') and hasattr(tfmr.config, 'output_attentions'):
self.original_output_attentions = tfmr.config.output_attentions
tfmr.config.output_attentions = True
target_layer = tfmr.layers[alignment_layer_idx].self_attn
hook_handle = target_layer.register_forward_hook(attention_forward_hook)
def step(self, logits, next_token=None):
# Backup original forward
original_forward = target_layer.forward
def patched_forward(self, *args, **kwargs):
kwargs['output_attentions'] = True
return original_forward(*args, **kwargs)
# TODO: how to unpatch it?
target_layer.forward = MethodType(patched_forward, target_layer)
def step(self, logits):
"""
Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
"""
# extract approximate alignment matrix chunk (1 frame at a time after the first chunk)
aligned_attn = torch.stack(self.last_aligned_attns).mean(dim=0) # (N, N)
aligned_attn = self.last_aligned_attn # (N, N)
i, j = self.text_tokens_slice
if self.curr_frame_pos == 0:
# first chunk has conditioning info, text tokens, and BOS token
@ -133,46 +133,22 @@ class AlignmentStreamAnalyzer:
last_text_token_duration = A[15:, -3:].sum()
# Activations for the final token that last too long are likely hallucinations.
long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 5) # 200ms
long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 10) # 400ms
# If there are activations in previous tokens after generation has completed, assume this is a repetition error.
alignment_repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
# Track generated tokens for repetition detection
if next_token is not None:
# Convert tensor to scalar if needed
if isinstance(next_token, torch.Tensor):
token_id = next_token.item() if next_token.numel() == 1 else next_token.view(-1)[0].item()
else:
token_id = next_token
self.generated_tokens.append(token_id)
# Keep only last 8 tokens to prevent memory issues
if len(self.generated_tokens) > 8:
self.generated_tokens = self.generated_tokens[-8:]
# Check for excessive token repetition (3x same token in a row)
token_repetition = (
# self.complete and
len(self.generated_tokens) >= 3 and
len(set(self.generated_tokens[-2:])) == 1
)
if token_repetition:
repeated_token = self.generated_tokens[-1]
logger.warning(f"🚨 Detected 2x repetition of token {repeated_token}")
# Suppress EoS to prevent early termination
if cur_text_posn < S - 3 and S > 5: # Only suppress if text is longer than 5 tokens
logits[..., self.eos_idx] = -2**15
repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
# If a bad ending is detected, force emit EOS by modifying logits
# NOTE: this means logits may be inconsistent with latents!
if long_tail or alignment_repetition or token_repetition:
logger.warning(f"forcing EOS token, {long_tail=}, {alignment_repetition=}, {token_repetition=}")
if long_tail or repetition:
logger.warn(f"forcing EOS token, {long_tail=}, {repetition=}")
# (±2**15 is safe for all dtypes >= 16bit)
logits = -(2**15) * torch.ones_like(logits)
logits[..., self.eos_idx] = 2**15
# Suppress EoS to prevent early termination
if cur_text_posn < S - 3: # FIXME: arbitrary
logits[..., self.eos_idx] = -2**15
self.curr_frame_pos += 1
return logits

View File

@ -32,42 +32,6 @@ LLAMA_520M_CONFIG_DICT = dict(
use_cache=True,
)
GPT2_MEDIUM_CONFIG = {
"activation_function": "gelu_new",
"architectures": [
"GPT2LMHeadModel"
],
"attn_pdrop": 0.1,
"bos_token_id": 50256,
"embd_pdrop": 0.1,
"eos_token_id": 50256,
"initializer_range": 0.02,
"layer_norm_epsilon": 1e-05,
"model_type": "gpt2",
"n_ctx": 8196,
"n_embd": 1024,
"hidden_size": 1024,
"n_head": 16,
"n_layer": 24,
"n_positions": 8196,
"n_special": 0,
"predict_special_tokens": True,
"resid_pdrop": 0.1,
"summary_activation": None,
"summary_first_dropout": 0.1,
"summary_proj_to_labels": True,
"summary_type": "cls_index",
"summary_use_proj": True,
"task_specific_params": {
"text-generation": {
"do_sample": True,
"max_length": 50
}
},
"vocab_size": 50276,
}
LLAMA_CONFIGS = {
"Llama_520M": LLAMA_520M_CONFIG_DICT,
"GPT2_medium": GPT2_MEDIUM_CONFIG,
}

View File

@ -2,40 +2,26 @@ from ..llama_configs import LLAMA_CONFIGS
class T3Config:
def __init__(self, text_tokens_dict_size=704):
self.start_text_token = 255
self.stop_text_token = 0
self.text_tokens_dict_size = text_tokens_dict_size
self.max_text_tokens = 2048
start_text_token = 255
stop_text_token = 0
text_tokens_dict_size = 704
max_text_tokens = 2048
self.start_speech_token = 6561
self.stop_speech_token = 6562
self.speech_tokens_dict_size = 8194
self.max_speech_tokens = 4096
start_speech_token = 6561
stop_speech_token = 6562
speech_tokens_dict_size = 8194
max_speech_tokens = 4096
self.llama_config_name = "Llama_520M"
self.input_pos_emb = "learned"
self.speech_cond_prompt_len = 150
llama_config_name = "Llama_520M"
input_pos_emb = "learned"
speech_cond_prompt_len = 150
self.encoder_type = "voice_encoder"
self.speaker_embed_size = 256
self.use_perceiver_resampler = True
self.emotion_adv = True
# For T3CondEnc
encoder_type = "voice_encoder"
speaker_embed_size = 256
use_perceiver_resampler = True
emotion_adv = True
@property
def n_channels(self):
return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
@property
def is_multilingual(self):
return self.text_tokens_dict_size == 2454
@classmethod
def english_only(cls):
"""Create configuration for English-only TTS model."""
return cls(text_tokens_dict_size=704)
@classmethod
def multilingual(cls):
"""Create configuration for multilingual TTS model."""
return cls(text_tokens_dict_size=2454)

View File

@ -3,21 +3,13 @@
import logging
from typing import Union, Optional, List
logger = logging.getLogger(__name__)
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from transformers import LlamaModel, LlamaConfig, GPT2Config, GPT2Model
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
MinPLogitsWarper,
)
from transformers import LlamaModel, LlamaConfig
from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor
from .modules.learned_pos_emb import LearnedPositionEmbeddings
from .modules.cond_enc import T3CondEnc, T3Cond
@ -25,12 +17,17 @@ from .modules.t3_config import T3Config
from .llama_configs import LLAMA_CONFIGS
from .inference.t3_hf_backend import T3HuggingfaceBackend
from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
from ..utils import AttrDict
logger = logging.getLogger(__name__)
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def _ensure_BOT_EOT(text_tokens: Tensor, hp):
B = text_tokens.size(0)
assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token"
@ -47,22 +44,11 @@ class T3(nn.Module):
different PE embedding space for speech.
"""
def __init__(self, hp=None):
if hp is None:
hp = T3Config.english_only()
def __init__(self, hp=T3Config()):
super().__init__()
self.hp = hp
config_dict = LLAMA_CONFIGS[hp.llama_config_name]
self.is_gpt = config_dict.get("model_type") == "gpt2"
if self.is_gpt:
self.cfg = GPT2Config(**config_dict)
self.tfmr = GPT2Model(self.cfg)
else:
self.cfg = LlamaConfig(**config_dict)
self.tfmr = LlamaModel(self.cfg)
self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
self.tfmr = LlamaModel(self.cfg)
self.dim = self.cfg.hidden_size
self.deepspeed_patch_applied = False
@ -72,8 +58,6 @@ class T3(nn.Module):
self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim)
# custom position embedding
self.text_pos_emb = None
self.speech_pos_emb = None
if hp.input_pos_emb == "learned":
max_text_seq_len = hp.max_text_tokens + 2
self.text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, self.dim)
@ -83,7 +67,7 @@ class T3(nn.Module):
# logit projection
self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False)
self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=self.is_gpt)
self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=False)
self.compiled = False
@property
@ -95,9 +79,8 @@ class T3(nn.Module):
Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`.
"""
if t3_cond.cond_prompt_speech_tokens is not None and t3_cond.cond_prompt_speech_emb is None:
t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens)
if not self.is_gpt:
t3_cond.cond_prompt_speech_emb += self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens) + \
self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
return self.cond_enc(t3_cond) # (B, len_cond, dim)
def prepare_input_embeds(
@ -106,13 +89,11 @@ class T3(nn.Module):
t3_cond: T3Cond,
text_tokens: torch.LongTensor,
speech_tokens: torch.LongTensor,
cfg_weight: float = 0.0,
):
# prepare input embeddings (skip backbone tranformer embeddings)
cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim)
text_emb = self.text_emb(text_tokens) # (B, len_text, dim)
if cfg_weight > 0.0 and not self.is_gpt:
text_emb[1].zero_() # CFG uncond
text_emb[1].zero_() # CFG uncond
speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim)
if self.hp.input_pos_emb == "learned":
@ -240,11 +221,10 @@ class T3(nn.Module):
stop_on_eos=True,
do_sample=True,
temperature=0.8,
top_p=0.95,
min_p=0.05,
top_p=0.8,
length_penalty=1.0,
repetition_penalty=1.2,
cfg_weight=0.5,
repetition_penalty=2.0,
cfg_weight=0,
):
"""
Args:
@ -264,7 +244,6 @@ class T3(nn.Module):
t3_cond=t3_cond,
text_tokens=text_tokens,
speech_tokens=initial_speech_tokens,
cfg_weight=cfg_weight,
)
# In order to use the standard HF generate method, we need to extend some methods to inject our custom logic
@ -275,18 +254,13 @@ class T3(nn.Module):
# TODO? synchronize the expensive compile function
# with self.compile_lock:
if not self.compiled:
# Default to None for English models, only create for multilingual
alignment_stream_analyzer = None
if self.hp.is_multilingual:
alignment_stream_analyzer = AlignmentStreamAnalyzer(
self.tfmr,
None,
text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
alignment_layer_idx=9, # TODO: hparam or something?
eos_idx=self.hp.stop_speech_token,
)
assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
alignment_stream_analyzer = AlignmentStreamAnalyzer(
self.tfmr,
None,
text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
alignment_layer_idx=9, # TODO: hparam or something?
eos_idx=self.hp.stop_speech_token,
)
patched_model = T3HuggingfaceBackend(
config=self.cfg,
llama=self.tfmr,
@ -307,7 +281,7 @@ class T3(nn.Module):
# max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
# num_return_sequences=num_return_sequences,
# temperature=temperature,
# min_p=min_p,
# top_p=top_p,
# length_penalty=length_penalty,
# repetition_penalty=repetition_penalty,
# do_sample=do_sample,
@ -332,9 +306,7 @@ class T3(nn.Module):
# Instantiate the logits processors.
top_p_warper = TopPLogitsWarper(top_p=top_p)
min_p_warper = MinPLogitsWarper(min_p=min_p)
top_p_warper = TopPLogitsWarper(top_p=top_p)
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty))
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
# ---- Initial Forward Pass (no kv_cache yet) ----
output = self.patched_model(
@ -350,32 +322,21 @@ class T3(nn.Module):
# ---- Generation Loop using kv_cache ----
for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
logits_step = output.logits[:, -1, :]
# CFG combine → (1, V)
cond = logits_step[0:1, :]
uncond = logits_step[1:2, :]
cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
logits = cond + cfg * (cond - uncond)
# Apply alignment stream analyzer integrity checks
if self.patched_model.alignment_stream_analyzer is not None:
if logits.dim() == 1: # guard in case something upstream squeezed
logits = logits.unsqueeze(0) # (1, V)
# Pass the last generated token for repetition tracking
last_token = generated_ids[0, -1].item() if len(generated_ids[0]) > 0 else None
logits = self.patched_model.alignment_stream_analyzer.step(logits, next_token=last_token) # (1, V)
logits = output.logits[:, -1, :]
# CFG
logits_cond = logits[0:1]
logits_uncond = logits[1:2]
logits = logits_cond + cfg_weight * (logits_cond - logits_uncond)
logits = logits.squeeze(1)
# Apply repetition penalty
ids_for_proc = generated_ids[:1, ...] # batch = 1
logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
# Apply temperature scaling.
if temperature != 1.0:
logits = logits / temperature
# Apply min_p and top_p filtering
logits = min_p_warper(ids_for_proc, logits)
logits = top_p_warper(ids_for_proc, logits)
# Apply repetition penalty and topp filtering.
logits = repetition_penalty_processor(generated_ids, logits)
logits = top_p_warper(None, logits)
# Convert logits to probabilities and sample the next token.
probs = torch.softmax(logits, dim=-1)
@ -386,7 +347,6 @@ class T3(nn.Module):
# Check for EOS token.
if next_token.view(-1) == self.hp.stop_speech_token:
logger.info(f"✅ EOS token detected! Stopping generation at step {i+1}")
break
# Get embedding for the new token.
@ -410,81 +370,3 @@ class T3(nn.Module):
# Concatenate all predicted tokens along the sequence dimension.
predicted_tokens = torch.cat(predicted, dim=1) # shape: (B, num_tokens)
return predicted_tokens
@torch.inference_mode()
def inference_turbo(self, t3_cond, text_tokens, temperature=0.8, top_k=1000, top_p=0.95, repetition_penalty=1.2,
max_gen_len=1000):
logits_processors = LogitsProcessorList()
if temperature > 0 and temperature != 1.0:
logits_processors.append(TemperatureLogitsWarper(temperature))
if top_k > 0:
logits_processors.append(TopKLogitsWarper(top_k))
if top_p < 1.0:
logits_processors.append(TopPLogitsWarper(top_p))
if repetition_penalty != 1.0:
logits_processors.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
speech_start_token = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])
embeds, _ = self.prepare_input_embeds(
t3_cond=t3_cond,
text_tokens=text_tokens,
speech_tokens=speech_start_token,
cfg_weight=0.0,
)
generated_speech_tokens = []
llm_outputs = self.tfmr(
inputs_embeds=embeds,
use_cache=True
)
hidden_states = llm_outputs[0]
past_key_values = llm_outputs.past_key_values
speech_hidden = hidden_states[:, -1:]
speech_logits = self.speech_head(speech_hidden)
processed_logits = logits_processors(speech_start_token, speech_logits[:, -1, :])
probs = F.softmax(processed_logits, dim=-1)
next_speech_token = torch.multinomial(probs, num_samples=1)
generated_speech_tokens.append(next_speech_token)
current_speech_token = next_speech_token
for _ in tqdm(range(max_gen_len)):
current_speech_embed = self.speech_emb(current_speech_token)
llm_outputs = self.tfmr(
inputs_embeds=current_speech_embed,
past_key_values=past_key_values,
use_cache=True
)
hidden_states = llm_outputs[0]
past_key_values = llm_outputs.past_key_values
speech_logits = self.speech_head(hidden_states)
input_ids = torch.cat(generated_speech_tokens, dim=1)
processed_logits = logits_processors(input_ids, speech_logits[:, -1, :])
if torch.all(processed_logits == -float("inf")):
print("Warning: All logits are -inf")
break
probs = F.softmax(processed_logits, dim=-1)
next_speech_token = torch.multinomial(probs, num_samples=1)
generated_speech_tokens.append(next_speech_token)
current_speech_token = next_speech_token
if torch.all(next_speech_token == self.hp.stop_speech_token):
break
all_tokens = torch.cat(generated_speech_tokens, dim=1)
# Remove EOS token if present
if all_tokens.size(1) > 0 and all_tokens[0, -1] == self.hp.stop_speech_token:
all_tokens = all_tokens[:, :-1]
return all_tokens

View File

@ -1 +1 @@
from .tokenizer import EnTokenizer, MTLTokenizer
from .tokenizer import EnTokenizer

View File

@ -1,11 +1,7 @@
import logging
import json
import torch
from pathlib import Path
from unicodedata import category, normalize
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
# Special tokens
@ -32,7 +28,7 @@ class EnTokenizer:
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
return text_tokens
def encode(self, txt: str):
def encode( self, txt: str, verbose=False):
"""
clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer
"""
@ -45,269 +41,10 @@ class EnTokenizer:
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt: str = self.tokenizer.decode(seq, skip_special_tokens=False)
txt: str = self.tokenizer.decode(seq,
skip_special_tokens=False)
txt = txt.replace(' ', '')
txt = txt.replace(SPACE, ' ')
txt = txt.replace(EOT, '')
txt = txt.replace(UNK, '')
return txt
# Model repository
REPO_ID = "ResembleAI/chatterbox"
# Global instances for optional dependencies
_kakasi = None
_dicta = None
_russian_stresser = None
def is_kanji(c: str) -> bool:
"""Check if character is kanji."""
return 19968 <= ord(c) <= 40959
def is_katakana(c: str) -> bool:
"""Check if character is katakana."""
return 12449 <= ord(c) <= 12538
def hiragana_normalize(text: str) -> str:
"""Japanese text normalization: converts kanji to hiragana; katakana remains the same."""
global _kakasi
try:
if _kakasi is None:
import pykakasi
_kakasi = pykakasi.kakasi()
result = _kakasi.convert(text)
out = []
for r in result:
inp = r['orig']
hira = r["hira"]
# Any kanji in the phrase
if any([is_kanji(c) for c in inp]):
if hira and hira[0] in ["", ""]: # Safety check for empty hira
hira = " " + hira
out.append(hira)
# All katakana
elif all([is_katakana(c) for c in inp]) if inp else False: # Safety check for empty inp
out.append(r['orig'])
else:
out.append(inp)
normalized_text = "".join(out)
# Decompose Japanese characters for tokenizer compatibility
import unicodedata
normalized_text = unicodedata.normalize('NFKD', normalized_text)
return normalized_text
except ImportError:
logger.warning("pykakasi not available - Japanese text processing skipped")
return text
def add_hebrew_diacritics(text: str) -> str:
"""Hebrew text normalization: adds diacritics to Hebrew text."""
global _dicta
try:
if _dicta is None:
from dicta_onnx import Dicta
_dicta = Dicta()
return _dicta.add_diacritics(text)
except ImportError:
logger.warning("dicta_onnx not available - Hebrew text processing skipped")
return text
except Exception as e:
logger.warning(f"Hebrew diacritization failed: {e}")
return text
def korean_normalize(text: str) -> str:
"""Korean text normalization: decompose syllables into Jamo for tokenization."""
def decompose_hangul(char):
"""Decompose Korean syllable into Jamo components."""
if not ('\uac00' <= char <= '\ud7af'):
return char
# Hangul decomposition formula
base = ord(char) - 0xAC00
initial = chr(0x1100 + base // (21 * 28))
medial = chr(0x1161 + (base % (21 * 28)) // 28)
final = chr(0x11A7 + base % 28) if base % 28 > 0 else ''
return initial + medial + final
# Decompose syllables and normalize punctuation
result = ''.join(decompose_hangul(char) for char in text)
return result.strip()
class ChineseCangjieConverter:
"""Converts Chinese characters to Cangjie codes for tokenization."""
def __init__(self, model_dir=None):
self.word2cj = {}
self.cj2word = {}
self.segmenter = None
self._load_cangjie_mapping(model_dir)
self._init_segmenter()
def _load_cangjie_mapping(self, model_dir=None):
"""Load Cangjie mapping from HuggingFace model repository."""
try:
cangjie_file = hf_hub_download(
repo_id=REPO_ID,
filename="Cangjie5_TC.json",
cache_dir=model_dir
)
with open(cangjie_file, "r", encoding="utf-8") as fp:
data = json.load(fp)
for entry in data:
word, code = entry.split("\t")[:2]
self.word2cj[word] = code
if code not in self.cj2word:
self.cj2word[code] = [word]
else:
self.cj2word[code].append(word)
except Exception as e:
logger.warning(f"Could not load Cangjie mapping: {e}")
def _init_segmenter(self):
"""Initialize pkuseg segmenter."""
try:
from spacy_pkuseg import pkuseg
self.segmenter = pkuseg()
except ImportError:
logger.warning("pkuseg not available - Chinese segmentation will be skipped")
self.segmenter = None
def _cangjie_encode(self, glyph: str):
"""Encode a single Chinese glyph to Cangjie code."""
normed_glyph = glyph
code = self.word2cj.get(normed_glyph, None)
if code is None: # e.g. Japanese hiragana
return None
index = self.cj2word[code].index(normed_glyph)
index = str(index) if index > 0 else ""
return code + str(index)
def __call__(self, text):
"""Convert Chinese characters in text to Cangjie tokens."""
output = []
if self.segmenter is not None:
segmented_words = self.segmenter.cut(text)
full_text = " ".join(segmented_words)
else:
full_text = text
for t in full_text:
if category(t) == "Lo":
cangjie = self._cangjie_encode(t)
if cangjie is None:
output.append(t)
continue
code = []
for c in cangjie:
code.append(f"[cj_{c}]")
code.append("[cj_.]")
code = "".join(code)
output.append(code)
else:
output.append(t)
return "".join(output)
def add_russian_stress(text: str) -> str:
"""Russian text normalization: adds stress marks to Russian text."""
global _russian_stresser
try:
if _russian_stresser is None:
from russian_text_stresser.text_stresser import RussianTextStresser
_russian_stresser = RussianTextStresser()
return _russian_stresser.stress_text(text)
except ImportError:
logger.warning("russian_text_stresser not available - Russian stress labeling skipped")
return text
except Exception as e:
logger.warning(f"Russian stress labeling failed: {e}")
return text
class MTLTokenizer:
def __init__(self, vocab_file_path):
self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
model_dir = Path(vocab_file_path).parent
self.cangjie_converter = ChineseCangjieConverter(model_dir)
self.check_vocabset_sot_eot()
def check_vocabset_sot_eot(self):
voc = self.tokenizer.get_vocab()
assert SOT in voc
assert EOT in voc
def preprocess_text(self, raw_text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
"""
Text preprocessor that handles lowercase conversion and NFKD normalization.
"""
preprocessed_text = raw_text
if lowercase:
preprocessed_text = preprocessed_text.lower()
if nfkd_normalize:
preprocessed_text = normalize("NFKD", preprocessed_text)
return preprocessed_text
def text_to_tokens(self, text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
text_tokens = self.encode(text, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize)
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
return text_tokens
def encode(self, txt: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
txt = self.preprocess_text(txt, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize)
# Language-specific text processing
if language_id == 'zh':
txt = self.cangjie_converter(txt)
elif language_id == 'ja':
txt = hiragana_normalize(txt)
elif language_id == 'he':
txt = add_hebrew_diacritics(txt)
elif language_id == 'ko':
txt = korean_normalize(txt)
elif language_id == 'ru':
txt = add_russian_stress(txt)
# Prepend language token
if language_id:
txt = f"[{language_id.lower()}]{txt}"
txt = txt.replace(' ', SPACE)
return self.tokenizer.encode(txt).ids
def decode(self, seq):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt = self.tokenizer.decode(seq, skip_special_tokens=False)
txt = txt.replace(' ', '').replace(SPACE, ' ').replace(EOT, '').replace(UNK, '')
return txt

View File

@ -1,4 +0,0 @@
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self

View File

@ -1,301 +0,0 @@
from dataclasses import dataclass
from pathlib import Path
import os
import librosa
import torch
import perth
import torch.nn.functional as F
from safetensors.torch import load_file as load_safetensors
from huggingface_hub import snapshot_download
from .models.t3 import T3
from .models.t3.modules.t3_config import T3Config
from .models.s3tokenizer import S3_SR, drop_invalid_tokens
from .models.s3gen import S3GEN_SR, S3Gen
from .models.tokenizers import MTLTokenizer
from .models.voice_encoder import VoiceEncoder
from .models.t3.modules.cond_enc import T3Cond
REPO_ID = "ResembleAI/chatterbox"
# Supported languages for the multilingual model
SUPPORTED_LANGUAGES = {
"ar": "Arabic",
"da": "Danish",
"de": "German",
"el": "Greek",
"en": "English",
"es": "Spanish",
"fi": "Finnish",
"fr": "French",
"he": "Hebrew",
"hi": "Hindi",
"it": "Italian",
"ja": "Japanese",
"ko": "Korean",
"ms": "Malay",
"nl": "Dutch",
"no": "Norwegian",
"pl": "Polish",
"pt": "Portuguese",
"ru": "Russian",
"sv": "Swedish",
"sw": "Swahili",
"tr": "Turkish",
"zh": "Chinese",
}
def punc_norm(text: str) -> str:
"""
Quick cleanup func for punctuation from LLMs or
containing chars not seen often in the dataset
"""
if len(text) == 0:
return "You need to add some text for me to talk."
# Capitalise first letter
if text[0].islower():
text = text[0].upper() + text[1:]
# Remove multiple space chars
text = " ".join(text.split())
# Replace uncommon/llm punc
punc_to_replace = [
("...", ", "),
("", ", "),
(":", ","),
(" - ", ", "),
(";", ", "),
("", "-"),
("", "-"),
(" ,", ","),
("", "\""),
("", "\""),
("", "'"),
("", "'"),
]
for old_char_sequence, new_char in punc_to_replace:
text = text.replace(old_char_sequence, new_char)
# Add full stop if no ending punc
text = text.rstrip(" ")
sentence_enders = {".", "!", "?", "-", ",","","","","",""}
if not any(text.endswith(p) for p in sentence_enders):
text += "."
return text
@dataclass
class Conditionals:
"""
Conditionals for T3 and S3Gen
- T3 conditionals:
- speaker_emb
- clap_emb
- cond_prompt_speech_tokens
- cond_prompt_speech_emb
- emotion_adv
- S3Gen conditionals:
- prompt_token
- prompt_token_len
- prompt_feat
- prompt_feat_len
- embedding
"""
t3: T3Cond
gen: dict
def to(self, device):
self.t3 = self.t3.to(device=device)
for k, v in self.gen.items():
if torch.is_tensor(v):
self.gen[k] = v.to(device=device)
return self
def save(self, fpath: Path):
arg_dict = dict(
t3=self.t3.__dict__,
gen=self.gen
)
torch.save(arg_dict, fpath)
@classmethod
def load(cls, fpath, map_location="cpu"):
kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
class ChatterboxMultilingualTTS:
ENC_COND_LEN = 6 * S3_SR
DEC_COND_LEN = 10 * S3GEN_SR
def __init__(
self,
t3: T3,
s3gen: S3Gen,
ve: VoiceEncoder,
tokenizer: MTLTokenizer,
device: str,
conds: Conditionals = None,
):
self.sr = S3GEN_SR # sample rate of synthesized audio
self.t3 = t3
self.s3gen = s3gen
self.ve = ve
self.tokenizer = tokenizer
self.device = device
self.conds = conds
self.watermarker = perth.PerthImplicitWatermarker()
@classmethod
def get_supported_languages(cls):
"""Return dictionary of supported language codes and names."""
return SUPPORTED_LANGUAGES.copy()
@classmethod
def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS':
ckpt_dir = Path(ckpt_dir)
ve = VoiceEncoder()
ve.load_state_dict(
torch.load(ckpt_dir / "ve.pt", weights_only=True)
)
ve.to(device).eval()
t3 = T3(T3Config.multilingual())
t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors")
if "model" in t3_state.keys():
t3_state = t3_state["model"][0]
t3.load_state_dict(t3_state)
t3.to(device).eval()
s3gen = S3Gen()
s3gen.load_state_dict(
torch.load(ckpt_dir / "s3gen.pt", weights_only=True)
)
s3gen.to(device).eval()
tokenizer = MTLTokenizer(
str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json")
)
conds = None
if (builtin_voice := ckpt_dir / "conds.pt").exists():
conds = Conditionals.load(builtin_voice).to(device)
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
@classmethod
def from_pretrained(cls, device: torch.device) -> 'ChatterboxMultilingualTTS':
ckpt_dir = Path(
snapshot_download(
repo_id=REPO_ID,
repo_type="model",
revision="main",
allow_patterns=["ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"],
token=os.getenv("HF_TOKEN"),
)
)
return cls.from_local(ckpt_dir, device)
def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
## Load reference wav
s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
# Speech cond prompt tokens
t3_cond_prompt_tokens = None
if plen := self.t3.hp.speech_cond_prompt_len:
s3_tokzr = self.s3gen.tokenizer
t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
# Voice-encoder speaker embedding
ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
t3_cond = T3Cond(
speaker_emb=ve_embed,
cond_prompt_speech_tokens=t3_cond_prompt_tokens,
emotion_adv=exaggeration * torch.ones(1, 1, 1),
).to(device=self.device)
self.conds = Conditionals(t3_cond, s3gen_ref_dict)
def generate(
self,
text,
language_id,
audio_prompt_path=None,
exaggeration=0.5,
cfg_weight=0.5,
temperature=0.8,
repetition_penalty=2.0,
min_p=0.05,
top_p=1.0,
):
# Validate language_id
if language_id and language_id.lower() not in SUPPORTED_LANGUAGES:
supported_langs = ", ".join(SUPPORTED_LANGUAGES.keys())
raise ValueError(
f"Unsupported language_id '{language_id}'. "
f"Supported languages: {supported_langs}"
)
if audio_prompt_path:
self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
else:
assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
# Update exaggeration if needed
if float(exaggeration) != float(self.conds.t3.emotion_adv[0, 0, 0].item()):
_cond: T3Cond = self.conds.t3
self.conds.t3 = T3Cond(
speaker_emb=_cond.speaker_emb,
cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens,
emotion_adv=exaggeration * torch.ones(1, 1, 1),
).to(device=self.device)
# Norm and tokenize text
text = punc_norm(text)
text_tokens = self.tokenizer.text_to_tokens(text, language_id=language_id.lower() if language_id else None).to(self.device)
text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
sot = self.t3.hp.start_text_token
eot = self.t3.hp.stop_text_token
text_tokens = F.pad(text_tokens, (1, 0), value=sot)
text_tokens = F.pad(text_tokens, (0, 1), value=eot)
with torch.inference_mode():
speech_tokens = self.t3.inference(
t3_cond=self.conds.t3,
text_tokens=text_tokens,
max_new_tokens=1000, # TODO: use the value in config
temperature=temperature,
cfg_weight=cfg_weight,
repetition_penalty=repetition_penalty,
min_p=min_p,
top_p=top_p,
)
# Extract only the conditional batch.
speech_tokens = speech_tokens[0]
# TODO: output becomes 1D
speech_tokens = drop_invalid_tokens(speech_tokens)
speech_tokens = speech_tokens.to(self.device)
wav, _ = self.s3gen.inference(
speech_tokens=speech_tokens,
ref_dict=self.conds.gen,
)
wav = wav.squeeze(0).detach().cpu().numpy()
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
return torch.from_numpy(watermarked_wav).unsqueeze(0)

View File

@ -6,7 +6,6 @@ import torch
import perth
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from .models.t3 import T3
from .models.s3tokenizer import S3_SR, drop_invalid_tokens
@ -137,12 +136,12 @@ class ChatterboxTTS:
ve = VoiceEncoder()
ve.load_state_dict(
load_file(ckpt_dir / "ve.safetensors")
torch.load(ckpt_dir / "ve.pt", map_location=map_location)
)
ve.to(device).eval()
t3 = T3()
t3_state = load_file(ckpt_dir / "t3_cfg.safetensors")
t3_state = torch.load(ckpt_dir / "t3_cfg.pt", map_location=map_location)
if "model" in t3_state.keys():
t3_state = t3_state["model"][0]
t3.load_state_dict(t3_state)
@ -150,7 +149,7 @@ class ChatterboxTTS:
s3gen = S3Gen()
s3gen.load_state_dict(
load_file(ckpt_dir / "s3gen.safetensors"), strict=False
torch.load(ckpt_dir / "s3gen.pt", map_location=map_location)
)
s3gen.to(device).eval()
@ -173,8 +172,8 @@ class ChatterboxTTS:
else:
print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
device = "cpu"
for fpath in ["ve.safetensors", "t3_cfg.safetensors", "s3gen.safetensors", "tokenizer.json", "conds.pt"]:
for fpath in ["ve.pt", "t3_cfg.pt", "s3gen.pt", "tokenizer.json", "conds.pt"]:
local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath)
return cls.from_local(Path(local_path).parent, device)
@ -208,9 +207,6 @@ class ChatterboxTTS:
def generate(
self,
text,
repetition_penalty=1.2,
min_p=0.05,
top_p=1.0,
audio_prompt_path=None,
exaggeration=0.5,
cfg_weight=0.5,
@ -233,9 +229,7 @@ class ChatterboxTTS:
# Norm and tokenize text
text = punc_norm(text)
text_tokens = self.tokenizer.text_to_tokens(text).to(self.device)
if cfg_weight > 0.0:
text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
sot = self.t3.hp.start_text_token
eot = self.t3.hp.stop_text_token
@ -249,18 +243,12 @@ class ChatterboxTTS:
max_new_tokens=1000, # TODO: use the value in config
temperature=temperature,
cfg_weight=cfg_weight,
repetition_penalty=repetition_penalty,
min_p=min_p,
top_p=top_p,
)
# Extract only the conditional batch.
speech_tokens = speech_tokens[0]
# TODO: output becomes 1D
speech_tokens = drop_invalid_tokens(speech_tokens)
speech_tokens = speech_tokens[speech_tokens < 6561]
speech_tokens = speech_tokens.to(self.device)
wav, _ = self.s3gen.inference(
@ -269,4 +257,4 @@ class ChatterboxTTS:
)
wav = wav.squeeze(0).detach().cpu().numpy()
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
return torch.from_numpy(watermarked_wav).unsqueeze(0)
return torch.from_numpy(watermarked_wav).unsqueeze(0)

View File

@ -1,296 +0,0 @@
import os
import math
from dataclasses import dataclass
from pathlib import Path
import librosa
import torch
import perth
import pyloudnorm as ln
from safetensors.torch import load_file
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from .models.t3 import T3
from .models.s3tokenizer import S3_SR
from .models.s3gen import S3GEN_SR, S3Gen
from .models.tokenizers import EnTokenizer
from .models.voice_encoder import VoiceEncoder
from .models.t3.modules.cond_enc import T3Cond
from .models.t3.modules.t3_config import T3Config
from .models.s3gen.const import S3GEN_SIL
import logging
logger = logging.getLogger(__name__)
REPO_ID = "ResembleAI/chatterbox-turbo"
def punc_norm(text: str) -> str:
"""
Quick cleanup func for punctuation from LLMs or
containing chars not seen often in the dataset
"""
if len(text) == 0:
return "You need to add some text for me to talk."
# Capitalise first letter
if text[0].islower():
text = text[0].upper() + text[1:]
# Remove multiple space chars
text = " ".join(text.split())
# Replace uncommon/llm punc
punc_to_replace = [
("", ", "),
(":", ","),
("", "-"),
("", "-"),
(" ,", ","),
("", "\""),
("", "\""),
("", "'"),
("", "'"),
]
for old_char_sequence, new_char in punc_to_replace:
text = text.replace(old_char_sequence, new_char)
# Add full stop if no ending punc
text = text.rstrip(" ")
sentence_enders = {".", "!", "?", "-", ","}
if not any(text.endswith(p) for p in sentence_enders):
text += "."
return text
@dataclass
class Conditionals:
"""
Conditionals for T3 and S3Gen
- T3 conditionals:
- speaker_emb
- clap_emb
- cond_prompt_speech_tokens
- cond_prompt_speech_emb
- emotion_adv
- S3Gen conditionals:
- prompt_token
- prompt_token_len
- prompt_feat
- prompt_feat_len
- embedding
"""
t3: T3Cond
gen: dict
def to(self, device):
self.t3 = self.t3.to(device=device)
for k, v in self.gen.items():
if torch.is_tensor(v):
self.gen[k] = v.to(device=device)
return self
def save(self, fpath: Path):
arg_dict = dict(
t3=self.t3.__dict__,
gen=self.gen
)
torch.save(arg_dict, fpath)
@classmethod
def load(cls, fpath, map_location="cpu"):
if isinstance(map_location, str):
map_location = torch.device(map_location)
kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
class ChatterboxTurboTTS:
ENC_COND_LEN = 15 * S3_SR
DEC_COND_LEN = 10 * S3GEN_SR
def __init__(
self,
t3: T3,
s3gen: S3Gen,
ve: VoiceEncoder,
tokenizer: EnTokenizer,
device: str,
conds: Conditionals = None,
):
self.sr = S3GEN_SR # sample rate of synthesized audio
self.t3 = t3
self.s3gen = s3gen
self.ve = ve
self.tokenizer = tokenizer
self.device = device
self.conds = conds
self.watermarker = perth.PerthImplicitWatermarker()
@classmethod
def from_local(cls, ckpt_dir, device) -> 'ChatterboxTurboTTS':
ckpt_dir = Path(ckpt_dir)
# Always load to CPU first for non-CUDA devices to handle CUDA-saved models
if device in ["cpu", "mps"]:
map_location = torch.device('cpu')
else:
map_location = None
ve = VoiceEncoder()
ve.load_state_dict(
load_file(ckpt_dir / "ve.safetensors")
)
ve.to(device).eval()
# Turbo specific hp
hp = T3Config(text_tokens_dict_size=50276)
hp.llama_config_name = "GPT2_medium"
hp.speech_tokens_dict_size = 6563
hp.input_pos_emb = None
hp.speech_cond_prompt_len = 375
hp.use_perceiver_resampler = False
hp.emotion_adv = False
t3 = T3(hp)
t3_state = load_file(ckpt_dir / "t3_turbo_v1.safetensors")
if "model" in t3_state.keys():
t3_state = t3_state["model"][0]
t3.load_state_dict(t3_state)
del t3.tfmr.wte
t3.to(device).eval()
s3gen = S3Gen(meanflow=True)
weights = load_file(ckpt_dir / "s3gen_meanflow.safetensors")
s3gen.load_state_dict(
weights, strict=True
)
s3gen.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if len(tokenizer) != 50276:
print(f"WARNING: Tokenizer len {len(tokenizer)} != 50276")
conds = None
builtin_voice = ckpt_dir / "conds.pt"
if builtin_voice.exists():
conds = Conditionals.load(builtin_voice, map_location=map_location).to(device)
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
@classmethod
def from_pretrained(cls, device) -> 'ChatterboxTurboTTS':
# Check if MPS is available on macOS
if device == "mps" and not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was not built with MPS enabled.")
else:
print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
device = "cpu"
local_path = snapshot_download(
repo_id=REPO_ID,
token=os.getenv("HF_TOKEN") or True,
# Optional: Filter to download only what you need
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"]
)
return cls.from_local(local_path, device)
def norm_loudness(self, wav, sr, target_lufs=-27):
try:
meter = ln.Meter(sr)
loudness = meter.integrated_loudness(wav)
gain_db = target_lufs - loudness
gain_linear = 10.0 ** (gain_db / 20.0)
if math.isfinite(gain_linear) and gain_linear > 0.0:
wav = wav * gain_linear
except Exception as e:
print(f"Warning: Error in norm_loudness, skipping: {e}")
return wav
def prepare_conditionals(self, wav_fpath, exaggeration=0.5, norm_loudness=True):
## Load and norm reference wav
s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
assert len(s3gen_ref_wav) / _sr > 5.0, "Audio prompt must be longer than 5 seconds!"
if norm_loudness:
s3gen_ref_wav = self.norm_loudness(s3gen_ref_wav, _sr)
ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
# Speech cond prompt tokens
if plen := self.t3.hp.speech_cond_prompt_len:
s3_tokzr = self.s3gen.tokenizer
t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
# Voice-encoder speaker embedding
ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
t3_cond = T3Cond(
speaker_emb=ve_embed,
cond_prompt_speech_tokens=t3_cond_prompt_tokens,
emotion_adv=exaggeration * torch.ones(1, 1, 1),
).to(device=self.device)
self.conds = Conditionals(t3_cond, s3gen_ref_dict)
def generate(
self,
text,
repetition_penalty=1.2,
min_p=0.00,
top_p=0.95,
audio_prompt_path=None,
exaggeration=0.0,
cfg_weight=0.0,
temperature=0.8,
top_k=1000,
norm_loudness=True,
):
if audio_prompt_path:
self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration, norm_loudness=norm_loudness)
else:
assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
if cfg_weight > 0.0 or exaggeration > 0.0 or min_p > 0.0:
logger.warning("CFG, min_p and exaggeration are not supported by Turbo version and will be ignored.")
# Norm and tokenize text
text = punc_norm(text)
text_tokens = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
text_tokens = text_tokens.input_ids.to(self.device)
speech_tokens = self.t3.inference_turbo(
t3_cond=self.conds.t3,
text_tokens=text_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
# Remove OOV tokens and add silence to end
speech_tokens = speech_tokens[speech_tokens < 6561]
speech_tokens = speech_tokens.to(self.device)
silence = torch.tensor([S3GEN_SIL, S3GEN_SIL, S3GEN_SIL]).long().to(self.device)
speech_tokens = torch.cat([speech_tokens, silence])
wav, _ = self.s3gen.inference(
speech_tokens=speech_tokens,
ref_dict=self.conds.gen,
n_cfm_timesteps=2,
)
wav = wav.squeeze(0).detach().cpu().numpy()
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
return torch.from_numpy(watermarked_wav).unsqueeze(0)

View File

@ -4,7 +4,6 @@ import librosa
import torch
import perth
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from .models.s3tokenizer import S3_SR
from .models.s3gen import S3GEN_SR, S3Gen
@ -52,7 +51,7 @@ class ChatterboxVC:
s3gen = S3Gen()
s3gen.load_state_dict(
load_file(ckpt_dir / "s3gen.safetensors"), strict=False
torch.load(ckpt_dir / "s3gen.pt", map_location=map_location)
)
s3gen.to(device).eval()
@ -68,7 +67,7 @@ class ChatterboxVC:
print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
device = "cpu"
for fpath in ["s3gen.safetensors", "conds.pt"]:
for fpath in ["s3gen.pt", "conds.pt"]:
local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath)
return cls.from_local(Path(local_path).parent, device)

91
voice_conversion.py Normal file
View File

@ -0,0 +1,91 @@
from tqdm import tqdm
import sys
import torch
import shutil
import perth
from pathlib import Path
import argparse
import os
import librosa
import soundfile as sf
from chatterbox.models.s3tokenizer import S3_SR
from chatterbox.models.s3gen import S3GEN_SR, S3Gen
AUDIO_EXTENSIONS = ["wav", "mp3", "flac", "opus"]
@torch.inference_mode()
def main():
parser = argparse.ArgumentParser(description="Voice Conversion")
parser.add_argument("input", type=str, help="Path to input (a sample or folder of samples).")
parser.add_argument("target_speaker", type=str, help="Path to the sample for the target speaker.")
parser.add_argument("-o", "--output_folder", type=str, default="vc_outputs")
parser.add_argument("-g", "--gpu_id", type=int, default=None)
parser.add_argument("-m", "--mps", action="store_true", help="Use MPS (Metal) on macOS")
parser.add_argument("--no-watermark", action="store_true", help="Skip watermarking")
args = parser.parse_args()
# Folders
input = Path(args.input)
output_folder = Path(args.output_folder)
output_orig_folder = output_folder / "input"
output_vc_folder = output_folder / "output"
ref_folder = output_vc_folder / "target"
output_orig_folder.mkdir(exist_ok=True, parents=True)
output_vc_folder.mkdir(exist_ok=True)
ref_folder.mkdir(exist_ok=True)
# Device selection with MPS support
if args.mps:
if torch.backends.mps.is_available():
device = torch.device("mps")
print("Using MPS (Metal) device")
else:
print("MPS not available, falling back to CPU")
device = torch.device("cpu")
elif args.gpu_id is not None:
device = torch.device(f"cuda:{args.gpu_id}")
else:
device = torch.device("cpu")
# Determine map_location for loading
map_location = torch.device('cpu') if device.type in ['cpu', 'mps'] else None
## s3gen
s3g_fp = "checkpoints/s3gen.pt"
s3gen = S3Gen()
s3gen.load_state_dict(torch.load(s3g_fp, map_location=map_location))
s3gen.to(device)
s3gen.eval()
wav_fpaths = []
if input.is_dir():
for ext in AUDIO_EXTENSIONS:
wav_fpaths += list(input.glob(f"*.{ext}"))
else:
wav_fpaths.append(input)
assert wav_fpaths, f"Didn't find any audio in '{input}'"
ref_24, _ = librosa.load(args.target_speaker, sr=S3GEN_SR, duration=10)
ref_24 = torch.tensor(ref_24).float()
shutil.copy(args.target_speaker, ref_folder / Path(args.target_speaker).name)
if not args.no_watermark:
watermarker = perth.PerthImplicitWatermarker()
for wav_fpath in tqdm(wav_fpaths):
shutil.copy(wav_fpath, output_orig_folder / wav_fpath.name)
audio_16, _ = librosa.load(str(wav_fpath), sr=S3_SR)
audio_16 = torch.tensor(audio_16).float().to(device)[None, ]
s3_tokens, _ = s3gen.tokenizer(audio_16)
wav = s3gen(s3_tokens.to(device), ref_24, S3GEN_SR)
wav = wav.view(-1).cpu().numpy()
if not args.no_watermark:
wav = watermarker.apply_watermark(wav, sample_rate=S3GEN_SR)
save_path = output_vc_folder / wav_fpath.name
sf.write(str(save_path), wav, samplerate=S3GEN_SR)
if __name__ == "__main__":
main()