Created
May 31, 2025 04:21
-
-
Save secemp9/ceed0bc6ec6f48bca114cc3118a85281 to your computer and use it in GitHub Desktop.
rotation invariant CNN implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
CyCNN implementation for rotation invariant image classification. | |
Based on "CyCNN: A Rotation Invariant CNN using Polar Mapping and Cylindrical Convolution Layers" | |
""" | |
# Core Libraries | |
import os | |
import random | |
import time | |
import gc | |
import math | |
# Image Processing and Data Handling | |
import cv2 | |
from PIL import Image, ImageDraw, ImageFilter | |
import numpy as np | |
import pandas as pd | |
import imagesize | |
from concurrent.futures import ThreadPoolExecutor | |
import shutil | |
# PyTorch | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader, Dataset, random_split | |
from torch.amp import GradScaler, autocast | |
# Plotting and Display | |
import matplotlib.pyplot as plt | |
from numpy import arange | |
from tqdm import tqdm | |
from IPython.display import display | |
def set_seed(seed=42): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(seed) | |
set_seed(69) | |
# Setup device | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f'Using device: {device}') | |
class PolarTransform: | |
"""Handles conversion from Cartesian to polar coordinates""" | |
def __init__(self, output_size=(224, 224), use_log_polar=False): | |
self.output_size = output_size | |
self.use_log_polar = use_log_polar | |
def cartesian_to_polar(self, image): | |
img = np.asarray(image) | |
h, w = img.shape[:2]; cx, cy = w//2, h//2 | |
max_r = min(cx, cy) | |
flag = cv2.WARP_POLAR_LOG if self.use_log_polar else cv2.WARP_POLAR_LINEAR | |
polar = cv2.warpPolar(img, self.output_size, (cx, cy), max_r, flag) | |
return polar | |
# ~ def cartesian_to_polar(self, image): | |
# ~ """Convert image from Cartesian to polar coordinates""" | |
# ~ if isinstance(image, Image.Image): | |
# ~ image = np.array(image) | |
# ~ h, w = image.shape[:2] | |
# ~ center_x, center_y = w // 2, h // 2 | |
# ~ # Maximum radius (from center to corner) | |
# ~ max_radius = min(center_x, center_y) | |
# ~ # Create polar coordinate grid | |
# ~ output_h, output_w = self.output_size | |
# ~ # Radius goes from 0 to max_radius | |
# ~ # Angle goes from 0 to 2π | |
# ~ polar_image = np.zeros((output_h, output_w, image.shape[2] if len(image.shape) == 3 else 1)) | |
# ~ for r_idx in range(output_h): | |
# ~ for theta_idx in range(output_w): | |
# ~ # Map to polar coordinates | |
# ~ if self.use_log_polar: | |
# ~ # Log-polar: r = exp(r_idx * log(max_radius) / output_h) | |
# ~ if r_idx == 0: | |
# ~ radius = 1 # Avoid log(0) | |
# ~ else: | |
# ~ radius = np.exp(r_idx * np.log(max_radius) / output_h) | |
# ~ else: | |
# ~ # Linear polar: r = r_idx * max_radius / output_h | |
# ~ radius = r_idx * max_radius / output_h | |
# ~ # Angle: theta = theta_idx * 2π / output_w | |
# ~ theta = theta_idx * 2 * np.pi / output_w | |
# ~ # Convert back to Cartesian | |
# ~ x = int(center_x + radius * np.cos(theta)) | |
# ~ y = int(center_y + radius * np.sin(theta)) | |
# ~ # Check bounds and sample | |
# ~ if 0 <= x < w and 0 <= y < h: | |
# ~ polar_image[r_idx, theta_idx] = image[y, x] | |
# ~ return polar_image.astype(np.uint8) | |
def __call__(self, image): | |
polar_img = self.cartesian_to_polar(image) | |
if len(polar_img.shape) == 3: | |
return Image.fromarray(polar_img) | |
else: | |
return Image.fromarray(polar_img.squeeze(), mode='L') | |
class CyConv2d(nn.Module): | |
"""Cylindrical Convolution Layer with Cylindrically Sliding Windows (CSW)""" | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): | |
super(CyConv2d, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) | |
self.stride = stride if isinstance(stride, tuple) else (stride, stride) | |
self.padding = padding if isinstance(padding, tuple) else (padding, padding) | |
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) | |
self.groups = groups | |
# Standard convolution layer | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) | |
def forward(self, x): | |
# Apply cylindrical padding | |
x_padded = self.apply_cylindrical_padding(x) | |
return self.conv(x_padded) | |
def apply_cylindrical_padding(self, x): | |
""" | |
Cylindrical padding: | |
• zero-pad along the *radius* axis (N, C, **H**, W) | |
• wrap-around along the *angle* axis (N, C, H, **W**) | |
""" | |
pad_r, pad_a = self.padding # pad_r = rows, pad_a = columns | |
# 1. zero-pad top & bottom (radius / height) | |
if pad_r > 0: | |
x = F.pad(x, (0, 0, pad_r, pad_r), mode='constant', value=0) | |
# 2. wrap left & right (angle / width) | |
if pad_a > 0: | |
left = x[..., -pad_a:] # last pad_a columns | |
right = x[..., :pad_a] # first pad_a columns | |
x = torch.cat([left, x, right], dim=3) | |
return x | |
class CyVGG(nn.Module): | |
"""VGG-like architecture with Cylindrical Convolutions""" | |
def __init__(self, num_classes=10, use_polar=True): | |
super(CyVGG, self).__init__() | |
self.use_polar = use_polar | |
# Feature extraction layers | |
self.features = nn.Sequential( | |
# Block 1 | |
CyConv2d(3, 64, 3, padding=1), | |
nn.BatchNorm2d(64), | |
nn.ReLU(inplace=True), | |
CyConv2d(64, 64, 3, padding=1), | |
nn.BatchNorm2d(64), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(2, 2), | |
# Block 2 | |
CyConv2d(64, 128, 3, padding=1), | |
nn.BatchNorm2d(128), | |
nn.ReLU(inplace=True), | |
CyConv2d(128, 128, 3, padding=1), | |
nn.BatchNorm2d(128), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(2, 2), | |
# Block 3 | |
CyConv2d(128, 256, 3, padding=1), | |
nn.BatchNorm2d(256), | |
nn.ReLU(inplace=True), | |
CyConv2d(256, 256, 3, padding=1), | |
nn.BatchNorm2d(256), | |
nn.ReLU(inplace=True), | |
CyConv2d(256, 256, 3, padding=1), | |
nn.BatchNorm2d(256), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(2, 2), | |
# Block 4 | |
CyConv2d(256, 512, 3, padding=1), | |
nn.BatchNorm2d(512), | |
nn.ReLU(inplace=True), | |
CyConv2d(512, 512, 3, padding=1), | |
nn.BatchNorm2d(512), | |
nn.ReLU(inplace=True), | |
CyConv2d(512, 512, 3, padding=1), | |
nn.BatchNorm2d(512), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(2, 2), | |
) | |
# Classifier | |
self.classifier = nn.Sequential( | |
nn.AdaptiveAvgPool2d((1, 1)), | |
nn.Flatten(), | |
nn.Linear(512, 256), | |
nn.ReLU(inplace=True), | |
nn.Dropout(0.5), | |
nn.Linear(256, 128), | |
nn.ReLU(inplace=True), | |
nn.Dropout(0.5), | |
nn.Linear(128, num_classes) | |
) | |
def forward(self, x): | |
x = self.features(x) | |
x = self.classifier(x) | |
return x | |
class CyEfficientNet(nn.Module): | |
"""EfficientNet with Cylindrical Convolutions for initial layers""" | |
def __init__(self, num_classes=10, use_polar=True): | |
super(CyEfficientNet, self).__init__() | |
self.use_polar = use_polar | |
# Load pre-trained EfficientNet | |
self.backbone = models.efficientnet_b0(weights="IMAGENET1K_V1") | |
# Replace first conv layer with cylindrical version | |
original_conv = self.backbone.features[0][0] | |
self.backbone.features[0][0] = CyConv2d( | |
in_channels=3, | |
out_channels=original_conv.out_channels, | |
kernel_size=original_conv.kernel_size, | |
stride=original_conv.stride, | |
padding=original_conv.padding[0] | |
) | |
# Modify classifier | |
num_features = self.backbone.classifier[1].in_features | |
self.backbone.classifier = nn.Sequential( | |
nn.Dropout(0.3), | |
nn.Linear(num_features, 256), | |
nn.ReLU(), | |
nn.Dropout(0.3), | |
nn.Linear(256, num_classes) | |
) | |
def forward(self, x): | |
return self.backbone(x) | |
class RotationDataset(Dataset): | |
"""Dataset for rotation invariance testing""" | |
def __init__(self, df, polar_transform=None, standard_transform=None, rotation_angles=None): | |
self.df = df.reset_index(drop=True) | |
self.polar_transform = polar_transform | |
self.standard_transform = standard_transform | |
# If no specific angles provided, use random rotations | |
if rotation_angles is None: | |
self.rotation_angles = list(range(0, 360, 10)) # Every 10 degrees | |
else: | |
self.rotation_angles = rotation_angles | |
# Create all combinations of images and rotations | |
self.samples = [] | |
for idx, row in df.iterrows(): | |
for angle in self.rotation_angles: | |
self.samples.append((row['filename'], angle)) | |
def __len__(self): | |
return len(self.samples) | |
def rotate_image(self, image, angle): | |
"""Rotate image by given angle""" | |
if angle == 0: | |
return image | |
# Convert PIL to numpy for rotation | |
img_array = np.array(image) | |
h, w = img_array.shape[:2] | |
center = (w // 2, h // 2) | |
# Rotation matrix | |
M = cv2.getRotationMatrix2D(center, angle, 1.0) | |
rotated = cv2.warpAffine(img_array, M, (w, h)) | |
return Image.fromarray(rotated) | |
def __getitem__(self, idx): | |
filepath, angle = self.samples[idx] | |
# Load image | |
try: | |
image = Image.open(filepath).convert('RGB') | |
except: | |
# Return dummy data if file not found | |
image = Image.new('RGB', (224, 224), color='red') | |
# Rotate image | |
rotated_image = self.rotate_image(image, angle) | |
# Apply polar transformation if provided | |
if self.polar_transform is not None: | |
rotated_image = self.polar_transform(rotated_image) | |
# Apply standard transforms | |
if self.standard_transform is not None: | |
rotated_image = self.standard_transform(rotated_image) | |
else: | |
rotated_image = transforms.ToTensor()(rotated_image) | |
# Label: angle in degrees (for regression) or angle class (for classification) | |
# For simplicity, let's use classification with angle bins | |
angle_class = angle // 10 # 36 classes (0-35, each representing 10-degree bins) | |
return rotated_image, angle_class, angle | |
# Training utilities (same as before with modifications) | |
class EarlyStopping: | |
def __init__(self, patience=5, delta=0.001, save_path='best_model.pth', verbose=False): | |
self.patience = patience | |
self.delta = delta | |
self.save_path = save_path | |
self.verbose = verbose | |
self.counter = 0 | |
self.best_loss = float('inf') | |
self.early_stop = False | |
def __call__(self, val_loss, model): | |
if val_loss < self.best_loss - self.delta: | |
self.best_loss = val_loss | |
self.counter = 0 | |
if self.save_path: | |
torch.save(model.state_dict(), self.save_path) | |
if self.verbose: | |
print(f"Validation loss decreased ({self.best_loss:.6f}). Saving model ...") | |
else: | |
self.counter += 1 | |
if self.verbose: | |
print(f"EarlyStopping counter: {self.counter} out of {self.patience}") | |
if self.counter >= self.patience: | |
self.early_stop = True | |
if self.verbose: | |
print("Early stopping") | |
def train_function(model, loss_fn, optimizer, dataloader, device, scaler, log_interval=None): | |
model.train() | |
total_loss = 0.0 | |
num_batches = len(dataloader) | |
pbar = tqdm(enumerate(dataloader), total=num_batches, desc="Training") | |
for i, batch in pbar: | |
if len(batch) == 3: | |
data, label, _ = batch # angle_class, actual_angle | |
else: | |
data, label = batch | |
data, label = data.to(device, non_blocking=True), label.to(device, non_blocking=True) | |
optimizer.zero_grad() | |
with autocast('cuda'): | |
output = model(data) | |
loss = loss_fn(output, label) | |
scaler.scale(loss).backward() | |
scaler.step(optimizer) | |
scaler.update() | |
total_loss += loss.item() | |
pbar.set_postfix({'loss': loss.item()}) | |
avg_epoch_loss = total_loss / num_batches | |
return avg_epoch_loss | |
def validate_function(model, loss_fn, dataloader, device): | |
model.eval() | |
total_loss = 0.0 | |
correct_predictions = 0 | |
num_samples = 0 | |
num_batches = len(dataloader) | |
pbar = tqdm(dataloader, total=num_batches, desc="Validating") | |
with torch.no_grad(): | |
for batch in pbar: | |
if len(batch) == 3: | |
data, label, _ = batch | |
else: | |
data, label = batch | |
data, label = data.to(device, non_blocking=True), label.to(device, non_blocking=True) | |
with autocast('cuda'): | |
output = model(data) | |
loss = loss_fn(output, label) | |
total_loss += loss.item() | |
predicted_classes = output.argmax(dim=1) | |
correct_predictions += (predicted_classes == label).sum().item() | |
num_samples += label.size(0) | |
pbar.set_postfix({'val_loss': loss.item()}) | |
avg_loss = total_loss / num_batches | |
accuracy = correct_predictions / num_samples if num_samples > 0 else 0 | |
return avg_loss, accuracy | |
def train_model_loop(epochs, model, loss_fn, optimizer, train_dataloader, val_dataloader, device, | |
scheduler=None, start_epoch=0, patience=5, save_path='best_model.pth'): | |
GradScaler(device='cuda') | |
early_stopping = EarlyStopping(patience=patience, save_path=save_path, verbose=True) | |
history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []} | |
for epoch_num in range(start_epoch, epochs): | |
print(f"\n--- Epoch {epoch_num + 1}/{epochs} ---") | |
epoch_start_time = time.time() | |
train_loss = train_function(model, loss_fn, optimizer, train_dataloader, device, scaler) | |
val_loss, val_accuracy = validate_function(model, loss_fn, val_dataloader, device) | |
epoch_duration = time.time() - epoch_start_time | |
history['train_loss'].append(train_loss) | |
history['val_loss'].append(val_loss) | |
history['val_accuracy'].append(val_accuracy) | |
print(f"Epoch Summary: Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.4f} | Time: {epoch_duration:.2f}s") | |
if scheduler: | |
scheduler.step() | |
early_stopping(val_loss, model) | |
if early_stopping.early_stop: | |
print("Early stopping triggered. Loading best model weights.") | |
if os.path.exists(early_stopping.save_path): | |
model.load_state_dict(torch.load(early_stopping.save_path)) | |
break | |
return history | |
def plot_training_history(history): | |
epochs_ran = len(history['train_loss']) | |
epoch_axis = range(1, epochs_ran + 1) | |
plt.figure(figsize=(12, 5)) | |
plt.subplot(1, 2, 1) | |
plt.plot(epoch_axis, history['train_loss'], label='Training Loss') | |
plt.plot(epoch_axis, history['val_loss'], label='Validation Loss') | |
plt.title('Training and Validation Loss') | |
plt.xlabel('Epochs') | |
plt.ylabel('Loss') | |
plt.legend() | |
plt.grid(True) | |
plt.subplot(1, 2, 2) | |
plt.plot(epoch_axis, history['val_accuracy'], label='Validation Accuracy', color='green') | |
plt.title('Validation Accuracy') | |
plt.xlabel('Epochs') | |
plt.ylabel('Accuracy') | |
plt.legend() | |
plt.grid(True) | |
plt.tight_layout() | |
plt.savefig("cycnn_training_history.png") | |
plt.show() | |
# Image processing utilities (reused from original) | |
def generate_df_from_directory(base_img_path): | |
"""Generates a DataFrame with 'filename' column for all JPEGs in a directory.""" | |
filenames = [] | |
if not os.path.isdir(base_img_path): | |
print(f"Error: Image path '{base_img_path}' not found or is not a directory.") | |
return pd.DataFrame(filenames, columns=['filename']) | |
for root, _, files in os.walk(base_img_path): | |
for file in files: | |
if file.lower().endswith(".jpeg") or file.lower().endswith(".jpg"): | |
filenames.append(os.path.join(root, file)) | |
df = pd.DataFrame(filenames, columns=['filename']) | |
print(f"Found {len(filenames)} JPEG files in '{base_img_path}'.") | |
return df | |
# Main execution | |
if __name__ == '__main__': | |
# Configuration | |
RAW_IMAGE_DIR = 'imagenet-mini' | |
NUM_EPOCHS = 20 | |
LEARNING_RATE = 1e-3 | |
BATCH_SIZE = 532 # Reduced for CyCNN | |
NUM_CLASSES = 36 # 36 angle classes (10-degree bins) | |
USE_POLAR = True | |
USE_LOG_POLAR = False | |
print("--- CyCNN: Rotation Invariant CNN Training ---") | |
# Create directory with dummy image if needed | |
if not os.path.exists(RAW_IMAGE_DIR): | |
os.makedirs(RAW_IMAGE_DIR) | |
try: | |
dummy_img = Image.new('RGB', (400, 400), color='blue') | |
dummy_img.save(os.path.join(RAW_IMAGE_DIR, "dummy_image.jpg")) | |
print(f"Created a dummy image in {RAW_IMAGE_DIR} for testing.") | |
except Exception as e: | |
print(f"Could not create dummy image: {e}") | |
# Load dataset | |
df_images = generate_df_from_directory(RAW_IMAGE_DIR) | |
if df_images.empty: | |
print("No images found. Exiting.") | |
exit() | |
# Setup transforms | |
polar_transform = PolarTransform(output_size=(224, 224), use_log_polar=USE_LOG_POLAR) if USE_POLAR else None | |
standard_transforms = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Create dataset with different rotation angles for training/testing | |
train_angles = list(range(0, 360, 15)) # Every 15 degrees for training | |
test_angles = list(range(0, 360, 10)) # Every 10 degrees for testing | |
dataset = RotationDataset( | |
df=df_images, | |
polar_transform=polar_transform, | |
standard_transform=standard_transforms, | |
rotation_angles=train_angles | |
) | |
print(f"Total samples in dataset: {len(dataset)}") | |
if len(dataset) == 0: | |
print("No samples in dataset. Exiting.") | |
exit() | |
# Split dataset | |
total_size = len(dataset) | |
train_size = int(0.7 * total_size) | |
val_size = int(0.15 * total_size) | |
test_size = total_size - train_size - val_size | |
train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size]) | |
print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}") | |
# Create dataloaders | |
num_workers = 4 if device.type == 'cuda' else 2 | |
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, | |
num_workers=num_workers, pin_memory=True) | |
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, | |
num_workers=num_workers, pin_memory=True) | |
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, | |
num_workers=num_workers, pin_memory=True) | |
# Initialize model | |
print(f"\n--- Initializing CyCNN Model (Use Polar: {USE_POLAR}) ---") | |
# Choose between CyVGG or CyEfficientNet | |
model = CyEfficientNet(num_classes=NUM_CLASSES, use_polar=USE_POLAR).to(device) | |
# model = CyVGG(num_classes=NUM_CLASSES, use_polar=USE_POLAR).to(device) | |
criterion = nn.CrossEntropyLoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) | |
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) | |
# Print model info | |
total_params = sum(p.numel() for p in model.parameters()) | |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
print(f"Total parameters: {total_params:,}") | |
print(f"Trainable parameters: {trainable_params:,}") | |
if USE_POLAR: | |
print("✓ Using polar coordinate transformation") | |
print("✓ Using cylindrical convolutions for rotation invariance") | |
# Train model | |
print("\n--- Starting CyCNN Training ---") | |
training_history = train_model_loop( | |
epochs=NUM_EPOCHS, | |
model=model, | |
loss_fn=criterion, | |
optimizer=optimizer, | |
train_dataloader=train_dl, | |
val_dataloader=val_dl, | |
device=device, | |
scheduler=scheduler, | |
patience=5, | |
save_path='best_cycnn_model.pth' | |
) | |
# Plot results | |
print("\n--- Plotting Training History ---") | |
plot_training_history(training_history) | |
# Test rotation invariance | |
print("\n--- Testing Rotation Invariance ---") | |
test_loss, test_accuracy = validate_function(model, criterion, test_dl, device) | |
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}") | |
print("\n--- CyCNN Training Complete ---") | |
# Cleanup | |
if device.type == 'cuda': | |
torch.cuda.empty_cache() | |
gc.collect() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment