Compare commits
No commits in common. "master" and "add-disc" have entirely different histories.
|
|
@ -1,23 +0,0 @@
|
|||
name: Test Installation
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "master" ]
|
||||
pull_request:
|
||||
branches: [ "master" ]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Test Standard Install
|
||||
run: |
|
||||
pip install -e .
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 707 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 763 KiB |
178
README.md
178
README.md
|
|
@ -1,112 +1,33 @@
|
|||

|
||||
|
||||
<img width="1200" alt="cb-big2" src="https://github.com/user-attachments/assets/bd8c5f03-e91d-4ee5-b680-57355da204d1" />
|
||||
|
||||
# Chatterbox TTS
|
||||
|
||||
[](https://resemble-ai.github.io/chatterbox_turbo_demopage/)
|
||||
[](https://huggingface.co/spaces/ResembleAI/chatterbox-turbo-demo)
|
||||
[](https://resemble-ai.github.io/chatterbox_demopage/)
|
||||
[](https://huggingface.co/spaces/ResembleAI/Chatterbox)
|
||||
[](https://podonos.com/resembleai/chatterbox)
|
||||
[](https://discord.gg/rJq9cRJBJ6)
|
||||
|
||||
_Made with ♥️ by <a href="https://resemble.ai" target="_blank"><img width="100" alt="resemble-logo-horizontal" src="https://github.com/user-attachments/assets/35cf756b-3506-4943-9c72-c05ddfa4e525" /></a>
|
||||
_Made with ♥️ by <img width="100" alt="resemble-logo-horizontal" src="https://github.com/user-attachments/assets/35cf756b-3506-4943-9c72-c05ddfa4e525" />_
|
||||
|
||||
**Chatterbox** is a family of three state-of-the-art, open-source text-to-speech models by Resemble AI.
|
||||
We're excited to introduce Chatterbox, [Resemble AI's](https://resemble.ai) first production-grade open source TTS model. Licensed under MIT, Chatterbox has been benchmarked against leading closed-source systems like ElevenLabs, and is consistently preferred in side-by-side evaluations.
|
||||
|
||||
We are excited to introduce **Chatterbox-Turbo**, our most efficient model yet. Built on a streamlined 350M parameter architecture, **Turbo** delivers high-quality speech with less compute and VRAM than our previous models. We have also distilled the speech-token-to-mel decoder, previously a bottleneck, reducing generation from 10 steps to just **one**, while retaining high-fidelity audio output.
|
||||
|
||||
**Paralinguistic tags** are now native to the Turbo model, allowing you to use `[cough]`, `[laugh]`, `[chuckle]`, and more to add distinct realism. While Turbo was built primarily for low-latency voice agents, it excels at narration and creative workflows.
|
||||
Whether you're working on memes, videos, games, or AI agents, Chatterbox brings your content to life. It's also the first open source TTS model to support **emotion exaggeration control**, a powerful feature that makes your voices stand out. Try it now on our [Hugging Face Gradio app.](https://huggingface.co/spaces/ResembleAI/Chatterbox)
|
||||
|
||||
If you like the model but need to scale or tune it for higher accuracy, check out our competitively priced TTS service (<a href="https://resemble.ai">link</a>). It delivers reliable performance with ultra-low latency of sub 200ms—ideal for production use in agents, applications, or interactive media.
|
||||
|
||||
<img width="1200" height="600" alt="Podonos Turbo Eval" src="https://storage.googleapis.com/chatterbox-demo-samples/turbo/podonos_turbo.png" />
|
||||
# Key Details
|
||||
- SoTA zeroshot TTS
|
||||
- 0.5B Llama backbone
|
||||
- Unique exaggeration/intensity control
|
||||
- Ultra-stable with alignment-informed inference
|
||||
- Trained on 0.5M hours of cleaned data
|
||||
- Watermarked outputs
|
||||
- Easy voice conversion script
|
||||
- [Outperforms ElevenLabs](https://podonos.com/resembleai/chatterbox)
|
||||
|
||||
### ⚡ Model Zoo
|
||||
|
||||
Choose the right model for your application.
|
||||
|
||||
| Model | Size | Languages | Key Features | Best For | 🤗 | Examples |
|
||||
|:----------------------------------------------------------------------------------------------------------------| :--- | :--- |:--------------------------------------------------------|:---------------------------------------------|:--------------------------------------------------------------------------| :--- |
|
||||
| **Chatterbox-Turbo** | **350M** | **English** | Paralinguistic Tags (`[laugh]`), Lower Compute and VRAM | Zero-shot voice agents, Production | [Demo](https://huggingface.co/spaces/ResembleAI/chatterbox-turbo-demo) | [Listen](https://resemble-ai.github.io/chatterbox_turbo_demopage/) |
|
||||
| Chatterbox-Multilingual [(Language list)](#supported-languages) | 500M | 23+ | Zero-shot cloning, Multiple Languages | Global applications, Localization | [Demo](https://huggingface.co/spaces/ResembleAI/Chatterbox-Multilingual-TTS) | [Listen](https://resemble-ai.github.io/chatterbox_demopage/) |
|
||||
| Chatterbox [(Tips and Tricks)](#original-chatterbox-tips) | 500M | English | CFG & Exaggeration tuning | General zero-shot TTS with creative controls | [Demo](https://huggingface.co/spaces/ResembleAI/Chatterbox) | [Listen](https://resemble-ai.github.io/chatterbox_demopage/) |
|
||||
|
||||
## Installation
|
||||
```shell
|
||||
pip install chatterbox-tts
|
||||
```
|
||||
|
||||
Alternatively, you can install from source:
|
||||
```shell
|
||||
# conda create -yn chatterbox python=3.11
|
||||
# conda activate chatterbox
|
||||
|
||||
git clone https://github.com/resemble-ai/chatterbox.git
|
||||
cd chatterbox
|
||||
pip install -e .
|
||||
```
|
||||
We developed and tested Chatterbox on Python 3.11 on Debian 11 OS; the versions of the dependencies are pinned in `pyproject.toml` to ensure consistency. You can modify the code or dependencies in this installation mode.
|
||||
|
||||
## Usage
|
||||
|
||||
##### Chatterbox-Turbo
|
||||
|
||||
```python
|
||||
import torchaudio as ta
|
||||
import torch
|
||||
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
||||
|
||||
# Load the Turbo model
|
||||
model = ChatterboxTurboTTS.from_pretrained(device="cuda")
|
||||
|
||||
# Generate with Paralinguistic Tags
|
||||
text = "Hi there, Sarah here from MochaFone calling you back [chuckle], have you got one minute to chat about the billing issue?"
|
||||
|
||||
# Generate audio (requires a reference clip for voice cloning)
|
||||
wav = model.generate(text, audio_prompt_path="your_10s_ref_clip.wav")
|
||||
|
||||
ta.save("test-turbo.wav", wav, model.sr)
|
||||
```
|
||||
|
||||
##### Chatterbox and Chatterbox-Multilingual
|
||||
|
||||
```python
|
||||
|
||||
import torchaudio as ta
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
||||
|
||||
# English example
|
||||
model = ChatterboxTTS.from_pretrained(device="cuda")
|
||||
|
||||
text = "Ezreal and Jinx teamed up with Ahri, Yasuo, and Teemo to take down the enemy's Nexus in an epic late-game pentakill."
|
||||
wav = model.generate(text)
|
||||
ta.save("test-english.wav", wav, model.sr)
|
||||
|
||||
# Multilingual examples
|
||||
multilingual_model = ChatterboxMultilingualTTS.from_pretrained(device=device)
|
||||
|
||||
french_text = "Bonjour, comment ça va? Ceci est le modèle de synthèse vocale multilingue Chatterbox, il prend en charge 23 langues."
|
||||
wav_french = multilingual_model.generate(spanish_text, language_id="fr")
|
||||
ta.save("test-french.wav", wav_french, model.sr)
|
||||
|
||||
chinese_text = "你好,今天天气真不错,希望你有一个愉快的周末。"
|
||||
wav_chinese = multilingual_model.generate(chinese_text, language_id="zh")
|
||||
ta.save("test-chinese.wav", wav_chinese, model.sr)
|
||||
|
||||
# If you want to synthesize with a different voice, specify the audio prompt
|
||||
AUDIO_PROMPT_PATH = "YOUR_FILE.wav"
|
||||
wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH)
|
||||
ta.save("test-2.wav", wav, model.sr)
|
||||
```
|
||||
See `example_tts.py` and `example_vc.py` for more examples.
|
||||
|
||||
## Supported Languages
|
||||
Arabic (ar) • Danish (da) • German (de) • Greek (el) • English (en) • Spanish (es) • Finnish (fi) • French (fr) • Hebrew (he) • Hindi (hi) • Italian (it) • Japanese (ja) • Korean (ko) • Malay (ms) • Dutch (nl) • Norwegian (no) • Polish (pl) • Portuguese (pt) • Russian (ru) • Swedish (sv) • Swahili (sw) • Turkish (tr) • Chinese (zh)
|
||||
|
||||
## Original Chatterbox Tips
|
||||
# Tips
|
||||
- **General Use (TTS and Voice Agents):**
|
||||
- Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip’s language. To mitigate this, set `cfg_weight` to `0`.
|
||||
- The default settings (`exaggeration=0.5`, `cfg_weight=0.5`) work well for most prompts across all languages.
|
||||
- The default settings (`exaggeration=0.5`, `cfg_weight=0.5`) work well for most prompts.
|
||||
- If the reference speaker has a fast speaking style, lowering `cfg_weight` to around `0.3` can improve pacing.
|
||||
|
||||
- **Expressive or Dramatic Speech:**
|
||||
|
|
@ -114,55 +35,40 @@ Arabic (ar) • Danish (da) • German (de) • Greek (el) • English (en) •
|
|||
- Higher `exaggeration` tends to speed up speech; reducing `cfg_weight` helps compensate with slower, more deliberate pacing.
|
||||
|
||||
|
||||
## Built-in PerTh Watermarking for Responsible AI
|
||||
|
||||
Every audio file generated by Chatterbox includes [Resemble AI's Perth (Perceptual Threshold) Watermarker](https://github.com/resemble-ai/perth) - imperceptible neural watermarks that survive MP3 compression, audio editing, and common manipulations while maintaining nearly 100% detection accuracy.
|
||||
|
||||
|
||||
## Watermark extraction
|
||||
|
||||
You can look for the watermark using the following script.
|
||||
|
||||
```python
|
||||
import perth
|
||||
import librosa
|
||||
|
||||
AUDIO_PATH = "YOUR_FILE.wav"
|
||||
|
||||
# Load the watermarked audio
|
||||
watermarked_audio, sr = librosa.load(AUDIO_PATH, sr=None)
|
||||
|
||||
# Initialize watermarker (same as used for embedding)
|
||||
watermarker = perth.PerthImplicitWatermarker()
|
||||
|
||||
# Extract watermark
|
||||
watermark = watermarker.get_watermark(watermarked_audio, sample_rate=sr)
|
||||
print(f"Extracted watermark: {watermark}")
|
||||
# Output: 0.0 (no watermark) or 1.0 (watermarked)
|
||||
# Installation
|
||||
```
|
||||
pip install chatterbox-tts
|
||||
```
|
||||
|
||||
|
||||
## Official Discord
|
||||
# Usage
|
||||
```python
|
||||
import torchaudio as ta
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
|
||||
👋 Join us on [Discord](https://discord.gg/rJq9cRJBJ6) and let's build something awesome together!
|
||||
model = ChatterboxTTS.from_pretrained(device="cuda")
|
||||
|
||||
## Acknowledgements
|
||||
text = "Ezreal and Jinx teamed up with Ahri, Yasuo, and Teemo to take down the enemy's Nexus in an epic late-game pentakill."
|
||||
wav = model.generate(text)
|
||||
ta.save("test-1.wav", wav, model.sr)
|
||||
|
||||
# If you want to synthesize with a different voice, specify the audio prompt
|
||||
AUDIO_PROMPT_PATH="YOUR_FILE.wav"
|
||||
wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH)
|
||||
ta.save("test-2.wav", wav, model.sr)
|
||||
```
|
||||
See `example_tts.py` for more examples.
|
||||
|
||||
# Acknowledgements
|
||||
- [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice)
|
||||
- [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning)
|
||||
- [HiFT-GAN](https://github.com/yl4579/HiFTNet)
|
||||
- [Llama 3](https://github.com/meta-llama/llama3)
|
||||
- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer)
|
||||
|
||||
## Citation
|
||||
If you find this model useful, please consider citing.
|
||||
```
|
||||
@misc{chatterboxtts2025,
|
||||
author = {{Resemble AI}},
|
||||
title = {{Chatterbox-TTS}},
|
||||
year = {2025},
|
||||
howpublished = {\url{https://github.com/resemble-ai/chatterbox}},
|
||||
note = {GitHub repository}
|
||||
}
|
||||
```
|
||||
## Disclaimer
|
||||
# Built-in PerTh Watermarking for Responsible AI
|
||||
|
||||
Every audio file generated by Chatterbox includes [Resemble AI's Perth (Perceptual Threshold) Watermarker](https://github.com/resemble-ai/perth) - imperceptible neural watermarks that survive MP3 compression, audio editing, and common manipulations while maintaining nearly 100% detection accuracy.
|
||||
|
||||
# Disclaimer
|
||||
Don't use this model to do bad things. Prompts are sourced from freely available data on the internet.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import torchaudio as ta
|
||||
import torch
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
||||
|
||||
# Automatically detect the best available device
|
||||
if torch.cuda.is_available():
|
||||
|
|
@ -17,15 +16,4 @@ model = ChatterboxTTS.from_pretrained(device=device)
|
|||
|
||||
text = "Ezreal and Jinx teamed up with Ahri, Yasuo, and Teemo to take down the enemy's Nexus in an epic late-game pentakill."
|
||||
wav = model.generate(text)
|
||||
ta.save("test-1.wav", wav, model.sr)
|
||||
|
||||
multilingual_model = ChatterboxMultilingualTTS.from_pretrained(device=device)
|
||||
text = "Bonjour, comment ça va? Ceci est le modèle de synthèse vocale multilingue Chatterbox, il prend en charge 23 langues."
|
||||
wav = multilingual_model.generate(text, language_id="fr")
|
||||
ta.save("test-2.wav", wav, multilingual_model.sr)
|
||||
|
||||
|
||||
# If you want to synthesize with a different voice, specify the audio prompt
|
||||
AUDIO_PROMPT_PATH = "YOUR_FILE.wav"
|
||||
wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH)
|
||||
ta.save("test-3.wav", wav, model.sr)
|
||||
ta.save("test-1.wav", wav, model.sr)
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
import torchaudio as ta
|
||||
import torch
|
||||
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
||||
|
||||
# Load the Turbo model
|
||||
model = ChatterboxTurboTTS.from_pretrained(device="cuda")
|
||||
|
||||
# Generate with Paralinguistic Tags
|
||||
text = "Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and all that jazz. Would you like me to get some prices for you?"
|
||||
|
||||
# Generate audio (requires a reference clip for voice cloning)
|
||||
# wav = model.generate(text, audio_prompt_path="your_10s_ref_clip.wav")
|
||||
wav = model.generate(text)
|
||||
ta.save("test-turbo.wav", wav, model.sr)
|
||||
|
|
@ -1,24 +1,6 @@
|
|||
import torch
|
||||
import torchaudio as ta
|
||||
|
||||
from chatterbox.vc import ChatterboxVC
|
||||
|
||||
# Automatically detect the best available device
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
print(f"Using device: {device}")
|
||||
|
||||
AUDIO_PATH = "YOUR_FILE.wav"
|
||||
TARGET_VOICE_PATH = "YOUR_FILE.wav"
|
||||
|
||||
model = ChatterboxVC.from_pretrained(device)
|
||||
wav = model.generate(
|
||||
audio=AUDIO_PATH,
|
||||
target_voice_path=TARGET_VOICE_PATH,
|
||||
)
|
||||
model = ChatterboxVC.from_pretrained("cuda")
|
||||
wav = model.generate("tests/trimmed_8b7f38b1.wav")
|
||||
import torchaudio as ta
|
||||
ta.save("testvc.wav", wav, model.sr)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def load_model():
|
|||
return model
|
||||
|
||||
|
||||
def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num, cfgw, min_p, top_p, repetition_penalty):
|
||||
def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num, cfgw):
|
||||
if model is None:
|
||||
model = ChatterboxTTS.from_pretrained(DEVICE)
|
||||
|
||||
|
|
@ -34,9 +34,6 @@ def generate(model, text, audio_prompt_path, exaggeration, temperature, seed_num
|
|||
exaggeration=exaggeration,
|
||||
temperature=temperature,
|
||||
cfg_weight=cfgw,
|
||||
min_p=min_p,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
return (model.sr, wav.squeeze(0).numpy())
|
||||
|
||||
|
|
@ -46,21 +43,14 @@ with gr.Blocks() as demo:
|
|||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
text = gr.Textbox(
|
||||
value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.",
|
||||
label="Text to synthesize (max chars 300)",
|
||||
max_lines=5
|
||||
)
|
||||
text = gr.Textbox(value="What does the fox say?", label="Text to synthesize")
|
||||
ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value=None)
|
||||
exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5)
|
||||
cfg_weight = gr.Slider(0.0, 1, step=.05, label="CFG/Pace", value=0.5)
|
||||
cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG/Pace", value=0.5)
|
||||
|
||||
with gr.Accordion("More options", open=False):
|
||||
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
|
||||
temp = gr.Slider(0.05, 5, step=.05, label="temperature", value=.8)
|
||||
min_p = gr.Slider(0.00, 1.00, step=0.01, label="min_p || Newer Sampler. Recommend 0.02 > 0.1. Handles Higher Temperatures better. 0.00 Disables", value=0.05)
|
||||
top_p = gr.Slider(0.00, 1.00, step=0.01, label="top_p || Original Sampler. 1.0 Disables(recommended). Original 0.8", value=1.00)
|
||||
repetition_penalty = gr.Slider(1.00, 2.00, step=0.1, label="repetition_penalty", value=1.2)
|
||||
|
||||
run_btn = gr.Button("Generate", variant="primary")
|
||||
|
||||
|
|
@ -79,9 +69,6 @@ with gr.Blocks() as demo:
|
|||
temp,
|
||||
seed_num,
|
||||
cfg_weight,
|
||||
min_p,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
],
|
||||
outputs=audio_output,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,186 +0,0 @@
|
|||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import gradio as gr
|
||||
from chatterbox.tts_turbo import ChatterboxTurboTTS
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
EVENT_TAGS = [
|
||||
"[clear throat]", "[sigh]", "[shush]", "[cough]", "[groan]",
|
||||
"[sniff]", "[gasp]", "[chuckle]", "[laugh]"
|
||||
]
|
||||
|
||||
# --- REFINED CSS ---
|
||||
# 1. tag-container: Forces the row to wrap items instead of scrolling. Removes borders/backgrounds.
|
||||
# 2. tag-btn: Sets the specific look (indigo theme) and stops them from stretching.
|
||||
CUSTOM_CSS = """
|
||||
.tag-container {
|
||||
display: flex !important;
|
||||
flex-wrap: wrap !important; /* This fixes the one-per-line issue */
|
||||
gap: 8px !important;
|
||||
margin-top: 5px !important;
|
||||
margin-bottom: 10px !important;
|
||||
border: none !important;
|
||||
background: transparent !important;
|
||||
}
|
||||
|
||||
.tag-btn {
|
||||
min-width: fit-content !important;
|
||||
width: auto !important;
|
||||
height: 32px !important;
|
||||
font-size: 13px !important;
|
||||
background: #eef2ff !important;
|
||||
border: 1px solid #c7d2fe !important;
|
||||
color: #3730a3 !important;
|
||||
border-radius: 6px !important;
|
||||
padding: 0 10px !important;
|
||||
margin: 0 !important;
|
||||
box-shadow: none !important;
|
||||
}
|
||||
|
||||
.tag-btn:hover {
|
||||
background: #c7d2fe !important;
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
"""
|
||||
|
||||
INSERT_TAG_JS = """
|
||||
(tag_val, current_text) => {
|
||||
const textarea = document.querySelector('#main_textbox textarea');
|
||||
if (!textarea) return current_text + " " + tag_val;
|
||||
|
||||
const start = textarea.selectionStart;
|
||||
const end = textarea.selectionEnd;
|
||||
|
||||
let prefix = " ";
|
||||
let suffix = " ";
|
||||
|
||||
if (start === 0) prefix = "";
|
||||
else if (current_text[start - 1] === ' ') prefix = "";
|
||||
|
||||
if (end < current_text.length && current_text[end] === ' ') suffix = "";
|
||||
|
||||
return current_text.slice(0, start) + prefix + tag_val + suffix + current_text.slice(end);
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
|
||||
def load_model():
|
||||
print(f"Loading Chatterbox-Turbo on {DEVICE}...")
|
||||
model = ChatterboxTurboTTS.from_pretrained(DEVICE)
|
||||
return model
|
||||
|
||||
|
||||
def generate(
|
||||
model,
|
||||
text,
|
||||
audio_prompt_path,
|
||||
temperature,
|
||||
seed_num,
|
||||
min_p,
|
||||
top_p,
|
||||
top_k,
|
||||
repetition_penalty,
|
||||
norm_loudness
|
||||
):
|
||||
if model is None:
|
||||
model = ChatterboxTurboTTS.from_pretrained(DEVICE)
|
||||
|
||||
if seed_num != 0:
|
||||
set_seed(int(seed_num))
|
||||
|
||||
wav = model.generate(
|
||||
text,
|
||||
audio_prompt_path=audio_prompt_path,
|
||||
temperature=temperature,
|
||||
min_p=min_p,
|
||||
top_p=top_p,
|
||||
top_k=int(top_k),
|
||||
repetition_penalty=repetition_penalty,
|
||||
norm_loudness=norm_loudness,
|
||||
)
|
||||
return (model.sr, wav.squeeze(0).numpy())
|
||||
|
||||
|
||||
with gr.Blocks(title="Chatterbox Turbo", css=CUSTOM_CSS) as demo:
|
||||
gr.Markdown("# ⚡ Chatterbox Turbo")
|
||||
|
||||
model_state = gr.State(None)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
text = gr.Textbox(
|
||||
value="Oh, that's hilarious! [chuckle] Um anyway, we do have a new model in store. It's the SkyNet T-800 series and it's got basically everything. Including AI integration with ChatGPT and um all that jazz. Would you like me to get some prices for you?",
|
||||
label="Text to synthesize (max chars 300)",
|
||||
max_lines=5,
|
||||
elem_id="main_textbox"
|
||||
)
|
||||
|
||||
# --- Event Tags ---
|
||||
# Switched back to Row, but applied specific CSS to force wrapping
|
||||
with gr.Row(elem_classes=["tag-container"]):
|
||||
for tag in EVENT_TAGS:
|
||||
# elem_classes targets the button specifically
|
||||
btn = gr.Button(tag, elem_classes=["tag-btn"])
|
||||
|
||||
btn.click(
|
||||
fn=None,
|
||||
inputs=[btn, text],
|
||||
outputs=text,
|
||||
js=INSERT_TAG_JS
|
||||
)
|
||||
|
||||
ref_wav = gr.Audio(
|
||||
sources=["upload", "microphone"],
|
||||
type="filepath",
|
||||
label="Reference Audio File",
|
||||
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_random_podcast.wav"
|
||||
)
|
||||
|
||||
run_btn = gr.Button("Generate ⚡", variant="primary")
|
||||
|
||||
with gr.Column():
|
||||
audio_output = gr.Audio(label="Output Audio")
|
||||
|
||||
with gr.Accordion("Advanced Options", open=False):
|
||||
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
|
||||
temp = gr.Slider(0.05, 2.0, step=.05, label="Temperature", value=0.8)
|
||||
top_p = gr.Slider(0.00, 1.00, step=0.01, label="Top P", value=0.95)
|
||||
top_k = gr.Slider(0, 1000, step=10, label="Top K", value=1000)
|
||||
repetition_penalty = gr.Slider(1.00, 2.00, step=0.05, label="Repetition Penalty", value=1.2)
|
||||
min_p = gr.Slider(0.00, 1.00, step=0.01, label="Min P (Set to 0 to disable)", value=0.00)
|
||||
norm_loudness = gr.Checkbox(value=True, label="Normalize Loudness (-27 LUFS)")
|
||||
|
||||
demo.load(fn=load_model, inputs=[], outputs=model_state)
|
||||
|
||||
run_btn.click(
|
||||
fn=generate,
|
||||
inputs=[
|
||||
model_state,
|
||||
text,
|
||||
ref_wav,
|
||||
temp,
|
||||
seed_num,
|
||||
min_p,
|
||||
top_p,
|
||||
top_k,
|
||||
repetition_penalty,
|
||||
norm_loudness,
|
||||
],
|
||||
outputs=audio_output,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.queue(
|
||||
max_size=50,
|
||||
default_concurrency_limit=1,
|
||||
).launch(share=True)
|
||||
|
|
@ -1,317 +0,0 @@
|
|||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
|
||||
import gradio as gr
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"🚀 Running on device: {DEVICE}")
|
||||
|
||||
# --- Global Model Initialization ---
|
||||
MODEL = None
|
||||
|
||||
LANGUAGE_CONFIG = {
|
||||
"ar": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_f/ar_prompts2.flac",
|
||||
"text": "في الشهر الماضي، وصلنا إلى معلم جديد بمليارين من المشاهدات على قناتنا على يوتيوب."
|
||||
},
|
||||
"da": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/da_m1.flac",
|
||||
"text": "Sidste måned nåede vi en ny milepæl med to milliarder visninger på vores YouTube-kanal."
|
||||
},
|
||||
"de": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/de_f1.flac",
|
||||
"text": "Letzten Monat haben wir einen neuen Meilenstein erreicht: zwei Milliarden Aufrufe auf unserem YouTube-Kanal."
|
||||
},
|
||||
"el": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/el_m.flac",
|
||||
"text": "Τον περασμένο μήνα, φτάσαμε σε ένα νέο ορόσημο με δύο δισεκατομμύρια προβολές στο κανάλι μας στο YouTube."
|
||||
},
|
||||
"en": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac",
|
||||
"text": "Last month, we reached a new milestone with two billion views on our YouTube channel."
|
||||
},
|
||||
"es": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/es_f1.flac",
|
||||
"text": "El mes pasado alcanzamos un nuevo hito: dos mil millones de visualizaciones en nuestro canal de YouTube."
|
||||
},
|
||||
"fi": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fi_m.flac",
|
||||
"text": "Viime kuussa saavutimme uuden virstanpylvään kahden miljardin katselukerran kanssa YouTube-kanavallamme."
|
||||
},
|
||||
"fr": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fr_f1.flac",
|
||||
"text": "Le mois dernier, nous avons atteint un nouveau jalon avec deux milliards de vues sur notre chaîne YouTube."
|
||||
},
|
||||
"he": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/he_m1.flac",
|
||||
"text": "בחודש שעבר הגענו לאבן דרך חדשה עם שני מיליארד צפיות בערוץ היוטיוב שלנו."
|
||||
},
|
||||
"hi": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/hi_f1.flac",
|
||||
"text": "पिछले महीने हमने एक नया मील का पत्थर छुआ: हमारे YouTube चैनल पर दो अरब व्यूज़।"
|
||||
},
|
||||
"it": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/it_m1.flac",
|
||||
"text": "Il mese scorso abbiamo raggiunto un nuovo traguardo: due miliardi di visualizzazioni sul nostro canale YouTube."
|
||||
},
|
||||
"ja": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ja/ja_prompts1.flac",
|
||||
"text": "先月、私たちのYouTubeチャンネルで二十億回の再生回数という新たなマイルストーンに到達しました。"
|
||||
},
|
||||
"ko": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ko_f.flac",
|
||||
"text": "지난달 우리는 유튜브 채널에서 이십억 조회수라는 새로운 이정표에 도달했습니다."
|
||||
},
|
||||
"ms": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ms_f.flac",
|
||||
"text": "Bulan lepas, kami mencapai pencapaian baru dengan dua bilion tontonan di saluran YouTube kami."
|
||||
},
|
||||
"nl": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/nl_m.flac",
|
||||
"text": "Vorige maand bereikten we een nieuwe mijlpaal met twee miljard weergaven op ons YouTube-kanaal."
|
||||
},
|
||||
"no": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/no_f1.flac",
|
||||
"text": "Forrige måned nådde vi en ny milepæl med to milliarder visninger på YouTube-kanalen vår."
|
||||
},
|
||||
"pl": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/pl_m.flac",
|
||||
"text": "W zeszłym miesiącu osiągnęliśmy nowy kamień milowy z dwoma miliardami wyświetleń na naszym kanale YouTube."
|
||||
},
|
||||
"pt": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/pt_m1.flac",
|
||||
"text": "No mês passado, alcançámos um novo marco: dois mil milhões de visualizações no nosso canal do YouTube."
|
||||
},
|
||||
"ru": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ru_m.flac",
|
||||
"text": "В прошлом месяце мы достигли нового рубежа: два миллиарда просмотров на нашем YouTube-канале."
|
||||
},
|
||||
"sv": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/sv_f.flac",
|
||||
"text": "Förra månaden nådde vi en ny milstolpe med två miljarder visningar på vår YouTube-kanal."
|
||||
},
|
||||
"sw": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/sw_m.flac",
|
||||
"text": "Mwezi uliopita, tulifika hatua mpya ya maoni ya bilioni mbili kweny kituo chetu cha YouTube."
|
||||
},
|
||||
"tr": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/tr_m.flac",
|
||||
"text": "Geçen ay YouTube kanalımızda iki milyar görüntüleme ile yeni bir dönüm noktasına ulaştık."
|
||||
},
|
||||
"zh": {
|
||||
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/zh_f2.flac",
|
||||
"text": "上个月,我们达到了一个新的里程碑. 我们的YouTube频道观看次数达到了二十亿次,这绝对令人难以置信。"
|
||||
},
|
||||
}
|
||||
|
||||
# --- UI Helpers ---
|
||||
def default_audio_for_ui(lang: str) -> str | None:
|
||||
return LANGUAGE_CONFIG.get(lang, {}).get("audio")
|
||||
|
||||
|
||||
def default_text_for_ui(lang: str) -> str:
|
||||
return LANGUAGE_CONFIG.get(lang, {}).get("text", "")
|
||||
|
||||
|
||||
def get_supported_languages_display() -> str:
|
||||
"""Generate a formatted display of all supported languages."""
|
||||
language_items = []
|
||||
for code, name in sorted(SUPPORTED_LANGUAGES.items()):
|
||||
language_items.append(f"**{name}** (`{code}`)")
|
||||
|
||||
# Split into 2 lines
|
||||
mid = len(language_items) // 2
|
||||
line1 = " • ".join(language_items[:mid])
|
||||
line2 = " • ".join(language_items[mid:])
|
||||
|
||||
return f"""
|
||||
### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total)
|
||||
{line1}
|
||||
|
||||
{line2}
|
||||
"""
|
||||
|
||||
|
||||
def get_or_load_model():
|
||||
"""Loads the ChatterboxMultilingualTTS model if it hasn't been loaded already,
|
||||
and ensures it's on the correct device."""
|
||||
global MODEL
|
||||
if MODEL is None:
|
||||
print("Model not loaded, initializing...")
|
||||
try:
|
||||
MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
|
||||
if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
|
||||
MODEL.to(DEVICE)
|
||||
print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {e}")
|
||||
raise
|
||||
return MODEL
|
||||
|
||||
# Attempt to load the model at startup.
|
||||
try:
|
||||
get_or_load_model()
|
||||
except Exception as e:
|
||||
print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""Sets the random seed for reproducibility across torch, numpy, and random."""
|
||||
torch.manual_seed(seed)
|
||||
if DEVICE == "cuda":
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
def resolve_audio_prompt(language_id: str, provided_path: str | None) -> str | None:
|
||||
"""
|
||||
Decide which audio prompt to use:
|
||||
- If user provided a path (upload/mic/url), use it.
|
||||
- Else, fall back to language-specific default (if any).
|
||||
"""
|
||||
if provided_path and str(provided_path).strip():
|
||||
return provided_path
|
||||
return LANGUAGE_CONFIG.get(language_id, {}).get("audio")
|
||||
|
||||
|
||||
def generate_tts_audio(
|
||||
text_input: str,
|
||||
language_id: str,
|
||||
audio_prompt_path_input: str = None,
|
||||
exaggeration_input: float = 0.5,
|
||||
temperature_input: float = 0.8,
|
||||
seed_num_input: int = 0,
|
||||
cfgw_input: float = 0.5
|
||||
) -> tuple[int, np.ndarray]:
|
||||
"""
|
||||
Generate high-quality speech audio from text using Chatterbox Multilingual model with optional reference audio styling.
|
||||
Supported languages: English, French, German, Spanish, Italian, Portuguese, and Hindi.
|
||||
|
||||
This tool synthesizes natural-sounding speech from input text. When a reference audio file
|
||||
is provided, it captures the speaker's voice characteristics and speaking style. The generated audio
|
||||
maintains the prosody, tone, and vocal qualities of the reference speaker, or uses default voice if no reference is provided.
|
||||
|
||||
Args:
|
||||
text_input (str): The text to synthesize into speech (maximum 300 characters)
|
||||
language_id (str): The language code for synthesis (eg. en, fr, de, es, it, pt, hi)
|
||||
audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None.
|
||||
exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5.
|
||||
temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8.
|
||||
seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0.
|
||||
cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.5, 0 for language transfer.
|
||||
|
||||
Returns:
|
||||
tuple[int, np.ndarray]: A tuple containing the sample rate (int) and the generated audio waveform (numpy.ndarray)
|
||||
"""
|
||||
current_model = get_or_load_model()
|
||||
|
||||
if current_model is None:
|
||||
raise RuntimeError("TTS model is not loaded.")
|
||||
|
||||
if seed_num_input != 0:
|
||||
set_seed(int(seed_num_input))
|
||||
|
||||
print(f"Generating audio for text: '{text_input[:50]}...'")
|
||||
|
||||
# Handle optional audio prompt
|
||||
chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id)
|
||||
|
||||
generate_kwargs = {
|
||||
"exaggeration": exaggeration_input,
|
||||
"temperature": temperature_input,
|
||||
"cfg_weight": cfgw_input,
|
||||
}
|
||||
if chosen_prompt:
|
||||
generate_kwargs["audio_prompt_path"] = chosen_prompt
|
||||
print(f"Using audio prompt: {chosen_prompt}")
|
||||
else:
|
||||
print("No audio prompt provided; using default voice.")
|
||||
|
||||
wav = current_model.generate(
|
||||
text_input[:300], # Truncate text to max chars
|
||||
language_id=language_id,
|
||||
**generate_kwargs
|
||||
)
|
||||
print("Audio generation complete.")
|
||||
return (current_model.sr, wav.squeeze(0).numpy())
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
# Chatterbox Multilingual Demo
|
||||
Generate high-quality multilingual speech from text with reference audio styling, supporting 23 languages.
|
||||
"""
|
||||
)
|
||||
|
||||
# Display supported languages
|
||||
gr.Markdown(get_supported_languages_display())
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
initial_lang = "fr"
|
||||
text = gr.Textbox(
|
||||
value=default_text_for_ui(initial_lang),
|
||||
label="Text to synthesize (max chars 300)",
|
||||
max_lines=5
|
||||
)
|
||||
|
||||
language_id = gr.Dropdown(
|
||||
choices=list(ChatterboxMultilingualTTS.get_supported_languages().keys()),
|
||||
value=initial_lang,
|
||||
label="Language",
|
||||
info="Select the language for text-to-speech synthesis"
|
||||
)
|
||||
|
||||
ref_wav = gr.Audio(
|
||||
sources=["upload", "microphone"],
|
||||
type="filepath",
|
||||
label="Reference Audio File (Optional)",
|
||||
value=default_audio_for_ui(initial_lang)
|
||||
)
|
||||
|
||||
gr.Markdown(
|
||||
"💡 **Note**: Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip's language. To mitigate this, set the CFG weight to 0.",
|
||||
elem_classes=["audio-note"]
|
||||
)
|
||||
|
||||
exaggeration = gr.Slider(
|
||||
0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
|
||||
)
|
||||
cfg_weight = gr.Slider(
|
||||
0.2, 1, step=.05, label="CFG/Pace", value=0.5
|
||||
)
|
||||
|
||||
with gr.Accordion("More options", open=False):
|
||||
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
|
||||
temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
|
||||
|
||||
run_btn = gr.Button("Generate", variant="primary")
|
||||
|
||||
with gr.Column():
|
||||
audio_output = gr.Audio(label="Output Audio")
|
||||
|
||||
def on_language_change(lang, current_ref, current_text):
|
||||
return default_audio_for_ui(lang), default_text_for_ui(lang)
|
||||
|
||||
language_id.change(
|
||||
fn=on_language_change,
|
||||
inputs=[language_id, ref_wav, text],
|
||||
outputs=[ref_wav, text],
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
run_btn.click(
|
||||
fn=generate_tts_audio,
|
||||
inputs=[
|
||||
text,
|
||||
language_id,
|
||||
ref_wav,
|
||||
exaggeration,
|
||||
temp,
|
||||
seed_num,
|
||||
cfg_weight,
|
||||
],
|
||||
outputs=[audio_output],
|
||||
)
|
||||
|
||||
demo.launch(mcp_server=True)
|
||||
|
|
@ -1,29 +1,25 @@
|
|||
[project]
|
||||
name = "chatterbox-tts"
|
||||
version = "0.1.6"
|
||||
version = "0.1.1"
|
||||
description = "Chatterbox: Open Source TTS and Voice Conversion by Resemble AI"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.8"
|
||||
license = {file = "LICENSE"}
|
||||
authors = [
|
||||
{name = "resemble-ai", email = "engineering@resemble.ai"}
|
||||
]
|
||||
dependencies = [
|
||||
"numpy>=1.24.0,<1.26.0",
|
||||
"librosa==0.11.0",
|
||||
"numpy~=1.26.0",
|
||||
"resampy==0.4.3",
|
||||
"librosa==0.10.0",
|
||||
"s3tokenizer",
|
||||
"torch==2.6.0",
|
||||
"torchaudio==2.6.0",
|
||||
"transformers==4.46.3",
|
||||
"diffusers==0.29.0",
|
||||
"resemble-perth==1.0.1",
|
||||
"omegaconf==2.3.0",
|
||||
"conformer==0.3.2",
|
||||
"safetensors==0.5.3",
|
||||
"spacy-pkuseg",
|
||||
"pykakasi==2.3.0",
|
||||
"gradio==5.44.1",
|
||||
"pyloudnorm",
|
||||
"omegaconf"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
|
|
|||
|
|
@ -1,11 +1,2 @@
|
|||
try:
|
||||
from importlib.metadata import version
|
||||
except ImportError:
|
||||
from importlib_metadata import version # For Python <3.8
|
||||
|
||||
__version__ = version("chatterbox-tts")
|
||||
|
||||
|
||||
from .tts import ChatterboxTTS
|
||||
from .vc import ChatterboxVC
|
||||
from .mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
from ..utils import AttrDict
|
||||
|
||||
CFM_PARAMS = AttrDict({
|
||||
"sigma_min": 1e-06,
|
||||
"solver": "euler",
|
||||
"t_scheduler": "cosine",
|
||||
"training_cfg_rate": 0.2,
|
||||
"inference_cfg_rate": 0.7,
|
||||
"reg_loss_type": "l1"
|
||||
})
|
||||
|
|
@ -1,2 +1 @@
|
|||
S3GEN_SR = 24000
|
||||
S3GEN_SIL = 4299
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ from .utils.mask import add_optional_chunk_mask
|
|||
from .matcha.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, \
|
||||
TimestepEmbedding, Upsample1D
|
||||
from .matcha.transformer import BasicTransformerBlock
|
||||
from .utils.intmeanflow import get_intmeanflow_time_mixer
|
||||
|
||||
|
||||
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
|
|
@ -96,6 +95,8 @@ class CausalConv1d(torch.nn.Conv1d):
|
|||
x = F.pad(x, self.causal_padding)
|
||||
x = super(CausalConv1d, self).forward(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConditionalDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -109,7 +110,6 @@ class ConditionalDecoder(nn.Module):
|
|||
num_mid_blocks=12,
|
||||
num_heads=8,
|
||||
act_fn="gelu",
|
||||
meanflow=False,
|
||||
):
|
||||
"""
|
||||
This decoder requires an input with the same shape of the target. So, if your text content
|
||||
|
|
@ -117,7 +117,6 @@ class ConditionalDecoder(nn.Module):
|
|||
"""
|
||||
super().__init__()
|
||||
channels = tuple(channels)
|
||||
self.meanflow = meanflow
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.causal = causal
|
||||
|
|
@ -128,7 +127,6 @@ class ConditionalDecoder(nn.Module):
|
|||
time_embed_dim=time_embed_dim,
|
||||
act_fn="silu",
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
|
@ -217,14 +215,6 @@ class ConditionalDecoder(nn.Module):
|
|||
self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
|
||||
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
||||
self.initialize_weights()
|
||||
self.time_embed_mixer = None
|
||||
if self.meanflow:
|
||||
self.time_embed_mixer = get_intmeanflow_time_mixer(time_embed_dim)
|
||||
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.final_proj.weight.dtype
|
||||
|
||||
def initialize_weights(self):
|
||||
for m in self.modules():
|
||||
|
|
@ -240,16 +230,15 @@ class ConditionalDecoder(nn.Module):
|
|||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x, mask, mu, t, spks=None, cond=None, r=None):
|
||||
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
||||
"""Forward pass of the UNet1DConditional model.
|
||||
|
||||
Args:
|
||||
x: (B, 80, T)
|
||||
mask (_type_)
|
||||
x (torch.Tensor): shape (batch_size, in_channels, time)
|
||||
mask (_type_): shape (batch_size, 1, time)
|
||||
t (_type_): shape (batch_size)
|
||||
spks (_type_, optional) Defaults to None.
|
||||
cond (_type_, optional)
|
||||
r: end time for meanflow mode (shape (1,) tensor)
|
||||
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
||||
cond (_type_, optional): placeholder for future use. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
|
|
@ -258,15 +247,10 @@ class ConditionalDecoder(nn.Module):
|
|||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
|
||||
t = self.time_embeddings(t).to(t.dtype)
|
||||
t = self.time_mlp(t)
|
||||
|
||||
if self.meanflow:
|
||||
r = self.time_embeddings(r).to(t.dtype)
|
||||
r = self.time_mlp(r)
|
||||
concat_embed = torch.cat([t, r], dim=1)
|
||||
t = self.time_embed_mixer(concat_embed)
|
||||
|
||||
x = pack([x, mu], "b * t")[0]
|
||||
|
||||
if spks is not None:
|
||||
|
|
|
|||
|
|
@ -14,30 +14,142 @@
|
|||
import logging
|
||||
import random
|
||||
from typing import Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from .utils.mask import make_pad_mask
|
||||
from .configs import CFM_PARAMS
|
||||
from omegaconf import DictConfig
|
||||
from .utils.mask import make_pad_mask
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class MaskedDiffWithXvec(torch.nn.Module):
|
||||
def __init__(self,
|
||||
input_size: int = 512,
|
||||
output_size: int = 80,
|
||||
spk_embed_dim: int = 192,
|
||||
output_type: str = "mel",
|
||||
vocab_size: int = 4096,
|
||||
input_frame_rate: int = 50,
|
||||
only_mask_loss: bool = True,
|
||||
encoder: torch.nn.Module = None,
|
||||
length_regulator: torch.nn.Module = None,
|
||||
decoder: torch.nn.Module = None,
|
||||
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
||||
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
||||
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
||||
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
||||
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
||||
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
||||
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.decoder_conf = decoder_conf
|
||||
self.mel_feat_conf = mel_feat_conf
|
||||
self.vocab_size = vocab_size
|
||||
self.output_type = output_type
|
||||
self.input_frame_rate = input_frame_rate
|
||||
logging.info(f"input frame rate={self.input_frame_rate}")
|
||||
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
||||
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
||||
self.encoder = encoder
|
||||
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
||||
self.decoder = decoder
|
||||
self.length_regulator = length_regulator
|
||||
self.only_mask_loss = only_mask_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
token = batch['speech_token'].to(device)
|
||||
token_len = batch['speech_token_len'].to(device)
|
||||
feat = batch['speech_feat'].to(device)
|
||||
feat_len = batch['speech_feat_len'].to(device)
|
||||
embedding = batch['embedding'].to(device)
|
||||
|
||||
def _repeat_batch_dim(tnsr, B, ndim):
|
||||
"repeat batch dimension if it's equal to 1"
|
||||
if tnsr is not None:
|
||||
# add missing batch dim if needed
|
||||
while tnsr.ndim < ndim:
|
||||
tnsr = tnsr[None]
|
||||
# repeat batch dim as needed
|
||||
if B > 1 and tnsr.size(0) == 1:
|
||||
tnsr = tnsr.repeat(B, *([1] * (ndim - 1)))
|
||||
assert tnsr.ndim == ndim, f"Expected {ndim=}, got {tnsr.ndim=}"
|
||||
return tnsr
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||
|
||||
# text encode
|
||||
h, h_lengths = self.encoder(token, token_len)
|
||||
h = self.encoder_proj(h)
|
||||
h, h_lengths = self.length_regulator(h, feat_len)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros(feat.shape, device=token.device)
|
||||
for i, j in enumerate(feat_len):
|
||||
if random.random() < 0.5:
|
||||
continue
|
||||
index = random.randint(0, int(0.3 * j))
|
||||
conds[i, :index] = feat[i, :index]
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(feat_len)).to(h)
|
||||
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
|
||||
loss, _ = self.decoder.compute_loss(
|
||||
feat.transpose(1, 2).contiguous(),
|
||||
mask.unsqueeze(1),
|
||||
h.transpose(1, 2).contiguous(),
|
||||
embedding,
|
||||
cond=conds
|
||||
)
|
||||
return {'loss': loss}
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(self,
|
||||
token,
|
||||
token_len,
|
||||
prompt_token,
|
||||
prompt_token_len,
|
||||
prompt_feat,
|
||||
prompt_feat_len,
|
||||
embedding,
|
||||
flow_cache):
|
||||
if self.fp16 is True:
|
||||
prompt_feat = prompt_feat.half()
|
||||
embedding = embedding.half()
|
||||
|
||||
assert token.shape[0] == 1
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
||||
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||
|
||||
# text encode
|
||||
h, h_lengths = self.encoder(token, token_len)
|
||||
h = self.encoder_proj(h)
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
|
||||
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||
conds[:, :mel_len1] = prompt_feat
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||
feat, flow_cache = self.decoder(
|
||||
mu=h.transpose(1, 2).contiguous(),
|
||||
mask=mask.unsqueeze(1),
|
||||
spks=embedding,
|
||||
cond=conds,
|
||||
n_timesteps=10,
|
||||
prompt_len=mel_len1,
|
||||
flow_cache=flow_cache
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat.float(), flow_cache
|
||||
|
||||
|
||||
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
|
|
@ -54,14 +166,10 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||
encoder: torch.nn.Module = None,
|
||||
decoder: torch.nn.Module = None,
|
||||
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
||||
'cfm_params': DictConfig(
|
||||
{'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
||||
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7,
|
||||
'reg_loss_type': 'l1'}),
|
||||
'decoder_params': {'channels': [256, 256], 'dropout': 0.0,
|
||||
'attention_head_dim': 64,
|
||||
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8,
|
||||
'act_fn': 'gelu'}},
|
||||
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
||||
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
||||
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
||||
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
||||
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
||||
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
||||
super().__init__()
|
||||
|
|
@ -82,51 +190,8 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||
self.token_mel_ratio = token_mel_ratio
|
||||
self.pre_lookahead_len = pre_lookahead_len
|
||||
|
||||
# NOTE: copied in from cosyvoice repo
|
||||
def compute_loss(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
token = batch['speech_token'].to(device)
|
||||
token_len = batch['speech_token_len'].to(device)
|
||||
feat = batch['speech_feat'].to(device) # (B, 80, T)
|
||||
feat_len = batch['speech_feat_len'].to(device)
|
||||
embedding = batch['embedding'].to(device)
|
||||
|
||||
# NOTE unified training, static_chunk_size > 0 or = 0
|
||||
# streaming = True if random.random() < 0.5 else False
|
||||
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) # (B, T, 1)
|
||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask # (B, T, emb)
|
||||
|
||||
# text encode
|
||||
h, h_lengths = self.encoder(token, token_len) # (B, T, C) -> (B, 2T, C)
|
||||
h = self.encoder_proj(h)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros(feat.shape, device=token.device)
|
||||
for i, j in enumerate(feat_len):
|
||||
if random.random() < 0.5:
|
||||
continue
|
||||
index = random.randint(0, int(0.3 * j))
|
||||
conds[i, :, :index] = feat[i, :, :index]
|
||||
|
||||
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
|
||||
loss, _ = self.decoder.compute_loss(
|
||||
feat.contiguous(),
|
||||
mask.unsqueeze(1),
|
||||
h.transpose(1, 2).contiguous(),
|
||||
embedding,
|
||||
cond=conds,
|
||||
# streaming=streaming,
|
||||
)
|
||||
return {'loss': loss}
|
||||
# FIXME: this was missing - just putting it in as false
|
||||
self.fp16 = False
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(self,
|
||||
|
|
@ -137,62 +202,41 @@ class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|||
prompt_feat,
|
||||
prompt_feat_len,
|
||||
embedding,
|
||||
finalize,
|
||||
n_timesteps=10,
|
||||
noised_mels=None,
|
||||
meanflow=False):
|
||||
# token: (B, n_toks)
|
||||
# token_len: (B,)
|
||||
B = token.size(0)
|
||||
finalize):
|
||||
if self.fp16 is True:
|
||||
prompt_feat = prompt_feat.half()
|
||||
embedding = embedding.half()
|
||||
|
||||
assert token.shape[0] == 1
|
||||
# xvec projection
|
||||
embedding = torch.atleast_2d(embedding)
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding) # (1 or B, emb_dim)
|
||||
|
||||
# adjust shapes (batching logic)
|
||||
prompt_token = _repeat_batch_dim(prompt_token, B, ndim=2) # (B, n_prompt)
|
||||
prompt_token_len = _repeat_batch_dim(prompt_token_len, B, ndim=1) # (B,)
|
||||
prompt_feat = _repeat_batch_dim(prompt_feat, B, ndim=3) # (B, n_feat, feat_dim=80)
|
||||
prompt_feat_len = _repeat_batch_dim(prompt_feat_len, B, ndim=1) # (B,) or None
|
||||
embedding = _repeat_batch_dim(embedding, B, ndim=2) # (B, emb_dim)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
||||
|
||||
if (token >= self.vocab_size).any():
|
||||
logger.error(f"{token.max()}>{self.vocab_size}\n out-of-range special tokens found in flow, fix inputs!")
|
||||
token = self.input_embedding(token.long()) * mask
|
||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||
|
||||
# text encode
|
||||
h, h_masks = self.encoder(token, token_len)
|
||||
h, h_lengths = self.encoder(token, token_len)
|
||||
if finalize is False:
|
||||
h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
|
||||
|
||||
h_lengths = h_masks.sum(dim=-1).squeeze(dim=-1)
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
||||
h = self.encoder_proj(h)
|
||||
|
||||
# # get conditions
|
||||
conds = torch.zeros([B, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||
# get conditions
|
||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||
conds[:, :mel_len1] = prompt_feat
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(h_lengths)).unsqueeze(1).to(h)
|
||||
|
||||
if mask.shape[0] != B:
|
||||
mask = mask.repeat(B, 1, 1)
|
||||
|
||||
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||
feat, _ = self.decoder(
|
||||
mu=h.transpose(1, 2).contiguous(),
|
||||
mask=mask,
|
||||
mask=mask.unsqueeze(1),
|
||||
spks=embedding,
|
||||
cond=conds,
|
||||
n_timesteps=n_timesteps,
|
||||
noised_mels=noised_mels,
|
||||
meanflow=meanflow,
|
||||
n_timesteps=10
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat, None # NOTE jrm: why are they returning None here?
|
||||
return feat.float(), None # NOTE jrm: why are they returning None here?
|
||||
|
|
|
|||
|
|
@ -15,12 +15,17 @@ import threading
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from .matcha.flow_matching import BASECFM
|
||||
from .configs import CFM_PARAMS
|
||||
from tqdm import tqdm
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def cast_all(*args, dtype):
|
||||
return [a if (not a.dtype.is_floating_point) or a.dtype == dtype else a.to(dtype) for a in args]
|
||||
CFM_PARAMS = OmegaConf.create({
|
||||
"sigma_min": 1e-06,
|
||||
"solver": "euler",
|
||||
"t_scheduler": "cosine",
|
||||
"training_cfg_rate": 0.2,
|
||||
"inference_cfg_rate": 0.7,
|
||||
"reg_loss_type": "l1"
|
||||
})
|
||||
|
||||
|
||||
class ConditionalCFM(BASECFM):
|
||||
|
|
@ -37,6 +42,7 @@ class ConditionalCFM(BASECFM):
|
|||
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = estimator
|
||||
self.lock = threading.Lock()
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
|
||||
|
|
@ -58,8 +64,6 @@ class ConditionalCFM(BASECFM):
|
|||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
|
||||
raise NotImplementedError("unused, needs updating for meanflow model")
|
||||
|
||||
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
||||
cache_size = flow_cache.shape[2]
|
||||
# fix prompt and overlap part mu and z
|
||||
|
|
@ -75,7 +79,7 @@ class ConditionalCFM(BASECFM):
|
|||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
|
||||
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond, meanflow=False):
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
|
|
@ -89,60 +93,65 @@ class ConditionalCFM(BASECFM):
|
|||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
meanflow: meanflow mode
|
||||
"""
|
||||
in_dtype = x.dtype
|
||||
x, t_span, mu, mask, spks, cond = cast_all(x, t_span, mu, mask, spks, cond, dtype=self.estimator.dtype)
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
||||
t = t.unsqueeze(dim=0)
|
||||
|
||||
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
||||
# Or in future might add like a return_all_steps flag
|
||||
sol = []
|
||||
|
||||
# Duplicated batch dims are for CFG
|
||||
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
||||
B, T = mu.size(0), x.size(2)
|
||||
x_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype)
|
||||
mask_in = torch.zeros([2 * B, 1, T], device=x.device, dtype=x.dtype)
|
||||
mu_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2 * B ], device=x.device, dtype=x.dtype)
|
||||
spks_in = torch.zeros([2 * B, 80 ], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2 * B, 80, T], device=x.device, dtype=x.dtype)
|
||||
r_in = torch.zeros([2 * B ], device=x.device, dtype=x.dtype) # (only used for meanflow)
|
||||
|
||||
for t, r in zip(t_span[:-1], t_span[1:]):
|
||||
t = t.unsqueeze(dim=0)
|
||||
r = r.unsqueeze(dim=0)
|
||||
# Shapes:
|
||||
# x_in ( 2B, 80, T )
|
||||
# mask_in ( 2B, 1, T )
|
||||
# mu_in ( 2B, 80, T )
|
||||
# t_in ( 2B, )
|
||||
# spks_in ( 2B, 80, )
|
||||
# cond_in ( 2B, 80, T )
|
||||
# r_in ( 2B, )
|
||||
# x ( B, 80, T )
|
||||
# mask ( B, 1, T )
|
||||
# mu ( B, 80, T )
|
||||
# t ( B, )
|
||||
# spks ( B, 80, )
|
||||
# cond ( B, 80, T )
|
||||
# r ( B, )
|
||||
|
||||
x_in[:B] = x_in[B:] = x
|
||||
mask_in[:B] = mask_in[B:] = mask
|
||||
mu_in[:B] = mu
|
||||
t_in[:B] = t_in[B:] = t
|
||||
spks_in[:B] = spks
|
||||
cond_in[:B] = cond
|
||||
r_in[:B] = r_in[B:] = r # (only used for meanflow)
|
||||
dxdt = self.estimator.forward(
|
||||
x=x_in, mask=mask_in, mu=mu_in, t=t_in, spks=spks_in, cond=cond_in,
|
||||
r=r_in if meanflow else None,
|
||||
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
|
||||
spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
for step in range(1, len(t_span)):
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
x_in[:] = x
|
||||
mask_in[:] = mask
|
||||
mu_in[0] = mu
|
||||
t_in[:] = t.unsqueeze(0)
|
||||
spks_in[0] = spks
|
||||
cond_in[0] = cond
|
||||
dphi_dt = self.forward_estimator(
|
||||
x_in, mask_in,
|
||||
mu_in, t_in,
|
||||
spks_in,
|
||||
cond_in
|
||||
)
|
||||
dxdt, cfg_dxdt = torch.split(dxdt, [B, B], dim=0)
|
||||
dxdt = ((1.0 + self.inference_cfg_rate) * dxdt - self.inference_cfg_rate * cfg_dxdt)
|
||||
dt = r - t
|
||||
x = x + dt * dxdt
|
||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
||||
x = x + dt * dphi_dt
|
||||
t = t + dt
|
||||
sol.append(x)
|
||||
if step < len(t_span) - 1:
|
||||
dt = t_span[step + 1] - t
|
||||
|
||||
return sol[-1].float()
|
||||
|
||||
|
||||
return x.to(in_dtype)
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
||||
if isinstance(self.estimator, torch.nn.Module):
|
||||
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
||||
else:
|
||||
with self.lock:
|
||||
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
self.estimator.set_input_shape('t', (2,))
|
||||
self.estimator.set_input_shape('spks', (2, 80))
|
||||
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
# run trt engine
|
||||
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()])
|
||||
return x
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||
"""Computes diffusion loss
|
||||
|
|
@ -189,11 +198,10 @@ class ConditionalCFM(BASECFM):
|
|||
class CausalConditionalCFM(ConditionalCFM):
|
||||
def __init__(self, in_channels=240, cfm_params=CFM_PARAMS, n_spks=1, spk_emb_dim=80, estimator=None):
|
||||
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
||||
# TODO: BAD BAD IDEA - IT'LL MESS UP DISTILLATION - SETTING TO NONE
|
||||
self.rand_noise = None
|
||||
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, noised_mels=None, meanflow=False):
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
|
|
@ -206,41 +214,15 @@ class CausalConditionalCFM(ConditionalCFM):
|
|||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
noised_mels: gt mels noised a time t
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
|
||||
B = mu.size(0)
|
||||
z = torch.randn_like(mu)
|
||||
|
||||
if noised_mels is not None:
|
||||
prompt_len = mu.size(2) - noised_mels.size(2)
|
||||
z[..., prompt_len:] = noised_mels
|
||||
|
||||
# time steps for reverse diffusion
|
||||
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
|
||||
# fix prompt and overlap part mu and z
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
if (not meanflow) and (self.t_scheduler == 'cosine'):
|
||||
if self.t_scheduler == 'cosine':
|
||||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||
|
||||
# NOTE: right now, the only meanflow models are also distilled models, which don't need CFG
|
||||
# because they were distilled with CFG outputs. We would need to add another hparam and
|
||||
# change the conditional logic here if we want to use CFG inference with a meanflow model.
|
||||
if meanflow:
|
||||
return self.basic_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
|
||||
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, meanflow=meanflow), None
|
||||
|
||||
def basic_euler(self, x, t_span, mu, mask, spks, cond):
|
||||
in_dtype = x.dtype
|
||||
x, t_span, mu, mask, spks, cond = cast_all(x, t_span, mu, mask, spks, cond, dtype=self.estimator.dtype)
|
||||
|
||||
print("S3 Token -> Mel Inference...")
|
||||
for t, r in tqdm(zip(t_span[..., :-1], t_span[..., 1:]), total=t_span.shape[-1] - 1):
|
||||
t, r = t[None], r[None]
|
||||
dxdt = self.estimator.forward(x, mask=mask, mu=mu, t=t, spks=spks, cond=cond, r=r)
|
||||
dt = r - t
|
||||
x = x + dt * dxdt
|
||||
|
||||
return x.to(in_dtype)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import torch
|
|||
import torchaudio as ta
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer
|
||||
from .const import S3GEN_SR
|
||||
|
|
@ -30,7 +31,6 @@ from .hifigan import HiFTGenerator
|
|||
from .transformer.upsample_encoder import UpsampleConformerEncoder
|
||||
from .flow_matching import CausalConditionalCFM
|
||||
from .decoder import ConditionalDecoder
|
||||
from .configs import CFM_PARAMS
|
||||
|
||||
|
||||
def drop_invalid_tokens(x):
|
||||
|
|
@ -46,20 +46,15 @@ def get_resampler(src_sr, dst_sr, device):
|
|||
|
||||
class S3Token2Mel(torch.nn.Module):
|
||||
"""
|
||||
S3Gen's CFM decoder maps S3 speech tokens to mel-spectrograms.
|
||||
CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms.
|
||||
|
||||
TODO: make these modules configurable?
|
||||
"""
|
||||
def __init__(self, meanflow=False):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz")
|
||||
self.mel_extractor = mel_spectrogram # TODO: make it a torch module?
|
||||
self.speaker_encoder = CAMPPlus(
|
||||
# NOTE: This doesn't affect inference. It turns off activation checkpointing
|
||||
# (a training optimization), which causes a crazy DDP error with accelerate
|
||||
memory_efficient=False,
|
||||
)
|
||||
self.meanflow = meanflow
|
||||
self.speaker_encoder = CAMPPlus() # use default args
|
||||
|
||||
encoder = UpsampleConformerEncoder(
|
||||
output_size=512,
|
||||
|
|
@ -89,9 +84,15 @@ class S3Token2Mel(torch.nn.Module):
|
|||
num_mid_blocks=12,
|
||||
num_heads=8,
|
||||
act_fn='gelu',
|
||||
meanflow=self.meanflow,
|
||||
)
|
||||
cfm_params = CFM_PARAMS
|
||||
cfm_params = DictConfig({
|
||||
"sigma_min": 1e-06,
|
||||
"solver": 'euler',
|
||||
"t_scheduler": 'cosine',
|
||||
"training_cfg_rate": 0.2,
|
||||
"inference_cfg_rate": 0.7,
|
||||
"reg_loss_type": 'l1',
|
||||
})
|
||||
decoder = CausalConditionalCFM(
|
||||
spk_emb_dim=80,
|
||||
cfm_params=cfm_params,
|
||||
|
|
@ -110,11 +111,6 @@ class S3Token2Mel(torch.nn.Module):
|
|||
params = self.tokenizer.parameters()
|
||||
return next(params).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
params = self.flow.parameters()
|
||||
return next(params).dtype
|
||||
|
||||
def embed_ref(
|
||||
self,
|
||||
ref_wav: torch.Tensor,
|
||||
|
|
@ -133,26 +129,23 @@ class S3Token2Mel(torch.nn.Module):
|
|||
ref_wav = ref_wav.unsqueeze(0) # (B, L)
|
||||
|
||||
if ref_wav.size(1) > 10 * ref_sr:
|
||||
print("WARNING: s3gen received ref longer than 10s")
|
||||
print("WARNING: cosydec received ref longer than 10s")
|
||||
|
||||
ref_wav_24 = ref_wav
|
||||
if ref_sr != S3GEN_SR:
|
||||
ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav)
|
||||
ref_wav_24 = ref_wav_24.to(device=device, dtype=self.dtype)
|
||||
|
||||
ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(dtype=self.dtype)
|
||||
ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device)
|
||||
ref_mels_24_len = None
|
||||
|
||||
# Resample to 16kHz
|
||||
ref_wav_16 = ref_wav
|
||||
if ref_sr != S3_SR:
|
||||
ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav)
|
||||
ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device)
|
||||
|
||||
# Speaker embedding
|
||||
ref_x_vector = self.speaker_encoder.inference(ref_wav_16.to(dtype=self.dtype))
|
||||
ref_x_vector = self.speaker_encoder.inference(ref_wav_16)
|
||||
|
||||
# Tokenize 16khz reference
|
||||
ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16.float())
|
||||
ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16)
|
||||
|
||||
# Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms)
|
||||
if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]:
|
||||
|
|
@ -178,10 +171,7 @@ class S3Token2Mel(torch.nn.Module):
|
|||
ref_sr: Optional[int],
|
||||
# pre-computed ref embedding (prod API)
|
||||
ref_dict: Optional[dict] = None,
|
||||
n_cfm_timesteps = None,
|
||||
finalize: bool = False,
|
||||
speech_token_lens=None,
|
||||
noised_mels=None,
|
||||
):
|
||||
"""
|
||||
Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
|
||||
|
|
@ -209,21 +199,18 @@ class S3Token2Mel(torch.nn.Module):
|
|||
if isinstance(ref_dict[rk], np.ndarray):
|
||||
ref_dict[rk] = torch.from_numpy(ref_dict[rk])
|
||||
if torch.is_tensor(ref_dict[rk]):
|
||||
ref_dict[rk] = ref_dict[rk].to(device=self.device, dtype=self.dtype)
|
||||
ref_dict[rk] = ref_dict[rk].to(self.device)
|
||||
|
||||
speech_tokens = torch.atleast_2d(speech_tokens)
|
||||
if len(speech_tokens.shape) == 1:
|
||||
speech_tokens = speech_tokens.unsqueeze(0)
|
||||
|
||||
# backcompat
|
||||
if speech_token_lens is None:
|
||||
speech_token_lens = torch.LongTensor([st.size(-1) for st in speech_tokens]).to(self.device)
|
||||
# assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now"
|
||||
speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device)
|
||||
|
||||
output_mels, _ = self.flow.inference(
|
||||
token=speech_tokens,
|
||||
token_len=speech_token_lens,
|
||||
finalize=finalize,
|
||||
noised_mels=noised_mels,
|
||||
n_timesteps=n_cfm_timesteps,
|
||||
meanflow=self.meanflow,
|
||||
**ref_dict,
|
||||
)
|
||||
return output_mels
|
||||
|
|
@ -231,15 +218,13 @@ class S3Token2Mel(torch.nn.Module):
|
|||
|
||||
class S3Token2Wav(S3Token2Mel):
|
||||
"""
|
||||
The decoder of S3Gen is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules.
|
||||
The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules.
|
||||
|
||||
TODO: make these modules configurable?
|
||||
"""
|
||||
|
||||
ignore_state_dict_missing = ("tokenizer._mel_filters", "tokenizer.window")
|
||||
|
||||
def __init__(self, meanflow=False):
|
||||
super().__init__(meanflow)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
f0_predictor = ConvRNNF0Predictor()
|
||||
self.mel2wav = HiFTGenerator(
|
||||
|
|
@ -256,7 +241,6 @@ class S3Token2Wav(S3Token2Mel):
|
|||
trim_fade = torch.zeros(2 * n_trim)
|
||||
trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2
|
||||
self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting)
|
||||
self.estimator_dtype = "fp32"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
@ -266,25 +250,9 @@ class S3Token2Wav(S3Token2Mel):
|
|||
ref_sr: Optional[int],
|
||||
# pre-computed ref embedding (prod API)
|
||||
ref_dict: Optional[dict] = None,
|
||||
finalize: bool = False,
|
||||
speech_token_lens=None,
|
||||
skip_vocoder=False,
|
||||
n_cfm_timesteps=None,
|
||||
noised_mels=None,
|
||||
|
||||
finalize: bool = False
|
||||
):
|
||||
"""
|
||||
Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
|
||||
NOTE: used for sync synthesis only. Please use `S3GenStreamer` for streaming synthesis.
|
||||
"""
|
||||
output_mels = super().forward(
|
||||
speech_tokens, speech_token_lens=speech_token_lens, ref_wav=ref_wav,
|
||||
ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize,
|
||||
n_cfm_timesteps=n_cfm_timesteps, noised_mels=noised_mels,
|
||||
)
|
||||
|
||||
if skip_vocoder:
|
||||
return output_mels
|
||||
output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
|
||||
|
||||
# TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now.
|
||||
hift_cache_source = torch.zeros(1, 1, 0).to(self.device)
|
||||
|
|
@ -306,24 +274,14 @@ class S3Token2Wav(S3Token2Mel):
|
|||
ref_sr: Optional[int] = None,
|
||||
# pre-computed ref embedding (prod API)
|
||||
ref_dict: Optional[dict] = None,
|
||||
n_cfm_timesteps = None,
|
||||
finalize: bool = False,
|
||||
speech_token_lens=None,
|
||||
):
|
||||
n_cfm_timesteps = n_cfm_timesteps or (2 if self.meanflow else 10)
|
||||
noise = None
|
||||
if self.meanflow:
|
||||
noise = torch.randn(1, 80, speech_tokens.size(-1) * 2, dtype=self.dtype, device=self.device)
|
||||
output_mels = super().forward(
|
||||
speech_tokens, speech_token_lens=speech_token_lens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict,
|
||||
n_cfm_timesteps=n_cfm_timesteps, finalize=finalize, noised_mels=noise,
|
||||
)
|
||||
return output_mels
|
||||
return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
|
||||
|
||||
@torch.inference_mode()
|
||||
def hift_inference(self, speech_feat, cache_source: torch.Tensor = None):
|
||||
if cache_source is None:
|
||||
cache_source = torch.zeros(1, 1, 0).to(device=self.device, dtype=self.dtype)
|
||||
cache_source = torch.zeros(1, 1, 0).to(self.device)
|
||||
return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source)
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
@ -335,26 +293,11 @@ class S3Token2Wav(S3Token2Mel):
|
|||
ref_sr: Optional[int] = None,
|
||||
# pre-computed ref embedding (prod API)
|
||||
ref_dict: Optional[dict] = None,
|
||||
# left as a kwarg because this can change input/output size ratio
|
||||
drop_invalid_tokens=True,
|
||||
n_cfm_timesteps=None,
|
||||
speech_token_lens=None,
|
||||
cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here
|
||||
finalize: bool = True,
|
||||
):
|
||||
# hallucination prevention, drop special tokens
|
||||
# if drop_invalid_tokens:
|
||||
# speech_tokens, speech_token_lens = drop_invalid(speech_tokens, pad=S3_QUIET_PAD)
|
||||
|
||||
output_mels = self.flow_inference(
|
||||
speech_tokens,
|
||||
speech_token_lens=speech_token_lens,
|
||||
ref_wav=ref_wav,
|
||||
ref_sr=ref_sr,
|
||||
ref_dict=ref_dict,
|
||||
n_cfm_timesteps=n_cfm_timesteps,
|
||||
finalize=True,
|
||||
)
|
||||
output_mels = output_mels.to(dtype=self.dtype) # FIXME (fp16 mode) is this still needed?
|
||||
output_wavs, output_sources = self.hift_inference(output_mels, None)
|
||||
output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
|
||||
output_wavs, output_sources = self.hift_inference(output_mels, cache_source)
|
||||
|
||||
# NOTE: ad-hoc method to reduce "spillover" from the reference clip.
|
||||
output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
|
||||
|
|
|
|||
|
|
@ -1,36 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def get_intmeanflow_time_mixer(dims):
|
||||
""""
|
||||
Diagonal init as described in 3.3 https://arxiv.org/pdf/2510.07979
|
||||
"""
|
||||
layer = nn.Linear(dims * 2, dims, bias=False)
|
||||
|
||||
with torch.no_grad():
|
||||
target_weight = torch.zeros(dims, 2 * dims)
|
||||
target_weight[:, 0:dims] = torch.eye(dims)
|
||||
layer.weight.data = target_weight
|
||||
|
||||
return layer
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
D_example = 6
|
||||
|
||||
W_layer = get_intmeanflow_time_mixer(D_example)
|
||||
|
||||
print(f"Layer weight (AFTER init):\n{W_layer.weight.data}\n")
|
||||
|
||||
e_t = torch.tensor([0., 1., 2., 3., 4., 5.])
|
||||
e_r = torch.tensor([6., 7., 8., 9., 10., 11.])
|
||||
e_concat = torch.cat([e_t, e_r]).unsqueeze(0) # Shape (1, 12)
|
||||
|
||||
output = W_layer(e_concat)
|
||||
|
||||
print(f"Test Input e_t: \n{e_t}")
|
||||
print(f"Test Input e_r: \n{e_r}")
|
||||
print(f"Test Input concat: \n{e_concat}")
|
||||
|
||||
print(f"Forward Pass Output: \n{output.squeeze(0)}")
|
||||
|
|
@ -181,7 +181,6 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 1, 1]]
|
||||
"""
|
||||
lengths = lengths.long()
|
||||
batch_size = lengths.size(0)
|
||||
max_len = max_len if max_len > 0 else lengths.max().item()
|
||||
seq_range = torch.arange(0,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,8 @@
|
|||
"""mel-spectrogram extraction in Matcha-TTS"""
|
||||
import logging
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# NOTE: they decalred these global vars
|
||||
mel_basis = {}
|
||||
|
|
@ -45,11 +42,10 @@ def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=48
|
|||
if len(y.shape) == 1:
|
||||
y = y[None, ]
|
||||
|
||||
# Debug: Check for audio clipping (values outside [-1.0, 1.0] range)
|
||||
min_val = torch.min(y)
|
||||
max_val = torch.max(y)
|
||||
if min_val < -1.0 or max_val > 1.0:
|
||||
logger.warning(f"Audio values outside normalized range: min={min_val.item():.4f}, max={max_val.item():.4f}")
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
|
||||
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
||||
|
|
|
|||
|
|
@ -10,9 +10,6 @@ from types import MethodType
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
LLAMA_ALIGNED_HEADS = [(12, 15), (13, 11), (9, 2)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlignmentAnalysisResult:
|
||||
# was this frame detected as being part of a noisy beginning chunk with potential hallucinations?
|
||||
|
|
@ -52,22 +49,21 @@ class AlignmentStreamAnalyzer:
|
|||
|
||||
self.complete = False
|
||||
self.completed_at = None
|
||||
|
||||
# Track generated tokens for repetition detection
|
||||
self.generated_tokens = []
|
||||
|
||||
# Using `output_attentions=True` is incompatible with optimized attention kernels, so
|
||||
# using it for all layers slows things down too much. We can apply it to just one layer
|
||||
# by intercepting the kwargs and adding a forward hook (credit: jrm)
|
||||
self.last_aligned_attns = []
|
||||
for i, (layer_idx, head_idx) in enumerate(LLAMA_ALIGNED_HEADS):
|
||||
self.last_aligned_attns += [None]
|
||||
self._add_attention_spy(tfmr, i, layer_idx, head_idx)
|
||||
self.last_aligned_attn = None
|
||||
self._add_attention_spy(tfmr, alignment_layer_idx)
|
||||
|
||||
def _add_attention_spy(self, tfmr, buffer_idx, layer_idx, head_idx):
|
||||
def _add_attention_spy(self, tfmr, alignment_layer_idx):
|
||||
"""
|
||||
Adds a forward hook to a specific attention layer to collect outputs.
|
||||
Using `output_attentions=True` is incompatible with optimized attention kernels, so
|
||||
using it for all layers slows things down too much.
|
||||
(credit: jrm)
|
||||
"""
|
||||
|
||||
def attention_forward_hook(module, input, output):
|
||||
"""
|
||||
See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
|
||||
|
|
@ -75,23 +71,27 @@ class AlignmentStreamAnalyzer:
|
|||
- When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
|
||||
- `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
|
||||
"""
|
||||
if isinstance(output, tuple) and len(output) > 1 and output[1] is not None:
|
||||
step_attention = output[1].cpu() # (B, n_heads, T0, Ti)
|
||||
self.last_aligned_attns[buffer_idx] = step_attention[0, head_idx] # (T0, Ti)
|
||||
step_attention = output[1].cpu() # (B, 16, N, N)
|
||||
self.last_aligned_attn = step_attention[0].mean(0) # (N, N)
|
||||
|
||||
target_layer = tfmr.layers[layer_idx].self_attn
|
||||
# Register hook and store the handle
|
||||
target_layer.register_forward_hook(attention_forward_hook)
|
||||
if hasattr(tfmr, 'config') and hasattr(tfmr.config, 'output_attentions'):
|
||||
self.original_output_attentions = tfmr.config.output_attentions
|
||||
tfmr.config.output_attentions = True
|
||||
target_layer = tfmr.layers[alignment_layer_idx].self_attn
|
||||
hook_handle = target_layer.register_forward_hook(attention_forward_hook)
|
||||
|
||||
def step(self, logits, next_token=None):
|
||||
# Backup original forward
|
||||
original_forward = target_layer.forward
|
||||
def patched_forward(self, *args, **kwargs):
|
||||
kwargs['output_attentions'] = True
|
||||
return original_forward(*args, **kwargs)
|
||||
|
||||
# TODO: how to unpatch it?
|
||||
target_layer.forward = MethodType(patched_forward, target_layer)
|
||||
|
||||
def step(self, logits):
|
||||
"""
|
||||
Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
|
||||
"""
|
||||
# extract approximate alignment matrix chunk (1 frame at a time after the first chunk)
|
||||
aligned_attn = torch.stack(self.last_aligned_attns).mean(dim=0) # (N, N)
|
||||
aligned_attn = self.last_aligned_attn # (N, N)
|
||||
i, j = self.text_tokens_slice
|
||||
if self.curr_frame_pos == 0:
|
||||
# first chunk has conditioning info, text tokens, and BOS token
|
||||
|
|
@ -133,46 +133,22 @@ class AlignmentStreamAnalyzer:
|
|||
last_text_token_duration = A[15:, -3:].sum()
|
||||
|
||||
# Activations for the final token that last too long are likely hallucinations.
|
||||
long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 5) # 200ms
|
||||
long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 10) # 400ms
|
||||
|
||||
# If there are activations in previous tokens after generation has completed, assume this is a repetition error.
|
||||
alignment_repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
|
||||
|
||||
# Track generated tokens for repetition detection
|
||||
if next_token is not None:
|
||||
# Convert tensor to scalar if needed
|
||||
if isinstance(next_token, torch.Tensor):
|
||||
token_id = next_token.item() if next_token.numel() == 1 else next_token.view(-1)[0].item()
|
||||
else:
|
||||
token_id = next_token
|
||||
self.generated_tokens.append(token_id)
|
||||
|
||||
# Keep only last 8 tokens to prevent memory issues
|
||||
if len(self.generated_tokens) > 8:
|
||||
self.generated_tokens = self.generated_tokens[-8:]
|
||||
|
||||
# Check for excessive token repetition (3x same token in a row)
|
||||
token_repetition = (
|
||||
# self.complete and
|
||||
len(self.generated_tokens) >= 3 and
|
||||
len(set(self.generated_tokens[-2:])) == 1
|
||||
)
|
||||
|
||||
if token_repetition:
|
||||
repeated_token = self.generated_tokens[-1]
|
||||
logger.warning(f"🚨 Detected 2x repetition of token {repeated_token}")
|
||||
|
||||
# Suppress EoS to prevent early termination
|
||||
if cur_text_posn < S - 3 and S > 5: # Only suppress if text is longer than 5 tokens
|
||||
logits[..., self.eos_idx] = -2**15
|
||||
repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
|
||||
|
||||
# If a bad ending is detected, force emit EOS by modifying logits
|
||||
# NOTE: this means logits may be inconsistent with latents!
|
||||
if long_tail or alignment_repetition or token_repetition:
|
||||
logger.warning(f"forcing EOS token, {long_tail=}, {alignment_repetition=}, {token_repetition=}")
|
||||
if long_tail or repetition:
|
||||
logger.warn(f"forcing EOS token, {long_tail=}, {repetition=}")
|
||||
# (±2**15 is safe for all dtypes >= 16bit)
|
||||
logits = -(2**15) * torch.ones_like(logits)
|
||||
logits[..., self.eos_idx] = 2**15
|
||||
|
||||
# Suppress EoS to prevent early termination
|
||||
if cur_text_posn < S - 3: # FIXME: arbitrary
|
||||
logits[..., self.eos_idx] = -2**15
|
||||
|
||||
self.curr_frame_pos += 1
|
||||
return logits
|
||||
|
|
|
|||
|
|
@ -32,42 +32,6 @@ LLAMA_520M_CONFIG_DICT = dict(
|
|||
use_cache=True,
|
||||
)
|
||||
|
||||
GPT2_MEDIUM_CONFIG = {
|
||||
"activation_function": "gelu_new",
|
||||
"architectures": [
|
||||
"GPT2LMHeadModel"
|
||||
],
|
||||
"attn_pdrop": 0.1,
|
||||
"bos_token_id": 50256,
|
||||
"embd_pdrop": 0.1,
|
||||
"eos_token_id": 50256,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"model_type": "gpt2",
|
||||
"n_ctx": 8196,
|
||||
"n_embd": 1024,
|
||||
"hidden_size": 1024,
|
||||
"n_head": 16,
|
||||
"n_layer": 24,
|
||||
"n_positions": 8196,
|
||||
"n_special": 0,
|
||||
"predict_special_tokens": True,
|
||||
"resid_pdrop": 0.1,
|
||||
"summary_activation": None,
|
||||
"summary_first_dropout": 0.1,
|
||||
"summary_proj_to_labels": True,
|
||||
"summary_type": "cls_index",
|
||||
"summary_use_proj": True,
|
||||
"task_specific_params": {
|
||||
"text-generation": {
|
||||
"do_sample": True,
|
||||
"max_length": 50
|
||||
}
|
||||
},
|
||||
"vocab_size": 50276,
|
||||
}
|
||||
|
||||
LLAMA_CONFIGS = {
|
||||
"Llama_520M": LLAMA_520M_CONFIG_DICT,
|
||||
"GPT2_medium": GPT2_MEDIUM_CONFIG,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,40 +2,26 @@ from ..llama_configs import LLAMA_CONFIGS
|
|||
|
||||
|
||||
class T3Config:
|
||||
def __init__(self, text_tokens_dict_size=704):
|
||||
self.start_text_token = 255
|
||||
self.stop_text_token = 0
|
||||
self.text_tokens_dict_size = text_tokens_dict_size
|
||||
self.max_text_tokens = 2048
|
||||
start_text_token = 255
|
||||
stop_text_token = 0
|
||||
text_tokens_dict_size = 704
|
||||
max_text_tokens = 2048
|
||||
|
||||
self.start_speech_token = 6561
|
||||
self.stop_speech_token = 6562
|
||||
self.speech_tokens_dict_size = 8194
|
||||
self.max_speech_tokens = 4096
|
||||
start_speech_token = 6561
|
||||
stop_speech_token = 6562
|
||||
speech_tokens_dict_size = 8194
|
||||
max_speech_tokens = 4096
|
||||
|
||||
self.llama_config_name = "Llama_520M"
|
||||
self.input_pos_emb = "learned"
|
||||
self.speech_cond_prompt_len = 150
|
||||
llama_config_name = "Llama_520M"
|
||||
input_pos_emb = "learned"
|
||||
speech_cond_prompt_len = 150
|
||||
|
||||
self.encoder_type = "voice_encoder"
|
||||
self.speaker_embed_size = 256
|
||||
self.use_perceiver_resampler = True
|
||||
self.emotion_adv = True
|
||||
# For T3CondEnc
|
||||
encoder_type = "voice_encoder"
|
||||
speaker_embed_size = 256
|
||||
use_perceiver_resampler = True
|
||||
emotion_adv = True
|
||||
|
||||
@property
|
||||
def n_channels(self):
|
||||
return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
|
||||
|
||||
@property
|
||||
def is_multilingual(self):
|
||||
return self.text_tokens_dict_size == 2454
|
||||
|
||||
@classmethod
|
||||
def english_only(cls):
|
||||
"""Create configuration for English-only TTS model."""
|
||||
return cls(text_tokens_dict_size=704)
|
||||
|
||||
@classmethod
|
||||
def multilingual(cls):
|
||||
"""Create configuration for multilingual TTS model."""
|
||||
return cls(text_tokens_dict_size=2454)
|
||||
|
|
|
|||
|
|
@ -3,21 +3,13 @@
|
|||
import logging
|
||||
from typing import Union, Optional, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
from transformers import LlamaModel, LlamaConfig, GPT2Config, GPT2Model
|
||||
from transformers.generation.logits_process import (
|
||||
LogitsProcessorList,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
MinPLogitsWarper,
|
||||
)
|
||||
from transformers import LlamaModel, LlamaConfig
|
||||
from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor
|
||||
|
||||
from .modules.learned_pos_emb import LearnedPositionEmbeddings
|
||||
|
||||
from .modules.cond_enc import T3CondEnc, T3Cond
|
||||
|
|
@ -25,12 +17,17 @@ from .modules.t3_config import T3Config
|
|||
from .llama_configs import LLAMA_CONFIGS
|
||||
from .inference.t3_hf_backend import T3HuggingfaceBackend
|
||||
from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
|
||||
from ..utils import AttrDict
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def _ensure_BOT_EOT(text_tokens: Tensor, hp):
|
||||
B = text_tokens.size(0)
|
||||
assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token"
|
||||
|
|
@ -47,22 +44,11 @@ class T3(nn.Module):
|
|||
different PE embedding space for speech.
|
||||
"""
|
||||
|
||||
def __init__(self, hp=None):
|
||||
if hp is None:
|
||||
hp = T3Config.english_only()
|
||||
def __init__(self, hp=T3Config()):
|
||||
super().__init__()
|
||||
self.hp = hp
|
||||
|
||||
config_dict = LLAMA_CONFIGS[hp.llama_config_name]
|
||||
self.is_gpt = config_dict.get("model_type") == "gpt2"
|
||||
|
||||
if self.is_gpt:
|
||||
self.cfg = GPT2Config(**config_dict)
|
||||
self.tfmr = GPT2Model(self.cfg)
|
||||
else:
|
||||
self.cfg = LlamaConfig(**config_dict)
|
||||
self.tfmr = LlamaModel(self.cfg)
|
||||
|
||||
self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name])
|
||||
self.tfmr = LlamaModel(self.cfg)
|
||||
self.dim = self.cfg.hidden_size
|
||||
self.deepspeed_patch_applied = False
|
||||
|
||||
|
|
@ -72,8 +58,6 @@ class T3(nn.Module):
|
|||
self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim)
|
||||
|
||||
# custom position embedding
|
||||
self.text_pos_emb = None
|
||||
self.speech_pos_emb = None
|
||||
if hp.input_pos_emb == "learned":
|
||||
max_text_seq_len = hp.max_text_tokens + 2
|
||||
self.text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, self.dim)
|
||||
|
|
@ -83,7 +67,7 @@ class T3(nn.Module):
|
|||
|
||||
# logit projection
|
||||
self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False)
|
||||
self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=self.is_gpt)
|
||||
self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=False)
|
||||
self.compiled = False
|
||||
|
||||
@property
|
||||
|
|
@ -95,9 +79,8 @@ class T3(nn.Module):
|
|||
Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`.
|
||||
"""
|
||||
if t3_cond.cond_prompt_speech_tokens is not None and t3_cond.cond_prompt_speech_emb is None:
|
||||
t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens)
|
||||
if not self.is_gpt:
|
||||
t3_cond.cond_prompt_speech_emb += self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
|
||||
t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens) + \
|
||||
self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
|
||||
return self.cond_enc(t3_cond) # (B, len_cond, dim)
|
||||
|
||||
def prepare_input_embeds(
|
||||
|
|
@ -106,13 +89,11 @@ class T3(nn.Module):
|
|||
t3_cond: T3Cond,
|
||||
text_tokens: torch.LongTensor,
|
||||
speech_tokens: torch.LongTensor,
|
||||
cfg_weight: float = 0.0,
|
||||
):
|
||||
# prepare input embeddings (skip backbone tranformer embeddings)
|
||||
cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim)
|
||||
text_emb = self.text_emb(text_tokens) # (B, len_text, dim)
|
||||
if cfg_weight > 0.0 and not self.is_gpt:
|
||||
text_emb[1].zero_() # CFG uncond
|
||||
text_emb[1].zero_() # CFG uncond
|
||||
|
||||
speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim)
|
||||
if self.hp.input_pos_emb == "learned":
|
||||
|
|
@ -240,11 +221,10 @@ class T3(nn.Module):
|
|||
stop_on_eos=True,
|
||||
do_sample=True,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
min_p=0.05,
|
||||
top_p=0.8,
|
||||
length_penalty=1.0,
|
||||
repetition_penalty=1.2,
|
||||
cfg_weight=0.5,
|
||||
repetition_penalty=2.0,
|
||||
cfg_weight=0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
|
@ -264,7 +244,6 @@ class T3(nn.Module):
|
|||
t3_cond=t3_cond,
|
||||
text_tokens=text_tokens,
|
||||
speech_tokens=initial_speech_tokens,
|
||||
cfg_weight=cfg_weight,
|
||||
)
|
||||
|
||||
# In order to use the standard HF generate method, we need to extend some methods to inject our custom logic
|
||||
|
|
@ -275,18 +254,13 @@ class T3(nn.Module):
|
|||
# TODO? synchronize the expensive compile function
|
||||
# with self.compile_lock:
|
||||
if not self.compiled:
|
||||
# Default to None for English models, only create for multilingual
|
||||
alignment_stream_analyzer = None
|
||||
if self.hp.is_multilingual:
|
||||
alignment_stream_analyzer = AlignmentStreamAnalyzer(
|
||||
self.tfmr,
|
||||
None,
|
||||
text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
|
||||
alignment_layer_idx=9, # TODO: hparam or something?
|
||||
eos_idx=self.hp.stop_speech_token,
|
||||
)
|
||||
assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
|
||||
|
||||
alignment_stream_analyzer = AlignmentStreamAnalyzer(
|
||||
self.tfmr,
|
||||
None,
|
||||
text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
|
||||
alignment_layer_idx=9, # TODO: hparam or something?
|
||||
eos_idx=self.hp.stop_speech_token,
|
||||
)
|
||||
patched_model = T3HuggingfaceBackend(
|
||||
config=self.cfg,
|
||||
llama=self.tfmr,
|
||||
|
|
@ -307,7 +281,7 @@ class T3(nn.Module):
|
|||
# max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
|
||||
# num_return_sequences=num_return_sequences,
|
||||
# temperature=temperature,
|
||||
# min_p=min_p,
|
||||
# top_p=top_p,
|
||||
# length_penalty=length_penalty,
|
||||
# repetition_penalty=repetition_penalty,
|
||||
# do_sample=do_sample,
|
||||
|
|
@ -332,9 +306,7 @@ class T3(nn.Module):
|
|||
|
||||
# Instantiate the logits processors.
|
||||
top_p_warper = TopPLogitsWarper(top_p=top_p)
|
||||
min_p_warper = MinPLogitsWarper(min_p=min_p)
|
||||
top_p_warper = TopPLogitsWarper(top_p=top_p)
|
||||
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty))
|
||||
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
|
||||
|
||||
# ---- Initial Forward Pass (no kv_cache yet) ----
|
||||
output = self.patched_model(
|
||||
|
|
@ -350,32 +322,21 @@ class T3(nn.Module):
|
|||
|
||||
# ---- Generation Loop using kv_cache ----
|
||||
for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
|
||||
logits_step = output.logits[:, -1, :]
|
||||
# CFG combine → (1, V)
|
||||
cond = logits_step[0:1, :]
|
||||
uncond = logits_step[1:2, :]
|
||||
cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
|
||||
logits = cond + cfg * (cond - uncond)
|
||||
|
||||
# Apply alignment stream analyzer integrity checks
|
||||
if self.patched_model.alignment_stream_analyzer is not None:
|
||||
if logits.dim() == 1: # guard in case something upstream squeezed
|
||||
logits = logits.unsqueeze(0) # (1, V)
|
||||
# Pass the last generated token for repetition tracking
|
||||
last_token = generated_ids[0, -1].item() if len(generated_ids[0]) > 0 else None
|
||||
logits = self.patched_model.alignment_stream_analyzer.step(logits, next_token=last_token) # (1, V)
|
||||
logits = output.logits[:, -1, :]
|
||||
|
||||
# CFG
|
||||
logits_cond = logits[0:1]
|
||||
logits_uncond = logits[1:2]
|
||||
logits = logits_cond + cfg_weight * (logits_cond - logits_uncond)
|
||||
logits = logits.squeeze(1)
|
||||
|
||||
# Apply repetition penalty
|
||||
ids_for_proc = generated_ids[:1, ...] # batch = 1
|
||||
logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
|
||||
|
||||
# Apply temperature scaling.
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
# Apply min_p and top_p filtering
|
||||
logits = min_p_warper(ids_for_proc, logits)
|
||||
logits = top_p_warper(ids_for_proc, logits)
|
||||
|
||||
# Apply repetition penalty and top‑p filtering.
|
||||
logits = repetition_penalty_processor(generated_ids, logits)
|
||||
logits = top_p_warper(None, logits)
|
||||
|
||||
# Convert logits to probabilities and sample the next token.
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
|
|
@ -386,7 +347,6 @@ class T3(nn.Module):
|
|||
|
||||
# Check for EOS token.
|
||||
if next_token.view(-1) == self.hp.stop_speech_token:
|
||||
logger.info(f"✅ EOS token detected! Stopping generation at step {i+1}")
|
||||
break
|
||||
|
||||
# Get embedding for the new token.
|
||||
|
|
@ -410,81 +370,3 @@ class T3(nn.Module):
|
|||
# Concatenate all predicted tokens along the sequence dimension.
|
||||
predicted_tokens = torch.cat(predicted, dim=1) # shape: (B, num_tokens)
|
||||
return predicted_tokens
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference_turbo(self, t3_cond, text_tokens, temperature=0.8, top_k=1000, top_p=0.95, repetition_penalty=1.2,
|
||||
max_gen_len=1000):
|
||||
|
||||
logits_processors = LogitsProcessorList()
|
||||
if temperature > 0 and temperature != 1.0:
|
||||
logits_processors.append(TemperatureLogitsWarper(temperature))
|
||||
if top_k > 0:
|
||||
logits_processors.append(TopKLogitsWarper(top_k))
|
||||
if top_p < 1.0:
|
||||
logits_processors.append(TopPLogitsWarper(top_p))
|
||||
if repetition_penalty != 1.0:
|
||||
logits_processors.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
|
||||
|
||||
|
||||
speech_start_token = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])
|
||||
embeds, _ = self.prepare_input_embeds(
|
||||
t3_cond=t3_cond,
|
||||
text_tokens=text_tokens,
|
||||
speech_tokens=speech_start_token,
|
||||
cfg_weight=0.0,
|
||||
)
|
||||
|
||||
generated_speech_tokens = []
|
||||
|
||||
llm_outputs = self.tfmr(
|
||||
inputs_embeds=embeds,
|
||||
use_cache=True
|
||||
)
|
||||
|
||||
hidden_states = llm_outputs[0]
|
||||
past_key_values = llm_outputs.past_key_values
|
||||
|
||||
speech_hidden = hidden_states[:, -1:]
|
||||
speech_logits = self.speech_head(speech_hidden)
|
||||
|
||||
processed_logits = logits_processors(speech_start_token, speech_logits[:, -1, :])
|
||||
probs = F.softmax(processed_logits, dim=-1)
|
||||
next_speech_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
generated_speech_tokens.append(next_speech_token)
|
||||
current_speech_token = next_speech_token
|
||||
|
||||
for _ in tqdm(range(max_gen_len)):
|
||||
current_speech_embed = self.speech_emb(current_speech_token)
|
||||
|
||||
llm_outputs = self.tfmr(
|
||||
inputs_embeds=current_speech_embed,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True
|
||||
)
|
||||
|
||||
hidden_states = llm_outputs[0]
|
||||
past_key_values = llm_outputs.past_key_values
|
||||
speech_logits = self.speech_head(hidden_states)
|
||||
|
||||
input_ids = torch.cat(generated_speech_tokens, dim=1)
|
||||
processed_logits = logits_processors(input_ids, speech_logits[:, -1, :])
|
||||
if torch.all(processed_logits == -float("inf")):
|
||||
print("Warning: All logits are -inf")
|
||||
break
|
||||
|
||||
probs = F.softmax(processed_logits, dim=-1)
|
||||
next_speech_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
generated_speech_tokens.append(next_speech_token)
|
||||
current_speech_token = next_speech_token
|
||||
if torch.all(next_speech_token == self.hp.stop_speech_token):
|
||||
break
|
||||
|
||||
all_tokens = torch.cat(generated_speech_tokens, dim=1)
|
||||
|
||||
# Remove EOS token if present
|
||||
if all_tokens.size(1) > 0 and all_tokens[0, -1] == self.hp.stop_speech_token:
|
||||
all_tokens = all_tokens[:, :-1]
|
||||
|
||||
return all_tokens
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
from .tokenizer import EnTokenizer, MTLTokenizer
|
||||
from .tokenizer import EnTokenizer
|
||||
|
|
|
|||
|
|
@ -1,11 +1,7 @@
|
|||
import logging
|
||||
import json
|
||||
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from unicodedata import category, normalize
|
||||
from tokenizers import Tokenizer
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
# Special tokens
|
||||
|
|
@ -32,7 +28,7 @@ class EnTokenizer:
|
|||
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
|
||||
return text_tokens
|
||||
|
||||
def encode(self, txt: str):
|
||||
def encode( self, txt: str, verbose=False):
|
||||
"""
|
||||
clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer
|
||||
"""
|
||||
|
|
@ -45,269 +41,10 @@ class EnTokenizer:
|
|||
if isinstance(seq, torch.Tensor):
|
||||
seq = seq.cpu().numpy()
|
||||
|
||||
txt: str = self.tokenizer.decode(seq, skip_special_tokens=False)
|
||||
txt: str = self.tokenizer.decode(seq,
|
||||
skip_special_tokens=False)
|
||||
txt = txt.replace(' ', '')
|
||||
txt = txt.replace(SPACE, ' ')
|
||||
txt = txt.replace(EOT, '')
|
||||
txt = txt.replace(UNK, '')
|
||||
return txt
|
||||
|
||||
|
||||
# Model repository
|
||||
REPO_ID = "ResembleAI/chatterbox"
|
||||
|
||||
# Global instances for optional dependencies
|
||||
_kakasi = None
|
||||
_dicta = None
|
||||
_russian_stresser = None
|
||||
|
||||
|
||||
def is_kanji(c: str) -> bool:
|
||||
"""Check if character is kanji."""
|
||||
return 19968 <= ord(c) <= 40959
|
||||
|
||||
|
||||
def is_katakana(c: str) -> bool:
|
||||
"""Check if character is katakana."""
|
||||
return 12449 <= ord(c) <= 12538
|
||||
|
||||
|
||||
def hiragana_normalize(text: str) -> str:
|
||||
"""Japanese text normalization: converts kanji to hiragana; katakana remains the same."""
|
||||
global _kakasi
|
||||
|
||||
try:
|
||||
if _kakasi is None:
|
||||
import pykakasi
|
||||
_kakasi = pykakasi.kakasi()
|
||||
|
||||
result = _kakasi.convert(text)
|
||||
out = []
|
||||
|
||||
for r in result:
|
||||
inp = r['orig']
|
||||
hira = r["hira"]
|
||||
|
||||
# Any kanji in the phrase
|
||||
if any([is_kanji(c) for c in inp]):
|
||||
if hira and hira[0] in ["は", "へ"]: # Safety check for empty hira
|
||||
hira = " " + hira
|
||||
out.append(hira)
|
||||
|
||||
# All katakana
|
||||
elif all([is_katakana(c) for c in inp]) if inp else False: # Safety check for empty inp
|
||||
out.append(r['orig'])
|
||||
|
||||
else:
|
||||
out.append(inp)
|
||||
|
||||
normalized_text = "".join(out)
|
||||
|
||||
# Decompose Japanese characters for tokenizer compatibility
|
||||
import unicodedata
|
||||
normalized_text = unicodedata.normalize('NFKD', normalized_text)
|
||||
|
||||
return normalized_text
|
||||
|
||||
except ImportError:
|
||||
logger.warning("pykakasi not available - Japanese text processing skipped")
|
||||
return text
|
||||
|
||||
|
||||
def add_hebrew_diacritics(text: str) -> str:
|
||||
"""Hebrew text normalization: adds diacritics to Hebrew text."""
|
||||
global _dicta
|
||||
|
||||
try:
|
||||
if _dicta is None:
|
||||
from dicta_onnx import Dicta
|
||||
_dicta = Dicta()
|
||||
|
||||
return _dicta.add_diacritics(text)
|
||||
|
||||
except ImportError:
|
||||
logger.warning("dicta_onnx not available - Hebrew text processing skipped")
|
||||
return text
|
||||
except Exception as e:
|
||||
logger.warning(f"Hebrew diacritization failed: {e}")
|
||||
return text
|
||||
|
||||
|
||||
def korean_normalize(text: str) -> str:
|
||||
"""Korean text normalization: decompose syllables into Jamo for tokenization."""
|
||||
|
||||
def decompose_hangul(char):
|
||||
"""Decompose Korean syllable into Jamo components."""
|
||||
if not ('\uac00' <= char <= '\ud7af'):
|
||||
return char
|
||||
|
||||
# Hangul decomposition formula
|
||||
base = ord(char) - 0xAC00
|
||||
initial = chr(0x1100 + base // (21 * 28))
|
||||
medial = chr(0x1161 + (base % (21 * 28)) // 28)
|
||||
final = chr(0x11A7 + base % 28) if base % 28 > 0 else ''
|
||||
|
||||
return initial + medial + final
|
||||
|
||||
# Decompose syllables and normalize punctuation
|
||||
result = ''.join(decompose_hangul(char) for char in text)
|
||||
return result.strip()
|
||||
|
||||
|
||||
class ChineseCangjieConverter:
|
||||
"""Converts Chinese characters to Cangjie codes for tokenization."""
|
||||
|
||||
def __init__(self, model_dir=None):
|
||||
self.word2cj = {}
|
||||
self.cj2word = {}
|
||||
self.segmenter = None
|
||||
self._load_cangjie_mapping(model_dir)
|
||||
self._init_segmenter()
|
||||
|
||||
def _load_cangjie_mapping(self, model_dir=None):
|
||||
"""Load Cangjie mapping from HuggingFace model repository."""
|
||||
try:
|
||||
cangjie_file = hf_hub_download(
|
||||
repo_id=REPO_ID,
|
||||
filename="Cangjie5_TC.json",
|
||||
cache_dir=model_dir
|
||||
)
|
||||
|
||||
with open(cangjie_file, "r", encoding="utf-8") as fp:
|
||||
data = json.load(fp)
|
||||
|
||||
for entry in data:
|
||||
word, code = entry.split("\t")[:2]
|
||||
self.word2cj[word] = code
|
||||
if code not in self.cj2word:
|
||||
self.cj2word[code] = [word]
|
||||
else:
|
||||
self.cj2word[code].append(word)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load Cangjie mapping: {e}")
|
||||
|
||||
def _init_segmenter(self):
|
||||
"""Initialize pkuseg segmenter."""
|
||||
try:
|
||||
from spacy_pkuseg import pkuseg
|
||||
self.segmenter = pkuseg()
|
||||
except ImportError:
|
||||
logger.warning("pkuseg not available - Chinese segmentation will be skipped")
|
||||
self.segmenter = None
|
||||
|
||||
def _cangjie_encode(self, glyph: str):
|
||||
"""Encode a single Chinese glyph to Cangjie code."""
|
||||
normed_glyph = glyph
|
||||
code = self.word2cj.get(normed_glyph, None)
|
||||
if code is None: # e.g. Japanese hiragana
|
||||
return None
|
||||
index = self.cj2word[code].index(normed_glyph)
|
||||
index = str(index) if index > 0 else ""
|
||||
return code + str(index)
|
||||
|
||||
|
||||
|
||||
def __call__(self, text):
|
||||
"""Convert Chinese characters in text to Cangjie tokens."""
|
||||
output = []
|
||||
if self.segmenter is not None:
|
||||
segmented_words = self.segmenter.cut(text)
|
||||
full_text = " ".join(segmented_words)
|
||||
else:
|
||||
full_text = text
|
||||
|
||||
for t in full_text:
|
||||
if category(t) == "Lo":
|
||||
cangjie = self._cangjie_encode(t)
|
||||
if cangjie is None:
|
||||
output.append(t)
|
||||
continue
|
||||
code = []
|
||||
for c in cangjie:
|
||||
code.append(f"[cj_{c}]")
|
||||
code.append("[cj_.]")
|
||||
code = "".join(code)
|
||||
output.append(code)
|
||||
else:
|
||||
output.append(t)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
def add_russian_stress(text: str) -> str:
|
||||
"""Russian text normalization: adds stress marks to Russian text."""
|
||||
global _russian_stresser
|
||||
|
||||
try:
|
||||
if _russian_stresser is None:
|
||||
from russian_text_stresser.text_stresser import RussianTextStresser
|
||||
_russian_stresser = RussianTextStresser()
|
||||
|
||||
return _russian_stresser.stress_text(text)
|
||||
|
||||
except ImportError:
|
||||
logger.warning("russian_text_stresser not available - Russian stress labeling skipped")
|
||||
return text
|
||||
except Exception as e:
|
||||
logger.warning(f"Russian stress labeling failed: {e}")
|
||||
return text
|
||||
|
||||
|
||||
class MTLTokenizer:
|
||||
def __init__(self, vocab_file_path):
|
||||
self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path)
|
||||
model_dir = Path(vocab_file_path).parent
|
||||
self.cangjie_converter = ChineseCangjieConverter(model_dir)
|
||||
self.check_vocabset_sot_eot()
|
||||
|
||||
def check_vocabset_sot_eot(self):
|
||||
voc = self.tokenizer.get_vocab()
|
||||
assert SOT in voc
|
||||
assert EOT in voc
|
||||
|
||||
def preprocess_text(self, raw_text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
|
||||
"""
|
||||
Text preprocessor that handles lowercase conversion and NFKD normalization.
|
||||
"""
|
||||
preprocessed_text = raw_text
|
||||
if lowercase:
|
||||
preprocessed_text = preprocessed_text.lower()
|
||||
if nfkd_normalize:
|
||||
preprocessed_text = normalize("NFKD", preprocessed_text)
|
||||
|
||||
return preprocessed_text
|
||||
|
||||
def text_to_tokens(self, text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
|
||||
text_tokens = self.encode(text, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize)
|
||||
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0)
|
||||
return text_tokens
|
||||
|
||||
def encode(self, txt: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True):
|
||||
txt = self.preprocess_text(txt, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize)
|
||||
|
||||
# Language-specific text processing
|
||||
if language_id == 'zh':
|
||||
txt = self.cangjie_converter(txt)
|
||||
elif language_id == 'ja':
|
||||
txt = hiragana_normalize(txt)
|
||||
elif language_id == 'he':
|
||||
txt = add_hebrew_diacritics(txt)
|
||||
elif language_id == 'ko':
|
||||
txt = korean_normalize(txt)
|
||||
elif language_id == 'ru':
|
||||
txt = add_russian_stress(txt)
|
||||
|
||||
# Prepend language token
|
||||
if language_id:
|
||||
txt = f"[{language_id.lower()}]{txt}"
|
||||
|
||||
txt = txt.replace(' ', SPACE)
|
||||
return self.tokenizer.encode(txt).ids
|
||||
|
||||
def decode(self, seq):
|
||||
if isinstance(seq, torch.Tensor):
|
||||
seq = seq.cpu().numpy()
|
||||
|
||||
txt = self.tokenizer.decode(seq, skip_special_tokens=False)
|
||||
txt = txt.replace(' ', '').replace(SPACE, ' ').replace(EOT, '').replace(UNK, '')
|
||||
return txt
|
||||
|
|
|
|||
|
|
@ -1,4 +0,0 @@
|
|||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
|
@ -1,301 +0,0 @@
|
|||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
import perth
|
||||
import torch.nn.functional as F
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from .models.t3 import T3
|
||||
from .models.t3.modules.t3_config import T3Config
|
||||
from .models.s3tokenizer import S3_SR, drop_invalid_tokens
|
||||
from .models.s3gen import S3GEN_SR, S3Gen
|
||||
from .models.tokenizers import MTLTokenizer
|
||||
from .models.voice_encoder import VoiceEncoder
|
||||
from .models.t3.modules.cond_enc import T3Cond
|
||||
|
||||
|
||||
REPO_ID = "ResembleAI/chatterbox"
|
||||
|
||||
# Supported languages for the multilingual model
|
||||
SUPPORTED_LANGUAGES = {
|
||||
"ar": "Arabic",
|
||||
"da": "Danish",
|
||||
"de": "German",
|
||||
"el": "Greek",
|
||||
"en": "English",
|
||||
"es": "Spanish",
|
||||
"fi": "Finnish",
|
||||
"fr": "French",
|
||||
"he": "Hebrew",
|
||||
"hi": "Hindi",
|
||||
"it": "Italian",
|
||||
"ja": "Japanese",
|
||||
"ko": "Korean",
|
||||
"ms": "Malay",
|
||||
"nl": "Dutch",
|
||||
"no": "Norwegian",
|
||||
"pl": "Polish",
|
||||
"pt": "Portuguese",
|
||||
"ru": "Russian",
|
||||
"sv": "Swedish",
|
||||
"sw": "Swahili",
|
||||
"tr": "Turkish",
|
||||
"zh": "Chinese",
|
||||
}
|
||||
|
||||
|
||||
def punc_norm(text: str) -> str:
|
||||
"""
|
||||
Quick cleanup func for punctuation from LLMs or
|
||||
containing chars not seen often in the dataset
|
||||
"""
|
||||
if len(text) == 0:
|
||||
return "You need to add some text for me to talk."
|
||||
|
||||
# Capitalise first letter
|
||||
if text[0].islower():
|
||||
text = text[0].upper() + text[1:]
|
||||
|
||||
# Remove multiple space chars
|
||||
text = " ".join(text.split())
|
||||
|
||||
# Replace uncommon/llm punc
|
||||
punc_to_replace = [
|
||||
("...", ", "),
|
||||
("…", ", "),
|
||||
(":", ","),
|
||||
(" - ", ", "),
|
||||
(";", ", "),
|
||||
("—", "-"),
|
||||
("–", "-"),
|
||||
(" ,", ","),
|
||||
("“", "\""),
|
||||
("”", "\""),
|
||||
("‘", "'"),
|
||||
("’", "'"),
|
||||
]
|
||||
for old_char_sequence, new_char in punc_to_replace:
|
||||
text = text.replace(old_char_sequence, new_char)
|
||||
|
||||
# Add full stop if no ending punc
|
||||
text = text.rstrip(" ")
|
||||
sentence_enders = {".", "!", "?", "-", ",","、",",","。","?","!"}
|
||||
if not any(text.endswith(p) for p in sentence_enders):
|
||||
text += "."
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conditionals:
|
||||
"""
|
||||
Conditionals for T3 and S3Gen
|
||||
- T3 conditionals:
|
||||
- speaker_emb
|
||||
- clap_emb
|
||||
- cond_prompt_speech_tokens
|
||||
- cond_prompt_speech_emb
|
||||
- emotion_adv
|
||||
- S3Gen conditionals:
|
||||
- prompt_token
|
||||
- prompt_token_len
|
||||
- prompt_feat
|
||||
- prompt_feat_len
|
||||
- embedding
|
||||
"""
|
||||
t3: T3Cond
|
||||
gen: dict
|
||||
|
||||
def to(self, device):
|
||||
self.t3 = self.t3.to(device=device)
|
||||
for k, v in self.gen.items():
|
||||
if torch.is_tensor(v):
|
||||
self.gen[k] = v.to(device=device)
|
||||
return self
|
||||
|
||||
def save(self, fpath: Path):
|
||||
arg_dict = dict(
|
||||
t3=self.t3.__dict__,
|
||||
gen=self.gen
|
||||
)
|
||||
torch.save(arg_dict, fpath)
|
||||
|
||||
@classmethod
|
||||
def load(cls, fpath, map_location="cpu"):
|
||||
kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
|
||||
return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
|
||||
|
||||
|
||||
class ChatterboxMultilingualTTS:
|
||||
ENC_COND_LEN = 6 * S3_SR
|
||||
DEC_COND_LEN = 10 * S3GEN_SR
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
t3: T3,
|
||||
s3gen: S3Gen,
|
||||
ve: VoiceEncoder,
|
||||
tokenizer: MTLTokenizer,
|
||||
device: str,
|
||||
conds: Conditionals = None,
|
||||
):
|
||||
self.sr = S3GEN_SR # sample rate of synthesized audio
|
||||
self.t3 = t3
|
||||
self.s3gen = s3gen
|
||||
self.ve = ve
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.conds = conds
|
||||
self.watermarker = perth.PerthImplicitWatermarker()
|
||||
|
||||
@classmethod
|
||||
def get_supported_languages(cls):
|
||||
"""Return dictionary of supported language codes and names."""
|
||||
return SUPPORTED_LANGUAGES.copy()
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS':
|
||||
ckpt_dir = Path(ckpt_dir)
|
||||
|
||||
ve = VoiceEncoder()
|
||||
ve.load_state_dict(
|
||||
torch.load(ckpt_dir / "ve.pt", weights_only=True)
|
||||
)
|
||||
ve.to(device).eval()
|
||||
|
||||
t3 = T3(T3Config.multilingual())
|
||||
t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors")
|
||||
if "model" in t3_state.keys():
|
||||
t3_state = t3_state["model"][0]
|
||||
t3.load_state_dict(t3_state)
|
||||
t3.to(device).eval()
|
||||
|
||||
s3gen = S3Gen()
|
||||
s3gen.load_state_dict(
|
||||
torch.load(ckpt_dir / "s3gen.pt", weights_only=True)
|
||||
)
|
||||
s3gen.to(device).eval()
|
||||
|
||||
tokenizer = MTLTokenizer(
|
||||
str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json")
|
||||
)
|
||||
|
||||
conds = None
|
||||
if (builtin_voice := ckpt_dir / "conds.pt").exists():
|
||||
conds = Conditionals.load(builtin_voice).to(device)
|
||||
|
||||
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, device: torch.device) -> 'ChatterboxMultilingualTTS':
|
||||
ckpt_dir = Path(
|
||||
snapshot_download(
|
||||
repo_id=REPO_ID,
|
||||
repo_type="model",
|
||||
revision="main",
|
||||
allow_patterns=["ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"],
|
||||
token=os.getenv("HF_TOKEN"),
|
||||
)
|
||||
)
|
||||
return cls.from_local(ckpt_dir, device)
|
||||
|
||||
def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
|
||||
## Load reference wav
|
||||
s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
|
||||
|
||||
ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
|
||||
|
||||
s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
|
||||
s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
|
||||
|
||||
# Speech cond prompt tokens
|
||||
t3_cond_prompt_tokens = None
|
||||
if plen := self.t3.hp.speech_cond_prompt_len:
|
||||
s3_tokzr = self.s3gen.tokenizer
|
||||
t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
|
||||
t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
|
||||
|
||||
# Voice-encoder speaker embedding
|
||||
ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
|
||||
ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
|
||||
|
||||
t3_cond = T3Cond(
|
||||
speaker_emb=ve_embed,
|
||||
cond_prompt_speech_tokens=t3_cond_prompt_tokens,
|
||||
emotion_adv=exaggeration * torch.ones(1, 1, 1),
|
||||
).to(device=self.device)
|
||||
self.conds = Conditionals(t3_cond, s3gen_ref_dict)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
text,
|
||||
language_id,
|
||||
audio_prompt_path=None,
|
||||
exaggeration=0.5,
|
||||
cfg_weight=0.5,
|
||||
temperature=0.8,
|
||||
repetition_penalty=2.0,
|
||||
min_p=0.05,
|
||||
top_p=1.0,
|
||||
):
|
||||
# Validate language_id
|
||||
if language_id and language_id.lower() not in SUPPORTED_LANGUAGES:
|
||||
supported_langs = ", ".join(SUPPORTED_LANGUAGES.keys())
|
||||
raise ValueError(
|
||||
f"Unsupported language_id '{language_id}'. "
|
||||
f"Supported languages: {supported_langs}"
|
||||
)
|
||||
|
||||
if audio_prompt_path:
|
||||
self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
|
||||
else:
|
||||
assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
|
||||
|
||||
# Update exaggeration if needed
|
||||
if float(exaggeration) != float(self.conds.t3.emotion_adv[0, 0, 0].item()):
|
||||
_cond: T3Cond = self.conds.t3
|
||||
self.conds.t3 = T3Cond(
|
||||
speaker_emb=_cond.speaker_emb,
|
||||
cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens,
|
||||
emotion_adv=exaggeration * torch.ones(1, 1, 1),
|
||||
).to(device=self.device)
|
||||
|
||||
# Norm and tokenize text
|
||||
text = punc_norm(text)
|
||||
text_tokens = self.tokenizer.text_to_tokens(text, language_id=language_id.lower() if language_id else None).to(self.device)
|
||||
text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
|
||||
|
||||
sot = self.t3.hp.start_text_token
|
||||
eot = self.t3.hp.stop_text_token
|
||||
text_tokens = F.pad(text_tokens, (1, 0), value=sot)
|
||||
text_tokens = F.pad(text_tokens, (0, 1), value=eot)
|
||||
|
||||
with torch.inference_mode():
|
||||
speech_tokens = self.t3.inference(
|
||||
t3_cond=self.conds.t3,
|
||||
text_tokens=text_tokens,
|
||||
max_new_tokens=1000, # TODO: use the value in config
|
||||
temperature=temperature,
|
||||
cfg_weight=cfg_weight,
|
||||
repetition_penalty=repetition_penalty,
|
||||
min_p=min_p,
|
||||
top_p=top_p,
|
||||
)
|
||||
# Extract only the conditional batch.
|
||||
speech_tokens = speech_tokens[0]
|
||||
|
||||
# TODO: output becomes 1D
|
||||
speech_tokens = drop_invalid_tokens(speech_tokens)
|
||||
speech_tokens = speech_tokens.to(self.device)
|
||||
|
||||
wav, _ = self.s3gen.inference(
|
||||
speech_tokens=speech_tokens,
|
||||
ref_dict=self.conds.gen,
|
||||
)
|
||||
wav = wav.squeeze(0).detach().cpu().numpy()
|
||||
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
|
||||
return torch.from_numpy(watermarked_wav).unsqueeze(0)
|
||||
|
|
@ -6,7 +6,6 @@ import torch
|
|||
import perth
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from .models.t3 import T3
|
||||
from .models.s3tokenizer import S3_SR, drop_invalid_tokens
|
||||
|
|
@ -137,12 +136,12 @@ class ChatterboxTTS:
|
|||
|
||||
ve = VoiceEncoder()
|
||||
ve.load_state_dict(
|
||||
load_file(ckpt_dir / "ve.safetensors")
|
||||
torch.load(ckpt_dir / "ve.pt", map_location=map_location)
|
||||
)
|
||||
ve.to(device).eval()
|
||||
|
||||
t3 = T3()
|
||||
t3_state = load_file(ckpt_dir / "t3_cfg.safetensors")
|
||||
t3_state = torch.load(ckpt_dir / "t3_cfg.pt", map_location=map_location)
|
||||
if "model" in t3_state.keys():
|
||||
t3_state = t3_state["model"][0]
|
||||
t3.load_state_dict(t3_state)
|
||||
|
|
@ -150,7 +149,7 @@ class ChatterboxTTS:
|
|||
|
||||
s3gen = S3Gen()
|
||||
s3gen.load_state_dict(
|
||||
load_file(ckpt_dir / "s3gen.safetensors"), strict=False
|
||||
torch.load(ckpt_dir / "s3gen.pt", map_location=map_location)
|
||||
)
|
||||
s3gen.to(device).eval()
|
||||
|
||||
|
|
@ -173,8 +172,8 @@ class ChatterboxTTS:
|
|||
else:
|
||||
print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
|
||||
device = "cpu"
|
||||
|
||||
for fpath in ["ve.safetensors", "t3_cfg.safetensors", "s3gen.safetensors", "tokenizer.json", "conds.pt"]:
|
||||
|
||||
for fpath in ["ve.pt", "t3_cfg.pt", "s3gen.pt", "tokenizer.json", "conds.pt"]:
|
||||
local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath)
|
||||
|
||||
return cls.from_local(Path(local_path).parent, device)
|
||||
|
|
@ -208,9 +207,6 @@ class ChatterboxTTS:
|
|||
def generate(
|
||||
self,
|
||||
text,
|
||||
repetition_penalty=1.2,
|
||||
min_p=0.05,
|
||||
top_p=1.0,
|
||||
audio_prompt_path=None,
|
||||
exaggeration=0.5,
|
||||
cfg_weight=0.5,
|
||||
|
|
@ -233,9 +229,7 @@ class ChatterboxTTS:
|
|||
# Norm and tokenize text
|
||||
text = punc_norm(text)
|
||||
text_tokens = self.tokenizer.text_to_tokens(text).to(self.device)
|
||||
|
||||
if cfg_weight > 0.0:
|
||||
text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
|
||||
text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
|
||||
|
||||
sot = self.t3.hp.start_text_token
|
||||
eot = self.t3.hp.stop_text_token
|
||||
|
|
@ -249,18 +243,12 @@ class ChatterboxTTS:
|
|||
max_new_tokens=1000, # TODO: use the value in config
|
||||
temperature=temperature,
|
||||
cfg_weight=cfg_weight,
|
||||
repetition_penalty=repetition_penalty,
|
||||
min_p=min_p,
|
||||
top_p=top_p,
|
||||
)
|
||||
# Extract only the conditional batch.
|
||||
speech_tokens = speech_tokens[0]
|
||||
|
||||
# TODO: output becomes 1D
|
||||
speech_tokens = drop_invalid_tokens(speech_tokens)
|
||||
|
||||
speech_tokens = speech_tokens[speech_tokens < 6561]
|
||||
|
||||
speech_tokens = speech_tokens.to(self.device)
|
||||
|
||||
wav, _ = self.s3gen.inference(
|
||||
|
|
@ -269,4 +257,4 @@ class ChatterboxTTS:
|
|||
)
|
||||
wav = wav.squeeze(0).detach().cpu().numpy()
|
||||
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
|
||||
return torch.from_numpy(watermarked_wav).unsqueeze(0)
|
||||
return torch.from_numpy(watermarked_wav).unsqueeze(0)
|
||||
|
|
@ -1,296 +0,0 @@
|
|||
import os
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
import perth
|
||||
import pyloudnorm as ln
|
||||
|
||||
from safetensors.torch import load_file
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from .models.t3 import T3
|
||||
from .models.s3tokenizer import S3_SR
|
||||
from .models.s3gen import S3GEN_SR, S3Gen
|
||||
from .models.tokenizers import EnTokenizer
|
||||
from .models.voice_encoder import VoiceEncoder
|
||||
from .models.t3.modules.cond_enc import T3Cond
|
||||
from .models.t3.modules.t3_config import T3Config
|
||||
from .models.s3gen.const import S3GEN_SIL
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REPO_ID = "ResembleAI/chatterbox-turbo"
|
||||
|
||||
|
||||
def punc_norm(text: str) -> str:
|
||||
"""
|
||||
Quick cleanup func for punctuation from LLMs or
|
||||
containing chars not seen often in the dataset
|
||||
"""
|
||||
if len(text) == 0:
|
||||
return "You need to add some text for me to talk."
|
||||
|
||||
# Capitalise first letter
|
||||
if text[0].islower():
|
||||
text = text[0].upper() + text[1:]
|
||||
|
||||
# Remove multiple space chars
|
||||
text = " ".join(text.split())
|
||||
|
||||
# Replace uncommon/llm punc
|
||||
punc_to_replace = [
|
||||
("…", ", "),
|
||||
(":", ","),
|
||||
("—", "-"),
|
||||
("–", "-"),
|
||||
(" ,", ","),
|
||||
("“", "\""),
|
||||
("”", "\""),
|
||||
("‘", "'"),
|
||||
("’", "'"),
|
||||
]
|
||||
for old_char_sequence, new_char in punc_to_replace:
|
||||
text = text.replace(old_char_sequence, new_char)
|
||||
|
||||
# Add full stop if no ending punc
|
||||
text = text.rstrip(" ")
|
||||
sentence_enders = {".", "!", "?", "-", ","}
|
||||
if not any(text.endswith(p) for p in sentence_enders):
|
||||
text += "."
|
||||
|
||||
return text
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conditionals:
|
||||
"""
|
||||
Conditionals for T3 and S3Gen
|
||||
- T3 conditionals:
|
||||
- speaker_emb
|
||||
- clap_emb
|
||||
- cond_prompt_speech_tokens
|
||||
- cond_prompt_speech_emb
|
||||
- emotion_adv
|
||||
- S3Gen conditionals:
|
||||
- prompt_token
|
||||
- prompt_token_len
|
||||
- prompt_feat
|
||||
- prompt_feat_len
|
||||
- embedding
|
||||
"""
|
||||
t3: T3Cond
|
||||
gen: dict
|
||||
|
||||
def to(self, device):
|
||||
self.t3 = self.t3.to(device=device)
|
||||
for k, v in self.gen.items():
|
||||
if torch.is_tensor(v):
|
||||
self.gen[k] = v.to(device=device)
|
||||
return self
|
||||
|
||||
def save(self, fpath: Path):
|
||||
arg_dict = dict(
|
||||
t3=self.t3.__dict__,
|
||||
gen=self.gen
|
||||
)
|
||||
torch.save(arg_dict, fpath)
|
||||
|
||||
@classmethod
|
||||
def load(cls, fpath, map_location="cpu"):
|
||||
if isinstance(map_location, str):
|
||||
map_location = torch.device(map_location)
|
||||
kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
|
||||
return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
|
||||
|
||||
|
||||
class ChatterboxTurboTTS:
|
||||
ENC_COND_LEN = 15 * S3_SR
|
||||
DEC_COND_LEN = 10 * S3GEN_SR
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
t3: T3,
|
||||
s3gen: S3Gen,
|
||||
ve: VoiceEncoder,
|
||||
tokenizer: EnTokenizer,
|
||||
device: str,
|
||||
conds: Conditionals = None,
|
||||
):
|
||||
self.sr = S3GEN_SR # sample rate of synthesized audio
|
||||
self.t3 = t3
|
||||
self.s3gen = s3gen
|
||||
self.ve = ve
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.conds = conds
|
||||
self.watermarker = perth.PerthImplicitWatermarker()
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, ckpt_dir, device) -> 'ChatterboxTurboTTS':
|
||||
ckpt_dir = Path(ckpt_dir)
|
||||
|
||||
# Always load to CPU first for non-CUDA devices to handle CUDA-saved models
|
||||
if device in ["cpu", "mps"]:
|
||||
map_location = torch.device('cpu')
|
||||
else:
|
||||
map_location = None
|
||||
|
||||
ve = VoiceEncoder()
|
||||
ve.load_state_dict(
|
||||
load_file(ckpt_dir / "ve.safetensors")
|
||||
)
|
||||
ve.to(device).eval()
|
||||
|
||||
# Turbo specific hp
|
||||
hp = T3Config(text_tokens_dict_size=50276)
|
||||
hp.llama_config_name = "GPT2_medium"
|
||||
hp.speech_tokens_dict_size = 6563
|
||||
hp.input_pos_emb = None
|
||||
hp.speech_cond_prompt_len = 375
|
||||
hp.use_perceiver_resampler = False
|
||||
hp.emotion_adv = False
|
||||
|
||||
t3 = T3(hp)
|
||||
t3_state = load_file(ckpt_dir / "t3_turbo_v1.safetensors")
|
||||
if "model" in t3_state.keys():
|
||||
t3_state = t3_state["model"][0]
|
||||
t3.load_state_dict(t3_state)
|
||||
del t3.tfmr.wte
|
||||
t3.to(device).eval()
|
||||
|
||||
s3gen = S3Gen(meanflow=True)
|
||||
weights = load_file(ckpt_dir / "s3gen_meanflow.safetensors")
|
||||
s3gen.load_state_dict(
|
||||
weights, strict=True
|
||||
)
|
||||
s3gen.to(device).eval()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
if len(tokenizer) != 50276:
|
||||
print(f"WARNING: Tokenizer len {len(tokenizer)} != 50276")
|
||||
|
||||
conds = None
|
||||
builtin_voice = ckpt_dir / "conds.pt"
|
||||
if builtin_voice.exists():
|
||||
conds = Conditionals.load(builtin_voice, map_location=map_location).to(device)
|
||||
|
||||
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, device) -> 'ChatterboxTurboTTS':
|
||||
# Check if MPS is available on macOS
|
||||
if device == "mps" and not torch.backends.mps.is_available():
|
||||
if not torch.backends.mps.is_built():
|
||||
print("MPS not available because the current PyTorch install was not built with MPS enabled.")
|
||||
else:
|
||||
print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
|
||||
device = "cpu"
|
||||
|
||||
local_path = snapshot_download(
|
||||
repo_id=REPO_ID,
|
||||
token=os.getenv("HF_TOKEN") or True,
|
||||
# Optional: Filter to download only what you need
|
||||
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.pt", "*.model"]
|
||||
)
|
||||
|
||||
return cls.from_local(local_path, device)
|
||||
|
||||
def norm_loudness(self, wav, sr, target_lufs=-27):
|
||||
try:
|
||||
meter = ln.Meter(sr)
|
||||
loudness = meter.integrated_loudness(wav)
|
||||
gain_db = target_lufs - loudness
|
||||
gain_linear = 10.0 ** (gain_db / 20.0)
|
||||
if math.isfinite(gain_linear) and gain_linear > 0.0:
|
||||
wav = wav * gain_linear
|
||||
except Exception as e:
|
||||
print(f"Warning: Error in norm_loudness, skipping: {e}")
|
||||
|
||||
return wav
|
||||
|
||||
def prepare_conditionals(self, wav_fpath, exaggeration=0.5, norm_loudness=True):
|
||||
## Load and norm reference wav
|
||||
s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
|
||||
|
||||
assert len(s3gen_ref_wav) / _sr > 5.0, "Audio prompt must be longer than 5 seconds!"
|
||||
|
||||
if norm_loudness:
|
||||
s3gen_ref_wav = self.norm_loudness(s3gen_ref_wav, _sr)
|
||||
|
||||
ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
|
||||
|
||||
s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
|
||||
s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
|
||||
|
||||
# Speech cond prompt tokens
|
||||
if plen := self.t3.hp.speech_cond_prompt_len:
|
||||
s3_tokzr = self.s3gen.tokenizer
|
||||
t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
|
||||
t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
|
||||
|
||||
# Voice-encoder speaker embedding
|
||||
ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
|
||||
ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
|
||||
|
||||
t3_cond = T3Cond(
|
||||
speaker_emb=ve_embed,
|
||||
cond_prompt_speech_tokens=t3_cond_prompt_tokens,
|
||||
emotion_adv=exaggeration * torch.ones(1, 1, 1),
|
||||
).to(device=self.device)
|
||||
self.conds = Conditionals(t3_cond, s3gen_ref_dict)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
text,
|
||||
repetition_penalty=1.2,
|
||||
min_p=0.00,
|
||||
top_p=0.95,
|
||||
audio_prompt_path=None,
|
||||
exaggeration=0.0,
|
||||
cfg_weight=0.0,
|
||||
temperature=0.8,
|
||||
top_k=1000,
|
||||
norm_loudness=True,
|
||||
):
|
||||
if audio_prompt_path:
|
||||
self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration, norm_loudness=norm_loudness)
|
||||
else:
|
||||
assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
|
||||
|
||||
if cfg_weight > 0.0 or exaggeration > 0.0 or min_p > 0.0:
|
||||
logger.warning("CFG, min_p and exaggeration are not supported by Turbo version and will be ignored.")
|
||||
|
||||
# Norm and tokenize text
|
||||
text = punc_norm(text)
|
||||
text_tokens = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
||||
text_tokens = text_tokens.input_ids.to(self.device)
|
||||
|
||||
speech_tokens = self.t3.inference_turbo(
|
||||
t3_cond=self.conds.t3,
|
||||
text_tokens=text_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
|
||||
# Remove OOV tokens and add silence to end
|
||||
speech_tokens = speech_tokens[speech_tokens < 6561]
|
||||
speech_tokens = speech_tokens.to(self.device)
|
||||
silence = torch.tensor([S3GEN_SIL, S3GEN_SIL, S3GEN_SIL]).long().to(self.device)
|
||||
speech_tokens = torch.cat([speech_tokens, silence])
|
||||
|
||||
wav, _ = self.s3gen.inference(
|
||||
speech_tokens=speech_tokens,
|
||||
ref_dict=self.conds.gen,
|
||||
n_cfm_timesteps=2,
|
||||
)
|
||||
wav = wav.squeeze(0).detach().cpu().numpy()
|
||||
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
|
||||
return torch.from_numpy(watermarked_wav).unsqueeze(0)
|
||||
|
|
@ -4,7 +4,6 @@ import librosa
|
|||
import torch
|
||||
import perth
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from .models.s3tokenizer import S3_SR
|
||||
from .models.s3gen import S3GEN_SR, S3Gen
|
||||
|
|
@ -52,7 +51,7 @@ class ChatterboxVC:
|
|||
|
||||
s3gen = S3Gen()
|
||||
s3gen.load_state_dict(
|
||||
load_file(ckpt_dir / "s3gen.safetensors"), strict=False
|
||||
torch.load(ckpt_dir / "s3gen.pt", map_location=map_location)
|
||||
)
|
||||
s3gen.to(device).eval()
|
||||
|
||||
|
|
@ -68,7 +67,7 @@ class ChatterboxVC:
|
|||
print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.")
|
||||
device = "cpu"
|
||||
|
||||
for fpath in ["s3gen.safetensors", "conds.pt"]:
|
||||
for fpath in ["s3gen.pt", "conds.pt"]:
|
||||
local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath)
|
||||
|
||||
return cls.from_local(Path(local_path).parent, device)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,91 @@
|
|||
from tqdm import tqdm
|
||||
import sys
|
||||
import torch
|
||||
import shutil
|
||||
import perth
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import os
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from chatterbox.models.s3tokenizer import S3_SR
|
||||
from chatterbox.models.s3gen import S3GEN_SR, S3Gen
|
||||
|
||||
AUDIO_EXTENSIONS = ["wav", "mp3", "flac", "opus"]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Voice Conversion")
|
||||
parser.add_argument("input", type=str, help="Path to input (a sample or folder of samples).")
|
||||
parser.add_argument("target_speaker", type=str, help="Path to the sample for the target speaker.")
|
||||
parser.add_argument("-o", "--output_folder", type=str, default="vc_outputs")
|
||||
parser.add_argument("-g", "--gpu_id", type=int, default=None)
|
||||
parser.add_argument("-m", "--mps", action="store_true", help="Use MPS (Metal) on macOS")
|
||||
parser.add_argument("--no-watermark", action="store_true", help="Skip watermarking")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Folders
|
||||
input = Path(args.input)
|
||||
output_folder = Path(args.output_folder)
|
||||
output_orig_folder = output_folder / "input"
|
||||
output_vc_folder = output_folder / "output"
|
||||
ref_folder = output_vc_folder / "target"
|
||||
output_orig_folder.mkdir(exist_ok=True, parents=True)
|
||||
output_vc_folder.mkdir(exist_ok=True)
|
||||
ref_folder.mkdir(exist_ok=True)
|
||||
|
||||
# Device selection with MPS support
|
||||
if args.mps:
|
||||
if torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
print("Using MPS (Metal) device")
|
||||
else:
|
||||
print("MPS not available, falling back to CPU")
|
||||
device = torch.device("cpu")
|
||||
elif args.gpu_id is not None:
|
||||
device = torch.device(f"cuda:{args.gpu_id}")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Determine map_location for loading
|
||||
map_location = torch.device('cpu') if device.type in ['cpu', 'mps'] else None
|
||||
|
||||
## s3gen
|
||||
s3g_fp = "checkpoints/s3gen.pt"
|
||||
s3gen = S3Gen()
|
||||
s3gen.load_state_dict(torch.load(s3g_fp, map_location=map_location))
|
||||
s3gen.to(device)
|
||||
s3gen.eval()
|
||||
|
||||
wav_fpaths = []
|
||||
if input.is_dir():
|
||||
for ext in AUDIO_EXTENSIONS:
|
||||
wav_fpaths += list(input.glob(f"*.{ext}"))
|
||||
else:
|
||||
wav_fpaths.append(input)
|
||||
|
||||
assert wav_fpaths, f"Didn't find any audio in '{input}'"
|
||||
|
||||
ref_24, _ = librosa.load(args.target_speaker, sr=S3GEN_SR, duration=10)
|
||||
ref_24 = torch.tensor(ref_24).float()
|
||||
shutil.copy(args.target_speaker, ref_folder / Path(args.target_speaker).name)
|
||||
if not args.no_watermark:
|
||||
watermarker = perth.PerthImplicitWatermarker()
|
||||
for wav_fpath in tqdm(wav_fpaths):
|
||||
shutil.copy(wav_fpath, output_orig_folder / wav_fpath.name)
|
||||
|
||||
audio_16, _ = librosa.load(str(wav_fpath), sr=S3_SR)
|
||||
audio_16 = torch.tensor(audio_16).float().to(device)[None, ]
|
||||
s3_tokens, _ = s3gen.tokenizer(audio_16)
|
||||
|
||||
wav = s3gen(s3_tokens.to(device), ref_24, S3GEN_SR)
|
||||
wav = wav.view(-1).cpu().numpy()
|
||||
if not args.no_watermark:
|
||||
wav = watermarker.apply_watermark(wav, sample_rate=S3GEN_SR)
|
||||
save_path = output_vc_folder / wav_fpath.name
|
||||
sf.write(str(save_path), wav, samplerate=S3GEN_SR)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue