Created
June 7, 2025 08:13
-
-
Save indiejoseph/c15dd3a024ad1995b2b51dff84e2f611 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import queue | |
import time | |
import threading | |
from transformers.generation.logits_process import ( | |
TopPLogitsWarper, | |
RepetitionPenaltyLogitsProcessor, | |
) | |
@torch.inference_mode() | |
def inference_v2( | |
self, | |
text: torch.Tensor, # [1, T] only support batch size 1 | |
text_len: torch.Tensor, # [1] | |
tone: torch.Tensor, # [1, T] | |
bert_feature: torch.Tensor, # [1, D, T] | |
speaker_embedding: Optional[torch.Tensor] = None, # [1, D] | |
prompt_text: Optional[torch.Tensor] = None, | |
prompt_text_len: Optional[torch.Tensor] = None, | |
prompt_tone: Optional[torch.Tensor] = None, | |
prompt_bert_feature: Optional[torch.Tensor] = None, | |
prompt_speech_token: Optional[torch.Tensor] = None, | |
prompt_speech_token_len: Optional[torch.Tensor] = None, | |
temperature=1.0, | |
top_p: float = 0.9, | |
repetition_penalty=2.0, | |
max_token_text_ratio: float = 8, | |
use_dp: bool = False, | |
) -> Generator[ | |
Tuple[int, List[torch.Tensor]], None, Tuple[List[int], List[torch.Tensor]] | |
]: | |
device = text.device | |
batch_size = text.shape[0] | |
top_p_warper = TopPLogitsWarper(top_p=top_p) | |
repetition_penalty_processor = RepetitionPenaltyLogitsProcessor( | |
penalty=repetition_penalty | |
) | |
assert batch_size == 1, "batch size should be 1 in inference mode" | |
if prompt_text is not None and prompt_text_len is not None: | |
text = torch.concat([prompt_text, text], dim=1) | |
text_len += prompt_text_len | |
text = self.text_embedding(text) | |
tone_embed = self.tone_embedding(tone) | |
if prompt_tone is not None: | |
prompt_tone_embed = self.tone_embedding(prompt_tone) | |
tone_embed = torch.concat([prompt_tone_embed, tone_embed], dim=1) | |
# add tone embeddings to text embeddings | |
bert_feature = self.bert_embed_affine_layer(bert_feature).transpose(1, 2) | |
if prompt_bert_feature is not None: | |
prompt_bert_feature = self.bert_embed_affine_layer( | |
prompt_bert_feature | |
).transpose(1, 2) | |
bert_feature = torch.concat([prompt_bert_feature, bert_feature], dim=1) | |
text = torch.concat([text, tone_embed], dim=2) | |
text = self.encoder_affine_layer(text) | |
text = self.text_fusion(text, bert_feature) | |
T_text = text_len[0].item() | |
text_start = 2 # 2 for sos + speaker | |
text_end = text_start + T_text + 1 | |
if prompt_text_len is not None: | |
text_start = 2 + prompt_text_len.item() | |
text_end = text_start + T_text + 1 + prompt_text_len.item() | |
dp_tracker = DPAlignmentTracker( | |
aeam_heads=self.aeam_heads, | |
text_start=text_start, | |
text_end=text_end, | |
rho=3, | |
) | |
# 1. encode text | |
text, text_len = self.encode(text, text_len) | |
# 2. encode embedding | |
if speaker_embedding is not None and speaker_embedding.shape[0] != 0: | |
speaker_embedding = F.normalize(speaker_embedding, dim=1) | |
speaker_embedding = self.spk_embed_affine_layer(speaker_embedding) | |
speaker_embedding = speaker_embedding.unsqueeze(dim=1) | |
else: | |
speaker_embedding = ( | |
torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype) | |
.to(device) | |
.to(text.dtype) | |
) | |
# 3. concat llm_input | |
sos_eos_emb = ( | |
self.llm_embedding.weight[self.sos_eos] | |
.reshape(1, 1, -1) | |
.repeat(batch_size, 1, 1) | |
) | |
task_id_emb = ( | |
self.llm_embedding.weight[self.task_id] | |
.reshape(1, 1, -1) | |
.repeat(batch_size, 1, 1) | |
) | |
# Calculate context_len (number of context tokens before text) | |
context_len = 1 # sos_eos_emb | |
if speaker_embedding is not None and speaker_embedding.shape[1] > 0: | |
context_len += speaker_embedding.shape[1] | |
if prompt_speech_token_len is not None: | |
context_len += ( | |
prompt_speech_token_len.item() | |
if hasattr(prompt_speech_token_len, "item") | |
else int(prompt_speech_token_len) | |
) | |
context_len += text.shape[1] # text tokens | |
context_len += 1 # task_id_emb | |
# 4. prepare llm input | |
llm_input = [ | |
sos_eos_emb, | |
speaker_embedding, | |
text, | |
task_id_emb, | |
] | |
if prompt_speech_token_len is not None: | |
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) | |
llm_input.append(prompt_speech_token_emb) | |
lm_input = torch.concat( | |
llm_input, | |
dim=1, | |
) | |
generated_id = torch.tensor([[self.task_id]], device=device, dtype=torch.int64) | |
predicted = [] | |
attn_scores_list = [] | |
offset = 0 | |
# Initialize the first step | |
y_pred, att_cache, cnn_cache, attn_scores = self.forward_one_step( | |
lm_input, | |
offset=offset, | |
step=0, | |
dp_tracker=dp_tracker if use_dp else None, | |
use_dp=use_dp, | |
device=device, | |
) | |
attn_scores_list.append(attn_scores) | |
# cal max length | |
max_len = int((text_len - prompt_text_len) * max_token_text_ratio) | |
for i in range(1, max_len): | |
offset += lm_input.size(1) | |
logits = self.llm_decoder(y_pred[:, -1]) # [1, D] | |
if temperature != 1.0: | |
logits = logits / temperature | |
# Apply repetition penalty and top-p filtering | |
logits = repetition_penalty_processor(generated_id, logits) | |
logits = top_p_warper(None, logits) | |
# Convert logits to probabilities and sample the next token | |
probs = torch.softmax(logits, dim=-1) | |
next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1) | |
# Check for EOS token | |
if next_token.view(-1) == self.speech_token_size: | |
break | |
yield next_token.item(), attn_scores | |
predicted.append(next_token) | |
generated_id = torch.cat([generated_id, next_token], dim=1) # [1, T+1] | |
# Get embedding for the new token | |
lm_input = self.speech_embedding.weight[next_token].reshape(1, 1, -1) | |
y_pred, att_cache, cnn_cache, attn_scores = self.forward_one_step( | |
lm_input, | |
offset=offset, | |
step=i, | |
att_cache=att_cache, | |
cnn_cache=cnn_cache, | |
dp_tracker=dp_tracker if use_dp else None, | |
use_dp=use_dp, | |
device=device, | |
) | |
attn_scores_list.append(attn_scores) | |
return predicted, attn_scores_list |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment