Last active
June 18, 2025 16:55
-
-
Save 903124/bddb42bbe5792237f61aae20b80833e1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Adapt from github.com/HKUNLP/Dream and github.com/ML-GSAI/LLaDA | |
import torch | |
import numpy as np | |
import gradio as gr | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModel | |
import time | |
import re | |
import traceback | |
import copy | |
# --- Outlines Imports --- | |
from outlines.processors.guide import RegexGuide | |
from outlines.models.transformers import TransformerTokenizer | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f"Using device: {device}") | |
# --- Model Configurations --- | |
MODELS = { | |
"LLaDA-8B-Instruct": { | |
"path": "GSAI-ML/LLaDA-8B-Instruct", | |
"mask_id": 126336, | |
"mask_token": "[MASK]", | |
"type": "llada" | |
}, | |
"Dream-v0-Instruct-7B": { | |
"path": "Dream-org/Dream-v0-Instruct-7B", | |
"mask_id": None, # Will be determined from tokenizer | |
"mask_token": "[MASK]", | |
"type": "dream" | |
} | |
} | |
# --- Global State --- | |
model = None | |
tokenizer = None | |
outlines_tokenizer = None | |
current_model_name = "" | |
MASK_ID = -1 | |
MASK_TOKEN = "[MASK]" | |
# --- Model Loading Function --- | |
def load_model_and_tokenizer(model_name: str): | |
"""Loads the selected model and tokenizer and initializes components.""" | |
global model, tokenizer, outlines_tokenizer, current_model_name, MASK_ID, MASK_TOKEN | |
if model_name == current_model_name: | |
return | |
print(f"\n--- Loading model: {model_name} ---") | |
model_config = MODELS[model_name] | |
model_path = model_config["path"] | |
try: | |
new_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
print("Tokenizer loaded.") | |
print("Loading model...") | |
new_model = AutoModel.from_pretrained(model_path, trust_remote_code=True, | |
torch_dtype=torch.bfloat16).to(device) | |
print("Model loaded and moved to device.") | |
# Assign to globals only after successful loading | |
tokenizer = new_tokenizer | |
model = new_model | |
# Create Outlines-compatible tokenizer | |
outlines_tokenizer = TransformerTokenizer(tokenizer) | |
# Determine mask token ID based on model type | |
if model_name == "Dream-v0-Instruct-7B": | |
MASK_ID = tokenizer.mask_token_id if tokenizer.mask_token_id is not None else -100 | |
if MASK_ID == -100: | |
print("Warning: .mask_token_id not found for Dream-v0. Using placeholder -100.") | |
else: # LLaDA-8B-Instruct | |
MASK_ID = model_config["mask_id"] | |
MASK_TOKEN = model_config["mask_token"] | |
current_model_name = model_name | |
print(f"Successfully loaded {model_name}. Mask ID: {MASK_ID}") | |
except Exception as e: | |
print(f"FATAL: Failed to load model {model_name}: {e}") | |
traceback.print_exc() | |
# Reset globals to prevent using a partially loaded model | |
model = None | |
tokenizer = None | |
outlines_tokenizer = None | |
current_model_name = "" | |
raise gr.Error(f"Failed to load model {model_name}. Check console for details.") | |
def get_fsm_state_for_prefix(guide: RegexGuide, prefix_token_ids_list: list[int], cache: dict): | |
"""Calculates the FSM state for a given prefix of token IDs, using a cache.""" | |
prefix_tuple = tuple(prefix_token_ids_list) | |
if prefix_tuple in cache: | |
return cache[prefix_tuple] | |
current_state = guide.initial_state | |
for token_id in prefix_token_ids_list: | |
current_state = guide.get_next_state(current_state, int(token_id)) | |
cache[prefix_tuple] = current_state | |
return current_state | |
def parse_constraints(constraints_text): | |
"""Parse constraints in format: 'position:word, position:word, ...'""" | |
constraints = {} | |
if not constraints_text: | |
return constraints | |
parts = constraints_text.split(',') | |
for part in parts: | |
if ':' not in part: continue | |
pos_str, word = part.split(':', 1) | |
try: | |
pos = int(pos_str.strip()) | |
word = word.strip() | |
if word and pos >= 0: | |
constraints[pos] = word | |
except ValueError: | |
continue | |
return constraints | |
def format_chat_history(history): | |
messages = [] | |
for user_msg, assistant_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
return messages | |
def add_gumbel_noise(logits, temperature): | |
if temperature <= 0: | |
return logits | |
logits = logits.to(torch.float64) | |
noise = torch.rand_like(logits, dtype=torch.float64) | |
gumbel_noise = -torch.log(-torch.log(noise)) * temperature | |
return logits + gumbel_noise | |
def get_num_transfer_tokens(mask_index, steps): | |
mask_num = mask_index.sum(dim=1, keepdim=True) | |
base = mask_num // steps | |
remainder = mask_num % steps | |
num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base | |
for i in range(mask_num.size(0)): | |
num_transfer_tokens[i, :remainder[i]] += 1 | |
return num_transfer_tokens | |
def unified_generate_with_visualization(model, tokenizer_hf, device, messages, gen_length=64, steps=32, | |
constraints=None, temperature=0.0, cfg_scale=0.0, block_length=32, | |
remasking='low_confidence', model_type="llada", | |
current_regex_guide=None, current_fsm_cache=None): | |
"""Unified generation function that works for both LLaDA and Dream models""" | |
if constraints is None: | |
constraints = {} | |
processed_constraints = {} | |
for pos, word in constraints.items(): | |
tokens = tokenizer_hf.encode(" " + word, add_special_tokens=False) | |
for i, token_id in enumerate(tokens): | |
processed_constraints[pos + i] = token_id | |
chat_input = tokenizer_hf.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | |
input_ids = tokenizer_hf(chat_input)['input_ids'] | |
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) | |
prompt_length = input_ids.shape[1] | |
x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(device) | |
x[:, :prompt_length] = input_ids.clone() | |
visualization_states = [] | |
initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)] | |
visualization_states.append(initial_state) | |
for pos, token_id in processed_constraints.items(): | |
absolute_pos = prompt_length + pos | |
if absolute_pos < x.shape[1]: | |
x[:, absolute_pos] = token_id | |
prompt_index = (x != MASK_ID) | |
if block_length > gen_length: | |
block_length = gen_length | |
num_blocks = (gen_length + block_length - 1) // block_length | |
steps_per_block = steps // num_blocks if num_blocks > 0 else steps | |
if steps_per_block < 1: | |
steps_per_block = 1 | |
for num_block in range(num_blocks): | |
block_start = prompt_length + num_block * block_length | |
block_end = min(prompt_length + (num_block + 1) * block_length, x.shape[1]) | |
block_mask_indices_in_x = (x[:, block_start:block_end] == MASK_ID) | |
if not block_mask_indices_in_x.any() and not processed_constraints: | |
continue | |
num_transfer_tokens_schedule = get_num_transfer_tokens(block_mask_indices_in_x, steps_per_block) | |
for step_idx in range(steps_per_block): | |
mask_index_full_seq = (x == MASK_ID) | |
if not mask_index_full_seq.any() and not processed_constraints: | |
break | |
# Get model logits | |
try: | |
with torch.no_grad(): | |
if cfg_scale > 0.0: | |
un_x = x.clone() | |
un_x[prompt_index] = MASK_ID | |
x_ = torch.cat([x, un_x], dim=0) | |
# Forward pass through model | |
if model_type == "dream": | |
outputs = model(x_) | |
logits = outputs.logits if hasattr(outputs, 'logits') else outputs | |
else: | |
logits = model(x_).logits | |
logits, un_logits = torch.chunk(logits, 2, dim=0) | |
logits = un_logits + (cfg_scale + 1) * (logits - un_logits) | |
else: | |
# Single forward pass | |
if model_type == "dream": | |
outputs = model(x) | |
logits = outputs.logits if hasattr(outputs, 'logits') else outputs | |
else: | |
logits = model(x).logits | |
except Exception as e: | |
print(f"Error during model forward pass: {e}") | |
traceback.print_exc() | |
raise | |
x0 = torch.zeros_like(x) | |
# Apply structured generation if regex guide is available | |
if current_regex_guide and current_fsm_cache is not None: | |
x0_so_far = x.clone() | |
for j_pos in range(block_start, block_end): | |
if x[0, j_pos] == MASK_ID: | |
prefix_token_ids = [] | |
for k_idx in range(prompt_length, j_pos): | |
token_at_k = x0_so_far[0, k_idx].item() | |
if token_at_k != MASK_ID: | |
prefix_token_ids.append(token_at_k) | |
fsm_state = get_fsm_state_for_prefix(current_regex_guide, prefix_token_ids, current_fsm_cache) | |
instruction = current_regex_guide.get_next_instruction(fsm_state) | |
allowed_tokens = instruction.tokens.to(logits.device, dtype=torch.long) | |
logits_for_pos = logits[0, j_pos].clone() | |
mask = torch.ones_like(logits_for_pos, dtype=torch.bool) | |
if allowed_tokens.numel() > 0: | |
mask[allowed_tokens] = False | |
logits_for_pos.masked_fill_(mask, -float('inf')) | |
noisy_logits = add_gumbel_noise(logits_for_pos, temperature) | |
next_token_id = torch.argmax(noisy_logits) | |
x0_so_far[0, j_pos] = next_token_id | |
x0 = x0_so_far | |
else: | |
logits_with_noise = add_gumbel_noise(logits, temperature) | |
x0 = torch.argmax(logits_with_noise, dim=-1) | |
if remasking == 'low_confidence': | |
p = F.softmax(logits.to(torch.float64), dim=-1) | |
x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) | |
else: # 'random' | |
x0_p = torch.rand_like(logits[:, :, 0]) | |
x0_p[:, :block_start] = -float('inf') | |
x0_p[:, block_end:] = -float('inf') | |
old_x = x.clone() | |
x0_final = torch.where(mask_index_full_seq, x0, x) | |
confidence = torch.where(mask_index_full_seq, x0_p, -float('inf')) | |
confidence_block_scoped = confidence.clone() | |
confidence_block_scoped[:, :block_start] = -float('inf') | |
confidence_block_scoped[:, block_end:] = -float('inf') | |
transfer_index = torch.zeros_like(x0_final, dtype=torch.bool) | |
for j_batch_idx in range(confidence_block_scoped.shape[0]): | |
masked_positions_in_block_confidence = confidence_block_scoped[j_batch_idx, block_start:block_end].clone() | |
masked_positions_in_block_confidence[~block_mask_indices_in_x[j_batch_idx]] = -float('inf') | |
k_to_transfer = min(num_transfer_tokens_schedule[j_batch_idx, step_idx].item(), | |
(masked_positions_in_block_confidence > -float('inf')).sum().item()) | |
if k_to_transfer > 0: | |
if step_idx < steps_per_block - 1: | |
_, select_indices_in_block = torch.topk(masked_positions_in_block_confidence, k=k_to_transfer) | |
select_indices_global = select_indices_in_block + block_start | |
transfer_index[j_batch_idx, select_indices_global] = True | |
else: | |
transfer_index[j_batch_idx, block_start:block_end] = block_mask_indices_in_x[j_batch_idx] | |
x = torch.where(transfer_index, x0_final, x) | |
for pos, token_id in processed_constraints.items(): | |
absolute_pos = prompt_length + pos | |
if absolute_pos < x.shape[1]: | |
x[:, absolute_pos] = token_id | |
current_state_vis = [] | |
for i_vis in range(gen_length): | |
pos_vis = prompt_length + i_vis | |
token_val, old_token_val = x[0, pos_vis].item(), old_x[0, pos_vis].item() | |
if token_val == MASK_ID: | |
current_state_vis.append((MASK_TOKEN, "#444444")) | |
elif old_token_val == MASK_ID: | |
token_str = tokenizer_hf.decode([token_val], skip_special_tokens=True) | |
confidence_val = float(x0_p[0, pos_vis].cpu()) if x0_p[0, pos_vis] != -float('inf') else 0.0 | |
color = "#FF6666" if confidence_val < 0.3 else "#FFAA33" if confidence_val < 0.7 else "#66CC66" | |
current_state_vis.append((token_str, color)) | |
else: | |
token_str = tokenizer_hf.decode([token_val], skip_special_tokens=True) | |
current_state_vis.append((token_str, "#6699CC")) | |
visualization_states.append(current_state_vis) | |
response_tokens_ids = [tid for tid in x[0, prompt_length:].tolist() if tid != MASK_ID] | |
final_text = tokenizer_hf.decode( | |
response_tokens_ids, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True | |
) | |
return visualization_states, final_text | |
css = """ | |
.category-legend{display:none} | |
button{height: 60px} | |
""" | |
def create_chatbot_demo(): | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("# Unified Diffusion Model Demo with Structured Generation") | |
gr.Markdown("### A unified interface for LLaDA and Dream-v0 with Outlines regex guidance.") | |
with gr.Row(): | |
model_selector = gr.Dropdown( | |
choices=list(MODELS.keys()), | |
value=list(MODELS.keys())[0], | |
label="Select Model", | |
info="The selected model will be loaded on first generation." | |
) | |
chat_history = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot_ui = gr.Chatbot(label="Conversation", height=500) | |
with gr.Group(): | |
with gr.Row(): | |
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here or select an example...", show_label=False) | |
send_btn = gr.Button("Send") | |
constraints_input = gr.Textbox(label="Word Constraints (Positional)", info="Format: '0:Word, 5:Another'", placeholder="0:Once, 5:upon, 10:time") | |
with gr.Column(scale=2): | |
output_vis = gr.HighlightedText(label="Denoising Process Visualization", combine_adjacent=False, show_legend=True) | |
with gr.Accordion("Regex Constraint", open=True): | |
regex_input = gr.Textbox( | |
label="Regex Pattern", | |
value=r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)", | |
info="The model will be constrained to generate text matching this regex. Try an example!", | |
interactive=True | |
) | |
gr.Examples( | |
examples=[ | |
[r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)", "Generate a random IP address."], | |
[r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", "The new user's email address is "], | |
[r"\d{4}-\d{2}-\d{2}", "Invoice date (YYYY-MM-DD): "], | |
[r'https?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', "The project's documentation is available at "], | |
[r'\{\s*"user":\s*".+",\s*"id":\s*\d+\s*\}', 'Create a JSON object for user "test" with id 123. '], | |
], | |
inputs=[regex_input, user_input], | |
label="Regex Examples", | |
) | |
with gr.Accordion("Generation Settings", open=False): | |
with gr.Row(): | |
gen_length = gr.Slider(minimum=16, maximum=128, value=64, step=8, label="Generation Length") | |
steps = gr.Slider(minimum=8, maximum=64, value=32, step=4, label="Denoising Steps") | |
with gr.Row(): | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Temperature") | |
cfg_scale = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale") | |
with gr.Row(): | |
block_length = gr.Slider(minimum=8, maximum=128, value=32, step=8, label="Block Length") | |
remasking_strategy = gr.Radio(choices=["low_confidence", "random"], value="low_confidence", label="Remasking Strategy") | |
with gr.Row(): | |
visualization_delay = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.1, label="Visualization Delay (seconds)") | |
current_response = gr.Textbox(label="Current Response", lines=3, visible=False) | |
clear_btn = gr.Button("Clear Conversation") | |
def add_message(history, message, response): | |
return history + [[message, response]] | |
def user_message_submitted(message, history): | |
if not message.strip(): | |
return history, history, "", [], "" | |
history = add_message(history, message, None) | |
return history, history, "", [], "" | |
def bot_response_stream(model_name, history, regex_str, gen_len, num_steps, const_text, delay, temp, cfg, block_len, remask): | |
try: | |
load_model_and_tokenizer(model_name) | |
except Exception as e: | |
error_msg = f"Failed to load model: {e}" | |
if history: | |
history[-1][1] = error_msg | |
else: | |
history = [["System", error_msg]] | |
yield history, [(error_msg, "red")], error_msg | |
return | |
if not history or history[-1][1] is not None: | |
yield history, [], "" | |
return | |
# Initialize regex guide for this specific generation | |
current_regex_guide = None | |
fsm_cache = {} | |
if regex_str: | |
try: | |
current_regex_guide = RegexGuide.from_regex(regex_str, outlines_tokenizer) | |
print(f"Successfully created RegexGuide for pattern: {regex_str}") | |
except Exception as e: | |
error_msg = f"Invalid Regex Pattern: {e}" | |
print(error_msg) | |
history[-1][1] = error_msg | |
yield history, [(error_msg, "red")], error_msg | |
return | |
messages = format_chat_history(history) | |
try: | |
model_config = MODELS[model_name] | |
model_type = model_config["type"] | |
parsed_constraints = parse_constraints(const_text) | |
vis_states, response_text = unified_generate_with_visualization( | |
model, tokenizer, device, messages, | |
gen_length=gen_len, steps=num_steps, constraints=parsed_constraints, | |
temperature=temp, cfg_scale=cfg, block_length=block_len, remasking=remask, | |
model_type=model_type, | |
current_regex_guide=current_regex_guide, current_fsm_cache=fsm_cache | |
) | |
history[-1][1] = response_text | |
if not vis_states: | |
yield history, [], response_text | |
return | |
yield history, vis_states[0], response_text | |
for state in vis_states[1:]: | |
time.sleep(delay) | |
yield history, state, response_text | |
except Exception as e: | |
error_msg = f"Error during generation: {e}" | |
print(error_msg) | |
traceback.print_exc() | |
history[-1][1] = error_msg | |
yield history, [(error_msg, "red")], error_msg | |
def clear_conversation(): | |
return [], [], "", [] | |
clear_btn.click(fn=clear_conversation, outputs=[chat_history, chatbot_ui, current_response, output_vis]) | |
trigger_args = { | |
"fn": user_message_submitted, | |
"inputs": [user_input, chat_history], | |
"outputs": [chat_history, chatbot_ui, user_input, output_vis, current_response] | |
} | |
response_args = { | |
"fn": bot_response_stream, | |
"inputs": [model_selector, chat_history, regex_input, gen_length, steps, constraints_input, visualization_delay, temperature, cfg_scale, block_length, remasking_strategy], | |
"outputs": [chatbot_ui, output_vis, current_response] | |
} | |
user_input.submit(**trigger_args).then(**response_args) | |
send_btn.click(**trigger_args).then(**response_args) | |
return demo | |
if __name__ == "__main__": | |
demo = create_chatbot_demo() | |
demo.queue().launch(share=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment