Enhance code readability with comments and formatting
Added comments for clarity and updated code for better readability.
This commit is contained in:
parent
777c377ae6
commit
aeb070b32e
|
|
@ -1,3 +1,5 @@
|
|||
# Copyright (c) 2025 Resemble AI
|
||||
# MIT License
|
||||
import logging
|
||||
from typing import Union, Optional, List
|
||||
|
||||
|
|
@ -64,10 +66,12 @@ class T3(nn.Module):
|
|||
self.dim = self.cfg.hidden_size
|
||||
self.deepspeed_patch_applied = False
|
||||
|
||||
# conditioning / embedding
|
||||
self.cond_enc = T3CondEnc(hp)
|
||||
self.text_emb = nn.Embedding(hp.text_tokens_dict_size, self.dim)
|
||||
self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim)
|
||||
|
||||
# custom position embedding
|
||||
self.text_pos_emb = None
|
||||
self.speech_pos_emb = None
|
||||
if hp.input_pos_emb == "learned":
|
||||
|
|
@ -77,6 +81,7 @@ class T3(nn.Module):
|
|||
max_mel_seq_len = hp.max_speech_tokens + 2 + 2
|
||||
self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim)
|
||||
|
||||
# logit projection
|
||||
self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False)
|
||||
self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=self.is_gpt)
|
||||
self.compiled = False
|
||||
|
|
@ -93,7 +98,7 @@ class T3(nn.Module):
|
|||
t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens)
|
||||
if not self.is_gpt:
|
||||
t3_cond.cond_prompt_speech_emb += self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens)
|
||||
return self.cond_enc(t3_cond)
|
||||
return self.cond_enc(t3_cond) # (B, len_cond, dim)
|
||||
|
||||
def prepare_input_embeds(
|
||||
self,
|
||||
|
|
@ -103,12 +108,13 @@ class T3(nn.Module):
|
|||
speech_tokens: torch.LongTensor,
|
||||
cfg_weight: float = 0.0,
|
||||
):
|
||||
cond_emb = self.prepare_conditioning(t3_cond)
|
||||
text_emb = self.text_emb(text_tokens)
|
||||
# prepare input embeddings (skip backbone tranformer embeddings)
|
||||
cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim)
|
||||
text_emb = self.text_emb(text_tokens) # (B, len_text, dim)
|
||||
if cfg_weight > 0.0 and not self.is_gpt:
|
||||
text_emb[1].zero_()
|
||||
text_emb[1].zero_() # CFG uncond
|
||||
|
||||
speech_emb = self.speech_emb(speech_tokens)
|
||||
speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim)
|
||||
if self.hp.input_pos_emb == "learned":
|
||||
text_emb = text_emb + self.text_pos_emb(text_tokens)
|
||||
speech_emb = speech_emb + self.speech_pos_emb(speech_tokens)
|
||||
|
|
@ -117,10 +123,11 @@ class T3(nn.Module):
|
|||
if cond_emb.size(0) != text_emb.size(0):
|
||||
cond_emb = cond_emb.expand(text_emb.size(0), -1, -1)
|
||||
|
||||
# concat
|
||||
embeds = torch.stack([
|
||||
torch.cat((ce, te, se))
|
||||
for ce, te, se in zip(cond_emb, text_emb, speech_emb)
|
||||
])
|
||||
]) # (B, length, dim)
|
||||
return embeds, len_cond
|
||||
|
||||
def forward(
|
||||
|
|
@ -135,21 +142,25 @@ class T3(nn.Module):
|
|||
):
|
||||
_ensure_BOT_EOT(text_tokens, self.hp)
|
||||
|
||||
# prepare custom input embeds
|
||||
embeds, len_cond = self.prepare_input_embeds(
|
||||
t3_cond=t3_cond,
|
||||
text_tokens=text_tokens,
|
||||
speech_tokens=speech_tokens,
|
||||
)
|
||||
|
||||
# backbone tranformer forward
|
||||
tfmr_out = self.tfmr.forward(
|
||||
input_ids=None,
|
||||
# position_ids=position_ids, # TODO? ROPE should be fine?
|
||||
inputs_embeds=embeds,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
use_cache=(not training),
|
||||
)
|
||||
hidden_states = tfmr_out.hidden_states[-1]
|
||||
hidden_states = tfmr_out.hidden_states[-1] # final tfmr layer output, (B, seq, dim)
|
||||
|
||||
# post-processing: splice out text and speech parts of hidden states
|
||||
len_text = text_tokens.size(1)
|
||||
len_speech = speech_tokens.size(1)
|
||||
B, _, dim = hidden_states.shape
|
||||
|
|
@ -164,6 +175,7 @@ class T3(nn.Module):
|
|||
text_latents[i, :ttl[i]] = hidden_states[i, len_cond:text_end]
|
||||
speech_latents[i, :stl[i]] = hidden_states[i, speech_start:speech_end]
|
||||
|
||||
# logit projection
|
||||
text_logits = self.text_head(text_latents)
|
||||
speech_logits = self.speech_head(speech_latents)
|
||||
|
||||
|
|
@ -197,12 +209,13 @@ class T3(nn.Module):
|
|||
speech_tokens=speech_tokens,
|
||||
speech_token_lens=speech_token_lens,
|
||||
training=True,
|
||||
)
|
||||
) # (B, seq, vocab_size)
|
||||
|
||||
# Calc CCE losses
|
||||
IGNORE_ID = -100
|
||||
device = out.text_logits.device
|
||||
mask_text = torch.arange(len_text, device=device)[None] >= text_token_lens[:, None]
|
||||
mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None]
|
||||
mask_text = torch.arange(len_text, device=device)[None] >= text_token_lens[:, None] # (B, len_text)
|
||||
mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None] # (B, len_speech)
|
||||
masked_text = text_tokens.masked_fill(mask_text, IGNORE_ID)
|
||||
masked_speech = speech_tokens.masked_fill(mask_speech, IGNORE_ID)
|
||||
loss_text = F.cross_entropy(out.text_logits, masked_text, ignore_index=IGNORE_ID)
|
||||
|
|
@ -218,8 +231,10 @@ class T3(nn.Module):
|
|||
text_tokens: Tensor,
|
||||
initial_speech_tokens: Optional[Tensor]=None,
|
||||
|
||||
# misc conditioning
|
||||
prepend_prompt_speech_tokens: Optional[Tensor]=None,
|
||||
|
||||
# HF generate args
|
||||
num_return_sequences=1,
|
||||
max_new_tokens=None,
|
||||
stop_on_eos=True,
|
||||
|
|
@ -235,13 +250,16 @@ class T3(nn.Module):
|
|||
Args:
|
||||
text_tokens: a 1D (unbatched) or 2D (batched) tensor.
|
||||
"""
|
||||
# Validate / sanitize inputs
|
||||
assert prepend_prompt_speech_tokens is None, "not implemented"
|
||||
_ensure_BOT_EOT(text_tokens, self.hp)
|
||||
text_tokens = torch.atleast_2d(text_tokens).to(dtype=torch.long, device=self.device)
|
||||
|
||||
# Default initial speech to a single start-of-speech token
|
||||
if initial_speech_tokens is None:
|
||||
initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])
|
||||
|
||||
# Prepare custom input embeds
|
||||
embeds, len_cond = self.prepare_input_embeds(
|
||||
t3_cond=t3_cond,
|
||||
text_tokens=text_tokens,
|
||||
|
|
@ -249,16 +267,22 @@ class T3(nn.Module):
|
|||
cfg_weight=cfg_weight,
|
||||
)
|
||||
|
||||
# In order to use the standard HF generate method, we need to extend some methods to inject our custom logic
|
||||
# Note the llama-specific logic. Other tfmr types can be added later.
|
||||
|
||||
self.compiled = False
|
||||
|
||||
# TODO? synchronize the expensive compile function
|
||||
# with self.compile_lock:
|
||||
if not self.compiled:
|
||||
# Default to None for English models, only create for multilingual
|
||||
alignment_stream_analyzer = None
|
||||
if self.hp.is_multilingual:
|
||||
alignment_stream_analyzer = AlignmentStreamAnalyzer(
|
||||
self.tfmr,
|
||||
None,
|
||||
text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
|
||||
alignment_layer_idx=9,
|
||||
alignment_layer_idx=9, # TODO: hparam or something?
|
||||
eos_idx=self.hp.stop_speech_token,
|
||||
)
|
||||
assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
|
||||
|
|
@ -273,19 +297,194 @@ class T3(nn.Module):
|
|||
self.patched_model = patched_model
|
||||
self.compiled = True
|
||||
|
||||
# # Run normal generate method, which calls our custom extended methods
|
||||
# return self.patched_model.generate(
|
||||
# inputs=initial_speech_tokens,
|
||||
# decoder_cond=embeds,
|
||||
# bos_token_id=self.hp.start_speech_token,
|
||||
# eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1),
|
||||
# pad_token_id=self.hp.stop_speech_token,
|
||||
# max_new_tokens=max_new_tokens or self.hp.max_speech_tokens,
|
||||
# num_return_sequences=num_return_sequences,
|
||||
# temperature=temperature,
|
||||
# min_p=min_p,
|
||||
# length_penalty=length_penalty,
|
||||
# repetition_penalty=repetition_penalty,
|
||||
# do_sample=do_sample,
|
||||
# # cache_implementation=None if not self.compiled else "static",
|
||||
# )
|
||||
|
||||
device = embeds.device
|
||||
|
||||
bos_token = torch.tensor([[self.hp.start_speech_token]], dtype=torch.long, device=device)
|
||||
bos_embed = self.speech_emb(bos_token)
|
||||
bos_embed = self.speech_emb(bos_token) # shape: (B, 1, embed_dim)
|
||||
bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0)
|
||||
|
||||
# batch_size=2 for CFG
|
||||
bos_embed = torch.cat([bos_embed, bos_embed])
|
||||
|
||||
# Combine condition and BOS token for the initial input
|
||||
inputs_embeds = torch.cat([embeds, bos_embed], dim=1)
|
||||
|
||||
# Track generated token ids; start with the BOS token.
|
||||
generated_ids = bos_token.clone()
|
||||
predicted = []
|
||||
predicted = [] # To store the predicted tokens
|
||||
|
||||
# Instantiate the logits processors.
|
||||
top_p_warper = TopPLogitsWarper(top_p=top_p)
|
||||
min_p_warper = MinPLogitsWarper(min_p=min_p)
|
||||
top_p_warper = TopPLogitsWarper(t
|
||||
top_p_warper = TopPLogitsWarper(top_p=top_p)
|
||||
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty))
|
||||
|
||||
# ---- Initial Forward Pass (no kv_cache yet) ----
|
||||
output = self.patched_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=None,
|
||||
use_cache=True,
|
||||
output_attentions=True,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
# Initialize kv_cache with the full context.
|
||||
past = output.past_key_values
|
||||
|
||||
# ---- Generation Loop using kv_cache ----
|
||||
for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True):
|
||||
logits_step = output.logits[:, -1, :]
|
||||
# CFG combine → (1, V)
|
||||
cond = logits_step[0:1, :]
|
||||
uncond = logits_step[1:2, :]
|
||||
cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
|
||||
logits = cond + cfg * (cond - uncond)
|
||||
|
||||
# Apply alignment stream analyzer integrity checks
|
||||
if self.patched_model.alignment_stream_analyzer is not None:
|
||||
if logits.dim() == 1: # guard in case something upstream squeezed
|
||||
logits = logits.unsqueeze(0) # (1, V)
|
||||
# Pass the last generated token for repetition tracking
|
||||
last_token = generated_ids[0, -1].item() if len(generated_ids[0]) > 0 else None
|
||||
logits = self.patched_model.alignment_stream_analyzer.step(logits, next_token=last_token) # (1, V)
|
||||
|
||||
# Apply repetition penalty
|
||||
ids_for_proc = generated_ids[:1, ...] # batch = 1
|
||||
logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
|
||||
|
||||
# Apply temperature scaling.
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
# Apply min_p and top_p filtering
|
||||
logits = min_p_warper(ids_for_proc, logits)
|
||||
logits = top_p_warper(ids_for_proc, logits)
|
||||
|
||||
# Convert logits to probabilities and sample the next token.
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1)
|
||||
|
||||
predicted.append(next_token)
|
||||
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
||||
|
||||
# Check for EOS token.
|
||||
if next_token.view(-1) == self.hp.stop_speech_token:
|
||||
logger.info(f"✅ EOS token detected! Stopping generation at step {i+1}")
|
||||
break
|
||||
|
||||
# Get embedding for the new token.
|
||||
next_token_embed = self.speech_emb(next_token)
|
||||
next_token_embed = next_token_embed + self.speech_pos_emb.get_fixed_embedding(i + 1)
|
||||
|
||||
# For CFG
|
||||
next_token_embed = torch.cat([next_token_embed, next_token_embed])
|
||||
|
||||
# Forward pass with only the new token and the cached past.
|
||||
output = self.patched_model(
|
||||
inputs_embeds=next_token_embed,
|
||||
past_key_values=past,
|
||||
output_attentions=True,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
# Update the kv_cache.
|
||||
past = output.past_key_values
|
||||
|
||||
# Concatenate all predicted tokens along the sequence dimension.
|
||||
predicted_tokens = torch.cat(predicted, dim=1) # shape: (B, num_tokens)
|
||||
return predicted_tokens
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference_turbo(self, t3_cond, text_tokens, temperature=0.8, top_k=1000, top_p=0.95, repetition_penalty=1.2,
|
||||
max_gen_len=1000):
|
||||
|
||||
logits_processors = LogitsProcessorList()
|
||||
if temperature > 0 and temperature != 1.0:
|
||||
logits_processors.append(TemperatureLogitsWarper(temperature))
|
||||
if top_k > 0:
|
||||
logits_processors.append(TopKLogitsWarper(top_k))
|
||||
if top_p < 1.0:
|
||||
logits_processors.append(TopPLogitsWarper(top_p))
|
||||
if repetition_penalty != 1.0:
|
||||
logits_processors.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
|
||||
|
||||
|
||||
speech_start_token = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1])
|
||||
embeds, _ = self.prepare_input_embeds(
|
||||
t3_cond=t3_cond,
|
||||
text_tokens=text_tokens,
|
||||
speech_tokens=speech_start_token,
|
||||
cfg_weight=0.0,
|
||||
)
|
||||
|
||||
generated_speech_tokens = []
|
||||
|
||||
llm_outputs = self.tfmr(
|
||||
inputs_embeds=embeds,
|
||||
use_cache=True
|
||||
)
|
||||
|
||||
hidden_states = llm_outputs[0]
|
||||
past_key_values = llm_outputs.past_key_values
|
||||
|
||||
speech_hidden = hidden_states[:, -1:]
|
||||
speech_logits = self.speech_head(speech_hidden)
|
||||
|
||||
processed_logits = logits_processors(speech_start_token, speech_logits[:, -1, :])
|
||||
probs = F.softmax(processed_logits, dim=-1)
|
||||
next_speech_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
generated_speech_tokens.append(next_speech_token)
|
||||
current_speech_token = next_speech_token
|
||||
|
||||
for _ in tqdm(range(max_gen_len)):
|
||||
current_speech_embed = self.speech_emb(current_speech_token)
|
||||
|
||||
llm_outputs = self.tfmr(
|
||||
inputs_embeds=current_speech_embed,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True
|
||||
)
|
||||
|
||||
hidden_states = llm_outputs[0]
|
||||
past_key_values = llm_outputs.past_key_values
|
||||
speech_logits = self.speech_head(hidden_states)
|
||||
|
||||
input_ids = torch.cat(generated_speech_tokens, dim=1)
|
||||
processed_logits = logits_processors(input_ids, speech_logits[:, -1, :])
|
||||
if torch.all(processed_logits == -float("inf")):
|
||||
print("Warning: All logits are -inf")
|
||||
break
|
||||
|
||||
probs = F.softmax(processed_logits, dim=-1)
|
||||
next_speech_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
generated_speech_tokens.append(next_speech_token)
|
||||
current_speech_token = next_speech_token
|
||||
if torch.all(next_speech_token == self.hp.stop_speech_token):
|
||||
break
|
||||
|
||||
all_tokens = torch.cat(generated_speech_tokens, dim=1)
|
||||
|
||||
# Remove EOS token if present
|
||||
if all_tokens.size(1) > 0 and all_tokens[0, -1] == self.hp.stop_speech_token:
|
||||
all_tokens = all_tokens[:, :-1]
|
||||
|
||||
return all_tokens
|
||||
|
|
|
|||
Loading…
Reference in New Issue