Skip to content

Instantly share code, notes, and snippets.

@d0rc
Created October 1, 2024 03:15
Show Gist options
  • Save d0rc/1241c3b5c1461fe35d9a6c09afe66a47 to your computer and use it in GitHub Desktop.
Save d0rc/1241c3b5c1461fe35d9a6c09afe66a47 to your computer and use it in GitHub Desktop.
build all tree of possible sequences
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
from scipy.special import softmax
from scipy.stats import entropy
from collections import deque
import os
# ANSI color codes
BLUE = "\033[94m"
GREEN = "\033[96m"
YELLOW = "\033[93m"
RESET = "\033[0m"
model_id = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
# Remove sequences.txt if it exists
if os.path.exists("sequences.txt"):
os.remove("sequences.txt")
class Node:
def __init__(self, token_id, parent=None):
self.token_id = token_id
self.children = []
self.parent = parent
self.is_branch_point = False
def calculate_entropy(logits):
probs = softmax(logits.detach().cpu().numpy())
return entropy(probs)
def generate_with_branching(prompt, max_length=50, probability_threshold=0.20, temperature=1.0):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
root = Node(input_ids[0][-1].item())
stack = deque([(root, input_ids, 0)])
leaf_counter = 0
while stack:
current_node, current_ids, step = stack.pop()
if step >= max_length:
continue
outputs = model(input_ids=current_ids)
next_token_logits = outputs.logits[0, -1, :] / temperature
full_entropy = calculate_entropy(next_token_logits)
probabilities = torch.softmax(next_token_logits, dim=-1)
print(f"\nStep {step + 1}.")
print(f"Text = {BLUE}{tokenizer.decode(current_ids[0])}{RESET}")
print(f"Full distribution entropy: {full_entropy:.4f}")
valid_continuations = []
for token_id in range(len(probabilities)):
token_prob = probabilities[token_id].item()
if token_prob > probability_threshold:
valid_continuations.append((token_id, token_prob))
if len(valid_continuations) >= 2:
current_node.is_branch_point = True
for token_id, token_prob in valid_continuations:
new_node = Node(token_id, parent=current_node)
current_node.children.append(new_node)
new_ids = torch.cat([current_ids, torch.tensor([[token_id]])], dim=1)
if token_id == tokenizer.eos_token_id:
print("End of text token reached.")
leaf_counter += 1
log_sequence(new_ids[0], count_branch_points(new_node), leaf_counter)
else:
stack.append((new_node, new_ids, step + 1))
print(
f"Branching: {YELLOW}{tokenizer.decode([token_id])}{RESET}, probability = {token_prob:.4f}, total branches = {len(stack)}")
else:
# Linear generation if there's only one or no valid continuation
if valid_continuations:
token_id, token_prob = valid_continuations[0]
else:
token_id = torch.argmax(next_token_logits).item()
token_prob = probabilities[token_id].item()
new_node = Node(token_id, parent=current_node)
current_node.children.append(new_node)
new_ids = torch.cat([current_ids, torch.tensor([[token_id]])], dim=1)
if token_id == tokenizer.eos_token_id:
print("End of text token reached.")
leaf_counter += 1
log_sequence(new_ids[0], count_branch_points(new_node), leaf_counter)
else:
stack.append((new_node, new_ids, step + 1))
print(f"Linear: {GREEN}{tokenizer.decode([token_id])}{RESET}, probability = {token_prob:.4f}")
if not current_node.children: # This is a leaf node
leaf_counter += 1
log_sequence(current_ids[0], count_branch_points(current_node), leaf_counter)
return root
def count_branch_points(node):
count = 0
current = node
while current.parent is not None:
if current.parent.is_branch_point:
count += 1
current = current.parent
return count
def log_sequence(sequence, branch_points, leaf_number):
with open("sequences.txt", "a") as f:
f.write(f"Leaf Node: {leaf_number}\n")
f.write(f"Branch Points: {branch_points}\n")
f.write(f"Sequence: {tokenizer.decode(sequence)}\n")
f.write("-" * 50 + "\n")
def print_tree(node, depth=0, prefix=""):
if node.parent is None:
print(f"{prefix}Root: {tokenizer.decode([node.token_id])}")
else:
print(f"{prefix}{'[B] ' if node.is_branch_point else ''}{tokenizer.decode([node.token_id])}")
for i, child in enumerate(node.children):
if i == len(node.children) - 1:
print_tree(child, depth + 1, prefix + "└── ")
else:
print_tree(child, depth + 1, prefix + "├── ")
if __name__ == '__main__':
prompt = """<|start_header_id|>system<|end_header_id|>
You are a helpful assistant with advanced analytical capabilities.<|eot_id|>
<|start_header_id|>user<|end_header_id|>
Which number is larger, 9.9 or 9.11?<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""
result_tree = generate_with_branching(prompt, max_length=100, probability_threshold=0.20, temperature=0.7)
print("\nGenerated Tree Structure:")
print_tree(result_tree)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment