Skip to content

Instantly share code, notes, and snippets.

@indiejoseph
Created June 7, 2025 08:13
Show Gist options
  • Save indiejoseph/c15dd3a024ad1995b2b51dff84e2f611 to your computer and use it in GitHub Desktop.
Save indiejoseph/c15dd3a024ad1995b2b51dff84e2f611 to your computer and use it in GitHub Desktop.
import queue
import time
import threading
from transformers.generation.logits_process import (
TopPLogitsWarper,
RepetitionPenaltyLogitsProcessor,
)
@torch.inference_mode()
def inference_v2(
self,
text: torch.Tensor, # [1, T] only support batch size 1
text_len: torch.Tensor, # [1]
tone: torch.Tensor, # [1, T]
bert_feature: torch.Tensor, # [1, D, T]
speaker_embedding: Optional[torch.Tensor] = None, # [1, D]
prompt_text: Optional[torch.Tensor] = None,
prompt_text_len: Optional[torch.Tensor] = None,
prompt_tone: Optional[torch.Tensor] = None,
prompt_bert_feature: Optional[torch.Tensor] = None,
prompt_speech_token: Optional[torch.Tensor] = None,
prompt_speech_token_len: Optional[torch.Tensor] = None,
temperature=1.0,
top_p: float = 0.9,
repetition_penalty=2.0,
max_token_text_ratio: float = 8,
use_dp: bool = False,
) -> Generator[
Tuple[int, List[torch.Tensor]], None, Tuple[List[int], List[torch.Tensor]]
]:
device = text.device
batch_size = text.shape[0]
top_p_warper = TopPLogitsWarper(top_p=top_p)
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(
penalty=repetition_penalty
)
assert batch_size == 1, "batch size should be 1 in inference mode"
if prompt_text is not None and prompt_text_len is not None:
text = torch.concat([prompt_text, text], dim=1)
text_len += prompt_text_len
text = self.text_embedding(text)
tone_embed = self.tone_embedding(tone)
if prompt_tone is not None:
prompt_tone_embed = self.tone_embedding(prompt_tone)
tone_embed = torch.concat([prompt_tone_embed, tone_embed], dim=1)
# add tone embeddings to text embeddings
bert_feature = self.bert_embed_affine_layer(bert_feature).transpose(1, 2)
if prompt_bert_feature is not None:
prompt_bert_feature = self.bert_embed_affine_layer(
prompt_bert_feature
).transpose(1, 2)
bert_feature = torch.concat([prompt_bert_feature, bert_feature], dim=1)
text = torch.concat([text, tone_embed], dim=2)
text = self.encoder_affine_layer(text)
text = self.text_fusion(text, bert_feature)
T_text = text_len[0].item()
text_start = 2 # 2 for sos + speaker
text_end = text_start + T_text + 1
if prompt_text_len is not None:
text_start = 2 + prompt_text_len.item()
text_end = text_start + T_text + 1 + prompt_text_len.item()
dp_tracker = DPAlignmentTracker(
aeam_heads=self.aeam_heads,
text_start=text_start,
text_end=text_end,
rho=3,
)
# 1. encode text
text, text_len = self.encode(text, text_len)
# 2. encode embedding
if speaker_embedding is not None and speaker_embedding.shape[0] != 0:
speaker_embedding = F.normalize(speaker_embedding, dim=1)
speaker_embedding = self.spk_embed_affine_layer(speaker_embedding)
speaker_embedding = speaker_embedding.unsqueeze(dim=1)
else:
speaker_embedding = (
torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype)
.to(device)
.to(text.dtype)
)
# 3. concat llm_input
sos_eos_emb = (
self.llm_embedding.weight[self.sos_eos]
.reshape(1, 1, -1)
.repeat(batch_size, 1, 1)
)
task_id_emb = (
self.llm_embedding.weight[self.task_id]
.reshape(1, 1, -1)
.repeat(batch_size, 1, 1)
)
# Calculate context_len (number of context tokens before text)
context_len = 1 # sos_eos_emb
if speaker_embedding is not None and speaker_embedding.shape[1] > 0:
context_len += speaker_embedding.shape[1]
if prompt_speech_token_len is not None:
context_len += (
prompt_speech_token_len.item()
if hasattr(prompt_speech_token_len, "item")
else int(prompt_speech_token_len)
)
context_len += text.shape[1] # text tokens
context_len += 1 # task_id_emb
# 4. prepare llm input
llm_input = [
sos_eos_emb,
speaker_embedding,
text,
task_id_emb,
]
if prompt_speech_token_len is not None:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
llm_input.append(prompt_speech_token_emb)
lm_input = torch.concat(
llm_input,
dim=1,
)
generated_id = torch.tensor([[self.task_id]], device=device, dtype=torch.int64)
predicted = []
attn_scores_list = []
offset = 0
# Initialize the first step
y_pred, att_cache, cnn_cache, attn_scores = self.forward_one_step(
lm_input,
offset=offset,
step=0,
dp_tracker=dp_tracker if use_dp else None,
use_dp=use_dp,
device=device,
)
attn_scores_list.append(attn_scores)
# cal max length
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
for i in range(1, max_len):
offset += lm_input.size(1)
logits = self.llm_decoder(y_pred[:, -1]) # [1, D]
if temperature != 1.0:
logits = logits / temperature
# Apply repetition penalty and top-p filtering
logits = repetition_penalty_processor(generated_id, logits)
logits = top_p_warper(None, logits)
# Convert logits to probabilities and sample the next token
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1)
# Check for EOS token
if next_token.view(-1) == self.speech_token_size:
break
yield next_token.item(), attn_scores
predicted.append(next_token)
generated_id = torch.cat([generated_id, next_token], dim=1) # [1, T+1]
# Get embedding for the new token
lm_input = self.speech_embedding.weight[next_token].reshape(1, 1, -1)
y_pred, att_cache, cnn_cache, attn_scores = self.forward_one_step(
lm_input,
offset=offset,
step=i,
att_cache=att_cache,
cnn_cache=cnn_cache,
dp_tracker=dp_tracker if use_dp else None,
use_dp=use_dp,
device=device,
)
attn_scores_list.append(attn_scores)
return predicted, attn_scores_list
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment