Skip to content

Instantly share code, notes, and snippets.

@WolfByttner
Created February 28, 2025 13:07
Show Gist options
  • Save WolfByttner/cfc503a0fee37981586988e3ac8a766c to your computer and use it in GitHub Desktop.
Save WolfByttner/cfc503a0fee37981586988e3ac8a766c to your computer and use it in GitHub Desktop.
Sub-ms sine curve prediction (demonstrating M1 hardware performance)
import torch
import torch.nn as nn
import torch.optim as optim
import time
import numpy as np
print("This code example demonstrates how to train a simple MLP model for price prediction using PyTorch.")
print("The model is trained on a sine wave and tested on a shifted sine wave.")
print("The goal is to demonstrate sub-ms latency for price prediction using PyTorch")
print("on Mac M1-M4 with Metal Performance Shaders (MPS) enabled.")
# Define a simple MLP network for price prediction
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Example price prediction function
def predict_price(model, test_data, device, input_dim):
# Ensure input data is a torch tensor and move to appropriate device
# Look at input_dim samples at a time and predict the next value
accuracy = 0
test_length = len(test_data) - input_dim - 1
model.eval()
transfer_time = 0.0
compute_time = 0.0
for start_index in range(test_length):
transfer_start = time.time()
input_tensor = torch.tensor(input_data[start_index:start_index + input_dim], dtype=torch.float32).unsqueeze(0).to(device)
target = torch.tensor([input_data[start_index + input_dim]], dtype=torch.float32).to(device)
transfer_end = time.time()
transfer_time += transfer_end - transfer_start
compute_start = time.time()
output = model(input_tensor)
compute_end = time.time()
compute_time += compute_end - compute_start
loss = torch.abs(output - target)
accuracy += loss.item()
print(f"Average prediction error: {accuracy / test_length:.2f}")
print(f"Average data transfer time: {1000 * transfer_time / test_length:.2f} ms")
print(f"Average compute time: {1000 * compute_time / test_length:.2f} ms")
# Create model instance
input_dim = 30 # Example: 10 features for input (e.g., market data)
hidden_dim = 64 # Number of neurons in hidden layer
output_dim = 1 # Output: Predicted price (single value)
def train_model(model, input_data, device, input_dim):
# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
# Ensure input data is a torch tensor and move to appropriate device
# Look at input_dim samples at a time and predict the next value
# Train the model
model.train()
for _ in range(100):
start_index = np.random.randint(0, len(input_data) - input_dim - 1)
input_tensor = torch.tensor(input_data[start_index:start_index + input_dim], dtype=torch.float32).unsqueeze(0).to(device)
target = torch.tensor([input_data[start_index + input_dim]], dtype=torch.float32).to(device)
optimizer.zero_grad()
output = model(input_tensor)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print("Model trained successfully")
# Display model summary
print(model)
# Set up the model
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") # Check if MPS (Metal) is available for Mac M1
model = MLP(input_dim, hidden_dim, output_dim).to(device)
# Generate regular points on a sin
input_data = [np.sin(x) for x in np.linspace(0, 100, 1000)]
train_model(model, input_data, device, input_dim)
test_data = [np.sin(x) for x in np.linspace(0.15, 100.15, 1000)]
# Run price prediction
predict_price(model, test_data, device, input_dim)
# This should output the average prediction error, data transfer time, and compute time for the model
# M1 Macbook Pro results (example):
# Average prediction error: 0.01
# Average data transfer time: 0.44 ms
# Average compute time: 0.13 ms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment