From aeb070b32e3e58b8995d9aff8564ebee993663bb Mon Sep 17 00:00:00 2001 From: Chintella-Esq Date: Thu, 25 Dec 2025 13:23:23 -0500 Subject: [PATCH] Enhance code readability with comments and formatting Added comments for clarity and updated code for better readability. --- src/chatterbox/models/t3/t3.py | 227 +++++++++++++++++++++++++++++++-- 1 file changed, 213 insertions(+), 14 deletions(-) diff --git a/src/chatterbox/models/t3/t3.py b/src/chatterbox/models/t3/t3.py index c5d72c3..6253028 100644 --- a/src/chatterbox/models/t3/t3.py +++ b/src/chatterbox/models/t3/t3.py @@ -1,3 +1,5 @@ +# Copyright (c) 2025 Resemble AI +# MIT License import logging from typing import Union, Optional, List @@ -64,10 +66,12 @@ class T3(nn.Module): self.dim = self.cfg.hidden_size self.deepspeed_patch_applied = False + # conditioning / embedding self.cond_enc = T3CondEnc(hp) self.text_emb = nn.Embedding(hp.text_tokens_dict_size, self.dim) self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim) + # custom position embedding self.text_pos_emb = None self.speech_pos_emb = None if hp.input_pos_emb == "learned": @@ -77,6 +81,7 @@ class T3(nn.Module): max_mel_seq_len = hp.max_speech_tokens + 2 + 2 self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim) + # logit projection self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False) self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=self.is_gpt) self.compiled = False @@ -93,7 +98,7 @@ class T3(nn.Module): t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens) if not self.is_gpt: t3_cond.cond_prompt_speech_emb += self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens) - return self.cond_enc(t3_cond) + return self.cond_enc(t3_cond) # (B, len_cond, dim) def prepare_input_embeds( self, @@ -103,12 +108,13 @@ class T3(nn.Module): speech_tokens: torch.LongTensor, cfg_weight: float = 0.0, ): - cond_emb = self.prepare_conditioning(t3_cond) - text_emb = self.text_emb(text_tokens) + # prepare input embeddings (skip backbone tranformer embeddings) + cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim) + text_emb = self.text_emb(text_tokens) # (B, len_text, dim) if cfg_weight > 0.0 and not self.is_gpt: - text_emb[1].zero_() + text_emb[1].zero_() # CFG uncond - speech_emb = self.speech_emb(speech_tokens) + speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim) if self.hp.input_pos_emb == "learned": text_emb = text_emb + self.text_pos_emb(text_tokens) speech_emb = speech_emb + self.speech_pos_emb(speech_tokens) @@ -117,10 +123,11 @@ class T3(nn.Module): if cond_emb.size(0) != text_emb.size(0): cond_emb = cond_emb.expand(text_emb.size(0), -1, -1) + # concat embeds = torch.stack([ torch.cat((ce, te, se)) for ce, te, se in zip(cond_emb, text_emb, speech_emb) - ]) + ]) # (B, length, dim) return embeds, len_cond def forward( @@ -135,21 +142,25 @@ class T3(nn.Module): ): _ensure_BOT_EOT(text_tokens, self.hp) + # prepare custom input embeds embeds, len_cond = self.prepare_input_embeds( t3_cond=t3_cond, text_tokens=text_tokens, speech_tokens=speech_tokens, ) + # backbone tranformer forward tfmr_out = self.tfmr.forward( input_ids=None, + # position_ids=position_ids, # TODO? ROPE should be fine? inputs_embeds=embeds, output_hidden_states=True, return_dict=True, use_cache=(not training), ) - hidden_states = tfmr_out.hidden_states[-1] + hidden_states = tfmr_out.hidden_states[-1] # final tfmr layer output, (B, seq, dim) + # post-processing: splice out text and speech parts of hidden states len_text = text_tokens.size(1) len_speech = speech_tokens.size(1) B, _, dim = hidden_states.shape @@ -164,6 +175,7 @@ class T3(nn.Module): text_latents[i, :ttl[i]] = hidden_states[i, len_cond:text_end] speech_latents[i, :stl[i]] = hidden_states[i, speech_start:speech_end] + # logit projection text_logits = self.text_head(text_latents) speech_logits = self.speech_head(speech_latents) @@ -197,12 +209,13 @@ class T3(nn.Module): speech_tokens=speech_tokens, speech_token_lens=speech_token_lens, training=True, - ) + ) # (B, seq, vocab_size) + # Calc CCE losses IGNORE_ID = -100 device = out.text_logits.device - mask_text = torch.arange(len_text, device=device)[None] >= text_token_lens[:, None] - mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None] + mask_text = torch.arange(len_text, device=device)[None] >= text_token_lens[:, None] # (B, len_text) + mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None] # (B, len_speech) masked_text = text_tokens.masked_fill(mask_text, IGNORE_ID) masked_speech = speech_tokens.masked_fill(mask_speech, IGNORE_ID) loss_text = F.cross_entropy(out.text_logits, masked_text, ignore_index=IGNORE_ID) @@ -218,8 +231,10 @@ class T3(nn.Module): text_tokens: Tensor, initial_speech_tokens: Optional[Tensor]=None, + # misc conditioning prepend_prompt_speech_tokens: Optional[Tensor]=None, + # HF generate args num_return_sequences=1, max_new_tokens=None, stop_on_eos=True, @@ -235,13 +250,16 @@ class T3(nn.Module): Args: text_tokens: a 1D (unbatched) or 2D (batched) tensor. """ + # Validate / sanitize inputs assert prepend_prompt_speech_tokens is None, "not implemented" _ensure_BOT_EOT(text_tokens, self.hp) text_tokens = torch.atleast_2d(text_tokens).to(dtype=torch.long, device=self.device) + # Default initial speech to a single start-of-speech token if initial_speech_tokens is None: initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1]) + # Prepare custom input embeds embeds, len_cond = self.prepare_input_embeds( t3_cond=t3_cond, text_tokens=text_tokens, @@ -249,16 +267,22 @@ class T3(nn.Module): cfg_weight=cfg_weight, ) + # In order to use the standard HF generate method, we need to extend some methods to inject our custom logic + # Note the llama-specific logic. Other tfmr types can be added later. + self.compiled = False + # TODO? synchronize the expensive compile function + # with self.compile_lock: if not self.compiled: + # Default to None for English models, only create for multilingual alignment_stream_analyzer = None if self.hp.is_multilingual: alignment_stream_analyzer = AlignmentStreamAnalyzer( self.tfmr, None, text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)), - alignment_layer_idx=9, + alignment_layer_idx=9, # TODO: hparam or something? eos_idx=self.hp.stop_speech_token, ) assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token @@ -273,19 +297,194 @@ class T3(nn.Module): self.patched_model = patched_model self.compiled = True + # # Run normal generate method, which calls our custom extended methods + # return self.patched_model.generate( + # inputs=initial_speech_tokens, + # decoder_cond=embeds, + # bos_token_id=self.hp.start_speech_token, + # eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1), + # pad_token_id=self.hp.stop_speech_token, + # max_new_tokens=max_new_tokens or self.hp.max_speech_tokens, + # num_return_sequences=num_return_sequences, + # temperature=temperature, + # min_p=min_p, + # length_penalty=length_penalty, + # repetition_penalty=repetition_penalty, + # do_sample=do_sample, + # # cache_implementation=None if not self.compiled else "static", + # ) + device = embeds.device bos_token = torch.tensor([[self.hp.start_speech_token]], dtype=torch.long, device=device) - bos_embed = self.speech_emb(bos_token) + bos_embed = self.speech_emb(bos_token) # shape: (B, 1, embed_dim) bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0) + # batch_size=2 for CFG bos_embed = torch.cat([bos_embed, bos_embed]) + # Combine condition and BOS token for the initial input inputs_embeds = torch.cat([embeds, bos_embed], dim=1) + # Track generated token ids; start with the BOS token. generated_ids = bos_token.clone() - predicted = [] + predicted = [] # To store the predicted tokens + # Instantiate the logits processors. top_p_warper = TopPLogitsWarper(top_p=top_p) min_p_warper = MinPLogitsWarper(min_p=min_p) - top_p_warper = TopPLogitsWarper(t \ No newline at end of file + top_p_warper = TopPLogitsWarper(top_p=top_p) + repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=float(repetition_penalty)) + + # ---- Initial Forward Pass (no kv_cache yet) ---- + output = self.patched_model( + inputs_embeds=inputs_embeds, + past_key_values=None, + use_cache=True, + output_attentions=True, + output_hidden_states=True, + return_dict=True, + ) + # Initialize kv_cache with the full context. + past = output.past_key_values + + # ---- Generation Loop using kv_cache ---- + for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True): + logits_step = output.logits[:, -1, :] + # CFG combine → (1, V) + cond = logits_step[0:1, :] + uncond = logits_step[1:2, :] + cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype) + logits = cond + cfg * (cond - uncond) + + # Apply alignment stream analyzer integrity checks + if self.patched_model.alignment_stream_analyzer is not None: + if logits.dim() == 1: # guard in case something upstream squeezed + logits = logits.unsqueeze(0) # (1, V) + # Pass the last generated token for repetition tracking + last_token = generated_ids[0, -1].item() if len(generated_ids[0]) > 0 else None + logits = self.patched_model.alignment_stream_analyzer.step(logits, next_token=last_token) # (1, V) + + # Apply repetition penalty + ids_for_proc = generated_ids[:1, ...] # batch = 1 + logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V) + + # Apply temperature scaling. + if temperature != 1.0: + logits = logits / temperature + + # Apply min_p and top_p filtering + logits = min_p_warper(ids_for_proc, logits) + logits = top_p_warper(ids_for_proc, logits) + + # Convert logits to probabilities and sample the next token. + probs = torch.softmax(logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1) + + predicted.append(next_token) + generated_ids = torch.cat([generated_ids, next_token], dim=1) + + # Check for EOS token. + if next_token.view(-1) == self.hp.stop_speech_token: + logger.info(f"✅ EOS token detected! Stopping generation at step {i+1}") + break + + # Get embedding for the new token. + next_token_embed = self.speech_emb(next_token) + next_token_embed = next_token_embed + self.speech_pos_emb.get_fixed_embedding(i + 1) + + # For CFG + next_token_embed = torch.cat([next_token_embed, next_token_embed]) + + # Forward pass with only the new token and the cached past. + output = self.patched_model( + inputs_embeds=next_token_embed, + past_key_values=past, + output_attentions=True, + output_hidden_states=True, + return_dict=True, + ) + # Update the kv_cache. + past = output.past_key_values + + # Concatenate all predicted tokens along the sequence dimension. + predicted_tokens = torch.cat(predicted, dim=1) # shape: (B, num_tokens) + return predicted_tokens + + @torch.inference_mode() + def inference_turbo(self, t3_cond, text_tokens, temperature=0.8, top_k=1000, top_p=0.95, repetition_penalty=1.2, + max_gen_len=1000): + + logits_processors = LogitsProcessorList() + if temperature > 0 and temperature != 1.0: + logits_processors.append(TemperatureLogitsWarper(temperature)) + if top_k > 0: + logits_processors.append(TopKLogitsWarper(top_k)) + if top_p < 1.0: + logits_processors.append(TopPLogitsWarper(top_p)) + if repetition_penalty != 1.0: + logits_processors.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) + + + speech_start_token = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1]) + embeds, _ = self.prepare_input_embeds( + t3_cond=t3_cond, + text_tokens=text_tokens, + speech_tokens=speech_start_token, + cfg_weight=0.0, + ) + + generated_speech_tokens = [] + + llm_outputs = self.tfmr( + inputs_embeds=embeds, + use_cache=True + ) + + hidden_states = llm_outputs[0] + past_key_values = llm_outputs.past_key_values + + speech_hidden = hidden_states[:, -1:] + speech_logits = self.speech_head(speech_hidden) + + processed_logits = logits_processors(speech_start_token, speech_logits[:, -1, :]) + probs = F.softmax(processed_logits, dim=-1) + next_speech_token = torch.multinomial(probs, num_samples=1) + + generated_speech_tokens.append(next_speech_token) + current_speech_token = next_speech_token + + for _ in tqdm(range(max_gen_len)): + current_speech_embed = self.speech_emb(current_speech_token) + + llm_outputs = self.tfmr( + inputs_embeds=current_speech_embed, + past_key_values=past_key_values, + use_cache=True + ) + + hidden_states = llm_outputs[0] + past_key_values = llm_outputs.past_key_values + speech_logits = self.speech_head(hidden_states) + + input_ids = torch.cat(generated_speech_tokens, dim=1) + processed_logits = logits_processors(input_ids, speech_logits[:, -1, :]) + if torch.all(processed_logits == -float("inf")): + print("Warning: All logits are -inf") + break + + probs = F.softmax(processed_logits, dim=-1) + next_speech_token = torch.multinomial(probs, num_samples=1) + + generated_speech_tokens.append(next_speech_token) + current_speech_token = next_speech_token + if torch.all(next_speech_token == self.hp.stop_speech_token): + break + + all_tokens = torch.cat(generated_speech_tokens, dim=1) + + # Remove EOS token if present + if all_tokens.size(1) > 0 and all_tokens[0, -1] == self.hp.stop_speech_token: + all_tokens = all_tokens[:, :-1] + + return all_tokens