Compare commits

..

1 Commits

Author SHA1 Message Date
manmay nakhashi 7cef1d6bea
remove src 2025-09-04 22:50:17 +05:30
21 changed files with 374 additions and 1140 deletions

View File

@ -1,23 +0,0 @@
name: Test Installation
on:
push:
branches: [ "master" ]
pull_request:
branches: [ "master" ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Test Standard Install
run: |
pip install -e .

Binary file not shown.

Before

Width:  |  Height:  |  Size: 763 KiB

104
README.md
View File

@ -1,36 +1,47 @@
![Chatterbox Turbo Image](./Chatterbox-Turbo.jpg)
<img width="1200" alt="Chatterbox Multilingual" src="https://github.com/user-attachments/assets/53b411b4-d7b8-4b18-87f3-335145837684" />
# Chatterbox TTS
[![Alt Text](https://img.shields.io/badge/listen-demo_samples-blue)](https://resemble-ai.github.io/chatterbox_turbo_demopage/)
[![Alt Text](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/ResembleAI/chatterbox-turbo-demo)
[![Alt Text](https://img.shields.io/badge/listen-demo_samples-blue)](https://resemble-ai.github.io/chatterbox_demopage/)
[![Alt Text](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/ResembleAI/Chatterbox)
[![Alt Text](https://static-public.podonos.com/badges/insight-on-pdns-sm-dark.svg)](https://podonos.com/resembleai/chatterbox)
[![Discord](https://img.shields.io/discord/1377773249798344776?label=join%20discord&logo=discord&style=flat)](https://discord.gg/rJq9cRJBJ6)
_Made with ♥️ by <a href="https://resemble.ai" target="_blank"><img width="100" alt="resemble-logo-horizontal" src="https://github.com/user-attachments/assets/35cf756b-3506-4943-9c72-c05ddfa4e525" /></a>
**Chatterbox** is a family of three state-of-the-art, open-source text-to-speech models by Resemble AI.
We're excited to introduce **Chatterbox Multilingual**, [Resemble AI's](https://resemble.ai) first production-grade open source TTS model supporting **23 languages** out of the box. Licensed under MIT, Chatterbox has been benchmarked against leading closed-source systems like ElevenLabs, and is consistently preferred in side-by-side evaluations.
We are excited to introduce **Chatterbox-Turbo**, our most efficient model yet. Built on a streamlined 350M parameter architecture, **Turbo** delivers high-quality speech with less compute and VRAM than our previous models. We have also distilled the speech-token-to-mel decoder, previously a bottleneck, reducing generation from 10 steps to just **one**, while retaining high-fidelity audio output.
**Paralinguistic tags** are now native to the Turbo model, allowing you to use `[cough]`, `[laugh]`, `[chuckle]`, and more to add distinct realism. While Turbo was built primarily for low-latency voice agents, it excels at narration and creative workflows.
Whether you're working on memes, videos, games, or AI agents, Chatterbox brings your content to life across languages. It's also the first open source TTS model to support **emotion exaggeration control** with robust **multilingual zero-shot voice cloning**. Try the english only version now on our [English Hugging Face Gradio app.](https://huggingface.co/spaces/ResembleAI/Chatterbox). Or try the multilingual version on our [Multilingual Hugging Face Gradio app.](https://huggingface.co/spaces/ResembleAI/Chatterbox-Multilingual-TTS).
If you like the model but need to scale or tune it for higher accuracy, check out our competitively priced TTS service (<a href="https://resemble.ai">link</a>). It delivers reliable performance with ultra-low latency of sub 200ms—ideal for production use in agents, applications, or interactive media.
<img width="1200" height="600" alt="Podonos Turbo Eval" src="https://storage.googleapis.com/chatterbox-demo-samples/turbo/podonos_turbo.png" />
# Key Details
- Multilingual, zero-shot TTS supporting 23 languages
- SoTA zeroshot English TTS
- 0.5B Llama backbone
- Unique exaggeration/intensity control
- Ultra-stable with alignment-informed inference
- Trained on 0.5M hours of cleaned data
- Watermarked outputs
- Easy voice conversion script
- [Outperforms ElevenLabs](https://podonos.com/resembleai/chatterbox)
### ⚡ Model Zoo
# Supported Languages
Arabic (ar) • Danish (da) • German (de) • Greek (el) • English (en) • Spanish (es) • Finnish (fi) • French (fr) • Hebrew (he) • Hindi (hi) • Italian (it) • Japanese (ja) • Korean (ko) • Malay (ms) • Dutch (nl) • Norwegian (no) • Polish (pl) • Portuguese (pt) • Russian (ru) • Swedish (sv) • Swahili (sw) • Turkish (tr) • Chinese (zh)
# Tips
- **General Use (TTS and Voice Agents):**
- Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clips language. To mitigate this, set `cfg_weight` to `0`.
- The default settings (`exaggeration=0.5`, `cfg_weight=0.5`) work well for most prompts across all languages.
- If the reference speaker has a fast speaking style, lowering `cfg_weight` to around `0.3` can improve pacing.
Choose the right model for your application.
- **Expressive or Dramatic Speech:**
- Try lower `cfg_weight` values (e.g. `~0.3`) and increase `exaggeration` to around `0.7` or higher.
- Higher `exaggeration` tends to speed up speech; reducing `cfg_weight` helps compensate with slower, more deliberate pacing.
| Model | Size | Languages | Key Features | Best For | 🤗 | Examples |
|:----------------------------------------------------------------------------------------------------------------| :--- | :--- |:--------------------------------------------------------|:---------------------------------------------|:--------------------------------------------------------------------------| :--- |
| **Chatterbox-Turbo** | **350M** | **English** | Paralinguistic Tags (`[laugh]`), Lower Compute and VRAM | Zero-shot voice agents, Production | [Demo](https://huggingface.co/spaces/ResembleAI/chatterbox-turbo-demo) | [Listen](https://resemble-ai.github.io/chatterbox_turbo_demopage/) |
| Chatterbox-Multilingual [(Language list)](#supported-languages) | 500M | 23+ | Zero-shot cloning, Multiple Languages | Global applications, Localization | [Demo](https://huggingface.co/spaces/ResembleAI/Chatterbox-Multilingual-TTS) | [Listen](https://resemble-ai.github.io/chatterbox_demopage/) |
| Chatterbox [(Tips and Tricks)](#original-chatterbox-tips) | 500M | English | CFG & Exaggeration tuning | General zero-shot TTS with creative controls | [Demo](https://huggingface.co/spaces/ResembleAI/Chatterbox) | [Listen](https://resemble-ai.github.io/chatterbox_demopage/) |
## Installation
# Installation
```shell
pip install chatterbox-tts
```
@ -46,34 +57,11 @@ pip install -e .
```
We developed and tested Chatterbox on Python 3.11 on Debian 11 OS; the versions of the dependencies are pinned in `pyproject.toml` to ensure consistency. You can modify the code or dependencies in this installation mode.
## Usage
##### Chatterbox-Turbo
# Usage
```python
import torchaudio as ta
import torch
from chatterbox.tts_turbo import ChatterboxTurboTTS
# Load the Turbo model
model = ChatterboxTurboTTS.from_pretrained(device="cuda")
# Generate with Paralinguistic Tags
text = "Hi there, Sarah here from MochaFone calling you back [chuckle], have you got one minute to chat about the billing issue?"
# Generate audio (requires a reference clip for voice cloning)
wav = model.generate(text, audio_prompt_path="your_10s_ref_clip.wav")
ta.save("test-turbo.wav", wav, model.sr)
```
##### Chatterbox and Chatterbox-Multilingual
```python
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
# English example
model = ChatterboxTTS.from_pretrained(device="cuda")
@ -100,21 +88,14 @@ ta.save("test-2.wav", wav, model.sr)
```
See `example_tts.py` and `example_vc.py` for more examples.
## Supported Languages
Arabic (ar) • Danish (da) • German (de) • Greek (el) • English (en) • Spanish (es) • Finnish (fi) • French (fr) • Hebrew (he) • Hindi (hi) • Italian (it) • Japanese (ja) • Korean (ko) • Malay (ms) • Dutch (nl) • Norwegian (no) • Polish (pl) • Portuguese (pt) • Russian (ru) • Swedish (sv) • Swahili (sw) • Turkish (tr) • Chinese (zh)
# Acknowledgements
- [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice)
- [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning)
- [HiFT-GAN](https://github.com/yl4579/HiFTNet)
- [Llama 3](https://github.com/meta-llama/llama3)
- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer)
## Original Chatterbox Tips
- **General Use (TTS and Voice Agents):**
- Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clips language. To mitigate this, set `cfg_weight` to `0`.
- The default settings (`exaggeration=0.5`, `cfg_weight=0.5`) work well for most prompts across all languages.
- If the reference speaker has a fast speaking style, lowering `cfg_weight` to around `0.3` can improve pacing.
- **Expressive or Dramatic Speech:**
- Try lower `cfg_weight` values (e.g. `~0.3`) and increase `exaggeration` to around `0.7` or higher.
- Higher `exaggeration` tends to speed up speech; reducing `cfg_weight` helps compensate with slower, more deliberate pacing.
## Built-in PerTh Watermarking for Responsible AI
# Built-in PerTh Watermarking for Responsible AI
Every audio file generated by Chatterbox includes [Resemble AI's Perth (Perceptual Threshold) Watermarker](https://github.com/resemble-ai/perth) - imperceptible neural watermarks that survive MP3 compression, audio editing, and common manipulations while maintaining nearly 100% detection accuracy.
@ -142,18 +123,11 @@ print(f"Extracted watermark: {watermark}")
```
## Official Discord
# Official Discord
👋 Join us on [Discord](https://discord.gg/rJq9cRJBJ6) and let's build something awesome together!
## Acknowledgements
- [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice)
- [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning)
- [HiFT-GAN](https://github.com/yl4579/HiFTNet)
- [Llama 3](https://github.com/meta-llama/llama3)
- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer)
## Citation
# Citation
If you find this model useful, please consider citing.
```
@misc{chatterboxtts2025,
@ -164,5 +138,5 @@ If you find this model useful, please consider citing.
note = {GitHub repository}
}
```
## Disclaimer
# Disclaimer
Don't use this model to do bad things. Prompts are sourced from freely available data on the internet.

View File

@ -1,14 +0,0 @@
import torchaudio as ta
import torch
from chatterbox.tts_turbo import ChatterboxTurboTTS
# Load the Turbo model
model = ChatterboxTurboTTS.from_pretrained(device="cuda")
# Generate with Paralinguistic Tags
text = "Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and all that jazz. Would you like me to get some prices for you?"
# Generate audio (requires a reference clip for voice cloning)
# wav = model.generate(text, audio_prompt_path="your_10s_ref_clip.wav")
wav = model.generate(text)
ta.save("test-turbo.wav", wav, model.sr)

View File

@ -1,186 +0,0 @@
import random
import numpy as np
import torch
import gradio as gr
from chatterbox.tts_turbo import ChatterboxTurboTTS
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EVENT_TAGS = [
"[clear throat]", "[sigh]", "[shush]", "[cough]", "[groan]",
"[sniff]", "[gasp]", "[chuckle]", "[laugh]"
]
# --- REFINED CSS ---
# 1. tag-container: Forces the row to wrap items instead of scrolling. Removes borders/backgrounds.
# 2. tag-btn: Sets the specific look (indigo theme) and stops them from stretching.
CUSTOM_CSS = """
.tag-container {
display: flex !important;
flex-wrap: wrap !important; /* This fixes the one-per-line issue */
gap: 8px !important;
margin-top: 5px !important;
margin-bottom: 10px !important;
border: none !important;
background: transparent !important;
}
.tag-btn {
min-width: fit-content !important;
width: auto !important;
height: 32px !important;
font-size: 13px !important;
background: #eef2ff !important;
border: 1px solid #c7d2fe !important;
color: #3730a3 !important;
border-radius: 6px !important;
padding: 0 10px !important;
margin: 0 !important;
box-shadow: none !important;
}
.tag-btn:hover {
background: #c7d2fe !important;
transform: translateY(-1px);
}
"""
INSERT_TAG_JS = """
(tag_val, current_text) => {
const textarea = document.querySelector('#main_textbox textarea');
if (!textarea) return current_text + " " + tag_val;
const start = textarea.selectionStart;
const end = textarea.selectionEnd;
let prefix = " ";
let suffix = " ";
if (start === 0) prefix = "";
else if (current_text[start - 1] === ' ') prefix = "";
if (end < current_text.length && current_text[end] === ' ') suffix = "";
return current_text.slice(0, start) + prefix + tag_val + suffix + current_text.slice(end);
}
"""
def set_seed(seed: int):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
def load_model():
print(f"Loading Chatterbox-Turbo on {DEVICE}...")
model = ChatterboxTurboTTS.from_pretrained(DEVICE)
return model
def generate(
model,
text,
audio_prompt_path,
temperature,
seed_num,
min_p,
top_p,
top_k,
repetition_penalty,
norm_loudness
):
if model is None:
model = ChatterboxTurboTTS.from_pretrained(DEVICE)
if seed_num != 0:
set_seed(int(seed_num))
wav = model.generate(
text,
audio_prompt_path=audio_prompt_path,
temperature=temperature,
min_p=min_p,
top_p=top_p,
top_k=int(top_k),
repetition_penalty=repetition_penalty,
norm_loudness=norm_loudness,
)
return (model.sr, wav.squeeze(0).numpy())
with gr.Blocks(title="Chatterbox Turbo", css=CUSTOM_CSS) as demo:
gr.Markdown("# ⚡ Chatterbox Turbo")
model_state = gr.State(None)
with gr.Row():
with gr.Column():
text = gr.Textbox(
value="Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and um all that jazz. Would you like me to get some prices for you?",
label="Text to synthesize (max chars 300)",
max_lines=5,
elem_id="main_textbox"
)
# --- Event Tags ---
# Switched back to Row, but applied specific CSS to force wrapping
with gr.Row(elem_classes=["tag-container"]):
for tag in EVENT_TAGS:
# elem_classes targets the button specifically
btn = gr.Button(tag, elem_classes=["tag-btn"])
btn.click(
fn=None,
inputs=[btn, text],
outputs=text,
js=INSERT_TAG_JS
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio File",
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_random_podcast.wav"
)
run_btn = gr.Button("Generate ⚡", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Output Audio")
with gr.Accordion("Advanced Options", open=False):
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
temp = gr.Slider(0.05, 2.0, step=.05, label="Temperature", value=0.8)
top_p = gr.Slider(0.00, 1.00, step=0.01, label="Top P", value=0.95)
top_k = gr.Slider(0, 1000, step=10, label="Top K", value=1000)
repetition_penalty = gr.Slider(1.00, 2.00, step=0.05, label="Repetition Penalty", value=1.2)
min_p = gr.Slider(0.00, 1.00, step=0.01, label="Min P (Set to 0 to disable)", value=0.00)
norm_loudness = gr.Checkbox(value=True, label="Normalize Loudness (-27 LUFS)")
demo.load(fn=load_model, inputs=[], outputs=model_state)
run_btn.click(
fn=generate,
inputs=[
model_state,
text,
ref_wav,
temp,
seed_num,
min_p,
top_p,
top_k,
repetition_penalty,
norm_loudness,
],
outputs=audio_output,
)
if __name__ == "__main__":
demo.queue(
max_size=50,
default_concurrency_limit=1,
).launch(share=True)

View File

@ -3,6 +3,7 @@ import numpy as np
import torch
from chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
import gradio as gr
import spaces
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {DEVICE}")
@ -175,6 +176,7 @@ def resolve_audio_prompt(language_id: str, provided_path: str | None) -> str | N
return LANGUAGE_CONFIG.get(language_id, {}).get("audio")
@spaces.GPU
def generate_tts_audio(
text_input: str,
language_id: str,

View File

@ -1,15 +1,15 @@
[project]
name = "chatterbox-tts"
version = "0.1.6"
version = "0.1.3"
description = "Chatterbox: Open Source TTS and Voice Conversion by Resemble AI"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.9"
license = {file = "LICENSE"}
authors = [
{name = "resemble-ai", email = "engineering@resemble.ai"}
]
dependencies = [
"numpy>=1.24.0,<1.26.0",
"numpy>=1.26.0",
"librosa==0.11.0",
"s3tokenizer",
"torch==2.6.0",
@ -19,11 +19,8 @@ dependencies = [
"resemble-perth==1.0.1",
"conformer==0.3.2",
"safetensors==0.5.3",
"spacy-pkuseg",
"pykakasi==2.3.0",
"gradio==5.44.1",
"pyloudnorm",
"omegaconf"
"pkuseg ==0.0.25",
"pykakasi==2.3.0"
]
[project.urls]

View File

@ -1,2 +1 @@
S3GEN_SR = 24000
S3GEN_SIL = 4299

View File

@ -20,7 +20,6 @@ from .utils.mask import add_optional_chunk_mask
from .matcha.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, \
TimestepEmbedding, Upsample1D
from .matcha.transformer import BasicTransformerBlock
from .utils.intmeanflow import get_intmeanflow_time_mixer
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
@ -96,6 +95,8 @@ class CausalConv1d(torch.nn.Conv1d):
x = F.pad(x, self.causal_padding)
x = super(CausalConv1d, self).forward(x)
return x
class ConditionalDecoder(nn.Module):
def __init__(
self,
@ -109,7 +110,6 @@ class ConditionalDecoder(nn.Module):
num_mid_blocks=12,
num_heads=8,
act_fn="gelu",
meanflow=False,
):
"""
This decoder requires an input with the same shape of the target. So, if your text content
@ -117,7 +117,6 @@ class ConditionalDecoder(nn.Module):
"""
super().__init__()
channels = tuple(channels)
self.meanflow = meanflow
self.in_channels = in_channels
self.out_channels = out_channels
self.causal = causal
@ -128,7 +127,6 @@ class ConditionalDecoder(nn.Module):
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
@ -217,14 +215,6 @@ class ConditionalDecoder(nn.Module):
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
self.time_embed_mixer = None
if self.meanflow:
self.time_embed_mixer = get_intmeanflow_time_mixer(time_embed_dim)
@property
def dtype(self):
return self.final_proj.weight.dtype
def initialize_weights(self):
for m in self.modules():
@ -240,16 +230,15 @@ class ConditionalDecoder(nn.Module):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None, r=None):
def forward(self, x, mask, mu, t, spks=None, cond=None):
"""Forward pass of the UNet1DConditional model.
Args:
x: (B, 80, T)
mask (_type_)
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional) Defaults to None.
cond (_type_, optional)
r: end time for meanflow mode (shape (1,) tensor)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
@ -258,15 +247,10 @@ class ConditionalDecoder(nn.Module):
Returns:
_type_: _description_
"""
t = self.time_embeddings(t).to(t.dtype)
t = self.time_mlp(t)
if self.meanflow:
r = self.time_embeddings(r).to(t.dtype)
r = self.time_mlp(r)
concat_embed = torch.cat([t, r], dim=1)
t = self.time_embed_mixer(concat_embed)
x = pack([x, mu], "b * t")[0]
if spks is not None:

View File

@ -21,49 +21,205 @@ import torch.nn as nn
from torch.nn import functional as F
from .utils.mask import make_pad_mask
from .configs import CFM_PARAMS
from omegaconf import DictConfig
logger = logging.getLogger(__name__)
class MaskedDiffWithXvec(torch.nn.Module):
def __init__(
self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 4096,
input_frame_rate: int = 50,
only_mask_loss: bool = True,
encoder: torch.nn.Module = None,
length_regulator: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {
'in_channels': 240,
'out_channel': 80,
'spk_emb_dim': 80,
'n_spks': 1,
'cfm_params': CFM_PARAMS,
'decoder_params': {
'channels': [256, 256],
'dropout': 0.0,
'attention_head_dim': 64,
'n_blocks': 4,
'num_mid_blocks': 12,
'num_heads': 8,
'act_fn': 'gelu',
}
},
mel_feat_conf: Dict = {
'n_fft': 1024,
'num_mels': 80,
'sampling_rate': 22050,
'hop_size': 256,
'win_size': 1024,
'fmin': 0,
'fmax': 8000
}
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.mel_feat_conf = mel_feat_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
logging.info(f"input frame rate={self.input_frame_rate}")
self.input_embedding = nn.Embedding(vocab_size, input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
self.encoder = encoder
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
self.decoder = decoder
self.length_regulator = length_regulator
self.only_mask_loss = only_mask_loss
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)
def _repeat_batch_dim(tnsr, B, ndim):
"repeat batch dimension if it's equal to 1"
if tnsr is not None:
# add missing batch dim if needed
while tnsr.ndim < ndim:
tnsr = tnsr[None]
# repeat batch dim as needed
if B > 1 and tnsr.size(0) == 1:
tnsr = tnsr.repeat(B, *([1] * (ndim - 1)))
assert tnsr.ndim == ndim, f"Expected {ndim=}, got {tnsr.ndim=}"
return tnsr
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h)
h, h_lengths = self.length_regulator(h, feat_len)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h)
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds
)
return {'loss': loss}
@torch.inference_mode()
def inference(self,
token,
token_len,
prompt_token,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding,
flow_cache):
if self.fp16 is True:
prompt_feat = prompt_feat.half()
embedding = embedding.half()
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
# Check for out-of-bounds token IDs
vocab_size = self.input_embedding.num_embeddings
if token.max() >= vocab_size or token.min() < 0:
logging.warning(f"S3Gen: Token IDs out of bounds: min={token.min().item()}, max={token.max().item()}, vocab_size={vocab_size}")
token = self.input_embedding(torch.clamp(token, min=0, max=vocab_size-1)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h)
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
# get conditions
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat, flow_cache = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10,
prompt_len=mel_len1,
flow_cache=flow_cache
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat.float(), flow_cache
class CausalMaskedDiffWithXvec(torch.nn.Module):
def __init__(self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 6561,
input_frame_rate: int = 25,
only_mask_loss: bool = True,
token_mel_ratio: int = 2,
pre_lookahead_len: int = 3,
encoder: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
'cfm_params': DictConfig(
{'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7,
'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0,
'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8,
'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
def __init__(
self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 6561,
input_frame_rate: int = 25,
only_mask_loss: bool = True,
token_mel_ratio: int = 2,
pre_lookahead_len: int = 3,
encoder: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {
'in_channels': 240,
'out_channel': 80,
'spk_emb_dim': 80,
'n_spks': 1,
'cfm_params': CFM_PARAMS,
'decoder_params': {
'channels': [256, 256],
'dropout': 0.0,
'attention_head_dim': 64,
'n_blocks': 4,
'num_mid_blocks': 12,
'num_heads': 8,
'act_fn': 'gelu',
}
},
mel_feat_conf: Dict = {
'n_fft': 1024,
'num_mels': 80,
'sampling_rate': 22050,
'hop_size': 256,
'win_size': 1024,
'fmin': 0,
'fmax': 8000
}
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
@ -82,51 +238,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
self.token_mel_ratio = token_mel_ratio
self.pre_lookahead_len = pre_lookahead_len
# NOTE: copied in from cosyvoice repo
def compute_loss(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device) # (B, 80, T)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)
# NOTE unified training, static_chunk_size > 0 or = 0
# streaming = True if random.random() < 0.5 else False
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) # (B, T, 1)
token = self.input_embedding(torch.clamp(token, min=0)) * mask # (B, T, emb)
# text encode
h, h_lengths = self.encoder(token, token_len) # (B, T, C) -> (B, 2T, C)
h = self.encoder_proj(h)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :, :index] = feat[i, :, :index]
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
loss, _ = self.decoder.compute_loss(
feat.contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds,
# streaming=streaming,
)
return {'loss': loss}
# FIXME: this was missing - just putting it in as false
self.fp16 = False
@torch.inference_mode()
def inference(self,
@ -137,62 +250,41 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
prompt_feat,
prompt_feat_len,
embedding,
finalize,
n_timesteps=10,
noised_mels=None,
meanflow=False):
# token: (B, n_toks)
# token_len: (B,)
B = token.size(0)
finalize):
if self.fp16 is True:
prompt_feat = prompt_feat.half()
embedding = embedding.half()
assert token.shape[0] == 1
# xvec projection
embedding = torch.atleast_2d(embedding)
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding) # (1 or B, emb_dim)
# adjust shapes (batching logic)
prompt_token = _repeat_batch_dim(prompt_token, B, ndim=2) # (B, n_prompt)
prompt_token_len = _repeat_batch_dim(prompt_token_len, B, ndim=1) # (B,)
prompt_feat = _repeat_batch_dim(prompt_feat, B, ndim=3) # (B, n_feat, feat_dim=80)
prompt_feat_len = _repeat_batch_dim(prompt_feat_len, B, ndim=1) # (B,) or None
embedding = _repeat_batch_dim(embedding, B, ndim=2) # (B, emb_dim)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
if (token >= self.vocab_size).any():
logger.error(f"{token.max()}>{self.vocab_size}\n out-of-range special tokens found in flow, fix inputs!")
token = self.input_embedding(token.long()) * mask
token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
# text encode
h, h_masks = self.encoder(token, token_len)
h, h_lengths = self.encoder(token, token_len)
if finalize is False:
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
h_lengths = h_masks.sum(dim=-1).squeeze(dim=-1)
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
h = self.encoder_proj(h)
# # get conditions
conds = torch.zeros([B, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
# get conditions
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(h_lengths)).unsqueeze(1).to(h)
if mask.shape[0] != B:
mask = mask.repeat(B, 1, 1)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat, _ = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask,
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=n_timesteps,
noised_mels=noised_mels,
meanflow=meanflow,
n_timesteps=10
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat, None # NOTE jrm: why are they returning None here?
return feat.float(), None # NOTE jrm: why are they returning None here?

View File

@ -16,11 +16,6 @@ import torch
import torch.nn.functional as F
from .matcha.flow_matching import BASECFM
from .configs import CFM_PARAMS
from tqdm import tqdm
def cast_all(*args, dtype):
return [a if (not a.dtype.is_floating_point) or a.dtype == dtype else a.to(dtype) for a in args]
class ConditionalCFM(BASECFM):
@ -37,6 +32,7 @@ class ConditionalCFM(BASECFM):
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
# Just change the architecture of the estimator here
self.estimator = estimator
self.lock = threading.Lock()
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
@ -58,8 +54,6 @@ class ConditionalCFM(BASECFM):
shape: (batch_size, n_feats, mel_timesteps)
"""
raise NotImplementedError("unused, needs updating for meanflow model")
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
cache_size = flow_cache.shape[2]
# fix prompt and overlap part mu and z
@ -75,7 +69,7 @@ class ConditionalCFM(BASECFM):
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
def solve_euler(self, x, t_span, mu, mask, spks, cond, meanflow=False):
def solve_euler(self, x, t_span, mu, mask, spks, cond):
"""
Fixed euler solver for ODEs.
Args:
@ -89,60 +83,65 @@ class ConditionalCFM(BASECFM):
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
meanflow: meanflow mode
"""
in_dtype = x.dtype
x, t_span, mu, mask, spks, cond = cast_all(x, t_span, mu, mask, spks, cond, dtype=self.estimator.dtype)
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
t = t.unsqueeze(dim=0)
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
# Duplicated batch dims are for CFG
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
B, T = mu.size(0), x.size(2)
x_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype)
mask_in = torch.zeros([2 * B, 1, T], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2 * B ], device=x.device, dtype=x.dtype)
spks_in = torch.zeros([2 * B, 80 ], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype)
r_in = torch.zeros([2 * B ], device=x.device, dtype=x.dtype) # (only used for meanflow)
for t, r in zip(t_span[:-1], t_span[1:]):
t = t.unsqueeze(dim=0)
r = r.unsqueeze(dim=0)
# Shapes:
# x_in ( 2B, 80, T )
# mask_in ( 2B, 1, T )
# mu_in ( 2B, 80, T )
# t_in ( 2B, )
# spks_in ( 2B, 80, )
# cond_in ( 2B, 80, T )
# r_in ( 2B, )
# x ( B, 80, T )
# mask ( B, 1, T )
# mu ( B, 80, T )
# t ( B, )
# spks ( B, 80, )
# cond ( B, 80, T )
# r ( B, )
x_in[:B] = x_in[B:] = x
mask_in[:B] = mask_in[B:] = mask
mu_in[:B] = mu
t_in[:B] = t_in[B:] = t
spks_in[:B] = spks
cond_in[:B] = cond
r_in[:B] = r_in[B:] = r # (only used for meanflow)
dxdt = self.estimator.forward(
x=x_in, mask=mask_in, mu=mu_in, t=t_in, spks=spks_in, cond=cond_in,
r=r_in if meanflow else None,
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
for step in range(1, len(t_span)):
# Classifier-Free Guidance inference introduced in VoiceBox
x_in[:] = x
mask_in[:] = mask
mu_in[0] = mu
t_in[:] = t.unsqueeze(0)
spks_in[0] = spks
cond_in[0] = cond
dphi_dt = self.forward_estimator(
x_in, mask_in,
mu_in, t_in,
spks_in,
cond_in
)
dxdt, cfg_dxdt = torch.split(dxdt, [B, B], dim=0)
dxdt = ((1.0 + self.inference_cfg_rate) * dxdt - self.inference_cfg_rate * cfg_dxdt)
dt = r - t
x = x + dt * dxdt
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1].float()
return x.to(in_dtype)
def forward_estimator(self, x, mask, mu, t, spks, cond):
if isinstance(self.estimator, torch.nn.Module):
return self.estimator.forward(x, mask, mu, t, spks, cond)
else:
with self.lock:
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
self.estimator.set_input_shape('t', (2,))
self.estimator.set_input_shape('spks', (2, 80))
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
# run trt engine
self.estimator.execute_v2([x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()])
return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss
@ -189,11 +188,10 @@ class ConditionalCFM(BASECFM):
class CausalConditionalCFM(ConditionalCFM):
def __init__(self, in_channels=240, cfm_params=CFM_PARAMS, n_spks=1, spk_emb_dim=80, estimator=None):
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
# TODO: BAD BAD IDEA - IT'LL MESS UP DISTILLATION - SETTING TO NONE
self.rand_noise = None
self.rand_noise = torch.randn([1, 80, 50 * 300])
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, noised_mels=None, meanflow=False):
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
@ -206,41 +204,15 @@ class CausalConditionalCFM(ConditionalCFM):
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
noised_mels: gt mels noised a time t
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
B = mu.size(0)
z = torch.randn_like(mu)
if noised_mels is not None:
prompt_len = mu.size(2) - noised_mels.size(2)
z[..., prompt_len:] = noised_mels
# time steps for reverse diffusion
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
# fix prompt and overlap part mu and z
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if (not meanflow) and (self.t_scheduler == 'cosine'):
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
# NOTE: right now, the only meanflow models are also distilled models, which don't need CFG
# because they were distilled with CFG outputs. We would need to add another hparam and
# change the conditional logic here if we want to use CFG inference with a meanflow model.
if meanflow:
return self.basic_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, meanflow=meanflow), None
def basic_euler(self, x, t_span, mu, mask, spks, cond):
in_dtype = x.dtype
x, t_span, mu, mask, spks, cond = cast_all(x, t_span, mu, mask, spks, cond, dtype=self.estimator.dtype)
print("S3 Token -> Mel Inference...")
for t, r in tqdm(zip(t_span[..., :-1], t_span[..., 1:]), total=t_span.shape[-1] - 1):
t, r = t[None], r[None]
dxdt = self.estimator.forward(x, mask=mask, mu=mu, t=t, spks=spks, cond=cond, r=r)
dt = r - t
x = x + dt * dxdt
return x.to(in_dtype)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None

View File

@ -46,20 +46,15 @@ def get_resampler(src_sr, dst_sr, device):
class S3Token2Mel(torch.nn.Module):
"""
S3Gen's CFM decoder maps S3 speech tokens to mel-spectrograms.
CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms.
TODO: make these modules configurable?
"""
def __init__(self, meanflow=False):
def __init__(self):
super().__init__()
self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz")
self.mel_extractor = mel_spectrogram # TODO: make it a torch module?
self.speaker_encoder = CAMPPlus(
# NOTE: This doesn't affect inference. It turns off activation checkpointing
# (a training optimization), which causes a crazy DDP error with accelerate
memory_efficient=False,
)
self.meanflow = meanflow
self.speaker_encoder = CAMPPlus() # use default args
encoder = UpsampleConformerEncoder(
output_size=512,
@ -89,7 +84,6 @@ class S3Token2Mel(torch.nn.Module):
num_mid_blocks=12,
num_heads=8,
act_fn='gelu',
meanflow=self.meanflow,
)
cfm_params = CFM_PARAMS
decoder = CausalConditionalCFM(
@ -110,11 +104,6 @@ class S3Token2Mel(torch.nn.Module):
params = self.tokenizer.parameters()
return next(params).device
@property
def dtype(self):
params = self.flow.parameters()
return next(params).dtype
def embed_ref(
self,
ref_wav: torch.Tensor,
@ -133,26 +122,23 @@ class S3Token2Mel(torch.nn.Module):
ref_wav = ref_wav.unsqueeze(0) # (B, L)
if ref_wav.size(1) > 10 * ref_sr:
print("WARNING: s3gen received ref longer than 10s")
print("WARNING: cosydec received ref longer than 10s")
ref_wav_24 = ref_wav
if ref_sr != S3GEN_SR:
ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav)
ref_wav_24 = ref_wav_24.to(device=device, dtype=self.dtype)
ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(dtype=self.dtype)
ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device)
ref_mels_24_len = None
# Resample to 16kHz
ref_wav_16 = ref_wav
if ref_sr != S3_SR:
ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav)
ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device)
# Speaker embedding
ref_x_vector = self.speaker_encoder.inference(ref_wav_16.to(dtype=self.dtype))
ref_x_vector = self.speaker_encoder.inference(ref_wav_16)
# Tokenize 16khz reference
ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16.float())
ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16)
# Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms)
if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]:
@ -178,10 +164,7 @@ class S3Token2Mel(torch.nn.Module):
ref_sr: Optional[int],
# pre-computed ref embedding (prod API)
ref_dict: Optional[dict] = None,
n_cfm_timesteps = None,
finalize: bool = False,
speech_token_lens=None,
noised_mels=None,
):
"""
Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
@ -209,21 +192,18 @@ class S3Token2Mel(torch.nn.Module):
if isinstance(ref_dict[rk], np.ndarray):
ref_dict[rk] = torch.from_numpy(ref_dict[rk])
if torch.is_tensor(ref_dict[rk]):
ref_dict[rk] = ref_dict[rk].to(device=self.device, dtype=self.dtype)
ref_dict[rk] = ref_dict[rk].to(self.device)
speech_tokens = torch.atleast_2d(speech_tokens)
if len(speech_tokens.shape) == 1:
speech_tokens = speech_tokens.unsqueeze(0)
# backcompat
if speech_token_lens is None:
speech_token_lens = torch.LongTensor([st.size(-1) for st in speech_tokens]).to(self.device)
# assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now"
speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device)
output_mels, _ = self.flow.inference(
token=speech_tokens,
token_len=speech_token_lens,
finalize=finalize,
noised_mels=noised_mels,
n_timesteps=n_cfm_timesteps,
meanflow=self.meanflow,
**ref_dict,
)
return output_mels
@ -231,15 +211,13 @@ class S3Token2Mel(torch.nn.Module):
class S3Token2Wav(S3Token2Mel):
"""
The decoder of S3Gen is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules.
The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules.
TODO: make these modules configurable?
"""
ignore_state_dict_missing = ("tokenizer._mel_filters", "tokenizer.window")
def __init__(self, meanflow=False):
super().__init__(meanflow)
def __init__(self):
super().__init__()
f0_predictor = ConvRNNF0Predictor()
self.mel2wav = HiFTGenerator(
@ -256,7 +234,6 @@ class S3Token2Wav(S3Token2Mel):
trim_fade = torch.zeros(2 * n_trim)
trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2
self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting)
self.estimator_dtype = "fp32"
def forward(
self,
@ -266,25 +243,9 @@ class S3Token2Wav(S3Token2Mel):
ref_sr: Optional[int],
# pre-computed ref embedding (prod API)
ref_dict: Optional[dict] = None,
finalize: bool = False,
speech_token_lens=None,
skip_vocoder=False,
n_cfm_timesteps=None,
noised_mels=None,
finalize: bool = False
):
"""
Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
NOTE: used for sync synthesis only. Please use `S3GenStreamer` for streaming synthesis.
"""
output_mels = super().forward(
speech_tokens, speech_token_lens=speech_token_lens, ref_wav=ref_wav,
ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize,
n_cfm_timesteps=n_cfm_timesteps, noised_mels=noised_mels,
)
if skip_vocoder:
return output_mels
output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
# TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now.
hift_cache_source = torch.zeros(1, 1, 0).to(self.device)
@ -306,24 +267,14 @@ class S3Token2Wav(S3Token2Mel):
ref_sr: Optional[int] = None,
# pre-computed ref embedding (prod API)
ref_dict: Optional[dict] = None,
n_cfm_timesteps = None,
finalize: bool = False,
speech_token_lens=None,
):
n_cfm_timesteps = n_cfm_timesteps or (2 if self.meanflow else 10)
noise = None
if self.meanflow:
noise = torch.randn(1, 80, speech_tokens.size(-1) * 2, dtype=self.dtype, device=self.device)
output_mels = super().forward(
speech_tokens, speech_token_lens=speech_token_lens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict,
n_cfm_timesteps=n_cfm_timesteps, finalize=finalize, noised_mels=noise,
)
return output_mels
return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
@torch.inference_mode()
def hift_inference(self, speech_feat, cache_source: torch.Tensor = None):
if cache_source is None:
cache_source = torch.zeros(1, 1, 0).to(device=self.device, dtype=self.dtype)
cache_source = torch.zeros(1, 1, 0).to(self.device)
return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source)
@torch.inference_mode()
@ -335,26 +286,11 @@ class S3Token2Wav(S3Token2Mel):
ref_sr: Optional[int] = None,
# pre-computed ref embedding (prod API)
ref_dict: Optional[dict] = None,
# left as a kwarg because this can change input/output size ratio
drop_invalid_tokens=True,
n_cfm_timesteps=None,
speech_token_lens=None,
cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here
finalize: bool = True,
):
# hallucination prevention, drop special tokens
# if drop_invalid_tokens:
# speech_tokens, speech_token_lens = drop_invalid(speech_tokens, pad=S3_QUIET_PAD)
output_mels = self.flow_inference(
speech_tokens,
speech_token_lens=speech_token_lens,
ref_wav=ref_wav,
ref_sr=ref_sr,
ref_dict=ref_dict,
n_cfm_timesteps=n_cfm_timesteps,
finalize=True,
)
output_mels = output_mels.to(dtype=self.dtype) # FIXME (fp16 mode) is this still needed?
output_wavs, output_sources = self.hift_inference(output_mels, None)
output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
output_wavs, output_sources = self.hift_inference(output_mels, cache_source)
# NOTE: ad-hoc method to reduce "spillover" from the reference clip.
output_wavs[:, :len(self.trim_fade)] *= self.trim_fade

View File

@ -1,36 +0,0 @@
import torch
import torch.nn as nn
def get_intmeanflow_time_mixer(dims):
""""
Diagonal init as described in 3.3 https://arxiv.org/pdf/2510.07979
"""
layer = nn.Linear(dims * 2, dims, bias=False)
with torch.no_grad():
target_weight = torch.zeros(dims, 2 * dims)
target_weight[:, 0:dims] = torch.eye(dims)
layer.weight.data = target_weight
return layer
if __name__ == '__main__':
D_example = 6
W_layer = get_intmeanflow_time_mixer(D_example)
print(f"Layer weight (AFTER init):\n{W_layer.weight.data}\n")
e_t = torch.tensor([0., 1., 2., 3., 4., 5.])
e_r = torch.tensor([6., 7., 8., 9., 10., 11.])
e_concat = torch.cat([e_t, e_r]).unsqueeze(0) # Shape (1, 12)
output = W_layer(e_concat)
print(f"Test Input e_t: \n{e_t}")
print(f"Test Input e_r: \n{e_r}")
print(f"Test Input concat: \n{e_concat}")
print(f"Forward Pass Output: \n{output.squeeze(0)}")

View File

@ -181,7 +181,6 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
lengths = lengths.long()
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0,

View File

@ -32,42 +32,6 @@ LLAMA_520M_CONFIG_DICT = dict(
use_cache=True,
)
GPT2_MEDIUM_CONFIG = {
"activation_function": "gelu_new",
"architectures": [
"GPT2LMHeadModel"
],
"attn_pdrop": 0.1,
"bos_token_id": 50256,
"embd_pdrop": 0.1,
"eos_token_id": 50256,
"initializer_range": 0.02,
"layer_norm_epsilon": 1e-05,
"model_type": "gpt2",
"n_ctx": 8196,
"n_embd": 1024,
"hidden_size": 1024,
"n_head": 16,
"n_layer": 24,
"n_positions": 8196,
"n_special": 0,
"predict_special_tokens": True,
"resid_pdrop": 0.1,
"summary_activation": None,
"summary_first_dropout": 0.1,
"summary_proj_to_labels": True,
"summary_type": "cls_index",
"summary_use_proj": True,
"task_specific_params": {
"text-generation": {
"do_sample": True,
"max_length": 50
}
},
"vocab_size": 50276,
}
LLAMA_CONFIGS = {
"Llama_520M": LLAMA_520M_CONFIG_DICT,
"GPT2_medium": GPT2_MEDIUM_CONFIG,
}

View File

@ -28,7 +28,7 @@ class T3Config:
@property
def is_multilingual(self):
return self.text_tokens_dict_size == 2454
return self.text_tokens_dict_size == 2352
@classmethod
def english_only(cls):
@ -38,4 +38,4 @@ class T3Config:
@classmethod
def multilingual(cls):
"""Create configuration for multilingual TTS model."""
return cls(text_tokens_dict_size=2454)
return cls(text_tokens_dict_size=2352)

View File

@ -9,15 +9,9 @@ from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from transformers import LlamaModel, LlamaConfig, GPT2Config, GPT2Model
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
MinPLogitsWarper,
)
from transformers import LlamaModel, LlamaConfig
from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor, MinPLogitsWarper
from .modules.learned_pos_emb import LearnedPositionEmbeddings
from .modules.cond_enc import T3CondEnc, T3Cond
@ -49,20 +43,11 @@ class T3(nn.Module):
def __init__(self, hp=None):
if hp is None:
hp = T3Config.english_only()
hp = T3Config.english_only() # Default to English-only config for backward compatibility
super().__init__()
self.hp = hp
config_dict = LLAMA_CONFIGS[hp.llama_config_name]
self.is_gpt = config_dict.get("model_type") == "gpt2"
if self.is_gpt:
self.cfg = GPT2Config(**config_dict)
self.tfmr = GPT2Model(self.cfg)
else:
self.cfg = LlamaConfig(**config_dict)
self.tfmr = LlamaModel(self.cfg)
self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
self.tfmr = LlamaModel(self.cfg)
self.dim = self.cfg.hidden_size
self.deepspeed_patch_applied = False
@ -72,8 +57,6 @@ class T3(nn.Module):
self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim)
# custom position embedding
self.text_pos_emb = None
self.speech_pos_emb = None
if hp.input_pos_emb == "learned":
max_text_seq_len = hp.max_text_tokens + 2
self.text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, self.dim)
@ -83,7 +66,7 @@ class T3(nn.Module):
# logit projection
self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False)
self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=self.is_gpt)
self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=False)
self.compiled = False
@property
@ -95,9 +78,8 @@ class T3(nn.Module):
Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`.
"""
if t3_cond.cond_prompt_speech_tokens is not None and t3_cond.cond_prompt_speech_emb is None:
t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens)
if not self.is_gpt:
t3_cond.cond_prompt_speech_emb += self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens) + \
self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
return self.cond_enc(t3_cond) # (B, len_cond, dim)
def prepare_input_embeds(
@ -111,7 +93,7 @@ class T3(nn.Module):
# prepare input embeddings (skip backbone tranformer embeddings)
cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim)
text_emb = self.text_emb(text_tokens) # (B, len_text, dim)
if cfg_weight > 0.0 and not self.is_gpt:
if cfg_weight > 0.0:
text_emb[1].zero_() # CFG uncond
speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim)
@ -350,7 +332,7 @@ class T3(nn.Module):
# ---- Generation Loop using kv_cache ----
for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
logits_step = output.logits[:, -1, :]
logits_step = output.logits[:, -1, :]
# CFG combine → (1, V)
cond = logits_step[0:1, :]
uncond = logits_step[1:2, :]
@ -410,81 +392,3 @@ class T3(nn.Module):
# Concatenate all predicted tokens along the sequence dimension.
predicted_tokens = torch.cat(predicted, dim=1) # shape: (B, num_tokens)
return predicted_tokens
@torch.inference_mode()
def inference_turbo(self, t3_cond, text_tokens, temperature=0.8, top_k=1000, top_p=0.95, repetition_penalty=1.2,
max_gen_len=1000):
logits_processors = LogitsProcessorList()
if temperature > 0 and temperature != 1.0:
logits_processors.append(TemperatureLogitsWarper(temperature))
if top_k > 0:
logits_processors.append(TopKLogitsWarper(top_k))
if top_p < 1.0:
logits_processors.append(TopPLogitsWarper(top_p))
if repetition_penalty != 1.0:
logits_processors.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
speech_start_token = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])
embeds, _ = self.prepare_input_embeds(
t3_cond=t3_cond,
text_tokens=text_tokens,
speech_tokens=speech_start_token,
cfg_weight=0.0,
)
generated_speech_tokens = []
llm_outputs = self.tfmr(
inputs_embeds=embeds,
use_cache=True
)
hidden_states = llm_outputs[0]
past_key_values = llm_outputs.past_key_values
speech_hidden = hidden_states[:, -1:]
speech_logits = self.speech_head(speech_hidden)
processed_logits = logits_processors(speech_start_token, speech_logits[:, -1, :])
probs = F.softmax(processed_logits, dim=-1)
next_speech_token = torch.multinomial(probs, num_samples=1)
generated_speech_tokens.append(next_speech_token)
current_speech_token = next_speech_token
for _ in tqdm(range(max_gen_len)):
current_speech_embed = self.speech_emb(current_speech_token)
llm_outputs = self.tfmr(
inputs_embeds=current_speech_embed,
past_key_values=past_key_values,
use_cache=True
)
hidden_states = llm_outputs[0]
past_key_values = llm_outputs.past_key_values
speech_logits = self.speech_head(hidden_states)
input_ids = torch.cat(generated_speech_tokens, dim=1)
processed_logits = logits_processors(input_ids, speech_logits[:, -1, :])
if torch.all(processed_logits == -float("inf")):
print("Warning: All logits are -inf")
break
probs = F.softmax(processed_logits, dim=-1)
next_speech_token = torch.multinomial(probs, num_samples=1)
generated_speech_tokens.append(next_speech_token)
current_speech_token = next_speech_token
if torch.all(next_speech_token == self.hp.stop_speech_token):
break
all_tokens = torch.cat(generated_speech_tokens, dim=1)
# Remove EOS token if present
if all_tokens.size(1) > 0 and all_tokens[0, -1] == self.hp.stop_speech_token:
all_tokens = all_tokens[:, :-1]
return all_tokens

View File

@ -1,9 +1,10 @@
import logging
import json
import re
import torch
from pathlib import Path
from unicodedata import category, normalize
from unicodedata import category
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
@ -32,7 +33,7 @@ class EnTokenizer:
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
return text_tokens
def encode(self, txt: str):
def encode( self, txt: str, verbose=False):
"""
clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer
"""
@ -45,7 +46,8 @@ class EnTokenizer:
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt: str = self.tokenizer.decode(seq, skip_special_tokens=False)
txt: str = self.tokenizer.decode(seq,
skip_special_tokens=False)
txt = txt.replace(' ', '')
txt = txt.replace(SPACE, ' ')
txt = txt.replace(EOT, '')
@ -59,7 +61,6 @@ REPO_ID = "ResembleAI/chatterbox"
# Global instances for optional dependencies
_kakasi = None
_dicta = None
_russian_stresser = None
def is_kanji(c: str) -> bool:
@ -190,7 +191,7 @@ class ChineseCangjieConverter:
def _init_segmenter(self):
"""Initialize pkuseg segmenter."""
try:
from spacy_pkuseg import pkuseg
from pkuseg import pkuseg
self.segmenter = pkuseg()
except ImportError:
logger.warning("pkuseg not available - Chinese segmentation will be skipped")
@ -234,25 +235,6 @@ class ChineseCangjieConverter:
return "".join(output)
def add_russian_stress(text: str) -> str:
"""Russian text normalization: adds stress marks to Russian text."""
global _russian_stresser
try:
if _russian_stresser is None:
from russian_text_stresser.text_stresser import RussianTextStresser
_russian_stresser = RussianTextStresser()
return _russian_stresser.stress_text(text)
except ImportError:
logger.warning("russian_text_stresser not available - Russian stress labeling skipped")
return text
except Exception as e:
logger.warning(f"Russian stress labeling failed: {e}")
return text
class MTLTokenizer:
def __init__(self, vocab_file_path):
self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
@ -265,26 +247,12 @@ class MTLTokenizer:
assert SOT in voc
assert EOT in voc
def preprocess_text(self, raw_text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
"""
Text preprocessor that handles lowercase conversion and NFKD normalization.
"""
preprocessed_text = raw_text
if lowercase:
preprocessed_text = preprocessed_text.lower()
if nfkd_normalize:
preprocessed_text = normalize("NFKD", preprocessed_text)
return preprocessed_text
def text_to_tokens(self, text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
text_tokens = self.encode(text, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize)
def text_to_tokens(self, text: str, language_id: str = None):
text_tokens = self.encode(text, language_id=language_id)
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
return text_tokens
def encode(self, txt: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
txt = self.preprocess_text(txt, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize)
def encode(self, txt: str, language_id: str = None):
# Language-specific text processing
if language_id == 'zh':
txt = self.cangjie_converter(txt)
@ -294,8 +262,6 @@ class MTLTokenizer:
txt = add_hebrew_diacritics(txt)
elif language_id == 'ko':
txt = korean_normalize(txt)
elif language_id == 'ru':
txt = add_russian_stress(txt)
# Prepend language token
if language_id:

View File

@ -168,7 +168,7 @@ class ChatterboxMultilingualTTS:
ve.to(device).eval()
t3 = T3(T3Config.multilingual())
t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors")
t3_state = load_safetensors(ckpt_dir / "t3_23lang.safetensors")
if "model" in t3_state.keys():
t3_state = t3_state["model"][0]
t3.load_state_dict(t3_state)
@ -181,7 +181,7 @@ class ChatterboxMultilingualTTS:
s3gen.to(device).eval()
tokenizer = MTLTokenizer(
str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json")
str(ckpt_dir / "mtl_tokenizer.json")
)
conds = None
@ -197,7 +197,7 @@ class ChatterboxMultilingualTTS:
repo_id=REPO_ID,
repo_type="model",
revision="main",
allow_patterns=["ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"],
allow_patterns=["ve.pt", "t3_23lang.safetensors", "s3gen.pt", "mtl_tokenizer.json", "conds.pt", "Cangjie5_TC.json"],
token=os.getenv("HF_TOKEN"),
)
)

View File

@ -269,4 +269,4 @@ class ChatterboxTTS:
)
wav = wav.squeeze(0).detach().cpu().numpy()
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
return torch.from_numpy(watermarked_wav).unsqueeze(0)
return torch.from_numpy(watermarked_wav).unsqueeze(0)

View File

@ -1,296 +0,0 @@
import os
import math
from dataclasses import dataclass
from pathlib import Path
import librosa
import torch
import perth
import pyloudnorm as ln
from safetensors.torch import load_file
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from .models.t3 import T3
from .models.s3tokenizer import S3_SR
from .models.s3gen import S3GEN_SR, S3Gen
from .models.tokenizers import EnTokenizer
from .models.voice_encoder import VoiceEncoder
from .models.t3.modules.cond_enc import T3Cond
from .models.t3.modules.t3_config import T3Config
from .models.s3gen.const import S3GEN_SIL
import logging
logger = logging.getLogger(__name__)
REPO_ID = "ResembleAI/chatterbox-turbo"
def punc_norm(text: str) -> str:
"""
Quick cleanup func for punctuation from LLMs or
containing chars not seen often in the dataset
"""
if len(text) == 0:
return "You need to add some text for me to talk."
# Capitalise first letter
if text[0].islower():
text = text[0].upper() + text[1:]
# Remove multiple space chars
text = " ".join(text.split())
# Replace uncommon/llm punc
punc_to_replace = [
("", ", "),
(":", ","),
("", "-"),
("", "-"),
(" ,", ","),
("", "\""),
("", "\""),
("", "'"),
("", "'"),
]
for old_char_sequence, new_char in punc_to_replace:
text = text.replace(old_char_sequence, new_char)
# Add full stop if no ending punc
text = text.rstrip(" ")
sentence_enders = {".", "!", "?", "-", ","}
if not any(text.endswith(p) for p in sentence_enders):
text += "."
return text
@dataclass
class Conditionals:
"""
Conditionals for T3 and S3Gen
- T3 conditionals:
- speaker_emb
- clap_emb
- cond_prompt_speech_tokens
- cond_prompt_speech_emb
- emotion_adv
- S3Gen conditionals:
- prompt_token
- prompt_token_len
- prompt_feat
- prompt_feat_len
- embedding
"""
t3: T3Cond
gen: dict
def to(self, device):
self.t3 = self.t3.to(device=device)
for k, v in self.gen.items():
if torch.is_tensor(v):
self.gen[k] = v.to(device=device)
return self
def save(self, fpath: Path):
arg_dict = dict(
t3=self.t3.__dict__,
gen=self.gen
)
torch.save(arg_dict, fpath)
@classmethod
def load(cls, fpath, map_location="cpu"):
if isinstance(map_location, str):
map_location = torch.device(map_location)
kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
class ChatterboxTurboTTS:
ENC_COND_LEN = 15 * S3_SR
DEC_COND_LEN = 10 * S3GEN_SR
def __init__(
self,
t3: T3,
s3gen: S3Gen,
ve: VoiceEncoder,
tokenizer: EnTokenizer,
device: str,
conds: Conditionals = None,
):
self.sr = S3GEN_SR # sample rate of synthesized audio
self.t3 = t3
self.s3gen = s3gen
self.ve = ve
self.tokenizer = tokenizer
self.device = device
self.conds = conds
self.watermarker = perth.PerthImplicitWatermarker()
@classmethod
def from_local(cls, ckpt_dir, device) -> 'ChatterboxTurboTTS':
ckpt_dir = Path(ckpt_dir)
# Always load to CPU first for non-CUDA devices to handle CUDA-saved models
if device in ["cpu", "mps"]:
map_location = torch.device('cpu')
else:
map_location = None
ve = VoiceEncoder()
ve.load_state_dict(
load_file(ckpt_dir / "ve.safetensors")
)
ve.to(device).eval()
# Turbo specific hp
hp = T3Config(text_tokens_dict_size=50276)
hp.llama_config_name = "GPT2_medium"
hp.speech_tokens_dict_size = 6563
hp.input_pos_emb = None
hp.speech_cond_prompt_len = 375
hp.use_perceiver_resampler = False
hp.emotion_adv = False
t3 = T3(hp)
t3_state = load_file(ckpt_dir / "t3_turbo_v1.safetensors")
if "model" in t3_state.keys():
t3_state = t3_state["model"][0]
t3.load_state_dict(t3_state)
del t3.tfmr.wte
t3.to(device).eval()
s3gen = S3Gen(meanflow=True)
weights = load_file(ckpt_dir / "s3gen_meanflow.safetensors")
s3gen.load_state_dict(
weights, strict=True
)
s3gen.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if len(tokenizer) != 50276:
print(f"WARNING: Tokenizer len {len(tokenizer)} != 50276")
conds = None
builtin_voice = ckpt_dir / "conds.pt"
if builtin_voice.exists():
conds = Conditionals.load(builtin_voice, map_location=map_location).to(device)
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
@classmethod
def from_pretrained(cls, device) -> 'ChatterboxTurboTTS':
# Check if MPS is available on macOS
if device == "mps" and not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was not built with MPS enabled.")
else:
print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
device = "cpu"
local_path = snapshot_download(
repo_id=REPO_ID,
token=os.getenv("HF_TOKEN") or True,
# Optional: Filter to download only what you need
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"]
)
return cls.from_local(local_path, device)
def norm_loudness(self, wav, sr, target_lufs=-27):
try:
meter = ln.Meter(sr)
loudness = meter.integrated_loudness(wav)
gain_db = target_lufs - loudness
gain_linear = 10.0 ** (gain_db / 20.0)
if math.isfinite(gain_linear) and gain_linear > 0.0:
wav = wav * gain_linear
except Exception as e:
print(f"Warning: Error in norm_loudness, skipping: {e}")
return wav
def prepare_conditionals(self, wav_fpath, exaggeration=0.5, norm_loudness=True):
## Load and norm reference wav
s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
assert len(s3gen_ref_wav) / _sr > 5.0, "Audio prompt must be longer than 5 seconds!"
if norm_loudness:
s3gen_ref_wav = self.norm_loudness(s3gen_ref_wav, _sr)
ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
# Speech cond prompt tokens
if plen := self.t3.hp.speech_cond_prompt_len:
s3_tokzr = self.s3gen.tokenizer
t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
# Voice-encoder speaker embedding
ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
t3_cond = T3Cond(
speaker_emb=ve_embed,
cond_prompt_speech_tokens=t3_cond_prompt_tokens,
emotion_adv=exaggeration * torch.ones(1, 1, 1),
).to(device=self.device)
self.conds = Conditionals(t3_cond, s3gen_ref_dict)
def generate(
self,
text,
repetition_penalty=1.2,
min_p=0.00,
top_p=0.95,
audio_prompt_path=None,
exaggeration=0.0,
cfg_weight=0.0,
temperature=0.8,
top_k=1000,
norm_loudness=True,
):
if audio_prompt_path:
self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration, norm_loudness=norm_loudness)
else:
assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
if cfg_weight > 0.0 or exaggeration > 0.0 or min_p > 0.0:
logger.warning("CFG, min_p and exaggeration are not supported by Turbo version and will be ignored.")
# Norm and tokenize text
text = punc_norm(text)
text_tokens = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
text_tokens = text_tokens.input_ids.to(self.device)
speech_tokens = self.t3.inference_turbo(
t3_cond=self.conds.t3,
text_tokens=text_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
# Remove OOV tokens and add silence to end
speech_tokens = speech_tokens[speech_tokens < 6561]
speech_tokens = speech_tokens.to(self.device)
silence = torch.tensor([S3GEN_SIL, S3GEN_SIL, S3GEN_SIL]).long().to(self.device)
speech_tokens = torch.cat([speech_tokens, silence])
wav, _ = self.s3gen.inference(
speech_tokens=speech_tokens,
ref_dict=self.conds.gen,
n_cfm_timesteps=2,
)
wav = wav.squeeze(0).detach().cpu().numpy()
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
return torch.from_numpy(watermarked_wav).unsqueeze(0)