Skip to content

Instantly share code, notes, and snippets.

@secemp9
Created May 31, 2025 04:21
Show Gist options
  • Save secemp9/ceed0bc6ec6f48bca114cc3118a85281 to your computer and use it in GitHub Desktop.
Save secemp9/ceed0bc6ec6f48bca114cc3118a85281 to your computer and use it in GitHub Desktop.
rotation invariant CNN implementation
# -*- coding: utf-8 -*-
"""
CyCNN implementation for rotation invariant image classification.
Based on "CyCNN: A Rotation Invariant CNN using Polar Mapping and Cylindrical Convolution Layers"
"""
# Core Libraries
import os
import random
import time
import gc
import math
# Image Processing and Data Handling
import cv2
from PIL import Image, ImageDraw, ImageFilter
import numpy as np
import pandas as pd
import imagesize
from concurrent.futures import ThreadPoolExecutor
import shutil
# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
from torch.amp import GradScaler, autocast
# Plotting and Display
import matplotlib.pyplot as plt
from numpy import arange
from tqdm import tqdm
from IPython.display import display
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
set_seed(69)
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
class PolarTransform:
"""Handles conversion from Cartesian to polar coordinates"""
def __init__(self, output_size=(224, 224), use_log_polar=False):
self.output_size = output_size
self.use_log_polar = use_log_polar
def cartesian_to_polar(self, image):
img = np.asarray(image)
h, w = img.shape[:2]; cx, cy = w//2, h//2
max_r = min(cx, cy)
flag = cv2.WARP_POLAR_LOG if self.use_log_polar else cv2.WARP_POLAR_LINEAR
polar = cv2.warpPolar(img, self.output_size, (cx, cy), max_r, flag)
return polar
# ~ def cartesian_to_polar(self, image):
# ~ """Convert image from Cartesian to polar coordinates"""
# ~ if isinstance(image, Image.Image):
# ~ image = np.array(image)
# ~ h, w = image.shape[:2]
# ~ center_x, center_y = w // 2, h // 2
# ~ # Maximum radius (from center to corner)
# ~ max_radius = min(center_x, center_y)
# ~ # Create polar coordinate grid
# ~ output_h, output_w = self.output_size
# ~ # Radius goes from 0 to max_radius
# ~ # Angle goes from 0 to 2π
# ~ polar_image = np.zeros((output_h, output_w, image.shape[2] if len(image.shape) == 3 else 1))
# ~ for r_idx in range(output_h):
# ~ for theta_idx in range(output_w):
# ~ # Map to polar coordinates
# ~ if self.use_log_polar:
# ~ # Log-polar: r = exp(r_idx * log(max_radius) / output_h)
# ~ if r_idx == 0:
# ~ radius = 1 # Avoid log(0)
# ~ else:
# ~ radius = np.exp(r_idx * np.log(max_radius) / output_h)
# ~ else:
# ~ # Linear polar: r = r_idx * max_radius / output_h
# ~ radius = r_idx * max_radius / output_h
# ~ # Angle: theta = theta_idx * 2π / output_w
# ~ theta = theta_idx * 2 * np.pi / output_w
# ~ # Convert back to Cartesian
# ~ x = int(center_x + radius * np.cos(theta))
# ~ y = int(center_y + radius * np.sin(theta))
# ~ # Check bounds and sample
# ~ if 0 <= x < w and 0 <= y < h:
# ~ polar_image[r_idx, theta_idx] = image[y, x]
# ~ return polar_image.astype(np.uint8)
def __call__(self, image):
polar_img = self.cartesian_to_polar(image)
if len(polar_img.shape) == 3:
return Image.fromarray(polar_img)
else:
return Image.fromarray(polar_img.squeeze(), mode='L')
class CyConv2d(nn.Module):
"""Cylindrical Convolution Layer with Cylindrically Sliding Windows (CSW)"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
super(CyConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation)
self.groups = groups
# Standard convolution layer
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
def forward(self, x):
# Apply cylindrical padding
x_padded = self.apply_cylindrical_padding(x)
return self.conv(x_padded)
def apply_cylindrical_padding(self, x):
"""
Cylindrical padding:
• zero-pad along the *radius* axis (N, C, **H**, W)
• wrap-around along the *angle* axis (N, C, H, **W**)
"""
pad_r, pad_a = self.padding # pad_r = rows, pad_a = columns
# 1. zero-pad top & bottom (radius / height)
if pad_r > 0:
x = F.pad(x, (0, 0, pad_r, pad_r), mode='constant', value=0)
# 2. wrap left & right (angle / width)
if pad_a > 0:
left = x[..., -pad_a:] # last pad_a columns
right = x[..., :pad_a] # first pad_a columns
x = torch.cat([left, x, right], dim=3)
return x
class CyVGG(nn.Module):
"""VGG-like architecture with Cylindrical Convolutions"""
def __init__(self, num_classes=10, use_polar=True):
super(CyVGG, self).__init__()
self.use_polar = use_polar
# Feature extraction layers
self.features = nn.Sequential(
# Block 1
CyConv2d(3, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
CyConv2d(64, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# Block 2
CyConv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
CyConv2d(128, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# Block 3
CyConv2d(128, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
CyConv2d(256, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
CyConv2d(256, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# Block 4
CyConv2d(256, 512, 3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
CyConv2d(512, 512, 3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
CyConv2d(512, 512, 3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
)
# Classifier
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(256, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(128, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
class CyEfficientNet(nn.Module):
"""EfficientNet with Cylindrical Convolutions for initial layers"""
def __init__(self, num_classes=10, use_polar=True):
super(CyEfficientNet, self).__init__()
self.use_polar = use_polar
# Load pre-trained EfficientNet
self.backbone = models.efficientnet_b0(weights="IMAGENET1K_V1")
# Replace first conv layer with cylindrical version
original_conv = self.backbone.features[0][0]
self.backbone.features[0][0] = CyConv2d(
in_channels=3,
out_channels=original_conv.out_channels,
kernel_size=original_conv.kernel_size,
stride=original_conv.stride,
padding=original_conv.padding[0]
)
# Modify classifier
num_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(num_features, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
def forward(self, x):
return self.backbone(x)
class RotationDataset(Dataset):
"""Dataset for rotation invariance testing"""
def __init__(self, df, polar_transform=None, standard_transform=None, rotation_angles=None):
self.df = df.reset_index(drop=True)
self.polar_transform = polar_transform
self.standard_transform = standard_transform
# If no specific angles provided, use random rotations
if rotation_angles is None:
self.rotation_angles = list(range(0, 360, 10)) # Every 10 degrees
else:
self.rotation_angles = rotation_angles
# Create all combinations of images and rotations
self.samples = []
for idx, row in df.iterrows():
for angle in self.rotation_angles:
self.samples.append((row['filename'], angle))
def __len__(self):
return len(self.samples)
def rotate_image(self, image, angle):
"""Rotate image by given angle"""
if angle == 0:
return image
# Convert PIL to numpy for rotation
img_array = np.array(image)
h, w = img_array.shape[:2]
center = (w // 2, h // 2)
# Rotation matrix
M = cv2.getRotationMatrix2D(center, angle, 1.0)
rotated = cv2.warpAffine(img_array, M, (w, h))
return Image.fromarray(rotated)
def __getitem__(self, idx):
filepath, angle = self.samples[idx]
# Load image
try:
image = Image.open(filepath).convert('RGB')
except:
# Return dummy data if file not found
image = Image.new('RGB', (224, 224), color='red')
# Rotate image
rotated_image = self.rotate_image(image, angle)
# Apply polar transformation if provided
if self.polar_transform is not None:
rotated_image = self.polar_transform(rotated_image)
# Apply standard transforms
if self.standard_transform is not None:
rotated_image = self.standard_transform(rotated_image)
else:
rotated_image = transforms.ToTensor()(rotated_image)
# Label: angle in degrees (for regression) or angle class (for classification)
# For simplicity, let's use classification with angle bins
angle_class = angle // 10 # 36 classes (0-35, each representing 10-degree bins)
return rotated_image, angle_class, angle
# Training utilities (same as before with modifications)
class EarlyStopping:
def __init__(self, patience=5, delta=0.001, save_path='best_model.pth', verbose=False):
self.patience = patience
self.delta = delta
self.save_path = save_path
self.verbose = verbose
self.counter = 0
self.best_loss = float('inf')
self.early_stop = False
def __call__(self, val_loss, model):
if val_loss < self.best_loss - self.delta:
self.best_loss = val_loss
self.counter = 0
if self.save_path:
torch.save(model.state_dict(), self.save_path)
if self.verbose:
print(f"Validation loss decreased ({self.best_loss:.6f}). Saving model ...")
else:
self.counter += 1
if self.verbose:
print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
if self.counter >= self.patience:
self.early_stop = True
if self.verbose:
print("Early stopping")
def train_function(model, loss_fn, optimizer, dataloader, device, scaler, log_interval=None):
model.train()
total_loss = 0.0
num_batches = len(dataloader)
pbar = tqdm(enumerate(dataloader), total=num_batches, desc="Training")
for i, batch in pbar:
if len(batch) == 3:
data, label, _ = batch # angle_class, actual_angle
else:
data, label = batch
data, label = data.to(device, non_blocking=True), label.to(device, non_blocking=True)
optimizer.zero_grad()
with autocast('cuda'):
output = model(data)
loss = loss_fn(output, label)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
total_loss += loss.item()
pbar.set_postfix({'loss': loss.item()})
avg_epoch_loss = total_loss / num_batches
return avg_epoch_loss
def validate_function(model, loss_fn, dataloader, device):
model.eval()
total_loss = 0.0
correct_predictions = 0
num_samples = 0
num_batches = len(dataloader)
pbar = tqdm(dataloader, total=num_batches, desc="Validating")
with torch.no_grad():
for batch in pbar:
if len(batch) == 3:
data, label, _ = batch
else:
data, label = batch
data, label = data.to(device, non_blocking=True), label.to(device, non_blocking=True)
with autocast('cuda'):
output = model(data)
loss = loss_fn(output, label)
total_loss += loss.item()
predicted_classes = output.argmax(dim=1)
correct_predictions += (predicted_classes == label).sum().item()
num_samples += label.size(0)
pbar.set_postfix({'val_loss': loss.item()})
avg_loss = total_loss / num_batches
accuracy = correct_predictions / num_samples if num_samples > 0 else 0
return avg_loss, accuracy
def train_model_loop(epochs, model, loss_fn, optimizer, train_dataloader, val_dataloader, device,
scheduler=None, start_epoch=0, patience=5, save_path='best_model.pth'):
GradScaler(device='cuda')
early_stopping = EarlyStopping(patience=patience, save_path=save_path, verbose=True)
history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}
for epoch_num in range(start_epoch, epochs):
print(f"\n--- Epoch {epoch_num + 1}/{epochs} ---")
epoch_start_time = time.time()
train_loss = train_function(model, loss_fn, optimizer, train_dataloader, device, scaler)
val_loss, val_accuracy = validate_function(model, loss_fn, val_dataloader, device)
epoch_duration = time.time() - epoch_start_time
history['train_loss'].append(train_loss)
history['val_loss'].append(val_loss)
history['val_accuracy'].append(val_accuracy)
print(f"Epoch Summary: Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.4f} | Time: {epoch_duration:.2f}s")
if scheduler:
scheduler.step()
early_stopping(val_loss, model)
if early_stopping.early_stop:
print("Early stopping triggered. Loading best model weights.")
if os.path.exists(early_stopping.save_path):
model.load_state_dict(torch.load(early_stopping.save_path))
break
return history
def plot_training_history(history):
epochs_ran = len(history['train_loss'])
epoch_axis = range(1, epochs_ran + 1)
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epoch_axis, history['train_loss'], label='Training Loss')
plt.plot(epoch_axis, history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.subplot(1, 2, 2)
plt.plot(epoch_axis, history['val_accuracy'], label='Validation Accuracy', color='green')
plt.title('Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("cycnn_training_history.png")
plt.show()
# Image processing utilities (reused from original)
def generate_df_from_directory(base_img_path):
"""Generates a DataFrame with 'filename' column for all JPEGs in a directory."""
filenames = []
if not os.path.isdir(base_img_path):
print(f"Error: Image path '{base_img_path}' not found or is not a directory.")
return pd.DataFrame(filenames, columns=['filename'])
for root, _, files in os.walk(base_img_path):
for file in files:
if file.lower().endswith(".jpeg") or file.lower().endswith(".jpg"):
filenames.append(os.path.join(root, file))
df = pd.DataFrame(filenames, columns=['filename'])
print(f"Found {len(filenames)} JPEG files in '{base_img_path}'.")
return df
# Main execution
if __name__ == '__main__':
# Configuration
RAW_IMAGE_DIR = 'imagenet-mini'
NUM_EPOCHS = 20
LEARNING_RATE = 1e-3
BATCH_SIZE = 532 # Reduced for CyCNN
NUM_CLASSES = 36 # 36 angle classes (10-degree bins)
USE_POLAR = True
USE_LOG_POLAR = False
print("--- CyCNN: Rotation Invariant CNN Training ---")
# Create directory with dummy image if needed
if not os.path.exists(RAW_IMAGE_DIR):
os.makedirs(RAW_IMAGE_DIR)
try:
dummy_img = Image.new('RGB', (400, 400), color='blue')
dummy_img.save(os.path.join(RAW_IMAGE_DIR, "dummy_image.jpg"))
print(f"Created a dummy image in {RAW_IMAGE_DIR} for testing.")
except Exception as e:
print(f"Could not create dummy image: {e}")
# Load dataset
df_images = generate_df_from_directory(RAW_IMAGE_DIR)
if df_images.empty:
print("No images found. Exiting.")
exit()
# Setup transforms
polar_transform = PolarTransform(output_size=(224, 224), use_log_polar=USE_LOG_POLAR) if USE_POLAR else None
standard_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Create dataset with different rotation angles for training/testing
train_angles = list(range(0, 360, 15)) # Every 15 degrees for training
test_angles = list(range(0, 360, 10)) # Every 10 degrees for testing
dataset = RotationDataset(
df=df_images,
polar_transform=polar_transform,
standard_transform=standard_transforms,
rotation_angles=train_angles
)
print(f"Total samples in dataset: {len(dataset)}")
if len(dataset) == 0:
print("No samples in dataset. Exiting.")
exit()
# Split dataset
total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size
train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size])
print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")
# Create dataloaders
num_workers = 4 if device.type == 'cuda' else 2
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
num_workers=num_workers, pin_memory=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
num_workers=num_workers, pin_memory=True)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
num_workers=num_workers, pin_memory=True)
# Initialize model
print(f"\n--- Initializing CyCNN Model (Use Polar: {USE_POLAR}) ---")
# Choose between CyVGG or CyEfficientNet
model = CyEfficientNet(num_classes=NUM_CLASSES, use_polar=USE_POLAR).to(device)
# model = CyVGG(num_classes=NUM_CLASSES, use_polar=USE_POLAR).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
# Print model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
if USE_POLAR:
print("✓ Using polar coordinate transformation")
print("✓ Using cylindrical convolutions for rotation invariance")
# Train model
print("\n--- Starting CyCNN Training ---")
training_history = train_model_loop(
epochs=NUM_EPOCHS,
model=model,
loss_fn=criterion,
optimizer=optimizer,
train_dataloader=train_dl,
val_dataloader=val_dl,
device=device,
scheduler=scheduler,
patience=5,
save_path='best_cycnn_model.pth'
)
# Plot results
print("\n--- Plotting Training History ---")
plot_training_history(training_history)
# Test rotation invariance
print("\n--- Testing Rotation Invariance ---")
test_loss, test_accuracy = validate_function(model, criterion, test_dl, device)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")
print("\n--- CyCNN Training Complete ---")
# Cleanup
if device.type == 'cuda':
torch.cuda.empty_cache()
gc.collect()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment