Skip to content

Instantly share code, notes, and snippets.

@shrimo
Last active May 11, 2025 21:33
Show Gist options
  • Save shrimo/469f6e890cf87473bc57a1ca11346fdf to your computer and use it in GitHub Desktop.
Save shrimo/469f6e890cf87473bc57a1ca11346fdf to your computer and use it in GitHub Desktop.
Small Transformer Model with GUI and Visualization

Small Transformer Model with GUI and Visualization

This project demonstrates a small Transformer model designed for natural language generation. It features a graphical user interface (GUI) using tkinter for user input and displays generated text as well as advanced real-time vector visualization using OpenCV.

Features

  1. Small Transformer Model:

    • A lightweight Transformer model optimized for environments with limited resources (e.g., GPUs with small memory capacities).
    • Features include reduced embedding size, fewer Transformer layers, and fewer attention heads.
  2. Graphical User Interface (GUI):

    • Built using tkinter.
    • Users can input "seed" text and view the model's generated response in a scrollable text area.
    • Includes a button to generate text on demand.
  3. Advanced Vector Visualization:

    • Embedding vectors are visualized in real-time using OpenCV as the model processes the input.
    • Features include:
      • Color-coded points.
      • Connections between points to represent sequential relationships.
      • Labels for points to indicate their order in the sequence.
  4. Data Preparation:

    • Scrapes text data from Wikipedia pages to train or fine-tune the model.
    • Prepares tokenized datasets optimized for a small Transformer architecture.

Installation

Prerequisites

  • Python 3.8 or later
  • GPU with CUDA support (optional but recommended)
  • Required Python packages:
    • torch
    • transformers
    • scikit-learn
    • opencv-python
    • tkinter (comes pre-installed with Python on most systems)
    • beautifulsoup4
    • requests
    • numpy

Installation Steps

  1. Clone the repository:

    git clone https://github.com/your/repo.git
    cd repo
  2. Install required Python packages:

    pip install torch transformers scikit-learn opencv-python beautifulsoup4 requests numpy
  3. Run the script:

    python script_name.py

Usage

  1. Run the Script: Execute the script to launch the GUI:

    python script_name.py
  2. Enter Seed Text:

    • A GUI window will appear.
    • Enter your "seed" text in the input field and press "Generate".
  3. View Generated Text:

    • The model's response will appear in the scrollable text area below the input field.
  4. Visualize Embeddings:

    • Real-time embeddings generated by the model will be displayed in an OpenCV window.
    • The visualization includes:
      • Points representing tokens in the sequence.
      • Lines connecting points to represent token order.
      • Color gradients for points and lines.
  5. Repeat:

    • Enter a new seed text to generate responses and visualize embeddings.

Code Overview

Key Components

  1. Model:

    • SmallTransformerModel is a lightweight Transformer model with customizable embedding size, attention heads, and layers.
  2. GUI:

    • Built using tkinter.
    • Contains:
      • Input field for seed text.
      • Button to trigger text generation.
      • Scrollable text area to display the results.
  3. Visualization:

    • Embedding vectors are visualized using OpenCV and scikit-learn's PCA for dimensionality reduction.
  4. Data Preparation:

    • Scrapes Wikipedia pages to prepare training datasets.
    • Tokenizes and splits data into manageable sequences.

File Structure

repo/
├── script_name.py       # Main script with GUI, text generation, and visualization
└── README.md            # Project documentation

Advanced Visualization

  • OpenCV Window:

    • Displays points in a 2D space, reduced using PCA.
    • Points are connected by lines to indicate sequential relationships.
    • Color gradients represent token indices.
  • Example Visualization:

    • Tokens from the input sequence are plotted in a 2D space.
    • Connections and color gradients make it easier to interpret the model's internal representations.

Future Work

  • Add support for training the model directly from the GUI.
  • Improve visualization to include interactive 3D plots.
  • Optimize model performance with quantization for deployment on edge devices.

References

import requests
from bs4 import BeautifulSoup
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.amp import GradScaler, autocast
from transformers import BertTokenizer
import os
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import cv2
import tkinter as tk
from tkinter import simpledialog, scrolledtext
from sklearn.decomposition import PCA
# Step 1: Scrape the Wikipedia page
def scrape_text(url):
response = requests.get(url)
soup = BeautifulSoup(response.text, 'html.parser')
paragraphs = soup.find_all('p')
text = " ".join([p.get_text(strip=True) for p in paragraphs])
return text
# Step 2: Custom Dataset for Tokenized Data with Truncation
class TextDataset(Dataset):
def __init__(self, tokenizer, text, block_size):
tokens = tokenizer.encode(text, add_special_tokens=True)
self.examples = [
tokens[i:i + block_size] for i in range(0, len(tokens), block_size) if len(tokens[i:i + block_size]) <= block_size
]
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
return torch.tensor(self.examples[idx])
# Step 3: Smaller Transformer Model
class SmallTransformerModel(nn.Module):
def __init__(self, vocab_size, embed_dim=64, num_heads=2, num_layers=2, ff_dim=128, max_seq_len=128):
super(SmallTransformerModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_len, embed_dim))
encoder_layer = nn.TransformerEncoderLayer(embed_dim, num_heads, ff_dim, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
self.output_layer = nn.Linear(embed_dim, vocab_size)
def forward(self, x):
embedded = self.embedding(x) + self.positional_encoding[:, :x.size(1), :]
transformed = self.transformer(embedded)
logits = self.output_layer(transformed)
return logits, embedded
# Step 4: Train the Model
def train_model(model, dataloader, optimizer, criterion, scheduler, device, epochs=10):
"""
Train the model using the provided DataLoader and optimizer.
"""
model.train()
scaler = GradScaler(enabled=torch.cuda.is_available()) # Mixed precision training
for epoch in range(epochs):
total_loss = 0
for batch in dataloader:
batch = batch.to(device)
optimizer.zero_grad()
with autocast(device_type='cuda', enabled=torch.cuda.is_available()): # Specify device_type
output, _ = model(batch[:, :-1])
loss = criterion(output.view(-1, output.size(-1)), batch[:, 1:].contiguous().view(-1))
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
scheduler.step()
total_loss += loss.item()
print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader):.4f}")
print("Training complete.")
# Step 5: Advanced visualization using OpenCV
def advanced_visualize_vectors(vectors, title="Advanced Vector Visualization"):
"""
Advanced visualization of vectors in 2D space using OpenCV.
Includes color coding, connections, and labels.
"""
vectors = vectors.detach().cpu().numpy()
vectors = vectors.reshape(-1, vectors.shape[-1]) # Flatten for visualization
pca = PCA(n_components=2)
reduced_vectors = pca.fit_transform(vectors)
# Normalize to fit within visualization space
reduced_vectors -= reduced_vectors.min(axis=0)
reduced_vectors /= reduced_vectors.max(axis=0)
reduced_vectors = (reduced_vectors * 400).astype(np.int32)
# Create blank image for visualization
img = np.zeros((500, 500, 3), dtype=np.uint8)
# Draw connections between points
for i in range(1, len(reduced_vectors)):
pt1 = (reduced_vectors[i - 1][0] + 50, reduced_vectors[i - 1][1] + 50)
pt2 = (reduced_vectors[i][0] + 50, reduced_vectors[i][1] + 50)
color = (255 - i * 5, i * 5, 200) # Gradient color
cv2.line(img, pt1, pt2, color, 1)
# Draw points with labels
for i, point in enumerate(reduced_vectors):
x, y = point
x += 50
y += 50
color = (0, 255 - i * 5, i * 5) # Gradient color
cv2.circle(img, (x, y), 5, color, -1)
cv2.putText(img, f"{i}", (x + 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
# Add title
cv2.putText(img, title, (10, 470), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
# Show the visualization
cv2.imshow(title, img)
cv2.waitKey(500)
# Step 6: Generate Text
def generate_text(model, tokenizer, device, seed_text, max_len=50):
"""
Generate text based on seed input using the model.
Also visualizes embeddings in real-time.
"""
model.eval()
tokens = tokenizer.encode(seed_text, return_tensors="pt").to(device)
generated = tokens
for _ in range(max_len):
with torch.no_grad():
output, embeddings = model(generated)
next_token = torch.argmax(output[:, -1, :], dim=-1).unsqueeze(0)
generated = torch.cat((generated, next_token), dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
# Advanced visualization of embeddings
advanced_visualize_vectors(embeddings)
return tokenizer.decode(generated[0])
# Step 7: Prepare Training Data
def prepare_training_data():
urls = [
"https://en.wikipedia.org/wiki/Three_Laws_of_Robotics",
"https://en.wikipedia.org/wiki/Foundation_(book_series)",
"https://en.wikipedia.org/wiki/The_Complete_Robot"
]
text = ""
for url in urls:
print(f"Scraping text from: {url}")
text += scrape_text(url) + " "
return text
# Step 8: Input Text via GUI with Output Window
def get_user_input_and_display_output(model, tokenizer, device):
"""
Opens a GUI to get text input from the user and display the generated text in a scrollable text area.
"""
def generate_and_display():
seed_text = input_text.get()
if seed_text:
generated_text = generate_text(model, tokenizer, device, seed_text)
output_text.insert(tk.END, f"Input: {seed_text}\n")
output_text.insert(tk.END, f"Generated: {generated_text}\n\n")
output_text.see(tk.END) # Auto-scroll to the end
root = tk.Tk()
root.title("Text Generator")
# Input field
input_label = tk.Label(root, text="Enter Seed Text:")
input_label.pack()
input_text = tk.Entry(root, width=50)
input_text.pack()
# Generate button
generate_button = tk.Button(root, text="Generate", command=generate_and_display)
generate_button.pack()
# Output text area
output_text = scrolledtext.ScrolledText(root, wrap=tk.WORD, width=60, height=20)
output_text.pack()
# Run the GUI
root.mainloop()
# Main script
if __name__ == "__main__":
model_path = "model/small_transformer.pth"
tokenizer_path = "model"
# Check if the model and tokenizer already exist
model_exists = os.path.exists(model_path)
tokenizer_exists = os.path.exists(tokenizer_path)
# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# Model Parameters
vocab_size = tokenizer.vocab_size
embed_dim = 64 # Smaller embedding size
num_heads = 4 # Less attention heads
num_layers = 4 # Fewer Transformer layers
ff_dim = 128 # Smaller feed-forward layer
max_seq_len = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SmallTransformerModel(vocab_size, embed_dim, num_heads, num_layers, ff_dim, max_seq_len)
model.to(device)
if not model_exists or not tokenizer_exists:
print("No trained model found. Starting training process...")
# Prepare training data
text = prepare_training_data()
dataset = TextDataset(tokenizer, text, block_size=64)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=pad_sequence)
# Training setup
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
criterion = nn.CrossEntropyLoss()
# Train the model
train_model(model, dataloader, optimizer, criterion, scheduler, device, epochs=50)
# Save the model and tokenizer
os.makedirs("model", exist_ok=True)
torch.save(model.state_dict(), model_path)
tokenizer.save_pretrained(tokenizer_path)
print("Training complete. Model and tokenizer saved.")
else:
model.load_state_dict(torch.load(model_path))
model.eval()
print("Model loaded for inference.")
# Get user input and display output
get_user_input_and_display_output(model, tokenizer, device)
cv2.destroyAllWindows() # Close OpenCV windows
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment