Compare commits
8 Commits
fix-multil
...
master
| Author | SHA1 | Date |
|---|---|---|
|
|
ed27b95ee4 | |
|
|
bc58894bb8 | |
|
|
c375f02ca1 | |
|
|
c61ec831bd | |
|
|
bf169fe5f5 | |
|
|
1b5ae50585 | |
|
|
a0434ffb3d | |
|
|
1798729e3a |
|
|
@ -0,0 +1,23 @@
|
|||
name: Test Installation
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "master" ]
|
||||
pull_request:
|
||||
branches: [ "master" ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Test Standard Install
|
||||
run: |
|
||||
pip install -e .
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 763 KiB |
104
README.md
104
README.md
|
|
@ -1,47 +1,36 @@
|
|||
|
||||
<img width="1200" alt="Chatterbox Multilingual" src="https://github.com/user-attachments/assets/53b411b4-d7b8-4b18-87f3-335145837684" />
|
||||

|
||||
|
||||
|
||||
# Chatterbox TTS
|
||||
|
||||
[](https://resemble-ai.github.io/chatterbox_demopage/)
|
||||
[](https://huggingface.co/spaces/ResembleAI/Chatterbox)
|
||||
[](https://resemble-ai.github.io/chatterbox_turbo_demopage/)
|
||||
[](https://huggingface.co/spaces/ResembleAI/chatterbox-turbo-demo)
|
||||
[](https://podonos.com/resembleai/chatterbox)
|
||||
[](https://discord.gg/rJq9cRJBJ6)
|
||||
|
||||
_Made with ♥️ by <a href="https://resemble.ai" target="_blank"><img width="100" alt="resemble-logo-horizontal" src="https://github.com/user-attachments/assets/35cf756b-3506-4943-9c72-c05ddfa4e525" /></a>
|
||||
|
||||
We're excited to introduce **Chatterbox Multilingual**, [Resemble AI's](https://resemble.ai) first production-grade open source TTS model supporting **23 languages** out of the box. Licensed under MIT, Chatterbox has been benchmarked against leading closed-source systems like ElevenLabs, and is consistently preferred in side-by-side evaluations.
|
||||
**Chatterbox** is a family of three state-of-the-art, open-source text-to-speech models by Resemble AI.
|
||||
|
||||
Whether you're working on memes, videos, games, or AI agents, Chatterbox brings your content to life across languages. It's also the first open source TTS model to support **emotion exaggeration control** with robust **multilingual zero-shot voice cloning**. Try the english only version now on our [English Hugging Face Gradio app.](https://huggingface.co/spaces/ResembleAI/Chatterbox). Or try the multilingual version on our [Multilingual Hugging Face Gradio app.](https://huggingface.co/spaces/ResembleAI/Chatterbox-Multilingual-TTS).
|
||||
We are excited to introduce **Chatterbox-Turbo**, our most efficient model yet. Built on a streamlined 350M parameter architecture, **Turbo** delivers high-quality speech with less compute and VRAM than our previous models. We have also distilled the speech-token-to-mel decoder, previously a bottleneck, reducing generation from 10 steps to just **one**, while retaining high-fidelity audio output.
|
||||
|
||||
**Paralinguistic tags** are now native to the Turbo model, allowing you to use `[cough]`, `[laugh]`, `[chuckle]`, and more to add distinct realism. While Turbo was built primarily for low-latency voice agents, it excels at narration and creative workflows.
|
||||
|
||||
If you like the model but need to scale or tune it for higher accuracy, check out our competitively priced TTS service (<a href="https://resemble.ai">link</a>). It delivers reliable performance with ultra-low latency of sub 200ms—ideal for production use in agents, applications, or interactive media.
|
||||
|
||||
# Key Details
|
||||
- Multilingual, zero-shot TTS supporting 23 languages
|
||||
- SoTA zeroshot English TTS
|
||||
- 0.5B Llama backbone
|
||||
- Unique exaggeration/intensity control
|
||||
- Ultra-stable with alignment-informed inference
|
||||
- Trained on 0.5M hours of cleaned data
|
||||
- Watermarked outputs
|
||||
- Easy voice conversion script
|
||||
- [Outperforms ElevenLabs](https://podonos.com/resembleai/chatterbox)
|
||||
<img width="1200" height="600" alt="Podonos Turbo Eval" src="https://storage.googleapis.com/chatterbox-demo-samples/turbo/podonos_turbo.png" />
|
||||
|
||||
# Supported Languages
|
||||
Arabic (ar) • Danish (da) • German (de) • Greek (el) • English (en) • Spanish (es) • Finnish (fi) • French (fr) • Hebrew (he) • Hindi (hi) • Italian (it) • Japanese (ja) • Korean (ko) • Malay (ms) • Dutch (nl) • Norwegian (no) • Polish (pl) • Portuguese (pt) • Russian (ru) • Swedish (sv) • Swahili (sw) • Turkish (tr) • Chinese (zh)
|
||||
# Tips
|
||||
- **General Use (TTS and Voice Agents):**
|
||||
- Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip’s language. To mitigate this, set `cfg_weight` to `0`.
|
||||
- The default settings (`exaggeration=0.5`, `cfg_weight=0.5`) work well for most prompts across all languages.
|
||||
- If the reference speaker has a fast speaking style, lowering `cfg_weight` to around `0.3` can improve pacing.
|
||||
### ⚡ Model Zoo
|
||||
|
||||
- **Expressive or Dramatic Speech:**
|
||||
- Try lower `cfg_weight` values (e.g. `~0.3`) and increase `exaggeration` to around `0.7` or higher.
|
||||
- Higher `exaggeration` tends to speed up speech; reducing `cfg_weight` helps compensate with slower, more deliberate pacing.
|
||||
Choose the right model for your application.
|
||||
|
||||
| Model | Size | Languages | Key Features | Best For | 🤗 | Examples |
|
||||
|:----------------------------------------------------------------------------------------------------------------| :--- | :--- |:--------------------------------------------------------|:---------------------------------------------|:--------------------------------------------------------------------------| :--- |
|
||||
| **Chatterbox-Turbo** | **350M** | **English** | Paralinguistic Tags (`[laugh]`), Lower Compute and VRAM | Zero-shot voice agents, Production | [Demo](https://huggingface.co/spaces/ResembleAI/chatterbox-turbo-demo) | [Listen](https://resemble-ai.github.io/chatterbox_turbo_demopage/) |
|
||||
| Chatterbox-Multilingual [(Language list)](#supported-languages) | 500M | 23+ | Zero-shot cloning, Multiple Languages | Global applications, Localization | [Demo](https://huggingface.co/spaces/ResembleAI/Chatterbox-Multilingual-TTS) | [Listen](https://resemble-ai.github.io/chatterbox_demopage/) |
|
||||
| Chatterbox [(Tips and Tricks)](#original-chatterbox-tips) | 500M | English | CFG & Exaggeration tuning | General zero-shot TTS with creative controls | [Demo](https://huggingface.co/spaces/ResembleAI/Chatterbox) | [Listen](https://resemble-ai.github.io/chatterbox_demopage/) |
|
||||
|
||||
# Installation
|
||||
## Installation
|
||||
```shell
|
||||
pip install chatterbox-tts
|
||||
```
|
||||
|
|
@ -57,11 +46,34 @@ pip install -e .
|
|||
```
|
||||
We developed and tested Chatterbox on Python 3.11 on Debian 11 OS; the versions of the dependencies are pinned in `pyproject.toml` to ensure consistency. You can modify the code or dependencies in this installation mode.
|
||||
|
||||
# Usage
|
||||
## Usage
|
||||
|
||||
##### Chatterbox-Turbo
|
||||
|
||||
```python
|
||||
import torchaudio as ta
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
import torch
|
||||
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
||||
|
||||
# Load the Turbo model
|
||||
model = ChatterboxTurboTTS.from_pretrained(device="cuda")
|
||||
|
||||
# Generate with Paralinguistic Tags
|
||||
text = "Hi there, Sarah here from MochaFone calling you back [chuckle], have you got one minute to chat about the billing issue?"
|
||||
|
||||
# Generate audio (requires a reference clip for voice cloning)
|
||||
wav = model.generate(text, audio_prompt_path="your_10s_ref_clip.wav")
|
||||
|
||||
ta.save("test-turbo.wav", wav, model.sr)
|
||||
```
|
||||
|
||||
##### Chatterbox and Chatterbox-Multilingual
|
||||
|
||||
```python
|
||||
|
||||
import torchaudio as ta
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
||||
|
||||
# English example
|
||||
model = ChatterboxTTS.from_pretrained(device="cuda")
|
||||
|
|
@ -88,14 +100,21 @@ ta.save("test-2.wav", wav, model.sr)
|
|||
```
|
||||
See `example_tts.py` and `example_vc.py` for more examples.
|
||||
|
||||
# Acknowledgements
|
||||
- [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice)
|
||||
- [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning)
|
||||
- [HiFT-GAN](https://github.com/yl4579/HiFTNet)
|
||||
- [Llama 3](https://github.com/meta-llama/llama3)
|
||||
- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer)
|
||||
## Supported Languages
|
||||
Arabic (ar) • Danish (da) • German (de) • Greek (el) • English (en) • Spanish (es) • Finnish (fi) • French (fr) • Hebrew (he) • Hindi (hi) • Italian (it) • Japanese (ja) • Korean (ko) • Malay (ms) • Dutch (nl) • Norwegian (no) • Polish (pl) • Portuguese (pt) • Russian (ru) • Swedish (sv) • Swahili (sw) • Turkish (tr) • Chinese (zh)
|
||||
|
||||
# Built-in PerTh Watermarking for Responsible AI
|
||||
## Original Chatterbox Tips
|
||||
- **General Use (TTS and Voice Agents):**
|
||||
- Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip’s language. To mitigate this, set `cfg_weight` to `0`.
|
||||
- The default settings (`exaggeration=0.5`, `cfg_weight=0.5`) work well for most prompts across all languages.
|
||||
- If the reference speaker has a fast speaking style, lowering `cfg_weight` to around `0.3` can improve pacing.
|
||||
|
||||
- **Expressive or Dramatic Speech:**
|
||||
- Try lower `cfg_weight` values (e.g. `~0.3`) and increase `exaggeration` to around `0.7` or higher.
|
||||
- Higher `exaggeration` tends to speed up speech; reducing `cfg_weight` helps compensate with slower, more deliberate pacing.
|
||||
|
||||
|
||||
## Built-in PerTh Watermarking for Responsible AI
|
||||
|
||||
Every audio file generated by Chatterbox includes [Resemble AI's Perth (Perceptual Threshold) Watermarker](https://github.com/resemble-ai/perth) - imperceptible neural watermarks that survive MP3 compression, audio editing, and common manipulations while maintaining nearly 100% detection accuracy.
|
||||
|
||||
|
|
@ -123,11 +142,18 @@ print(f"Extracted watermark: {watermark}")
|
|||
```
|
||||
|
||||
|
||||
# Official Discord
|
||||
## Official Discord
|
||||
|
||||
👋 Join us on [Discord](https://discord.gg/rJq9cRJBJ6) and let's build something awesome together!
|
||||
|
||||
# Citation
|
||||
## Acknowledgements
|
||||
- [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice)
|
||||
- [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning)
|
||||
- [HiFT-GAN](https://github.com/yl4579/HiFTNet)
|
||||
- [Llama 3](https://github.com/meta-llama/llama3)
|
||||
- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer)
|
||||
|
||||
## Citation
|
||||
If you find this model useful, please consider citing.
|
||||
```
|
||||
@misc{chatterboxtts2025,
|
||||
|
|
@ -138,5 +164,5 @@ If you find this model useful, please consider citing.
|
|||
note = {GitHub repository}
|
||||
}
|
||||
```
|
||||
# Disclaimer
|
||||
## Disclaimer
|
||||
Don't use this model to do bad things. Prompts are sourced from freely available data on the internet.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
import torchaudio as ta
|
||||
import torch
|
||||
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
||||
|
||||
# Load the Turbo model
|
||||
model = ChatterboxTurboTTS.from_pretrained(device="cuda")
|
||||
|
||||
# Generate with Paralinguistic Tags
|
||||
text = "Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and all that jazz. Would you like me to get some prices for you?"
|
||||
|
||||
# Generate audio (requires a reference clip for voice cloning)
|
||||
# wav = model.generate(text, audio_prompt_path="your_10s_ref_clip.wav")
|
||||
wav = model.generate(text)
|
||||
ta.save("test-turbo.wav", wav, model.sr)
|
||||
|
|
@ -0,0 +1,186 @@
|
|||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import gradio as gr
|
||||
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
EVENT_TAGS = [
|
||||
"[clear throat]", "[sigh]", "[shush]", "[cough]", "[groan]",
|
||||
"[sniff]", "[gasp]", "[chuckle]", "[laugh]"
|
||||
]
|
||||
|
||||
# --- REFINED CSS ---
|
||||
# 1. tag-container: Forces the row to wrap items instead of scrolling. Removes borders/backgrounds.
|
||||
# 2. tag-btn: Sets the specific look (indigo theme) and stops them from stretching.
|
||||
CUSTOM_CSS = """
|
||||
.tag-container {
|
||||
display: flex !important;
|
||||
flex-wrap: wrap !important; /* This fixes the one-per-line issue */
|
||||
gap: 8px !important;
|
||||
margin-top: 5px !important;
|
||||
margin-bottom: 10px !important;
|
||||
border: none !important;
|
||||
background: transparent !important;
|
||||
}
|
||||
|
||||
.tag-btn {
|
||||
min-width: fit-content !important;
|
||||
width: auto !important;
|
||||
height: 32px !important;
|
||||
font-size: 13px !important;
|
||||
background: #eef2ff !important;
|
||||
border: 1px solid #c7d2fe !important;
|
||||
color: #3730a3 !important;
|
||||
border-radius: 6px !important;
|
||||
padding: 0 10px !important;
|
||||
margin: 0 !important;
|
||||
box-shadow: none !important;
|
||||
}
|
||||
|
||||
.tag-btn:hover {
|
||||
background: #c7d2fe !important;
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
"""
|
||||
|
||||
INSERT_TAG_JS = """
|
||||
(tag_val, current_text) => {
|
||||
const textarea = document.querySelector('#main_textbox textarea');
|
||||
if (!textarea) return current_text + " " + tag_val;
|
||||
|
||||
const start = textarea.selectionStart;
|
||||
const end = textarea.selectionEnd;
|
||||
|
||||
let prefix = " ";
|
||||
let suffix = " ";
|
||||
|
||||
if (start === 0) prefix = "";
|
||||
else if (current_text[start - 1] === ' ') prefix = "";
|
||||
|
||||
if (end < current_text.length && current_text[end] === ' ') suffix = "";
|
||||
|
||||
return current_text.slice(0, start) + prefix + tag_val + suffix + current_text.slice(end);
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
|
||||
def load_model():
|
||||
print(f"Loading Chatterbox-Turbo on {DEVICE}...")
|
||||
model = ChatterboxTurboTTS.from_pretrained(DEVICE)
|
||||
return model
|
||||
|
||||
|
||||
def generate(
|
||||
model,
|
||||
text,
|
||||
audio_prompt_path,
|
||||
temperature,
|
||||
seed_num,
|
||||
min_p,
|
||||
top_p,
|
||||
top_k,
|
||||
repetition_penalty,
|
||||
norm_loudness
|
||||
):
|
||||
if model is None:
|
||||
model = ChatterboxTurboTTS.from_pretrained(DEVICE)
|
||||
|
||||
if seed_num != 0:
|
||||
set_seed(int(seed_num))
|
||||
|
||||
wav = model.generate(
|
||||
text,
|
||||
audio_prompt_path=audio_prompt_path,
|
||||
temperature=temperature,
|
||||
min_p=min_p,
|
||||
top_p=top_p,
|
||||
top_k=int(top_k),
|
||||
repetition_penalty=repetition_penalty,
|
||||
norm_loudness=norm_loudness,
|
||||
)
|
||||
return (model.sr, wav.squeeze(0).numpy())
|
||||
|
||||
|
||||
with gr.Blocks(title="Chatterbox Turbo", css=CUSTOM_CSS) as demo:
|
||||
gr.Markdown("# ⚡ Chatterbox Turbo")
|
||||
|
||||
model_state = gr.State(None)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
text = gr.Textbox(
|
||||
value="Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and um all that jazz. Would you like me to get some prices for you?",
|
||||
label="Text to synthesize (max chars 300)",
|
||||
max_lines=5,
|
||||
elem_id="main_textbox"
|
||||
)
|
||||
|
||||
# --- Event Tags ---
|
||||
# Switched back to Row, but applied specific CSS to force wrapping
|
||||
with gr.Row(elem_classes=["tag-container"]):
|
||||
for tag in EVENT_TAGS:
|
||||
# elem_classes targets the button specifically
|
||||
btn = gr.Button(tag, elem_classes=["tag-btn"])
|
||||
|
||||
btn.click(
|
||||
fn=None,
|
||||
inputs=[btn, text],
|
||||
outputs=text,
|
||||
js=INSERT_TAG_JS
|
||||
)
|
||||
|
||||
ref_wav = gr.Audio(
|
||||
sources=["upload", "microphone"],
|
||||
type="filepath",
|
||||
label="Reference Audio File",
|
||||
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_random_podcast.wav"
|
||||
)
|
||||
|
||||
run_btn = gr.Button("Generate ⚡", variant="primary")
|
||||
|
||||
with gr.Column():
|
||||
audio_output = gr.Audio(label="Output Audio")
|
||||
|
||||
with gr.Accordion("Advanced Options", open=False):
|
||||
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
|
||||
temp = gr.Slider(0.05, 2.0, step=.05, label="Temperature", value=0.8)
|
||||
top_p = gr.Slider(0.00, 1.00, step=0.01, label="Top P", value=0.95)
|
||||
top_k = gr.Slider(0, 1000, step=10, label="Top K", value=1000)
|
||||
repetition_penalty = gr.Slider(1.00, 2.00, step=0.05, label="Repetition Penalty", value=1.2)
|
||||
min_p = gr.Slider(0.00, 1.00, step=0.01, label="Min P (Set to 0 to disable)", value=0.00)
|
||||
norm_loudness = gr.Checkbox(value=True, label="Normalize Loudness (-27 LUFS)")
|
||||
|
||||
demo.load(fn=load_model, inputs=[], outputs=model_state)
|
||||
|
||||
run_btn.click(
|
||||
fn=generate,
|
||||
inputs=[
|
||||
model_state,
|
||||
text,
|
||||
ref_wav,
|
||||
temp,
|
||||
seed_num,
|
||||
min_p,
|
||||
top_p,
|
||||
top_k,
|
||||
repetition_penalty,
|
||||
norm_loudness,
|
||||
],
|
||||
outputs=audio_output,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.queue(
|
||||
max_size=50,
|
||||
default_concurrency_limit=1,
|
||||
).launch(share=True)
|
||||
|
|
@ -3,7 +3,6 @@ import numpy as np
|
|||
import torch
|
||||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
|
||||
import gradio as gr
|
||||
import spaces
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"🚀 Running on device: {DEVICE}")
|
||||
|
|
@ -176,7 +175,6 @@ def resolve_audio_prompt(language_id: str, provided_path: str | None) -> str | N
|
|||
return LANGUAGE_CONFIG.get(language_id, {}).get("audio")
|
||||
|
||||
|
||||
@spaces.GPU
|
||||
def generate_tts_audio(
|
||||
text_input: str,
|
||||
language_id: str,
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
[project]
|
||||
name = "chatterbox-tts"
|
||||
version = "0.1.3"
|
||||
version = "0.1.6"
|
||||
description = "Chatterbox: Open Source TTS and Voice Conversion by Resemble AI"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
requires-python = ">=3.10"
|
||||
license = {file = "LICENSE"}
|
||||
authors = [
|
||||
{name = "resemble-ai", email = "engineering@resemble.ai"}
|
||||
]
|
||||
dependencies = [
|
||||
"numpy>=1.26.0",
|
||||
"numpy>=1.24.0,<1.26.0",
|
||||
"librosa==0.11.0",
|
||||
"s3tokenizer",
|
||||
"torch==2.6.0",
|
||||
|
|
@ -19,8 +19,11 @@ dependencies = [
|
|||
"resemble-perth==1.0.1",
|
||||
"conformer==0.3.2",
|
||||
"safetensors==0.5.3",
|
||||
"pkuseg ==0.0.25",
|
||||
"pykakasi==2.3.0"
|
||||
"spacy-pkuseg",
|
||||
"pykakasi==2.3.0",
|
||||
"gradio==5.44.1",
|
||||
"pyloudnorm",
|
||||
"omegaconf"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
S3GEN_SR = 24000
|
||||
S3GEN_SIL = 4299
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from .utils.mask import add_optional_chunk_mask
|
|||
from .matcha.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, \
|
||||
TimestepEmbedding, Upsample1D
|
||||
from .matcha.transformer import BasicTransformerBlock
|
||||
from .utils.intmeanflow import get_intmeanflow_time_mixer
|
||||
|
||||
|
||||
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
|
|
@ -95,8 +96,6 @@ class CausalConv1d(torch.nn.Conv1d):
|
|||
x = F.pad(x, self.causal_padding)
|
||||
x = super(CausalConv1d, self).forward(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConditionalDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -110,6 +109,7 @@ class ConditionalDecoder(nn.Module):
|
|||
num_mid_blocks=12,
|
||||
num_heads=8,
|
||||
act_fn="gelu",
|
||||
meanflow=False,
|
||||
):
|
||||
"""
|
||||
This decoder requires an input with the same shape of the target. So, if your text content
|
||||
|
|
@ -117,6 +117,7 @@ class ConditionalDecoder(nn.Module):
|
|||
"""
|
||||
super().__init__()
|
||||
channels = tuple(channels)
|
||||
self.meanflow = meanflow
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.causal = causal
|
||||
|
|
@ -127,6 +128,7 @@ class ConditionalDecoder(nn.Module):
|
|||
time_embed_dim=time_embed_dim,
|
||||
act_fn="silu",
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
|
@ -215,6 +217,14 @@ class ConditionalDecoder(nn.Module):
|
|||
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
||||
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
||||
self.initialize_weights()
|
||||
self.time_embed_mixer = None
|
||||
if self.meanflow:
|
||||
self.time_embed_mixer = get_intmeanflow_time_mixer(time_embed_dim)
|
||||
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.final_proj.weight.dtype
|
||||
|
||||
def initialize_weights(self):
|
||||
for m in self.modules():
|
||||
|
|
@ -230,15 +240,16 @@ class ConditionalDecoder(nn.Module):
|
|||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
||||
def forward(self, x, mask, mu, t, spks=None, cond=None, r=None):
|
||||
"""Forward pass of the UNet1DConditional model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): shape (batch_size, in_channels, time)
|
||||
mask (_type_): shape (batch_size, 1, time)
|
||||
x: (B, 80, T)
|
||||
mask (_type_)
|
||||
t (_type_): shape (batch_size)
|
||||
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
||||
cond (_type_, optional): placeholder for future use. Defaults to None.
|
||||
spks (_type_, optional) Defaults to None.
|
||||
cond (_type_, optional)
|
||||
r: end time for meanflow mode (shape (1,) tensor)
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
|
|
@ -247,10 +258,15 @@ class ConditionalDecoder(nn.Module):
|
|||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
|
||||
t = self.time_embeddings(t).to(t.dtype)
|
||||
t = self.time_mlp(t)
|
||||
|
||||
if self.meanflow:
|
||||
r = self.time_embeddings(r).to(t.dtype)
|
||||
r = self.time_mlp(r)
|
||||
concat_embed = torch.cat([t, r], dim=1)
|
||||
t = self.time_embed_mixer(concat_embed)
|
||||
|
||||
x = pack([x, mu], "b * t")[0]
|
||||
|
||||
if spks is not None:
|
||||
|
|
|
|||
|
|
@ -21,205 +21,49 @@ import torch.nn as nn
|
|||
from torch.nn import functional as F
|
||||
from .utils.mask import make_pad_mask
|
||||
from .configs import CFM_PARAMS
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
class MaskedDiffWithXvec(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int = 512,
|
||||
output_size: int = 80,
|
||||
spk_embed_dim: int = 192,
|
||||
output_type: str = "mel",
|
||||
vocab_size: int = 4096,
|
||||
input_frame_rate: int = 50,
|
||||
only_mask_loss: bool = True,
|
||||
encoder: torch.nn.Module = None,
|
||||
length_regulator: torch.nn.Module = None,
|
||||
decoder: torch.nn.Module = None,
|
||||
decoder_conf: Dict = {
|
||||
'in_channels': 240,
|
||||
'out_channel': 80,
|
||||
'spk_emb_dim': 80,
|
||||
'n_spks': 1,
|
||||
'cfm_params': CFM_PARAMS,
|
||||
'decoder_params': {
|
||||
'channels': [256, 256],
|
||||
'dropout': 0.0,
|
||||
'attention_head_dim': 64,
|
||||
'n_blocks': 4,
|
||||
'num_mid_blocks': 12,
|
||||
'num_heads': 8,
|
||||
'act_fn': 'gelu',
|
||||
}
|
||||
},
|
||||
mel_feat_conf: Dict = {
|
||||
'n_fft': 1024,
|
||||
'num_mels': 80,
|
||||
'sampling_rate': 22050,
|
||||
'hop_size': 256,
|
||||
'win_size': 1024,
|
||||
'fmin': 0,
|
||||
'fmax': 8000
|
||||
}
|
||||
):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.decoder_conf = decoder_conf
|
||||
self.mel_feat_conf = mel_feat_conf
|
||||
self.vocab_size = vocab_size
|
||||
self.output_type = output_type
|
||||
self.input_frame_rate = input_frame_rate
|
||||
logging.info(f"input frame rate={self.input_frame_rate}")
|
||||
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
||||
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
||||
self.encoder = encoder
|
||||
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
||||
self.decoder = decoder
|
||||
self.length_regulator = length_regulator
|
||||
self.only_mask_loss = only_mask_loss
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
token = batch['speech_token'].to(device)
|
||||
token_len = batch['speech_token_len'].to(device)
|
||||
feat = batch['speech_feat'].to(device)
|
||||
feat_len = batch['speech_feat_len'].to(device)
|
||||
embedding = batch['embedding'].to(device)
|
||||
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
||||
token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
|
||||
|
||||
# text encode
|
||||
h, h_lengths = self.encoder(token, token_len)
|
||||
h = self.encoder_proj(h)
|
||||
h, h_lengths = self.length_regulator(h, feat_len)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros(feat.shape, device=token.device)
|
||||
for i, j in enumerate(feat_len):
|
||||
if random.random() < 0.5:
|
||||
continue
|
||||
index = random.randint(0, int(0.3 * j))
|
||||
conds[i, :index] = feat[i, :index]
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(feat_len)).to(h)
|
||||
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
||||
loss, _ = self.decoder.compute_loss(
|
||||
feat.transpose(1, 2).contiguous(),
|
||||
mask.unsqueeze(1),
|
||||
h.transpose(1, 2).contiguous(),
|
||||
embedding,
|
||||
cond=conds
|
||||
)
|
||||
return {'loss': loss}
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(self,
|
||||
token,
|
||||
token_len,
|
||||
prompt_token,
|
||||
prompt_token_len,
|
||||
prompt_feat,
|
||||
prompt_feat_len,
|
||||
embedding,
|
||||
flow_cache):
|
||||
if self.fp16 is True:
|
||||
prompt_feat = prompt_feat.half()
|
||||
embedding = embedding.half()
|
||||
|
||||
assert token.shape[0] == 1
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
||||
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
||||
|
||||
# Check for out-of-bounds token IDs
|
||||
vocab_size = self.input_embedding.num_embeddings
|
||||
if token.max() >= vocab_size or token.min() < 0:
|
||||
logging.warning(f"S3Gen: Token IDs out of bounds: min={token.min().item()}, max={token.max().item()}, vocab_size={vocab_size}")
|
||||
|
||||
token = self.input_embedding(torch.clamp(token, min=0, max=vocab_size-1)) * mask
|
||||
|
||||
# text encode
|
||||
h, h_lengths = self.encoder(token, token_len)
|
||||
h = self.encoder_proj(h)
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
|
||||
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||
conds[:, :mel_len1] = prompt_feat
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||
feat, flow_cache = self.decoder(
|
||||
mu=h.transpose(1, 2).contiguous(),
|
||||
mask=mask.unsqueeze(1),
|
||||
spks=embedding,
|
||||
cond=conds,
|
||||
n_timesteps=10,
|
||||
prompt_len=mel_len1,
|
||||
flow_cache=flow_cache
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat.float(), flow_cache
|
||||
def _repeat_batch_dim(tnsr, B, ndim):
|
||||
"repeat batch dimension if it's equal to 1"
|
||||
if tnsr is not None:
|
||||
# add missing batch dim if needed
|
||||
while tnsr.ndim < ndim:
|
||||
tnsr = tnsr[None]
|
||||
# repeat batch dim as needed
|
||||
if B > 1 and tnsr.size(0) == 1:
|
||||
tnsr = tnsr.repeat(B, *([1] * (ndim - 1)))
|
||||
assert tnsr.ndim == ndim, f"Expected {ndim=}, got {tnsr.ndim=}"
|
||||
return tnsr
|
||||
|
||||
|
||||
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int = 512,
|
||||
output_size: int = 80,
|
||||
spk_embed_dim: int = 192,
|
||||
output_type: str = "mel",
|
||||
vocab_size: int = 6561,
|
||||
input_frame_rate: int = 25,
|
||||
only_mask_loss: bool = True,
|
||||
token_mel_ratio: int = 2,
|
||||
pre_lookahead_len: int = 3,
|
||||
encoder: torch.nn.Module = None,
|
||||
decoder: torch.nn.Module = None,
|
||||
decoder_conf: Dict = {
|
||||
'in_channels': 240,
|
||||
'out_channel': 80,
|
||||
'spk_emb_dim': 80,
|
||||
'n_spks': 1,
|
||||
'cfm_params': CFM_PARAMS,
|
||||
'decoder_params': {
|
||||
'channels': [256, 256],
|
||||
'dropout': 0.0,
|
||||
'attention_head_dim': 64,
|
||||
'n_blocks': 4,
|
||||
'num_mid_blocks': 12,
|
||||
'num_heads': 8,
|
||||
'act_fn': 'gelu',
|
||||
}
|
||||
},
|
||||
mel_feat_conf: Dict = {
|
||||
'n_fft': 1024,
|
||||
'num_mels': 80,
|
||||
'sampling_rate': 22050,
|
||||
'hop_size': 256,
|
||||
'win_size': 1024,
|
||||
'fmin': 0,
|
||||
'fmax': 8000
|
||||
}
|
||||
):
|
||||
def __init__(self,
|
||||
input_size: int = 512,
|
||||
output_size: int = 80,
|
||||
spk_embed_dim: int = 192,
|
||||
output_type: str = "mel",
|
||||
vocab_size: int = 6561,
|
||||
input_frame_rate: int = 25,
|
||||
only_mask_loss: bool = True,
|
||||
token_mel_ratio: int = 2,
|
||||
pre_lookahead_len: int = 3,
|
||||
encoder: torch.nn.Module = None,
|
||||
decoder: torch.nn.Module = None,
|
||||
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
||||
'cfm_params': DictConfig(
|
||||
{'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
||||
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7,
|
||||
'reg_loss_type': 'l1'}),
|
||||
'decoder_params': {'channels': [256, 256], 'dropout': 0.0,
|
||||
'attention_head_dim': 64,
|
||||
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8,
|
||||
'act_fn': 'gelu'}},
|
||||
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
||||
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
|
|
@ -238,8 +82,51 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||
self.token_mel_ratio = token_mel_ratio
|
||||
self.pre_lookahead_len = pre_lookahead_len
|
||||
|
||||
# FIXME: this was missing - just putting it in as false
|
||||
self.fp16 = False
|
||||
# NOTE: copied in from cosyvoice repo
|
||||
def compute_loss(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
token = batch['speech_token'].to(device)
|
||||
token_len = batch['speech_token_len'].to(device)
|
||||
feat = batch['speech_feat'].to(device) # (B, 80, T)
|
||||
feat_len = batch['speech_feat_len'].to(device)
|
||||
embedding = batch['embedding'].to(device)
|
||||
|
||||
# NOTE unified training, static_chunk_size > 0 or = 0
|
||||
# streaming = True if random.random() < 0.5 else False
|
||||
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) # (B, T, 1)
|
||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask # (B, T, emb)
|
||||
|
||||
# text encode
|
||||
h, h_lengths = self.encoder(token, token_len) # (B, T, C) -> (B, 2T, C)
|
||||
h = self.encoder_proj(h)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros(feat.shape, device=token.device)
|
||||
for i, j in enumerate(feat_len):
|
||||
if random.random() < 0.5:
|
||||
continue
|
||||
index = random.randint(0, int(0.3 * j))
|
||||
conds[i, :, :index] = feat[i, :, :index]
|
||||
|
||||
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
|
||||
loss, _ = self.decoder.compute_loss(
|
||||
feat.contiguous(),
|
||||
mask.unsqueeze(1),
|
||||
h.transpose(1, 2).contiguous(),
|
||||
embedding,
|
||||
cond=conds,
|
||||
# streaming=streaming,
|
||||
)
|
||||
return {'loss': loss}
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(self,
|
||||
|
|
@ -250,41 +137,62 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||
prompt_feat,
|
||||
prompt_feat_len,
|
||||
embedding,
|
||||
finalize):
|
||||
if self.fp16 is True:
|
||||
prompt_feat = prompt_feat.half()
|
||||
embedding = embedding.half()
|
||||
finalize,
|
||||
n_timesteps=10,
|
||||
noised_mels=None,
|
||||
meanflow=False):
|
||||
# token: (B, n_toks)
|
||||
# token_len: (B,)
|
||||
B = token.size(0)
|
||||
|
||||
assert token.shape[0] == 1
|
||||
# xvec projection
|
||||
embedding = torch.atleast_2d(embedding)
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
embedding = self.spk_embed_affine_layer(embedding) # (1 or B, emb_dim)
|
||||
|
||||
# adjust shapes (batching logic)
|
||||
prompt_token = _repeat_batch_dim(prompt_token, B, ndim=2) # (B, n_prompt)
|
||||
prompt_token_len = _repeat_batch_dim(prompt_token_len, B, ndim=1) # (B,)
|
||||
prompt_feat = _repeat_batch_dim(prompt_feat, B, ndim=3) # (B, n_feat, feat_dim=80)
|
||||
prompt_feat_len = _repeat_batch_dim(prompt_feat_len, B, ndim=1) # (B,) or None
|
||||
embedding = _repeat_batch_dim(embedding, B, ndim=2) # (B, emb_dim)
|
||||
|
||||
# concat text and prompt_text
|
||||
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
||||
token = self.input_embedding(torch.clamp(token, min=0, max=self.input_embedding.num_embeddings-1)) * mask
|
||||
|
||||
if (token >= self.vocab_size).any():
|
||||
logger.error(f"{token.max()}>{self.vocab_size}\n out-of-range special tokens found in flow, fix inputs!")
|
||||
token = self.input_embedding(token.long()) * mask
|
||||
|
||||
# text encode
|
||||
h, h_lengths = self.encoder(token, token_len)
|
||||
h, h_masks = self.encoder(token, token_len)
|
||||
if finalize is False:
|
||||
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
|
||||
|
||||
h_lengths = h_masks.sum(dim=-1).squeeze(dim=-1)
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
||||
h = self.encoder_proj(h)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||
# # get conditions
|
||||
conds = torch.zeros([B, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||
conds[:, :mel_len1] = prompt_feat
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||
mask = (~make_pad_mask(h_lengths)).unsqueeze(1).to(h)
|
||||
|
||||
if mask.shape[0] != B:
|
||||
mask = mask.repeat(B, 1, 1)
|
||||
|
||||
feat, _ = self.decoder(
|
||||
mu=h.transpose(1, 2).contiguous(),
|
||||
mask=mask.unsqueeze(1),
|
||||
mask=mask,
|
||||
spks=embedding,
|
||||
cond=conds,
|
||||
n_timesteps=10
|
||||
n_timesteps=n_timesteps,
|
||||
noised_mels=noised_mels,
|
||||
meanflow=meanflow,
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat.float(), None # NOTE jrm: why are they returning None here?
|
||||
return feat, None # NOTE jrm: why are they returning None here?
|
||||
|
|
|
|||
|
|
@ -16,6 +16,11 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from .matcha.flow_matching import BASECFM
|
||||
from .configs import CFM_PARAMS
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def cast_all(*args, dtype):
|
||||
return [a if (not a.dtype.is_floating_point) or a.dtype == dtype else a.to(dtype) for a in args]
|
||||
|
||||
|
||||
class ConditionalCFM(BASECFM):
|
||||
|
|
@ -32,7 +37,6 @@ class ConditionalCFM(BASECFM):
|
|||
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = estimator
|
||||
self.lock = threading.Lock()
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
|
||||
|
|
@ -54,6 +58,8 @@ class ConditionalCFM(BASECFM):
|
|||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
|
||||
raise NotImplementedError("unused, needs updating for meanflow model")
|
||||
|
||||
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
||||
cache_size = flow_cache.shape[2]
|
||||
# fix prompt and overlap part mu and z
|
||||
|
|
@ -69,7 +75,7 @@ class ConditionalCFM(BASECFM):
|
|||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
|
||||
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond, meanflow=False):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
|
|
@ -83,65 +89,60 @@ class ConditionalCFM(BASECFM):
|
|||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
meanflow: meanflow mode
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
||||
t = t.unsqueeze(dim=0)
|
||||
|
||||
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
||||
# Or in future might add like a return_all_steps flag
|
||||
sol = []
|
||||
in_dtype = x.dtype
|
||||
x, t_span, mu, mask, spks, cond = cast_all(x, t_span, mu, mask, spks, cond, dtype=self.estimator.dtype)
|
||||
|
||||
# Duplicated batch dims are for CFG
|
||||
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
||||
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
||||
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
for step in range(1, len(t_span)):
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
x_in[:] = x
|
||||
mask_in[:] = mask
|
||||
mu_in[0] = mu
|
||||
t_in[:] = t.unsqueeze(0)
|
||||
spks_in[0] = spks
|
||||
cond_in[0] = cond
|
||||
dphi_dt = self.forward_estimator(
|
||||
x_in, mask_in,
|
||||
mu_in, t_in,
|
||||
spks_in,
|
||||
cond_in
|
||||
B, T = mu.size(0), x.size(2)
|
||||
x_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype)
|
||||
mask_in = torch.zeros([2 * B, 1, T], device=x.device, dtype=x.dtype)
|
||||
mu_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2 * B ], device=x.device, dtype=x.dtype)
|
||||
spks_in = torch.zeros([2 * B, 80 ], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype)
|
||||
r_in = torch.zeros([2 * B ], device=x.device, dtype=x.dtype) # (only used for meanflow)
|
||||
|
||||
for t, r in zip(t_span[:-1], t_span[1:]):
|
||||
t = t.unsqueeze(dim=0)
|
||||
r = r.unsqueeze(dim=0)
|
||||
# Shapes:
|
||||
# x_in ( 2B, 80, T )
|
||||
# mask_in ( 2B, 1, T )
|
||||
# mu_in ( 2B, 80, T )
|
||||
# t_in ( 2B, )
|
||||
# spks_in ( 2B, 80, )
|
||||
# cond_in ( 2B, 80, T )
|
||||
# r_in ( 2B, )
|
||||
# x ( B, 80, T )
|
||||
# mask ( B, 1, T )
|
||||
# mu ( B, 80, T )
|
||||
# t ( B, )
|
||||
# spks ( B, 80, )
|
||||
# cond ( B, 80, T )
|
||||
# r ( B, )
|
||||
|
||||
x_in[:B] = x_in[B:] = x
|
||||
mask_in[:B] = mask_in[B:] = mask
|
||||
mu_in[:B] = mu
|
||||
t_in[:B] = t_in[B:] = t
|
||||
spks_in[:B] = spks
|
||||
cond_in[:B] = cond
|
||||
r_in[:B] = r_in[B:] = r # (only used for meanflow)
|
||||
dxdt = self.estimator.forward(
|
||||
x=x_in, mask=mask_in, mu=mu_in, t=t_in, spks=spks_in, cond=cond_in,
|
||||
r=r_in if meanflow else None,
|
||||
)
|
||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
||||
x = x + dt * dphi_dt
|
||||
t = t + dt
|
||||
sol.append(x)
|
||||
if step < len(t_span) - 1:
|
||||
dt = t_span[step + 1] - t
|
||||
dxdt, cfg_dxdt = torch.split(dxdt, [B, B], dim=0)
|
||||
dxdt = ((1.0 + self.inference_cfg_rate) * dxdt - self.inference_cfg_rate * cfg_dxdt)
|
||||
dt = r - t
|
||||
x = x + dt * dxdt
|
||||
|
||||
return sol[-1].float()
|
||||
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
||||
if isinstance(self.estimator, torch.nn.Module):
|
||||
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
||||
else:
|
||||
with self.lock:
|
||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('t', (2,))
|
||||
self.estimator.set_input_shape('spks', (2, 80))
|
||||
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
# run trt engine
|
||||
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()])
|
||||
return x
|
||||
|
||||
return x.to(in_dtype)
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||
"""Computes diffusion loss
|
||||
|
|
@ -188,10 +189,11 @@ class ConditionalCFM(BASECFM):
|
|||
class CausalConditionalCFM(ConditionalCFM):
|
||||
def __init__(self, in_channels=240, cfm_params=CFM_PARAMS, n_spks=1, spk_emb_dim=80, estimator=None):
|
||||
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
||||
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
||||
# TODO: BAD BAD IDEA - IT'LL MESS UP DISTILLATION - SETTING TO NONE
|
||||
self.rand_noise = None
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, noised_mels=None, meanflow=False):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
|
|
@ -204,15 +206,41 @@ class CausalConditionalCFM(ConditionalCFM):
|
|||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
|
||||
noised_mels: gt mels noised a time t
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
|
||||
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
|
||||
# fix prompt and overlap part mu and z
|
||||
B = mu.size(0)
|
||||
z = torch.randn_like(mu)
|
||||
|
||||
if noised_mels is not None:
|
||||
prompt_len = mu.size(2) - noised_mels.size(2)
|
||||
z[..., prompt_len:] = noised_mels
|
||||
|
||||
# time steps for reverse diffusion
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
if self.t_scheduler == 'cosine':
|
||||
if (not meanflow) and (self.t_scheduler == 'cosine'):
|
||||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
|
||||
|
||||
# NOTE: right now, the only meanflow models are also distilled models, which don't need CFG
|
||||
# because they were distilled with CFG outputs. We would need to add another hparam and
|
||||
# change the conditional logic here if we want to use CFG inference with a meanflow model.
|
||||
if meanflow:
|
||||
return self.basic_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
|
||||
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, meanflow=meanflow), None
|
||||
|
||||
def basic_euler(self, x, t_span, mu, mask, spks, cond):
|
||||
in_dtype = x.dtype
|
||||
x, t_span, mu, mask, spks, cond = cast_all(x, t_span, mu, mask, spks, cond, dtype=self.estimator.dtype)
|
||||
|
||||
print("S3 Token -> Mel Inference...")
|
||||
for t, r in tqdm(zip(t_span[..., :-1], t_span[..., 1:]), total=t_span.shape[-1] - 1):
|
||||
t, r = t[None], r[None]
|
||||
dxdt = self.estimator.forward(x, mask=mask, mu=mu, t=t, spks=spks, cond=cond, r=r)
|
||||
dt = r - t
|
||||
x = x + dt * dxdt
|
||||
|
||||
return x.to(in_dtype)
|
||||
|
|
|
|||
|
|
@ -46,15 +46,20 @@ def get_resampler(src_sr, dst_sr, device):
|
|||
|
||||
class S3Token2Mel(torch.nn.Module):
|
||||
"""
|
||||
CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms.
|
||||
S3Gen's CFM decoder maps S3 speech tokens to mel-spectrograms.
|
||||
|
||||
TODO: make these modules configurable?
|
||||
"""
|
||||
def __init__(self):
|
||||
def __init__(self, meanflow=False):
|
||||
super().__init__()
|
||||
self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz")
|
||||
self.mel_extractor = mel_spectrogram # TODO: make it a torch module?
|
||||
self.speaker_encoder = CAMPPlus() # use default args
|
||||
self.speaker_encoder = CAMPPlus(
|
||||
# NOTE: This doesn't affect inference. It turns off activation checkpointing
|
||||
# (a training optimization), which causes a crazy DDP error with accelerate
|
||||
memory_efficient=False,
|
||||
)
|
||||
self.meanflow = meanflow
|
||||
|
||||
encoder = UpsampleConformerEncoder(
|
||||
output_size=512,
|
||||
|
|
@ -84,6 +89,7 @@ class S3Token2Mel(torch.nn.Module):
|
|||
num_mid_blocks=12,
|
||||
num_heads=8,
|
||||
act_fn='gelu',
|
||||
meanflow=self.meanflow,
|
||||
)
|
||||
cfm_params = CFM_PARAMS
|
||||
decoder = CausalConditionalCFM(
|
||||
|
|
@ -104,6 +110,11 @@ class S3Token2Mel(torch.nn.Module):
|
|||
params = self.tokenizer.parameters()
|
||||
return next(params).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
params = self.flow.parameters()
|
||||
return next(params).dtype
|
||||
|
||||
def embed_ref(
|
||||
self,
|
||||
ref_wav: torch.Tensor,
|
||||
|
|
@ -122,23 +133,26 @@ class S3Token2Mel(torch.nn.Module):
|
|||
ref_wav = ref_wav.unsqueeze(0) # (B, L)
|
||||
|
||||
if ref_wav.size(1) > 10 * ref_sr:
|
||||
print("WARNING: cosydec received ref longer than 10s")
|
||||
print("WARNING: s3gen received ref longer than 10s")
|
||||
|
||||
ref_wav_24 = ref_wav
|
||||
if ref_sr != S3GEN_SR:
|
||||
ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav)
|
||||
ref_wav_24 = ref_wav_24.to(device=device, dtype=self.dtype)
|
||||
|
||||
ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device)
|
||||
ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(dtype=self.dtype)
|
||||
ref_mels_24_len = None
|
||||
|
||||
# Resample to 16kHz
|
||||
ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device)
|
||||
ref_wav_16 = ref_wav
|
||||
if ref_sr != S3_SR:
|
||||
ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav)
|
||||
|
||||
# Speaker embedding
|
||||
ref_x_vector = self.speaker_encoder.inference(ref_wav_16)
|
||||
ref_x_vector = self.speaker_encoder.inference(ref_wav_16.to(dtype=self.dtype))
|
||||
|
||||
# Tokenize 16khz reference
|
||||
ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16)
|
||||
ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16.float())
|
||||
|
||||
# Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms)
|
||||
if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]:
|
||||
|
|
@ -164,7 +178,10 @@ class S3Token2Mel(torch.nn.Module):
|
|||
ref_sr: Optional[int],
|
||||
# pre-computed ref embedding (prod API)
|
||||
ref_dict: Optional[dict] = None,
|
||||
n_cfm_timesteps = None,
|
||||
finalize: bool = False,
|
||||
speech_token_lens=None,
|
||||
noised_mels=None,
|
||||
):
|
||||
"""
|
||||
Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
|
||||
|
|
@ -192,18 +209,21 @@ class S3Token2Mel(torch.nn.Module):
|
|||
if isinstance(ref_dict[rk], np.ndarray):
|
||||
ref_dict[rk] = torch.from_numpy(ref_dict[rk])
|
||||
if torch.is_tensor(ref_dict[rk]):
|
||||
ref_dict[rk] = ref_dict[rk].to(self.device)
|
||||
ref_dict[rk] = ref_dict[rk].to(device=self.device, dtype=self.dtype)
|
||||
|
||||
if len(speech_tokens.shape) == 1:
|
||||
speech_tokens = speech_tokens.unsqueeze(0)
|
||||
speech_tokens = torch.atleast_2d(speech_tokens)
|
||||
|
||||
# assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now"
|
||||
speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device)
|
||||
# backcompat
|
||||
if speech_token_lens is None:
|
||||
speech_token_lens = torch.LongTensor([st.size(-1) for st in speech_tokens]).to(self.device)
|
||||
|
||||
output_mels, _ = self.flow.inference(
|
||||
token=speech_tokens,
|
||||
token_len=speech_token_lens,
|
||||
finalize=finalize,
|
||||
noised_mels=noised_mels,
|
||||
n_timesteps=n_cfm_timesteps,
|
||||
meanflow=self.meanflow,
|
||||
**ref_dict,
|
||||
)
|
||||
return output_mels
|
||||
|
|
@ -211,13 +231,15 @@ class S3Token2Mel(torch.nn.Module):
|
|||
|
||||
class S3Token2Wav(S3Token2Mel):
|
||||
"""
|
||||
The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules.
|
||||
The decoder of S3Gen is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules.
|
||||
|
||||
TODO: make these modules configurable?
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
ignore_state_dict_missing = ("tokenizer._mel_filters", "tokenizer.window")
|
||||
|
||||
def __init__(self, meanflow=False):
|
||||
super().__init__(meanflow)
|
||||
|
||||
f0_predictor = ConvRNNF0Predictor()
|
||||
self.mel2wav = HiFTGenerator(
|
||||
|
|
@ -234,6 +256,7 @@ class S3Token2Wav(S3Token2Mel):
|
|||
trim_fade = torch.zeros(2 * n_trim)
|
||||
trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2
|
||||
self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting)
|
||||
self.estimator_dtype = "fp32"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -243,9 +266,25 @@ class S3Token2Wav(S3Token2Mel):
|
|||
ref_sr: Optional[int],
|
||||
# pre-computed ref embedding (prod API)
|
||||
ref_dict: Optional[dict] = None,
|
||||
finalize: bool = False
|
||||
finalize: bool = False,
|
||||
speech_token_lens=None,
|
||||
skip_vocoder=False,
|
||||
n_cfm_timesteps=None,
|
||||
noised_mels=None,
|
||||
|
||||
):
|
||||
output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
|
||||
"""
|
||||
Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
|
||||
NOTE: used for sync synthesis only. Please use `S3GenStreamer` for streaming synthesis.
|
||||
"""
|
||||
output_mels = super().forward(
|
||||
speech_tokens, speech_token_lens=speech_token_lens, ref_wav=ref_wav,
|
||||
ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize,
|
||||
n_cfm_timesteps=n_cfm_timesteps, noised_mels=noised_mels,
|
||||
)
|
||||
|
||||
if skip_vocoder:
|
||||
return output_mels
|
||||
|
||||
# TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now.
|
||||
hift_cache_source = torch.zeros(1, 1, 0).to(self.device)
|
||||
|
|
@ -267,14 +306,24 @@ class S3Token2Wav(S3Token2Mel):
|
|||
ref_sr: Optional[int] = None,
|
||||
# pre-computed ref embedding (prod API)
|
||||
ref_dict: Optional[dict] = None,
|
||||
n_cfm_timesteps = None,
|
||||
finalize: bool = False,
|
||||
speech_token_lens=None,
|
||||
):
|
||||
return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
|
||||
n_cfm_timesteps = n_cfm_timesteps or (2 if self.meanflow else 10)
|
||||
noise = None
|
||||
if self.meanflow:
|
||||
noise = torch.randn(1, 80, speech_tokens.size(-1) * 2, dtype=self.dtype, device=self.device)
|
||||
output_mels = super().forward(
|
||||
speech_tokens, speech_token_lens=speech_token_lens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict,
|
||||
n_cfm_timesteps=n_cfm_timesteps, finalize=finalize, noised_mels=noise,
|
||||
)
|
||||
return output_mels
|
||||
|
||||
@torch.inference_mode()
|
||||
def hift_inference(self, speech_feat, cache_source: torch.Tensor = None):
|
||||
if cache_source is None:
|
||||
cache_source = torch.zeros(1, 1, 0).to(self.device)
|
||||
cache_source = torch.zeros(1, 1, 0).to(device=self.device, dtype=self.dtype)
|
||||
return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source)
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
@ -286,11 +335,26 @@ class S3Token2Wav(S3Token2Mel):
|
|||
ref_sr: Optional[int] = None,
|
||||
# pre-computed ref embedding (prod API)
|
||||
ref_dict: Optional[dict] = None,
|
||||
cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here
|
||||
finalize: bool = True,
|
||||
# left as a kwarg because this can change input/output size ratio
|
||||
drop_invalid_tokens=True,
|
||||
n_cfm_timesteps=None,
|
||||
speech_token_lens=None,
|
||||
):
|
||||
output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
|
||||
output_wavs, output_sources = self.hift_inference(output_mels, cache_source)
|
||||
# hallucination prevention, drop special tokens
|
||||
# if drop_invalid_tokens:
|
||||
# speech_tokens, speech_token_lens = drop_invalid(speech_tokens, pad=S3_QUIET_PAD)
|
||||
|
||||
output_mels = self.flow_inference(
|
||||
speech_tokens,
|
||||
speech_token_lens=speech_token_lens,
|
||||
ref_wav=ref_wav,
|
||||
ref_sr=ref_sr,
|
||||
ref_dict=ref_dict,
|
||||
n_cfm_timesteps=n_cfm_timesteps,
|
||||
finalize=True,
|
||||
)
|
||||
output_mels = output_mels.to(dtype=self.dtype) # FIXME (fp16 mode) is this still needed?
|
||||
output_wavs, output_sources = self.hift_inference(output_mels, None)
|
||||
|
||||
# NOTE: ad-hoc method to reduce "spillover" from the reference clip.
|
||||
output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def get_intmeanflow_time_mixer(dims):
|
||||
""""
|
||||
Diagonal init as described in 3.3 https://arxiv.org/pdf/2510.07979
|
||||
"""
|
||||
layer = nn.Linear(dims * 2, dims, bias=False)
|
||||
|
||||
with torch.no_grad():
|
||||
target_weight = torch.zeros(dims, 2 * dims)
|
||||
target_weight[:, 0:dims] = torch.eye(dims)
|
||||
layer.weight.data = target_weight
|
||||
|
||||
return layer
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
D_example = 6
|
||||
|
||||
W_layer = get_intmeanflow_time_mixer(D_example)
|
||||
|
||||
print(f"Layer weight (AFTER init):\n{W_layer.weight.data}\n")
|
||||
|
||||
e_t = torch.tensor([0., 1., 2., 3., 4., 5.])
|
||||
e_r = torch.tensor([6., 7., 8., 9., 10., 11.])
|
||||
e_concat = torch.cat([e_t, e_r]).unsqueeze(0) # Shape (1, 12)
|
||||
|
||||
output = W_layer(e_concat)
|
||||
|
||||
print(f"Test Input e_t: \n{e_t}")
|
||||
print(f"Test Input e_r: \n{e_r}")
|
||||
print(f"Test Input concat: \n{e_concat}")
|
||||
|
||||
print(f"Forward Pass Output: \n{output.squeeze(0)}")
|
||||
|
|
@ -181,6 +181,7 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 1, 1]]
|
||||
"""
|
||||
lengths = lengths.long()
|
||||
batch_size = lengths.size(0)
|
||||
max_len = max_len if max_len > 0 else lengths.max().item()
|
||||
seq_range = torch.arange(0,
|
||||
|
|
|
|||
|
|
@ -32,6 +32,42 @@ LLAMA_520M_CONFIG_DICT = dict(
|
|||
use_cache=True,
|
||||
)
|
||||
|
||||
GPT2_MEDIUM_CONFIG = {
|
||||
"activation_function": "gelu_new",
|
||||
"architectures": [
|
||||
"GPT2LMHeadModel"
|
||||
],
|
||||
"attn_pdrop": 0.1,
|
||||
"bos_token_id": 50256,
|
||||
"embd_pdrop": 0.1,
|
||||
"eos_token_id": 50256,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"model_type": "gpt2",
|
||||
"n_ctx": 8196,
|
||||
"n_embd": 1024,
|
||||
"hidden_size": 1024,
|
||||
"n_head": 16,
|
||||
"n_layer": 24,
|
||||
"n_positions": 8196,
|
||||
"n_special": 0,
|
||||
"predict_special_tokens": True,
|
||||
"resid_pdrop": 0.1,
|
||||
"summary_activation": None,
|
||||
"summary_first_dropout": 0.1,
|
||||
"summary_proj_to_labels": True,
|
||||
"summary_type": "cls_index",
|
||||
"summary_use_proj": True,
|
||||
"task_specific_params": {
|
||||
"text-generation": {
|
||||
"do_sample": True,
|
||||
"max_length": 50
|
||||
}
|
||||
},
|
||||
"vocab_size": 50276,
|
||||
}
|
||||
|
||||
LLAMA_CONFIGS = {
|
||||
"Llama_520M": LLAMA_520M_CONFIG_DICT,
|
||||
"GPT2_medium": GPT2_MEDIUM_CONFIG,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class T3Config:
|
|||
|
||||
@property
|
||||
def is_multilingual(self):
|
||||
return self.text_tokens_dict_size == 2352
|
||||
return self.text_tokens_dict_size == 2454
|
||||
|
||||
@classmethod
|
||||
def english_only(cls):
|
||||
|
|
@ -38,4 +38,4 @@ class T3Config:
|
|||
@classmethod
|
||||
def multilingual(cls):
|
||||
"""Create configuration for multilingual TTS model."""
|
||||
return cls(text_tokens_dict_size=2352)
|
||||
return cls(text_tokens_dict_size=2454)
|
||||
|
|
|
|||
|
|
@ -9,9 +9,15 @@ from tqdm import tqdm
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
from transformers import LlamaModel, LlamaConfig
|
||||
from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor, MinPLogitsWarper
|
||||
|
||||
from transformers import LlamaModel, LlamaConfig, GPT2Config, GPT2Model
|
||||
from transformers.generation.logits_process import (
|
||||
LogitsProcessorList,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
MinPLogitsWarper,
|
||||
)
|
||||
from .modules.learned_pos_emb import LearnedPositionEmbeddings
|
||||
|
||||
from .modules.cond_enc import T3CondEnc, T3Cond
|
||||
|
|
@ -43,11 +49,20 @@ class T3(nn.Module):
|
|||
|
||||
def __init__(self, hp=None):
|
||||
if hp is None:
|
||||
hp = T3Config.english_only() # Default to English-only config for backward compatibility
|
||||
hp = T3Config.english_only()
|
||||
super().__init__()
|
||||
self.hp = hp
|
||||
self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
|
||||
self.tfmr = LlamaModel(self.cfg)
|
||||
|
||||
config_dict = LLAMA_CONFIGS[hp.llama_config_name]
|
||||
self.is_gpt = config_dict.get("model_type") == "gpt2"
|
||||
|
||||
if self.is_gpt:
|
||||
self.cfg = GPT2Config(**config_dict)
|
||||
self.tfmr = GPT2Model(self.cfg)
|
||||
else:
|
||||
self.cfg = LlamaConfig(**config_dict)
|
||||
self.tfmr = LlamaModel(self.cfg)
|
||||
|
||||
self.dim = self.cfg.hidden_size
|
||||
self.deepspeed_patch_applied = False
|
||||
|
||||
|
|
@ -57,6 +72,8 @@ class T3(nn.Module):
|
|||
self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim)
|
||||
|
||||
# custom position embedding
|
||||
self.text_pos_emb = None
|
||||
self.speech_pos_emb = None
|
||||
if hp.input_pos_emb == "learned":
|
||||
max_text_seq_len = hp.max_text_tokens + 2
|
||||
self.text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, self.dim)
|
||||
|
|
@ -66,7 +83,7 @@ class T3(nn.Module):
|
|||
|
||||
# logit projection
|
||||
self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False)
|
||||
self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=False)
|
||||
self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=self.is_gpt)
|
||||
self.compiled = False
|
||||
|
||||
@property
|
||||
|
|
@ -78,8 +95,9 @@ class T3(nn.Module):
|
|||
Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`.
|
||||
"""
|
||||
if t3_cond.cond_prompt_speech_tokens is not None and t3_cond.cond_prompt_speech_emb is None:
|
||||
t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens) + \
|
||||
self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
|
||||
t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens)
|
||||
if not self.is_gpt:
|
||||
t3_cond.cond_prompt_speech_emb += self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
|
||||
return self.cond_enc(t3_cond) # (B, len_cond, dim)
|
||||
|
||||
def prepare_input_embeds(
|
||||
|
|
@ -93,7 +111,7 @@ class T3(nn.Module):
|
|||
# prepare input embeddings (skip backbone tranformer embeddings)
|
||||
cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim)
|
||||
text_emb = self.text_emb(text_tokens) # (B, len_text, dim)
|
||||
if cfg_weight > 0.0:
|
||||
if cfg_weight > 0.0 and not self.is_gpt:
|
||||
text_emb[1].zero_() # CFG uncond
|
||||
|
||||
speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim)
|
||||
|
|
@ -332,7 +350,7 @@ class T3(nn.Module):
|
|||
|
||||
# ---- Generation Loop using kv_cache ----
|
||||
for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
|
||||
logits_step = output.logits[:, -1, :]
|
||||
logits_step = output.logits[:, -1, :]
|
||||
# CFG combine → (1, V)
|
||||
cond = logits_step[0:1, :]
|
||||
uncond = logits_step[1:2, :]
|
||||
|
|
@ -392,3 +410,81 @@ class T3(nn.Module):
|
|||
# Concatenate all predicted tokens along the sequence dimension.
|
||||
predicted_tokens = torch.cat(predicted, dim=1) # shape: (B, num_tokens)
|
||||
return predicted_tokens
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference_turbo(self, t3_cond, text_tokens, temperature=0.8, top_k=1000, top_p=0.95, repetition_penalty=1.2,
|
||||
max_gen_len=1000):
|
||||
|
||||
logits_processors = LogitsProcessorList()
|
||||
if temperature > 0 and temperature != 1.0:
|
||||
logits_processors.append(TemperatureLogitsWarper(temperature))
|
||||
if top_k > 0:
|
||||
logits_processors.append(TopKLogitsWarper(top_k))
|
||||
if top_p < 1.0:
|
||||
logits_processors.append(TopPLogitsWarper(top_p))
|
||||
if repetition_penalty != 1.0:
|
||||
logits_processors.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
|
||||
|
||||
|
||||
speech_start_token = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])
|
||||
embeds, _ = self.prepare_input_embeds(
|
||||
t3_cond=t3_cond,
|
||||
text_tokens=text_tokens,
|
||||
speech_tokens=speech_start_token,
|
||||
cfg_weight=0.0,
|
||||
)
|
||||
|
||||
generated_speech_tokens = []
|
||||
|
||||
llm_outputs = self.tfmr(
|
||||
inputs_embeds=embeds,
|
||||
use_cache=True
|
||||
)
|
||||
|
||||
hidden_states = llm_outputs[0]
|
||||
past_key_values = llm_outputs.past_key_values
|
||||
|
||||
speech_hidden = hidden_states[:, -1:]
|
||||
speech_logits = self.speech_head(speech_hidden)
|
||||
|
||||
processed_logits = logits_processors(speech_start_token, speech_logits[:, -1, :])
|
||||
probs = F.softmax(processed_logits, dim=-1)
|
||||
next_speech_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
generated_speech_tokens.append(next_speech_token)
|
||||
current_speech_token = next_speech_token
|
||||
|
||||
for _ in tqdm(range(max_gen_len)):
|
||||
current_speech_embed = self.speech_emb(current_speech_token)
|
||||
|
||||
llm_outputs = self.tfmr(
|
||||
inputs_embeds=current_speech_embed,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True
|
||||
)
|
||||
|
||||
hidden_states = llm_outputs[0]
|
||||
past_key_values = llm_outputs.past_key_values
|
||||
speech_logits = self.speech_head(hidden_states)
|
||||
|
||||
input_ids = torch.cat(generated_speech_tokens, dim=1)
|
||||
processed_logits = logits_processors(input_ids, speech_logits[:, -1, :])
|
||||
if torch.all(processed_logits == -float("inf")):
|
||||
print("Warning: All logits are -inf")
|
||||
break
|
||||
|
||||
probs = F.softmax(processed_logits, dim=-1)
|
||||
next_speech_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
generated_speech_tokens.append(next_speech_token)
|
||||
current_speech_token = next_speech_token
|
||||
if torch.all(next_speech_token == self.hp.stop_speech_token):
|
||||
break
|
||||
|
||||
all_tokens = torch.cat(generated_speech_tokens, dim=1)
|
||||
|
||||
# Remove EOS token if present
|
||||
if all_tokens.size(1) > 0 and all_tokens[0, -1] == self.hp.stop_speech_token:
|
||||
all_tokens = all_tokens[:, :-1]
|
||||
|
||||
return all_tokens
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import logging
|
||||
import json
|
||||
import re
|
||||
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from unicodedata import category
|
||||
from unicodedata import category, normalize
|
||||
from tokenizers import Tokenizer
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
|
@ -33,7 +32,7 @@ class EnTokenizer:
|
|||
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
|
||||
return text_tokens
|
||||
|
||||
def encode( self, txt: str, verbose=False):
|
||||
def encode(self, txt: str):
|
||||
"""
|
||||
clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer
|
||||
"""
|
||||
|
|
@ -46,8 +45,7 @@ class EnTokenizer:
|
|||
if isinstance(seq, torch.Tensor):
|
||||
seq = seq.cpu().numpy()
|
||||
|
||||
txt: str = self.tokenizer.decode(seq,
|
||||
skip_special_tokens=False)
|
||||
txt: str = self.tokenizer.decode(seq, skip_special_tokens=False)
|
||||
txt = txt.replace(' ', '')
|
||||
txt = txt.replace(SPACE, ' ')
|
||||
txt = txt.replace(EOT, '')
|
||||
|
|
@ -61,6 +59,7 @@ REPO_ID = "ResembleAI/chatterbox"
|
|||
# Global instances for optional dependencies
|
||||
_kakasi = None
|
||||
_dicta = None
|
||||
_russian_stresser = None
|
||||
|
||||
|
||||
def is_kanji(c: str) -> bool:
|
||||
|
|
@ -191,7 +190,7 @@ class ChineseCangjieConverter:
|
|||
def _init_segmenter(self):
|
||||
"""Initialize pkuseg segmenter."""
|
||||
try:
|
||||
from pkuseg import pkuseg
|
||||
from spacy_pkuseg import pkuseg
|
||||
self.segmenter = pkuseg()
|
||||
except ImportError:
|
||||
logger.warning("pkuseg not available - Chinese segmentation will be skipped")
|
||||
|
|
@ -235,6 +234,25 @@ class ChineseCangjieConverter:
|
|||
return "".join(output)
|
||||
|
||||
|
||||
def add_russian_stress(text: str) -> str:
|
||||
"""Russian text normalization: adds stress marks to Russian text."""
|
||||
global _russian_stresser
|
||||
|
||||
try:
|
||||
if _russian_stresser is None:
|
||||
from russian_text_stresser.text_stresser import RussianTextStresser
|
||||
_russian_stresser = RussianTextStresser()
|
||||
|
||||
return _russian_stresser.stress_text(text)
|
||||
|
||||
except ImportError:
|
||||
logger.warning("russian_text_stresser not available - Russian stress labeling skipped")
|
||||
return text
|
||||
except Exception as e:
|
||||
logger.warning(f"Russian stress labeling failed: {e}")
|
||||
return text
|
||||
|
||||
|
||||
class MTLTokenizer:
|
||||
def __init__(self, vocab_file_path):
|
||||
self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
|
||||
|
|
@ -247,12 +265,26 @@ class MTLTokenizer:
|
|||
assert SOT in voc
|
||||
assert EOT in voc
|
||||
|
||||
def text_to_tokens(self, text: str, language_id: str = None):
|
||||
text_tokens = self.encode(text, language_id=language_id)
|
||||
def preprocess_text(self, raw_text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
|
||||
"""
|
||||
Text preprocessor that handles lowercase conversion and NFKD normalization.
|
||||
"""
|
||||
preprocessed_text = raw_text
|
||||
if lowercase:
|
||||
preprocessed_text = preprocessed_text.lower()
|
||||
if nfkd_normalize:
|
||||
preprocessed_text = normalize("NFKD", preprocessed_text)
|
||||
|
||||
return preprocessed_text
|
||||
|
||||
def text_to_tokens(self, text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
|
||||
text_tokens = self.encode(text, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize)
|
||||
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
|
||||
return text_tokens
|
||||
|
||||
def encode(self, txt: str, language_id: str = None):
|
||||
def encode(self, txt: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
|
||||
txt = self.preprocess_text(txt, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize)
|
||||
|
||||
# Language-specific text processing
|
||||
if language_id == 'zh':
|
||||
txt = self.cangjie_converter(txt)
|
||||
|
|
@ -262,6 +294,8 @@ class MTLTokenizer:
|
|||
txt = add_hebrew_diacritics(txt)
|
||||
elif language_id == 'ko':
|
||||
txt = korean_normalize(txt)
|
||||
elif language_id == 'ru':
|
||||
txt = add_russian_stress(txt)
|
||||
|
||||
# Prepend language token
|
||||
if language_id:
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ class ChatterboxMultilingualTTS:
|
|||
ve.to(device).eval()
|
||||
|
||||
t3 = T3(T3Config.multilingual())
|
||||
t3_state = load_safetensors(ckpt_dir / "t3_23lang.safetensors")
|
||||
t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors")
|
||||
if "model" in t3_state.keys():
|
||||
t3_state = t3_state["model"][0]
|
||||
t3.load_state_dict(t3_state)
|
||||
|
|
@ -181,7 +181,7 @@ class ChatterboxMultilingualTTS:
|
|||
s3gen.to(device).eval()
|
||||
|
||||
tokenizer = MTLTokenizer(
|
||||
str(ckpt_dir / "mtl_tokenizer.json")
|
||||
str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json")
|
||||
)
|
||||
|
||||
conds = None
|
||||
|
|
@ -197,7 +197,7 @@ class ChatterboxMultilingualTTS:
|
|||
repo_id=REPO_ID,
|
||||
repo_type="model",
|
||||
revision="main",
|
||||
allow_patterns=["ve.pt", "t3_23lang.safetensors", "s3gen.pt", "mtl_tokenizer.json", "conds.pt", "Cangjie5_TC.json"],
|
||||
allow_patterns=["ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"],
|
||||
token=os.getenv("HF_TOKEN"),
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -269,4 +269,4 @@ class ChatterboxTTS:
|
|||
)
|
||||
wav = wav.squeeze(0).detach().cpu().numpy()
|
||||
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
|
||||
return torch.from_numpy(watermarked_wav).unsqueeze(0)
|
||||
return torch.from_numpy(watermarked_wav).unsqueeze(0)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,296 @@
|
|||
import os
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
import perth
|
||||
import pyloudnorm as ln
|
||||
|
||||
from safetensors.torch import load_file
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from .models.t3 import T3
|
||||
from .models.s3tokenizer import S3_SR
|
||||
from .models.s3gen import S3GEN_SR, S3Gen
|
||||
from .models.tokenizers import EnTokenizer
|
||||
from .models.voice_encoder import VoiceEncoder
|
||||
from .models.t3.modules.cond_enc import T3Cond
|
||||
from .models.t3.modules.t3_config import T3Config
|
||||
from .models.s3gen.const import S3GEN_SIL
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REPO_ID = "ResembleAI/chatterbox-turbo"
|
||||
|
||||
|
||||
def punc_norm(text: str) -> str:
|
||||
"""
|
||||
Quick cleanup func for punctuation from LLMs or
|
||||
containing chars not seen often in the dataset
|
||||
"""
|
||||
if len(text) == 0:
|
||||
return "You need to add some text for me to talk."
|
||||
|
||||
# Capitalise first letter
|
||||
if text[0].islower():
|
||||
text = text[0].upper() + text[1:]
|
||||
|
||||
# Remove multiple space chars
|
||||
text = " ".join(text.split())
|
||||
|
||||
# Replace uncommon/llm punc
|
||||
punc_to_replace = [
|
||||
("…", ", "),
|
||||
(":", ","),
|
||||
("—", "-"),
|
||||
("–", "-"),
|
||||
(" ,", ","),
|
||||
("“", "\""),
|
||||
("”", "\""),
|
||||
("‘", "'"),
|
||||
("’", "'"),
|
||||
]
|
||||
for old_char_sequence, new_char in punc_to_replace:
|
||||
text = text.replace(old_char_sequence, new_char)
|
||||
|
||||
# Add full stop if no ending punc
|
||||
text = text.rstrip(" ")
|
||||
sentence_enders = {".", "!", "?", "-", ","}
|
||||
if not any(text.endswith(p) for p in sentence_enders):
|
||||
text += "."
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conditionals:
|
||||
"""
|
||||
Conditionals for T3 and S3Gen
|
||||
- T3 conditionals:
|
||||
- speaker_emb
|
||||
- clap_emb
|
||||
- cond_prompt_speech_tokens
|
||||
- cond_prompt_speech_emb
|
||||
- emotion_adv
|
||||
- S3Gen conditionals:
|
||||
- prompt_token
|
||||
- prompt_token_len
|
||||
- prompt_feat
|
||||
- prompt_feat_len
|
||||
- embedding
|
||||
"""
|
||||
t3: T3Cond
|
||||
gen: dict
|
||||
|
||||
def to(self, device):
|
||||
self.t3 = self.t3.to(device=device)
|
||||
for k, v in self.gen.items():
|
||||
if torch.is_tensor(v):
|
||||
self.gen[k] = v.to(device=device)
|
||||
return self
|
||||
|
||||
def save(self, fpath: Path):
|
||||
arg_dict = dict(
|
||||
t3=self.t3.__dict__,
|
||||
gen=self.gen
|
||||
)
|
||||
torch.save(arg_dict, fpath)
|
||||
|
||||
@classmethod
|
||||
def load(cls, fpath, map_location="cpu"):
|
||||
if isinstance(map_location, str):
|
||||
map_location = torch.device(map_location)
|
||||
kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
|
||||
return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
|
||||
|
||||
|
||||
class ChatterboxTurboTTS:
|
||||
ENC_COND_LEN = 15 * S3_SR
|
||||
DEC_COND_LEN = 10 * S3GEN_SR
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
t3: T3,
|
||||
s3gen: S3Gen,
|
||||
ve: VoiceEncoder,
|
||||
tokenizer: EnTokenizer,
|
||||
device: str,
|
||||
conds: Conditionals = None,
|
||||
):
|
||||
self.sr = S3GEN_SR # sample rate of synthesized audio
|
||||
self.t3 = t3
|
||||
self.s3gen = s3gen
|
||||
self.ve = ve
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.conds = conds
|
||||
self.watermarker = perth.PerthImplicitWatermarker()
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, ckpt_dir, device) -> 'ChatterboxTurboTTS':
|
||||
ckpt_dir = Path(ckpt_dir)
|
||||
|
||||
# Always load to CPU first for non-CUDA devices to handle CUDA-saved models
|
||||
if device in ["cpu", "mps"]:
|
||||
map_location = torch.device('cpu')
|
||||
else:
|
||||
map_location = None
|
||||
|
||||
ve = VoiceEncoder()
|
||||
ve.load_state_dict(
|
||||
load_file(ckpt_dir / "ve.safetensors")
|
||||
)
|
||||
ve.to(device).eval()
|
||||
|
||||
# Turbo specific hp
|
||||
hp = T3Config(text_tokens_dict_size=50276)
|
||||
hp.llama_config_name = "GPT2_medium"
|
||||
hp.speech_tokens_dict_size = 6563
|
||||
hp.input_pos_emb = None
|
||||
hp.speech_cond_prompt_len = 375
|
||||
hp.use_perceiver_resampler = False
|
||||
hp.emotion_adv = False
|
||||
|
||||
t3 = T3(hp)
|
||||
t3_state = load_file(ckpt_dir / "t3_turbo_v1.safetensors")
|
||||
if "model" in t3_state.keys():
|
||||
t3_state = t3_state["model"][0]
|
||||
t3.load_state_dict(t3_state)
|
||||
del t3.tfmr.wte
|
||||
t3.to(device).eval()
|
||||
|
||||
s3gen = S3Gen(meanflow=True)
|
||||
weights = load_file(ckpt_dir / "s3gen_meanflow.safetensors")
|
||||
s3gen.load_state_dict(
|
||||
weights, strict=True
|
||||
)
|
||||
s3gen.to(device).eval()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
if len(tokenizer) != 50276:
|
||||
print(f"WARNING: Tokenizer len {len(tokenizer)} != 50276")
|
||||
|
||||
conds = None
|
||||
builtin_voice = ckpt_dir / "conds.pt"
|
||||
if builtin_voice.exists():
|
||||
conds = Conditionals.load(builtin_voice, map_location=map_location).to(device)
|
||||
|
||||
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, device) -> 'ChatterboxTurboTTS':
|
||||
# Check if MPS is available on macOS
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
if not torch.backends.mps.is_built():
|
||||
print("MPS not available because the current PyTorch install was not built with MPS enabled.")
|
||||
else:
|
||||
print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
|
||||
device = "cpu"
|
||||
|
||||
local_path = snapshot_download(
|
||||
repo_id=REPO_ID,
|
||||
token=os.getenv("HF_TOKEN") or True,
|
||||
# Optional: Filter to download only what you need
|
||||
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"]
|
||||
)
|
||||
|
||||
return cls.from_local(local_path, device)
|
||||
|
||||
def norm_loudness(self, wav, sr, target_lufs=-27):
|
||||
try:
|
||||
meter = ln.Meter(sr)
|
||||
loudness = meter.integrated_loudness(wav)
|
||||
gain_db = target_lufs - loudness
|
||||
gain_linear = 10.0 ** (gain_db / 20.0)
|
||||
if math.isfinite(gain_linear) and gain_linear > 0.0:
|
||||
wav = wav * gain_linear
|
||||
except Exception as e:
|
||||
print(f"Warning: Error in norm_loudness, skipping: {e}")
|
||||
|
||||
return wav
|
||||
|
||||
def prepare_conditionals(self, wav_fpath, exaggeration=0.5, norm_loudness=True):
|
||||
## Load and norm reference wav
|
||||
s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
|
||||
|
||||
assert len(s3gen_ref_wav) / _sr > 5.0, "Audio prompt must be longer than 5 seconds!"
|
||||
|
||||
if norm_loudness:
|
||||
s3gen_ref_wav = self.norm_loudness(s3gen_ref_wav, _sr)
|
||||
|
||||
ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
|
||||
|
||||
s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
|
||||
s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
|
||||
|
||||
# Speech cond prompt tokens
|
||||
if plen := self.t3.hp.speech_cond_prompt_len:
|
||||
s3_tokzr = self.s3gen.tokenizer
|
||||
t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
|
||||
t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
|
||||
|
||||
# Voice-encoder speaker embedding
|
||||
ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
|
||||
ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
|
||||
|
||||
t3_cond = T3Cond(
|
||||
speaker_emb=ve_embed,
|
||||
cond_prompt_speech_tokens=t3_cond_prompt_tokens,
|
||||
emotion_adv=exaggeration * torch.ones(1, 1, 1),
|
||||
).to(device=self.device)
|
||||
self.conds = Conditionals(t3_cond, s3gen_ref_dict)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
text,
|
||||
repetition_penalty=1.2,
|
||||
min_p=0.00,
|
||||
top_p=0.95,
|
||||
audio_prompt_path=None,
|
||||
exaggeration=0.0,
|
||||
cfg_weight=0.0,
|
||||
temperature=0.8,
|
||||
top_k=1000,
|
||||
norm_loudness=True,
|
||||
):
|
||||
if audio_prompt_path:
|
||||
self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration, norm_loudness=norm_loudness)
|
||||
else:
|
||||
assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
|
||||
|
||||
if cfg_weight > 0.0 or exaggeration > 0.0 or min_p > 0.0:
|
||||
logger.warning("CFG, min_p and exaggeration are not supported by Turbo version and will be ignored.")
|
||||
|
||||
# Norm and tokenize text
|
||||
text = punc_norm(text)
|
||||
text_tokens = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
||||
text_tokens = text_tokens.input_ids.to(self.device)
|
||||
|
||||
speech_tokens = self.t3.inference_turbo(
|
||||
t3_cond=self.conds.t3,
|
||||
text_tokens=text_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
|
||||
# Remove OOV tokens and add silence to end
|
||||
speech_tokens = speech_tokens[speech_tokens < 6561]
|
||||
speech_tokens = speech_tokens.to(self.device)
|
||||
silence = torch.tensor([S3GEN_SIL, S3GEN_SIL, S3GEN_SIL]).long().to(self.device)
|
||||
speech_tokens = torch.cat([speech_tokens, silence])
|
||||
|
||||
wav, _ = self.s3gen.inference(
|
||||
speech_tokens=speech_tokens,
|
||||
ref_dict=self.conds.gen,
|
||||
n_cfm_timesteps=2,
|
||||
)
|
||||
wav = wav.squeeze(0).detach().cpu().numpy()
|
||||
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
|
||||
return torch.from_numpy(watermarked_wav).unsqueeze(0)
|
||||
Loading…
Reference in New Issue