Skip to content

Instantly share code, notes, and snippets.

@Gal-Lahat
Last active May 14, 2025 18:42
Show Gist options
  • Save Gal-Lahat/154a39ce49298dadfb0ef87f31d56568 to your computer and use it in GitHub Desktop.
Save Gal-Lahat/154a39ce49298dadfb0ef87f31d56568 to your computer and use it in GitHub Desktop.

The Latent Manipulator Cookbook.md

This guide explains the "Latent Manipulator," an experimental AI architecture designed to "think" in a latent space before generating text, contrasting with standard Transformer models that predict text sequentially. It includes the theory, code for implementation, and links to datasets and pretrained model checkpoints.

Based on the video exploring this concept: [https://www.youtube.com/watch?v=fWiieyG2zes]

⚠️ This project is free to use. If you find it helpful, please consider supporting it by checking out Peach Voice Typing: [https://peach-voice.com]

Note:

This project doesn’t have a dedicated integrated repository. It’s a compilation of multiple standalone scripts, patched together with the help of language models.

Why a Latent Manipulator? The Theory

Standard Large Language Models (LLMs) like ChatGPT are typically based on the Transformer architecture. Their core operation involves predicting the next word (or token) in a sequence, given all the preceding words. This means their process of "thinking" or reasoning is intertwined with the act of generating text word-by-word. If you ask ChatGPT, "Can you think quietly before writing?", it might say yes, but architecturally, it can't – its computation is the generation process.

Humans, however, can often form an "idea" or grasp the semantics of a concept before finding the exact words to express it. The Latent Manipulator architecture attempts to mimic this separation:

  1. Idea Space (Latent Space): We need a way to represent the meaning or idea of a piece of text numerically, separate from the text itself. This is achieved using an Autoencoder.

    • Encoder: Takes text as input and compresses it into a dense numerical vector (e.g., 1024 numbers). This vector lives in the "latent space" and represents the "idea" of the input text.
    • Decoder: Takes a vector from the latent space and reconstructs the original text (or a close approximation).
    • Bottleneck: The crucial part is the "bottleneck" in the middle of the autoencoder (the latent space itself), which forces the model to learn a compact, meaningful representation of the input.
  2. The Latent Manipulator (Thinking Engine): This is a separate model (which doesn't have to be a Transformer) that operates entirely within the latent space.

    • It takes the latent vector representing the question (generated by the Encoder).
    • It performs computations on this vector to transform it into a new latent vector representing the answer.
    • This transformation is the "thinking" process, happening without generating any text.
  3. Generating the Answer: The resulting latent vector (the "idea" of the answer) is then fed into the Decoder part of the autoencoder, which converts this "idea" back into human-readable text.

In essence: Text Question -> Encoder -> Latent Question -> Latent Manipulator -> Latent Answer -> Decoder -> Text Answer.

This separation offers potential advantages:

  • True "Thinking": Allows computation on semantic meaning before articulation.
  • Multilingual Potential: The latent space could potentially become language-agnostic. You could train different Encoder/Decoder pairs for various languages but use the same Latent Manipulator for reasoning, promoting consistency across languages.
  • Efficiency & Control: Manipulating smaller latent vectors might be more efficient than full text generation for certain reasoning tasks.

Implementation Guide

Prerequisites

  • Python 3.12.2 (or compatible)
  • Transformers library (version 4.37.2 used here)
  • PyTorch
  • Pandas & PyArrow (for data preparation)
  • NumPy
  • tqdm (for progress bars)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import json
import numpy as np
from tqdm import tqdm
import pandas as pd
import glob
import pyarrow

1. Setup: Loading the Autoencoder

We use a pre-trained T5 model that has been adapted into a bottlenecked autoencoder. This model provides the embed (text to latent) and generate_from_latent (latent to text) functionalities.

# Define the Autoencoder Abstraction (Helper Class)
# Note: This requires the specific model code from 'thesephist/contra-bottleneck-t5-large-wikipedia'
# Ensure you have 'trust_remote_code=True' when loading if needed.
class BottleneckT5Autoencoder:
    def __init__(self, model_path: str, device='cpu'):
        self.device = device
        print(f"Using device: {self.device}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=512)
        # Ensure trust_remote_code=True if the model requires custom code
        self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(self.device)
        self.model.eval() # Set to evaluation mode

    @torch.no_grad()
    def embed(self, text: str) -> torch.FloatTensor:
        """Encodes a single string into a latent embedding."""
        inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=512).to(self.device)
        # Decoder input starts with the beginning-of-sequence token for T5-like models
        decoder_input_ids = torch.tensor([[self.tokenizer.pad_token_id]], dtype=torch.long).to(self.device)

        # Generate the latent embedding
        outputs = self.model(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            decoder_input_ids=decoder_input_ids, # Provide initial decoder input
            encode_only=True, # Flag to get the bottleneck representation
        )
        # The exact output structure might vary; inspect 'outputs' if needed.
        # Assuming the latent vector is the first element.
        return outputs[0]


    @torch.no_grad()
    def embed_batch(self, texts: list[str]) -> torch.FloatTensor:
        """Encodes a batch of strings into latent embeddings."""
        inputs = self.tokenizer(
            texts,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=512
        ).to(self.device)

        # Prepare decoder start tokens for the batch
        decoder_start_token_id = self.model.config.decoder_start_token_id
        if decoder_start_token_id is None:
             decoder_start_token_id = self.tokenizer.pad_token_id # Fallback if not defined

        decoder_input_ids = torch.full(
            (len(texts), 1),
            decoder_start_token_id,
            dtype=torch.long,
            device=self.device
        )

        outputs = self.model(
            **inputs,
            decoder_input_ids=decoder_input_ids,
            encode_only=True,
        )
        return outputs[0]

    @torch.no_grad()
    def generate_from_latent(self, latent: torch.FloatTensor, max_length=512, temperature=0.4) -> str:
        """Decodes a latent embedding back into text."""
        # Ensure latent is on the correct device and has a batch dimension
        if latent.dim() == 1:
            latent = latent.unsqueeze(0)
        latent = latent.to(self.device)

        # Use the model's generate method with the latent vector
        # This relies on the custom model code handling the 'latent_vector' parameter
        output_sequences = self.model.generate(
            encoder_outputs=None, # We provide latent directly, not standard encoder outputs
            latent_vector=latent, # Custom argument for this specific model
            max_length=max_length,
            do_sample=True,
            temperature=temperature,
            top_p=0.9,
            num_return_sequences=1,
            pad_token_id=self.tokenizer.eos_token_id # Important for stopping generation
        )
        # Decode the first sequence
        return self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
# --- Initialize the Autoencoder ---
# Set your Hugging Face token if needed
# os.environ["HF_TOKEN"] = "your_huggingface_token"

# Determine device (adjust as needed: 'cuda', 'mps', 'cpu')
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available(): # For Apple Silicon
    device = 'mps'
else:
    device = 'cpu'

# Load the pre-trained autoencoder model
model_path = 'thesephist/contra-bottleneck-t5-large-wikipedia'
autoencoder = BottleneckT5Autoencoder(model_path=model_path, device=device)
print("Autoencoder loaded successfully.")

2. Data Preparation

The training data for the Latent Manipulator consists of pairs of latent embeddings: one for the instruction (question) and one for the response (answer).

a) Combining Raw Data (if needed)

The example uses the LaMini dataset, provided as Parquet files. If your data is similar, you first need to combine these into a single file (e.g., JSONL).

You can find the raw data i used in [https://huggingface.co/datasets/MBZUAI/LaMini-instruction/tree/main/data]

# --- Combine Parquet files into a single JSONL (Example) ---
parquet_dir = "/path/to/your/parquet/files/" # Directory containing the .parquet files
output_jsonl_file = "/path/to/save/merged_output.jsonl"
parquet_files = glob.glob(os.path.join(parquet_dir, "train-*.parquet"))

print(f"Found Parquet files: {parquet_files}")

# Use 'append' mode to allow resuming if interrupted
with open(output_jsonl_file, "a", encoding="utf-8") as outfile:
    for file_path in tqdm(parquet_files, desc="Processing Parquet files"):
        try:
            df = pd.read_parquet(file_path)
            print(f"Read {len(df)} rows from {os.path.basename(file_path)}")
            # Iterate through DataFrame rows and write as JSON lines
            for _, row in tqdm(df.iterrows(), total=len(df), desc="Writing JSONL", leave=False):
                # Ensure columns 'instruction' and 'response' exist
                if 'instruction' in row and 'response' in row:
                     json_record = json.dumps({"instruction": row['instruction'], "response": row['response']})
                     outfile.write(json_record + '\n')
                else:
                     print(f"Skipping row due to missing columns: {row.to_dict()}")
        except Exception as e:
            print(f"Error processing {file_path}: {e}")

print(f"Finished merging Parquet files into {output_jsonl_file}")

b) Generating Embeddings

Convert the text instruction/response pairs from the JSONL file into latent embeddings using the autoencoder and save them to a NumPy file. This uses checkpointing to handle large datasets and allow resuming.

# --- Generate Embeddings from JSONL ---

# Configuration
checkpoint_interval = 100_000 # Save progress every N lines
input_jsonl_file = "/path/to/save/merged_output.jsonl" # Input JSONL file from previous step
embeddings_file = "/path/to/save/embeddings.npy" # Output NumPy file
checkpoint_file = "/path/to/save/checkpoint.txt" # Checkpoint file

batch_size = 64 # Adjust based on your GPU memory

# --- Function to process batches ---
def process_batch(batch_texts_instructions, batch_texts_responses, autoencoder_model):
    if not batch_texts_instructions or not batch_texts_responses:
        return [], []
    try:
        # Ensure the model's embed_batch method handles lists of texts
        instr_embeddings = autoencoder_model.embed_batch(batch_texts_instructions)
        resp_embeddings = autoencoder_model.embed_batch(batch_texts_responses)
        # Move embeddings to CPU before converting to NumPy
        return instr_embeddings.cpu().numpy(), resp_embeddings.cpu().numpy()
    except Exception as e:
        print(f"Error processing batch: {e}")
        # Optionally, try processing item by item as a fallback
        instr_np, resp_np = [], []
        for instr, resp in zip(batch_texts_instructions, batch_texts_responses):
             try:
                 instr_emb = autoencoder_model.embed(instr).cpu().numpy()
                 resp_emb = autoencoder_model.embed(resp).cpu().numpy()
                 instr_np.append(instr_emb)
                 resp_np.append(resp_emb)
             except Exception as item_e:
                 print(f"Error processing item '{instr[:50]}...': {item_e}")
        return np.array(instr_np) if instr_np else [], np.array(resp_np) if resp_np else []


# --- Main Embedding Generation Logic ---
try:
    with open(input_jsonl_file, "r", encoding="utf-8") as f:
        total_lines = sum(1 for _ in f)
    print(f"Total lines in file: {total_lines}")
except FileNotFoundError:
    print(f"Error: Input JSONL file not found at {input_jsonl_file}")
    exit()


if os.path.exists(checkpoint_file):
    with open(checkpoint_file, "r") as f:
        last_processed_line = int(f.read().strip())
    print(f"Resuming from line: {last_processed_line + 1}")
else:
    last_processed_line = 0
    print("No checkpoint found. Starting from the beginning.")


if os.path.exists(embeddings_file) and last_processed_line > 0:
    # Load only if resuming and file exists
    try:
        existing_embeddings = np.load(embeddings_file)
         # Ensure we only load embeddings corresponding to processed lines
        num_expected_embeddings = last_processed_line * 2
        if existing_embeddings.shape[0] >= num_expected_embeddings:
             embeddings_list = list(existing_embeddings[:num_expected_embeddings])
             print(f"Loaded {len(embeddings_list)} embeddings from previous run (up to line {last_processed_line}).")
        else:
             print("Warning: Embedding file size doesn't match checkpoint. Starting embeddings list fresh.")
             embeddings_list = []
             last_processed_line = 0 # Reset checkpoint if mismatch
    except Exception as e:
        print(f"Error loading existing embeddings: {e}. Starting fresh.")
        embeddings_list = []
        last_processed_line = 0 # Reset checkpoint on load error
else:
    embeddings_list = []
    if last_processed_line > 0:
         print("Warning: Checkpoint found but no embeddings file. Resetting checkpoint.")
         last_processed_line = 0 # Reset checkpoint if no embedding file


overall_line_count = last_processed_line
batch_instructions: list[str] = []
batch_responses: list[str] = []

try:
    with open(input_jsonl_file, "r", encoding="utf-8") as f:
        # Skip lines up to the checkpoint
        for _ in range(last_processed_line):
            next(f)

        # Process remaining lines with tqdm
        pbar = tqdm(f, total=total_lines, initial=last_processed_line, unit="line", desc="Processing lines")
        for line in pbar:
            overall_line_count += 1
            try:
                obj = json.loads(line)
                if "instruction" in obj and "response" in obj:
                    batch_instructions.append(obj["instruction"])
                    batch_responses.append(obj["response"])
                else:
                    print(f"Skipping line {overall_line_count}: Missing 'instruction' or 'response'.")
                    continue # Skip this line

                # Process batch when full
                if len(batch_instructions) == batch_size:
                    instr_np, resp_np = process_batch(batch_instructions, batch_responses, autoencoder)
                    # Add pairs to the list
                    for i in range(len(instr_np)):
                        embeddings_list.append(instr_np[i])
                        embeddings_list.append(resp_np[i])

                    # Clear batches
                    batch_instructions = []
                    batch_responses = []

                    # Checkpoint saving logic
                    if overall_line_count % checkpoint_interval == 0:
                        current_line_to_save = overall_line_count
                        with open(checkpoint_file, "w") as cf:
                            cf.write(str(current_line_to_save))
                        np.save(embeddings_file, np.array(embeddings_list, dtype=np.float32))
                        pbar.set_postfix_str(f"Checkpoint saved at line {current_line_to_save}")


            except json.JSONDecodeError:
                print(f"Skipping line {overall_line_count}: Invalid JSON.")
            except Exception as e:
                 print(f"Error on line {overall_line_count}: {e}")


        # Process any remaining items in the last batch
        if batch_instructions:
             instr_np, resp_np = process_batch(batch_instructions, batch_responses, autoencoder)
             for i in range(len(instr_np)):
                 embeddings_list.append(instr_np[i])
                 embeddings_list.append(resp_np[i])

    # Final save
    with open(checkpoint_file, "w") as cf:
        cf.write(str(overall_line_count))
    np.save(embeddings_file, np.array(embeddings_list, dtype=np.float32))
    print(f"Processing complete. Final line count: {overall_line_count}/{total_lines}. Embeddings saved.")

except FileNotFoundError:
    print(f"Error: Input JSONL file not found at {input_jsonl_file}")
except Exception as e:
    print(f"An unexpected error occurred: {e}")
    # Save progress even on error
    with open(checkpoint_file, "w") as cf:
        cf.write(str(overall_line_count - len(batch_instructions))) # Save last fully processed line
    if embeddings_list:
        np.save(embeddings_file, np.array(embeddings_list, dtype=np.float32))
    print("Saved progress before exiting due to error.")

3. Latent Manipulator Model Code

This defines the neural network that learns to map question embeddings to answer embeddings. The architecture shown uses multiple feed-forward layers with skip-connection-like features (concatenating intermediate "choked" outputs).

  • Input: 1024-dimensional latent vector (from Encoder).
  • Architecture: A series of Linear layers, BatchNorm, LeakyReLU, and Dropout. Intermediate outputs are "choked" (reduced) to 2048 dimensions and concatenated with the original input before final aggregation layers. This complex structure aims to handle deep transformations while mitigating vanishing/exploding gradients.
  • Output: 1024-dimensional latent vector (to be fed to Decoder).
# --- Latent Manipulator Model Definition ---

class LatentManipulator(nn.Module):
    """
    A Feed-Forward Network designed to manipulate latent embeddings.
    Takes a 1024-dim embedding and outputs a 1024-dim embedding.
    Uses intermediate layer outputs (choked) and concatenation for richness.
    """
    def __init__(self, dropout_rate=0.2): # Reduced dropout from original
        super(LatentManipulator, self).__init__()
        
        # --- Main Layers (Expand -> Contract) ---
        self.layer1 = nn.Sequential(nn.Linear(1024, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
        self.layer2 = nn.Sequential(nn.Linear(2048, 4096), nn.BatchNorm1d(4096), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
        self.layer3 = nn.Sequential(nn.Linear(4096, 6144), nn.BatchNorm1d(6144), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
        self.layer4 = nn.Sequential(nn.Linear(6144, 9216), nn.BatchNorm1d(9216), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate)) # Widest layer
        self.layer5 = nn.Sequential(nn.Linear(9216, 6144), nn.BatchNorm1d(6144), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
        self.layer6 = nn.Sequential(nn.Linear(6144, 4096), nn.BatchNorm1d(4096), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
        self.layer7 = nn.Sequential(nn.Linear(4096, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))

        # --- Choke Layers (Reduce intermediate outputs to 2048) ---
        # These act like shortcuts, bringing information from earlier layers forward.
        self.choke1 = nn.Sequential(nn.Linear(2048, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
        self.choke2 = nn.Sequential(nn.Linear(4096, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
        self.choke3 = nn.Sequential(nn.Linear(6144, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
        self.choke4 = nn.Sequential(nn.Linear(9216, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
        self.choke5 = nn.Sequential(nn.Linear(6144, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
        self.choke6 = nn.Sequential(nn.Linear(4096, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
        self.choke7 = nn.Sequential(nn.Linear(2048, 2048), nn.BatchNorm1d(2048), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))

        # --- Aggregation Layers (Combine concatenated features) ---
        # Total input size = 1024 (original input) + 7 * 2048 (choked outputs) = 15360
        self.aLayer1 = nn.Sequential(nn.Linear(15360, 8192), nn.BatchNorm1d(8192), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
        self.aLayer2 = nn.Sequential(nn.Linear(8192, 4096), nn.BatchNorm1d(4096), nn.LeakyReLU(0.01, inplace=True), nn.Dropout(dropout_rate))
        
        # --- Final Output Layer ---
        self.output_layer = nn.Linear(4096, 1024) # Output matches input dimension

    def forward(self, x):
        # Pass through main layers
        x1 = self.layer1(x); x2 = self.layer2(x1); x3 = self.layer3(x2)
        x4 = self.layer4(x3); x5 = self.layer5(x4); x6 = self.layer6(x5)
        x7 = self.layer7(x6)

        # Apply choke layers
        c1 = self.choke1(x1); c2 = self.choke2(x2); c3 = self.choke3(x3)
        c4 = self.choke4(x4); c5 = self.choke5(x5); c6 = self.choke6(x6)
        c7 = self.choke7(x7)

        # Concatenate original input and all choked outputs
        concat = torch.cat([x, c1, c2, c3, c4, c5, c6, c7], dim=1) # Dim 1 for batch processing

        # Pass through aggregation layers
        out = self.aLayer1(concat)
        out = self.aLayer2(out)
        out = self.output_layer(out)
        return out

# --- Helper: Weight Initialization ---
def init_weights(m):
    """Applies Kaiming Normal initialization for LeakyReLU."""
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

# --- Helper: Count Parameters ---
def count_parameters(model):
    """Counts the number of trainable parameters in a model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

4. Training the Latent Manipulator

This involves setting up a dataset to efficiently load the pre-computed embeddings and a training loop.

a) Dataset Class

Uses NumPy's memory mapping (mmap_mode='r') to avoid loading the entire (potentially huge) embeddings file into RAM. It loads only the required question-answer pair for each __getitem__ call.

# --- NumPy Embedding Dataset ---
class NPYEmbeddingDataset(Dataset):
    """
    Lazily loads pairs of embeddings (instruction, response) from a NumPy file
    using memory mapping for efficiency with large files.
    Assumes embeddings are stored sequentially: [instr1, resp1, instr2, resp2, ...].
    """
    def __init__(self, npy_file):
        self.npy_file = npy_file
        # Load with mmap_mode to get shape/dtype without loading all data
        try:
             with np.load(npy_file, mmap_mode='r') as data:
                 self.shape = data.shape
                 self.dtype = data.dtype
        except FileNotFoundError:
             print(f"Error: NPY embedding file not found at {npy_file}")
             raise
        except Exception as e:
             print(f"Error loading NPY file: {e}")
             raise


        if len(self.shape) != 2 or self.shape[0] % 2 != 0:
            raise ValueError(f"Expected a 2D numpy array with an even number of rows (embeddings). Got shape: {self.shape}")
            
        self.num_pairs = self.shape[0] // 2
        self.embedding_dim = self.shape[1]
        print(f"Dataset initialized: {self.num_pairs} pairs, embedding dim {self.embedding_dim}")


    def __len__(self):
        """Returns the number of instruction-response pairs."""
        return self.num_pairs

    def __getitem__(self, idx):
        """Loads the idx-th instruction and response embedding pair."""
        if idx >= self.num_pairs:
            raise IndexError("Index out of bounds")
            
        # Load with mmap_mode again inside __getitem__ for multi-process loading safety
        # This ensures each worker gets its own file handle if num_workers > 0
        data = np.load(self.npy_file, mmap_mode='r')
        
        # Calculate row indices for the pair
        q_idx = idx * 2
        a_idx = idx * 2 + 1
        
        # Extract the embeddings and convert to tensors
        q_emb = torch.from_numpy(data[q_idx].copy()).float() # Use .copy() with mmap
        a_emb = torch.from_numpy(data[a_idx].copy()).float()
        
        return q_emb, a_emb

b) Training Loop

Standard PyTorch training loop using the defined dataset and model. Key features:

  • Loss: Mean Squared Error (MSELoss) because we are comparing output embeddings to target embeddings (regression).
  • Optimizer: AdamW is a good default choice.
  • Learning Rate Scheduling: ReduceLROnPlateau adjusts the learning rate based on validation loss (or average training loss here) stagnation. A warmup phase is also added to start with a lower LR and gradually increase it, improving stability early in training.
  • Gradient Clipping: Prevents exploding gradients, crucial for deep networks.
  • Checkpointing: Saves the model state periodically, especially when the loss improves.
# --- Utility to save checkpoints ---
def save_checkpoint(state, filename="checkpoint.pt"):
    """Saves model and optimizer state."""
    try:
        torch.save(state, filename)
        print(f"Checkpoint saved to {filename}")
    except Exception as e:
        print(f"Error saving checkpoint: {e}")


# --- Training Function ---
def train(model, dataloader, epochs=10, base_lr=1e-4, warmup_epochs=1, clip_value=5.0, device=None, checkpoint_dir="checkpoints"):
    """Trains the LatentManipulator model."""
    if device is None:
        if torch.cuda.is_available(): device = torch.device("cuda")
        elif torch.backends.mps.is_available(): device = torch.device("mps")
        else: device = torch.device("cpu")
    print(f"Training on device: {device}")
    model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.01)
    criterion = nn.MSELoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, verbose=True) # More aggressive reduction factor

    os.makedirs(checkpoint_dir, exist_ok=True)
    best_loss = float('inf')

    print(f"Starting training for {epochs} epochs...")

    for epoch in range(epochs):
        model.train() # Set model to training mode
        running_loss = 0.0
        
        # Learning Rate Warmup
        if epoch < warmup_epochs:
            lr = base_lr * (epoch + 1) / warmup_epochs
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            current_lr = lr
        else:
             # Get current LR from the optimizer after warmup potentially adjusted by scheduler
             current_lr = optimizer.param_groups[0]['lr']

        print(f"\n--- Epoch {epoch+1}/{epochs} --- LR: {current_lr:.6f}")

        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1} Training", leave=True)
        for batch_idx, (q_emb, a_emb) in enumerate(pbar):
            q_emb, a_emb = q_emb.to(device), a_emb.to(device)

            optimizer.zero_grad()
            outputs = model(q_emb)
            loss = criterion(outputs, a_emb)

            # Check for NaN loss
            if torch.isnan(loss):
                 print(f"NaN loss detected at Epoch {epoch+1}, Batch {batch_idx}. Stopping training.")
                 # Optionally save state before exiting
                 save_checkpoint({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'avg_loss': float('inf') }, os.path.join(checkpoint_dir, "checkpoint_error_nan.pt"))
                 return # Stop training

            loss.backward()
            # Gradient Clipping
            grad_norm = nn.utils.clip_grad_norm_(model.parameters(), clip_value)
            optimizer.step()

            running_loss += loss.item()
            
            # Update progress bar
            if (batch_idx + 1) % 100 == 0: # Update less frequently
                 pbar.set_postfix_str(f"Loss: {loss.item():.4f}, GradNorm: {grad_norm:.4f}")


        avg_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch+1} Average Loss: {avg_loss:.6f}")

        # Step the scheduler based on the average loss for the epoch
        scheduler.step(avg_loss)

        # Save checkpoint if loss improved
        if avg_loss < best_loss:
            print(f"Loss improved from {best_loss:.6f} to {avg_loss:.6f}. Saving checkpoint...")
            best_loss = avg_loss
            checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}_best.pt")
            save_checkpoint({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'avg_loss': avg_loss
            }, filename=checkpoint_path)
        else:
             print(f"Loss did not improve from {best_loss:.6f}.")
             # Optional: Save checkpoint every epoch regardless
             # checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pt")
             # save_checkpoint({ ... }, filename=checkpoint_path)


    print("\nTraining complete.")


# --- Main Training Execution Block ---
if __name__ == "__main__":
    npy_file = "/path/to/save/embeddings.npy" # Make sure this path is correct
    checkpoint_dir = "latent_manipulator_checkpoints" # Directory to save model checkpoints

    try:
        dataset = NPYEmbeddingDataset(npy_file)
        # Adjust batch_size and num_workers based on your system
        dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True if device=='cuda' else False)

        model = LatentManipulator(dropout_rate=0.2) # Instantiate the model
        model.apply(init_weights) # Initialize weights

        print("Model Architecture:\n", model)
        total_params = count_parameters(model)
        print(f"\nTotal Trainable Parameters: {total_params:,}")

        # Start training
        train(model, dataloader, epochs=10, base_lr=1e-4, clip_value=5.0, device=device, checkpoint_dir=checkpoint_dir)

    except FileNotFoundError:
         print(f"Error: Embeddings file not found at {npy_file}. Please generate embeddings first.")
    except Exception as e:
         print(f"An error occurred during training setup or execution: {e}")

5. Inference: Using the Trained Model

To get an answer for a new question:

  1. Load the trained LatentManipulator model from a checkpoint.
  2. Load the BottleneckT5Autoencoder (needed for encoding the question and decoding the answer).
  3. Encode the input question text into its latent vector using the autoencoder.
  4. Pass this latent vector through the loaded LatentManipulator to get the predicted latent vector for the answer.
  5. Decode this answer latent vector back into text using the autoencoder.
# --- Inference Script ---
# Make sure latent_manipulator.py (containing the model definition) is accessible
# and you have the BottleneckT5Autoencoder class defined/imported

# --- Function to load the Latent Manipulator model ---
def load_manipulator(checkpoint_path, device):
    """Loads the trained LatentManipulator from a checkpoint file."""
    # Instantiate the model architecture (ensure dropout_rate matches training)
    model = LatentManipulator(dropout_rate=0.2)
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval() # Set to evaluation mode
        print(f"Loaded LatentManipulator from epoch {checkpoint.get('epoch', 'N/A')} with loss {checkpoint.get('avg_loss', 'N/A'):.6f}")
        return model
    except FileNotFoundError:
        print(f"Error: Checkpoint file not found at {checkpoint_path}")
        raise
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        raise

# --- Main Inference Execution ---
if __name__ == "__main__":
    # --- Configuration ---
    autoencoder_model_path = 'thesephist/contra-bottleneck-t5-large-wikipedia'
    manipulator_checkpoint_path = "latent_manipulator_checkpoints/checkpoint_epoch_10_best.pt" # Path to your best saved checkpoint

    # Determine device
    if torch.cuda.is_available(): device = 'cuda'
    elif torch.backends.mps.is_available(): device = 'mps'
    else: device = 'cpu'
    print(f"Using device for inference: {device}")

    # --- Load Models ---
    try:
        # Load the Autoencoder (needed for embed/generate)
        autoencoder = BottleneckT5Autoencoder(model_path=autoencoder_model_path, device=device)
        
        # Load the trained Latent Manipulator
        manipulator_model = load_manipulator(manipulator_checkpoint_path, device)
    except Exception as e:
        print(f"Failed to load models: {e}")
        exit()


    # --- Get Input and Generate ---
    while True:
        try:
            input_text = input("Enter your question (or type 'quit' to exit): ")
            if input_text.lower() == 'quit':
                break
            if not input_text:
                continue

            # 1. Encode the input question
            input_embedding_latent = autoencoder.embed(input_text) # Shape [1, 1024]

            # Ensure it's on the correct device (embed should handle this, but double-check)
            input_embedding_latent = input_embedding_latent.to(device)

            # 2. Manipulate the latent vector to get the answer latent
            with torch.no_grad():
                output_embedding_latent = manipulator_model(input_embedding_latent) # Shape [1, 1024]

            # 3. Decode the answer latent back to text
            # Ensure the latent vector is detached and on CPU if generate_from_latent expects it,
            # but the provided class seems to handle device transfer internally.
            output_text = autoencoder.generate_from_latent(output_embedding_latent, temperature=0.5) # Adjust temperature as needed

            print("\nInput:  ", input_text)
            print("Output: ", output_text)
            print("-" * 30)

        except KeyboardInterrupt:
            print("\nExiting.")
            break
        except Exception as e:
            print(f"An error occurred during generation: {e}")

Resources

Conclusion

The Latent Manipulator presents an intriguing alternative to standard sequential text generation. By separating the "thinking" (latent space transformation) from the "speaking" (text decoding), it opens up possibilities for potentially more efficient, controllable, and perhaps even language-agnostic reasoning in AI models. While still experimental, this approach highlights the ongoing exploration into different ways AI can process and generate information.

This article was compiled with the help of an LLM.

@Gal-Lahat
Copy link
Author

Have questions or want to help others?
Use this comments section to ask, share tips, or lend a hand to others.

@rENS206
Copy link

rENS206 commented Apr 11, 2025

Hi, I'm new to ai and find it very fascinating. Lately I've been thinking about creating my own model from scratch. I also have a question to ask from the video you made. What's its trained parameters? I am very curious because the second model by open ai was gpt 2, and from the outputs of this model, while not exactly great them model you made does seem very impressive. It could also be a new very efficient way of creating llms.

@MilkywayRides
Copy link

Hi, I'm new to ai and find it very fascinating. Lately I've been thinking about creating my own model from scratch. I also have a question to ask from the video you made. What's its trained parameters? I am very curious because the second model by open ai was gpt 2, and from the outputs of this model, while not exactly great them model you made does seem very impressive. It could also be a new very efficient way of creating llms.

you can check 3blue1brown - https://www.youtube.com/playlist?list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment