Created
October 1, 2024 03:15
-
-
Save d0rc/1241c3b5c1461fe35d9a6c09afe66a47 to your computer and use it in GitHub Desktop.
build all tree of possible sequences
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import numpy as np | |
from scipy.special import softmax | |
from scipy.stats import entropy | |
from collections import deque | |
import os | |
# ANSI color codes | |
BLUE = "\033[94m" | |
GREEN = "\033[96m" | |
YELLOW = "\033[93m" | |
RESET = "\033[0m" | |
model_id = "meta-llama/Llama-3.2-1B-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained(model_id) | |
# Remove sequences.txt if it exists | |
if os.path.exists("sequences.txt"): | |
os.remove("sequences.txt") | |
class Node: | |
def __init__(self, token_id, parent=None): | |
self.token_id = token_id | |
self.children = [] | |
self.parent = parent | |
self.is_branch_point = False | |
def calculate_entropy(logits): | |
probs = softmax(logits.detach().cpu().numpy()) | |
return entropy(probs) | |
def generate_with_branching(prompt, max_length=50, probability_threshold=0.20, temperature=1.0): | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
root = Node(input_ids[0][-1].item()) | |
stack = deque([(root, input_ids, 0)]) | |
leaf_counter = 0 | |
while stack: | |
current_node, current_ids, step = stack.pop() | |
if step >= max_length: | |
continue | |
outputs = model(input_ids=current_ids) | |
next_token_logits = outputs.logits[0, -1, :] / temperature | |
full_entropy = calculate_entropy(next_token_logits) | |
probabilities = torch.softmax(next_token_logits, dim=-1) | |
print(f"\nStep {step + 1}.") | |
print(f"Text = {BLUE}{tokenizer.decode(current_ids[0])}{RESET}") | |
print(f"Full distribution entropy: {full_entropy:.4f}") | |
valid_continuations = [] | |
for token_id in range(len(probabilities)): | |
token_prob = probabilities[token_id].item() | |
if token_prob > probability_threshold: | |
valid_continuations.append((token_id, token_prob)) | |
if len(valid_continuations) >= 2: | |
current_node.is_branch_point = True | |
for token_id, token_prob in valid_continuations: | |
new_node = Node(token_id, parent=current_node) | |
current_node.children.append(new_node) | |
new_ids = torch.cat([current_ids, torch.tensor([[token_id]])], dim=1) | |
if token_id == tokenizer.eos_token_id: | |
print("End of text token reached.") | |
leaf_counter += 1 | |
log_sequence(new_ids[0], count_branch_points(new_node), leaf_counter) | |
else: | |
stack.append((new_node, new_ids, step + 1)) | |
print( | |
f"Branching: {YELLOW}{tokenizer.decode([token_id])}{RESET}, probability = {token_prob:.4f}, total branches = {len(stack)}") | |
else: | |
# Linear generation if there's only one or no valid continuation | |
if valid_continuations: | |
token_id, token_prob = valid_continuations[0] | |
else: | |
token_id = torch.argmax(next_token_logits).item() | |
token_prob = probabilities[token_id].item() | |
new_node = Node(token_id, parent=current_node) | |
current_node.children.append(new_node) | |
new_ids = torch.cat([current_ids, torch.tensor([[token_id]])], dim=1) | |
if token_id == tokenizer.eos_token_id: | |
print("End of text token reached.") | |
leaf_counter += 1 | |
log_sequence(new_ids[0], count_branch_points(new_node), leaf_counter) | |
else: | |
stack.append((new_node, new_ids, step + 1)) | |
print(f"Linear: {GREEN}{tokenizer.decode([token_id])}{RESET}, probability = {token_prob:.4f}") | |
if not current_node.children: # This is a leaf node | |
leaf_counter += 1 | |
log_sequence(current_ids[0], count_branch_points(current_node), leaf_counter) | |
return root | |
def count_branch_points(node): | |
count = 0 | |
current = node | |
while current.parent is not None: | |
if current.parent.is_branch_point: | |
count += 1 | |
current = current.parent | |
return count | |
def log_sequence(sequence, branch_points, leaf_number): | |
with open("sequences.txt", "a") as f: | |
f.write(f"Leaf Node: {leaf_number}\n") | |
f.write(f"Branch Points: {branch_points}\n") | |
f.write(f"Sequence: {tokenizer.decode(sequence)}\n") | |
f.write("-" * 50 + "\n") | |
def print_tree(node, depth=0, prefix=""): | |
if node.parent is None: | |
print(f"{prefix}Root: {tokenizer.decode([node.token_id])}") | |
else: | |
print(f"{prefix}{'[B] ' if node.is_branch_point else ''}{tokenizer.decode([node.token_id])}") | |
for i, child in enumerate(node.children): | |
if i == len(node.children) - 1: | |
print_tree(child, depth + 1, prefix + "└── ") | |
else: | |
print_tree(child, depth + 1, prefix + "├── ") | |
if __name__ == '__main__': | |
prompt = """<|start_header_id|>system<|end_header_id|> | |
You are a helpful assistant with advanced analytical capabilities.<|eot_id|> | |
<|start_header_id|>user<|end_header_id|> | |
Which number is larger, 9.9 or 9.11?<|eot_id|> | |
<|start_header_id|>assistant<|end_header_id|> | |
""" | |
result_tree = generate_with_branching(prompt, max_length=100, probability_threshold=0.20, temperature=0.7) | |
print("\nGenerated Tree Structure:") | |
print_tree(result_tree) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment